mirror of https://github.com/jackc/pgx.git
Error detection for mismatched types
parent
95301ea276
commit
6c1c819a5e
2
conn.go
2
conn.go
|
@ -760,7 +760,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||
|
||||
switch arg := arguments[i].(type) {
|
||||
case BinaryEncoder:
|
||||
err = arg.EncodeBinary(wbuf)
|
||||
err = arg.EncodeBinary(wbuf, &ps.FieldDescriptions[i])
|
||||
case TextEncoder:
|
||||
var s string
|
||||
var status byte
|
||||
|
|
52
values.go
52
values.go
|
@ -63,8 +63,14 @@ type TextEncoder interface {
|
|||
// BinaryEncoder is an interface used to encode values in binary format for
|
||||
// transmission to the PostgreSQL server. It is used by prepared queries.
|
||||
type BinaryEncoder interface {
|
||||
// EncodeBinary writes the binary value to w
|
||||
EncodeBinary(w *WriteBuf) error
|
||||
// EncodeBinary writes the binary value to w.
|
||||
//
|
||||
// EncodeBinary MUST check fd.DataType to see if the parameter data type is
|
||||
// compatible. If this is not done, the PostgreSQL server may detect the
|
||||
// error if the expected data size or format of the encoded data does not
|
||||
// match. But if the encoded data is a valid representation of the data type
|
||||
// PostgreSQL expects such as date and int4, incorrect data may be stored.
|
||||
EncodeBinary(w *WriteBuf, fd *FieldDescription) error
|
||||
}
|
||||
|
||||
// NullFloat32 represents an float4 that may be null.
|
||||
|
@ -96,7 +102,11 @@ func (n NullFloat32) EncodeText() (string, byte, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (n NullFloat32) EncodeBinary(w *WriteBuf) error {
|
||||
func (n NullFloat32) EncodeBinary(w *WriteBuf, fd *FieldDescription) error {
|
||||
if fd.DataType != Float4Oid {
|
||||
return SerializationError(fmt.Sprintf("NullFloat32.EncodeBinary cannot encode into OID %d", fd.DataType))
|
||||
}
|
||||
|
||||
if !n.Valid {
|
||||
w.WriteInt32(-1)
|
||||
return nil
|
||||
|
@ -134,7 +144,11 @@ func (n NullFloat64) EncodeText() (string, byte, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (n NullFloat64) EncodeBinary(w *WriteBuf) error {
|
||||
func (n NullFloat64) EncodeBinary(w *WriteBuf, fd *FieldDescription) error {
|
||||
if fd.DataType != Float8Oid {
|
||||
return SerializationError(fmt.Sprintf("NullFloat64.EncodeBinary cannot encode into OID %d", fd.DataType))
|
||||
}
|
||||
|
||||
if !n.Valid {
|
||||
w.WriteInt32(-1)
|
||||
return nil
|
||||
|
@ -201,7 +215,11 @@ func (n NullInt16) EncodeText() (string, byte, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (n NullInt16) EncodeBinary(w *WriteBuf) error {
|
||||
func (n NullInt16) EncodeBinary(w *WriteBuf, fd *FieldDescription) error {
|
||||
if fd.DataType != Int2Oid {
|
||||
return SerializationError(fmt.Sprintf("NullInt16.EncodeBinary cannot encode into OID %d", fd.DataType))
|
||||
}
|
||||
|
||||
if !n.Valid {
|
||||
w.WriteInt32(-1)
|
||||
return nil
|
||||
|
@ -239,7 +257,11 @@ func (n NullInt32) EncodeText() (string, byte, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (n NullInt32) EncodeBinary(w *WriteBuf) error {
|
||||
func (n NullInt32) EncodeBinary(w *WriteBuf, fd *FieldDescription) error {
|
||||
if fd.DataType != Int4Oid {
|
||||
return SerializationError(fmt.Sprintf("NullInt32.EncodeBinary cannot encode into OID %d", fd.DataType))
|
||||
}
|
||||
|
||||
if !n.Valid {
|
||||
w.WriteInt32(-1)
|
||||
return nil
|
||||
|
@ -277,7 +299,11 @@ func (n NullInt64) EncodeText() (string, byte, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (n NullInt64) EncodeBinary(w *WriteBuf) error {
|
||||
func (n NullInt64) EncodeBinary(w *WriteBuf, fd *FieldDescription) error {
|
||||
if fd.DataType != Int8Oid {
|
||||
return SerializationError(fmt.Sprintf("NullInt64.EncodeBinary cannot encode into OID %d", fd.DataType))
|
||||
}
|
||||
|
||||
if !n.Valid {
|
||||
w.WriteInt32(-1)
|
||||
return nil
|
||||
|
@ -315,7 +341,11 @@ func (n NullBool) EncodeText() (string, byte, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (n NullBool) EncodeBinary(w *WriteBuf) error {
|
||||
func (n NullBool) EncodeBinary(w *WriteBuf, fd *FieldDescription) error {
|
||||
if fd.DataType != BoolOid {
|
||||
return SerializationError(fmt.Sprintf("NullBool.EncodeBinary cannot encode into OID %d", fd.DataType))
|
||||
}
|
||||
|
||||
if !n.Valid {
|
||||
w.WriteInt32(-1)
|
||||
return nil
|
||||
|
@ -354,7 +384,11 @@ func (n NullTime) EncodeText() (string, byte, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (n NullTime) EncodeBinary(w *WriteBuf) error {
|
||||
func (n NullTime) EncodeBinary(w *WriteBuf, fd *FieldDescription) error {
|
||||
if fd.DataType != TimestampTzOid {
|
||||
return SerializationError(fmt.Sprintf("NullTime.EncodeBinary cannot encode into OID %d", fd.DataType))
|
||||
}
|
||||
|
||||
if !n.Valid {
|
||||
w.WriteInt32(-1)
|
||||
return nil
|
||||
|
|
|
@ -250,3 +250,53 @@ func TestNullX(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNullXMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
type allTypes struct {
|
||||
s pgx.NullString
|
||||
i16 pgx.NullInt16
|
||||
i32 pgx.NullInt32
|
||||
i64 pgx.NullInt64
|
||||
f32 pgx.NullFloat32
|
||||
f64 pgx.NullFloat64
|
||||
b pgx.NullBool
|
||||
t pgx.NullTime
|
||||
}
|
||||
|
||||
var actual, zero allTypes
|
||||
|
||||
tests := []struct {
|
||||
sql string
|
||||
queryArgs []interface{}
|
||||
scanArgs []interface{}
|
||||
err string
|
||||
}{
|
||||
{"select $1::date", []interface{}{pgx.NullString{String: "foo", Valid: true}}, []interface{}{&actual.s}, "invalid input syntax for type date"},
|
||||
{"select $1::date", []interface{}{pgx.NullInt16{Int16: 1, Valid: true}}, []interface{}{&actual.i16}, "cannot encode into OID 1082"},
|
||||
{"select $1::date", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, "cannot encode into OID 1082"},
|
||||
{"select $1::date", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, "cannot encode into OID 1082"},
|
||||
{"select $1::date", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, "cannot encode into OID 1082"},
|
||||
{"select $1::date", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.f64}, "cannot encode into OID 1082"},
|
||||
{"select $1::date", []interface{}{pgx.NullBool{Bool: true, Valid: true}}, []interface{}{&actual.b}, "cannot encode into OID 1082"},
|
||||
{"select $1::date", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, "cannot encode into OID 1082"},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
psName := fmt.Sprintf("ps%d", i)
|
||||
mustPrepare(t, conn, psName, tt.sql)
|
||||
|
||||
actual = zero
|
||||
|
||||
err := conn.QueryRow(psName, tt.queryArgs...).Scan(tt.scanArgs...)
|
||||
if err == nil || !strings.Contains(err.Error(), tt.err) {
|
||||
t.Errorf(`%d. Expected error to contain "%s", but it didn't: %v`, i, tt.err, err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue