diff --git a/bench_test.go b/bench_test.go index b08c2b4e..348c840c 100644 --- a/bench_test.go +++ b/bench_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" ) func BenchmarkConnPool(b *testing.B) { @@ -49,126 +50,6 @@ func BenchmarkConnPoolQueryRow(b *testing.B) { } } -func BenchmarkNullXWithNullValues(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - - _, err := conn.Prepare("selectNulls", "select 1::int4, 'johnsmith', null::text, null::text, null::text, null::date, null::timestamptz") - if err != nil { - b.Fatal(err) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - var record struct { - id int32 - userName string - email pgx.NullString - name pgx.NullString - sex pgx.NullString - birthDate pgx.NullTime - lastLoginTime pgx.NullTime - } - - err = conn.QueryRow("selectNulls").Scan( - &record.id, - &record.userName, - &record.email, - &record.name, - &record.sex, - &record.birthDate, - &record.lastLoginTime, - ) - if err != nil { - b.Fatal(err) - } - - // These checks both ensure that the correct data was returned - // and provide a benchmark of accessing the returned values. - if record.id != 1 { - b.Fatalf("bad value for id: %v", record.id) - } - if record.userName != "johnsmith" { - b.Fatalf("bad value for userName: %v", record.userName) - } - if record.email.Valid { - b.Fatalf("bad value for email: %v", record.email) - } - if record.name.Valid { - b.Fatalf("bad value for name: %v", record.name) - } - if record.sex.Valid { - b.Fatalf("bad value for sex: %v", record.sex) - } - if record.birthDate.Valid { - b.Fatalf("bad value for birthDate: %v", record.birthDate) - } - if record.lastLoginTime.Valid { - b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime) - } - } -} - -func BenchmarkNullXWithPresentValues(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - - _, err := conn.Prepare("selectNulls", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz") - if err != nil { - b.Fatal(err) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - var record struct { - id int32 - userName string - email pgx.NullString - name pgx.NullString - sex pgx.NullString - birthDate pgx.NullTime - lastLoginTime pgx.NullTime - } - - err = conn.QueryRow("selectNulls").Scan( - &record.id, - &record.userName, - &record.email, - &record.name, - &record.sex, - &record.birthDate, - &record.lastLoginTime, - ) - if err != nil { - b.Fatal(err) - } - - // These checks both ensure that the correct data was returned - // and provide a benchmark of accessing the returned values. - if record.id != 1 { - b.Fatalf("bad value for id: %v", record.id) - } - if record.userName != "johnsmith" { - b.Fatalf("bad value for userName: %v", record.userName) - } - if !record.email.Valid || record.email.String != "johnsmith@example.com" { - b.Fatalf("bad value for email: %v", record.email) - } - if !record.name.Valid || record.name.String != "John Smith" { - b.Fatalf("bad value for name: %v", record.name) - } - if !record.sex.Valid || record.sex.String != "male" { - b.Fatalf("bad value for sex: %v", record.sex) - } - if !record.birthDate.Valid || record.birthDate.Time != time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local) { - b.Fatalf("bad value for birthDate: %v", record.birthDate) - } - if !record.lastLoginTime.Valid || record.lastLoginTime.Time != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) { - b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime) - } - } -} - func BenchmarkPointerPointerWithNullValues(b *testing.B) { conn := mustConnect(b, *defaultConnConfig) defer closeConn(b, conn) @@ -475,12 +356,12 @@ func newBenchmarkWriteTableCopyToSrc(count int) pgx.CopyToSource { row: []interface{}{ "varchar_1", "varchar_2", - pgx.NullString{}, + pgtype.Text{}, time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), - pgx.NullTime{}, + pgtype.Date{}, 1, 2, - pgx.NullInt32{}, + pgtype.Int4{}, time.Date(2001, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2002, 1, 1, 0, 0, 0, 0, time.Local), true, diff --git a/conn.go b/conn.go index d541e942..ae83fc77 100644 --- a/conn.go +++ b/conn.go @@ -1003,9 +1003,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(int16(len(ps.ParameterOids))) for i, oid := range ps.ParameterOids { - switch arg := arguments[i].(type) { - case Encoder: - wbuf.WriteInt16(arg.FormatCode()) + switch arguments[i].(type) { case pgtype.BinaryEncoder: wbuf.WriteInt16(BinaryFormatCode) case pgtype.TextEncoder: diff --git a/example_custom_type_test.go b/example_custom_type_test.go index 74fbab67..71110f85 100644 --- a/example_custom_type_test.go +++ b/example_custom_type_test.go @@ -1,78 +1,63 @@ package pgx_test import ( - "errors" "fmt" + "io" "regexp" "strconv" "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" ) var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`) -// NullPoint represents a point that may be null. -// -// If Valid is false then the value is NULL. -type NullPoint struct { - X, Y float64 // Coordinates of point - Valid bool // Valid is true if not NULL +// Point represents a point that may be null. +type Point struct { + X, Y float64 // Coordinates of point + Status pgtype.Status } -func (p *NullPoint) ScanPgx(vr *pgx.ValueReader) error { - if vr.Type().DataTypeName != "point" { - return pgx.SerializationError(fmt.Sprintf("NullPoint.Scan cannot decode %s (Oid %d)", vr.Type().DataTypeName, vr.Type().DataType)) - } - - if vr.Len() == -1 { - p.X, p.Y, p.Valid = 0, 0, false +func (dst *Point) DecodeText(src []byte) error { + if src == nil { + *dst = Point{Status: pgtype.Null} return nil } - switch vr.Type().FormatCode { - case pgx.TextFormatCode: - s := vr.ReadString(vr.Len()) - match := pointRegexp.FindStringSubmatch(s) - if match == nil { - return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s)) - } - - var err error - p.X, err = strconv.ParseFloat(match[1], 64) - if err != nil { - return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s)) - } - p.Y, err = strconv.ParseFloat(match[2], 64) - if err != nil { - return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s)) - } - case pgx.BinaryFormatCode: - return errors.New("binary format not implemented") - default: - return fmt.Errorf("unknown format %v", vr.Type().FormatCode) + s := string(src) + match := pointRegexp.FindStringSubmatch(s) + if match == nil { + return fmt.Errorf("Received invalid point: %v", s) } - p.Valid = true - return vr.Err() -} - -func (p NullPoint) FormatCode() int16 { return pgx.BinaryFormatCode } - -func (p NullPoint) Encode(w *pgx.WriteBuf, oid pgx.Oid) error { - if !p.Valid { - w.WriteInt32(-1) - return nil + x, err := strconv.ParseFloat(match[1], 64) + if err != nil { + return fmt.Errorf("Received invalid point: %v", s) + } + y, err := strconv.ParseFloat(match[2], 64) + if err != nil { + return fmt.Errorf("Received invalid point: %v", s) } - s := fmt.Sprintf("point(%v,%v)", p.X, p.Y) - w.WriteInt32(int32(len(s))) - w.WriteBytes([]byte(s)) + *dst = Point{X: x, Y: y, Status: pgtype.Present} return nil } -func (p NullPoint) String() string { - if p.Valid { +func (src Point) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case pgtype.Null: + return true, nil + case pgtype.Undefined: + return false, fmt.Errorf("undefined") + } + + _, err := io.WriteString(w, fmt.Sprintf("point(%v,%v)", src.X, src.Y)) + return false, err +} + +func (p Point) String() string { + if p.Status == pgtype.Present { return fmt.Sprintf("%v, %v", p.X, p.Y) } return "null point" @@ -85,7 +70,7 @@ func Example_CustomType() { return } - var p NullPoint + var p Point err = conn.QueryRow("select null::point").Scan(&p) if err != nil { fmt.Println(err) diff --git a/query.go b/query.go index d8caa08d..63ce91ed 100644 --- a/query.go +++ b/query.go @@ -211,16 +211,6 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { *b = nil } } - } else if s, ok := d.(Scanner); ok { - err = s.Scan(vr) - if err != nil { - rows.Fatal(scanArgError{col: i, err: err}) - } - } else if s, ok := d.(PgxScanner); ok { - err = s.ScanPgx(vr) - if err != nil { - rows.Fatal(scanArgError{col: i, err: err}) - } } else if s, ok := d.(pgtype.BinaryDecoder); ok && vr.Type().FormatCode == BinaryFormatCode { err = s.DecodeBinary(vr.bytes()) if err != nil { diff --git a/query_test.go b/query_test.go index 46b012cf..8838329c 100644 --- a/query_test.go +++ b/query_test.go @@ -270,44 +270,6 @@ func TestConnQueryScanIgnoreColumn(t *testing.T) { ensureConnValid(t, conn) } -func TestConnQueryScanner(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - rows, err := conn.Query("select null::int8, 1::int8") - if err != nil { - t.Fatalf("conn.Query failed: %v", err) - } - - ok := rows.Next() - if !ok { - t.Fatal("rows.Next terminated early") - } - - var n, m pgx.NullInt64 - err = rows.Scan(&n, &m) - if err != nil { - t.Fatalf("rows.Scan failed: %v", err) - } - rows.Close() - - if n.Valid { - t.Error("Null should not be valid, but it was") - } - - if !m.Valid { - t.Error("1 should be valid, but it wasn't") - } - - if m.Int64 != 1 { - t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64) - } - - ensureConnValid(t, conn) -} - func TestConnQueryErrorWhileReturningRows(t *testing.T) { t.Parallel() @@ -339,42 +301,6 @@ func TestConnQueryErrorWhileReturningRows(t *testing.T) { } -func TestConnQueryEncoder(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - n := pgx.NullInt64{Int64: 1, Valid: true} - - rows, err := conn.Query("select $1::int8", &n) - if err != nil { - t.Fatalf("conn.Query failed: %v", err) - } - - ok := rows.Next() - if !ok { - t.Fatal("rows.Next terminated early") - } - - var m pgx.NullInt64 - err = rows.Scan(&m) - if err != nil { - t.Fatalf("rows.Scan failed: %v", err) - } - rows.Close() - - if !m.Valid { - t.Error("m should be valid, but it wasn't") - } - - if m.Int64 != 1 { - t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64) - } - - ensureConnValid(t, conn) -} - func TestQueryEncodeError(t *testing.T) { t.Parallel() @@ -397,35 +323,6 @@ func TestQueryEncodeError(t *testing.T) { } } -// Ensure that an argument that implements Encoder works when the parameter type -// is a core type. -type coreEncoder struct{} - -func (n coreEncoder) FormatCode() int16 { return pgx.TextFormatCode } - -func (n *coreEncoder) Encode(w *pgx.WriteBuf, oid pgx.Oid) error { - w.WriteInt32(int32(2)) - w.WriteBytes([]byte("42")) - return nil -} - -func TestQueryEncodeCoreTextFormatError(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - var n int32 - err := conn.QueryRow("select $1::integer", &coreEncoder{}).Scan(&n) - if err != nil { - t.Fatalf("Unexpected conn.QueryRow error: %v", err) - } - - if n != 42 { - t.Errorf("Expected 42, got %v", n) - } -} - func TestQueryRowCoreTypes(t *testing.T) { t.Parallel() diff --git a/values.go b/values.go index e1c8f731..e2b30087 100644 --- a/values.go +++ b/values.go @@ -159,245 +159,6 @@ func (e SerializationError) Error() string { return string(e) } -// Deprecated: Scanner is an interface used to decode values from the PostgreSQL -// server. To allow types to support pgx and database/sql.Scan this interface -// has been deprecated in favor of PgxScanner. -type Scanner interface { - // Scan MUST check r.Type().DataType (to check by Oid) or - // r.Type().DataTypeName (to check by name) to ensure that it is scanning an - // expected column type. It also MUST check r.Type().FormatCode before - // decoding. It should not assume that it was called on a data type or format - // that it understands. - Scan(r *ValueReader) error -} - -// PgxScanner is an interface used to decode values from the PostgreSQL server. -// It is used exactly the same as the Scanner interface. It simply has renamed -// the method. -type PgxScanner interface { - // ScanPgx MUST check r.Type().DataType (to check by Oid) or - // r.Type().DataTypeName (to check by name) to ensure that it is scanning an - // expected column type. It also MUST check r.Type().FormatCode before - // decoding. It should not assume that it was called on a data type or format - // that it understands. - ScanPgx(r *ValueReader) error -} - -// Encoder is an interface used to encode values for transmission to the -// PostgreSQL server. -type Encoder interface { - // Encode writes the value to w. - // - // If the value is NULL an int32(-1) should be written. - // - // Encode MUST check oid to see if the parameter data type is compatible. If - // this is not done, the PostgreSQL server may detect the error if the - // expected data size or format of the encoded data does not match. But if - // the encoded data is a valid representation of the data type PostgreSQL - // expects such as date and int4, incorrect data may be stored. - Encode(w *WriteBuf, oid Oid) error - - // FormatCode returns the format that the encoder writes the value. It must be - // either pgx.TextFormatCode or pgx.BinaryFormatCode. - FormatCode() int16 -} - -// NullFloat32 represents an float4 that may be null. NullFloat32 implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullFloat32 struct { - Float32 float32 - Valid bool // Valid is true if Float32 is not NULL -} - -func (n *NullFloat32) Scan(vr *ValueReader) error { - if vr.Type().DataType != Float4Oid { - return SerializationError(fmt.Sprintf("NullFloat32.Scan cannot decode Oid %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Float32, n.Valid = 0, false - return nil - } - n.Valid = true - n.Float32 = decodeFloat4(vr) - return vr.Err() -} - -func (n NullFloat32) FormatCode() int16 { return BinaryFormatCode } - -func (n NullFloat32) Encode(w *WriteBuf, oid Oid) error { - if oid != Float4Oid { - return SerializationError(fmt.Sprintf("NullFloat32.Encode cannot encode into Oid %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeFloat32(w, oid, n.Float32) -} - -// NullFloat64 represents an float8 that may be null. NullFloat64 implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullFloat64 struct { - Float64 float64 - Valid bool // Valid is true if Float64 is not NULL -} - -func (n *NullFloat64) Scan(vr *ValueReader) error { - if vr.Type().DataType != Float8Oid { - return SerializationError(fmt.Sprintf("NullFloat64.Scan cannot decode Oid %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Float64, n.Valid = 0, false - return nil - } - n.Valid = true - n.Float64 = decodeFloat8(vr) - return vr.Err() -} - -func (n NullFloat64) FormatCode() int16 { return BinaryFormatCode } - -func (n NullFloat64) Encode(w *WriteBuf, oid Oid) error { - if oid != Float8Oid { - return SerializationError(fmt.Sprintf("NullFloat64.Encode cannot encode into Oid %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeFloat64(w, oid, n.Float64) -} - -// NullString represents an string that may be null. NullString implements the -// Scanner Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullString struct { - String string - Valid bool // Valid is true if String is not NULL -} - -func (n *NullString) Scan(vr *ValueReader) error { - // Not checking oid as so we can scan anything into into a NullString - may revisit this decision later - - if vr.Len() == -1 { - n.String, n.Valid = "", false - return nil - } - - n.Valid = true - n.String = decodeText(vr) - return vr.Err() -} - -func (n NullString) FormatCode() int16 { return TextFormatCode } - -func (s NullString) Encode(w *WriteBuf, oid Oid) error { - if !s.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeString(w, oid, s.String) -} - -// NullInt16 represents a smallint that may be null. NullInt16 implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan for prepared and unprepared queries. -// -// If Valid is false then the value is NULL. -type NullInt16 struct { - Int16 int16 - Valid bool // Valid is true if Int16 is not NULL -} - -func (n *NullInt16) Scan(vr *ValueReader) error { - if vr.Type().DataType != Int2Oid { - return SerializationError(fmt.Sprintf("NullInt16.Scan cannot decode Oid %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Int16, n.Valid = 0, false - return nil - } - n.Valid = true - n.Int16 = decodeInt2(vr) - return vr.Err() -} - -func (n NullInt16) FormatCode() int16 { return BinaryFormatCode } - -func (n NullInt16) Encode(w *WriteBuf, oid Oid) error { - if oid != Int2Oid { - return SerializationError(fmt.Sprintf("NullInt16.Encode cannot encode into Oid %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - w.WriteInt32(2) - - _, err := pgtype.Int2{Int: n.Int16, Status: pgtype.Present}.EncodeBinary(w) - return err -} - -// NullInt32 represents an integer that may be null. NullInt32 implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullInt32 struct { - Int32 int32 - Valid bool // Valid is true if Int32 is not NULL -} - -func (n *NullInt32) Scan(vr *ValueReader) error { - if vr.Type().DataType != Int4Oid { - return SerializationError(fmt.Sprintf("NullInt32.Scan cannot decode Oid %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Int32, n.Valid = 0, false - return nil - } - n.Valid = true - n.Int32 = decodeInt4(vr) - return vr.Err() -} - -func (n NullInt32) FormatCode() int16 { return BinaryFormatCode } - -func (n NullInt32) Encode(w *WriteBuf, oid Oid) error { - if oid != Int4Oid { - return SerializationError(fmt.Sprintf("NullInt32.Encode cannot encode into Oid %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - w.WriteInt32(4) - - _, err := pgtype.Int4{Int: n.Int32, Status: pgtype.Present}.EncodeBinary(w) - return err -} - // Oid (Object Identifier Type) is, according to https://www.postgresql.org/docs/current/static/datatype-oid.html, // used internally by PostgreSQL as a primary key for various system tables. It is currently implemented // as an unsigned four-byte integer. Its definition can be found in src/include/postgres_ext.h @@ -442,140 +203,6 @@ func (src Oid) EncodeBinary(w io.Writer) (bool, error) { return false, err } -// NullInt64 represents an bigint that may be null. NullInt64 implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullInt64 struct { - Int64 int64 - Valid bool // Valid is true if Int64 is not NULL -} - -func (n *NullInt64) Scan(vr *ValueReader) error { - if vr.Type().DataType != Int8Oid { - return SerializationError(fmt.Sprintf("NullInt64.Scan cannot decode Oid %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Int64, n.Valid = 0, false - return nil - } - n.Valid = true - n.Int64 = decodeInt8(vr) - return vr.Err() -} - -func (n NullInt64) FormatCode() int16 { return BinaryFormatCode } - -func (n NullInt64) Encode(w *WriteBuf, oid Oid) error { - if oid != Int8Oid { - return SerializationError(fmt.Sprintf("NullInt64.Encode cannot encode into Oid %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - w.WriteInt32(8) - - _, err := pgtype.Int8{Int: n.Int64, Status: pgtype.Present}.EncodeBinary(w) - return err -} - -// NullBool represents an bool that may be null. NullBool implements the Scanner -// and Encoder interfaces so it may be used both as an argument to Query[Row] -// and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullBool struct { - Bool bool - Valid bool // Valid is true if Bool is not NULL -} - -func (n *NullBool) Scan(vr *ValueReader) error { - if vr.Type().DataType != BoolOid { - return SerializationError(fmt.Sprintf("NullBool.Scan cannot decode Oid %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Bool, n.Valid = false, false - return nil - } - n.Valid = true - n.Bool = decodeBool(vr) - return vr.Err() -} - -func (n NullBool) FormatCode() int16 { return BinaryFormatCode } - -func (n NullBool) Encode(w *WriteBuf, oid Oid) error { - if oid != BoolOid { - return SerializationError(fmt.Sprintf("NullBool.Encode cannot encode into Oid %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - w.WriteInt32(1) - - _, err := pgtype.Bool{Bool: n.Bool, Status: pgtype.Present}.EncodeBinary(w) - return err -} - -// NullTime represents an time.Time that may be null. NullTime implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. It corresponds with the PostgreSQL -// types timestamptz, timestamp, and date. -// -// If Valid is false then the value is NULL. -type NullTime struct { - Time time.Time - Valid bool // Valid is true if Time is not NULL -} - -func (n *NullTime) Scan(vr *ValueReader) error { - oid := vr.Type().DataType - if oid != TimestampTzOid && oid != TimestampOid && oid != DateOid { - return SerializationError(fmt.Sprintf("NullTime.Scan cannot decode Oid %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Time, n.Valid = time.Time{}, false - return nil - } - - n.Valid = true - switch oid { - case TimestampTzOid: - n.Time = decodeTimestampTz(vr) - case TimestampOid: - n.Time = decodeTimestamp(vr) - case DateOid: - n.Time = decodeDate(vr) - } - - return vr.Err() -} - -func (n NullTime) FormatCode() int16 { return BinaryFormatCode } - -func (n NullTime) Encode(w *WriteBuf, oid Oid) error { - if oid != TimestampTzOid && oid != TimestampOid && oid != DateOid { - return SerializationError(fmt.Sprintf("NullTime.Encode cannot encode into Oid %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeTime(w, oid, n.Time) -} - // Encode encodes arg into wbuf as the type oid. This allows implementations // of the Encoder interface to delegate the actual work of encoding to the // built-in functionality. @@ -586,8 +213,6 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { } switch arg := arg.(type) { - case Encoder: - return arg.Encode(wbuf, oid) case pgtype.BinaryEncoder: buf := &bytes.Buffer{} null, err := arg.EncodeBinary(buf) diff --git a/values_test.go b/values_test.go index 7b82d456..69a91d4e 100644 --- a/values_test.go +++ b/values_test.go @@ -4,7 +4,6 @@ import ( "bytes" "net" "reflect" - "strings" "testing" "time" @@ -558,70 +557,6 @@ func TestInetCidrTranscodeWithJustIP(t *testing.T) { } } -func TestNullX(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - type allTypes struct { - s pgx.NullString - i16 pgx.NullInt16 - i32 pgx.NullInt32 - i64 pgx.NullInt64 - f32 pgx.NullFloat32 - f64 pgx.NullFloat64 - b pgx.NullBool - t pgx.NullTime - } - - var actual, zero allTypes - - tests := []struct { - sql string - queryArgs []interface{} - scanArgs []interface{} - expected allTypes - }{ - {"select $1::text", []interface{}{pgx.NullString{String: "foo", Valid: true}}, []interface{}{&actual.s}, allTypes{s: pgx.NullString{String: "foo", Valid: true}}}, - {"select $1::text", []interface{}{pgx.NullString{String: "foo", Valid: false}}, []interface{}{&actual.s}, allTypes{s: pgx.NullString{String: "", Valid: false}}}, - {"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: true}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 1, Valid: true}}}, - {"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: false}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 0, Valid: false}}}, - {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 1, Valid: true}}}, - {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: false}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 0, Valid: false}}}, - {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 1, Valid: true}}}, - {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: false}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 0, Valid: false}}}, - {"select $1::float4", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, allTypes{f32: pgx.NullFloat32{Float32: 1.23, Valid: true}}}, - {"select $1::float4", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: false}}, []interface{}{&actual.f32}, allTypes{f32: pgx.NullFloat32{Float32: 0, Valid: false}}}, - {"select $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.f64}, allTypes{f64: pgx.NullFloat64{Float64: 1.23, Valid: true}}}, - {"select $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: false}}, []interface{}{&actual.f64}, allTypes{f64: pgx.NullFloat64{Float64: 0, Valid: false}}}, - {"select $1::bool", []interface{}{pgx.NullBool{Bool: true, Valid: true}}, []interface{}{&actual.b}, allTypes{b: pgx.NullBool{Bool: true, Valid: true}}}, - {"select $1::bool", []interface{}{pgx.NullBool{Bool: true, Valid: false}}, []interface{}{&actual.b}, allTypes{b: pgx.NullBool{Bool: false, Valid: false}}}, - {"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}}, - {"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, - {"select $1::timestamp", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}}, - {"select $1::timestamp", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, - {"select $1::date", []interface{}{pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}}, - {"select $1::date", []interface{}{pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, - {"select 42::int4, $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.i32, &actual.f64}, allTypes{i32: pgx.NullInt32{Int32: 42, Valid: true}, f64: pgx.NullFloat64{Float64: 1.23, Valid: true}}}, - } - - for i, tt := range tests { - actual = zero - - err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) - if err != nil { - t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs) - } - - if actual != tt.expected { - t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, tt.sql, tt.queryArgs) - } - - ensureConnValid(t, conn) - } -} - func TestArrayDecoding(t *testing.T) { t.Parallel() @@ -736,36 +671,6 @@ func TestArrayDecoding(t *testing.T) { } } -type shortScanner struct{} - -func (*shortScanner) Scan(r *pgx.ValueReader) error { - r.ReadByte() - return nil -} - -func TestShortScanner(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - rows, err := conn.Query("select 'ab', 'cd' union select 'cd', 'ef'") - if err != nil { - t.Error(err) - } - defer rows.Close() - - for rows.Next() { - var s1, s2 shortScanner - err = rows.Scan(&s1, &s2) - if err != nil { - t.Error(err) - } - } - - ensureConnValid(t, conn) -} - func TestEmptyArrayDecoding(t *testing.T) { t.Parallel() @@ -814,53 +719,6 @@ func TestEmptyArrayDecoding(t *testing.T) { ensureConnValid(t, conn) } -func TestNullXMismatch(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - type allTypes struct { - s pgx.NullString - i16 pgx.NullInt16 - i32 pgx.NullInt32 - i64 pgx.NullInt64 - f32 pgx.NullFloat32 - f64 pgx.NullFloat64 - b pgx.NullBool - t pgx.NullTime - } - - var actual, zero allTypes - - tests := []struct { - sql string - queryArgs []interface{} - scanArgs []interface{} - err string - }{ - {"select $1::date", []interface{}{pgx.NullString{String: "foo", Valid: true}}, []interface{}{&actual.s}, "invalid input syntax for type date"}, - {"select $1::date", []interface{}{pgx.NullInt16{Int16: 1, Valid: true}}, []interface{}{&actual.i16}, "cannot encode into Oid 1082"}, - {"select $1::date", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, "cannot encode into Oid 1082"}, - {"select $1::date", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, "cannot encode into Oid 1082"}, - {"select $1::date", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, "cannot encode into Oid 1082"}, - {"select $1::date", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.f64}, "cannot encode into Oid 1082"}, - {"select $1::date", []interface{}{pgx.NullBool{Bool: true, Valid: true}}, []interface{}{&actual.b}, "cannot encode into Oid 1082"}, - {"select $1::int4", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, "cannot encode into Oid 23"}, - } - - for i, tt := range tests { - actual = zero - - err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) - if err == nil || !strings.Contains(err.Error(), tt.err) { - t.Errorf(`%d. Expected error to contain "%s", but it didn't: %v`, i, tt.err, err) - } - - ensureConnValid(t, conn) - } -} - func TestPointerPointer(t *testing.T) { t.Parallel()