diff --git a/pgtype/int2.go b/pgtype/int2.go new file mode 100644 index 00000000..cf096a62 --- /dev/null +++ b/pgtype/int2.go @@ -0,0 +1,75 @@ +package pgtype + +import ( + "fmt" + "io" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +type Int2 int32 + +func (i *Int2) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + return fmt.Errorf("invalid length for int2: %v", size) + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseInt(string(buf), 10, 16) + if err != nil { + return err + } + + *i = Int2(n) + return nil +} + +func (i *Int2) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size != 2 { + return fmt.Errorf("invalid length for int2: %v", size) + } + + n, err := pgio.ReadInt16(r) + if err != nil { + return err + } + + *i = Int2(n) + return nil +} + +func (i Int2) EncodeText(w io.Writer) error { + s := strconv.FormatInt(int64(i), 10) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +func (i Int2) EncodeBinary(w io.Writer) error { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = pgio.WriteInt16(w, int16(i)) + return err +} diff --git a/pgtype/int4range.go b/pgtype/int4range.go deleted file mode 100644 index 12b86566..00000000 --- a/pgtype/int4range.go +++ /dev/null @@ -1,85 +0,0 @@ -package pgtype - -import ( - "encoding/binary" - "fmt" - "io" - "strconv" -) - -type Int4range struct { - Lower int32 - Upper int32 - LowerType BoundType - UpperType BoundType -} - -func (r *Int4range) ParseText(src string) error { - utr, err := ParseUntypedTextRange(src) - if err != nil { - return err - } - - r.LowerType = utr.LowerType - r.UpperType = utr.UpperType - - if r.LowerType == Empty { - return nil - } - - if r.LowerType == Inclusive || r.LowerType == Exclusive { - n, err := strconv.ParseInt(utr.Lower, 10, 32) - if err != nil { - return err - } - r.Lower = int32(n) - } - - if r.UpperType == Inclusive || r.UpperType == Exclusive { - n, err := strconv.ParseInt(utr.Upper, 10, 32) - if err != nil { - return err - } - r.Upper = int32(n) - } - - return nil -} - -func (r *Int4range) ParseBinary(src []byte) error { - ubr, err := ParseUntypedBinaryRange(src) - if err != nil { - return err - } - - r.LowerType = ubr.LowerType - r.UpperType = ubr.UpperType - - if r.LowerType == Empty { - return nil - } - - if r.LowerType == Inclusive || r.LowerType == Exclusive { - if len(ubr.Lower) != 4 { - return fmt.Errorf("invalid length for lower int4: %v", len(ubr.Lower)) - } - r.Lower = int32(binary.BigEndian.Uint32(ubr.Lower)) - } - - if r.UpperType == Inclusive || r.UpperType == Exclusive { - if len(ubr.Upper) != 4 { - return fmt.Errorf("invalid length for upper int4: %v", len(ubr.Upper)) - } - r.Upper = int32(binary.BigEndian.Uint32(ubr.Upper)) - } - - return nil -} - -func (r *Int4range) FormatText(w io.Writer) error { - return nil -} - -func (r *Int4range) FormatBinary(w io.Writer) error { - return nil -} diff --git a/pgtype/int4range_test.go b/pgtype/int4range_test.go deleted file mode 100644 index dae66f7a..00000000 --- a/pgtype/int4range_test.go +++ /dev/null @@ -1,186 +0,0 @@ -package pgtype - -import ( - "os" - "testing" - - "github.com/jackc/pgx" -) - -// func TestInt4rangeText(t *testing.T) { -// conns := mustConnectAll(t) -// defer mustCloseAll(t, conns) - -// tests := []struct { -// name string -// sql string -// args []interface{} -// err error -// result Int4range -// }{ -// { -// name: "Normal", -// sql: "select $1::int4range", -// args: []interface{}{&Int4range{Lower: 1, Upper: 10, LowerType: Inclusive, UpperType: Exclusive}}, -// err: nil, -// result: Int4range{Lower: 1, Upper: 10, LowerType: Inclusive, UpperType: Exclusive}, -// }, -// { -// name: "Negative", -// sql: "select int4range(-42, -5)", -// args: []interface{}{&Int4range{Lower: -42, Upper: -5, LowerType: Inclusive, UpperType: Exclusive}}, -// err: nil, -// result: Int4range{Lower: -42, Upper: -5, LowerType: Inclusive, UpperType: Exclusive}, -// }, -// { -// name: "Normalized Bounds", -// sql: "select int4range(1, 10, '(]')", -// args: []interface{}{Int4range{Lower: 1, Upper: 10, LowerType: Exclusive, UpperType: Inclusive}}, -// err: nil, -// result: Int4range{Lower: 2, Upper: 11, LowerType: Inclusive, UpperType: Exclusive}, -// }, -// } - -// for _, conn := range conns { -// for _, tt := range tests { -// var r Int4range -// var s string -// err := conn.QueryRow(tt.sql, tt.args...).Scan(&s) -// if err != tt.err { -// t.Errorf("%s %s: %v", conn.DriverName(), tt.name, err) -// } - -// err = r.ParseText(s) -// if err != nil { -// t.Errorf("%s %s: %v", conn.DriverName(), tt.name, err) -// } - -// if r != tt.result { -// t.Errorf("%s %s: expected %#v, got %#v", conn.DriverName(), tt.name, tt.result, r) -// } -// } -// } -// } - -func TestInt4rangeParseText(t *testing.T) { - conns := mustConnectAll(t) - defer mustCloseAll(t, conns) - - tests := []struct { - name string - sql string - args []interface{} - err error - result Int4range - }{ - { - name: "Scan", - sql: "select int4range(1, 10)", - args: []interface{}{}, - err: nil, - result: Int4range{Lower: 1, Upper: 10, LowerType: Inclusive, UpperType: Exclusive}, - }, - { - name: "Scan Negative", - sql: "select int4range(-42, -5)", - args: []interface{}{}, - err: nil, - result: Int4range{Lower: -42, Upper: -5, LowerType: Inclusive, UpperType: Exclusive}, - }, - { - name: "Scan Normalized Bounds", - sql: "select int4range(1, 10, '(]')", - args: []interface{}{}, - err: nil, - result: Int4range{Lower: 2, Upper: 11, LowerType: Inclusive, UpperType: Exclusive}, - }, - } - - for _, conn := range conns { - for _, tt := range tests { - var r Int4range - var s string - err := conn.QueryRow(tt.sql, tt.args...).Scan(&s) - if err != tt.err { - t.Errorf("%s %s: %v", conn.DriverName(), tt.name, err) - } - - err = r.ParseText(s) - if err != nil { - t.Errorf("%s %s: %v", conn.DriverName(), tt.name, err) - } - - if r != tt.result { - t.Errorf("%s %s: expected %#v, got %#v", conn.DriverName(), tt.name, tt.result, r) - } - } - } -} - -func TestInt4rangeParseBinary(t *testing.T) { - config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) - if err != nil { - t.Fatal(err) - } - - conn, err := pgx.Connect(config) - if err != nil { - t.Fatal(err) - } - defer mustClose(t, conn) - - tests := []struct { - name string - sql string - args []interface{} - err error - result Int4range - }{ - { - name: "Scan", - sql: "select int4range(1, 10)", - args: []interface{}{}, - err: nil, - result: Int4range{Lower: 1, Upper: 10, LowerType: Inclusive, UpperType: Exclusive}, - }, - { - name: "Scan Negative", - sql: "select int4range(-42, -5)", - args: []interface{}{}, - err: nil, - result: Int4range{Lower: -42, Upper: -5, LowerType: Inclusive, UpperType: Exclusive}, - }, - { - name: "Scan Normalized Bounds", - sql: "select int4range(1, 10, '(]')", - args: []interface{}{}, - err: nil, - result: Int4range{Lower: 2, Upper: 11, LowerType: Inclusive, UpperType: Exclusive}, - }, - } - - for _, tt := range tests { - ps, err := conn.Prepare(tt.sql, tt.sql) - if err != nil { - t.Errorf("conn.Prepare failed: %v", err) - continue - } - ps.FieldDescriptions[0].FormatCode = pgx.BinaryFormatCode - - var r Int4range - var buf []byte - err = conn.QueryRow(tt.sql, tt.args...).Scan(&buf) - if err != tt.err { - t.Errorf("%s: %v", tt.name, err) - } - - err = r.ParseBinary(buf) - if err != nil { - t.Errorf("%s: %v", tt.name, err) - } - - if r != tt.result { - t.Errorf("%s: expected %#v, got %#v", tt.name, tt.result, r) - } - } -} diff --git a/pgtype/int8.go b/pgtype/int8.go new file mode 100644 index 00000000..5592a13b --- /dev/null +++ b/pgtype/int8.go @@ -0,0 +1,75 @@ +package pgtype + +import ( + "fmt" + "io" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +type Int8 int64 + +func (i *Int8) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + return fmt.Errorf("invalid length for int8: %v", size) + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseInt(string(buf), 10, 64) + if err != nil { + return err + } + + *i = Int8(n) + return nil +} + +func (i *Int8) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size != 8 { + return fmt.Errorf("invalid length for int8: %v", size) + } + + n, err := pgio.ReadInt64(r) + if err != nil { + return err + } + + *i = Int8(n) + return nil +} + +func (i Int8) EncodeText(w io.Writer) error { + s := strconv.FormatInt(int64(i), 10) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +func (i Int8) EncodeBinary(w io.Writer) error { + _, err := pgio.WriteInt32(w, 8) + if err != nil { + return err + } + + _, err = pgio.WriteInt64(w, int64(i)) + return err +} diff --git a/pgtype/pgxtype.go b/pgtype/pgxtype.go deleted file mode 100644 index 859332ea..00000000 --- a/pgtype/pgxtype.go +++ /dev/null @@ -1,141 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "fmt" -) - -// type FieldDescription interface { -// Name() string -// Table() uint32 -// AttributeNumber() int16 -// DataType() uint32 -// DataTypeSize() int16 -// DataTypeName() string -// Modifier() int32 -// FormatCode() int16 -// } - -// Remember need to delegate for server controlled format like inet - -// Or separate interfaces for raw bytes and preprocessed by pgx? - -// Or interface{} like database/sql - and just pre-process into more things - -// type ScannerV3 interface { -// ScanPgxV3(fieldDescription FieldDescription, src interface{}) error -// } - -// // Encoders could also return interface{} to delegate to internal pgx - -// type TextEncoderV3 interface { -// EncodeTextPgxV3(oid uint32) (interface{}, error) -// } - -// type BinaryEncoderV3 interface { -// EncodeBinaryPgxV3(oid uint32) (interface{}, error) -// } - -// const ( -// Int4OID = 23 -// ) - -type Status byte - -const ( - Undefined Status = iota - Null - Present -) - -func (s Status) String() string { - switch s { - case Undefined: - return "Undefined" - case Null: - return "Null" - case Present: - return "Present" - } - - return "Invalid status" -} - -type Int32Box struct { - Value2 int32 - Status Status -} - -func (s *Int32Box) ScanPgxV3(fieldDescription interface{}, src interface{}) error { - switch v := src.(type) { - case int64: - s.Value2 = int32(v) - s.Status = Present - default: - return fmt.Errorf("cannot scan %v (%T)", v, v) - } - - return nil -} - -func (s *Int32Box) Scan(src interface{}) error { - switch v := src.(type) { - case int64: - s.Value2 = int32(v) - s.Status = Present - // TODO - should this have to accept all integer types? - case int32: - s.Value2 = int32(v) - s.Status = Present - default: - return fmt.Errorf("cannot scan %v (%T)", v, v) - } - - return nil -} - -func (s Int32Box) Value() (driver.Value, error) { - // if !n.Valid { - - // return nil, nil - - // } - - return int64(s.Value2), nil - -} - -type StringBox struct { - Value string - Status Status -} - -func (s *StringBox) ScanPgxV3(fieldDescription interface{}, src interface{}) error { - switch v := src.(type) { - case string: - s.Value = v - s.Status = Present - case []byte: - s.Value = string(v) - s.Status = Present - default: - return fmt.Errorf("cannot scan %v (%T)", v, v) - } - - return nil -} - -func (s *StringBox) Scan(src interface{}) error { - switch v := src.(type) { - case string: - s.Value = v - s.Status = Present - case []byte: - s.Value = string(v) - s.Status = Present - default: - return fmt.Errorf("cannot scan %v (%T)", v, v) - } - - return nil -} diff --git a/pgtype/pgxtype_test.go b/pgtype/pgxtype_test.go deleted file mode 100644 index 30f4b88a..00000000 --- a/pgtype/pgxtype_test.go +++ /dev/null @@ -1,186 +0,0 @@ -package pgtype - -import ( - "database/sql" - "os" - "testing" - - "github.com/jackc/pgx" - _ "github.com/jackc/pgx/stdlib" - _ "github.com/lib/pq" -) - -func mustConnectAll(t testing.TB) []QueryRowCloser { - return []QueryRowCloser{ - mustConnectPgx(t), - mustConnectDatabaseSQL(t, "github.com/lib/pq"), - mustConnectDatabaseSQL(t, "github.com/jackc/pgx/stdlib"), - } -} - -func mustCloseAll(t testing.TB, conns []QueryRowCloser) { - for _, conn := range conns { - err := conn.Close() - if err != nil { - t.Error(err) - } - } -} - -func mustConnectPgx(t testing.TB) QueryRowCloser { - config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) - if err != nil { - t.Fatal(err) - } - - conn, err := pgx.Connect(config) - if err != nil { - t.Fatal(err) - } - - return &PgxConn{conn: conn} -} - -func mustClose(t testing.TB, conn interface { - Close() error -}) { - err := conn.Close() - if err != nil { - t.Fatal(err) - } -} - -func mustConnectDatabaseSQL(t testing.TB, driverName string) QueryRowCloser { - var sqlDriverName string - switch driverName { - case "github.com/lib/pq": - sqlDriverName = "postgres" - case "github.com/jackc/pgx/stdlib": - sqlDriverName = "pgx" - default: - t.Fatalf("Unknown driver %v", driverName) - } - - db, err := sql.Open(sqlDriverName, os.Getenv("DATABASE_URL")) - if err != nil { - t.Fatal(err) - } - - return &DatabaseSQLConn{db: db, name: driverName} -} - -type QueryRowScanner interface { - Scan(dest ...interface{}) error -} - -type QueryRowCloser interface { - QueryRow(query string, args ...interface{}) QueryRowScanner - Close() error - DriverName() string -} - -type DatabaseSQLConn struct { - db *sql.DB - name string -} - -func (c *DatabaseSQLConn) QueryRow(query string, args ...interface{}) QueryRowScanner { - return c.db.QueryRow(query, args...) -} - -func (c *DatabaseSQLConn) Close() error { - return c.db.Close() -} - -func (c *DatabaseSQLConn) DriverName() string { - return c.name -} - -type PgxConn struct { - conn *pgx.Conn -} - -func (c *PgxConn) QueryRow(query string, args ...interface{}) QueryRowScanner { - return c.conn.QueryRow(query, args...) -} - -func (c *PgxConn) Close() error { - return c.conn.Close() -} - -func (c *PgxConn) DriverName() string { - return "github.com/jackc/pgx" -} - -// Test scan lib/pq -// Test encode lib/pq -// Test scan pgx/stdlib -// Test encode pgx/stdlib -// Test scan pgx binary -// Test scan pgx text -// Test encode pgx - -func TestInt32BoxScan(t *testing.T) { - conns := mustConnectAll(t) - defer mustCloseAll(t, conns) - - tests := []struct { - name string - sql string - args []interface{} - err error - result Int32Box - }{ - { - name: "Scan", - sql: "select 42", - args: []interface{}{}, - err: nil, - result: Int32Box{Status: Present, Value2: 42}, - }, - { - name: "Encode", - sql: "select $1::int4", - args: []interface{}{&Int32Box{Status: Present, Value2: 42}}, - err: nil, - result: Int32Box{Status: Present, Value2: 42}, - }, - } - - for _, conn := range conns { - for _, tt := range tests { - var n Int32Box - err := conn.QueryRow(tt.sql, tt.args...).Scan(&n) - if err != tt.err { - t.Errorf("%s %s: %v", conn.DriverName(), tt.name, err) - } - - if n.Status != tt.result.Status { - t.Errorf("%s %s: expected Status %v, got %v", conn.DriverName(), tt.name, tt.result.Status, n.Status) - } - if n.Value2 != tt.result.Value2 { - t.Errorf("%s %s: expected Value %v, got %v", conn.DriverName(), tt.name, tt.result.Value2, n.Value2) - } - } - } -} - -func TestStringBoxScan(t *testing.T) { - conns := mustConnectAll(t) - defer mustCloseAll(t, conns) - - for _, conn := range conns { - var n StringBox - err := conn.QueryRow("select 'Hello, world'").Scan(&n) - if err != nil { - t.Errorf("%s: %v", conn.DriverName(), err) - } - - if n.Status != Present { - t.Errorf("%s: expected Status %v, got %v", conn.DriverName(), Present, n.Status) - } - if n.Value != "Hello, world" { - t.Errorf("%s: expected Value %v, got %v", "Hello, world", conn.DriverName(), n.Value) - } - } -} diff --git a/pgtype/range.go b/pgtype/range.go deleted file mode 100644 index 9137ab74..00000000 --- a/pgtype/range.go +++ /dev/null @@ -1,286 +0,0 @@ -package pgtype - -import ( - "bytes" - "encoding/binary" - "fmt" - "io" - "unicode" -) - -type BoundType byte - -const ( - Inclusive = BoundType('i') - Exclusive = BoundType('e') - Unbounded = BoundType('U') - Empty = BoundType('E') -) - -type UntypedTextRange struct { - Lower string - Upper string - LowerType BoundType - UpperType BoundType -} - -func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { - utr := &UntypedTextRange{} - if src == "empty" { - utr.LowerType = 'E' - utr.UpperType = 'E' - return utr, nil - } - - buf := bytes.NewBufferString(src) - - skipWhitespace(buf) - - r, _, err := buf.ReadRune() - if err != nil { - return nil, fmt.Errorf("invalid lower bound: %v", err) - } - switch r { - case '(': - utr.LowerType = Exclusive - case '[': - utr.LowerType = Inclusive - default: - return nil, fmt.Errorf("missing lower bound, instead got: %v", string(r)) - } - - r, _, err = buf.ReadRune() - if err != nil { - return nil, fmt.Errorf("invalid lower value: %v", err) - } - buf.UnreadRune() - - if r == ',' { - utr.LowerType = Unbounded - } else { - utr.Lower, err = rangeParseValue(buf) - if err != nil { - return nil, fmt.Errorf("invalid lower value: %v", err) - } - } - - r, _, err = buf.ReadRune() - if err != nil { - return nil, fmt.Errorf("missing range separator: %v", err) - } - if r != ',' { - return nil, fmt.Errorf("missing range separator: %v", r) - } - - r, _, err = buf.ReadRune() - if err != nil { - return nil, fmt.Errorf("invalid upper value: %v", err) - } - buf.UnreadRune() - - if r == ')' || r == ']' { - utr.UpperType = Unbounded - } else { - utr.Upper, err = rangeParseValue(buf) - if err != nil { - return nil, fmt.Errorf("invalid upper value: %v", err) - } - } - - r, _, err = buf.ReadRune() - if err != nil { - return nil, fmt.Errorf("missing upper bound: %v", err) - } - switch r { - case ')': - utr.UpperType = Exclusive - case ']': - utr.UpperType = Inclusive - default: - return nil, fmt.Errorf("missing upper bound, instead got: %v", string(r)) - } - - skipWhitespace(buf) - - if buf.Len() > 0 { - return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) - } - - return utr, nil -} - -func skipWhitespace(buf *bytes.Buffer) { - var r rune - var err error - for r, _, _ = buf.ReadRune(); unicode.IsSpace(r); r, _, _ = buf.ReadRune() { - } - - if err != io.EOF { - buf.UnreadRune() - } -} - -func rangeParseValue(buf *bytes.Buffer) (string, error) { - r, _, err := buf.ReadRune() - if err != nil { - return "", err - } - if r == '"' { - return rangeParseQuotedValue(buf) - } - buf.UnreadRune() - - s := &bytes.Buffer{} - - for { - r, _, err := buf.ReadRune() - if err != nil { - return "", err - } - - switch r { - case '\\': - r, _, err = buf.ReadRune() - if err != nil { - return "", err - } - case ',', '[', ']', '(', ')': - buf.UnreadRune() - return s.String(), nil - } - - s.WriteRune(r) - } -} - -func rangeParseQuotedValue(buf *bytes.Buffer) (string, error) { - s := &bytes.Buffer{} - - for { - r, _, err := buf.ReadRune() - if err != nil { - return "", err - } - - switch r { - case '\\': - r, _, err = buf.ReadRune() - if err != nil { - return "", err - } - case '"': - r, _, err = buf.ReadRune() - if err != nil { - return "", err - } - if r != '"' { - buf.UnreadRune() - return s.String(), nil - } - } - s.WriteRune(r) - } -} - -type UntypedBinaryRange struct { - Lower []byte - Upper []byte - LowerType BoundType - UpperType BoundType -} - -// 0 = () = 00000 -// 1 = empty = 00001 -// 2 = [) = 00010 -// 4 = (] = 00100 -// 6 = [] = 00110 -// 8 = ) = 01000 -// 12 = ] = 01100 -// 16 = ( = 10000 -// 18 = [ = 10010 -// 24 = = 11000 - -const emptyMask = 1 -const lowerInclusiveMask = 2 -const upperInclusiveMask = 4 -const lowerUnboundedMask = 8 -const upperUnboundedMask = 16 - -func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { - ubr := &UntypedBinaryRange{} - - if len(src) == 0 { - return nil, fmt.Errorf("range too short: %v", len(src)) - } - - rangeType := src[0] - rp := 1 - - if rangeType&emptyMask > 0 { - if len(src[rp:]) > 0 { - return nil, fmt.Errorf("unexpected trailing bytes parsing empty range: %v", len(src[rp:])) - } - ubr.LowerType = Empty - ubr.UpperType = Empty - return ubr, nil - } - - if rangeType&lowerInclusiveMask > 0 { - ubr.LowerType = Inclusive - } else if rangeType&lowerUnboundedMask > 0 { - ubr.LowerType = Unbounded - } else { - ubr.LowerType = Exclusive - } - - if rangeType&upperInclusiveMask > 0 { - ubr.UpperType = Inclusive - } else if rangeType&upperUnboundedMask > 0 { - ubr.UpperType = Unbounded - } else { - ubr.UpperType = Exclusive - } - - if ubr.LowerType == Unbounded && ubr.UpperType == Unbounded { - if len(src[rp:]) > 0 { - return nil, fmt.Errorf("unexpected trailing bytes parsing unbounded range: %v", len(src[rp:])) - } - return ubr, nil - } - - if len(src[rp:]) < 4 { - return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) - } - valueLen := int(binary.BigEndian.Uint32(src[rp:])) - rp += 4 - - val := src[rp : rp+valueLen] - rp += valueLen - - if ubr.LowerType != Unbounded { - ubr.Lower = val - } else { - ubr.Upper = val - if len(src[rp:]) > 0 { - return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) - } - return ubr, nil - } - - if ubr.UpperType != Unbounded { - if len(src[rp:]) < 4 { - return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) - } - valueLen := int(binary.BigEndian.Uint32(src[rp:])) - rp += 4 - ubr.Upper = src[rp : rp+valueLen] - rp += valueLen - } - - if len(src[rp:]) > 0 { - return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) - } - - return ubr, nil - -} diff --git a/pgtype/range_test.go b/pgtype/range_test.go deleted file mode 100644 index 9e16df59..00000000 --- a/pgtype/range_test.go +++ /dev/null @@ -1,177 +0,0 @@ -package pgtype - -import ( - "bytes" - "testing" -) - -func TestParseUntypedTextRange(t *testing.T) { - tests := []struct { - src string - result UntypedTextRange - err error - }{ - { - src: `[1,2)`, - result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `[1,2]`, - result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Inclusive}, - err: nil, - }, - { - src: `(1,3)`, - result: UntypedTextRange{Lower: "1", Upper: "3", LowerType: Exclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: ` [1,2) `, - result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `[ foo , bar )`, - result: UntypedTextRange{Lower: " foo ", Upper: " bar ", LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `["foo","bar")`, - result: UntypedTextRange{Lower: "foo", Upper: "bar", LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `["f""oo","b""ar")`, - result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `["f""oo","b""ar")`, - result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `["","bar")`, - result: UntypedTextRange{Lower: ``, Upper: `bar`, LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `[f\"oo\,,b\\ar\))`, - result: UntypedTextRange{Lower: `f"oo,`, Upper: `b\ar)`, LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: `empty`, - result: UntypedTextRange{Lower: "", Upper: "", LowerType: Empty, UpperType: Empty}, - err: nil, - }, - } - - for i, tt := range tests { - r, err := ParseUntypedTextRange(tt.src) - if err != tt.err { - t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) - continue - } - - if r.LowerType != tt.result.LowerType { - t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType)) - } - - if r.UpperType != tt.result.UpperType { - t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType)) - } - - if r.Lower != tt.result.Lower { - t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower) - } - - if r.Upper != tt.result.Upper { - t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper) - } - } -} - -func TestParseUntypedBinaryRange(t *testing.T) { - tests := []struct { - src []byte - result UntypedBinaryRange - err error - }{ - { - src: []byte{0, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: []byte{1}, - result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Empty, UpperType: Empty}, - err: nil, - }, - { - src: []byte{2, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Exclusive}, - err: nil, - }, - { - src: []byte{4, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Inclusive}, - err: nil, - }, - { - src: []byte{6, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Inclusive}, - err: nil, - }, - { - src: []byte{8, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Exclusive}, - err: nil, - }, - { - src: []byte{12, 0, 0, 0, 2, 0, 5}, - result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Inclusive}, - err: nil, - }, - { - src: []byte{16, 0, 0, 0, 2, 0, 4}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Exclusive, UpperType: Unbounded}, - err: nil, - }, - { - src: []byte{18, 0, 0, 0, 2, 0, 4}, - result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Inclusive, UpperType: Unbounded}, - err: nil, - }, - { - src: []byte{24}, - result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Unbounded, UpperType: Unbounded}, - err: nil, - }, - } - - for i, tt := range tests { - r, err := ParseUntypedBinaryRange(tt.src) - if err != tt.err { - t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) - continue - } - - if r.LowerType != tt.result.LowerType { - t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType)) - } - - if r.UpperType != tt.result.UpperType { - t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType)) - } - - if bytes.Compare(r.Lower, tt.result.Lower) != 0 { - t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower) - } - - if bytes.Compare(r.Upper, tt.result.Upper) != 0 { - t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper) - } - } -} diff --git a/values.go b/values.go index 85c7ad3d..3614febb 100644 --- a/values.go +++ b/values.go @@ -1467,17 +1467,26 @@ func decodeInt8(vr *ValueReader) int64 { return 0 } - if vr.Type().FormatCode != BinaryFormatCode { + vr.err = errRewoundLen + + var n pgtype.Int8 + var err error + switch vr.Type().FormatCode { + case TextFormatCode: + err = n.DecodeText(&valueReader2{vr}) + case BinaryFormatCode: + err = n.DecodeBinary(&valueReader2{vr}) + default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 } - if vr.Len() != 8 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", vr.Len()))) + if err != nil { + vr.Fatal(err) return 0 } - return vr.ReadInt64() + return int64(n) } func decodeChar(vr *ValueReader) Char { @@ -1515,17 +1524,26 @@ func decodeInt2(vr *ValueReader) int16 { return 0 } - if vr.Type().FormatCode != BinaryFormatCode { + vr.err = errRewoundLen + + var n pgtype.Int2 + var err error + switch vr.Type().FormatCode { + case TextFormatCode: + err = n.DecodeText(&valueReader2{vr}) + case BinaryFormatCode: + err = n.DecodeBinary(&valueReader2{vr}) + default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 } - if vr.Len() != 2 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2: %d", vr.Len()))) + if err != nil { + vr.Fatal(err) return 0 } - return vr.ReadInt16() + return int16(n) } func encodeInt(w *WriteBuf, oid OID, value int) error {