diff --git a/pgtype/numeric_array.go b/pgtype/numeric_array.go new file mode 100644 index 00000000..b147e6a2 --- /dev/null +++ b/pgtype/numeric_array.go @@ -0,0 +1,357 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type NumericArray struct { + Elements []Numeric + Dimensions []ArrayDimension + Status Status +} + +func (dst *NumericArray) Set(src interface{}) error { + switch value := src.(type) { + + case []float32: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []float64: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Numeric", value) + } + + return nil +} + +func (dst *NumericArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *NumericArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]float32: + *v = make([]float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]float64: + *v = make([]float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return nullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *NumericArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = NumericArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Numeric + + if len(uta.Elements) > 0 { + elements = make([]Numeric, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Numeric + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = NumericArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *NumericArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = NumericArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = NumericArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Numeric, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = NumericArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *NumericArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if len(src.Dimensions) == 0 { + _, err := io.WriteString(w, "{}") + return false, err + } + + err := EncodeTextArrayDimensions(w, src.Dimensions) + if err != nil { + return false, err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(w, '{') + if err != nil { + return false, err + } + } + } + + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else { + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(w, '}') + if err != nil { + return false, err + } + } + } + } + + return false, nil +} + +func (src *NumericArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("numeric"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "numeric") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(ci, w) + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } + } + } + + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *NumericArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *NumericArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/numeric_array_test.go b/pgtype/numeric_array_test.go new file mode 100644 index 00000000..af2e8e51 --- /dev/null +++ b/pgtype/numeric_array_test.go @@ -0,0 +1,159 @@ +package pgtype_test + +import ( + "math/big" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestNumericArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "numeric[]", []interface{}{ + &pgtype.NumericArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}, + pgtype.Numeric{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.NumericArray{Status: pgtype.Null}, + &pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(2), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(3), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(4), Status: pgtype.Present}, + pgtype.Numeric{Status: pgtype.Null}, + pgtype.Numeric{Int: big.NewInt(6), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(2), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(3), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(4), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestNumericArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.NumericArray + }{ + { + source: []float32{1}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []float64{1}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]float32)(nil)), + result: pgtype.NumericArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.NumericArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestNumericArrayAssignTo(t *testing.T) { + var float32Slice []float32 + var float64Slice []float64 + + simpleTests := []struct { + src pgtype.NumericArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float32Slice, + expected: []float32{1}, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float64Slice, + expected: []float64{1}, + }, + { + src: pgtype.NumericArray{Status: pgtype.Null}, + dst: &float32Slice, + expected: (([]float32)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.NumericArray + dst interface{} + }{ + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float32Slice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index d7e28641..208b1f00 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -216,6 +216,7 @@ func init() { "_int2": &Int2Array{}, "_int4": &Int4Array{}, "_int8": &Int8Array{}, + "_numeric": &NumericArray{}, "_text": &TextArray{}, "_timestamp": &TimestampArray{}, "_timestamptz": &TimestamptzArray{}, diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index 52612466..2e36b8b3 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -14,4 +14,5 @@ erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[] erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go +erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]float64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go goimports -w *_array.go diff --git a/v3.md b/v3.md index b79ce9cd..f1ec1990 100644 --- a/v3.md +++ b/v3.md @@ -68,7 +68,6 @@ something like: select array[1,2,3], array[4,5,6,7] pgtype TODO: -numeric[] point line lseg