diff --git a/query_test.go b/query_test.go index 9e31a009..439d808a 100644 --- a/query_test.go +++ b/query_test.go @@ -453,7 +453,7 @@ func TestQueryRowErrors(t *testing.T) { {"select $1", []interface{}{"Jack"}, []interface{}{&actual.i16}, "could not determine data type of parameter $1 (SQLSTATE 42P18)"}, {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`}, {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"}, - {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "Expected type oid 21 but received type oid 25"}, + {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "Cannot decode oid 25 into int16"}, {"select $1::int8range", []interface{}{int(705)}, []interface{}{&actual.s}, "Cannot encode int into oid 3926 - int must implement Encoder or be converted to a string"}, } diff --git a/values.go b/values.go index cd2c4e09..71185999 100644 --- a/values.go +++ b/values.go @@ -418,6 +418,11 @@ func decodeBool(vr *ValueReader) bool { return false } + if vr.Type().DataType != BoolOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into bool", vr.Type().DataType))) + return false + } + switch vr.Type().FormatCode { case TextFormatCode: s := vr.ReadString(vr.Len()) @@ -468,7 +473,7 @@ func decodeInt8(vr *ValueReader) int64 { } if vr.Type().DataType != Int8Oid { - vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int8Oid, vr.Type().DataType))) + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int8", vr.Type().DataType))) return 0 } @@ -534,7 +539,7 @@ func decodeInt2(vr *ValueReader) int16 { } if vr.Type().DataType != Int2Oid { - vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int2Oid, vr.Type().DataType))) + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int16", vr.Type().DataType))) return 0 } @@ -615,7 +620,7 @@ func decodeInt4(vr *ValueReader) int32 { } if vr.Type().DataType != Int4Oid { - vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int4Oid, vr.Type().DataType))) + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int32", vr.Type().DataType))) return 0 } @@ -689,7 +694,7 @@ func decodeOid(vr *ValueReader) Oid { } if vr.Type().DataType != OidOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", OidOid, vr.Type().DataType))) + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Oid", vr.Type().DataType))) return 0 } @@ -731,6 +736,11 @@ func decodeFloat4(vr *ValueReader) float32 { return 0 } + if vr.Type().DataType != Float4Oid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float32", vr.Type().DataType))) + return 0 + } + switch vr.Type().FormatCode { case TextFormatCode: s := vr.ReadString(vr.Len()) @@ -783,6 +793,11 @@ func decodeFloat8(vr *ValueReader) float64 { return 0 } + if vr.Type().DataType != Float8Oid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float64", vr.Type().DataType))) + return 0 + } + switch vr.Type().FormatCode { case TextFormatCode: s := vr.ReadString(vr.Len()) @@ -852,6 +867,11 @@ func decodeBytea(vr *ValueReader) []byte { return nil } + if vr.Type().DataType != ByteaOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []byte", vr.Type().DataType))) + return nil + } + switch vr.Type().FormatCode { case TextFormatCode: s := vr.ReadString(vr.Len()) @@ -890,7 +910,7 @@ func decodeDate(vr *ValueReader) time.Time { } if vr.Type().DataType != DateOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", DateOid, vr.Type().DataType))) + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) return zeroTime } @@ -936,7 +956,7 @@ func decodeTimestampTz(vr *ValueReader) time.Time { } if vr.Type().DataType != TimestampTzOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampTzOid, vr.Type().DataType))) + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) return zeroTime } @@ -986,7 +1006,7 @@ func decodeTimestamp(vr *ValueReader) time.Time { } if vr.Type().DataType != TimestampOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TimestampOid, vr.Type().DataType))) + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) return zeroTime } @@ -1048,7 +1068,7 @@ func decodeInt2Array(vr *ValueReader) []int16 { } if vr.Type().DataType != Int2ArrayOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int2ArrayOid, vr.Type().DataType))) + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int16", vr.Type().DataType))) return nil } @@ -1110,7 +1130,7 @@ func decodeInt4Array(vr *ValueReader) []int32 { } if vr.Type().DataType != Int4ArrayOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int4ArrayOid, vr.Type().DataType))) + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int32", vr.Type().DataType))) return nil } @@ -1172,7 +1192,7 @@ func decodeInt8Array(vr *ValueReader) []int64 { } if vr.Type().DataType != Int8ArrayOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Int8ArrayOid, vr.Type().DataType))) + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int64", vr.Type().DataType))) return nil } @@ -1234,7 +1254,7 @@ func decodeFloat4Array(vr *ValueReader) []float32 { } if vr.Type().DataType != Float4ArrayOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Float4ArrayOid, vr.Type().DataType))) + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []float32", vr.Type().DataType))) return nil } @@ -1300,7 +1320,7 @@ func decodeFloat8Array(vr *ValueReader) []float64 { } if vr.Type().DataType != Float8ArrayOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", Float8ArrayOid, vr.Type().DataType))) + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []float64", vr.Type().DataType))) return nil } @@ -1366,7 +1386,7 @@ func decodeTextArray(vr *ValueReader) []string { } if vr.Type().DataType != TextArrayOid && vr.Type().DataType != VarcharArrayOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v or %v but received type oid %v", TextArrayOid, VarcharArrayOid, vr.Type().DataType))) + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []string", vr.Type().DataType))) return nil }