diff --git a/values.go b/values.go index cf20fdee..0dc5b641 100644 --- a/values.go +++ b/values.go @@ -2,7 +2,6 @@ package pgx import ( "bytes" - "encoding/hex" "fmt" "math" "strconv" @@ -573,29 +572,18 @@ func decodeBool(vr *ValueReader) bool { return false } - switch vr.Type().FormatCode { - case TextFormatCode: - s := vr.ReadString(vr.Len()) - switch s { - case "t": - return true - case "f": - return false - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid bool: %v", s))) - return false - } - case BinaryFormatCode: - if vr.Len() != 1 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool: %d", vr.Len()))) - return false - } - b := vr.ReadByte() - return b != 0 - default: + if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return false } + + if vr.Len() != 1 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool: %d", vr.Len()))) + return false + } + + b := vr.ReadByte() + return b != 0 } func encodeBool(w *WriteBuf, value interface{}) error { @@ -627,25 +615,17 @@ func decodeInt8(vr *ValueReader) int64 { return 0 } - switch vr.Type().FormatCode { - case TextFormatCode: - s := vr.ReadString(vr.Len()) - n, err := strconv.ParseInt(s, 10, 64) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int8: %v", s))) - return 0 - } - return n - case BinaryFormatCode: - if vr.Len() != 8 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", vr.Len()))) - return 0 - } - return vr.ReadInt64() - default: + if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 } + + if vr.Len() != 8 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", vr.Len()))) + return 0 + } + + return vr.ReadInt64() } func encodeInt8(w *WriteBuf, value interface{}) error { @@ -693,25 +673,17 @@ func decodeInt2(vr *ValueReader) int16 { return 0 } - switch vr.Type().FormatCode { - case TextFormatCode: - s := vr.ReadString(vr.Len()) - n, err := strconv.ParseInt(s, 10, 16) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int2: %v", s))) - return 0 - } - return int16(n) - case BinaryFormatCode: - if vr.Len() != 2 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2: %d", vr.Len()))) - return 0 - } - return vr.ReadInt16() - default: + if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 } + + if vr.Len() != 2 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2: %d", vr.Len()))) + return 0 + } + + return vr.ReadInt16() } func encodeInt2(w *WriteBuf, value interface{}) error { @@ -774,24 +746,17 @@ func decodeInt4(vr *ValueReader) int32 { return 0 } - switch vr.Type().FormatCode { - case TextFormatCode: - s := vr.ReadString(vr.Len()) - n, err := strconv.ParseInt(s, 10, 32) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid int4: %v", s))) - } - return int32(n) - case BinaryFormatCode: - if vr.Len() != 4 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4: %d", vr.Len()))) - return 0 - } - return vr.ReadInt32() - default: + if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 } + + if vr.Len() != 4 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4: %d", vr.Len()))) + return 0 + } + + return vr.ReadInt32() } func encodeInt4(w *WriteBuf, value interface{}) error { @@ -848,6 +813,7 @@ func decodeOid(vr *ValueReader) Oid { return 0 } + // Oid needs to decode text format because it is used in loadPgTypes switch vr.Type().FormatCode { case TextFormatCode: s := vr.ReadString(vr.Len()) @@ -891,27 +857,18 @@ func decodeFloat4(vr *ValueReader) float32 { return 0 } - switch vr.Type().FormatCode { - case TextFormatCode: - s := vr.ReadString(vr.Len()) - n, err := strconv.ParseFloat(s, 32) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid float4: %v", s))) - return 0 - } - return float32(n) - case BinaryFormatCode: - if vr.Len() != 4 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", vr.Len()))) - return 0 - } - - i := vr.ReadInt32() - return math.Float32frombits(uint32(i)) - default: + if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 } + + if vr.Len() != 4 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", vr.Len()))) + return 0 + } + + i := vr.ReadInt32() + return math.Float32frombits(uint32(i)) } func encodeFloat4(w *WriteBuf, value interface{}) error { @@ -946,27 +903,18 @@ func decodeFloat8(vr *ValueReader) float64 { return 0 } - switch vr.Type().FormatCode { - case TextFormatCode: - s := vr.ReadString(vr.Len()) - v, err := strconv.ParseFloat(s, 64) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid float8: %v", s))) - return 0 - } - return v - case BinaryFormatCode: - if vr.Len() != 8 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float8: %d", vr.Len()))) - return 0 - } - - i := vr.ReadInt64() - return math.Float64frombits(uint64(i)) - default: + if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 } + + if vr.Len() != 8 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float8: %d", vr.Len()))) + return 0 + } + + i := vr.ReadInt64() + return math.Float64frombits(uint64(i)) } func encodeFloat8(w *WriteBuf, value interface{}) error { @@ -1021,21 +969,12 @@ func decodeBytea(vr *ValueReader) []byte { return nil } - switch vr.Type().FormatCode { - case TextFormatCode: - s := vr.ReadString(vr.Len()) - b, err := hex.DecodeString(s[2:]) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode byte array: %v - %v", err, s))) - return nil - } - return b - case BinaryFormatCode: - return vr.ReadBytes(vr.Len()) - default: + if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return nil } + + return vr.ReadBytes(vr.Len()) } func encodeBytea(w *WriteBuf, value interface{}) error { @@ -1063,25 +1002,16 @@ func decodeDate(vr *ValueReader) time.Time { return zeroTime } - switch vr.Type().FormatCode { - case TextFormatCode: - s := vr.ReadString(vr.Len()) - t, err := time.ParseInLocation("2006-01-02", s, time.Local) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode date: %v", s))) - return zeroTime - } - return t - case BinaryFormatCode: - if vr.Len() != 4 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an date: %d", vr.Len()))) - } - dayOffset := vr.ReadInt32() - return time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.Local) - default: + if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return zeroTime } + + if vr.Len() != 4 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an date: %d", vr.Len()))) + } + dayOffset := vr.ReadInt32() + return time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.Local) } func encodeDate(w *WriteBuf, value interface{}) error { @@ -1109,26 +1039,19 @@ func decodeTimestampTz(vr *ValueReader) time.Time { return zeroTime } - switch vr.Type().FormatCode { - case TextFormatCode: - s := vr.ReadString(vr.Len()) - t, err := time.Parse("2006-01-02 15:04:05.999999-07", s) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode timestamptz: %v - %v", err, s))) - return zeroTime - } - return t - case BinaryFormatCode: - if vr.Len() != 8 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamptz: %d", vr.Len()))) - } - microsecSinceY2K := vr.ReadInt64() - microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K - return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) - default: + if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return zeroTime } + + if vr.Len() != 8 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamptz: %d", vr.Len()))) + return zeroTime + } + + microsecSinceY2K := vr.ReadInt64() + microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K + return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) } func encodeTimestampTz(w *WriteBuf, value interface{}) error { @@ -1159,26 +1082,18 @@ func decodeTimestamp(vr *ValueReader) time.Time { return zeroTime } - switch vr.Type().FormatCode { - case TextFormatCode: - s := vr.ReadString(vr.Len()) - t, err := time.ParseInLocation("2006-01-02 15:04:05.999999", s, time.Local) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode timestamp: %v - %v", err, s))) - return zeroTime - } - return t - case BinaryFormatCode: - if vr.Len() != 8 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamp: %d", vr.Len()))) - } - microsecSinceY2K := vr.ReadInt64() - microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K - return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) - default: + if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return zeroTime } + + if vr.Len() != 8 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamp: %d", vr.Len()))) + } + + microsecSinceY2K := vr.ReadInt64() + microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K + return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) } func encodeTimestamp(w *WriteBuf, value interface{}) error {