mirror of https://github.com/jackc/pgx.git
wip
parent
66b79e9408
commit
3366699bea
|
@ -0,0 +1,75 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type Int2 int32
|
||||
|
||||
func (i *Int2) DecodeText(r io.Reader) error {
|
||||
size, err := pgio.ReadInt32(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if size == -1 {
|
||||
return fmt.Errorf("invalid length for int2: %v", size)
|
||||
}
|
||||
|
||||
buf := make([]byte, int(size))
|
||||
_, err = r.Read(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
n, err := strconv.ParseInt(string(buf), 10, 16)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*i = Int2(n)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Int2) DecodeBinary(r io.Reader) error {
|
||||
size, err := pgio.ReadInt32(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if size != 2 {
|
||||
return fmt.Errorf("invalid length for int2: %v", size)
|
||||
}
|
||||
|
||||
n, err := pgio.ReadInt16(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*i = Int2(n)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i Int2) EncodeText(w io.Writer) error {
|
||||
s := strconv.FormatInt(int64(i), 10)
|
||||
_, err := pgio.WriteInt32(w, int32(len(s)))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
_, err = w.Write([]byte(s))
|
||||
return err
|
||||
}
|
||||
|
||||
func (i Int2) EncodeBinary(w io.Writer) error {
|
||||
_, err := pgio.WriteInt32(w, 2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = pgio.WriteInt16(w, int16(i))
|
||||
return err
|
||||
}
|
|
@ -1,85 +0,0 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type Int4range struct {
|
||||
Lower int32
|
||||
Upper int32
|
||||
LowerType BoundType
|
||||
UpperType BoundType
|
||||
}
|
||||
|
||||
func (r *Int4range) ParseText(src string) error {
|
||||
utr, err := ParseUntypedTextRange(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.LowerType = utr.LowerType
|
||||
r.UpperType = utr.UpperType
|
||||
|
||||
if r.LowerType == Empty {
|
||||
return nil
|
||||
}
|
||||
|
||||
if r.LowerType == Inclusive || r.LowerType == Exclusive {
|
||||
n, err := strconv.ParseInt(utr.Lower, 10, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Lower = int32(n)
|
||||
}
|
||||
|
||||
if r.UpperType == Inclusive || r.UpperType == Exclusive {
|
||||
n, err := strconv.ParseInt(utr.Upper, 10, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Upper = int32(n)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Int4range) ParseBinary(src []byte) error {
|
||||
ubr, err := ParseUntypedBinaryRange(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.LowerType = ubr.LowerType
|
||||
r.UpperType = ubr.UpperType
|
||||
|
||||
if r.LowerType == Empty {
|
||||
return nil
|
||||
}
|
||||
|
||||
if r.LowerType == Inclusive || r.LowerType == Exclusive {
|
||||
if len(ubr.Lower) != 4 {
|
||||
return fmt.Errorf("invalid length for lower int4: %v", len(ubr.Lower))
|
||||
}
|
||||
r.Lower = int32(binary.BigEndian.Uint32(ubr.Lower))
|
||||
}
|
||||
|
||||
if r.UpperType == Inclusive || r.UpperType == Exclusive {
|
||||
if len(ubr.Upper) != 4 {
|
||||
return fmt.Errorf("invalid length for upper int4: %v", len(ubr.Upper))
|
||||
}
|
||||
r.Upper = int32(binary.BigEndian.Uint32(ubr.Upper))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Int4range) FormatText(w io.Writer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Int4range) FormatBinary(w io.Writer) error {
|
||||
return nil
|
||||
}
|
|
@ -1,186 +0,0 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
)
|
||||
|
||||
// func TestInt4rangeText(t *testing.T) {
|
||||
// conns := mustConnectAll(t)
|
||||
// defer mustCloseAll(t, conns)
|
||||
|
||||
// tests := []struct {
|
||||
// name string
|
||||
// sql string
|
||||
// args []interface{}
|
||||
// err error
|
||||
// result Int4range
|
||||
// }{
|
||||
// {
|
||||
// name: "Normal",
|
||||
// sql: "select $1::int4range",
|
||||
// args: []interface{}{&Int4range{Lower: 1, Upper: 10, LowerType: Inclusive, UpperType: Exclusive}},
|
||||
// err: nil,
|
||||
// result: Int4range{Lower: 1, Upper: 10, LowerType: Inclusive, UpperType: Exclusive},
|
||||
// },
|
||||
// {
|
||||
// name: "Negative",
|
||||
// sql: "select int4range(-42, -5)",
|
||||
// args: []interface{}{&Int4range{Lower: -42, Upper: -5, LowerType: Inclusive, UpperType: Exclusive}},
|
||||
// err: nil,
|
||||
// result: Int4range{Lower: -42, Upper: -5, LowerType: Inclusive, UpperType: Exclusive},
|
||||
// },
|
||||
// {
|
||||
// name: "Normalized Bounds",
|
||||
// sql: "select int4range(1, 10, '(]')",
|
||||
// args: []interface{}{Int4range{Lower: 1, Upper: 10, LowerType: Exclusive, UpperType: Inclusive}},
|
||||
// err: nil,
|
||||
// result: Int4range{Lower: 2, Upper: 11, LowerType: Inclusive, UpperType: Exclusive},
|
||||
// },
|
||||
// }
|
||||
|
||||
// for _, conn := range conns {
|
||||
// for _, tt := range tests {
|
||||
// var r Int4range
|
||||
// var s string
|
||||
// err := conn.QueryRow(tt.sql, tt.args...).Scan(&s)
|
||||
// if err != tt.err {
|
||||
// t.Errorf("%s %s: %v", conn.DriverName(), tt.name, err)
|
||||
// }
|
||||
|
||||
// err = r.ParseText(s)
|
||||
// if err != nil {
|
||||
// t.Errorf("%s %s: %v", conn.DriverName(), tt.name, err)
|
||||
// }
|
||||
|
||||
// if r != tt.result {
|
||||
// t.Errorf("%s %s: expected %#v, got %#v", conn.DriverName(), tt.name, tt.result, r)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
func TestInt4rangeParseText(t *testing.T) {
|
||||
conns := mustConnectAll(t)
|
||||
defer mustCloseAll(t, conns)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sql string
|
||||
args []interface{}
|
||||
err error
|
||||
result Int4range
|
||||
}{
|
||||
{
|
||||
name: "Scan",
|
||||
sql: "select int4range(1, 10)",
|
||||
args: []interface{}{},
|
||||
err: nil,
|
||||
result: Int4range{Lower: 1, Upper: 10, LowerType: Inclusive, UpperType: Exclusive},
|
||||
},
|
||||
{
|
||||
name: "Scan Negative",
|
||||
sql: "select int4range(-42, -5)",
|
||||
args: []interface{}{},
|
||||
err: nil,
|
||||
result: Int4range{Lower: -42, Upper: -5, LowerType: Inclusive, UpperType: Exclusive},
|
||||
},
|
||||
{
|
||||
name: "Scan Normalized Bounds",
|
||||
sql: "select int4range(1, 10, '(]')",
|
||||
args: []interface{}{},
|
||||
err: nil,
|
||||
result: Int4range{Lower: 2, Upper: 11, LowerType: Inclusive, UpperType: Exclusive},
|
||||
},
|
||||
}
|
||||
|
||||
for _, conn := range conns {
|
||||
for _, tt := range tests {
|
||||
var r Int4range
|
||||
var s string
|
||||
err := conn.QueryRow(tt.sql, tt.args...).Scan(&s)
|
||||
if err != tt.err {
|
||||
t.Errorf("%s %s: %v", conn.DriverName(), tt.name, err)
|
||||
}
|
||||
|
||||
err = r.ParseText(s)
|
||||
if err != nil {
|
||||
t.Errorf("%s %s: %v", conn.DriverName(), tt.name, err)
|
||||
}
|
||||
|
||||
if r != tt.result {
|
||||
t.Errorf("%s %s: expected %#v, got %#v", conn.DriverName(), tt.name, tt.result, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInt4rangeParseBinary(t *testing.T) {
|
||||
config, err := pgx.ParseURI(os.Getenv("DATABASE_URL"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
conn, err := pgx.Connect(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer mustClose(t, conn)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sql string
|
||||
args []interface{}
|
||||
err error
|
||||
result Int4range
|
||||
}{
|
||||
{
|
||||
name: "Scan",
|
||||
sql: "select int4range(1, 10)",
|
||||
args: []interface{}{},
|
||||
err: nil,
|
||||
result: Int4range{Lower: 1, Upper: 10, LowerType: Inclusive, UpperType: Exclusive},
|
||||
},
|
||||
{
|
||||
name: "Scan Negative",
|
||||
sql: "select int4range(-42, -5)",
|
||||
args: []interface{}{},
|
||||
err: nil,
|
||||
result: Int4range{Lower: -42, Upper: -5, LowerType: Inclusive, UpperType: Exclusive},
|
||||
},
|
||||
{
|
||||
name: "Scan Normalized Bounds",
|
||||
sql: "select int4range(1, 10, '(]')",
|
||||
args: []interface{}{},
|
||||
err: nil,
|
||||
result: Int4range{Lower: 2, Upper: 11, LowerType: Inclusive, UpperType: Exclusive},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
ps, err := conn.Prepare(tt.sql, tt.sql)
|
||||
if err != nil {
|
||||
t.Errorf("conn.Prepare failed: %v", err)
|
||||
continue
|
||||
}
|
||||
ps.FieldDescriptions[0].FormatCode = pgx.BinaryFormatCode
|
||||
|
||||
var r Int4range
|
||||
var buf []byte
|
||||
err = conn.QueryRow(tt.sql, tt.args...).Scan(&buf)
|
||||
if err != tt.err {
|
||||
t.Errorf("%s: %v", tt.name, err)
|
||||
}
|
||||
|
||||
err = r.ParseBinary(buf)
|
||||
if err != nil {
|
||||
t.Errorf("%s: %v", tt.name, err)
|
||||
}
|
||||
|
||||
if r != tt.result {
|
||||
t.Errorf("%s: expected %#v, got %#v", tt.name, tt.result, r)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type Int8 int64
|
||||
|
||||
func (i *Int8) DecodeText(r io.Reader) error {
|
||||
size, err := pgio.ReadInt32(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if size == -1 {
|
||||
return fmt.Errorf("invalid length for int8: %v", size)
|
||||
}
|
||||
|
||||
buf := make([]byte, int(size))
|
||||
_, err = r.Read(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
n, err := strconv.ParseInt(string(buf), 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*i = Int8(n)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Int8) DecodeBinary(r io.Reader) error {
|
||||
size, err := pgio.ReadInt32(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if size != 8 {
|
||||
return fmt.Errorf("invalid length for int8: %v", size)
|
||||
}
|
||||
|
||||
n, err := pgio.ReadInt64(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*i = Int8(n)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i Int8) EncodeText(w io.Writer) error {
|
||||
s := strconv.FormatInt(int64(i), 10)
|
||||
_, err := pgio.WriteInt32(w, int32(len(s)))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
_, err = w.Write([]byte(s))
|
||||
return err
|
||||
}
|
||||
|
||||
func (i Int8) EncodeBinary(w io.Writer) error {
|
||||
_, err := pgio.WriteInt32(w, 8)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = pgio.WriteInt64(w, int64(i))
|
||||
return err
|
||||
}
|
|
@ -1,141 +0,0 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// type FieldDescription interface {
|
||||
// Name() string
|
||||
// Table() uint32
|
||||
// AttributeNumber() int16
|
||||
// DataType() uint32
|
||||
// DataTypeSize() int16
|
||||
// DataTypeName() string
|
||||
// Modifier() int32
|
||||
// FormatCode() int16
|
||||
// }
|
||||
|
||||
// Remember need to delegate for server controlled format like inet
|
||||
|
||||
// Or separate interfaces for raw bytes and preprocessed by pgx?
|
||||
|
||||
// Or interface{} like database/sql - and just pre-process into more things
|
||||
|
||||
// type ScannerV3 interface {
|
||||
// ScanPgxV3(fieldDescription FieldDescription, src interface{}) error
|
||||
// }
|
||||
|
||||
// // Encoders could also return interface{} to delegate to internal pgx
|
||||
|
||||
// type TextEncoderV3 interface {
|
||||
// EncodeTextPgxV3(oid uint32) (interface{}, error)
|
||||
// }
|
||||
|
||||
// type BinaryEncoderV3 interface {
|
||||
// EncodeBinaryPgxV3(oid uint32) (interface{}, error)
|
||||
// }
|
||||
|
||||
// const (
|
||||
// Int4OID = 23
|
||||
// )
|
||||
|
||||
type Status byte
|
||||
|
||||
const (
|
||||
Undefined Status = iota
|
||||
Null
|
||||
Present
|
||||
)
|
||||
|
||||
func (s Status) String() string {
|
||||
switch s {
|
||||
case Undefined:
|
||||
return "Undefined"
|
||||
case Null:
|
||||
return "Null"
|
||||
case Present:
|
||||
return "Present"
|
||||
}
|
||||
|
||||
return "Invalid status"
|
||||
}
|
||||
|
||||
type Int32Box struct {
|
||||
Value2 int32
|
||||
Status Status
|
||||
}
|
||||
|
||||
func (s *Int32Box) ScanPgxV3(fieldDescription interface{}, src interface{}) error {
|
||||
switch v := src.(type) {
|
||||
case int64:
|
||||
s.Value2 = int32(v)
|
||||
s.Status = Present
|
||||
default:
|
||||
return fmt.Errorf("cannot scan %v (%T)", v, v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Int32Box) Scan(src interface{}) error {
|
||||
switch v := src.(type) {
|
||||
case int64:
|
||||
s.Value2 = int32(v)
|
||||
s.Status = Present
|
||||
// TODO - should this have to accept all integer types?
|
||||
case int32:
|
||||
s.Value2 = int32(v)
|
||||
s.Status = Present
|
||||
default:
|
||||
return fmt.Errorf("cannot scan %v (%T)", v, v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s Int32Box) Value() (driver.Value, error) {
|
||||
// if !n.Valid {
|
||||
|
||||
// return nil, nil
|
||||
|
||||
// }
|
||||
|
||||
return int64(s.Value2), nil
|
||||
|
||||
}
|
||||
|
||||
type StringBox struct {
|
||||
Value string
|
||||
Status Status
|
||||
}
|
||||
|
||||
func (s *StringBox) ScanPgxV3(fieldDescription interface{}, src interface{}) error {
|
||||
switch v := src.(type) {
|
||||
case string:
|
||||
s.Value = v
|
||||
s.Status = Present
|
||||
case []byte:
|
||||
s.Value = string(v)
|
||||
s.Status = Present
|
||||
default:
|
||||
return fmt.Errorf("cannot scan %v (%T)", v, v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StringBox) Scan(src interface{}) error {
|
||||
switch v := src.(type) {
|
||||
case string:
|
||||
s.Value = v
|
||||
s.Status = Present
|
||||
case []byte:
|
||||
s.Value = string(v)
|
||||
s.Status = Present
|
||||
default:
|
||||
return fmt.Errorf("cannot scan %v (%T)", v, v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,186 +0,0 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
_ "github.com/jackc/pgx/stdlib"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
func mustConnectAll(t testing.TB) []QueryRowCloser {
|
||||
return []QueryRowCloser{
|
||||
mustConnectPgx(t),
|
||||
mustConnectDatabaseSQL(t, "github.com/lib/pq"),
|
||||
mustConnectDatabaseSQL(t, "github.com/jackc/pgx/stdlib"),
|
||||
}
|
||||
}
|
||||
|
||||
func mustCloseAll(t testing.TB, conns []QueryRowCloser) {
|
||||
for _, conn := range conns {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func mustConnectPgx(t testing.TB) QueryRowCloser {
|
||||
config, err := pgx.ParseURI(os.Getenv("DATABASE_URL"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
conn, err := pgx.Connect(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return &PgxConn{conn: conn}
|
||||
}
|
||||
|
||||
func mustClose(t testing.TB, conn interface {
|
||||
Close() error
|
||||
}) {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func mustConnectDatabaseSQL(t testing.TB, driverName string) QueryRowCloser {
|
||||
var sqlDriverName string
|
||||
switch driverName {
|
||||
case "github.com/lib/pq":
|
||||
sqlDriverName = "postgres"
|
||||
case "github.com/jackc/pgx/stdlib":
|
||||
sqlDriverName = "pgx"
|
||||
default:
|
||||
t.Fatalf("Unknown driver %v", driverName)
|
||||
}
|
||||
|
||||
db, err := sql.Open(sqlDriverName, os.Getenv("DATABASE_URL"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return &DatabaseSQLConn{db: db, name: driverName}
|
||||
}
|
||||
|
||||
type QueryRowScanner interface {
|
||||
Scan(dest ...interface{}) error
|
||||
}
|
||||
|
||||
type QueryRowCloser interface {
|
||||
QueryRow(query string, args ...interface{}) QueryRowScanner
|
||||
Close() error
|
||||
DriverName() string
|
||||
}
|
||||
|
||||
type DatabaseSQLConn struct {
|
||||
db *sql.DB
|
||||
name string
|
||||
}
|
||||
|
||||
func (c *DatabaseSQLConn) QueryRow(query string, args ...interface{}) QueryRowScanner {
|
||||
return c.db.QueryRow(query, args...)
|
||||
}
|
||||
|
||||
func (c *DatabaseSQLConn) Close() error {
|
||||
return c.db.Close()
|
||||
}
|
||||
|
||||
func (c *DatabaseSQLConn) DriverName() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
type PgxConn struct {
|
||||
conn *pgx.Conn
|
||||
}
|
||||
|
||||
func (c *PgxConn) QueryRow(query string, args ...interface{}) QueryRowScanner {
|
||||
return c.conn.QueryRow(query, args...)
|
||||
}
|
||||
|
||||
func (c *PgxConn) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *PgxConn) DriverName() string {
|
||||
return "github.com/jackc/pgx"
|
||||
}
|
||||
|
||||
// Test scan lib/pq
|
||||
// Test encode lib/pq
|
||||
// Test scan pgx/stdlib
|
||||
// Test encode pgx/stdlib
|
||||
// Test scan pgx binary
|
||||
// Test scan pgx text
|
||||
// Test encode pgx
|
||||
|
||||
func TestInt32BoxScan(t *testing.T) {
|
||||
conns := mustConnectAll(t)
|
||||
defer mustCloseAll(t, conns)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sql string
|
||||
args []interface{}
|
||||
err error
|
||||
result Int32Box
|
||||
}{
|
||||
{
|
||||
name: "Scan",
|
||||
sql: "select 42",
|
||||
args: []interface{}{},
|
||||
err: nil,
|
||||
result: Int32Box{Status: Present, Value2: 42},
|
||||
},
|
||||
{
|
||||
name: "Encode",
|
||||
sql: "select $1::int4",
|
||||
args: []interface{}{&Int32Box{Status: Present, Value2: 42}},
|
||||
err: nil,
|
||||
result: Int32Box{Status: Present, Value2: 42},
|
||||
},
|
||||
}
|
||||
|
||||
for _, conn := range conns {
|
||||
for _, tt := range tests {
|
||||
var n Int32Box
|
||||
err := conn.QueryRow(tt.sql, tt.args...).Scan(&n)
|
||||
if err != tt.err {
|
||||
t.Errorf("%s %s: %v", conn.DriverName(), tt.name, err)
|
||||
}
|
||||
|
||||
if n.Status != tt.result.Status {
|
||||
t.Errorf("%s %s: expected Status %v, got %v", conn.DriverName(), tt.name, tt.result.Status, n.Status)
|
||||
}
|
||||
if n.Value2 != tt.result.Value2 {
|
||||
t.Errorf("%s %s: expected Value %v, got %v", conn.DriverName(), tt.name, tt.result.Value2, n.Value2)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringBoxScan(t *testing.T) {
|
||||
conns := mustConnectAll(t)
|
||||
defer mustCloseAll(t, conns)
|
||||
|
||||
for _, conn := range conns {
|
||||
var n StringBox
|
||||
err := conn.QueryRow("select 'Hello, world'").Scan(&n)
|
||||
if err != nil {
|
||||
t.Errorf("%s: %v", conn.DriverName(), err)
|
||||
}
|
||||
|
||||
if n.Status != Present {
|
||||
t.Errorf("%s: expected Status %v, got %v", conn.DriverName(), Present, n.Status)
|
||||
}
|
||||
if n.Value != "Hello, world" {
|
||||
t.Errorf("%s: expected Value %v, got %v", "Hello, world", conn.DriverName(), n.Value)
|
||||
}
|
||||
}
|
||||
}
|
286
pgtype/range.go
286
pgtype/range.go
|
@ -1,286 +0,0 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
type BoundType byte
|
||||
|
||||
const (
|
||||
Inclusive = BoundType('i')
|
||||
Exclusive = BoundType('e')
|
||||
Unbounded = BoundType('U')
|
||||
Empty = BoundType('E')
|
||||
)
|
||||
|
||||
type UntypedTextRange struct {
|
||||
Lower string
|
||||
Upper string
|
||||
LowerType BoundType
|
||||
UpperType BoundType
|
||||
}
|
||||
|
||||
func ParseUntypedTextRange(src string) (*UntypedTextRange, error) {
|
||||
utr := &UntypedTextRange{}
|
||||
if src == "empty" {
|
||||
utr.LowerType = 'E'
|
||||
utr.UpperType = 'E'
|
||||
return utr, nil
|
||||
}
|
||||
|
||||
buf := bytes.NewBufferString(src)
|
||||
|
||||
skipWhitespace(buf)
|
||||
|
||||
r, _, err := buf.ReadRune()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid lower bound: %v", err)
|
||||
}
|
||||
switch r {
|
||||
case '(':
|
||||
utr.LowerType = Exclusive
|
||||
case '[':
|
||||
utr.LowerType = Inclusive
|
||||
default:
|
||||
return nil, fmt.Errorf("missing lower bound, instead got: %v", string(r))
|
||||
}
|
||||
|
||||
r, _, err = buf.ReadRune()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid lower value: %v", err)
|
||||
}
|
||||
buf.UnreadRune()
|
||||
|
||||
if r == ',' {
|
||||
utr.LowerType = Unbounded
|
||||
} else {
|
||||
utr.Lower, err = rangeParseValue(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid lower value: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
r, _, err = buf.ReadRune()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("missing range separator: %v", err)
|
||||
}
|
||||
if r != ',' {
|
||||
return nil, fmt.Errorf("missing range separator: %v", r)
|
||||
}
|
||||
|
||||
r, _, err = buf.ReadRune()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid upper value: %v", err)
|
||||
}
|
||||
buf.UnreadRune()
|
||||
|
||||
if r == ')' || r == ']' {
|
||||
utr.UpperType = Unbounded
|
||||
} else {
|
||||
utr.Upper, err = rangeParseValue(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid upper value: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
r, _, err = buf.ReadRune()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("missing upper bound: %v", err)
|
||||
}
|
||||
switch r {
|
||||
case ')':
|
||||
utr.UpperType = Exclusive
|
||||
case ']':
|
||||
utr.UpperType = Inclusive
|
||||
default:
|
||||
return nil, fmt.Errorf("missing upper bound, instead got: %v", string(r))
|
||||
}
|
||||
|
||||
skipWhitespace(buf)
|
||||
|
||||
if buf.Len() > 0 {
|
||||
return nil, fmt.Errorf("unexpected trailing data: %v", buf.String())
|
||||
}
|
||||
|
||||
return utr, nil
|
||||
}
|
||||
|
||||
func skipWhitespace(buf *bytes.Buffer) {
|
||||
var r rune
|
||||
var err error
|
||||
for r, _, _ = buf.ReadRune(); unicode.IsSpace(r); r, _, _ = buf.ReadRune() {
|
||||
}
|
||||
|
||||
if err != io.EOF {
|
||||
buf.UnreadRune()
|
||||
}
|
||||
}
|
||||
|
||||
func rangeParseValue(buf *bytes.Buffer) (string, error) {
|
||||
r, _, err := buf.ReadRune()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if r == '"' {
|
||||
return rangeParseQuotedValue(buf)
|
||||
}
|
||||
buf.UnreadRune()
|
||||
|
||||
s := &bytes.Buffer{}
|
||||
|
||||
for {
|
||||
r, _, err := buf.ReadRune()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
switch r {
|
||||
case '\\':
|
||||
r, _, err = buf.ReadRune()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
case ',', '[', ']', '(', ')':
|
||||
buf.UnreadRune()
|
||||
return s.String(), nil
|
||||
}
|
||||
|
||||
s.WriteRune(r)
|
||||
}
|
||||
}
|
||||
|
||||
func rangeParseQuotedValue(buf *bytes.Buffer) (string, error) {
|
||||
s := &bytes.Buffer{}
|
||||
|
||||
for {
|
||||
r, _, err := buf.ReadRune()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
switch r {
|
||||
case '\\':
|
||||
r, _, err = buf.ReadRune()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
case '"':
|
||||
r, _, err = buf.ReadRune()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if r != '"' {
|
||||
buf.UnreadRune()
|
||||
return s.String(), nil
|
||||
}
|
||||
}
|
||||
s.WriteRune(r)
|
||||
}
|
||||
}
|
||||
|
||||
type UntypedBinaryRange struct {
|
||||
Lower []byte
|
||||
Upper []byte
|
||||
LowerType BoundType
|
||||
UpperType BoundType
|
||||
}
|
||||
|
||||
// 0 = () = 00000
|
||||
// 1 = empty = 00001
|
||||
// 2 = [) = 00010
|
||||
// 4 = (] = 00100
|
||||
// 6 = [] = 00110
|
||||
// 8 = ) = 01000
|
||||
// 12 = ] = 01100
|
||||
// 16 = ( = 10000
|
||||
// 18 = [ = 10010
|
||||
// 24 = = 11000
|
||||
|
||||
const emptyMask = 1
|
||||
const lowerInclusiveMask = 2
|
||||
const upperInclusiveMask = 4
|
||||
const lowerUnboundedMask = 8
|
||||
const upperUnboundedMask = 16
|
||||
|
||||
func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) {
|
||||
ubr := &UntypedBinaryRange{}
|
||||
|
||||
if len(src) == 0 {
|
||||
return nil, fmt.Errorf("range too short: %v", len(src))
|
||||
}
|
||||
|
||||
rangeType := src[0]
|
||||
rp := 1
|
||||
|
||||
if rangeType&emptyMask > 0 {
|
||||
if len(src[rp:]) > 0 {
|
||||
return nil, fmt.Errorf("unexpected trailing bytes parsing empty range: %v", len(src[rp:]))
|
||||
}
|
||||
ubr.LowerType = Empty
|
||||
ubr.UpperType = Empty
|
||||
return ubr, nil
|
||||
}
|
||||
|
||||
if rangeType&lowerInclusiveMask > 0 {
|
||||
ubr.LowerType = Inclusive
|
||||
} else if rangeType&lowerUnboundedMask > 0 {
|
||||
ubr.LowerType = Unbounded
|
||||
} else {
|
||||
ubr.LowerType = Exclusive
|
||||
}
|
||||
|
||||
if rangeType&upperInclusiveMask > 0 {
|
||||
ubr.UpperType = Inclusive
|
||||
} else if rangeType&upperUnboundedMask > 0 {
|
||||
ubr.UpperType = Unbounded
|
||||
} else {
|
||||
ubr.UpperType = Exclusive
|
||||
}
|
||||
|
||||
if ubr.LowerType == Unbounded && ubr.UpperType == Unbounded {
|
||||
if len(src[rp:]) > 0 {
|
||||
return nil, fmt.Errorf("unexpected trailing bytes parsing unbounded range: %v", len(src[rp:]))
|
||||
}
|
||||
return ubr, nil
|
||||
}
|
||||
|
||||
if len(src[rp:]) < 4 {
|
||||
return nil, fmt.Errorf("too few bytes for size: %v", src[rp:])
|
||||
}
|
||||
valueLen := int(binary.BigEndian.Uint32(src[rp:]))
|
||||
rp += 4
|
||||
|
||||
val := src[rp : rp+valueLen]
|
||||
rp += valueLen
|
||||
|
||||
if ubr.LowerType != Unbounded {
|
||||
ubr.Lower = val
|
||||
} else {
|
||||
ubr.Upper = val
|
||||
if len(src[rp:]) > 0 {
|
||||
return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:]))
|
||||
}
|
||||
return ubr, nil
|
||||
}
|
||||
|
||||
if ubr.UpperType != Unbounded {
|
||||
if len(src[rp:]) < 4 {
|
||||
return nil, fmt.Errorf("too few bytes for size: %v", src[rp:])
|
||||
}
|
||||
valueLen := int(binary.BigEndian.Uint32(src[rp:]))
|
||||
rp += 4
|
||||
ubr.Upper = src[rp : rp+valueLen]
|
||||
rp += valueLen
|
||||
}
|
||||
|
||||
if len(src[rp:]) > 0 {
|
||||
return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:]))
|
||||
}
|
||||
|
||||
return ubr, nil
|
||||
|
||||
}
|
|
@ -1,177 +0,0 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseUntypedTextRange(t *testing.T) {
|
||||
tests := []struct {
|
||||
src string
|
||||
result UntypedTextRange
|
||||
err error
|
||||
}{
|
||||
{
|
||||
src: `[1,2)`,
|
||||
result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
src: `[1,2]`,
|
||||
result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Inclusive},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
src: `(1,3)`,
|
||||
result: UntypedTextRange{Lower: "1", Upper: "3", LowerType: Exclusive, UpperType: Exclusive},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
src: ` [1,2) `,
|
||||
result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
src: `[ foo , bar )`,
|
||||
result: UntypedTextRange{Lower: " foo ", Upper: " bar ", LowerType: Inclusive, UpperType: Exclusive},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
src: `["foo","bar")`,
|
||||
result: UntypedTextRange{Lower: "foo", Upper: "bar", LowerType: Inclusive, UpperType: Exclusive},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
src: `["f""oo","b""ar")`,
|
||||
result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
src: `["f""oo","b""ar")`,
|
||||
result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
src: `["","bar")`,
|
||||
result: UntypedTextRange{Lower: ``, Upper: `bar`, LowerType: Inclusive, UpperType: Exclusive},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
src: `[f\"oo\,,b\\ar\))`,
|
||||
result: UntypedTextRange{Lower: `f"oo,`, Upper: `b\ar)`, LowerType: Inclusive, UpperType: Exclusive},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
src: `empty`,
|
||||
result: UntypedTextRange{Lower: "", Upper: "", LowerType: Empty, UpperType: Empty},
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
r, err := ParseUntypedTextRange(tt.src)
|
||||
if err != tt.err {
|
||||
t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if r.LowerType != tt.result.LowerType {
|
||||
t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType))
|
||||
}
|
||||
|
||||
if r.UpperType != tt.result.UpperType {
|
||||
t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType))
|
||||
}
|
||||
|
||||
if r.Lower != tt.result.Lower {
|
||||
t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower)
|
||||
}
|
||||
|
||||
if r.Upper != tt.result.Upper {
|
||||
t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUntypedBinaryRange(t *testing.T) {
|
||||
tests := []struct {
|
||||
src []byte
|
||||
result UntypedBinaryRange
|
||||
err error
|
||||
}{
|
||||
{
|
||||
src: []byte{0, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5},
|
||||
result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Exclusive},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
src: []byte{1},
|
||||
result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Empty, UpperType: Empty},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
src: []byte{2, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5},
|
||||
result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Exclusive},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
src: []byte{4, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5},
|
||||
result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Inclusive},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
src: []byte{6, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5},
|
||||
result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Inclusive},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
src: []byte{8, 0, 0, 0, 2, 0, 5},
|
||||
result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Exclusive},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
src: []byte{12, 0, 0, 0, 2, 0, 5},
|
||||
result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Inclusive},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
src: []byte{16, 0, 0, 0, 2, 0, 4},
|
||||
result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Exclusive, UpperType: Unbounded},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
src: []byte{18, 0, 0, 0, 2, 0, 4},
|
||||
result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Inclusive, UpperType: Unbounded},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
src: []byte{24},
|
||||
result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Unbounded, UpperType: Unbounded},
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
r, err := ParseUntypedBinaryRange(tt.src)
|
||||
if err != tt.err {
|
||||
t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if r.LowerType != tt.result.LowerType {
|
||||
t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType))
|
||||
}
|
||||
|
||||
if r.UpperType != tt.result.UpperType {
|
||||
t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType))
|
||||
}
|
||||
|
||||
if bytes.Compare(r.Lower, tt.result.Lower) != 0 {
|
||||
t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower)
|
||||
}
|
||||
|
||||
if bytes.Compare(r.Upper, tt.result.Upper) != 0 {
|
||||
t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper)
|
||||
}
|
||||
}
|
||||
}
|
34
values.go
34
values.go
|
@ -1467,17 +1467,26 @@ func decodeInt8(vr *ValueReader) int64 {
|
|||
return 0
|
||||
}
|
||||
|
||||
if vr.Type().FormatCode != BinaryFormatCode {
|
||||
vr.err = errRewoundLen
|
||||
|
||||
var n pgtype.Int8
|
||||
var err error
|
||||
switch vr.Type().FormatCode {
|
||||
case TextFormatCode:
|
||||
err = n.DecodeText(&valueReader2{vr})
|
||||
case BinaryFormatCode:
|
||||
err = n.DecodeBinary(&valueReader2{vr})
|
||||
default:
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return 0
|
||||
}
|
||||
|
||||
if vr.Len() != 8 {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", vr.Len())))
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
return vr.ReadInt64()
|
||||
return int64(n)
|
||||
}
|
||||
|
||||
func decodeChar(vr *ValueReader) Char {
|
||||
|
@ -1515,17 +1524,26 @@ func decodeInt2(vr *ValueReader) int16 {
|
|||
return 0
|
||||
}
|
||||
|
||||
if vr.Type().FormatCode != BinaryFormatCode {
|
||||
vr.err = errRewoundLen
|
||||
|
||||
var n pgtype.Int2
|
||||
var err error
|
||||
switch vr.Type().FormatCode {
|
||||
case TextFormatCode:
|
||||
err = n.DecodeText(&valueReader2{vr})
|
||||
case BinaryFormatCode:
|
||||
err = n.DecodeBinary(&valueReader2{vr})
|
||||
default:
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return 0
|
||||
}
|
||||
|
||||
if vr.Len() != 2 {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2: %d", vr.Len())))
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
return 0
|
||||
}
|
||||
|
||||
return vr.ReadInt16()
|
||||
return int16(n)
|
||||
}
|
||||
|
||||
func encodeInt(w *WriteBuf, oid OID, value int) error {
|
||||
|
|
Loading…
Reference in New Issue