diff --git a/pgtype/bool.go b/pgtype/bool.go index 4fcc67e3..4b6fbaf2 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -7,136 +7,27 @@ import ( "strconv" ) +type BoolScanner interface { + ScanBool(v bool, valid bool) error +} + type Bool struct { Bool bool Valid bool } -func (dst *Bool) Set(src interface{}) error { - if src == nil { +// ScanBool implements the BoolScanner interface. +func (dst *Bool) ScanBool(v bool, valid bool) error { + if !valid { *dst = Bool{} return nil } - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case bool: - *dst = Bool{Bool: value, Valid: true} - case string: - bb, err := strconv.ParseBool(value) - if err != nil { - return err - } - *dst = Bool{Bool: bb, Valid: true} - case *bool: - if value == nil { - *dst = Bool{} - } else { - return dst.Set(*value) - } - case *string: - if value == nil { - *dst = Bool{} - } else { - return dst.Set(*value) - } - default: - if originalSrc, ok := underlyingBoolType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Bool", value) - } + *dst = Bool{Bool: v, Valid: true} return nil } -func (dst Bool) Get() interface{} { - if !dst.Valid { - return nil - } - - return dst.Bool -} - -func (src *Bool) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - switch v := dst.(type) { - case *bool: - *v = src.Bool - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } -} - -func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Bool{} - return nil - } - - if len(src) != 1 { - return fmt.Errorf("invalid length for bool: %v", len(src)) - } - - *dst = Bool{Bool: src[0] == 't', Valid: true} - return nil -} - -func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Bool{} - return nil - } - - if len(src) != 1 { - return fmt.Errorf("invalid length for bool: %v", len(src)) - } - - *dst = Bool{Bool: src[0] == 1, Valid: true} - return nil -} - -func (src Bool) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if src.Bool { - buf = append(buf, 't') - } else { - buf = append(buf, 'f') - } - - return buf, nil -} - -func (src Bool) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if src.Bool { - buf = append(buf, 1) - } else { - buf = append(buf, 0) - } - - return buf, nil -} - // Scan implements the database/sql Scanner interface. func (dst *Bool) Scan(src interface{}) error { if src == nil { @@ -149,11 +40,19 @@ func (dst *Bool) Scan(src interface{}) error { *dst = Bool{Bool: src, Valid: true} return nil case string: - return dst.DecodeText(nil, []byte(src)) + b, err := strconv.ParseBool(src) + if err != nil { + return err + } + *dst = Bool{Bool: b, Valid: true} + return nil case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) + b, err := strconv.ParseBool(string(src)) + if err != nil { + return err + } + *dst = Bool{Bool: b, Valid: true} + return nil } return fmt.Errorf("cannot scan %T", src) @@ -195,3 +94,204 @@ func (dst *Bool) UnmarshalJSON(b []byte) error { return nil } + +type BoolCodec struct{} + +func (BoolCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (BoolCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (BoolCodec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { + v, valid, err := convertToBoolForEncode(value) + if err != nil { + return nil, fmt.Errorf("cannot convert %v to bool: %v", value, err) + } + if !valid { + return nil, nil + } + if value == nil { + return nil, nil + } + + switch format { + case BinaryFormatCode: + if v { + buf = append(buf, 1) + } else { + buf = append(buf, 0) + } + return buf, nil + case TextFormatCode: + if v { + buf = append(buf, 't') + } else { + buf = append(buf, 'f') + } + return buf, nil + default: + return nil, fmt.Errorf("unknown format code: %v", format) + } +} + +func (BoolCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case *bool: + return scanPlanBinaryBoolToBool{} + case BoolScanner: + return scanPlanBinaryBoolToBoolScanner{} + } + case TextFormatCode: + switch target.(type) { + case *bool: + return scanPlanTextAnyToBool{} + case BoolScanner: + return scanPlanTextAnyToBoolScanner{} + } + } + + return nil +} + +func (c BoolCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return c.DecodeValue(ci, oid, format, src) +} + +func (c BoolCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var b bool + scanPlan := c.PlanScan(ci, oid, format, &b, true) + if scanPlan == nil { + return nil, fmt.Errorf("PlanScan did not find a plan") + } + err := scanPlan.Scan(ci, oid, format, src, &b) + if err != nil { + return nil, err + } + return b, nil +} + +func convertToBoolForEncode(v interface{}) (b bool, valid bool, err error) { + if v == nil { + return false, false, nil + } + + switch v := v.(type) { + case bool: + return v, true, nil + case *bool: + if v == nil { + return false, false, nil + } + return *v, true, nil + case string: + bb, err := strconv.ParseBool(v) + if err != nil { + return false, false, err + } + return bb, true, nil + case *string: + if v == nil { + return false, false, nil + } + bb, err := strconv.ParseBool(*v) + if err != nil { + return false, false, err + } + return bb, true, nil + default: + if originalvalue, ok := underlyingBoolType(v); ok { + return convertToBoolForEncode(originalvalue) + } + return false, false, fmt.Errorf("cannot convert %v to bool", v) + } +} + +type scanPlanBinaryBoolToBool struct{} + +func (scanPlanBinaryBoolToBool) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) + } + + p, ok := (dst).(*bool) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = src[0] == 1 + + return nil +} + +type scanPlanTextAnyToBool struct{} + +func (scanPlanTextAnyToBool) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) + } + + p, ok := (dst).(*bool) + if !ok { + return ErrScanTargetTypeChanged + } + + *p = src[0] == 't' + + return nil +} + +type scanPlanBinaryBoolToBoolScanner struct{} + +func (scanPlanBinaryBoolToBoolScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + s, ok := (dst).(BoolScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanBool(false, false) + } + + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) + } + + return s.ScanBool(src[0] == 1, true) +} + +type scanPlanTextAnyToBoolScanner struct{} + +func (scanPlanTextAnyToBoolScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + s, ok := (dst).(BoolScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanBool(false, false) + } + + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) + } + + return s.ScanBool(src[0] == 't', true) +} diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go deleted file mode 100644 index a282fd6b..00000000 --- a/pgtype/bool_array.go +++ /dev/null @@ -1,504 +0,0 @@ -// Code generated by erb. DO NOT EDIT. - -package pgtype - -import ( - "database/sql/driver" - "encoding/binary" - "fmt" - "reflect" - - "github.com/jackc/pgio" -) - -type BoolArray struct { - Elements []Bool - Dimensions []ArrayDimension - Valid bool -} - -func (dst *BoolArray) Set(src interface{}) error { - // untyped nil and typed nil interfaces are different - if src == nil { - *dst = BoolArray{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - // Attempt to match to select common types: - switch value := src.(type) { - - case []bool: - if value == nil { - *dst = BoolArray{} - } else if len(value) == 0 { - *dst = BoolArray{Valid: true} - } else { - elements := make([]Bool, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = BoolArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []*bool: - if value == nil { - *dst = BoolArray{} - } else if len(value) == 0 { - *dst = BoolArray{Valid: true} - } else { - elements := make([]Bool, len(value)) - for i := range value { - if err := elements[i].Set(value[i]); err != nil { - return err - } - } - *dst = BoolArray{ - Elements: elements, - Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, - Valid: true, - } - } - - case []Bool: - if value == nil { - *dst = BoolArray{} - } else if len(value) == 0 { - *dst = BoolArray{Valid: true} - } else { - *dst = BoolArray{ - Elements: value, - Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, - Valid: true, - } - } - default: - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - reflectedValue := reflect.ValueOf(src) - if !reflectedValue.IsValid() || reflectedValue.IsZero() { - *dst = BoolArray{} - return nil - } - - dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) - if !ok { - return fmt.Errorf("cannot find dimensions of %v for BoolArray", src) - } - if elementsLength == 0 { - *dst = BoolArray{Valid: true} - return nil - } - if len(dimensions) == 0 { - if originalSrc, ok := underlyingSliceType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to BoolArray", src) - } - - *dst = BoolArray{ - Elements: make([]Bool, elementsLength), - Dimensions: dimensions, - Valid: true, - } - elementCount, err := dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - // Maybe the target was one dimension too far, try again: - if len(dst.Dimensions) > 1 { - dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] - elementsLength = 0 - for _, dim := range dst.Dimensions { - if elementsLength == 0 { - elementsLength = int(dim.Length) - } else { - elementsLength *= int(dim.Length) - } - } - dst.Elements = make([]Bool, elementsLength) - elementCount, err = dst.setRecursive(reflectedValue, 0, 0) - if err != nil { - return err - } - } else { - return err - } - } - if elementCount != len(dst.Elements) { - return fmt.Errorf("cannot convert %v to BoolArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) - } - } - - return nil -} - -func (dst *BoolArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(dst.Dimensions) == dimension { - break - } - - valueLen := value.Len() - if int32(valueLen) != dst.Dimensions[dimension].Length { - return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") - } - for i := 0; i < valueLen; i++ { - var err error - index, err = dst.setRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if !value.CanInterface() { - return 0, fmt.Errorf("cannot convert all values to BoolArray") - } - if err := dst.Elements[index].Set(value.Interface()); err != nil { - return 0, fmt.Errorf("%v in BoolArray", err) - } - index++ - - return index, nil -} - -func (dst BoolArray) Get() interface{} { - if !dst.Valid { - return nil - } - return dst -} - -func (src *BoolArray) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) - } - - if len(src.Dimensions) <= 1 { - // Attempt to match to select common types: - switch v := dst.(type) { - - case *[]bool: - *v = make([]bool, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - case *[]*bool: - *v = make([]*bool, len(src.Elements)) - for i := range src.Elements { - if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { - return err - } - } - return nil - - } - } - - // Try to convert to something AssignTo can use directly. - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - - // Fallback to reflection if an optimised match was not found. - // The reflection is necessary for arrays and multidimensional slices, - // but it comes with a 20-50% performance penalty for large arrays/slices - value := reflect.ValueOf(dst) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - switch value.Kind() { - case reflect.Array, reflect.Slice: - default: - return fmt.Errorf("cannot assign %T to %T", src, dst) - } - - if len(src.Elements) == 0 { - if value.Kind() == reflect.Slice { - value.Set(reflect.MakeSlice(value.Type(), 0, 0)) - return nil - } - } - - elementCount, err := src.assignToRecursive(value, 0, 0) - if err != nil { - return err - } - if elementCount != len(src.Elements) { - return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) - } - - return nil -} - -func (src *BoolArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { - switch kind := value.Kind(); kind { - case reflect.Array: - fallthrough - case reflect.Slice: - if len(src.Dimensions) == dimension { - break - } - - length := int(src.Dimensions[dimension].Length) - if reflect.Array == kind { - typ := value.Type() - if typ.Len() != length { - return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) - } - value.Set(reflect.New(typ).Elem()) - } else { - value.Set(reflect.MakeSlice(value.Type(), length, length)) - } - - var err error - for i := 0; i < length; i++ { - index, err = src.assignToRecursive(value.Index(i), index, dimension+1) - if err != nil { - return 0, err - } - } - - return index, nil - } - if len(src.Dimensions) != dimension { - return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) - } - if !value.CanAddr() { - return 0, fmt.Errorf("cannot assign all values from BoolArray") - } - addr := value.Addr() - if !addr.CanInterface() { - return 0, fmt.Errorf("cannot assign all values from BoolArray") - } - if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { - return 0, err - } - index++ - return index, nil -} - -func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = BoolArray{} - return nil - } - - uta, err := ParseUntypedTextArray(string(src)) - if err != nil { - return err - } - - var elements []Bool - - if len(uta.Elements) > 0 { - elements = make([]Bool, len(uta.Elements)) - - for i, s := range uta.Elements { - var elem Bool - var elemSrc []byte - if s != "NULL" || uta.Quoted[i] { - elemSrc = []byte(s) - } - err = elem.DecodeText(ci, elemSrc) - if err != nil { - return err - } - - elements[i] = elem - } - } - - *dst = BoolArray{Elements: elements, Dimensions: uta.Dimensions, Valid: true} - - return nil -} - -func (dst *BoolArray) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = BoolArray{} - return nil - } - - var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(ci, src) - if err != nil { - return err - } - - if len(arrayHeader.Dimensions) == 0 { - *dst = BoolArray{Dimensions: arrayHeader.Dimensions, Valid: true} - return nil - } - - elementCount := arrayHeader.Dimensions[0].Length - for _, d := range arrayHeader.Dimensions[1:] { - elementCount *= d.Length - } - - elements := make([]Bool, elementCount) - - for i := range elements { - elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - var elemSrc []byte - if elemLen >= 0 { - elemSrc = src[rp : rp+elemLen] - rp += elemLen - } - err = elements[i].DecodeBinary(ci, elemSrc) - if err != nil { - return err - } - } - - *dst = BoolArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Valid: true} - return nil -} - -func (src BoolArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - if len(src.Dimensions) == 0 { - return append(buf, '{', '}'), nil - } - - buf = EncodeTextArrayDimensions(buf, src.Dimensions) - - // dimElemCounts is the multiples of elements that each array lies on. For - // example, a single dimension array of length 4 would have a dimElemCounts of - // [4]. A multi-dimensional array of lengths [3,5,2] would have a - // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' - // or '}'. - dimElemCounts := make([]int, len(src.Dimensions)) - dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) - for i := len(src.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] - } - - inElemBuf := make([]byte, 0, 32) - for i, elem := range src.Elements { - if i > 0 { - buf = append(buf, ',') - } - - for _, dec := range dimElemCounts { - if i%dec == 0 { - buf = append(buf, '{') - } - } - - elemBuf, err := elem.EncodeText(ci, inElemBuf) - if err != nil { - return nil, err - } - if elemBuf == nil { - buf = append(buf, `NULL`...) - } else { - buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) - } - - for _, dec := range dimElemCounts { - if (i+1)%dec == 0 { - buf = append(buf, '}') - } - } - } - - return buf, nil -} - -func (src BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("bool"); ok { - arrayHeader.ElementOID = int32(dt.OID) - } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "bool") - } - - for i := range src.Elements { - if !src.Elements[i].Valid { - arrayHeader.ContainsNull = true - break - } - } - - buf = arrayHeader.EncodeBinary(ci, buf) - - for i := range src.Elements { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) - if err != nil { - return nil, err - } - if elemBuf != nil { - buf = elemBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - } - - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *BoolArray) Scan(src interface{}) error { - if src == nil { - return dst.DecodeText(nil, nil) - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src BoolArray) Value() (driver.Value, error) { - buf, err := src.EncodeText(nil, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - - return string(buf), nil -} diff --git a/pgtype/bool_array_test.go b/pgtype/bool_array_test.go deleted file mode 100644 index 7de5612a..00000000 --- a/pgtype/bool_array_test.go +++ /dev/null @@ -1,283 +0,0 @@ -package pgtype_test - -import ( - "reflect" - "testing" - - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" -) - -func TestBoolArrayTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "bool[]", []interface{}{ - &pgtype.BoolArray{ - Elements: nil, - Dimensions: nil, - Valid: true, - }, - &pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Valid: true}, - {}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.BoolArray{}, - &pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Valid: true}, - {Bool: true, Valid: true}, - {Bool: false, Valid: true}, - {Bool: true, Valid: true}, - {}, - {Bool: false, Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, - Valid: true, - }, - &pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Valid: true}, - {Bool: false, Valid: true}, - {Bool: true, Valid: true}, - {Bool: false, Valid: true}, - }, - Dimensions: []pgtype.ArrayDimension{ - {Length: 2, LowerBound: 4}, - {Length: 2, LowerBound: 2}, - }, - Valid: true, - }, - }) -} - -func TestBoolArraySet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.BoolArray - }{ - { - source: []bool{true}, - result: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: (([]bool)(nil)), - result: pgtype.BoolArray{}, - }, - { - source: [][]bool{{true}, {false}}, - result: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [][][][]bool{{{{true, false, true}}}, {{{false, true, false}}}}, - result: pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Valid: true}, - {Bool: false, Valid: true}, - {Bool: true, Valid: true}, - {Bool: false, Valid: true}, - {Bool: true, Valid: true}, - {Bool: false, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - { - source: [2][1]bool{{true}, {false}}, - result: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - }, - { - source: [2][1][1][3]bool{{{{true, false, true}}}, {{{false, true, false}}}}, - result: pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Valid: true}, - {Bool: false, Valid: true}, - {Bool: true, Valid: true}, - {Bool: false, Valid: true}, - {Bool: true, Valid: true}, - {Bool: false, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r pgtype.BoolArray - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if !reflect.DeepEqual(r, tt.result) { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestBoolArrayAssignTo(t *testing.T) { - var boolSlice []bool - type _boolSlice []bool - var namedBoolSlice _boolSlice - var boolSliceDim2 [][]bool - var boolSliceDim4 [][][][]bool - var boolArrayDim2 [2][1]bool - var boolArrayDim4 [2][1][1][3]bool - - simpleTests := []struct { - src pgtype.BoolArray - dst interface{} - expected interface{} - }{ - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &boolSlice, - expected: []bool{true}, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &namedBoolSlice, - expected: _boolSlice{true}, - }, - { - src: pgtype.BoolArray{}, - dst: &boolSlice, - expected: (([]bool)(nil)), - }, - { - src: pgtype.BoolArray{Valid: true}, - dst: &boolSlice, - expected: []bool{}, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - expected: [][]bool{{true}, {false}}, - dst: &boolSliceDim2, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Valid: true}, - {Bool: false, Valid: true}, - {Bool: true, Valid: true}, - {Bool: false, Valid: true}, - {Bool: true, Valid: true}, - {Bool: false, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - expected: [][][][]bool{{{{true, false, true}}}, {{{false, true, false}}}}, - dst: &boolSliceDim4, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - expected: [2][1]bool{{true}, {false}}, - dst: &boolArrayDim2, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{ - {Bool: true, Valid: true}, - {Bool: false, Valid: true}, - {Bool: true, Valid: true}, - {Bool: false, Valid: true}, - {Bool: true, Valid: true}, - {Bool: false, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{ - {LowerBound: 1, Length: 2}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 1}, - {LowerBound: 1, Length: 3}}, - Valid: true}, - expected: [2][1][1][3]bool{{{{true, false, true}}}, {{{false, true, false}}}}, - dst: &boolArrayDim4, - }, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.BoolArray - dst interface{} - }{ - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, - Valid: true, - }, - dst: &boolSlice, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &boolArrayDim2, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}, {LowerBound: 1, Length: 2}}, - Valid: true}, - dst: &boolSlice, - }, - { - src: pgtype.BoolArray{ - Elements: []pgtype.Bool{{Bool: true, Valid: true}, {Bool: false, Valid: true}}, - Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 2}, {LowerBound: 1, Length: 1}}, - Valid: true}, - dst: &boolArrayDim4, - }, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } - -} diff --git a/pgtype/bool_test.go b/pgtype/bool_test.go index 9a07491f..ec8c31d9 100644 --- a/pgtype/bool_test.go +++ b/pgtype/bool_test.go @@ -1,101 +1,21 @@ package pgtype_test import ( - "reflect" "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) -func TestBoolTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "bool", []interface{}{ - &pgtype.Bool{Bool: false, Valid: true}, - &pgtype.Bool{Bool: true, Valid: true}, - &pgtype.Bool{Bool: false}, +func TestBoolCodec(t *testing.T) { + testPgxCodec(t, "bool", []PgxTranscodeTestCase{ + {true, new(bool), isExpectedEq(true)}, + {false, new(bool), isExpectedEq(false)}, + {true, new(pgtype.Bool), isExpectedEq(pgtype.Bool{Bool: true, Valid: true})}, + {pgtype.Bool{}, new(pgtype.Bool), isExpectedEq(pgtype.Bool{})}, + {nil, new(*bool), isExpectedEq((*bool)(nil))}, }) } -func TestBoolSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result pgtype.Bool - }{ - {source: true, result: pgtype.Bool{Bool: true, Valid: true}}, - {source: false, result: pgtype.Bool{Bool: false, Valid: true}}, - {source: "true", result: pgtype.Bool{Bool: true, Valid: true}}, - {source: "false", result: pgtype.Bool{Bool: false, Valid: true}}, - {source: "t", result: pgtype.Bool{Bool: true, Valid: true}}, - {source: "f", result: pgtype.Bool{Bool: false, Valid: true}}, - {source: _bool(true), result: pgtype.Bool{Bool: true, Valid: true}}, - {source: _bool(false), result: pgtype.Bool{Bool: false, Valid: true}}, - {source: nil, result: pgtype.Bool{}}, - } - - for i, tt := range successfulTests { - var r pgtype.Bool - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestBoolAssignTo(t *testing.T) { - var b bool - var _b _bool - var pb *bool - var _pb *_bool - - simpleTests := []struct { - src pgtype.Bool - dst interface{} - expected interface{} - }{ - {src: pgtype.Bool{Bool: false, Valid: true}, dst: &b, expected: false}, - {src: pgtype.Bool{Bool: true, Valid: true}, dst: &b, expected: true}, - {src: pgtype.Bool{Bool: false, Valid: true}, dst: &_b, expected: _bool(false)}, - {src: pgtype.Bool{Bool: true, Valid: true}, dst: &_b, expected: _bool(true)}, - {src: pgtype.Bool{Bool: false}, dst: &pb, expected: ((*bool)(nil))}, - {src: pgtype.Bool{Bool: false}, dst: &_pb, expected: ((*_bool)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Bool - dst interface{} - expected interface{} - }{ - {src: pgtype.Bool{Bool: true, Valid: true}, dst: &pb, expected: true}, - {src: pgtype.Bool{Bool: true, Valid: true}, dst: &_pb, expected: _bool(true)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } -} - func TestBoolMarshalJSON(t *testing.T) { successfulTests := []struct { source pgtype.Bool diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index fe3fae44..372f755b 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -292,7 +292,7 @@ func NewConnInfo() *ConnInfo { ci := newConnInfo() ci.RegisterDataType(DataType{Value: &ACLItemArray{}, Name: "_aclitem", OID: ACLItemArrayOID}) - ci.RegisterDataType(DataType{Value: &BoolArray{}, Name: "_bool", OID: BoolArrayOID}) + ci.RegisterDataType(DataType{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementCodec: BoolCodec{}, ElementOID: BoolOID}}) ci.RegisterDataType(DataType{Value: &BPCharArray{}, Name: "_bpchar", OID: BPCharArrayOID}) ci.RegisterDataType(DataType{Value: &ByteaArray{}, Name: "_bytea", OID: ByteaArrayOID}) ci.RegisterDataType(DataType{Value: &CIDRArray{}, Name: "_cidr", OID: CIDRArrayOID}) @@ -311,7 +311,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &VarcharArray{}, Name: "_varchar", OID: VarcharArrayOID}) ci.RegisterDataType(DataType{Value: &ACLItem{}, Name: "aclitem", OID: ACLItemOID}) ci.RegisterDataType(DataType{Value: &Bit{}, Name: "bit", OID: BitOID}) - ci.RegisterDataType(DataType{Value: &Bool{}, Name: "bool", OID: BoolOID}) + ci.RegisterDataType(DataType{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) ci.RegisterDataType(DataType{Value: &Box{}, Name: "box", OID: BoxOID}) ci.RegisterDataType(DataType{Value: &BPChar{}, Name: "bpchar", OID: BPCharOID}) ci.RegisterDataType(DataType{Value: &Bytea{}, Name: "bytea", OID: ByteaOID}) diff --git a/pgtype/zzz.bool.go b/pgtype/zzz.bool.go deleted file mode 100644 index e6ed52de..00000000 --- a/pgtype/zzz.bool.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Bool) BinaryFormatSupported() bool { - return true -} - -func (Bool) TextFormatSupported() bool { - return true -} - -func (Bool) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Bool) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Bool) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -}