diff --git a/pgtype/bool.go b/pgtype/bool.go index e7be27e2..71caffa7 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -1,10 +1,12 @@ package pgtype import ( + "bytes" "database/sql/driver" "encoding/json" "fmt" "strconv" + "strings" ) type BoolScanner interface { @@ -264,8 +266,8 @@ func (scanPlanTextAnyToBool) Scan(src []byte, dst any) error { return fmt.Errorf("cannot scan NULL into %T", dst) } - if len(src) != 1 { - return fmt.Errorf("invalid length for bool: %v", len(src)) + if len(src) == 0 { + return fmt.Errorf("cannot scan empty string into %T", dst) } p, ok := (dst).(*bool) @@ -273,7 +275,12 @@ func (scanPlanTextAnyToBool) Scan(src []byte, dst any) error { return ErrScanTargetTypeChanged } - *p = src[0] == 't' + v, err := planTextToBool(src) + if err != nil { + return err + } + + *p = v return nil } @@ -309,9 +316,28 @@ func (scanPlanTextAnyToBoolScanner) Scan(src []byte, dst any) error { return s.ScanBool(Bool{}) } - if len(src) != 1 { - return fmt.Errorf("invalid length for bool: %v", len(src)) + if len(src) == 0 { + return fmt.Errorf("cannot scan empty string into %T", dst) } - return s.ScanBool(Bool{Bool: src[0] == 't', Valid: true}) + v, err := planTextToBool(src) + if err != nil { + return err + } + + return s.ScanBool(Bool{Bool: v, Valid: true}) +} + +// https://www.postgresql.org/docs/11/datatype-boolean.html +func planTextToBool(src []byte) (bool, error) { + s := string(bytes.ToLower(bytes.TrimSpace(src))) + + switch { + case strings.HasPrefix("true", s), strings.HasPrefix("yes", s), s == "on", s == "1": + return true, nil + case strings.HasPrefix("false", s), strings.HasPrefix("no", s), strings.HasPrefix("off", s), s == "0": + return false, nil + default: + return false, fmt.Errorf("unknown boolean string representation %q", src) + } } diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 341ce5d8..5affc0c7 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -453,6 +453,60 @@ func TestMapScanNullToWrongType(t *testing.T) { assert.False(t, pn.Valid) } +func TestMapScanTextToBool(t *testing.T) { + tests := []struct { + name string + src []byte + want bool + }{ + {"t", []byte("t"), true}, + {"f", []byte("f"), false}, + {"y", []byte("y"), true}, + {"n", []byte("n"), false}, + {"1", []byte("1"), true}, + {"0", []byte("0"), false}, + {"true", []byte("true"), true}, + {"false", []byte("false"), false}, + {"yes", []byte("yes"), true}, + {"no", []byte("no"), false}, + {"on", []byte("on"), true}, + {"off", []byte("off"), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := pgtype.NewMap() + + var v bool + err := m.Scan(pgtype.BoolOID, pgx.TextFormatCode, tt.src, &v) + require.NoError(t, err) + assert.Equal(t, tt.want, v) + }) + } +} + +func TestMapScanTextToBoolError(t *testing.T) { + tests := []struct { + name string + src []byte + want string + }{ + {"nil", nil, "cannot scan NULL into *bool"}, + {"empty", []byte{}, "cannot scan empty string into *bool"}, + {"foo", []byte("foo"), "unknown boolean string representation \"foo\""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := pgtype.NewMap() + + var v bool + err := m.Scan(pgtype.BoolOID, pgx.TextFormatCode, tt.src, &v) + require.ErrorContains(t, err, tt.want) + }) + } +} + type databaseValuerUUID [16]byte func (v databaseValuerUUID) Value() (driver.Value, error) {