mirror of https://github.com/jackc/pgx.git
Remove old Scanner and Encoder system
parent
7bb1f3677d
commit
26d57356f7
127
bench_test.go
127
bench_test.go
|
@ -8,6 +8,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
func BenchmarkConnPool(b *testing.B) {
|
||||
|
@ -49,126 +50,6 @@ func BenchmarkConnPoolQueryRow(b *testing.B) {
|
|||
}
|
||||
}
|
||||
|
||||
func BenchmarkNullXWithNullValues(b *testing.B) {
|
||||
conn := mustConnect(b, *defaultConnConfig)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
_, err := conn.Prepare("selectNulls", "select 1::int4, 'johnsmith', null::text, null::text, null::text, null::date, null::timestamptz")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var record struct {
|
||||
id int32
|
||||
userName string
|
||||
email pgx.NullString
|
||||
name pgx.NullString
|
||||
sex pgx.NullString
|
||||
birthDate pgx.NullTime
|
||||
lastLoginTime pgx.NullTime
|
||||
}
|
||||
|
||||
err = conn.QueryRow("selectNulls").Scan(
|
||||
&record.id,
|
||||
&record.userName,
|
||||
&record.email,
|
||||
&record.name,
|
||||
&record.sex,
|
||||
&record.birthDate,
|
||||
&record.lastLoginTime,
|
||||
)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// These checks both ensure that the correct data was returned
|
||||
// and provide a benchmark of accessing the returned values.
|
||||
if record.id != 1 {
|
||||
b.Fatalf("bad value for id: %v", record.id)
|
||||
}
|
||||
if record.userName != "johnsmith" {
|
||||
b.Fatalf("bad value for userName: %v", record.userName)
|
||||
}
|
||||
if record.email.Valid {
|
||||
b.Fatalf("bad value for email: %v", record.email)
|
||||
}
|
||||
if record.name.Valid {
|
||||
b.Fatalf("bad value for name: %v", record.name)
|
||||
}
|
||||
if record.sex.Valid {
|
||||
b.Fatalf("bad value for sex: %v", record.sex)
|
||||
}
|
||||
if record.birthDate.Valid {
|
||||
b.Fatalf("bad value for birthDate: %v", record.birthDate)
|
||||
}
|
||||
if record.lastLoginTime.Valid {
|
||||
b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNullXWithPresentValues(b *testing.B) {
|
||||
conn := mustConnect(b, *defaultConnConfig)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
_, err := conn.Prepare("selectNulls", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var record struct {
|
||||
id int32
|
||||
userName string
|
||||
email pgx.NullString
|
||||
name pgx.NullString
|
||||
sex pgx.NullString
|
||||
birthDate pgx.NullTime
|
||||
lastLoginTime pgx.NullTime
|
||||
}
|
||||
|
||||
err = conn.QueryRow("selectNulls").Scan(
|
||||
&record.id,
|
||||
&record.userName,
|
||||
&record.email,
|
||||
&record.name,
|
||||
&record.sex,
|
||||
&record.birthDate,
|
||||
&record.lastLoginTime,
|
||||
)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// These checks both ensure that the correct data was returned
|
||||
// and provide a benchmark of accessing the returned values.
|
||||
if record.id != 1 {
|
||||
b.Fatalf("bad value for id: %v", record.id)
|
||||
}
|
||||
if record.userName != "johnsmith" {
|
||||
b.Fatalf("bad value for userName: %v", record.userName)
|
||||
}
|
||||
if !record.email.Valid || record.email.String != "johnsmith@example.com" {
|
||||
b.Fatalf("bad value for email: %v", record.email)
|
||||
}
|
||||
if !record.name.Valid || record.name.String != "John Smith" {
|
||||
b.Fatalf("bad value for name: %v", record.name)
|
||||
}
|
||||
if !record.sex.Valid || record.sex.String != "male" {
|
||||
b.Fatalf("bad value for sex: %v", record.sex)
|
||||
}
|
||||
if !record.birthDate.Valid || record.birthDate.Time != time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local) {
|
||||
b.Fatalf("bad value for birthDate: %v", record.birthDate)
|
||||
}
|
||||
if !record.lastLoginTime.Valid || record.lastLoginTime.Time != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) {
|
||||
b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPointerPointerWithNullValues(b *testing.B) {
|
||||
conn := mustConnect(b, *defaultConnConfig)
|
||||
defer closeConn(b, conn)
|
||||
|
@ -475,12 +356,12 @@ func newBenchmarkWriteTableCopyToSrc(count int) pgx.CopyToSource {
|
|||
row: []interface{}{
|
||||
"varchar_1",
|
||||
"varchar_2",
|
||||
pgx.NullString{},
|
||||
pgtype.Text{},
|
||||
time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local),
|
||||
pgx.NullTime{},
|
||||
pgtype.Date{},
|
||||
1,
|
||||
2,
|
||||
pgx.NullInt32{},
|
||||
pgtype.Int4{},
|
||||
time.Date(2001, 1, 1, 0, 0, 0, 0, time.Local),
|
||||
time.Date(2002, 1, 1, 0, 0, 0, 0, time.Local),
|
||||
true,
|
||||
|
|
4
conn.go
4
conn.go
|
@ -1003,9 +1003,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||
|
||||
wbuf.WriteInt16(int16(len(ps.ParameterOids)))
|
||||
for i, oid := range ps.ParameterOids {
|
||||
switch arg := arguments[i].(type) {
|
||||
case Encoder:
|
||||
wbuf.WriteInt16(arg.FormatCode())
|
||||
switch arguments[i].(type) {
|
||||
case pgtype.BinaryEncoder:
|
||||
wbuf.WriteInt16(BinaryFormatCode)
|
||||
case pgtype.TextEncoder:
|
||||
|
|
|
@ -1,78 +1,63 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`)
|
||||
|
||||
// NullPoint represents a point that may be null.
|
||||
//
|
||||
// If Valid is false then the value is NULL.
|
||||
type NullPoint struct {
|
||||
X, Y float64 // Coordinates of point
|
||||
Valid bool // Valid is true if not NULL
|
||||
// Point represents a point that may be null.
|
||||
type Point struct {
|
||||
X, Y float64 // Coordinates of point
|
||||
Status pgtype.Status
|
||||
}
|
||||
|
||||
func (p *NullPoint) ScanPgx(vr *pgx.ValueReader) error {
|
||||
if vr.Type().DataTypeName != "point" {
|
||||
return pgx.SerializationError(fmt.Sprintf("NullPoint.Scan cannot decode %s (Oid %d)", vr.Type().DataTypeName, vr.Type().DataType))
|
||||
}
|
||||
|
||||
if vr.Len() == -1 {
|
||||
p.X, p.Y, p.Valid = 0, 0, false
|
||||
func (dst *Point) DecodeText(src []byte) error {
|
||||
if src == nil {
|
||||
*dst = Point{Status: pgtype.Null}
|
||||
return nil
|
||||
}
|
||||
|
||||
switch vr.Type().FormatCode {
|
||||
case pgx.TextFormatCode:
|
||||
s := vr.ReadString(vr.Len())
|
||||
match := pointRegexp.FindStringSubmatch(s)
|
||||
if match == nil {
|
||||
return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s))
|
||||
}
|
||||
|
||||
var err error
|
||||
p.X, err = strconv.ParseFloat(match[1], 64)
|
||||
if err != nil {
|
||||
return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s))
|
||||
}
|
||||
p.Y, err = strconv.ParseFloat(match[2], 64)
|
||||
if err != nil {
|
||||
return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s))
|
||||
}
|
||||
case pgx.BinaryFormatCode:
|
||||
return errors.New("binary format not implemented")
|
||||
default:
|
||||
return fmt.Errorf("unknown format %v", vr.Type().FormatCode)
|
||||
s := string(src)
|
||||
match := pointRegexp.FindStringSubmatch(s)
|
||||
if match == nil {
|
||||
return fmt.Errorf("Received invalid point: %v", s)
|
||||
}
|
||||
|
||||
p.Valid = true
|
||||
return vr.Err()
|
||||
}
|
||||
|
||||
func (p NullPoint) FormatCode() int16 { return pgx.BinaryFormatCode }
|
||||
|
||||
func (p NullPoint) Encode(w *pgx.WriteBuf, oid pgx.Oid) error {
|
||||
if !p.Valid {
|
||||
w.WriteInt32(-1)
|
||||
return nil
|
||||
x, err := strconv.ParseFloat(match[1], 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Received invalid point: %v", s)
|
||||
}
|
||||
y, err := strconv.ParseFloat(match[2], 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Received invalid point: %v", s)
|
||||
}
|
||||
|
||||
s := fmt.Sprintf("point(%v,%v)", p.X, p.Y)
|
||||
w.WriteInt32(int32(len(s)))
|
||||
w.WriteBytes([]byte(s))
|
||||
*dst = Point{X: x, Y: y, Status: pgtype.Present}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p NullPoint) String() string {
|
||||
if p.Valid {
|
||||
func (src Point) EncodeText(w io.Writer) (bool, error) {
|
||||
switch src.Status {
|
||||
case pgtype.Null:
|
||||
return true, nil
|
||||
case pgtype.Undefined:
|
||||
return false, fmt.Errorf("undefined")
|
||||
}
|
||||
|
||||
_, err := io.WriteString(w, fmt.Sprintf("point(%v,%v)", src.X, src.Y))
|
||||
return false, err
|
||||
}
|
||||
|
||||
func (p Point) String() string {
|
||||
if p.Status == pgtype.Present {
|
||||
return fmt.Sprintf("%v, %v", p.X, p.Y)
|
||||
}
|
||||
return "null point"
|
||||
|
@ -85,7 +70,7 @@ func Example_CustomType() {
|
|||
return
|
||||
}
|
||||
|
||||
var p NullPoint
|
||||
var p Point
|
||||
err = conn.QueryRow("select null::point").Scan(&p)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
|
|
10
query.go
10
query.go
|
@ -211,16 +211,6 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
|
|||
*b = nil
|
||||
}
|
||||
}
|
||||
} else if s, ok := d.(Scanner); ok {
|
||||
err = s.Scan(vr)
|
||||
if err != nil {
|
||||
rows.Fatal(scanArgError{col: i, err: err})
|
||||
}
|
||||
} else if s, ok := d.(PgxScanner); ok {
|
||||
err = s.ScanPgx(vr)
|
||||
if err != nil {
|
||||
rows.Fatal(scanArgError{col: i, err: err})
|
||||
}
|
||||
} else if s, ok := d.(pgtype.BinaryDecoder); ok && vr.Type().FormatCode == BinaryFormatCode {
|
||||
err = s.DecodeBinary(vr.bytes())
|
||||
if err != nil {
|
||||
|
|
103
query_test.go
103
query_test.go
|
@ -270,44 +270,6 @@ func TestConnQueryScanIgnoreColumn(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnQueryScanner(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
rows, err := conn.Query("select null::int8, 1::int8")
|
||||
if err != nil {
|
||||
t.Fatalf("conn.Query failed: %v", err)
|
||||
}
|
||||
|
||||
ok := rows.Next()
|
||||
if !ok {
|
||||
t.Fatal("rows.Next terminated early")
|
||||
}
|
||||
|
||||
var n, m pgx.NullInt64
|
||||
err = rows.Scan(&n, &m)
|
||||
if err != nil {
|
||||
t.Fatalf("rows.Scan failed: %v", err)
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
if n.Valid {
|
||||
t.Error("Null should not be valid, but it was")
|
||||
}
|
||||
|
||||
if !m.Valid {
|
||||
t.Error("1 should be valid, but it wasn't")
|
||||
}
|
||||
|
||||
if m.Int64 != 1 {
|
||||
t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnQueryErrorWhileReturningRows(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -339,42 +301,6 @@ func TestConnQueryErrorWhileReturningRows(t *testing.T) {
|
|||
|
||||
}
|
||||
|
||||
func TestConnQueryEncoder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
n := pgx.NullInt64{Int64: 1, Valid: true}
|
||||
|
||||
rows, err := conn.Query("select $1::int8", &n)
|
||||
if err != nil {
|
||||
t.Fatalf("conn.Query failed: %v", err)
|
||||
}
|
||||
|
||||
ok := rows.Next()
|
||||
if !ok {
|
||||
t.Fatal("rows.Next terminated early")
|
||||
}
|
||||
|
||||
var m pgx.NullInt64
|
||||
err = rows.Scan(&m)
|
||||
if err != nil {
|
||||
t.Fatalf("rows.Scan failed: %v", err)
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
if !m.Valid {
|
||||
t.Error("m should be valid, but it wasn't")
|
||||
}
|
||||
|
||||
if m.Int64 != 1 {
|
||||
t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryEncodeError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -397,35 +323,6 @@ func TestQueryEncodeError(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// Ensure that an argument that implements Encoder works when the parameter type
|
||||
// is a core type.
|
||||
type coreEncoder struct{}
|
||||
|
||||
func (n coreEncoder) FormatCode() int16 { return pgx.TextFormatCode }
|
||||
|
||||
func (n *coreEncoder) Encode(w *pgx.WriteBuf, oid pgx.Oid) error {
|
||||
w.WriteInt32(int32(2))
|
||||
w.WriteBytes([]byte("42"))
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestQueryEncodeCoreTextFormatError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
var n int32
|
||||
err := conn.QueryRow("select $1::integer", &coreEncoder{}).Scan(&n)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected conn.QueryRow error: %v", err)
|
||||
}
|
||||
|
||||
if n != 42 {
|
||||
t.Errorf("Expected 42, got %v", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryRowCoreTypes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
375
values.go
375
values.go
|
@ -159,245 +159,6 @@ func (e SerializationError) Error() string {
|
|||
return string(e)
|
||||
}
|
||||
|
||||
// Deprecated: Scanner is an interface used to decode values from the PostgreSQL
|
||||
// server. To allow types to support pgx and database/sql.Scan this interface
|
||||
// has been deprecated in favor of PgxScanner.
|
||||
type Scanner interface {
|
||||
// Scan MUST check r.Type().DataType (to check by Oid) or
|
||||
// r.Type().DataTypeName (to check by name) to ensure that it is scanning an
|
||||
// expected column type. It also MUST check r.Type().FormatCode before
|
||||
// decoding. It should not assume that it was called on a data type or format
|
||||
// that it understands.
|
||||
Scan(r *ValueReader) error
|
||||
}
|
||||
|
||||
// PgxScanner is an interface used to decode values from the PostgreSQL server.
|
||||
// It is used exactly the same as the Scanner interface. It simply has renamed
|
||||
// the method.
|
||||
type PgxScanner interface {
|
||||
// ScanPgx MUST check r.Type().DataType (to check by Oid) or
|
||||
// r.Type().DataTypeName (to check by name) to ensure that it is scanning an
|
||||
// expected column type. It also MUST check r.Type().FormatCode before
|
||||
// decoding. It should not assume that it was called on a data type or format
|
||||
// that it understands.
|
||||
ScanPgx(r *ValueReader) error
|
||||
}
|
||||
|
||||
// Encoder is an interface used to encode values for transmission to the
|
||||
// PostgreSQL server.
|
||||
type Encoder interface {
|
||||
// Encode writes the value to w.
|
||||
//
|
||||
// If the value is NULL an int32(-1) should be written.
|
||||
//
|
||||
// Encode MUST check oid to see if the parameter data type is compatible. If
|
||||
// this is not done, the PostgreSQL server may detect the error if the
|
||||
// expected data size or format of the encoded data does not match. But if
|
||||
// the encoded data is a valid representation of the data type PostgreSQL
|
||||
// expects such as date and int4, incorrect data may be stored.
|
||||
Encode(w *WriteBuf, oid Oid) error
|
||||
|
||||
// FormatCode returns the format that the encoder writes the value. It must be
|
||||
// either pgx.TextFormatCode or pgx.BinaryFormatCode.
|
||||
FormatCode() int16
|
||||
}
|
||||
|
||||
// NullFloat32 represents an float4 that may be null. NullFloat32 implements the
|
||||
// Scanner and Encoder interfaces so it may be used both as an argument to
|
||||
// Query[Row] and a destination for Scan.
|
||||
//
|
||||
// If Valid is false then the value is NULL.
|
||||
type NullFloat32 struct {
|
||||
Float32 float32
|
||||
Valid bool // Valid is true if Float32 is not NULL
|
||||
}
|
||||
|
||||
func (n *NullFloat32) Scan(vr *ValueReader) error {
|
||||
if vr.Type().DataType != Float4Oid {
|
||||
return SerializationError(fmt.Sprintf("NullFloat32.Scan cannot decode Oid %d", vr.Type().DataType))
|
||||
}
|
||||
|
||||
if vr.Len() == -1 {
|
||||
n.Float32, n.Valid = 0, false
|
||||
return nil
|
||||
}
|
||||
n.Valid = true
|
||||
n.Float32 = decodeFloat4(vr)
|
||||
return vr.Err()
|
||||
}
|
||||
|
||||
func (n NullFloat32) FormatCode() int16 { return BinaryFormatCode }
|
||||
|
||||
func (n NullFloat32) Encode(w *WriteBuf, oid Oid) error {
|
||||
if oid != Float4Oid {
|
||||
return SerializationError(fmt.Sprintf("NullFloat32.Encode cannot encode into Oid %d", oid))
|
||||
}
|
||||
|
||||
if !n.Valid {
|
||||
w.WriteInt32(-1)
|
||||
return nil
|
||||
}
|
||||
|
||||
return encodeFloat32(w, oid, n.Float32)
|
||||
}
|
||||
|
||||
// NullFloat64 represents an float8 that may be null. NullFloat64 implements the
|
||||
// Scanner and Encoder interfaces so it may be used both as an argument to
|
||||
// Query[Row] and a destination for Scan.
|
||||
//
|
||||
// If Valid is false then the value is NULL.
|
||||
type NullFloat64 struct {
|
||||
Float64 float64
|
||||
Valid bool // Valid is true if Float64 is not NULL
|
||||
}
|
||||
|
||||
func (n *NullFloat64) Scan(vr *ValueReader) error {
|
||||
if vr.Type().DataType != Float8Oid {
|
||||
return SerializationError(fmt.Sprintf("NullFloat64.Scan cannot decode Oid %d", vr.Type().DataType))
|
||||
}
|
||||
|
||||
if vr.Len() == -1 {
|
||||
n.Float64, n.Valid = 0, false
|
||||
return nil
|
||||
}
|
||||
n.Valid = true
|
||||
n.Float64 = decodeFloat8(vr)
|
||||
return vr.Err()
|
||||
}
|
||||
|
||||
func (n NullFloat64) FormatCode() int16 { return BinaryFormatCode }
|
||||
|
||||
func (n NullFloat64) Encode(w *WriteBuf, oid Oid) error {
|
||||
if oid != Float8Oid {
|
||||
return SerializationError(fmt.Sprintf("NullFloat64.Encode cannot encode into Oid %d", oid))
|
||||
}
|
||||
|
||||
if !n.Valid {
|
||||
w.WriteInt32(-1)
|
||||
return nil
|
||||
}
|
||||
|
||||
return encodeFloat64(w, oid, n.Float64)
|
||||
}
|
||||
|
||||
// NullString represents an string that may be null. NullString implements the
|
||||
// Scanner Encoder interfaces so it may be used both as an argument to
|
||||
// Query[Row] and a destination for Scan.
|
||||
//
|
||||
// If Valid is false then the value is NULL.
|
||||
type NullString struct {
|
||||
String string
|
||||
Valid bool // Valid is true if String is not NULL
|
||||
}
|
||||
|
||||
func (n *NullString) Scan(vr *ValueReader) error {
|
||||
// Not checking oid as so we can scan anything into into a NullString - may revisit this decision later
|
||||
|
||||
if vr.Len() == -1 {
|
||||
n.String, n.Valid = "", false
|
||||
return nil
|
||||
}
|
||||
|
||||
n.Valid = true
|
||||
n.String = decodeText(vr)
|
||||
return vr.Err()
|
||||
}
|
||||
|
||||
func (n NullString) FormatCode() int16 { return TextFormatCode }
|
||||
|
||||
func (s NullString) Encode(w *WriteBuf, oid Oid) error {
|
||||
if !s.Valid {
|
||||
w.WriteInt32(-1)
|
||||
return nil
|
||||
}
|
||||
|
||||
return encodeString(w, oid, s.String)
|
||||
}
|
||||
|
||||
// NullInt16 represents a smallint that may be null. NullInt16 implements the
|
||||
// Scanner and Encoder interfaces so it may be used both as an argument to
|
||||
// Query[Row] and a destination for Scan for prepared and unprepared queries.
|
||||
//
|
||||
// If Valid is false then the value is NULL.
|
||||
type NullInt16 struct {
|
||||
Int16 int16
|
||||
Valid bool // Valid is true if Int16 is not NULL
|
||||
}
|
||||
|
||||
func (n *NullInt16) Scan(vr *ValueReader) error {
|
||||
if vr.Type().DataType != Int2Oid {
|
||||
return SerializationError(fmt.Sprintf("NullInt16.Scan cannot decode Oid %d", vr.Type().DataType))
|
||||
}
|
||||
|
||||
if vr.Len() == -1 {
|
||||
n.Int16, n.Valid = 0, false
|
||||
return nil
|
||||
}
|
||||
n.Valid = true
|
||||
n.Int16 = decodeInt2(vr)
|
||||
return vr.Err()
|
||||
}
|
||||
|
||||
func (n NullInt16) FormatCode() int16 { return BinaryFormatCode }
|
||||
|
||||
func (n NullInt16) Encode(w *WriteBuf, oid Oid) error {
|
||||
if oid != Int2Oid {
|
||||
return SerializationError(fmt.Sprintf("NullInt16.Encode cannot encode into Oid %d", oid))
|
||||
}
|
||||
|
||||
if !n.Valid {
|
||||
w.WriteInt32(-1)
|
||||
return nil
|
||||
}
|
||||
|
||||
w.WriteInt32(2)
|
||||
|
||||
_, err := pgtype.Int2{Int: n.Int16, Status: pgtype.Present}.EncodeBinary(w)
|
||||
return err
|
||||
}
|
||||
|
||||
// NullInt32 represents an integer that may be null. NullInt32 implements the
|
||||
// Scanner and Encoder interfaces so it may be used both as an argument to
|
||||
// Query[Row] and a destination for Scan.
|
||||
//
|
||||
// If Valid is false then the value is NULL.
|
||||
type NullInt32 struct {
|
||||
Int32 int32
|
||||
Valid bool // Valid is true if Int32 is not NULL
|
||||
}
|
||||
|
||||
func (n *NullInt32) Scan(vr *ValueReader) error {
|
||||
if vr.Type().DataType != Int4Oid {
|
||||
return SerializationError(fmt.Sprintf("NullInt32.Scan cannot decode Oid %d", vr.Type().DataType))
|
||||
}
|
||||
|
||||
if vr.Len() == -1 {
|
||||
n.Int32, n.Valid = 0, false
|
||||
return nil
|
||||
}
|
||||
n.Valid = true
|
||||
n.Int32 = decodeInt4(vr)
|
||||
return vr.Err()
|
||||
}
|
||||
|
||||
func (n NullInt32) FormatCode() int16 { return BinaryFormatCode }
|
||||
|
||||
func (n NullInt32) Encode(w *WriteBuf, oid Oid) error {
|
||||
if oid != Int4Oid {
|
||||
return SerializationError(fmt.Sprintf("NullInt32.Encode cannot encode into Oid %d", oid))
|
||||
}
|
||||
|
||||
if !n.Valid {
|
||||
w.WriteInt32(-1)
|
||||
return nil
|
||||
}
|
||||
|
||||
w.WriteInt32(4)
|
||||
|
||||
_, err := pgtype.Int4{Int: n.Int32, Status: pgtype.Present}.EncodeBinary(w)
|
||||
return err
|
||||
}
|
||||
|
||||
// Oid (Object Identifier Type) is, according to https://www.postgresql.org/docs/current/static/datatype-oid.html,
|
||||
// used internally by PostgreSQL as a primary key for various system tables. It is currently implemented
|
||||
// as an unsigned four-byte integer. Its definition can be found in src/include/postgres_ext.h
|
||||
|
@ -442,140 +203,6 @@ func (src Oid) EncodeBinary(w io.Writer) (bool, error) {
|
|||
return false, err
|
||||
}
|
||||
|
||||
// NullInt64 represents an bigint that may be null. NullInt64 implements the
|
||||
// Scanner and Encoder interfaces so it may be used both as an argument to
|
||||
// Query[Row] and a destination for Scan.
|
||||
//
|
||||
// If Valid is false then the value is NULL.
|
||||
type NullInt64 struct {
|
||||
Int64 int64
|
||||
Valid bool // Valid is true if Int64 is not NULL
|
||||
}
|
||||
|
||||
func (n *NullInt64) Scan(vr *ValueReader) error {
|
||||
if vr.Type().DataType != Int8Oid {
|
||||
return SerializationError(fmt.Sprintf("NullInt64.Scan cannot decode Oid %d", vr.Type().DataType))
|
||||
}
|
||||
|
||||
if vr.Len() == -1 {
|
||||
n.Int64, n.Valid = 0, false
|
||||
return nil
|
||||
}
|
||||
n.Valid = true
|
||||
n.Int64 = decodeInt8(vr)
|
||||
return vr.Err()
|
||||
}
|
||||
|
||||
func (n NullInt64) FormatCode() int16 { return BinaryFormatCode }
|
||||
|
||||
func (n NullInt64) Encode(w *WriteBuf, oid Oid) error {
|
||||
if oid != Int8Oid {
|
||||
return SerializationError(fmt.Sprintf("NullInt64.Encode cannot encode into Oid %d", oid))
|
||||
}
|
||||
|
||||
if !n.Valid {
|
||||
w.WriteInt32(-1)
|
||||
return nil
|
||||
}
|
||||
|
||||
w.WriteInt32(8)
|
||||
|
||||
_, err := pgtype.Int8{Int: n.Int64, Status: pgtype.Present}.EncodeBinary(w)
|
||||
return err
|
||||
}
|
||||
|
||||
// NullBool represents an bool that may be null. NullBool implements the Scanner
|
||||
// and Encoder interfaces so it may be used both as an argument to Query[Row]
|
||||
// and a destination for Scan.
|
||||
//
|
||||
// If Valid is false then the value is NULL.
|
||||
type NullBool struct {
|
||||
Bool bool
|
||||
Valid bool // Valid is true if Bool is not NULL
|
||||
}
|
||||
|
||||
func (n *NullBool) Scan(vr *ValueReader) error {
|
||||
if vr.Type().DataType != BoolOid {
|
||||
return SerializationError(fmt.Sprintf("NullBool.Scan cannot decode Oid %d", vr.Type().DataType))
|
||||
}
|
||||
|
||||
if vr.Len() == -1 {
|
||||
n.Bool, n.Valid = false, false
|
||||
return nil
|
||||
}
|
||||
n.Valid = true
|
||||
n.Bool = decodeBool(vr)
|
||||
return vr.Err()
|
||||
}
|
||||
|
||||
func (n NullBool) FormatCode() int16 { return BinaryFormatCode }
|
||||
|
||||
func (n NullBool) Encode(w *WriteBuf, oid Oid) error {
|
||||
if oid != BoolOid {
|
||||
return SerializationError(fmt.Sprintf("NullBool.Encode cannot encode into Oid %d", oid))
|
||||
}
|
||||
|
||||
if !n.Valid {
|
||||
w.WriteInt32(-1)
|
||||
return nil
|
||||
}
|
||||
|
||||
w.WriteInt32(1)
|
||||
|
||||
_, err := pgtype.Bool{Bool: n.Bool, Status: pgtype.Present}.EncodeBinary(w)
|
||||
return err
|
||||
}
|
||||
|
||||
// NullTime represents an time.Time that may be null. NullTime implements the
|
||||
// Scanner and Encoder interfaces so it may be used both as an argument to
|
||||
// Query[Row] and a destination for Scan. It corresponds with the PostgreSQL
|
||||
// types timestamptz, timestamp, and date.
|
||||
//
|
||||
// If Valid is false then the value is NULL.
|
||||
type NullTime struct {
|
||||
Time time.Time
|
||||
Valid bool // Valid is true if Time is not NULL
|
||||
}
|
||||
|
||||
func (n *NullTime) Scan(vr *ValueReader) error {
|
||||
oid := vr.Type().DataType
|
||||
if oid != TimestampTzOid && oid != TimestampOid && oid != DateOid {
|
||||
return SerializationError(fmt.Sprintf("NullTime.Scan cannot decode Oid %d", vr.Type().DataType))
|
||||
}
|
||||
|
||||
if vr.Len() == -1 {
|
||||
n.Time, n.Valid = time.Time{}, false
|
||||
return nil
|
||||
}
|
||||
|
||||
n.Valid = true
|
||||
switch oid {
|
||||
case TimestampTzOid:
|
||||
n.Time = decodeTimestampTz(vr)
|
||||
case TimestampOid:
|
||||
n.Time = decodeTimestamp(vr)
|
||||
case DateOid:
|
||||
n.Time = decodeDate(vr)
|
||||
}
|
||||
|
||||
return vr.Err()
|
||||
}
|
||||
|
||||
func (n NullTime) FormatCode() int16 { return BinaryFormatCode }
|
||||
|
||||
func (n NullTime) Encode(w *WriteBuf, oid Oid) error {
|
||||
if oid != TimestampTzOid && oid != TimestampOid && oid != DateOid {
|
||||
return SerializationError(fmt.Sprintf("NullTime.Encode cannot encode into Oid %d", oid))
|
||||
}
|
||||
|
||||
if !n.Valid {
|
||||
w.WriteInt32(-1)
|
||||
return nil
|
||||
}
|
||||
|
||||
return encodeTime(w, oid, n.Time)
|
||||
}
|
||||
|
||||
// Encode encodes arg into wbuf as the type oid. This allows implementations
|
||||
// of the Encoder interface to delegate the actual work of encoding to the
|
||||
// built-in functionality.
|
||||
|
@ -586,8 +213,6 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error {
|
|||
}
|
||||
|
||||
switch arg := arg.(type) {
|
||||
case Encoder:
|
||||
return arg.Encode(wbuf, oid)
|
||||
case pgtype.BinaryEncoder:
|
||||
buf := &bytes.Buffer{}
|
||||
null, err := arg.EncodeBinary(buf)
|
||||
|
|
142
values_test.go
142
values_test.go
|
@ -4,7 +4,6 @@ import (
|
|||
"bytes"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -558,70 +557,6 @@ func TestInetCidrTranscodeWithJustIP(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestNullX(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
type allTypes struct {
|
||||
s pgx.NullString
|
||||
i16 pgx.NullInt16
|
||||
i32 pgx.NullInt32
|
||||
i64 pgx.NullInt64
|
||||
f32 pgx.NullFloat32
|
||||
f64 pgx.NullFloat64
|
||||
b pgx.NullBool
|
||||
t pgx.NullTime
|
||||
}
|
||||
|
||||
var actual, zero allTypes
|
||||
|
||||
tests := []struct {
|
||||
sql string
|
||||
queryArgs []interface{}
|
||||
scanArgs []interface{}
|
||||
expected allTypes
|
||||
}{
|
||||
{"select $1::text", []interface{}{pgx.NullString{String: "foo", Valid: true}}, []interface{}{&actual.s}, allTypes{s: pgx.NullString{String: "foo", Valid: true}}},
|
||||
{"select $1::text", []interface{}{pgx.NullString{String: "foo", Valid: false}}, []interface{}{&actual.s}, allTypes{s: pgx.NullString{String: "", Valid: false}}},
|
||||
{"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: true}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 1, Valid: true}}},
|
||||
{"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: false}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 0, Valid: false}}},
|
||||
{"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 1, Valid: true}}},
|
||||
{"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: false}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 0, Valid: false}}},
|
||||
{"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 1, Valid: true}}},
|
||||
{"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: false}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 0, Valid: false}}},
|
||||
{"select $1::float4", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, allTypes{f32: pgx.NullFloat32{Float32: 1.23, Valid: true}}},
|
||||
{"select $1::float4", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: false}}, []interface{}{&actual.f32}, allTypes{f32: pgx.NullFloat32{Float32: 0, Valid: false}}},
|
||||
{"select $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.f64}, allTypes{f64: pgx.NullFloat64{Float64: 1.23, Valid: true}}},
|
||||
{"select $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: false}}, []interface{}{&actual.f64}, allTypes{f64: pgx.NullFloat64{Float64: 0, Valid: false}}},
|
||||
{"select $1::bool", []interface{}{pgx.NullBool{Bool: true, Valid: true}}, []interface{}{&actual.b}, allTypes{b: pgx.NullBool{Bool: true, Valid: true}}},
|
||||
{"select $1::bool", []interface{}{pgx.NullBool{Bool: true, Valid: false}}, []interface{}{&actual.b}, allTypes{b: pgx.NullBool{Bool: false, Valid: false}}},
|
||||
{"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}},
|
||||
{"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}},
|
||||
{"select $1::timestamp", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}},
|
||||
{"select $1::timestamp", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}},
|
||||
{"select $1::date", []interface{}{pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}},
|
||||
{"select $1::date", []interface{}{pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}},
|
||||
{"select 42::int4, $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.i32, &actual.f64}, allTypes{i32: pgx.NullInt32{Int32: 42, Valid: true}, f64: pgx.NullFloat64{Float64: 1.23, Valid: true}}},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
actual = zero
|
||||
|
||||
err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...)
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs)
|
||||
}
|
||||
|
||||
if actual != tt.expected {
|
||||
t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, tt.sql, tt.queryArgs)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestArrayDecoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -736,36 +671,6 @@ func TestArrayDecoding(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type shortScanner struct{}
|
||||
|
||||
func (*shortScanner) Scan(r *pgx.ValueReader) error {
|
||||
r.ReadByte()
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestShortScanner(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
rows, err := conn.Query("select 'ab', 'cd' union select 'cd', 'ef'")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var s1, s2 shortScanner
|
||||
err = rows.Scan(&s1, &s2)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestEmptyArrayDecoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -814,53 +719,6 @@ func TestEmptyArrayDecoding(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestNullXMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
type allTypes struct {
|
||||
s pgx.NullString
|
||||
i16 pgx.NullInt16
|
||||
i32 pgx.NullInt32
|
||||
i64 pgx.NullInt64
|
||||
f32 pgx.NullFloat32
|
||||
f64 pgx.NullFloat64
|
||||
b pgx.NullBool
|
||||
t pgx.NullTime
|
||||
}
|
||||
|
||||
var actual, zero allTypes
|
||||
|
||||
tests := []struct {
|
||||
sql string
|
||||
queryArgs []interface{}
|
||||
scanArgs []interface{}
|
||||
err string
|
||||
}{
|
||||
{"select $1::date", []interface{}{pgx.NullString{String: "foo", Valid: true}}, []interface{}{&actual.s}, "invalid input syntax for type date"},
|
||||
{"select $1::date", []interface{}{pgx.NullInt16{Int16: 1, Valid: true}}, []interface{}{&actual.i16}, "cannot encode into Oid 1082"},
|
||||
{"select $1::date", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, "cannot encode into Oid 1082"},
|
||||
{"select $1::date", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, "cannot encode into Oid 1082"},
|
||||
{"select $1::date", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, "cannot encode into Oid 1082"},
|
||||
{"select $1::date", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.f64}, "cannot encode into Oid 1082"},
|
||||
{"select $1::date", []interface{}{pgx.NullBool{Bool: true, Valid: true}}, []interface{}{&actual.b}, "cannot encode into Oid 1082"},
|
||||
{"select $1::int4", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, "cannot encode into Oid 23"},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
actual = zero
|
||||
|
||||
err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...)
|
||||
if err == nil || !strings.Contains(err.Error(), tt.err) {
|
||||
t.Errorf(`%d. Expected error to contain "%s", but it didn't: %v`, i, tt.err, err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPointerPointer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
Loading…
Reference in New Issue