mirror of https://github.com/jackc/pgx.git
support different bool string representations
parent
6defa2a607
commit
c27b9b49ea
|
@ -1,10 +1,12 @@
|
||||||
package pgtype
|
package pgtype
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type BoolScanner interface {
|
type BoolScanner interface {
|
||||||
|
@ -264,8 +266,8 @@ func (scanPlanTextAnyToBool) Scan(src []byte, dst any) error {
|
||||||
return fmt.Errorf("cannot scan NULL into %T", dst)
|
return fmt.Errorf("cannot scan NULL into %T", dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(src) != 1 {
|
if len(src) == 0 {
|
||||||
return fmt.Errorf("invalid length for bool: %v", len(src))
|
return fmt.Errorf("cannot scan empty string into %T", dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
p, ok := (dst).(*bool)
|
p, ok := (dst).(*bool)
|
||||||
|
@ -273,7 +275,12 @@ func (scanPlanTextAnyToBool) Scan(src []byte, dst any) error {
|
||||||
return ErrScanTargetTypeChanged
|
return ErrScanTargetTypeChanged
|
||||||
}
|
}
|
||||||
|
|
||||||
*p = src[0] == 't'
|
v, err := planTextToBool(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*p = v
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -309,9 +316,28 @@ func (scanPlanTextAnyToBoolScanner) Scan(src []byte, dst any) error {
|
||||||
return s.ScanBool(Bool{})
|
return s.ScanBool(Bool{})
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(src) != 1 {
|
if len(src) == 0 {
|
||||||
return fmt.Errorf("invalid length for bool: %v", len(src))
|
return fmt.Errorf("cannot scan empty string into %T", dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.ScanBool(Bool{Bool: src[0] == 't', Valid: true})
|
v, err := planTextToBool(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.ScanBool(Bool{Bool: v, Valid: true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://www.postgresql.org/docs/11/datatype-boolean.html
|
||||||
|
func planTextToBool(src []byte) (bool, error) {
|
||||||
|
s := string(bytes.ToLower(bytes.TrimSpace(src)))
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix("true", s), strings.HasPrefix("yes", s), s == "on", s == "1":
|
||||||
|
return true, nil
|
||||||
|
case strings.HasPrefix("false", s), strings.HasPrefix("no", s), strings.HasPrefix("off", s), s == "0":
|
||||||
|
return false, nil
|
||||||
|
default:
|
||||||
|
return false, fmt.Errorf("unknown boolean string representation %q", src)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -453,6 +453,60 @@ func TestMapScanNullToWrongType(t *testing.T) {
|
||||||
assert.False(t, pn.Valid)
|
assert.False(t, pn.Valid)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMapScanTextToBool(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
src []byte
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"t", []byte("t"), true},
|
||||||
|
{"f", []byte("f"), false},
|
||||||
|
{"y", []byte("y"), true},
|
||||||
|
{"n", []byte("n"), false},
|
||||||
|
{"1", []byte("1"), true},
|
||||||
|
{"0", []byte("0"), false},
|
||||||
|
{"true", []byte("true"), true},
|
||||||
|
{"false", []byte("false"), false},
|
||||||
|
{"yes", []byte("yes"), true},
|
||||||
|
{"no", []byte("no"), false},
|
||||||
|
{"on", []byte("on"), true},
|
||||||
|
{"off", []byte("off"), false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
m := pgtype.NewMap()
|
||||||
|
|
||||||
|
var v bool
|
||||||
|
err := m.Scan(pgtype.BoolOID, pgx.TextFormatCode, tt.src, &v)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.want, v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapScanTextToBoolError(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
src []byte
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"nil", nil, "cannot scan NULL into *bool"},
|
||||||
|
{"empty", []byte{}, "cannot scan empty string into *bool"},
|
||||||
|
{"foo", []byte("foo"), "unknown boolean string representation \"foo\""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
m := pgtype.NewMap()
|
||||||
|
|
||||||
|
var v bool
|
||||||
|
err := m.Scan(pgtype.BoolOID, pgx.TextFormatCode, tt.src, &v)
|
||||||
|
require.ErrorContains(t, err, tt.want)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type databaseValuerUUID [16]byte
|
type databaseValuerUUID [16]byte
|
||||||
|
|
||||||
func (v databaseValuerUUID) Value() (driver.Value, error) {
|
func (v databaseValuerUUID) Value() (driver.Value, error) {
|
||||||
|
|
Loading…
Reference in New Issue