diff --git a/query_test.go b/query_test.go index 68c798b2..556fecef 100644 --- a/query_test.go +++ b/query_test.go @@ -1551,6 +1551,238 @@ func TestConnSimpleProtocol(t *testing.T) { } } + { + tests := []struct { + expected []string + }{ + {[]string(nil)}, + {[]string{}}, + {[]string{"test", "foo", "bar"}}, + {[]string{`foo'bar"\baz;quz`, `foo'bar"\baz;quz`}}, + } + for i, tt := range tests { + var actual []string + err := conn.QueryRow( + context.Background(), + "select $1::text[]", + pgx.QuerySimpleProtocol(true), + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + } + + { + tests := []struct { + expected []int16 + }{ + {[]int16(nil)}, + {[]int16{}}, + {[]int16{1, 2, 3}}, + } + for i, tt := range tests { + var actual []int16 + err := conn.QueryRow( + context.Background(), + "select $1::smallint[]", + pgx.QuerySimpleProtocol(true), + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + } + + { + tests := []struct { + expected []int32 + }{ + {[]int32(nil)}, + {[]int32{}}, + {[]int32{1, 2, 3}}, + } + for i, tt := range tests { + var actual []int32 + err := conn.QueryRow( + context.Background(), + "select $1::int[]", + pgx.QuerySimpleProtocol(true), + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + } + + { + tests := []struct { + expected []int64 + }{ + {[]int64(nil)}, + {[]int64{}}, + {[]int64{1, 2, 3}}, + } + for i, tt := range tests { + var actual []int64 + err := conn.QueryRow( + context.Background(), + "select $1::bigint[]", + pgx.QuerySimpleProtocol(true), + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + } + + { + tests := []struct { + expected []int + }{ + {[]int(nil)}, + {[]int{}}, + {[]int{1, 2, 3}}, + } + for i, tt := range tests { + var actual []int + err := conn.QueryRow( + context.Background(), + "select $1::bigint[]", + pgx.QuerySimpleProtocol(true), + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + } + + { + tests := []struct { + expected []uint16 + }{ + {[]uint16(nil)}, + {[]uint16{}}, + {[]uint16{1, 2, 3}}, + } + for i, tt := range tests { + var actual []uint16 + err := conn.QueryRow( + context.Background(), + "select $1::smallint[]", + pgx.QuerySimpleProtocol(true), + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + } + + { + tests := []struct { + expected []uint32 + }{ + {[]uint32(nil)}, + {[]uint32{}}, + {[]uint32{1, 2, 3}}, + } + for i, tt := range tests { + var actual []uint32 + err := conn.QueryRow( + context.Background(), + "select $1::bigint[]", + pgx.QuerySimpleProtocol(true), + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + } + + { + tests := []struct { + expected []uint64 + }{ + {[]uint64(nil)}, + {[]uint64{}}, + {[]uint64{1, 2, 3}}, + } + for i, tt := range tests { + var actual []uint64 + err := conn.QueryRow( + context.Background(), + "select $1::bigint[]", + pgx.QuerySimpleProtocol(true), + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + } + + { + tests := []struct { + expected []uint + }{ + {[]uint(nil)}, + {[]uint{}}, + {[]uint{1, 2, 3}}, + } + for i, tt := range tests { + var actual []uint + err := conn.QueryRow( + context.Background(), + "select $1::bigint[]", + pgx.QuerySimpleProtocol(true), + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + } + + { + tests := []struct { + expected []float32 + }{ + {[]float32(nil)}, + {[]float32{}}, + {[]float32{1, 2, 3}}, + } + for i, tt := range tests { + var actual []float32 + err := conn.QueryRow( + context.Background(), + "select $1::float4[]", + pgx.QuerySimpleProtocol(true), + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + } + + { + tests := []struct { + expected []float64 + }{ + {[]float64(nil)}, + {[]float64{}}, + {[]float64{1, 2, 3}}, + } + for i, tt := range tests { + var actual []float64 + err := conn.QueryRow( + context.Background(), + "select $1::float8[]", + pgx.QuerySimpleProtocol(true), + tt.expected, + ).Scan(&actual) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expected, actual, "%d", i) + } + } + // Test high-level type { diff --git a/values.go b/values.go index da9bb70a..afb1c46f 100644 --- a/values.go +++ b/values.go @@ -30,6 +30,24 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e return nil, nil } + setArgAndEncodeTextToString := func(t interface { + pgtype.Value + pgtype.TextEncoder + }) (interface{}, error) { + err := t.Set(arg) + if err != nil { + return nil, err + } + buf, err := t.EncodeText(ci, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + return string(buf), nil + } + switch arg := arg.(type) { // https://github.com/jackc/pgx/issues/409 Changed JSON and JSONB to surface @@ -68,8 +86,8 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e return nil, nil } return string(buf), nil - case int64: - return arg, nil + case float32: + return float64(arg), nil case float64: return arg, nil case bool: @@ -86,6 +104,8 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e return int64(arg), nil case int32: return int64(arg), nil + case int64: + return arg, nil case int: return int64(arg), nil case uint8: @@ -104,8 +124,28 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e return nil, errors.Errorf("arg too big for int64: %v", arg) } return int64(arg), nil - case float32: - return float64(arg), nil + case []string: + return setArgAndEncodeTextToString(&pgtype.TextArray{}) + case []float32: + return setArgAndEncodeTextToString(&pgtype.Float4Array{}) + case []float64: + return setArgAndEncodeTextToString(&pgtype.Float8Array{}) + case []int16: + return setArgAndEncodeTextToString(&pgtype.Int8Array{}) + case []int32: + return setArgAndEncodeTextToString(&pgtype.Int8Array{}) + case []int64: + return setArgAndEncodeTextToString(&pgtype.Int8Array{}) + case []int: + return setArgAndEncodeTextToString(&pgtype.Int8Array{}) + case []uint16: + return setArgAndEncodeTextToString(&pgtype.Int8Array{}) + case []uint32: + return setArgAndEncodeTextToString(&pgtype.Int8Array{}) + case []uint64: + return setArgAndEncodeTextToString(&pgtype.Int8Array{}) + case []uint: + return setArgAndEncodeTextToString(&pgtype.Int8Array{}) } refVal := reflect.ValueOf(arg)