package pgx import ( "bytes" "database/sql/driver" "encoding/json" "fmt" "math" "net" "reflect" "strconv" "strings" "time" ) // PostgreSQL oids for common types const ( BoolOid = 16 ByteaOid = 17 Int8Oid = 20 Int2Oid = 21 Int4Oid = 23 TextOid = 25 OidOid = 26 JsonOid = 114 CidrOid = 650 CidrArrayOid = 651 Float4Oid = 700 Float8Oid = 701 InetOid = 869 BoolArrayOid = 1000 Int2ArrayOid = 1005 Int4ArrayOid = 1007 TextArrayOid = 1009 VarcharArrayOid = 1015 Int8ArrayOid = 1016 Float4ArrayOid = 1021 Float8ArrayOid = 1022 InetArrayOid = 1041 VarcharOid = 1043 DateOid = 1082 TimestampOid = 1114 TimestampArrayOid = 1115 TimestampTzOid = 1184 TimestampTzArrayOid = 1185 UuidOid = 2950 JsonbOid = 3802 ) // PostgreSQL format codes const ( TextFormatCode = 0 BinaryFormatCode = 1 ) // DefaultTypeFormats maps type names to their default requested format (text // or binary). In theory the Scanner interface should be the one to determine // the format of the returned values. However, the query has already been // executed by the time Scan is called so it has no chance to set the format. // So for types that should be returned in binary th var DefaultTypeFormats map[string]int16 func init() { DefaultTypeFormats = map[string]int16{ "_bool": BinaryFormatCode, "_cidr": BinaryFormatCode, "_float4": BinaryFormatCode, "_float8": BinaryFormatCode, "_inet": BinaryFormatCode, "_int2": BinaryFormatCode, "_int4": BinaryFormatCode, "_int8": BinaryFormatCode, "_text": BinaryFormatCode, "_timestamp": BinaryFormatCode, "_timestamptz": BinaryFormatCode, "_varchar": BinaryFormatCode, "bool": BinaryFormatCode, "bytea": BinaryFormatCode, "cidr": BinaryFormatCode, "date": BinaryFormatCode, "float4": BinaryFormatCode, "float8": BinaryFormatCode, "inet": BinaryFormatCode, "int2": BinaryFormatCode, "int4": BinaryFormatCode, "int8": BinaryFormatCode, "oid": BinaryFormatCode, "timestamp": BinaryFormatCode, "timestamptz": BinaryFormatCode, } } type SerializationError string func (e SerializationError) Error() string { return string(e) } // Scanner is an interface used to decode values from the PostgreSQL server. type Scanner interface { // Scan MUST check r.Type().DataType (to check by OID) or // r.Type().DataTypeName (to check by name) to ensure that it is scanning an // expected column type. It also MUST check r.Type().FormatCode before // decoding. It should not assume that it was called on a data type or format // that it understands. Scan(r *ValueReader) error } // Encoder is an interface used to encode values for transmission to the // PostgreSQL server. type Encoder interface { // Encode writes the value to w. // // If the value is NULL an int32(-1) should be written. // // Encode MUST check oid to see if the parameter data type is compatible. If // this is not done, the PostgreSQL server may detect the error if the // expected data size or format of the encoded data does not match. But if // the encoded data is a valid representation of the data type PostgreSQL // expects such as date and int4, incorrect data may be stored. Encode(w *WriteBuf, oid Oid) error // FormatCode returns the format that the encoder writes the value. It must be // either pgx.TextFormatCode or pgx.BinaryFormatCode. FormatCode() int16 } // NullFloat32 represents an float4 that may be null. NullFloat32 implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan. // // If Valid is false then the value is NULL. type NullFloat32 struct { Float32 float32 Valid bool // Valid is true if Float32 is not NULL } func (n *NullFloat32) Scan(vr *ValueReader) error { if vr.Type().DataType != Float4Oid { return SerializationError(fmt.Sprintf("NullFloat32.Scan cannot decode OID %d", vr.Type().DataType)) } if vr.Len() == -1 { n.Float32, n.Valid = 0, false return nil } n.Valid = true n.Float32 = decodeFloat4(vr) return vr.Err() } func (n NullFloat32) FormatCode() int16 { return BinaryFormatCode } func (n NullFloat32) Encode(w *WriteBuf, oid Oid) error { if oid != Float4Oid { return SerializationError(fmt.Sprintf("NullFloat32.Encode cannot encode into OID %d", oid)) } if !n.Valid { w.WriteInt32(-1) return nil } return encodeFloat32(w, oid, n.Float32) } // NullFloat64 represents an float8 that may be null. NullFloat64 implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan. // // If Valid is false then the value is NULL. type NullFloat64 struct { Float64 float64 Valid bool // Valid is true if Float64 is not NULL } func (n *NullFloat64) Scan(vr *ValueReader) error { if vr.Type().DataType != Float8Oid { return SerializationError(fmt.Sprintf("NullFloat64.Scan cannot decode OID %d", vr.Type().DataType)) } if vr.Len() == -1 { n.Float64, n.Valid = 0, false return nil } n.Valid = true n.Float64 = decodeFloat8(vr) return vr.Err() } func (n NullFloat64) FormatCode() int16 { return BinaryFormatCode } func (n NullFloat64) Encode(w *WriteBuf, oid Oid) error { if oid != Float8Oid { return SerializationError(fmt.Sprintf("NullFloat64.Encode cannot encode into OID %d", oid)) } if !n.Valid { w.WriteInt32(-1) return nil } return encodeFloat64(w, oid, n.Float64) } // NullString represents an string that may be null. NullString implements the // Scanner Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan. // // If Valid is false then the value is NULL. type NullString struct { String string Valid bool // Valid is true if String is not NULL } func (s *NullString) Scan(vr *ValueReader) error { // Not checking oid as so we can scan anything into into a NullString - may revisit this decision later if vr.Len() == -1 { s.String, s.Valid = "", false return nil } s.Valid = true s.String = decodeText(vr) return vr.Err() } func (n NullString) FormatCode() int16 { return TextFormatCode } func (s NullString) Encode(w *WriteBuf, oid Oid) error { if !s.Valid { w.WriteInt32(-1) return nil } return encodeString(w, oid, s.String) } // NullInt16 represents an smallint that may be null. NullInt16 implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan for prepared and unprepared queries. // // If Valid is false then the value is NULL. type NullInt16 struct { Int16 int16 Valid bool // Valid is true if Int16 is not NULL } func (n *NullInt16) Scan(vr *ValueReader) error { if vr.Type().DataType != Int2Oid { return SerializationError(fmt.Sprintf("NullInt16.Scan cannot decode OID %d", vr.Type().DataType)) } if vr.Len() == -1 { n.Int16, n.Valid = 0, false return nil } n.Valid = true n.Int16 = decodeInt2(vr) return vr.Err() } func (n NullInt16) FormatCode() int16 { return BinaryFormatCode } func (n NullInt16) Encode(w *WriteBuf, oid Oid) error { if oid != Int2Oid { return SerializationError(fmt.Sprintf("NullInt16.Encode cannot encode into OID %d", oid)) } if !n.Valid { w.WriteInt32(-1) return nil } return encodeInt16(w, oid, n.Int16) } // NullInt32 represents an integer that may be null. NullInt32 implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan. // // If Valid is false then the value is NULL. type NullInt32 struct { Int32 int32 Valid bool // Valid is true if Int32 is not NULL } func (n *NullInt32) Scan(vr *ValueReader) error { if vr.Type().DataType != Int4Oid { return SerializationError(fmt.Sprintf("NullInt32.Scan cannot decode OID %d", vr.Type().DataType)) } if vr.Len() == -1 { n.Int32, n.Valid = 0, false return nil } n.Valid = true n.Int32 = decodeInt4(vr) return vr.Err() } func (n NullInt32) FormatCode() int16 { return BinaryFormatCode } func (n NullInt32) Encode(w *WriteBuf, oid Oid) error { if oid != Int4Oid { return SerializationError(fmt.Sprintf("NullInt32.Encode cannot encode into OID %d", oid)) } if !n.Valid { w.WriteInt32(-1) return nil } return encodeInt32(w, oid, n.Int32) } // NullInt64 represents an bigint that may be null. NullInt64 implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan. // // If Valid is false then the value is NULL. type NullInt64 struct { Int64 int64 Valid bool // Valid is true if Int64 is not NULL } func (n *NullInt64) Scan(vr *ValueReader) error { if vr.Type().DataType != Int8Oid { return SerializationError(fmt.Sprintf("NullInt64.Scan cannot decode OID %d", vr.Type().DataType)) } if vr.Len() == -1 { n.Int64, n.Valid = 0, false return nil } n.Valid = true n.Int64 = decodeInt8(vr) return vr.Err() } func (n NullInt64) FormatCode() int16 { return BinaryFormatCode } func (n NullInt64) Encode(w *WriteBuf, oid Oid) error { if oid != Int8Oid { return SerializationError(fmt.Sprintf("NullInt64.Encode cannot encode into OID %d", oid)) } if !n.Valid { w.WriteInt32(-1) return nil } return encodeInt64(w, oid, n.Int64) } // NullBool represents an bool that may be null. NullBool implements the Scanner // and Encoder interfaces so it may be used both as an argument to Query[Row] // and a destination for Scan. // // If Valid is false then the value is NULL. type NullBool struct { Bool bool Valid bool // Valid is true if Bool is not NULL } func (n *NullBool) Scan(vr *ValueReader) error { if vr.Type().DataType != BoolOid { return SerializationError(fmt.Sprintf("NullBool.Scan cannot decode OID %d", vr.Type().DataType)) } if vr.Len() == -1 { n.Bool, n.Valid = false, false return nil } n.Valid = true n.Bool = decodeBool(vr) return vr.Err() } func (n NullBool) FormatCode() int16 { return BinaryFormatCode } func (n NullBool) Encode(w *WriteBuf, oid Oid) error { if oid != BoolOid { return SerializationError(fmt.Sprintf("NullBool.Encode cannot encode into OID %d", oid)) } if !n.Valid { w.WriteInt32(-1) return nil } return encodeBool(w, oid, n.Bool) } // NullTime represents an time.Time that may be null. NullTime implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan. It corresponds with the PostgreSQL // types timestamptz, timestamp, and date. // // If Valid is false then the value is NULL. type NullTime struct { Time time.Time Valid bool // Valid is true if Time is not NULL } func (n *NullTime) Scan(vr *ValueReader) error { oid := vr.Type().DataType if oid != TimestampTzOid && oid != TimestampOid && oid != DateOid { return SerializationError(fmt.Sprintf("NullTime.Scan cannot decode OID %d", vr.Type().DataType)) } if vr.Len() == -1 { n.Time, n.Valid = time.Time{}, false return nil } n.Valid = true switch oid { case TimestampTzOid: n.Time = decodeTimestampTz(vr) case TimestampOid: n.Time = decodeTimestamp(vr) case DateOid: n.Time = decodeDate(vr) } return vr.Err() } func (n NullTime) FormatCode() int16 { return BinaryFormatCode } func (n NullTime) Encode(w *WriteBuf, oid Oid) error { if oid != TimestampTzOid && oid != TimestampOid && oid != DateOid { return SerializationError(fmt.Sprintf("NullTime.Encode cannot encode into OID %d", oid)) } if !n.Valid { w.WriteInt32(-1) return nil } return encodeTime(w, oid, n.Time) } // Hstore represents an hstore column. It does not support a null column or null // key values (use NullHstore for this). Hstore implements the Scanner and // Encoder interfaces so it may be used both as an argument to Query[Row] and a // destination for Scan. type Hstore map[string]string func (h *Hstore) Scan(vr *ValueReader) error { //oid for hstore not standardized, so we check its type name if vr.Type().DataTypeName != "hstore" { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into Hstore", vr.Type().DataTypeName))) return nil } if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null column into Hstore")) return nil } switch vr.Type().FormatCode { case TextFormatCode: m, err := parseHstoreToMap(vr.ReadString(vr.Len())) if err != nil { vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err))) return nil } hm := Hstore(m) *h = hm return nil case BinaryFormatCode: vr.Fatal(ProtocolError("Can't decode binary hstore")) return nil default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return nil } } func (h Hstore) FormatCode() int16 { return TextFormatCode } func (h Hstore) Encode(w *WriteBuf, oid Oid) error { var buf bytes.Buffer i := 0 for k, v := range h { i++ ks := strings.Replace(k, `\`, `\\`, -1) ks = strings.Replace(ks, `"`, `\"`, -1) vs := strings.Replace(v, `\`, `\\`, -1) vs = strings.Replace(vs, `"`, `\"`, -1) buf.WriteString(fmt.Sprintf(`"%s"=>"%s"`, ks, vs)) if i < len(h) { buf.WriteString(", ") } } w.WriteInt32(int32(buf.Len())) w.WriteBytes(buf.Bytes()) return nil } // NullHstore represents an hstore column that can be null or have null values // associated with its keys. NullHstore implements the Scanner and Encoder // interfaces so it may be used both as an argument to Query[Row] and a // destination for Scan. // // If Valid is false, then the value of the entire hstore column is NULL // If any of the NullString values in Store has Valid set to false, the key // appears in the hstore column, but its value is explicitly set to NULL. type NullHstore struct { Hstore map[string]NullString Valid bool } func (h *NullHstore) Scan(vr *ValueReader) error { //oid for hstore not standardized, so we check its type name if vr.Type().DataTypeName != "hstore" { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into NullHstore", vr.Type().DataTypeName))) return nil } if vr.Len() == -1 { h.Valid = false return nil } switch vr.Type().FormatCode { case TextFormatCode: store, err := parseHstoreToNullHstore(vr.ReadString(vr.Len())) if err != nil { vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err))) return nil } h.Valid = true h.Hstore = store return nil case BinaryFormatCode: vr.Fatal(ProtocolError("Can't decode binary hstore")) return nil default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return nil } } func (h NullHstore) FormatCode() int16 { return TextFormatCode } func (h NullHstore) Encode(w *WriteBuf, oid Oid) error { var buf bytes.Buffer if !h.Valid { w.WriteInt32(-1) return nil } i := 0 for k, v := range h.Hstore { i++ ks := strings.Replace(k, `\`, `\\`, -1) ks = strings.Replace(ks, `"`, `\"`, -1) if v.Valid { vs := strings.Replace(v.String, `\`, `\\`, -1) vs = strings.Replace(vs, `"`, `\"`, -1) buf.WriteString(fmt.Sprintf(`"%s"=>"%s"`, ks, vs)) } else { buf.WriteString(fmt.Sprintf(`"%s"=>NULL`, ks)) } if i < len(h.Hstore) { buf.WriteString(", ") } } w.WriteInt32(int32(buf.Len())) w.WriteBytes(buf.Bytes()) return nil } // Encode encodes arg into wbuf as the type oid. This allows implementations // of the Encoder interface to delegate the actual work of encoding to the // built-in functionality. func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { if arg == nil { wbuf.WriteInt32(-1) return nil } switch arg := arg.(type) { case Encoder: return arg.Encode(wbuf, oid) case driver.Valuer: v, err := arg.Value() if err != nil { return err } return Encode(wbuf, oid, v) case string: return encodeString(wbuf, oid, arg) case []byte: return encodeByteSlice(wbuf, oid, arg) } if v := reflect.ValueOf(arg); v.Kind() == reflect.Ptr { if v.IsNil() { wbuf.WriteInt32(-1) return nil } else { arg = v.Elem().Interface() return Encode(wbuf, oid, arg) } } if oid == JsonOid || oid == JsonbOid { return encodeJson(wbuf, oid, arg) } switch arg := arg.(type) { case []string: return encodeStringSlice(wbuf, oid, arg) case bool: return encodeBool(wbuf, oid, arg) case []bool: return encodeBoolSlice(wbuf, oid, arg) case int8: return encodeInt8(wbuf, oid, arg) case uint8: return encodeUInt8(wbuf, oid, arg) case int16: return encodeInt16(wbuf, oid, arg) case []int16: return encodeInt16Slice(wbuf, oid, arg) case uint16: return encodeUInt16(wbuf, oid, arg) case []uint16: return encodeUInt16Slice(wbuf, oid, arg) case int32: return encodeInt32(wbuf, oid, arg) case []int32: return encodeInt32Slice(wbuf, oid, arg) case uint32: return encodeUInt32(wbuf, oid, arg) case []uint32: return encodeUInt32Slice(wbuf, oid, arg) case int64: return encodeInt64(wbuf, oid, arg) case []int64: return encodeInt64Slice(wbuf, oid, arg) case uint64: return encodeUInt64(wbuf, oid, arg) case []uint64: return encodeUInt64Slice(wbuf, oid, arg) case int: return encodeInt(wbuf, oid, arg) case float32: return encodeFloat32(wbuf, oid, arg) case []float32: return encodeFloat32Slice(wbuf, oid, arg) case float64: return encodeFloat64(wbuf, oid, arg) case []float64: return encodeFloat64Slice(wbuf, oid, arg) case time.Time: return encodeTime(wbuf, oid, arg) case []time.Time: return encodeTimeSlice(wbuf, oid, arg) case net.IP: return encodeIP(wbuf, oid, arg) case []net.IP: return encodeIPSlice(wbuf, oid, arg) case net.IPNet: return encodeIPNet(wbuf, oid, arg) case []net.IPNet: return encodeIPNetSlice(wbuf, oid, arg) case Oid: return encodeOid(wbuf, oid, arg) default: return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) } } // Decode decodes from vr into d. d must be a pointer. This allows // implementations of the Decoder interface to delegate the actual work of // decoding to the built-in functionality. func Decode(vr *ValueReader, d interface{}) error { switch v := d.(type) { case *bool: *v = decodeBool(vr) case *int64: *v = decodeInt8(vr) case *int16: *v = decodeInt2(vr) case *int32: *v = decodeInt4(vr) case *uint32: var valInt int32 switch vr.Type().DataType { case Int2Oid: valInt = int32(decodeInt2(vr)) case Int4Oid: valInt = decodeInt4(vr) default: return fmt.Errorf("Can't convert OID %v to uint32", vr.Type().DataType) } if valInt < 0 { return fmt.Errorf("%d is less than zero for uint32", valInt) } *v = uint32(valInt) case *uint64: var valInt int64 switch vr.Type().DataType { case Int2Oid: valInt = int64(decodeInt2(vr)) case Int4Oid: valInt = int64(decodeInt4(vr)) case Int8Oid: valInt = decodeInt8(vr) default: return fmt.Errorf("Can't convert OID %v to uint64", vr.Type().DataType) } if valInt < 0 { return fmt.Errorf("%d is less than zero for uint64", valInt) } *v = uint64(valInt) case *Oid: *v = decodeOid(vr) case *string: *v = decodeText(vr) case *float32: *v = decodeFloat4(vr) case *float64: *v = decodeFloat8(vr) case *[]bool: *v = decodeBoolArray(vr) case *[]int16: *v = decodeInt2Array(vr) case *[]uint16: *v = decodeInt2ArrayToUInt(vr) case *[]int32: *v = decodeInt4Array(vr) case *[]uint32: *v = decodeInt4ArrayToUInt(vr) case *[]int64: *v = decodeInt8Array(vr) case *[]uint64: *v = decodeInt8ArrayToUInt(vr) case *[]float32: *v = decodeFloat4Array(vr) case *[]float64: *v = decodeFloat8Array(vr) case *[]string: *v = decodeTextArray(vr) case *[]time.Time: *v = decodeTimestampArray(vr) case *time.Time: switch vr.Type().DataType { case DateOid: *v = decodeDate(vr) case TimestampTzOid: *v = decodeTimestampTz(vr) case TimestampOid: *v = decodeTimestamp(vr) default: return fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType) } case *net.IPNet: *v = decodeInet(vr) case *[]net.IPNet: *v = decodeInetArray(vr) default: // if d is a pointer to pointer, strip the pointer and try again if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr { if el := v.Elem(); el.Kind() == reflect.Ptr { // -1 is a null value if vr.Len() == -1 { if !el.IsNil() { // if the destination pointer is not nil, nil it out el.Set(reflect.Zero(el.Type())) } return nil } else { if el.IsNil() { // allocate destination el.Set(reflect.New(el.Type().Elem())) } d = el.Interface() return Decode(vr, d) } } } return fmt.Errorf("Scan cannot decode into %T", d) } return nil } func decodeBool(vr *ValueReader) bool { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into bool")) return false } if vr.Type().DataType != BoolOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into bool", vr.Type().DataType))) return false } if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return false } if vr.Len() != 1 { vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool: %d", vr.Len()))) return false } b := vr.ReadByte() return b != 0 } func encodeBool(w *WriteBuf, oid Oid, value bool) error { if oid != BoolOid { return fmt.Errorf("cannot encode Go %s into oid %d", "bool", oid) } w.WriteInt32(1) var n byte if value { n = 1 } w.WriteByte(n) return nil } func decodeInt8(vr *ValueReader) int64 { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into int64")) return 0 } if vr.Type().DataType != Int8Oid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int8", vr.Type().DataType))) return 0 } if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 } if vr.Len() != 8 { vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", vr.Len()))) return 0 } return vr.ReadInt64() } func decodeInt2(vr *ValueReader) int16 { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into int16")) return 0 } if vr.Type().DataType != Int2Oid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int16", vr.Type().DataType))) return 0 } if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 } if vr.Len() != 2 { vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2: %d", vr.Len()))) return 0 } return vr.ReadInt16() } func encodeInt8(w *WriteBuf, oid Oid, value int8) error { switch oid { case Int2Oid: w.WriteInt32(2) w.WriteInt16(int16(value)) case Int4Oid: w.WriteInt32(4) w.WriteInt32(int32(value)) case Int8Oid: w.WriteInt32(8) w.WriteInt64(int64(value)) default: return fmt.Errorf("cannot encode %s into oid %v", "int8", oid) } return nil } func encodeUInt8(w *WriteBuf, oid Oid, value uint8) error { switch oid { case Int2Oid: w.WriteInt32(2) w.WriteInt16(int16(value)) case Int4Oid: w.WriteInt32(4) w.WriteInt32(int32(value)) case Int8Oid: w.WriteInt32(8) w.WriteInt64(int64(value)) default: return fmt.Errorf("cannot encode %s into oid %v", "uint8", oid) } return nil } func encodeInt16(w *WriteBuf, oid Oid, value int16) error { switch oid { case Int2Oid: w.WriteInt32(2) w.WriteInt16(value) case Int4Oid: w.WriteInt32(4) w.WriteInt32(int32(value)) case Int8Oid: w.WriteInt32(8) w.WriteInt64(int64(value)) default: return fmt.Errorf("cannot encode %s into oid %v", "int16", oid) } return nil } func encodeUInt16(w *WriteBuf, oid Oid, value uint16) error { switch oid { case Int2Oid: if value <= math.MaxInt16 { w.WriteInt32(2) w.WriteInt16(int16(value)) } else { return fmt.Errorf("%d is larger than max int16 %d", value, math.MaxInt16) } case Int4Oid: w.WriteInt32(4) w.WriteInt32(int32(value)) case Int8Oid: w.WriteInt32(8) w.WriteInt64(int64(value)) default: return fmt.Errorf("cannot encode %s into oid %v", "int16", oid) } return nil } func encodeInt32(w *WriteBuf, oid Oid, value int32) error { switch oid { case Int2Oid: if value <= math.MaxInt16 { w.WriteInt32(2) w.WriteInt16(int16(value)) } else { return fmt.Errorf("%d is larger than max int16 %d", value, math.MaxInt16) } case Int4Oid: w.WriteInt32(4) w.WriteInt32(value) case Int8Oid: w.WriteInt32(8) w.WriteInt64(int64(value)) default: return fmt.Errorf("cannot encode %s into oid %v", "int32", oid) } return nil } func encodeUInt32(w *WriteBuf, oid Oid, value uint32) error { switch oid { case Int2Oid: if value <= math.MaxInt16 { w.WriteInt32(2) w.WriteInt16(int16(value)) } else { return fmt.Errorf("%d is larger than max int16 %d", value, math.MaxInt16) } case Int4Oid: if value <= math.MaxInt32 { w.WriteInt32(4) w.WriteInt32(int32(value)) } else { return fmt.Errorf("%d is larger than max int32 %d", value, math.MaxInt32) } case Int8Oid: w.WriteInt32(8) w.WriteInt64(int64(value)) default: return fmt.Errorf("cannot encode %s into oid %v", "uint32", oid) } return nil } func encodeInt64(w *WriteBuf, oid Oid, value int64) error { switch oid { case Int2Oid: if value <= math.MaxInt16 { w.WriteInt32(2) w.WriteInt16(int16(value)) } else { return fmt.Errorf("%d is larger than max int16 %d", value, math.MaxInt16) } case Int4Oid: if value <= math.MaxInt32 { w.WriteInt32(4) w.WriteInt32(int32(value)) } else { return fmt.Errorf("%d is larger than max int32 %d", value, math.MaxInt32) } case Int8Oid: w.WriteInt32(8) w.WriteInt64(value) default: return fmt.Errorf("cannot encode %s into oid %v", "int64", oid) } return nil } func encodeUInt64(w *WriteBuf, oid Oid, value uint64) error { switch oid { case Int2Oid: if value <= math.MaxInt16 { w.WriteInt32(2) w.WriteInt16(int16(value)) } else { return fmt.Errorf("%d is larger than max int16 %d", value, math.MaxInt16) } case Int4Oid: if value <= math.MaxInt32 { w.WriteInt32(4) w.WriteInt32(int32(value)) } else { return fmt.Errorf("%d is larger than max int32 %d", value, math.MaxInt32) } case Int8Oid: if value <= math.MaxInt64 { w.WriteInt32(8) w.WriteInt64(int64(value)) } else { return fmt.Errorf("%d is larger than max int64 %d", value, int64(math.MaxInt64)) } default: return fmt.Errorf("cannot encode %s into oid %v", "uint64", oid) } return nil } func encodeInt(w *WriteBuf, oid Oid, value int) error { switch oid { case Int2Oid: if value <= math.MaxInt16 { w.WriteInt32(2) w.WriteInt16(int16(value)) } else { return fmt.Errorf("%d is larger than max int16 %d", value, math.MaxInt16) } case Int4Oid: if value <= math.MaxInt32 { w.WriteInt32(4) w.WriteInt32(int32(value)) } else { return fmt.Errorf("%d is larger than max int32 %d", value, math.MaxInt32) } case Int8Oid: if int64(value) <= int64(math.MaxInt64) { w.WriteInt32(8) w.WriteInt64(int64(value)) } else { return fmt.Errorf("%d is larger than max int64 %d", value, int64(math.MaxInt64)) } default: return fmt.Errorf("cannot encode %s into oid %v", "uint64", oid) } return nil } func decodeInt4(vr *ValueReader) int32 { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into int32")) return 0 } if vr.Type().DataType != Int4Oid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int32", vr.Type().DataType))) return 0 } if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 } if vr.Len() != 4 { vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4: %d", vr.Len()))) return 0 } return vr.ReadInt32() } func decodeOid(vr *ValueReader) Oid { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into Oid")) return 0 } if vr.Type().DataType != OidOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Oid", vr.Type().DataType))) return 0 } // Oid needs to decode text format because it is used in loadPgTypes switch vr.Type().FormatCode { case TextFormatCode: s := vr.ReadString(vr.Len()) n, err := strconv.ParseInt(s, 10, 32) if err != nil { vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s))) } return Oid(n) case BinaryFormatCode: if vr.Len() != 4 { vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len()))) return 0 } return Oid(vr.ReadInt32()) default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return Oid(0) } } func encodeOid(w *WriteBuf, oid Oid, value Oid) error { if oid != OidOid { return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Oid", oid) } w.WriteInt32(4) w.WriteInt32(int32(value)) return nil } func decodeFloat4(vr *ValueReader) float32 { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into float32")) return 0 } if vr.Type().DataType != Float4Oid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float32", vr.Type().DataType))) return 0 } if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 } if vr.Len() != 4 { vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", vr.Len()))) return 0 } i := vr.ReadInt32() return math.Float32frombits(uint32(i)) } func encodeFloat32(w *WriteBuf, oid Oid, value float32) error { switch oid { case Float4Oid: w.WriteInt32(4) w.WriteInt32(int32(math.Float32bits(value))) case Float8Oid: w.WriteInt32(8) w.WriteInt64(int64(math.Float64bits(float64(value)))) default: return fmt.Errorf("cannot encode %s into oid %v", "float32", oid) } return nil } func decodeFloat8(vr *ValueReader) float64 { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into float64")) return 0 } if vr.Type().DataType != Float8Oid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float64", vr.Type().DataType))) return 0 } if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 } if vr.Len() != 8 { vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float8: %d", vr.Len()))) return 0 } i := vr.ReadInt64() return math.Float64frombits(uint64(i)) } func encodeFloat64(w *WriteBuf, oid Oid, value float64) error { switch oid { case Float8Oid: w.WriteInt32(8) w.WriteInt64(int64(math.Float64bits(value))) default: return fmt.Errorf("cannot encode %s into oid %v", "float64", oid) } return nil } func decodeText(vr *ValueReader) string { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into string")) return "" } return vr.ReadString(vr.Len()) } func encodeString(w *WriteBuf, oid Oid, value string) error { w.WriteInt32(int32(len(value))) w.WriteBytes([]byte(value)) return nil } func decodeBytea(vr *ValueReader) []byte { if vr.Len() == -1 { return nil } if vr.Type().DataType != ByteaOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []byte", vr.Type().DataType))) return nil } if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return nil } return vr.ReadBytes(vr.Len()) } func encodeByteSlice(w *WriteBuf, oid Oid, value []byte) error { w.WriteInt32(int32(len(value))) w.WriteBytes(value) return nil } func decodeJson(vr *ValueReader, d interface{}) error { if vr.Len() == -1 { return nil } if vr.Type().DataType != JsonOid && vr.Type().DataType != JsonbOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into json", vr.Type().DataType))) } bytes := vr.ReadBytes(vr.Len()) err := json.Unmarshal(bytes, d) if err != nil { vr.Fatal(err) } return err } func encodeJson(w *WriteBuf, oid Oid, value interface{}) error { if oid != JsonOid && oid != JsonbOid { return fmt.Errorf("cannot encode JSON into oid %v", oid) } s, err := json.Marshal(value) if err != nil { return fmt.Errorf("Failed to encode json from type: %T", value) } w.WriteInt32(int32(len(s))) w.WriteBytes(s) return nil } func decodeDate(vr *ValueReader) time.Time { var zeroTime time.Time if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into time.Time")) return zeroTime } if vr.Type().DataType != DateOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) return zeroTime } if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return zeroTime } if vr.Len() != 4 { vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an date: %d", vr.Len()))) } dayOffset := vr.ReadInt32() return time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.Local) } func encodeTime(w *WriteBuf, oid Oid, value time.Time) error { switch oid { case DateOid: tUnix := time.Date(value.Year(), value.Month(), value.Day(), 0, 0, 0, 0, time.UTC).Unix() dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() secSinceDateEpoch := tUnix - dateEpoch daysSinceDateEpoch := secSinceDateEpoch / 86400 w.WriteInt32(4) w.WriteInt32(int32(daysSinceDateEpoch)) return nil case TimestampTzOid, TimestampOid: microsecSinceUnixEpoch := value.Unix()*1000000 + int64(value.Nanosecond())/1000 microsecSinceY2K := microsecSinceUnixEpoch - microsecFromUnixEpochToY2K w.WriteInt32(8) w.WriteInt64(microsecSinceY2K) return nil default: return fmt.Errorf("cannot encode %s into oid %v", "time.Time", oid) } } const microsecFromUnixEpochToY2K = 946684800 * 1000000 func decodeTimestampTz(vr *ValueReader) time.Time { var zeroTime time.Time if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into time.Time")) return zeroTime } if vr.Type().DataType != TimestampTzOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) return zeroTime } if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return zeroTime } if vr.Len() != 8 { vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamptz: %d", vr.Len()))) return zeroTime } microsecSinceY2K := vr.ReadInt64() microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) } func decodeTimestamp(vr *ValueReader) time.Time { var zeroTime time.Time if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into timestamp")) return zeroTime } if vr.Type().DataType != TimestampOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) return zeroTime } if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return zeroTime } if vr.Len() != 8 { vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamp: %d", vr.Len()))) return zeroTime } microsecSinceY2K := vr.ReadInt64() microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) } func decodeInet(vr *ValueReader) net.IPNet { var zero net.IPNet if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into net.IPNet")) return zero } if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return zero } pgType := vr.Type() if vr.Len() != 8 && vr.Len() != 20 { vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a %s: %d", pgType.Name, vr.Len()))) return zero } if pgType.DataType != InetOid && pgType.DataType != CidrOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, pgType.Name))) return zero } vr.ReadByte() // ignore family bits := vr.ReadByte() vr.ReadByte() // ignore is_cidr addressLength := vr.ReadByte() var ipnet net.IPNet ipnet.IP = vr.ReadBytes(int32(addressLength)) ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8) return ipnet } func encodeIPNet(w *WriteBuf, oid Oid, value net.IPNet) error { if oid != InetOid && oid != CidrOid { return fmt.Errorf("cannot encode %s into oid %v", "net.IPNet", oid) } var size int32 var family byte switch len(value.IP) { case net.IPv4len: size = 8 family = w.conn.pgsql_af_inet case net.IPv6len: size = 20 family = w.conn.pgsql_af_inet6 default: return fmt.Errorf("Unexpected IP length: %v", len(value.IP)) } w.WriteInt32(size) w.WriteByte(family) ones, _ := value.Mask.Size() w.WriteByte(byte(ones)) w.WriteByte(0) // is_cidr is ignored on server w.WriteByte(byte(len(value.IP))) w.WriteBytes(value.IP) return nil } func encodeIP(w *WriteBuf, oid Oid, value net.IP) error { if oid != InetOid && oid != CidrOid { return fmt.Errorf("cannot encode %s into oid %v", "net.IP", oid) } var ipnet net.IPNet ipnet.IP = value bitCount := len(value) * 8 ipnet.Mask = net.CIDRMask(bitCount, bitCount) return encodeIPNet(w, oid, ipnet) } func decode1dArrayHeader(vr *ValueReader) (length int32, err error) { numDims := vr.ReadInt32() if numDims > 1 { return 0, ProtocolError(fmt.Sprintf("Expected array to have 0 or 1 dimension, but it had %v", numDims)) } vr.ReadInt32() // 0 if no nulls / 1 if there is one or more nulls -- but we don't care vr.ReadInt32() // element oid if numDims == 0 { return 0, nil } length = vr.ReadInt32() idxFirstElem := vr.ReadInt32() if idxFirstElem != 1 { return 0, ProtocolError(fmt.Sprintf("Expected array's first element to start a index 1, but it is %d", idxFirstElem)) } return length, nil } func decodeBoolArray(vr *ValueReader) []bool { if vr.Len() == -1 { return nil } if vr.Type().DataType != BoolArrayOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []bool", vr.Type().DataType))) return nil } if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return nil } numElems, err := decode1dArrayHeader(vr) if err != nil { vr.Fatal(err) return nil } a := make([]bool, int(numElems)) for i := 0; i < len(a); i++ { elSize := vr.ReadInt32() switch elSize { case 1: if vr.ReadByte() == 1 { a[i] = true } case -1: vr.Fatal(ProtocolError("Cannot decode null element")) return nil default: vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool element: %d", elSize))) return nil } } return a } func encodeBoolSlice(w *WriteBuf, oid Oid, slice []bool) error { if oid != BoolArrayOid { return fmt.Errorf("cannot encode Go %s into oid %d", "[]bool", oid) } encodeArrayHeader(w, BoolOid, len(slice), 5) for _, v := range slice { w.WriteInt32(1) var b byte if v { b = 1 } w.WriteByte(b) } return nil } func decodeInt2Array(vr *ValueReader) []int16 { if vr.Len() == -1 { return nil } if vr.Type().DataType != Int2ArrayOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int16", vr.Type().DataType))) return nil } if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return nil } numElems, err := decode1dArrayHeader(vr) if err != nil { vr.Fatal(err) return nil } a := make([]int16, int(numElems)) for i := 0; i < len(a); i++ { elSize := vr.ReadInt32() switch elSize { case 2: a[i] = vr.ReadInt16() case -1: vr.Fatal(ProtocolError("Cannot decode null element")) return nil default: vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2 element: %d", elSize))) return nil } } return a } func decodeInt2ArrayToUInt(vr *ValueReader) []uint16 { if vr.Len() == -1 { return nil } if vr.Type().DataType != Int2ArrayOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint16", vr.Type().DataType))) return nil } if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return nil } numElems, err := decode1dArrayHeader(vr) if err != nil { vr.Fatal(err) return nil } a := make([]uint16, int(numElems)) for i := 0; i < len(a); i++ { elSize := vr.ReadInt32() switch elSize { case 2: tmp := vr.ReadInt16() if tmp < 0 { vr.Fatal(ProtocolError(fmt.Sprintf("%d is less than zero for uint16", tmp))) return nil } a[i] = uint16(tmp) case -1: vr.Fatal(ProtocolError("Cannot decode null element")) return nil default: vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2 element: %d", elSize))) return nil } } return a } func encodeInt16Slice(w *WriteBuf, oid Oid, slice []int16) error { if oid != Int2ArrayOid { return fmt.Errorf("cannot encode Go %s into oid %d", "[]int16", oid) } encodeArrayHeader(w, Int2Oid, len(slice), 6) for _, v := range slice { w.WriteInt32(2) w.WriteInt16(v) } return nil } func encodeUInt16Slice(w *WriteBuf, oid Oid, slice []uint16) error { if oid != Int2ArrayOid { return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint16", oid) } encodeArrayHeader(w, Int2Oid, len(slice), 6) for _, v := range slice { if v <= math.MaxInt16 { w.WriteInt32(2) w.WriteInt16(int16(v)) } else { return fmt.Errorf("%d is larger than max smallint %d", v, math.MaxInt16) } } return nil } func decodeInt4Array(vr *ValueReader) []int32 { if vr.Len() == -1 { return nil } if vr.Type().DataType != Int4ArrayOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int32", vr.Type().DataType))) return nil } if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return nil } numElems, err := decode1dArrayHeader(vr) if err != nil { vr.Fatal(err) return nil } a := make([]int32, int(numElems)) for i := 0; i < len(a); i++ { elSize := vr.ReadInt32() switch elSize { case 4: a[i] = vr.ReadInt32() case -1: vr.Fatal(ProtocolError("Cannot decode null element")) return nil default: vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4 element: %d", elSize))) return nil } } return a } func decodeInt4ArrayToUInt(vr *ValueReader) []uint32 { if vr.Len() == -1 { return nil } if vr.Type().DataType != Int4ArrayOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint32", vr.Type().DataType))) return nil } if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return nil } numElems, err := decode1dArrayHeader(vr) if err != nil { vr.Fatal(err) return nil } a := make([]uint32, int(numElems)) for i := 0; i < len(a); i++ { elSize := vr.ReadInt32() switch elSize { case 4: tmp := vr.ReadInt32() if tmp < 0 { vr.Fatal(ProtocolError(fmt.Sprintf("%d is less than zero for uint32", tmp))) return nil } a[i] = uint32(tmp) case -1: vr.Fatal(ProtocolError("Cannot decode null element")) return nil default: vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4 element: %d", elSize))) return nil } } return a } func encodeInt32Slice(w *WriteBuf, oid Oid, slice []int32) error { if oid != Int4ArrayOid { return fmt.Errorf("cannot encode Go %s into oid %d", "[]int32", oid) } encodeArrayHeader(w, Int4Oid, len(slice), 8) for _, v := range slice { w.WriteInt32(4) w.WriteInt32(v) } return nil } func encodeUInt32Slice(w *WriteBuf, oid Oid, slice []uint32) error { if oid != Int4ArrayOid { return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint32", oid) } encodeArrayHeader(w, Int4Oid, len(slice), 8) for _, v := range slice { if v <= math.MaxInt32 { w.WriteInt32(4) w.WriteInt32(int32(v)) } else { return fmt.Errorf("%d is larger than max integer %d", v, math.MaxInt32) } } return nil } func decodeInt8Array(vr *ValueReader) []int64 { if vr.Len() == -1 { return nil } if vr.Type().DataType != Int8ArrayOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int64", vr.Type().DataType))) return nil } if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return nil } numElems, err := decode1dArrayHeader(vr) if err != nil { vr.Fatal(err) return nil } a := make([]int64, int(numElems)) for i := 0; i < len(a); i++ { elSize := vr.ReadInt32() switch elSize { case 8: a[i] = vr.ReadInt64() case -1: vr.Fatal(ProtocolError("Cannot decode null element")) return nil default: vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8 element: %d", elSize))) return nil } } return a } func decodeInt8ArrayToUInt(vr *ValueReader) []uint64 { if vr.Len() == -1 { return nil } if vr.Type().DataType != Int8ArrayOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint64", vr.Type().DataType))) return nil } if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return nil } numElems, err := decode1dArrayHeader(vr) if err != nil { vr.Fatal(err) return nil } a := make([]uint64, int(numElems)) for i := 0; i < len(a); i++ { elSize := vr.ReadInt32() switch elSize { case 8: tmp := vr.ReadInt64() if tmp < 0 { vr.Fatal(ProtocolError(fmt.Sprintf("%d is less than zero for uint64", tmp))) return nil } a[i] = uint64(tmp) case -1: vr.Fatal(ProtocolError("Cannot decode null element")) return nil default: vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8 element: %d", elSize))) return nil } } return a } func encodeInt64Slice(w *WriteBuf, oid Oid, slice []int64) error { if oid != Int8ArrayOid { return fmt.Errorf("cannot encode Go %s into oid %d", "[]int64", oid) } encodeArrayHeader(w, Int8Oid, len(slice), 12) for _, v := range slice { w.WriteInt32(8) w.WriteInt64(v) } return nil } func encodeUInt64Slice(w *WriteBuf, oid Oid, slice []uint64) error { if oid != Int8ArrayOid { return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint64", oid) } encodeArrayHeader(w, Int8Oid, len(slice), 12) for _, v := range slice { if v <= math.MaxInt64 { w.WriteInt32(8) w.WriteInt64(int64(v)) } else { return fmt.Errorf("%d is larger than max bigint %d", v, int64(math.MaxInt64)) } } return nil } func decodeFloat4Array(vr *ValueReader) []float32 { if vr.Len() == -1 { return nil } if vr.Type().DataType != Float4ArrayOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []float32", vr.Type().DataType))) return nil } if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return nil } numElems, err := decode1dArrayHeader(vr) if err != nil { vr.Fatal(err) return nil } a := make([]float32, int(numElems)) for i := 0; i < len(a); i++ { elSize := vr.ReadInt32() switch elSize { case 4: n := vr.ReadInt32() a[i] = math.Float32frombits(uint32(n)) case -1: vr.Fatal(ProtocolError("Cannot decode null element")) return nil default: vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4 element: %d", elSize))) return nil } } return a } func encodeFloat32Slice(w *WriteBuf, oid Oid, slice []float32) error { if oid != Float4ArrayOid { return fmt.Errorf("cannot encode Go %s into oid %d", "[]float32", oid) } encodeArrayHeader(w, Float4Oid, len(slice), 8) for _, v := range slice { w.WriteInt32(4) w.WriteInt32(int32(math.Float32bits(v))) } return nil } func decodeFloat8Array(vr *ValueReader) []float64 { if vr.Len() == -1 { return nil } if vr.Type().DataType != Float8ArrayOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []float64", vr.Type().DataType))) return nil } if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return nil } numElems, err := decode1dArrayHeader(vr) if err != nil { vr.Fatal(err) return nil } a := make([]float64, int(numElems)) for i := 0; i < len(a); i++ { elSize := vr.ReadInt32() switch elSize { case 8: n := vr.ReadInt64() a[i] = math.Float64frombits(uint64(n)) case -1: vr.Fatal(ProtocolError("Cannot decode null element")) return nil default: vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4 element: %d", elSize))) return nil } } return a } func encodeFloat64Slice(w *WriteBuf, oid Oid, slice []float64) error { if oid != Float8ArrayOid { return fmt.Errorf("cannot encode Go %s into oid %d", "[]float64", oid) } encodeArrayHeader(w, Float8Oid, len(slice), 12) for _, v := range slice { w.WriteInt32(8) w.WriteInt64(int64(math.Float64bits(v))) } return nil } func decodeTextArray(vr *ValueReader) []string { if vr.Len() == -1 { return nil } if vr.Type().DataType != TextArrayOid && vr.Type().DataType != VarcharArrayOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []string", vr.Type().DataType))) return nil } if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return nil } numElems, err := decode1dArrayHeader(vr) if err != nil { vr.Fatal(err) return nil } a := make([]string, int(numElems)) for i := 0; i < len(a); i++ { elSize := vr.ReadInt32() if elSize == -1 { vr.Fatal(ProtocolError("Cannot decode null element")) return nil } a[i] = vr.ReadString(elSize) } return a } func encodeStringSlice(w *WriteBuf, oid Oid, slice []string) error { var elOid Oid switch oid { case VarcharArrayOid: elOid = VarcharOid case TextArrayOid: elOid = TextOid default: return fmt.Errorf("cannot encode Go %s into oid %d", "[]string", oid) } var totalStringSize int for _, v := range slice { totalStringSize += len(v) } size := 20 + len(slice)*4 + totalStringSize w.WriteInt32(int32(size)) w.WriteInt32(1) // number of dimensions w.WriteInt32(0) // no nulls w.WriteInt32(int32(elOid)) // type of elements w.WriteInt32(int32(len(slice))) // number of elements w.WriteInt32(1) // index of first element for _, v := range slice { w.WriteInt32(int32(len(v))) w.WriteBytes([]byte(v)) } return nil } func decodeTimestampArray(vr *ValueReader) []time.Time { if vr.Len() == -1 { return nil } if vr.Type().DataType != TimestampArrayOid && vr.Type().DataType != TimestampTzArrayOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []time.Time", vr.Type().DataType))) return nil } if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return nil } numElems, err := decode1dArrayHeader(vr) if err != nil { vr.Fatal(err) return nil } a := make([]time.Time, int(numElems)) for i := 0; i < len(a); i++ { elSize := vr.ReadInt32() switch elSize { case 8: microsecSinceY2K := vr.ReadInt64() microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K a[i] = time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) case -1: vr.Fatal(ProtocolError("Cannot decode null element")) return nil default: vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an time.Time element: %d", elSize))) return nil } } return a } func encodeTimeSlice(w *WriteBuf, oid Oid, slice []time.Time) error { var elOid Oid switch oid { case TimestampArrayOid: elOid = TimestampOid case TimestampTzArrayOid: elOid = TimestampTzOid default: return fmt.Errorf("cannot encode Go %s into oid %d", "[]time.Time", oid) } encodeArrayHeader(w, int(elOid), len(slice), 12) for _, t := range slice { w.WriteInt32(8) microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000 microsecSinceY2K := microsecSinceUnixEpoch - microsecFromUnixEpochToY2K w.WriteInt64(microsecSinceY2K) } return nil } func decodeInetArray(vr *ValueReader) []net.IPNet { if vr.Len() == -1 { return nil } if vr.Type().DataType != InetArrayOid && vr.Type().DataType != CidrArrayOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []net.IP", vr.Type().DataType))) return nil } if vr.Type().FormatCode != BinaryFormatCode { vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return nil } numElems, err := decode1dArrayHeader(vr) if err != nil { vr.Fatal(err) return nil } a := make([]net.IPNet, int(numElems)) for i := 0; i < len(a); i++ { elSize := vr.ReadInt32() if elSize == -1 { vr.Fatal(ProtocolError("Cannot decode null element")) return nil } vr.ReadByte() // ignore family bits := vr.ReadByte() vr.ReadByte() // ignore is_cidr addressLength := vr.ReadByte() var ipnet net.IPNet ipnet.IP = vr.ReadBytes(int32(addressLength)) ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8) a[i] = ipnet } return a } func encodeIPNetSlice(w *WriteBuf, oid Oid, slice []net.IPNet) error { var elOid Oid switch oid { case InetArrayOid: elOid = InetOid case CidrArrayOid: elOid = CidrOid default: return fmt.Errorf("cannot encode Go %s into oid %d", "[]net.IPNet", oid) } size := int32(20) // array header size for _, ipnet := range slice { size += 4 + 4 + int32(len(ipnet.IP)) // size of element + inet/cidr metadata + IP bytes } w.WriteInt32(int32(size)) w.WriteInt32(1) // number of dimensions w.WriteInt32(0) // no nulls w.WriteInt32(int32(elOid)) // type of elements w.WriteInt32(int32(len(slice))) // number of elements w.WriteInt32(1) // index of first element for _, ipnet := range slice { encodeIPNet(w, elOid, ipnet) } return nil } func encodeIPSlice(w *WriteBuf, oid Oid, slice []net.IP) error { var elOid Oid switch oid { case InetArrayOid: elOid = InetOid case CidrArrayOid: elOid = CidrOid default: return fmt.Errorf("cannot encode Go %s into oid %d", "[]net.IPNet", oid) } size := int32(20) // array header size for _, ip := range slice { size += 4 + 4 + int32(len(ip)) // size of element + inet/cidr metadata + IP bytes } w.WriteInt32(int32(size)) w.WriteInt32(1) // number of dimensions w.WriteInt32(0) // no nulls w.WriteInt32(int32(elOid)) // type of elements w.WriteInt32(int32(len(slice))) // number of elements w.WriteInt32(1) // index of first element for _, ip := range slice { encodeIP(w, elOid, ip) } return nil } func encodeArrayHeader(w *WriteBuf, oid, length, sizePerItem int) { w.WriteInt32(int32(20 + length*sizePerItem)) w.WriteInt32(1) // number of dimensions w.WriteInt32(0) // no nulls w.WriteInt32(int32(oid)) // type of elements w.WriteInt32(int32(length)) // number of elements w.WriteInt32(1) // index of first element }