diff --git a/conn.go b/conn.go index b0f2cac0..f6a1e902 100644 --- a/conn.go +++ b/conn.go @@ -760,7 +760,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} switch arg := arguments[i].(type) { case BinaryEncoder: - err = arg.EncodeBinary(wbuf) + err = arg.EncodeBinary(wbuf, &ps.FieldDescriptions[i]) case TextEncoder: var s string var status byte diff --git a/values.go b/values.go index 77edfb22..94932851 100644 --- a/values.go +++ b/values.go @@ -63,8 +63,14 @@ type TextEncoder interface { // BinaryEncoder is an interface used to encode values in binary format for // transmission to the PostgreSQL server. It is used by prepared queries. type BinaryEncoder interface { - // EncodeBinary writes the binary value to w - EncodeBinary(w *WriteBuf) error + // EncodeBinary writes the binary value to w. + // + // EncodeBinary MUST check fd.DataType to see if the parameter data type is + // compatible. If this is not done, the PostgreSQL server may detect the + // error if the expected data size or format of the encoded data does not + // match. But if the encoded data is a valid representation of the data type + // PostgreSQL expects such as date and int4, incorrect data may be stored. + EncodeBinary(w *WriteBuf, fd *FieldDescription) error } // NullFloat32 represents an float4 that may be null. @@ -96,7 +102,11 @@ func (n NullFloat32) EncodeText() (string, byte, error) { } } -func (n NullFloat32) EncodeBinary(w *WriteBuf) error { +func (n NullFloat32) EncodeBinary(w *WriteBuf, fd *FieldDescription) error { + if fd.DataType != Float4Oid { + return SerializationError(fmt.Sprintf("NullFloat32.EncodeBinary cannot encode into OID %d", fd.DataType)) + } + if !n.Valid { w.WriteInt32(-1) return nil @@ -134,7 +144,11 @@ func (n NullFloat64) EncodeText() (string, byte, error) { } } -func (n NullFloat64) EncodeBinary(w *WriteBuf) error { +func (n NullFloat64) EncodeBinary(w *WriteBuf, fd *FieldDescription) error { + if fd.DataType != Float8Oid { + return SerializationError(fmt.Sprintf("NullFloat64.EncodeBinary cannot encode into OID %d", fd.DataType)) + } + if !n.Valid { w.WriteInt32(-1) return nil @@ -201,7 +215,11 @@ func (n NullInt16) EncodeText() (string, byte, error) { } } -func (n NullInt16) EncodeBinary(w *WriteBuf) error { +func (n NullInt16) EncodeBinary(w *WriteBuf, fd *FieldDescription) error { + if fd.DataType != Int2Oid { + return SerializationError(fmt.Sprintf("NullInt16.EncodeBinary cannot encode into OID %d", fd.DataType)) + } + if !n.Valid { w.WriteInt32(-1) return nil @@ -239,7 +257,11 @@ func (n NullInt32) EncodeText() (string, byte, error) { } } -func (n NullInt32) EncodeBinary(w *WriteBuf) error { +func (n NullInt32) EncodeBinary(w *WriteBuf, fd *FieldDescription) error { + if fd.DataType != Int4Oid { + return SerializationError(fmt.Sprintf("NullInt32.EncodeBinary cannot encode into OID %d", fd.DataType)) + } + if !n.Valid { w.WriteInt32(-1) return nil @@ -277,7 +299,11 @@ func (n NullInt64) EncodeText() (string, byte, error) { } } -func (n NullInt64) EncodeBinary(w *WriteBuf) error { +func (n NullInt64) EncodeBinary(w *WriteBuf, fd *FieldDescription) error { + if fd.DataType != Int8Oid { + return SerializationError(fmt.Sprintf("NullInt64.EncodeBinary cannot encode into OID %d", fd.DataType)) + } + if !n.Valid { w.WriteInt32(-1) return nil @@ -315,7 +341,11 @@ func (n NullBool) EncodeText() (string, byte, error) { } } -func (n NullBool) EncodeBinary(w *WriteBuf) error { +func (n NullBool) EncodeBinary(w *WriteBuf, fd *FieldDescription) error { + if fd.DataType != BoolOid { + return SerializationError(fmt.Sprintf("NullBool.EncodeBinary cannot encode into OID %d", fd.DataType)) + } + if !n.Valid { w.WriteInt32(-1) return nil @@ -354,7 +384,11 @@ func (n NullTime) EncodeText() (string, byte, error) { } } -func (n NullTime) EncodeBinary(w *WriteBuf) error { +func (n NullTime) EncodeBinary(w *WriteBuf, fd *FieldDescription) error { + if fd.DataType != TimestampTzOid { + return SerializationError(fmt.Sprintf("NullTime.EncodeBinary cannot encode into OID %d", fd.DataType)) + } + if !n.Valid { w.WriteInt32(-1) return nil diff --git a/values_test.go b/values_test.go index 0c208836..bb896774 100644 --- a/values_test.go +++ b/values_test.go @@ -250,3 +250,53 @@ func TestNullX(t *testing.T) { } } } + +func TestNullXMismatch(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + type allTypes struct { + s pgx.NullString + i16 pgx.NullInt16 + i32 pgx.NullInt32 + i64 pgx.NullInt64 + f32 pgx.NullFloat32 + f64 pgx.NullFloat64 + b pgx.NullBool + t pgx.NullTime + } + + var actual, zero allTypes + + tests := []struct { + sql string + queryArgs []interface{} + scanArgs []interface{} + err string + }{ + {"select $1::date", []interface{}{pgx.NullString{String: "foo", Valid: true}}, []interface{}{&actual.s}, "invalid input syntax for type date"}, + {"select $1::date", []interface{}{pgx.NullInt16{Int16: 1, Valid: true}}, []interface{}{&actual.i16}, "cannot encode into OID 1082"}, + {"select $1::date", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, "cannot encode into OID 1082"}, + {"select $1::date", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, "cannot encode into OID 1082"}, + {"select $1::date", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, "cannot encode into OID 1082"}, + {"select $1::date", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.f64}, "cannot encode into OID 1082"}, + {"select $1::date", []interface{}{pgx.NullBool{Bool: true, Valid: true}}, []interface{}{&actual.b}, "cannot encode into OID 1082"}, + {"select $1::date", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, "cannot encode into OID 1082"}, + } + + for i, tt := range tests { + psName := fmt.Sprintf("ps%d", i) + mustPrepare(t, conn, psName, tt.sql) + + actual = zero + + err := conn.QueryRow(psName, tt.queryArgs...).Scan(tt.scanArgs...) + if err == nil || !strings.Contains(err.Error(), tt.err) { + t.Errorf(`%d. Expected error to contain "%s", but it didn't: %v`, i, tt.err, err) + } + + ensureConnValid(t, conn) + } +}