diff --git a/.gitignore b/.gitignore index cb0cd901..0ff00800 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ _testmain.go *.exe conn_config_test.go +.envrc diff --git a/.travis.yml b/.travis.yml index d9ea43b0..971b46a9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,7 @@ language: go go: - - 1.7.4 - - 1.6.4 + - 1.8 - tip # Derived from https://github.com/lib/pq/blob/master/.travis.yml @@ -28,6 +27,8 @@ before_install: - sudo /etc/init.d/postgresql restart env: + global: + - PGX_TEST_DATABASE=postgres://pgx_md5:secret@127.0.0.1/pgx_test matrix: - PGVERSION=9.6 - PGVERSION=9.5 @@ -40,7 +41,7 @@ env: before_script: - mv conn_config_test.go.travis conn_config_test.go - psql -U postgres -c 'create database pgx_test' - - "[[ \"${PGVERSION}\" = '9.0' ]] && psql -U postgres -f /usr/share/postgresql/9.0/contrib/hstore.sql pgx_test || psql -U postgres pgx_test -c 'create extension hstore'" + - psql -U postgres pgx_test -c 'create extension hstore' - psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'" - psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" - psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'" @@ -51,9 +52,14 @@ install: - go get -u github.com/shopspring/decimal - go get -u gopkg.in/inconshreveable/log15.v2 - go get -u github.com/jackc/fake + - go get -u github.com/lib/pq + - go get -u github.com/hashicorp/go-version + - go get -u github.com/satori/go.uuid + - go get -u github.com/Sirupsen/logrus + - go get -u github.com/pkg/errors script: - - go test -v -race -short ./... + - go test -v -race ./... matrix: allow_failures: diff --git a/CHANGELOG.md b/CHANGELOG.md index 88c663b0..1b4ab492 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,62 @@ -# Unreleased +# Unreleased V3 + +## Changes + +* Pid to PID in accordance with Go naming conventions. +* Conn.Pid changed to accessor method Conn.PID() +* Conn.SecretKey removed +* Remove Conn.TxStatus +* Logger interface reduced to single Log method. +* Replace BeginIso with BeginEx. BeginEx adds support for read/write mode and deferrable mode. +* Transaction isolation level constants are now typed strings instead of bare strings. +* Conn.WaitForNotification now takes context.Context instead of time.Duration for cancellation support. +* Conn.WaitForNotification no longer automatically pings internally every 15 seconds. +* ReplicationConn.WaitForReplicationMessage now takes context.Context instead of time.Duration for cancellation support. +* Reject scanning binary format values into a string (e.g. binary encoded timestamptz to string). See https://github.com/jackc/pgx/issues/219 and https://github.com/jackc/pgx/issues/228 +* No longer can read raw bytes of any value into a []byte. Use pgtype.GenericBinary if this functionality is needed. +* Remove CopyTo (functionality is now in CopyFrom) +* OID constants moved from pgx to pgtype package +* Replaced Scanner, Encoder, and PgxScanner interfaces with pgtype system +* Removed ValueReader +* ConnPool.Close no longer waits for all acquired connections to be released. Instead, it immediately closes all available connections, and closes acquired connections when they are released in the same manner as ConnPool.Reset. +* Removed Rows.Fatal(error) +* Removed Rows.AfterClose() +* Removed Rows.Conn() +* Removed Tx.AfterClose() +* Removed Tx.Conn() +* Use Go casing convention for OID, UUID, JSON(B), ACLItem, CID, TID, XID, and CIDR +* Replaced stdlib.OpenFromConnPool with DriverConfig system + +## Features + +* Entirely revamped pluggable type system that supports approximately 60 PostgreSQL types. +* Types support database/sql interfaces and therefore can be used with other drivers +* Added context methods supporting cancellation where appropriate +* Added simple query protocol support +* Added single round-trip query mode +* Added batch query operations +* Added OnNotice +* github.com/pkg/errors used where possible for errors +* Added stdlib.DriverConfig which directly allows full configuration of underlying pgx connections without needing to use a pgx.ConnPool +* Added AcquireConn and ReleaseConn to stdlib to allow acquiring a connection from a database/sql connection. + +# 2.11.0 (June 5, 2017) + +## Fixes + +* Fix race with concurrent execution of stdlib.OpenFromConnPool (Terin Stock) + +## Features + +* .pgpass support (j7b) +* Add missing CopyFrom delegators to Tx and ConnPool (Jack Christensen) +* Add ParseConnectionString (James Lawrence) + +## Performance + +* Optimize HStore encoding (ReneĢ Kroon) + +# 2.10.0 (March 17, 2017) ## Fixes diff --git a/README.md b/README.md index 965de95e..570afce3 100644 --- a/README.md +++ b/README.md @@ -1,63 +1,57 @@ [![](https://godoc.org/github.com/jackc/pgx?status.svg)](https://godoc.org/github.com/jackc/pgx) -# Pgx +# pgx - PostgreSQL Driver and Toolkit -## Master Branch - -This is the `master` branch which tracks the stable release of the current -version. At the moment this is `v2`. The `v3` branch which is currently in beta. -General release is planned for July. `v3` is considered to be stable in the -sense of lack of known bugs, but the API is not considered stable until general -release. No further changes are planned, but the beta process may surface -desirable changes. If possible API changes are acceptable, then `v3` is the -recommended branch for new development. Regardless, please lock to the `v2` or -`v3` branch as when `v3` is released breaking changes will be applied to the -master branch. - -Pgx is a pure Go database connection library designed specifically for -PostgreSQL. Pgx is different from other drivers such as -[pq](http://godoc.org/github.com/lib/pq) because, while it can operate as a -database/sql compatible driver, pgx is primarily intended to be used directly. -It offers a native interface similar to database/sql that offers better -performance and more features. +pgx is a pure Go driver and toolkit for PostgreSQL. pgx is different from other drivers such as [pq](http://godoc.org/github.com/lib/pq) because, while it can operate as a database/sql compatible driver, pgx is also usable directly. It offers a native interface similar to database/sql that offers better performance and more features. ## Features -Pgx supports many additional features beyond what is available through database/sql. +pgx supports many additional features beyond what is available through database/sql. -* Listen / notify -* Transaction isolation level control +* Support for approximately 60 different PostgreSQL types +* Batch queries +* Single-round trip query mode * Full TLS connection control * Binary format support for custom types (can be much faster) -* Copy from protocol support for faster bulk data loads -* Logging support -* Configurable connection pool with after connect hooks to do arbitrary connection setup +* Copy protocol support for faster bulk data loads +* Extendable logging support including built-in support for log15 and logrus +* Connection pool with after connect hook to do arbitrary connection setup +* Listen / notify * PostgreSQL array to Go slice mapping for integers, floats, and strings * Hstore support * JSON and JSONB support * Maps inet and cidr PostgreSQL types to net.IPNet and net.IP * Large object support -* Null mapping to Null* struct or pointer to pointer. +* NULL mapping to Null* struct or pointer to pointer. * Supports database/sql.Scanner and database/sql/driver.Valuer interfaces for custom types * Logical replication connections, including receiving WAL and sending standby status updates +* Notice response handling (this is different than listen / notify) ## Performance -Pgx performs roughly equivalent to [pq](http://godoc.org/github.com/lib/pq) and -[go-pg](https://github.com/go-pg/pg) for selecting a single column from a single -row, but it is substantially faster when selecting multiple entire rows (6893 -queries/sec for pgx vs. 3968 queries/sec for pq -- 73% faster). +pgx performs roughly equivalent to [go-pg](https://github.com/go-pg/pg) and is almost always faster than [pq](http://godoc.org/github.com/lib/pq). When parsing large result sets the percentage difference can be significant (16483 queries/sec for pgx vs. 10106 queries/sec for pq -- 63% faster). -See this [gist](https://gist.github.com/jackc/d282f39e088b495fba3e) for the -underlying benchmark results or checkout -[go_db_bench](https://github.com/jackc/go_db_bench) to run tests for yourself. +In many use cases a significant cause of latency is network round trips between the application and the server. pgx supports query batching to bundle multiple queries into a single round trip. Even in the case of a connection with the lowest possible latency, a local Unix domain socket, batching as few as three queries together can yield an improvement of 57%. With a typical network connection the results can be even more substantial. -## database/sql +See this [gist](https://gist.github.com/jackc/4996e8648a0c59839bff644f49d6e434) for the underlying benchmark results or checkout [go_db_bench](https://github.com/jackc/go_db_bench) to run tests for yourself. -Import the ```github.com/jackc/pgx/stdlib``` package to use pgx as a driver for -database/sql. It is possible to retrieve a pgx connection from database/sql on -demand. This allows using the database/sql interface in most places, but using -pgx directly when more performance or PostgreSQL specific features are needed. +In addition to the native driver, pgx also includes a number of packages that provide additional functionality. + +## github.com/jackc/pgxstdlib + +database/sql compatibility layer for pgx. pgx can be used as a normal database/sql driver, but at any time the native interface may be acquired for more performance or PostgreSQL specific functionality. + +## github.com/jackc/pgx/pgtype + +Approximately 60 PostgreSQL types are supported including uuid, hstore, json, bytea, numeric, interval, inet, and arrays. These types support database/sql interfaces and are usable even outside of pgx. They are fully tested in pgx and pq. They also support a higher performance interface when used with the pgx driver. + +## github.com/jackc/pgx/pgproto3 + +pgproto3 provides standalone encoding and decoding of the PostgreSQL v3 wire protocol. This is useful for implementing very low level PostgreSQL tooling. + +## github.com/jackc/pgx/pgmock + +pgmock offers the ability to create a server that mocks the PostgreSQL wire protocol. This is used internally to test pgx by purposely inducing unusual errors. pgproto3 and pgmock together provide most of the foundational tooling required to implement a PostgreSQL proxy or MitM (such as for a custom connection pooler). ## Documentation @@ -132,6 +126,8 @@ Change the following settings in your postgresql.conf: max_wal_senders=5 max_replication_slots=5 +Set `replicationConnConfig` appropriately in `conn_config_test.go`. + ## Version Policy -pgx follows semantic versioning for the documented public API on stable releases. Branch `v2` is the latest stable release. `master` can contain new features or behavior that will change or be removed before being merged to the stable `v2` branch (in practice, this occurs very rarely). +pgx follows semantic versioning for the documented public API on stable releases. Branch `v3` is the latest stable release. `master` can contain new features or behavior that will change or be removed before being merged to the stable `v3` branch (in practice, this occurs very rarely). `v2` is the previous stable release. diff --git a/aclitem_parse_test.go b/aclitem_parse_test.go deleted file mode 100644 index 5c7c748f..00000000 --- a/aclitem_parse_test.go +++ /dev/null @@ -1,126 +0,0 @@ -package pgx - -import ( - "reflect" - "testing" -) - -func TestEscapeAclItem(t *testing.T) { - tests := []struct { - input string - expected string - }{ - { - "foo", - "foo", - }, - { - `foo, "\}`, - `foo\, \"\\\}`, - }, - } - - for i, tt := range tests { - actual, err := escapeAclItem(tt.input) - - if err != nil { - t.Errorf("%d. Unexpected error %v", i, err) - } - - if actual != tt.expected { - t.Errorf("%d.\nexpected: %s,\nactual: %s", i, tt.expected, actual) - } - } -} - -func TestParseAclItemArray(t *testing.T) { - tests := []struct { - input string - expected []AclItem - errMsg string - }{ - { - "", - []AclItem{}, - "", - }, - { - "one", - []AclItem{"one"}, - "", - }, - { - `"one"`, - []AclItem{"one"}, - "", - }, - { - "one,two,three", - []AclItem{"one", "two", "three"}, - "", - }, - { - `"one","two","three"`, - []AclItem{"one", "two", "three"}, - "", - }, - { - `"one",two,"three"`, - []AclItem{"one", "two", "three"}, - "", - }, - { - `one,two,"three"`, - []AclItem{"one", "two", "three"}, - "", - }, - { - `"one","two",three`, - []AclItem{"one", "two", "three"}, - "", - }, - { - `"one","t w o",three`, - []AclItem{"one", "t w o", "three"}, - "", - }, - { - `"one","t, w o\"\}\\",three`, - []AclItem{"one", `t, w o"}\`, "three"}, - "", - }, - { - `"one","two",three"`, - []AclItem{"one", "two", `three"`}, - "", - }, - { - `"one","two,"three"`, - nil, - "unexpected rune after quoted value", - }, - { - `"one","two","three`, - nil, - "unexpected end of quoted value", - }, - } - - for i, tt := range tests { - actual, err := parseAclItemArray(tt.input) - - if err != nil { - if tt.errMsg == "" { - t.Errorf("%d. Unexpected error %v", i, err) - } else if err.Error() != tt.errMsg { - t.Errorf("%d. Expected error %v did not match actual error %v", i, tt.errMsg, err.Error()) - } - } else if tt.errMsg != "" { - t.Errorf("%d. Expected error not returned: \"%v\"", i, tt.errMsg) - } - - if !reflect.DeepEqual(actual, tt.expected) { - t.Errorf("%d. Expected %v did not match actual %v", i, tt.expected, actual) - } - } -} diff --git a/batch.go b/batch.go new file mode 100644 index 00000000..fc6f0d03 --- /dev/null +++ b/batch.go @@ -0,0 +1,246 @@ +package pgx + +import ( + "context" + + "github.com/jackc/pgx/pgproto3" + "github.com/jackc/pgx/pgtype" +) + +type batchItem struct { + query string + arguments []interface{} + parameterOIDs []pgtype.OID + resultFormatCodes []int16 +} + +// Batch queries are a way of bundling multiple queries together to avoid +// unnecessary network round trips. +type Batch struct { + conn *Conn + connPool *ConnPool + items []*batchItem + resultsRead int + sent bool + ctx context.Context + err error +} + +// BeginBatch returns a *Batch query for c. +func (c *Conn) BeginBatch() *Batch { + return &Batch{conn: c} +} + +// Conn returns the underlying connection that b will or was performed on. +func (b *Batch) Conn() *Conn { + return b.conn +} + +// Queue queues a query to batch b. parameterOIDs are required if there are +// parameters and query is not the name of a prepared statement. +// resultFormatCodes are required if there is a result. +func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgtype.OID, resultFormatCodes []int16) { + b.items = append(b.items, &batchItem{ + query: query, + arguments: arguments, + parameterOIDs: parameterOIDs, + resultFormatCodes: resultFormatCodes, + }) +} + +// Send sends all queued queries to the server at once. All queries are wrapped +// in a transaction. The transaction can optionally be configured with +// txOptions. The context is in effect until the Batch is closed. +func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { + if b.err != nil { + return b.err + } + + b.ctx = ctx + + err := b.conn.waitForPreviousCancelQuery(ctx) + if err != nil { + return err + } + + if err := b.conn.ensureConnectionReadyForQuery(); err != nil { + return err + } + + err = b.conn.initContext(ctx) + if err != nil { + return err + } + + buf := appendQuery(b.conn.wbuf, txOptions.beginSQL()) + + for _, bi := range b.items { + var psName string + var psParameterOIDs []pgtype.OID + + if ps, ok := b.conn.preparedStatements[bi.query]; ok { + psName = ps.Name + psParameterOIDs = ps.ParameterOIDs + } else { + psParameterOIDs = bi.parameterOIDs + buf = appendParse(buf, "", bi.query, psParameterOIDs) + } + + var err error + buf, err = appendBind(buf, "", psName, b.conn.ConnInfo, psParameterOIDs, bi.arguments, bi.resultFormatCodes) + if err != nil { + return err + } + + buf = appendDescribe(buf, 'P', "") + buf = appendExecute(buf, "", 0) + } + + buf = appendSync(buf) + buf = appendQuery(buf, "commit") + + n, err := b.conn.conn.Write(buf) + if err != nil { + if fatalWriteErr(n, err) { + b.conn.die(err) + } + return err + } + + // expect ReadyForQuery from sync and from commit + b.conn.pendingReadyForQueryCount = b.conn.pendingReadyForQueryCount + 2 + + b.sent = true + + for { + msg, err := b.conn.rxMsg() + if err != nil { + return err + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + return nil + default: + if err := b.conn.processContextFreeMsg(msg); err != nil { + return err + } + } + } + + return nil +} + +// ExecResults reads the results from the next query in the batch as if the +// query has been sent with Exec. +func (b *Batch) ExecResults() (CommandTag, error) { + if b.err != nil { + return "", b.err + } + + select { + case <-b.ctx.Done(): + b.die(b.ctx.Err()) + return "", b.ctx.Err() + default: + } + + b.resultsRead++ + + for { + msg, err := b.conn.rxMsg() + if err != nil { + return "", err + } + + switch msg := msg.(type) { + case *pgproto3.CommandComplete: + return CommandTag(msg.CommandTag), nil + default: + if err := b.conn.processContextFreeMsg(msg); err != nil { + return "", err + } + } + } +} + +// QueryResults reads the results from the next query in the batch as if the +// query has been sent with Query. +func (b *Batch) QueryResults() (*Rows, error) { + rows := b.conn.getRows("batch query", nil) + + if b.err != nil { + rows.fatal(b.err) + return rows, b.err + } + + select { + case <-b.ctx.Done(): + b.die(b.ctx.Err()) + rows.fatal(b.err) + return rows, b.ctx.Err() + default: + } + + b.resultsRead++ + + fieldDescriptions, err := b.conn.readUntilRowDescription() + if err != nil { + b.die(err) + rows.fatal(b.err) + return rows, err + } + + rows.batch = b + rows.fields = fieldDescriptions + return rows, nil +} + +// QueryRowResults reads the results from the next query in the batch as if the +// query has been sent with QueryRow. +func (b *Batch) QueryRowResults() *Row { + rows, _ := b.QueryResults() + return (*Row)(rows) + +} + +// Close closes the batch operation. Any error that occured during a batch +// operation may have made it impossible to resyncronize the connection with the +// server. In this case the underlying connection will have been closed. +func (b *Batch) Close() (err error) { + if b.err != nil { + return b.err + } + + defer func() { + err = b.conn.termContext(err) + if b.conn != nil && b.connPool != nil { + b.connPool.Release(b.conn) + } + }() + + for i := b.resultsRead; i < len(b.items); i++ { + if _, err = b.ExecResults(); err != nil { + return err + } + } + + if err = b.conn.ensureConnectionReadyForQuery(); err != nil { + return err + } + + return nil +} + +func (b *Batch) die(err error) { + if b.err != nil { + return + } + + b.err = err + b.conn.die(err) + + if b.conn != nil && b.connPool != nil { + b.connPool.Release(b.conn) + } +} diff --git a/batch_test.go b/batch_test.go new file mode 100644 index 00000000..e12e4f32 --- /dev/null +++ b/batch_test.go @@ -0,0 +1,478 @@ +package pgx_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" +) + +func TestConnBeginBatch(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + sql := `create temporary table ledger( + id serial primary key, + description varchar not null, + amount int not null +);` + mustExec(t, conn, sql) + + batch := conn.BeginBatch() + batch.Queue("insert into ledger(description, amount) values($1, $2)", + []interface{}{"q1", 1}, + []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, + nil, + ) + batch.Queue("insert into ledger(description, amount) values($1, $2)", + []interface{}{"q2", 2}, + []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, + nil, + ) + batch.Queue("insert into ledger(description, amount) values($1, $2)", + []interface{}{"q3", 3}, + []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, + nil, + ) + batch.Queue("select id, description, amount from ledger order by id", + nil, + nil, + []int16{pgx.BinaryFormatCode, pgx.TextFormatCode, pgx.BinaryFormatCode}, + ) + batch.Queue("select sum(amount) from ledger", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + ct, err := batch.ExecResults() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) + } + + ct, err = batch.ExecResults() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) + } + + rows, err := batch.QueryResults() + if err != nil { + t.Error(err) + } + + var id int32 + var description string + var amount int32 + if !rows.Next() { + t.Fatal("expected a row to be available") + } + if err := rows.Scan(&id, &description, &amount); err != nil { + t.Fatal(err) + } + if id != 1 { + t.Errorf("id => %v, want %v", id, 1) + } + if description != "q1" { + t.Errorf("description => %v, want %v", description, "q1") + } + if amount != 1 { + t.Errorf("amount => %v, want %v", amount, 1) + } + + if !rows.Next() { + t.Fatal("expected a row to be available") + } + if err := rows.Scan(&id, &description, &amount); err != nil { + t.Fatal(err) + } + if id != 2 { + t.Errorf("id => %v, want %v", id, 2) + } + if description != "q2" { + t.Errorf("description => %v, want %v", description, "q2") + } + if amount != 2 { + t.Errorf("amount => %v, want %v", amount, 2) + } + + if !rows.Next() { + t.Fatal("expected a row to be available") + } + if err := rows.Scan(&id, &description, &amount); err != nil { + t.Fatal(err) + } + if id != 3 { + t.Errorf("id => %v, want %v", id, 3) + } + if description != "q3" { + t.Errorf("description => %v, want %v", description, "q3") + } + if amount != 3 { + t.Errorf("amount => %v, want %v", amount, 3) + } + + if rows.Next() { + t.Fatal("did not expect a row to be available") + } + + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + + err = batch.QueryRowResults().Scan(&amount) + if err != nil { + t.Error(err) + } + if amount != 6 { + t.Errorf("amount => %v, want %v", amount, 6) + } + + err = batch.Close() + if err != nil { + t.Fatal(err) + } + + ensureConnValid(t, conn) +} + +func TestConnBeginBatchWithPreparedStatement(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + _, err := conn.Prepare("ps1", "select n from generate_series(0,$1::int) n") + if err != nil { + t.Fatal(err) + } + + batch := conn.BeginBatch() + + queryCount := 3 + for i := 0; i < queryCount; i++ { + batch.Queue("ps1", + []interface{}{5}, + nil, + []int16{pgx.BinaryFormatCode}, + ) + } + + err = batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < queryCount; i++ { + rows, err := batch.QueryResults() + if err != nil { + t.Fatal(err) + } + + for k := 0; rows.Next(); k++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Fatal(err) + } + if n != k { + t.Fatalf("n => %v, want %v", n, k) + } + } + + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + } + + err = batch.Close() + if err != nil { + t.Fatal(err) + } + + ensureConnValid(t, conn) +} + +func TestConnBeginBatchContextCancelBeforeExecResults(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + + sql := `create temporary table ledger( + id serial primary key, + description varchar not null, + amount int not null +);` + mustExec(t, conn, sql) + + batch := conn.BeginBatch() + batch.Queue("insert into ledger(description, amount) values($1, $2)", + []interface{}{"q1", 1}, + []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, + nil, + ) + batch.Queue("select pg_sleep(2)", + nil, + nil, + nil, + ) + + ctx, cancelFn := context.WithCancel(context.Background()) + + err := batch.Send(ctx, nil) + if err != nil { + t.Fatal(err) + } + + cancelFn() + + _, err = batch.ExecResults() + if err != context.Canceled { + t.Errorf("err => %v, want %v", err, context.Canceled) + } + + if conn.IsAlive() { + t.Error("conn should be dead, but was alive") + } +} + +func TestConnBeginBatchContextCancelBeforeQueryResults(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + + batch := conn.BeginBatch() + batch.Queue("select pg_sleep(2)", + nil, + nil, + nil, + ) + batch.Queue("select pg_sleep(2)", + nil, + nil, + nil, + ) + + ctx, cancelFn := context.WithCancel(context.Background()) + + err := batch.Send(ctx, nil) + if err != nil { + t.Fatal(err) + } + + cancelFn() + + _, err = batch.QueryResults() + if err != context.Canceled { + t.Errorf("err => %v, want %v", err, context.Canceled) + } + + if conn.IsAlive() { + t.Error("conn should be dead, but was alive") + } +} + +func TestConnBeginBatchContextCancelBeforeFinish(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + + batch := conn.BeginBatch() + batch.Queue("select pg_sleep(2)", + nil, + nil, + nil, + ) + batch.Queue("select pg_sleep(2)", + nil, + nil, + nil, + ) + + ctx, cancelFn := context.WithCancel(context.Background()) + + err := batch.Send(ctx, nil) + if err != nil { + t.Fatal(err) + } + + cancelFn() + + err = batch.Close() + if err != context.Canceled { + t.Errorf("err => %v, want %v", err, context.Canceled) + } + + if conn.IsAlive() { + t.Error("conn should be dead, but was alive") + } +} + +func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + batch := conn.BeginBatch() + batch.Queue("select n from generate_series(0,5) n", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + batch.Queue("select n from generate_series(0,5) n", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + rows, err := batch.QueryResults() + if err != nil { + t.Error(err) + } + + for i := 0; i < 3; i++ { + if !rows.Next() { + t.Error("expected a row to be available") + } + + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } + } + + rows.Close() + + rows, err = batch.QueryResults() + if err != nil { + t.Error(err) + } + + for i := 0; rows.Next(); i++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } + } + + if rows.Err() != nil { + t.Error(rows.Err()) + } + + err = batch.Close() + if err != nil { + t.Fatal(err) + } + + ensureConnValid(t, conn) +} + +func TestConnBeginBatchQueryError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + batch := conn.BeginBatch() + batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + batch.Queue("select n from generate_series(0,5) n", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + rows, err := batch.QueryResults() + if err != nil { + t.Error(err) + } + + for i := 0; rows.Next(); i++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } + } + + if pgErr, ok := rows.Err().(pgx.PgError); !(ok && pgErr.Code == "22012") { + t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012) + } + + err = batch.Close() + if pgErr, ok := err.(pgx.PgError); !(ok && pgErr.Code == "22012") { + t.Errorf("rows.Err() => %v, want error code %v", err, 22012) + } + + if conn.IsAlive() { + t.Error("conn should be dead, but was alive") + } +} + +func TestConnBeginBatchQuerySyntaxError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + batch := conn.BeginBatch() + batch.Queue("select 1 1", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + var n int32 + err = batch.QueryRowResults().Scan(&n) + if pgErr, ok := err.(pgx.PgError); !(ok && pgErr.Code == "42601") { + t.Errorf("rows.Err() => %v, want error code %v", err, 42601) + } + + err = batch.Close() + if err == nil { + t.Error("Expected error") + } + + if conn.IsAlive() { + t.Error("conn should be dead, but was alive") + } +} diff --git a/bench-tmp_test.go b/bench-tmp_test.go new file mode 100644 index 00000000..a8e3f7db --- /dev/null +++ b/bench-tmp_test.go @@ -0,0 +1,55 @@ +package pgx_test + +import ( + "testing" +) + +func BenchmarkPgtypeInt4ParseBinary(b *testing.B) { + conn := mustConnect(b, *defaultConnConfig) + defer closeConn(b, conn) + + _, err := conn.Prepare("selectBinary", "select n::int4 from generate_series(1, 100) n") + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var n int32 + + rows, err := conn.Query("selectBinary") + if err != nil { + b.Fatal(err) + } + + for rows.Next() { + err := rows.Scan(&n) + if err != nil { + b.Fatal(err) + } + } + + if rows.Err() != nil { + b.Fatal(rows.Err()) + } + } +} + +func BenchmarkPgtypeInt4EncodeBinary(b *testing.B) { + conn := mustConnect(b, *defaultConnConfig) + defer closeConn(b, conn) + + _, err := conn.Prepare("encodeBinary", "select $1::int4, $2::int4, $3::int4, $4::int4, $5::int4, $6::int4, $7::int4") + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + rows, err := conn.Query("encodeBinary", int32(i), int32(i), int32(i), int32(i), int32(i), int32(i), int32(i)) + if err != nil { + b.Fatal(err) + } + rows.Close() + } +} diff --git a/bench_test.go b/bench_test.go index 30e31e2a..7f82891e 100644 --- a/bench_test.go +++ b/bench_test.go @@ -2,13 +2,14 @@ package pgx_test import ( "bytes" + "context" "fmt" "strings" "testing" "time" "github.com/jackc/pgx" - log "gopkg.in/inconshreveable/log15.v2" + "github.com/jackc/pgx/pgtype" ) func BenchmarkConnPool(b *testing.B) { @@ -50,126 +51,6 @@ func BenchmarkConnPoolQueryRow(b *testing.B) { } } -func BenchmarkNullXWithNullValues(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - - _, err := conn.Prepare("selectNulls", "select 1::int4, 'johnsmith', null::text, null::text, null::text, null::date, null::timestamptz") - if err != nil { - b.Fatal(err) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - var record struct { - id int32 - userName string - email pgx.NullString - name pgx.NullString - sex pgx.NullString - birthDate pgx.NullTime - lastLoginTime pgx.NullTime - } - - err = conn.QueryRow("selectNulls").Scan( - &record.id, - &record.userName, - &record.email, - &record.name, - &record.sex, - &record.birthDate, - &record.lastLoginTime, - ) - if err != nil { - b.Fatal(err) - } - - // These checks both ensure that the correct data was returned - // and provide a benchmark of accessing the returned values. - if record.id != 1 { - b.Fatalf("bad value for id: %v", record.id) - } - if record.userName != "johnsmith" { - b.Fatalf("bad value for userName: %v", record.userName) - } - if record.email.Valid { - b.Fatalf("bad value for email: %v", record.email) - } - if record.name.Valid { - b.Fatalf("bad value for name: %v", record.name) - } - if record.sex.Valid { - b.Fatalf("bad value for sex: %v", record.sex) - } - if record.birthDate.Valid { - b.Fatalf("bad value for birthDate: %v", record.birthDate) - } - if record.lastLoginTime.Valid { - b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime) - } - } -} - -func BenchmarkNullXWithPresentValues(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - - _, err := conn.Prepare("selectNulls", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz") - if err != nil { - b.Fatal(err) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - var record struct { - id int32 - userName string - email pgx.NullString - name pgx.NullString - sex pgx.NullString - birthDate pgx.NullTime - lastLoginTime pgx.NullTime - } - - err = conn.QueryRow("selectNulls").Scan( - &record.id, - &record.userName, - &record.email, - &record.name, - &record.sex, - &record.birthDate, - &record.lastLoginTime, - ) - if err != nil { - b.Fatal(err) - } - - // These checks both ensure that the correct data was returned - // and provide a benchmark of accessing the returned values. - if record.id != 1 { - b.Fatalf("bad value for id: %v", record.id) - } - if record.userName != "johnsmith" { - b.Fatalf("bad value for userName: %v", record.userName) - } - if !record.email.Valid || record.email.String != "johnsmith@example.com" { - b.Fatalf("bad value for email: %v", record.email) - } - if !record.name.Valid || record.name.String != "John Smith" { - b.Fatalf("bad value for name: %v", record.name) - } - if !record.sex.Valid || record.sex.String != "male" { - b.Fatalf("bad value for sex: %v", record.sex) - } - if !record.birthDate.Valid || record.birthDate.Time != time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local) { - b.Fatalf("bad value for birthDate: %v", record.birthDate) - } - if !record.lastLoginTime.Valid || record.lastLoginTime.Time != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) { - b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime) - } - } -} - func BenchmarkPointerPointerWithNullValues(b *testing.B) { conn := mustConnect(b, *defaultConnConfig) defer closeConn(b, conn) @@ -297,71 +178,51 @@ func BenchmarkSelectWithoutLogging(b *testing.B) { benchmarkSelectWithLog(b, conn) } -func BenchmarkSelectWithLoggingTraceWithLog15(b *testing.B) { - connConfig := *defaultConnConfig +type discardLogger struct{} - logger := log.New() - lvl, err := log.LvlFromString("debug") - if err != nil { - b.Fatal(err) - } - logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler())) - connConfig.Logger = logger - connConfig.LogLevel = pgx.LogLevelTrace - conn := mustConnect(b, connConfig) +func (dl discardLogger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {} + +func BenchmarkSelectWithLoggingTraceDiscard(b *testing.B) { + conn := mustConnect(b, *defaultConnConfig) defer closeConn(b, conn) + var logger discardLogger + conn.SetLogger(logger) + conn.SetLogLevel(pgx.LogLevelTrace) + benchmarkSelectWithLog(b, conn) } -func BenchmarkSelectWithLoggingDebugWithLog15(b *testing.B) { - connConfig := *defaultConnConfig - - logger := log.New() - lvl, err := log.LvlFromString("debug") - if err != nil { - b.Fatal(err) - } - logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler())) - connConfig.Logger = logger - connConfig.LogLevel = pgx.LogLevelDebug - conn := mustConnect(b, connConfig) +func BenchmarkSelectWithLoggingDebugWithDiscard(b *testing.B) { + conn := mustConnect(b, *defaultConnConfig) defer closeConn(b, conn) + var logger discardLogger + conn.SetLogger(logger) + conn.SetLogLevel(pgx.LogLevelDebug) + benchmarkSelectWithLog(b, conn) } -func BenchmarkSelectWithLoggingInfoWithLog15(b *testing.B) { - connConfig := *defaultConnConfig - - logger := log.New() - lvl, err := log.LvlFromString("info") - if err != nil { - b.Fatal(err) - } - logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler())) - connConfig.Logger = logger - connConfig.LogLevel = pgx.LogLevelInfo - conn := mustConnect(b, connConfig) +func BenchmarkSelectWithLoggingInfoWithDiscard(b *testing.B) { + conn := mustConnect(b, *defaultConnConfig) defer closeConn(b, conn) + var logger discardLogger + conn.SetLogger(logger) + conn.SetLogLevel(pgx.LogLevelInfo) + benchmarkSelectWithLog(b, conn) } -func BenchmarkSelectWithLoggingErrorWithLog15(b *testing.B) { - connConfig := *defaultConnConfig - - logger := log.New() - lvl, err := log.LvlFromString("error") - if err != nil { - b.Fatal(err) - } - logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler())) - connConfig.Logger = logger - connConfig.LogLevel = pgx.LogLevelError - conn := mustConnect(b, connConfig) +func BenchmarkSelectWithLoggingErrorWithDiscard(b *testing.B) { + conn := mustConnect(b, *defaultConnConfig) defer closeConn(b, conn) + var logger discardLogger + conn.SetLogger(logger) + conn.SetLogLevel(pgx.LogLevelError) + benchmarkSelectWithLog(b, conn) } @@ -422,20 +283,6 @@ func benchmarkSelectWithLog(b *testing.B, conn *pgx.Conn) { } } -func BenchmarkLog15Discard(b *testing.B) { - logger := log.New() - lvl, err := log.LvlFromString("error") - if err != nil { - b.Fatal(err) - } - logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler())) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - logger.Debug("benchmark", "i", i, "b.N", b.N) - } -} - const benchmarkWriteTableCreateSQL = `drop table if exists t; create table t( @@ -510,12 +357,12 @@ func newBenchmarkWriteTableCopyFromSrc(count int) pgx.CopyFromSource { row: []interface{}{ "varchar_1", "varchar_2", - pgx.NullString{}, + pgtype.Text{}, time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), - pgx.NullTime{}, + pgtype.Date{}, 1, 2, - pgx.NullInt32{}, + pgtype.Int4{}, time.Date(2001, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2002, 1, 1, 0, 0, 0, 0, time.Local), true, @@ -763,3 +610,92 @@ func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) { func BenchmarkWrite10000RowsViaCopy(b *testing.B) { benchmarkWriteNRowsViaCopy(b, 10000) } + +func BenchmarkMultipleQueriesNonBatch(b *testing.B) { + config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 5} + pool, err := pgx.NewConnPool(config) + if err != nil { + b.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + queryCount := 3 + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for j := 0; j < queryCount; j++ { + rows, err := pool.Query("select n from generate_series(0, 5) n") + if err != nil { + b.Fatal(err) + } + + for k := 0; rows.Next(); k++ { + var n int + if err := rows.Scan(&n); err != nil { + b.Fatal(err) + } + if n != k { + b.Fatalf("n => %v, want %v", n, k) + } + } + + if rows.Err() != nil { + b.Fatal(rows.Err()) + } + } + } +} + +func BenchmarkMultipleQueriesBatch(b *testing.B) { + config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 5} + pool, err := pgx.NewConnPool(config) + if err != nil { + b.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + queryCount := 3 + + b.ResetTimer() + for i := 0; i < b.N; i++ { + batch := pool.BeginBatch() + for j := 0; j < queryCount; j++ { + batch.Queue("select n from generate_series(0,5) n", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + } + + err := batch.Send(context.Background(), nil) + if err != nil { + b.Fatal(err) + } + + for j := 0; j < queryCount; j++ { + rows, err := batch.QueryResults() + if err != nil { + b.Fatal(err) + } + + for k := 0; rows.Next(); k++ { + var n int + if err := rows.Scan(&n); err != nil { + b.Fatal(err) + } + if n != k { + b.Fatalf("n => %v, want %v", n, k) + } + } + + if rows.Err() != nil { + b.Fatal(rows.Err()) + } + } + + err = batch.Close() + if err != nil { + b.Fatal(err) + } + } +} diff --git a/chunkreader/chunkreader.go b/chunkreader/chunkreader.go new file mode 100644 index 00000000..f8d437b2 --- /dev/null +++ b/chunkreader/chunkreader.go @@ -0,0 +1,89 @@ +package chunkreader + +import ( + "io" +) + +type ChunkReader struct { + r io.Reader + + buf []byte + rp, wp int // buf read position and write position + + options Options +} + +type Options struct { + MinBufLen int // Minimum buffer length +} + +func NewChunkReader(r io.Reader) *ChunkReader { + cr, err := NewChunkReaderEx(r, Options{}) + if err != nil { + panic("default options can't be bad") + } + + return cr +} + +func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) { + if options.MinBufLen == 0 { + options.MinBufLen = 4096 + } + + return &ChunkReader{ + r: r, + buf: make([]byte, options.MinBufLen), + options: options, + }, nil +} + +// Next returns buf filled with the next n bytes. If an error occurs, buf will +// be nil. +func (r *ChunkReader) Next(n int) (buf []byte, err error) { + // n bytes already in buf + if (r.wp - r.rp) >= n { + buf = r.buf[r.rp : r.rp+n] + r.rp += n + return buf, err + } + + // available space in buf is less than n + if len(r.buf) < n { + r.copyBufContents(r.newBuf(n)) + } + + // buf is large enough, but need to shift filled area to start to make enough contiguous space + minReadCount := n - (r.wp - r.rp) + if (len(r.buf) - r.wp) < minReadCount { + newBuf := r.newBuf(n) + r.copyBufContents(newBuf) + } + + if err := r.appendAtLeast(minReadCount); err != nil { + return nil, err + } + + buf = r.buf[r.rp : r.rp+n] + r.rp += n + return buf, nil +} + +func (r *ChunkReader) appendAtLeast(fillLen int) error { + n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen) + r.wp += n + return err +} + +func (r *ChunkReader) newBuf(size int) []byte { + if size < r.options.MinBufLen { + size = r.options.MinBufLen + } + return make([]byte, size) +} + +func (r *ChunkReader) copyBufContents(dest []byte) { + r.wp = copy(dest, r.buf[r.rp:r.wp]) + r.rp = 0 + r.buf = dest +} diff --git a/chunkreader/chunkreader_test.go b/chunkreader/chunkreader_test.go new file mode 100644 index 00000000..3be07e3c --- /dev/null +++ b/chunkreader/chunkreader_test.go @@ -0,0 +1,96 @@ +package chunkreader + +import ( + "bytes" + "testing" +) + +func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { + server := &bytes.Buffer{} + r, err := NewChunkReaderEx(server, Options{MinBufLen: 4}) + if err != nil { + t.Fatal(err) + } + + src := []byte{1, 2, 3, 4} + server.Write(src) + + n1, err := r.Next(2) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(n1, src[0:2]) != 0 { + t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:2], n1) + } + + n2, err := r.Next(2) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(n2, src[2:4]) != 0 { + t.Fatalf("Expected read bytes to be %v, but they were %v", src[2:4], n2) + } + + if bytes.Compare(r.buf, src) != 0 { + t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf) + } + if r.rp != 4 { + t.Fatalf("Expected r.rp to be %v, but it was %v", 4, r.rp) + } + if r.wp != 4 { + t.Fatalf("Expected r.wp to be %v, but it was %v", 4, r.wp) + } +} + +func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) { + server := &bytes.Buffer{} + r, err := NewChunkReaderEx(server, Options{MinBufLen: 4}) + if err != nil { + t.Fatal(err) + } + + src := []byte{1, 2, 3, 4, 5, 6, 7, 8} + server.Write(src) + + n1, err := r.Next(5) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(n1, src[0:5]) != 0 { + t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:5], n1) + } + if len(r.buf) != 5 { + t.Fatalf("Expected len(r.buf) to be %v, but it was %v", 5, len(r.buf)) + } +} + +func TestChunkReaderDoesNotReuseBuf(t *testing.T) { + server := &bytes.Buffer{} + r, err := NewChunkReaderEx(server, Options{MinBufLen: 4}) + if err != nil { + t.Fatal(err) + } + + src := []byte{1, 2, 3, 4, 5, 6, 7, 8} + server.Write(src) + + n1, err := r.Next(4) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(n1, src[0:4]) != 0 { + t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:4], n1) + } + + n2, err := r.Next(4) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(n2, src[4:8]) != 0 { + t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2) + } + + if bytes.Compare(n1, src[0:4]) != 0 { + t.Fatalf("Expected KeepLast to prevent Next from overwriting buf, expected %v but it was %v", src[0:4], n1) + } +} diff --git a/conn.go b/conn.go index a2d60e7e..0d51228d 100644 --- a/conn.go +++ b/conn.go @@ -1,12 +1,11 @@ package pgx import ( - "bufio" + "context" "crypto/md5" "crypto/tls" "encoding/binary" "encoding/hex" - "errors" "fmt" "io" "net" @@ -17,9 +16,45 @@ import ( "regexp" "strconv" "strings" + "sync" + "sync/atomic" "time" + + "github.com/pkg/errors" + + "github.com/jackc/pgx/pgio" + "github.com/jackc/pgx/pgproto3" + "github.com/jackc/pgx/pgtype" ) +const ( + connStatusUninitialized = iota + connStatusClosed + connStatusIdle + connStatusBusy +) + +// minimalConnInfo has just enough static type information to establish the +// connection and retrieve the type data. +var minimalConnInfo *pgtype.ConnInfo + +func init() { + minimalConnInfo = pgtype.NewConnInfo() + minimalConnInfo.InitializeDataTypes(map[string]pgtype.OID{ + "int4": pgtype.Int4OID, + "name": pgtype.NameOID, + "oid": pgtype.OIDOID, + "text": pgtype.TextOID, + }) +} + +// NoticeHandler is a function that can handle notices received from the +// PostgreSQL server. Notices can be received at any time, usually during +// handling of a query response. The *Conn is provided so the handler is aware +// of the origin of the notice, but it must not invoke any query method. Be +// aware that this is distinct from LISTEN/NOTIFY notification. +type NoticeHandler func(*Conn, *Notice) + // DialFunc is a function that can be used to connect to a PostgreSQL server type DialFunc func(network, addr string) (net.Conn, error) @@ -37,37 +72,63 @@ type ConnConfig struct { LogLevel int Dial DialFunc RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) + OnNotice NoticeHandler // Callback function called when a notice response is received. +} + +func (cc *ConnConfig) networkAddress() (network, address string) { + network = "tcp" + address = fmt.Sprintf("%s:%d", cc.Host, cc.Port) + // See if host is a valid path, if yes connect with a socket + if _, err := os.Stat(cc.Host); err == nil { + // For backward compatibility accept socket file paths -- but directories are now preferred + network = "unix" + address = cc.Host + if !strings.Contains(address, "/.s.PGSQL.") { + address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10) + } + } + + return network, address } // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. // Use ConnPool to manage access to multiple database connections from multiple // goroutines. type Conn struct { - conn net.Conn // the underlying TCP or unix domain socket connection - lastActivityTime time.Time // the last time the connection was used - reader *bufio.Reader // buffered reader to improve read performance - wbuf [1024]byte - writeBuf WriteBuf - Pid int32 // backend pid - SecretKey int32 // key to use to send a cancel query message to the server + conn net.Conn // the underlying TCP or unix domain socket connection + lastActivityTime time.Time // the last time the connection was used + wbuf []byte + pid uint32 // backend pid + secretKey uint32 // key to use to send a cancel query message to the server RuntimeParams map[string]string // parameters that have been reported by the server - PgTypes map[Oid]PgType // oids to PgTypes config ConnConfig // config used when establishing this connection - TxStatus byte + txStatus byte preparedStatements map[string]*PreparedStatement channels map[string]struct{} notifications []*Notification - alive bool - causeOfDeath error logger Logger logLevel int - mr msgReader fp *fastpath - pgsqlAfInet *byte - pgsqlAfInet6 *byte - busy bool poolResetCount int preallocatedRows []Rows + onNotice NoticeHandler + + mux sync.Mutex + status byte // One of connStatus* constants + causeOfDeath error + + pendingReadyForQueryCount int // numer of ReadyForQuery messages expected + cancelQueryInProgress int32 + cancelQueryCompleted chan struct{} + + // context support + ctxInProgress bool + doneChan chan struct{} + closedChan chan error + + ConnInfo *pgtype.ConnInfo + + frontend *pgproto3.Frontend } // PreparedStatement is a description of a prepared statement @@ -75,27 +136,21 @@ type PreparedStatement struct { Name string SQL string FieldDescriptions []FieldDescription - ParameterOids []Oid + ParameterOIDs []pgtype.OID } // PrepareExOptions is an option struct that can be passed to PrepareEx type PrepareExOptions struct { - ParameterOids []Oid + ParameterOIDs []pgtype.OID } // Notification is a message received from the PostgreSQL LISTEN/NOTIFY system type Notification struct { - Pid int32 // backend pid that sent the notification + PID uint32 // backend pid that sent the notification Channel string // channel from which notification was received Payload string } -// PgType is information about PostgreSQL type and how to encode and decode it -type PgType struct { - Name string // name of type e.g. int4, text, date - DefaultFormat int16 // default format (text or binary) this type will be requested in -} - // CommandTag is the result of an Exec function type CommandTag string @@ -127,9 +182,6 @@ func (ident Identifier) Sanitize() string { // ErrNoRows occurs when rows are expected but none are returned. var ErrNoRows = errors.New("no rows in result set") -// ErrNotificationTimeout occurs when WaitForNotification times out. -var ErrNotificationTimeout = errors.New("notification timeout") - // ErrDeadConn occurs on an attempt to use a dead connection var ErrDeadConn = errors.New("conn is dead") @@ -155,29 +207,14 @@ func (e ProtocolError) Error() string { // config.Host must be specified. config.User will default to the OS user name. // Other config fields are optional. func Connect(config ConnConfig) (c *Conn, err error) { - return connect(config, nil, nil, nil) + return connect(config, minimalConnInfo) } -func connect(config ConnConfig, pgTypes map[Oid]PgType, pgsqlAfInet *byte, pgsqlAfInet6 *byte) (c *Conn, err error) { +func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) { c = new(Conn) c.config = config - - if pgTypes != nil { - c.PgTypes = make(map[Oid]PgType, len(pgTypes)) - for k, v := range pgTypes { - c.PgTypes[k] = v - } - } - - if pgsqlAfInet != nil { - c.pgsqlAfInet = new(byte) - *c.pgsqlAfInet = *pgsqlAfInet - } - if pgsqlAfInet6 != nil { - c.pgsqlAfInet6 = new(byte) - *c.pgsqlAfInet6 = *pgsqlAfInet6 - } + c.ConnInfo = connInfo if c.config.LogLevel != 0 { c.logLevel = c.config.LogLevel @@ -186,8 +223,6 @@ func connect(config ConnConfig, pgTypes map[Oid]PgType, pgsqlAfInet *byte, pgsql c.logLevel = LogLevelDebug } c.logger = c.config.Logger - c.mr.log = c.log - c.mr.shouldLog = c.shouldLog if c.config.User == "" { user, err := user.Current() @@ -196,46 +231,38 @@ func connect(config ConnConfig, pgTypes map[Oid]PgType, pgsqlAfInet *byte, pgsql } c.config.User = user.Username if c.shouldLog(LogLevelDebug) { - c.log(LogLevelDebug, "Using default connection config", "User", c.config.User) + c.log(LogLevelDebug, "Using default connection config", map[string]interface{}{"User": c.config.User}) } } if c.config.Port == 0 { c.config.Port = 5432 if c.shouldLog(LogLevelDebug) { - c.log(LogLevelDebug, "Using default connection config", "Port", c.config.Port) + c.log(LogLevelDebug, "Using default connection config", map[string]interface{}{"Port": c.config.Port}) } } - network := "tcp" - address := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port) - // See if host is a valid path, if yes connect with a socket - if _, err := os.Stat(c.config.Host); err == nil { - // For backward compatibility accept socket file paths -- but directories are now preferred - network = "unix" - address = c.config.Host - if !strings.Contains(address, "/.s.PGSQL.") { - address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10) - } - } + c.onNotice = config.OnNotice + + network, address := c.config.networkAddress() if c.config.Dial == nil { c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial } if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address)) + c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address}) } err = c.connect(config, network, address, config.TLSConfig) if err != nil && config.UseFallbackTLS { if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, fmt.Sprintf("Connect with TLSConfig failed, trying FallbackTLSConfig: %v", err)) + c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err}) } err = c.connect(config, network, address, config.FallbackTLSConfig) } if err != nil { if c.shouldLog(LogLevelError) { - c.log(LogLevelError, fmt.Sprintf("Connect failed: %v", err)) + c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err}) } return nil, err } @@ -251,88 +278,95 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl defer func() { if c != nil && err != nil { c.conn.Close() - c.alive = false + c.mux.Lock() + c.status = connStatusClosed + c.mux.Unlock() } }() c.RuntimeParams = make(map[string]string) c.preparedStatements = make(map[string]*PreparedStatement) c.channels = make(map[string]struct{}) - c.alive = true c.lastActivityTime = time.Now() + c.cancelQueryCompleted = make(chan struct{}, 1) + c.doneChan = make(chan struct{}) + c.closedChan = make(chan error) + c.wbuf = make([]byte, 0, 1024) + + c.mux.Lock() + c.status = connStatusIdle + c.mux.Unlock() if tlsConfig != nil { if c.shouldLog(LogLevelDebug) { - c.log(LogLevelDebug, "Starting TLS handshake") + c.log(LogLevelDebug, "starting TLS handshake", nil) } if err := c.startTLS(tlsConfig); err != nil { return err } } - c.reader = bufio.NewReader(c.conn) - c.mr.reader = c.reader + c.frontend, err = pgproto3.NewFrontend(c.conn, c.conn) + if err != nil { + return err + } - msg := newStartupMessage() + startupMsg := pgproto3.StartupMessage{ + ProtocolVersion: pgproto3.ProtocolVersionNumber, + Parameters: make(map[string]string), + } // Default to disabling TLS renegotiation. // // Go does not support (https://github.com/golang/go/issues/5742) // PostgreSQL recommends disabling (http://www.postgresql.org/docs/9.4/static/runtime-config-connection.html#GUC-SSL-RENEGOTIATION-LIMIT) if tlsConfig != nil { - msg.options["ssl_renegotiation_limit"] = "0" + startupMsg.Parameters["ssl_renegotiation_limit"] = "0" } // Copy default run-time params for k, v := range config.RuntimeParams { - msg.options[k] = v + startupMsg.Parameters[k] = v } - msg.options["user"] = c.config.User + startupMsg.Parameters["user"] = c.config.User if c.config.Database != "" { - msg.options["database"] = c.config.Database + startupMsg.Parameters["database"] = c.config.Database } - if err = c.txStartupMessage(msg); err != nil { + if _, err := c.conn.Write(startupMsg.Encode(nil)); err != nil { return err } + c.pendingReadyForQueryCount = 1 + for { - var t byte - var r *msgReader - t, r, err = c.rxMsg() + msg, err := c.rxMsg() if err != nil { return err } - switch t { - case backendKeyData: - c.rxBackendKeyData(r) - case authenticationX: - if err = c.rxAuthenticationX(r); err != nil { + switch msg := msg.(type) { + case *pgproto3.BackendKeyData: + c.rxBackendKeyData(msg) + case *pgproto3.Authentication: + if err = c.rxAuthenticationX(msg); err != nil { return err } - case readyForQuery: - c.rxReadyForQuery(r) + case *pgproto3.ReadyForQuery: + c.rxReadyForQuery(msg) if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "Connection established") + c.log(LogLevelInfo, "connection established", nil) } // Replication connections can't execute the queries to // populate the c.PgTypes and c.pgsqlAfInet - if _, ok := msg.options["replication"]; ok { + if _, ok := config.RuntimeParams["replication"]; ok { return nil } - if c.PgTypes == nil { - err = c.loadPgTypes() - if err != nil { - return err - } - } - - if c.pgsqlAfInet == nil || c.pgsqlAfInet6 == nil { - err = c.loadInetConstants() + if c.ConnInfo == minimalConnInfo { + err = c.initConnInfo() if err != nil { return err } @@ -340,77 +374,146 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl return nil default: - if err = c.processContextFreeMsg(t, r); err != nil { + if err = c.processContextFreeMsg(msg); err != nil { return err } } } } -func (c *Conn) loadPgTypes() error { +func (c *Conn) initConnInfo() error { + nameOIDs := make(map[string]pgtype.OID, 256) + rows, err := c.Query(`select t.oid, t.typname from pg_type t left join pg_type base_type on t.typelem=base_type.oid where ( - t.typtype='b' - and (base_type.oid is null or base_type.typtype='b') - ) - or t.typname in('record');`) + t.typtype in('b', 'p', 'r', 'e') + and (base_type.oid is null or base_type.typtype in('b', 'p', 'r')) + )`) if err != nil { return err } - c.PgTypes = make(map[Oid]PgType, 128) - for rows.Next() { - var oid Oid - var t PgType + var oid pgtype.OID + var name pgtype.Text + if err := rows.Scan(&oid, &name); err != nil { + return err + } - rows.Scan(&oid, &t.Name) - - // The zero value is text format so we ignore any types without a default type format - t.DefaultFormat, _ = DefaultTypeFormats[t.Name] - - c.PgTypes[oid] = t + nameOIDs[name.String] = oid } - return rows.Err() + if rows.Err() != nil { + return rows.Err() + } + + c.ConnInfo = pgtype.NewConnInfo() + c.ConnInfo.InitializeDataTypes(nameOIDs) + return nil } -// Family is needed for binary encoding of inet/cidr. The constant is based on -// the server's definition of AF_INET. In theory, this could differ between -// platforms, so request an IPv4 and an IPv6 inet and get the family from that. -func (c *Conn) loadInetConstants() error { - var ipv4, ipv6 []byte - - err := c.QueryRow("select '127.0.0.1'::inet, '1::'::inet").Scan(&ipv4, &ipv6) - if err != nil { - return err - } - - c.pgsqlAfInet = &ipv4[0] - c.pgsqlAfInet6 = &ipv6[0] - - return nil +// PID returns the backend PID for this connection. +func (c *Conn) PID() uint32 { + return c.pid } // Close closes a connection. It is safe to call Close on a already closed // connection. func (c *Conn) Close() (err error) { - if !c.IsAlive() { + c.mux.Lock() + defer c.mux.Unlock() + + if c.status < connStatusIdle { return nil } + c.status = connStatusClosed - wbuf := newWriteBuf(c, 'X') - wbuf.closeMsg() + defer func() { + c.conn.Close() + c.causeOfDeath = errors.New("Closed") + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, "closed connection", nil) + } + }() - _, err = c.conn.Write(wbuf.buf) - - c.die(errors.New("Closed")) - if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "Closed connection") + err = c.conn.SetDeadline(time.Time{}) + if err != nil && c.shouldLog(LogLevelWarn) { + c.log(LogLevelWarn, "failed to clear deadlines to send close message", map[string]interface{}{"err": err}) + return err } - return err + + _, err = c.conn.Write([]byte{'X', 0, 0, 0, 4}) + if err != nil && c.shouldLog(LogLevelWarn) { + c.log(LogLevelWarn, "failed to send terminate message", map[string]interface{}{"err": err}) + return err + } + + err = c.conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + if err != nil && c.shouldLog(LogLevelWarn) { + c.log(LogLevelWarn, "failed to set read deadline to finish closing", map[string]interface{}{"err": err}) + return err + } + + _, err = c.conn.Read(make([]byte, 1)) + if err != io.EOF { + return err + } + + return nil +} + +// Merge returns a new ConnConfig with the attributes of old and other +// combined. When an attribute is set on both, other takes precedence. +// +// As a security precaution, if the other TLSConfig is nil, all old TLS +// attributes will be preserved. +func (old ConnConfig) Merge(other ConnConfig) ConnConfig { + cc := old + + if other.Host != "" { + cc.Host = other.Host + } + if other.Port != 0 { + cc.Port = other.Port + } + if other.Database != "" { + cc.Database = other.Database + } + if other.User != "" { + cc.User = other.User + } + if other.Password != "" { + cc.Password = other.Password + } + + if other.TLSConfig != nil { + cc.TLSConfig = other.TLSConfig + cc.UseFallbackTLS = other.UseFallbackTLS + cc.FallbackTLSConfig = other.FallbackTLSConfig + } + + if other.Logger != nil { + cc.Logger = other.Logger + } + if other.LogLevel != 0 { + cc.LogLevel = other.LogLevel + } + + if other.Dial != nil { + cc.Dial = other.Dial + } + + cc.RuntimeParams = make(map[string]string) + for k, v := range old.RuntimeParams { + cc.RuntimeParams[k] = v + } + for k, v := range other.RuntimeParams { + cc.RuntimeParams[k] = v + } + + return cc } // ParseURI parses a database URI into ConnConfig @@ -626,7 +729,7 @@ func configSSL(sslmode string, cc *ConnConfig) error { // name and sql arguments. This allows a code path to Prepare and Query/Exec without // concern for if the statement has already been prepared. func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { - return c.PrepareEx(name, sql, nil) + return c.PrepareEx(context.Background(), name, sql, nil) } // PrepareEx creates a prepared statement with name and sql. sql can contain placeholders @@ -636,83 +739,95 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { // PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same // name and sql arguments. This allows a code path to PrepareEx and Query/Exec without // concern for if the statement has already been prepared. -func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { +func (c *Conn) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { + err = c.waitForPreviousCancelQuery(ctx) + if err != nil { + return nil, err + } + + err = c.initContext(ctx) + if err != nil { + return nil, err + } + + ps, err = c.prepareEx(name, sql, opts) + err = c.termContext(err) + return ps, err +} + +func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { if name != "" { if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql { return ps, nil } } + if err := c.ensureConnectionReadyForQuery(); err != nil { + return nil, err + } + if c.shouldLog(LogLevelError) { defer func() { if err != nil { - c.log(LogLevelError, fmt.Sprintf("Prepare `%s` as `%s` failed: %v", name, sql, err)) + c.log(LogLevelError, "prepareEx failed", map[string]interface{}{"err": err, "name": name, "sql": sql}) } }() } - // parse - wbuf := newWriteBuf(c, 'P') - wbuf.WriteCString(name) - wbuf.WriteCString(sql) - - if opts != nil { - if len(opts.ParameterOids) > 65535 { - return nil, fmt.Errorf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids)) - } - wbuf.WriteInt16(int16(len(opts.ParameterOids))) - for _, oid := range opts.ParameterOids { - wbuf.WriteInt32(int32(oid)) - } - } else { - wbuf.WriteInt16(0) + if opts == nil { + opts = &PrepareExOptions{} } - // describe - wbuf.startMsg('D') - wbuf.WriteByte('S') - wbuf.WriteCString(name) + if len(opts.ParameterOIDs) > 65535 { + return nil, errors.Errorf("Number of PrepareExOptions ParameterOIDs must be between 0 and 65535, received %d", len(opts.ParameterOIDs)) + } - // sync - wbuf.startMsg('S') - wbuf.closeMsg() + buf := appendParse(c.wbuf, name, sql, opts.ParameterOIDs) + buf = appendDescribe(buf, 'S', name) + buf = appendSync(buf) - _, err = c.conn.Write(wbuf.buf) + n, err := c.conn.Write(buf) if err != nil { - c.die(err) + if fatalWriteErr(n, err) { + c.die(err) + } return nil, err } + c.pendingReadyForQueryCount++ ps = &PreparedStatement{Name: name, SQL: sql} var softErr error for { - var t byte - var r *msgReader - t, r, err := c.rxMsg() + msg, err := c.rxMsg() if err != nil { return nil, err } - switch t { - case parseComplete: - case parameterDescription: - ps.ParameterOids = c.rxParameterDescription(r) + switch msg := msg.(type) { + case *pgproto3.ParameterDescription: + ps.ParameterOIDs = c.rxParameterDescription(msg) - if len(ps.ParameterOids) > 65535 && softErr == nil { - softErr = fmt.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOids)) + if len(ps.ParameterOIDs) > 65535 && softErr == nil { + softErr = errors.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOIDs)) } - case rowDescription: - ps.FieldDescriptions = c.rxRowDescription(r) + case *pgproto3.RowDescription: + ps.FieldDescriptions = c.rxRowDescription(msg) for i := range ps.FieldDescriptions { - t, _ := c.PgTypes[ps.FieldDescriptions[i].DataType] - ps.FieldDescriptions[i].DataTypeName = t.Name - ps.FieldDescriptions[i].FormatCode = t.DefaultFormat + if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok { + ps.FieldDescriptions[i].DataTypeName = dt.Name + if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { + ps.FieldDescriptions[i].FormatCode = BinaryFormatCode + } else { + ps.FieldDescriptions[i].FormatCode = TextFormatCode + } + } else { + return nil, errors.Errorf("unknown oid: %d", ps.FieldDescriptions[i].DataType) + } } - case noData: - case readyForQuery: - c.rxReadyForQuery(r) + case *pgproto3.ReadyForQuery: + c.rxReadyForQuery(msg) if softErr == nil { c.preparedStatements[name] = ps @@ -720,7 +835,7 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared return ps, softErr default: - if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { + if e := c.processContextFreeMsg(msg); e != nil && softErr == nil { softErr = e } } @@ -728,37 +843,62 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared } // Deallocate released a prepared statement -func (c *Conn) Deallocate(name string) (err error) { +func (c *Conn) Deallocate(name string) error { + return c.deallocateContext(context.Background(), name) +} + +// TODO - consider making this public +func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) { + err = c.waitForPreviousCancelQuery(ctx) + if err != nil { + return err + } + + err = c.initContext(ctx) + if err != nil { + return err + } + defer func() { + err = c.termContext(err) + }() + + if err := c.ensureConnectionReadyForQuery(); err != nil { + return err + } + delete(c.preparedStatements, name) // close - wbuf := newWriteBuf(c, 'C') - wbuf.WriteByte('S') - wbuf.WriteCString(name) + buf := c.wbuf + buf = append(buf, 'C') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, 'S') + buf = append(buf, name...) + buf = append(buf, 0) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) // flush - wbuf.startMsg('H') - wbuf.closeMsg() + buf = append(buf, 'H') + buf = pgio.AppendInt32(buf, 4) - _, err = c.conn.Write(wbuf.buf) + _, err = c.conn.Write(buf) if err != nil { c.die(err) return err } for { - var t byte - var r *msgReader - t, r, err := c.rxMsg() + msg, err := c.rxMsg() if err != nil { return err } - switch t { - case closeComplete: + switch msg.(type) { + case *pgproto3.CloseComplete: return nil default: - err = c.processContextFreeMsg(t, r) + err = c.processContextFreeMsg(msg) if err != nil { return err } @@ -789,9 +929,8 @@ func (c *Conn) Unlisten(channel string) error { return nil } -// WaitForNotification waits for a PostgreSQL notification for up to timeout. -// If the timeout occurs it returns pgx.ErrNotificationTimeout -func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) { +// WaitForNotification waits for a PostgreSQL notification. +func (c *Conn) WaitForNotification(ctx context.Context) (notification *Notification, err error) { // Return already received notification immediately if len(c.notifications) > 0 { notification := c.notifications[0] @@ -799,86 +938,40 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) return notification, nil } - stopTime := time.Now().Add(timeout) - - for { - now := time.Now() - - if now.After(stopTime) { - return nil, ErrNotificationTimeout - } - - // If there has been no activity on this connection for a while send a nop message just to ensure - // the connection is alive - nextEnsureAliveTime := c.lastActivityTime.Add(15 * time.Second) - if nextEnsureAliveTime.Before(now) { - // If the server can't respond to a nop in 15 seconds, assume it's dead - err := c.conn.SetReadDeadline(now.Add(15 * time.Second)) - if err != nil { - return nil, err - } - - _, err = c.Exec("--;") - if err != nil { - return nil, err - } - - c.lastActivityTime = now - } - - var deadline time.Time - if stopTime.Before(nextEnsureAliveTime) { - deadline = stopTime - } else { - deadline = nextEnsureAliveTime - } - - notification, err := c.waitForNotification(deadline) - if err != ErrNotificationTimeout { - return notification, err - } + err = c.waitForPreviousCancelQuery(ctx) + if err != nil { + return nil, err } -} -func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) { - var zeroTime time.Time + err = c.initContext(ctx) + if err != nil { + return nil, err + } + defer func() { + err = c.termContext(err) + }() + + if err = c.lock(); err != nil { + return nil, err + } + defer func() { + if unlockErr := c.unlock(); unlockErr != nil && err == nil { + err = unlockErr + } + }() + + if err := c.ensureConnectionReadyForQuery(); err != nil { + return nil, err + } for { - // Use SetReadDeadline to implement the timeout. SetReadDeadline will - // cause operations to fail with a *net.OpError that has a Timeout() - // of true. Because the normal pgx rxMsg path considers any error to - // have potentially corrupted the state of the connection, it dies - // on any errors. So to avoid timeout errors in rxMsg we set the - // deadline and peek into the reader. If a timeout error occurs there - // we don't break the pgx connection. If the Peek returns that data - // is available then we turn off the read deadline before the rxMsg. - err := c.conn.SetReadDeadline(deadline) + msg, err := c.rxMsg() if err != nil { return nil, err } - // Wait until there is a byte available before continuing onto the normal msg reading path - _, err = c.reader.Peek(1) + err = c.processContextFreeMsg(msg) if err != nil { - c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline - if err, ok := err.(*net.OpError); ok && err.Timeout() { - return nil, ErrNotificationTimeout - } - return nil, err - } - - err = c.conn.SetReadDeadline(zeroTime) - if err != nil { - return nil, err - } - - var t byte - var r *msgReader - if t, r, err = c.rxMsg(); err == nil { - if err = c.processContextFreeMsg(t, r); err != nil { - return nil, err - } - } else { return nil, err } @@ -891,10 +984,14 @@ func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) { } func (c *Conn) IsAlive() bool { - return c.alive + c.mux.Lock() + defer c.mux.Unlock() + return c.status >= connStatusIdle } func (c *Conn) CauseOfDeath() error { + c.mux.Lock() + defer c.mux.Unlock() return c.causeOfDeath } @@ -906,17 +1003,19 @@ func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) { } func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { + if err := c.ensureConnectionReadyForQuery(); err != nil { + return err + } if len(args) == 0 { - wbuf := newWriteBuf(c, 'Q') - wbuf.WriteCString(sql) - wbuf.closeMsg() + buf := appendQuery(c.wbuf, sql) - _, err := c.conn.Write(wbuf.buf) + _, err := c.conn.Write(buf) if err != nil { c.die(err) return err } + c.pendingReadyForQueryCount++ return nil } @@ -930,168 +1029,105 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { } func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}) (err error) { - if len(ps.ParameterOids) != len(arguments) { - return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOids), len(arguments)) + if len(ps.ParameterOIDs) != len(arguments) { + return errors.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOIDs), len(arguments)) } - // bind - wbuf := newWriteBuf(c, 'B') - wbuf.WriteByte(0) - wbuf.WriteCString(ps.Name) - - wbuf.WriteInt16(int16(len(ps.ParameterOids))) - for i, oid := range ps.ParameterOids { - switch arg := arguments[i].(type) { - case Encoder: - wbuf.WriteInt16(arg.FormatCode()) - case string, *string: - wbuf.WriteInt16(TextFormatCode) - default: - switch oid { - case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, ByteaArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid, RecordOid, JsonOid, JsonbOid: - wbuf.WriteInt16(BinaryFormatCode) - default: - wbuf.WriteInt16(TextFormatCode) - } - } + if err := c.ensureConnectionReadyForQuery(); err != nil { + return err } - wbuf.WriteInt16(int16(len(arguments))) - for i, oid := range ps.ParameterOids { - if err := Encode(wbuf, oid, arguments[i]); err != nil { - return err - } + resultFormatCodes := make([]int16, len(ps.FieldDescriptions)) + for i, fd := range ps.FieldDescriptions { + resultFormatCodes[i] = fd.FormatCode } - - wbuf.WriteInt16(int16(len(ps.FieldDescriptions))) - for _, fd := range ps.FieldDescriptions { - wbuf.WriteInt16(fd.FormatCode) - } - - // execute - wbuf.startMsg('E') - wbuf.WriteByte(0) - wbuf.WriteInt32(0) - - // sync - wbuf.startMsg('S') - wbuf.closeMsg() - - _, err = c.conn.Write(wbuf.buf) + buf, err := appendBind(c.wbuf, "", ps.Name, c.ConnInfo, ps.ParameterOIDs, arguments, resultFormatCodes) if err != nil { - c.die(err) + return err } - return err + buf = appendExecute(buf, "", 0) + buf = appendSync(buf) + + n, err := c.conn.Write(buf) + if err != nil { + if fatalWriteErr(n, err) { + c.die(err) + } + return err + } + c.pendingReadyForQueryCount++ + + return nil +} + +// fatalWriteError takes the response of a net.Conn.Write and determines if it is fatal +func fatalWriteErr(bytesWritten int, err error) bool { + // Partial writes break the connection + if bytesWritten > 0 { + return true + } + + netErr, is := err.(net.Error) + return !(is && netErr.Timeout()) } // Exec executes sql. sql can be either a prepared statement name or an SQL string. // arguments should be referenced positionally from the sql string as $1, $2, etc. func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { - if err = c.lock(); err != nil { - return commandTag, err - } - - startTime := time.Now() - c.lastActivityTime = startTime - - defer func() { - if err == nil { - if c.shouldLog(LogLevelInfo) { - endTime := time.Now() - c.log(LogLevelInfo, "Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag) - } - } else { - if c.shouldLog(LogLevelError) { - c.log(LogLevelError, "Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err) - } - } - - if unlockErr := c.unlock(); unlockErr != nil && err == nil { - err = unlockErr - } - }() - - if err = c.sendQuery(sql, arguments...); err != nil { - return - } - - var softErr error - - for { - var t byte - var r *msgReader - t, r, err = c.rxMsg() - if err != nil { - return commandTag, err - } - - switch t { - case readyForQuery: - c.rxReadyForQuery(r) - return commandTag, softErr - case rowDescription: - case dataRow: - case bindComplete: - case commandComplete: - commandTag = CommandTag(r.readCString()) - default: - if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { - softErr = e - } - } - } + return c.ExecEx(context.Background(), sql, nil, arguments...) } // Processes messages that are not exclusive to one context such as -// authentication or query response. The response to these messages -// is the same regardless of when they occur. -func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) { - switch t { - case 'S': - c.rxParameterStatus(r) - return nil - case errorResponse: - return c.rxErrorResponse(r) - case noticeResponse: - return nil - case emptyQueryResponse: - return nil - case notificationResponse: - c.rxNotificationResponse(r) - return nil - default: - return fmt.Errorf("Received unknown message type: %c", t) +// authentication or query response. The response to these messages is the same +// regardless of when they occur. It also ignores messages that are only +// meaningful in a given context. These messages can occur due to a context +// deadline interrupting message processing. For example, an interrupted query +// may have left DataRow messages on the wire. +func (c *Conn) processContextFreeMsg(msg pgproto3.BackendMessage) (err error) { + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + return c.rxErrorResponse(msg) + case *pgproto3.NoticeResponse: + c.rxNoticeResponse(msg) + case *pgproto3.NotificationResponse: + c.rxNotificationResponse(msg) + case *pgproto3.ReadyForQuery: + c.rxReadyForQuery(msg) + case *pgproto3.ParameterStatus: + c.rxParameterStatus(msg) } + + return nil } -func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { - if !c.alive { - return 0, nil, ErrDeadConn +func (c *Conn) rxMsg() (pgproto3.BackendMessage, error) { + if !c.IsAlive() { + return nil, ErrDeadConn } - t, err = c.mr.rxMsg() + msg, err := c.frontend.Receive() if err != nil { - c.die(err) + if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { + c.die(err) + } + return nil, err } c.lastActivityTime = time.Now() - if c.shouldLog(LogLevelTrace) { - c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBytesRemaining", c.mr.msgBytesRemaining) - } + // fmt.Printf("rxMsg: %#v\n", msg) - return t, &c.mr, err + return msg, nil } -func (c *Conn) rxAuthenticationX(r *msgReader) (err error) { - switch r.readInt32() { - case 0: // AuthenticationOk - case 3: // AuthenticationCleartextPassword +func (c *Conn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { + switch msg.Type { + case pgproto3.AuthTypeOk: + case pgproto3.AuthTypeCleartextPassword: err = c.txPasswordMessage(c.config.Password) - case 5: // AuthenticationMD5Password - salt := r.readString(4) - digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+salt) + case pgproto3.AuthTypeMD5Password: + digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+string(msg.Salt[:])) err = c.txPasswordMessage(digestedPassword) default: err = errors.New("Received unknown authentication message") @@ -1106,114 +1142,103 @@ func hexMD5(s string) string { return hex.EncodeToString(hash.Sum(nil)) } -func (c *Conn) rxParameterStatus(r *msgReader) { - key := r.readCString() - value := r.readCString() - c.RuntimeParams[key] = value +func (c *Conn) rxParameterStatus(msg *pgproto3.ParameterStatus) { + c.RuntimeParams[msg.Name] = msg.Value } -func (c *Conn) rxErrorResponse(r *msgReader) (err PgError) { - for { - switch r.readByte() { - case 'S': - err.Severity = r.readCString() - case 'C': - err.Code = r.readCString() - case 'M': - err.Message = r.readCString() - case 'D': - err.Detail = r.readCString() - case 'H': - err.Hint = r.readCString() - case 'P': - s := r.readCString() - n, _ := strconv.ParseInt(s, 10, 32) - err.Position = int32(n) - case 'p': - s := r.readCString() - n, _ := strconv.ParseInt(s, 10, 32) - err.InternalPosition = int32(n) - case 'q': - err.InternalQuery = r.readCString() - case 'W': - err.Where = r.readCString() - case 's': - err.SchemaName = r.readCString() - case 't': - err.TableName = r.readCString() - case 'c': - err.ColumnName = r.readCString() - case 'd': - err.DataTypeName = r.readCString() - case 'n': - err.ConstraintName = r.readCString() - case 'F': - err.File = r.readCString() - case 'L': - s := r.readCString() - n, _ := strconv.ParseInt(s, 10, 32) - err.Line = int32(n) - case 'R': - err.Routine = r.readCString() - - case 0: // End of error message - if err.Severity == "FATAL" { - c.die(err) - } - return - default: // Ignore other error fields - r.readCString() - } +func (c *Conn) rxErrorResponse(msg *pgproto3.ErrorResponse) PgError { + err := PgError{ + Severity: msg.Severity, + Code: msg.Code, + Message: msg.Message, + Detail: msg.Detail, + Hint: msg.Hint, + Position: msg.Position, + InternalPosition: msg.InternalPosition, + InternalQuery: msg.InternalQuery, + Where: msg.Where, + SchemaName: msg.SchemaName, + TableName: msg.TableName, + ColumnName: msg.ColumnName, + DataTypeName: msg.DataTypeName, + ConstraintName: msg.ConstraintName, + File: msg.File, + Line: msg.Line, + Routine: msg.Routine, } -} -func (c *Conn) rxBackendKeyData(r *msgReader) { - c.Pid = r.readInt32() - c.SecretKey = r.readInt32() -} - -func (c *Conn) rxReadyForQuery(r *msgReader) { - c.TxStatus = r.readByte() -} - -func (c *Conn) rxRowDescription(r *msgReader) (fields []FieldDescription) { - fieldCount := r.readInt16() - fields = make([]FieldDescription, fieldCount) - for i := int16(0); i < fieldCount; i++ { - f := &fields[i] - f.Name = r.readCString() - f.Table = r.readOid() - f.AttributeNumber = r.readInt16() - f.DataType = r.readOid() - f.DataTypeSize = r.readInt16() - f.Modifier = r.readInt32() - f.FormatCode = r.readInt16() + if err.Severity == "FATAL" { + c.die(err) } - return + + return err } -func (c *Conn) rxParameterDescription(r *msgReader) (parameters []Oid) { - // Internally, PostgreSQL supports greater than 64k parameters to a prepared - // statement. But the parameter description uses a 16-bit integer for the - // count of parameters. If there are more than 64K parameters, this count is - // wrong. So read the count, ignore it, and compute the proper value from - // the size of the message. - r.readInt16() - parameterCount := r.msgBytesRemaining / 4 - - parameters = make([]Oid, 0, parameterCount) - - for i := int32(0); i < parameterCount; i++ { - parameters = append(parameters, r.readOid()) +func (c *Conn) rxNoticeResponse(msg *pgproto3.NoticeResponse) { + if c.onNotice == nil { + return } - return + + notice := &Notice{ + Severity: msg.Severity, + Code: msg.Code, + Message: msg.Message, + Detail: msg.Detail, + Hint: msg.Hint, + Position: msg.Position, + InternalPosition: msg.InternalPosition, + InternalQuery: msg.InternalQuery, + Where: msg.Where, + SchemaName: msg.SchemaName, + TableName: msg.TableName, + ColumnName: msg.ColumnName, + DataTypeName: msg.DataTypeName, + ConstraintName: msg.ConstraintName, + File: msg.File, + Line: msg.Line, + Routine: msg.Routine, + } + + c.onNotice(c, notice) } -func (c *Conn) rxNotificationResponse(r *msgReader) { +func (c *Conn) rxBackendKeyData(msg *pgproto3.BackendKeyData) { + c.pid = msg.ProcessID + c.secretKey = msg.SecretKey +} + +func (c *Conn) rxReadyForQuery(msg *pgproto3.ReadyForQuery) { + c.pendingReadyForQueryCount-- + c.txStatus = msg.TxStatus +} + +func (c *Conn) rxRowDescription(msg *pgproto3.RowDescription) []FieldDescription { + fields := make([]FieldDescription, len(msg.Fields)) + for i := 0; i < len(fields); i++ { + fields[i].Name = msg.Fields[i].Name + fields[i].Table = pgtype.OID(msg.Fields[i].TableOID) + fields[i].AttributeNumber = msg.Fields[i].TableAttributeNumber + fields[i].DataType = pgtype.OID(msg.Fields[i].DataTypeOID) + fields[i].DataTypeSize = msg.Fields[i].DataTypeSize + fields[i].Modifier = msg.Fields[i].TypeModifier + fields[i].FormatCode = msg.Fields[i].Format + } + return fields +} + +func (c *Conn) rxParameterDescription(msg *pgproto3.ParameterDescription) []pgtype.OID { + parameters := make([]pgtype.OID, len(msg.ParameterOIDs)) + for i := 0; i < len(parameters); i++ { + parameters[i] = pgtype.OID(msg.ParameterOIDs[i]) + } + return parameters +} + +func (c *Conn) rxNotificationResponse(msg *pgproto3.NotificationResponse) { n := new(Notification) - n.Pid = r.readInt32() - n.Channel = r.readCString() - n.Payload = r.readCString() + n.PID = msg.PID + n.Channel = msg.Channel + n.Payload = msg.Payload c.notifications = append(c.notifications, n) } @@ -1237,40 +1262,54 @@ func (c *Conn) startTLS(tlsConfig *tls.Config) (err error) { return nil } -func (c *Conn) txStartupMessage(msg *startupMessage) error { - _, err := c.conn.Write(msg.Bytes()) - return err -} - func (c *Conn) txPasswordMessage(password string) (err error) { - wbuf := newWriteBuf(c, 'p') - wbuf.WriteCString(password) - wbuf.closeMsg() + buf := c.wbuf + buf = append(buf, 'p') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, password...) + buf = append(buf, 0) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - _, err = c.conn.Write(wbuf.buf) + _, err = c.conn.Write(buf) return err } func (c *Conn) die(err error) { - c.alive = false + c.mux.Lock() + defer c.mux.Unlock() + + if c.status == connStatusClosed { + return + } + + c.status = connStatusClosed c.causeOfDeath = err c.conn.Close() } func (c *Conn) lock() error { - if c.busy { + c.mux.Lock() + defer c.mux.Unlock() + + if c.status != connStatusIdle { return ErrConnBusy } - c.busy = true + + c.status = connStatusBusy return nil } func (c *Conn) unlock() error { - if !c.busy { + c.mux.Lock() + defer c.mux.Unlock() + + if c.status != connStatusBusy { return errors.New("unlock conn that is not busy") } - c.busy = false + + c.status = connStatusIdle return nil } @@ -1278,23 +1317,15 @@ func (c *Conn) shouldLog(lvl int) bool { return c.logger != nil && c.logLevel >= lvl } -func (c *Conn) log(lvl int, msg string, ctx ...interface{}) { - if c.Pid != 0 { - ctx = append(ctx, "pid", c.Pid) +func (c *Conn) log(lvl LogLevel, msg string, data map[string]interface{}) { + if data == nil { + data = map[string]interface{}{} + } + if c.pid != 0 { + data["pid"] = c.PID } - switch lvl { - case LogLevelTrace: - c.logger.Debug(msg, ctx...) - case LogLevelDebug: - c.logger.Debug(msg, ctx...) - case LogLevelInfo: - c.logger.Info(msg, ctx...) - case LogLevelWarn: - c.logger.Warn(msg, ctx...) - case LogLevelError: - c.logger.Error(msg, ctx...) - } + c.logger.Log(lvl, msg, data) } // SetLogger replaces the current logger and returns the previous logger. @@ -1320,3 +1351,278 @@ func (c *Conn) SetLogLevel(lvl int) (int, error) { func quoteIdentifier(s string) string { return `"` + strings.Replace(s, `"`, `""`, -1) + `"` } + +// cancelQuery sends a cancel request to the PostgreSQL server. It returns an +// error if unable to deliver the cancel request, but lack of an error does not +// ensure that the query was canceled. As specified in the documentation, there +// is no way to be sure a query was canceled. See +// https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861 +func (c *Conn) cancelQuery() { + if !atomic.CompareAndSwapInt32(&c.cancelQueryInProgress, 0, 1) { + panic("cancelQuery when cancelQueryInProgress") + } + + if err := c.conn.SetDeadline(time.Now()); err != nil { + c.Close() // Close connection if unable to set deadline + return + } + + doCancel := func() error { + network, address := c.config.networkAddress() + cancelConn, err := c.config.Dial(network, address) + if err != nil { + return err + } + defer cancelConn.Close() + + // If server doesn't process cancellation request in bounded time then abort. + err = cancelConn.SetDeadline(time.Now().Add(15 * time.Second)) + if err != nil { + return err + } + + buf := make([]byte, 16) + binary.BigEndian.PutUint32(buf[0:4], 16) + binary.BigEndian.PutUint32(buf[4:8], 80877102) + binary.BigEndian.PutUint32(buf[8:12], uint32(c.pid)) + binary.BigEndian.PutUint32(buf[12:16], uint32(c.secretKey)) + _, err = cancelConn.Write(buf) + if err != nil { + return err + } + + _, err = cancelConn.Read(buf) + if err != io.EOF { + return errors.Errorf("Server failed to close connection after cancel query request: %v %v", err, buf) + } + + return nil + } + + go func() { + err := doCancel() + if err != nil { + c.Close() // Something is very wrong. Terminate the connection. + } + c.cancelQueryCompleted <- struct{}{} + }() +} + +func (c *Conn) Ping(ctx context.Context) error { + _, err := c.ExecEx(ctx, ";", nil) + return err +} + +func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (CommandTag, error) { + err := c.waitForPreviousCancelQuery(ctx) + if err != nil { + return "", err + } + + if err := c.lock(); err != nil { + return "", err + } + defer c.unlock() + + startTime := time.Now() + c.lastActivityTime = startTime + + commandTag, err := c.execEx(ctx, sql, options, arguments...) + if err != nil { + if c.shouldLog(LogLevelError) { + c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err}) + } + return commandTag, err + } + + if c.shouldLog(LogLevelInfo) { + endTime := time.Now() + c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag}) + } + + return commandTag, err +} + +func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) { + err = c.initContext(ctx) + if err != nil { + return "", err + } + defer func() { + err = c.termContext(err) + }() + + if options != nil && options.SimpleProtocol { + err = c.sanitizeAndSendSimpleQuery(sql, arguments...) + if err != nil { + return "", err + } + } else if options != nil && len(options.ParameterOIDs) > 0 { + buf, err := c.buildOneRoundTripExec(c.wbuf, sql, options, arguments) + if err != nil { + return "", err + } + + buf = appendSync(buf) + + n, err := c.conn.Write(buf) + if err != nil && fatalWriteErr(n, err) { + c.die(err) + return "", err + } + c.pendingReadyForQueryCount++ + } else { + if len(arguments) > 0 { + ps, ok := c.preparedStatements[sql] + if !ok { + var err error + ps, err = c.prepareEx("", sql, nil) + if err != nil { + return "", err + } + } + + err = c.sendPreparedQuery(ps, arguments...) + if err != nil { + return "", err + } + } else { + if err = c.sendQuery(sql, arguments...); err != nil { + return + } + } + } + + var softErr error + + for { + msg, err := c.rxMsg() + if err != nil { + return commandTag, err + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + c.rxReadyForQuery(msg) + return commandTag, softErr + case *pgproto3.CommandComplete: + commandTag = CommandTag(msg.CommandTag) + default: + if e := c.processContextFreeMsg(msg); e != nil && softErr == nil { + softErr = e + } + } + } +} + +func (c *Conn) buildOneRoundTripExec(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) { + if len(arguments) != len(options.ParameterOIDs) { + return nil, errors.Errorf("mismatched number of arguments (%d) and options.ParameterOIDs (%d)", len(arguments), len(options.ParameterOIDs)) + } + + if len(options.ParameterOIDs) > 65535 { + return nil, errors.Errorf("Number of QueryExOptions ParameterOIDs must be between 0 and 65535, received %d", len(options.ParameterOIDs)) + } + + buf = appendParse(buf, "", sql, options.ParameterOIDs) + buf, err := appendBind(buf, "", "", c.ConnInfo, options.ParameterOIDs, arguments, nil) + if err != nil { + return nil, err + } + buf = appendExecute(buf, "", 0) + + return buf, nil +} + +func (c *Conn) initContext(ctx context.Context) error { + if c.ctxInProgress { + return errors.New("ctx already in progress") + } + + if ctx.Done() == nil { + return nil + } + + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + c.ctxInProgress = true + + go c.contextHandler(ctx) + + return nil +} + +func (c *Conn) termContext(opErr error) error { + if !c.ctxInProgress { + return opErr + } + + var err error + + select { + case err = <-c.closedChan: + if opErr == nil { + err = nil + } + case c.doneChan <- struct{}{}: + err = opErr + } + + c.ctxInProgress = false + return err +} + +func (c *Conn) contextHandler(ctx context.Context) { + select { + case <-ctx.Done(): + c.cancelQuery() + c.closedChan <- ctx.Err() + case <-c.doneChan: + } +} + +func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error { + if atomic.LoadInt32(&c.cancelQueryInProgress) == 0 { + return nil + } + + select { + case <-c.cancelQueryCompleted: + atomic.StoreInt32(&c.cancelQueryInProgress, 0) + if err := c.conn.SetDeadline(time.Time{}); err != nil { + c.Close() // Close connection if unable to disable deadline + return err + } + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (c *Conn) ensureConnectionReadyForQuery() error { + for c.pendingReadyForQueryCount > 0 { + msg, err := c.rxMsg() + if err != nil { + return err + } + + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + pgErr := c.rxErrorResponse(msg) + if pgErr.Severity == "FATAL" { + return pgErr + } + default: + err = c.processContextFreeMsg(msg) + if err != nil { + return err + } + } + } + + return nil +} diff --git a/conn_config_test.go.example b/conn_config_test.go.example index cac798b7..4f6a5e5a 100644 --- a/conn_config_test.go.example +++ b/conn_config_test.go.example @@ -23,3 +23,5 @@ var replicationConnConfig *pgx.ConnConfig = nil // var invalidUserConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"} // var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}} // var customDialerConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} +// var replicationConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_replication", Password: "secret", Database: "pgx_test"} + diff --git a/conn_pool.go b/conn_pool.go index 1913699e..5fa923b7 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -1,9 +1,13 @@ package pgx import ( - "errors" + "context" "sync" "time" + + "github.com/pkg/errors" + + "github.com/jackc/pgx/pgtype" ) type ConnPoolConfig struct { @@ -27,11 +31,7 @@ type ConnPool struct { closed bool preparedStatements map[string]*PreparedStatement acquireTimeout time.Duration - pgTypes map[Oid]PgType - pgsqlAfInet *byte - pgsqlAfInet6 *byte - txAfterClose func(tx *Tx) - rowsAfterClose func(rows *Rows) + connInfo *pgtype.ConnInfo } type ConnPoolStat struct { @@ -48,6 +48,7 @@ var ErrAcquireTimeout = errors.New("timeout acquiring connection from pool") func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) { p = new(ConnPool) p.config = config.ConnConfig + p.connInfo = minimalConnInfo p.maxConnections = config.MaxConnections if p.maxConnections == 0 { p.maxConnections = 5 @@ -73,14 +74,6 @@ func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) { p.logLevel = LogLevelNone } - p.txAfterClose = func(tx *Tx) { - p.Release(tx.Conn()) - } - - p.rowsAfterClose = func(rows *Rows) { - p.Release(rows.Conn()) - } - p.allConnections = make([]*Conn, 0, p.maxConnections) p.availableConnections = make([]*Conn, 0, p.maxConnections) p.preparedStatements = make(map[string]*PreparedStatement) @@ -94,6 +87,7 @@ func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) { } p.allConnections = append(p.allConnections, c) p.availableConnections = append(p.availableConnections, c) + p.connInfo = c.ConnInfo.DeepCopy() return } @@ -161,7 +155,7 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { } // All connections are in use and we cannot create more if p.logLevel >= LogLevelWarn { - p.logger.Warn("All connections in pool are busy - waiting...") + p.logger.Log(LogLevelWarn, "waiting for available connection", nil) } // Wait until there is an available connection OR room to create a new connection @@ -181,7 +175,11 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { // Release gives up use of a connection. func (p *ConnPool) Release(conn *Conn) { - if conn.TxStatus != 'I' { + if conn.ctxInProgress { + panic("should never release when context is in progress") + } + + if conn.txStatus != 'I' { conn.Exec("rollback") } @@ -223,25 +221,21 @@ func (p *ConnPool) removeFromAllConnections(conn *Conn) bool { return false } -// Close ends the use of a connection pool. It prevents any new connections -// from being acquired, waits until all acquired connections are released, -// then closes all underlying connections. +// Close ends the use of a connection pool. It prevents any new connections from +// being acquired and closes available underlying connections. Any acquired +// connections will be closed when they are released. func (p *ConnPool) Close() { p.cond.L.Lock() defer p.cond.L.Unlock() p.closed = true - // Wait until all connections are released - if len(p.availableConnections) != len(p.allConnections) { - for len(p.availableConnections) != len(p.allConnections) { - p.cond.Wait() - } - } - - for _, c := range p.allConnections { + for _, c := range p.availableConnections { _ = c.Close() } + + // This will cause any checked out connections to be closed on release + p.resetCount++ } // Reset closes all open connections, but leaves the pool open. It is intended @@ -289,7 +283,7 @@ func (p *ConnPool) Stat() (s ConnPoolStat) { } func (p *ConnPool) createConnection() (*Conn, error) { - c, err := connect(p.config, p.pgTypes, p.pgsqlAfInet, p.pgsqlAfInet6) + c, err := connect(p.config, p.connInfo) if err != nil { return nil, err } @@ -324,10 +318,6 @@ func (p *ConnPool) createConnectionUnlocked() (*Conn, error) { // afterConnectionCreated executes (if it is) afterConnect() callback and prepares // all the known statements for the new connection. func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) { - p.pgTypes = c.PgTypes - p.pgsqlAfInet = c.pgsqlAfInet - p.pgsqlAfInet6 = c.pgsqlAfInet6 - if p.afterConnect != nil { err := p.afterConnect(c) if err != nil { @@ -357,6 +347,16 @@ func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag Comman return c.Exec(sql, arguments...) } +func (p *ConnPool) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) { + var c *Conn + if c, err = p.Acquire(); err != nil { + return + } + defer p.Release(c) + + return c.ExecEx(ctx, sql, options, arguments...) +} + // Query acquires a connection and delegates the call to that connection. When // *Rows are closed, the connection is released automatically. func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) { @@ -372,7 +372,25 @@ func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) { return rows, err } - rows.AfterClose(p.rowsAfterClose) + rows.connPool = p + + return rows, nil +} + +func (p *ConnPool) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (*Rows, error) { + c, err := p.Acquire() + if err != nil { + // Because checking for errors can be deferred to the *Rows, build one with the error + return &Rows{closed: true, err: err}, err + } + + rows, err := c.QueryEx(ctx, sql, options, args...) + if err != nil { + p.Release(c) + return rows, err + } + + rows.connPool = p return rows, nil } @@ -385,10 +403,15 @@ func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row { return (*Row)(rows) } +func (p *ConnPool) QueryRowEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) *Row { + rows, _ := p.QueryEx(ctx, sql, options, args...) + return (*Row)(rows) +} + // Begin acquires a connection and begins a transaction on it. When the // transaction is closed the connection will be automatically released. func (p *ConnPool) Begin() (*Tx, error) { - return p.BeginIso("") + return p.BeginEx(context.Background(), nil) } // Prepare creates a prepared statement on a connection in the pool to test the @@ -403,7 +426,7 @@ func (p *ConnPool) Begin() (*Tx, error) { // the same name and sql arguments. This allows a code path to Prepare and // Query/Exec/PrepareEx without concern for if the statement has already been prepared. func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) { - return p.PrepareEx(name, sql, nil) + return p.PrepareEx(context.Background(), name, sql, nil) } // PrepareEx creates a prepared statement on a connection in the pool to test the @@ -417,7 +440,7 @@ func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) { // PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same // name and sql arguments. This allows a code path to PrepareEx and Query/Exec/Prepare without // concern for if the statement has already been prepared. -func (p *ConnPool) PrepareEx(name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) { +func (p *ConnPool) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) { p.cond.L.Lock() defer p.cond.L.Unlock() @@ -439,13 +462,13 @@ func (p *ConnPool) PrepareEx(name, sql string, opts *PrepareExOptions) (*Prepare return ps, nil } - ps, err := c.PrepareEx(name, sql, opts) + ps, err := c.PrepareEx(ctx, name, sql, opts) if err != nil { return nil, err } for _, c := range p.availableConnections { - _, err := c.PrepareEx(name, sql, opts) + _, err := c.PrepareEx(ctx, name, sql, opts) if err != nil { return nil, err } @@ -474,17 +497,17 @@ func (p *ConnPool) Deallocate(name string) (err error) { return nil } -// BeginIso acquires a connection and begins a transaction in isolation mode iso -// on it. When the transaction is closed the connection will be automatically -// released. -func (p *ConnPool) BeginIso(iso string) (*Tx, error) { +// BeginEx acquires a connection and starts a transaction with txOptions +// determining the transaction mode. When the transaction is closed the +// connection will be automatically released. +func (p *ConnPool) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) { for { c, err := p.Acquire() if err != nil { return nil, err } - tx, err := c.BeginIso(iso) + tx, err := c.BeginEx(ctx, txOptions) if err != nil { alive := c.IsAlive() p.Release(c) @@ -499,25 +522,12 @@ func (p *ConnPool) BeginIso(iso string) (*Tx, error) { continue } - tx.AfterClose(p.txAfterClose) + tx.connPool = p return tx, nil } } -// Deprecated. Use CopyFrom instead. CopyTo acquires a connection, delegates the -// call to that connection, and releases the connection. -func (p *ConnPool) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) { - c, err := p.Acquire() - if err != nil { - return 0, err - } - defer p.Release(c) - - return c.CopyTo(tableName, columnNames, rowSrc) -} - -// CopyFrom acquires a connection, delegates the call to that connection, and -// releases the connection. +// CopyFrom acquires a connection, delegates the call to that connection, and releases the connection func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) { c, err := p.Acquire() if err != nil { @@ -527,3 +537,10 @@ func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc C return c.CopyFrom(tableName, columnNames, rowSrc) } + +// BeginBatch acquires a connection and begins a batch on that connection. When +// *Batch is finished, the connection is released automatically. +func (p *ConnPool) BeginBatch() *Batch { + c, err := p.Acquire() + return &Batch{conn: c, connPool: p, err: err} +} diff --git a/conn_pool_test.go b/conn_pool_test.go index ab76bfb7..ccc38ba9 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -1,13 +1,15 @@ package pgx_test import ( - "errors" + "context" "fmt" "net" "sync" "testing" "time" + "github.com/pkg/errors" + "github.com/jackc/pgx" ) @@ -297,7 +299,7 @@ func TestPoolWithoutAcquireTimeoutSet(t *testing.T) { // ... then try to consume 1 more. It should hang forever. // To unblock it we release the previously taken connection in a goroutine. stopDeadWaitTimeout := 5 * time.Second - timer := time.AfterFunc(stopDeadWaitTimeout, func() { + timer := time.AfterFunc(stopDeadWaitTimeout+100*time.Millisecond, func() { releaseAllConnections(pool, allConnections) }) defer timer.Stop() @@ -328,14 +330,14 @@ func TestPoolReleaseWithTransactions(t *testing.T) { t.Fatal("Did not receive expected error") } - if conn.TxStatus != 'E' { - t.Fatalf("Expected TxStatus to be 'E', instead it was '%c'", conn.TxStatus) + if conn.TxStatus() != 'E' { + t.Fatalf("Expected TxStatus to be 'E', instead it was '%c'", conn.TxStatus()) } pool.Release(conn) - if conn.TxStatus != 'I' { - t.Fatalf("Expected release to rollback errored transaction, but it did not: '%c'", conn.TxStatus) + if conn.TxStatus() != 'I' { + t.Fatalf("Expected release to rollback errored transaction, but it did not: '%c'", conn.TxStatus()) } conn, err = pool.Acquire() @@ -343,14 +345,14 @@ func TestPoolReleaseWithTransactions(t *testing.T) { t.Fatalf("Unable to acquire connection: %v", err) } mustExec(t, conn, "begin") - if conn.TxStatus != 'T' { - t.Fatalf("Expected txStatus to be 'T', instead it was '%c'", conn.TxStatus) + if conn.TxStatus() != 'T' { + t.Fatalf("Expected txStatus to be 'T', instead it was '%c'", conn.TxStatus()) } pool.Release(conn) - if conn.TxStatus != 'I' { - t.Fatalf("Expected release to rollback uncommitted transaction, but it did not: '%c'", conn.TxStatus) + if conn.TxStatus() != 'I' { + t.Fatalf("Expected release to rollback uncommitted transaction, but it did not: '%c'", conn.TxStatus()) } } @@ -428,7 +430,7 @@ func TestPoolReleaseDiscardsDeadConnections(t *testing.T) { } }() - if _, err = c2.Exec("select pg_terminate_backend($1)", c1.Pid); err != nil { + if _, err = c2.Exec("select pg_terminate_backend($1)", c1.PID()); err != nil { t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) } @@ -635,9 +637,9 @@ func TestConnPoolTransactionIso(t *testing.T) { pool := createConnPool(t, 2) defer pool.Close() - tx, err := pool.BeginIso(pgx.Serializable) + tx, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable}) if err != nil { - t.Fatalf("pool.Begin failed: %v", err) + t.Fatalf("pool.BeginEx failed: %v", err) } defer tx.Rollback() @@ -674,7 +676,7 @@ func TestConnPoolBeginRetry(t *testing.T) { pool.Release(victimConn) // Terminate connection that was released to pool - if _, err = killerConn.Exec("select pg_terminate_backend($1)", victimConn.Pid); err != nil { + if _, err = killerConn.Exec("select pg_terminate_backend($1)", victimConn.PID()); err != nil { t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) } @@ -686,13 +688,13 @@ func TestConnPoolBeginRetry(t *testing.T) { } defer tx.Rollback() - var txPid int32 - err = tx.QueryRow("select pg_backend_pid()").Scan(&txPid) + var txPID uint32 + err = tx.QueryRow("select pg_backend_pid()").Scan(&txPID) if err != nil { t.Fatalf("tx.QueryRow Scan failed: %v", err) } - if txPid == victimConn.Pid { - t.Error("Expected txPid to defer from killed conn pid, but it didn't") + if txPID == victimConn.PID() { + t.Error("Expected txPID to defer from killed conn pid, but it didn't") } }() } @@ -980,3 +982,70 @@ func TestConnPoolPrepareWhenConnIsAlreadyAcquired(t *testing.T) { t.Errorf("Expected error calling deallocated prepared statement, but got: %v", err) } } + +func TestConnPoolBeginBatch(t *testing.T) { + t.Parallel() + + pool := createConnPool(t, 2) + defer pool.Close() + + batch := pool.BeginBatch() + batch.Queue("select n from generate_series(0,5) n", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + batch.Queue("select n from generate_series(0,5) n", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + rows, err := batch.QueryResults() + if err != nil { + t.Error(err) + } + + for i := 0; rows.Next(); i++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } + } + + if rows.Err() != nil { + t.Error(rows.Err()) + } + + rows, err = batch.QueryResults() + if err != nil { + t.Error(err) + } + + for i := 0; rows.Next(); i++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } + } + + if rows.Err() != nil { + t.Error(rows.Err()) + } + + err = batch.Close() + if err != nil { + t.Fatal(err) + } +} diff --git a/conn_test.go b/conn_test.go index cfb99561..d9369a1a 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,6 +1,7 @@ package pgx_test import ( + "context" "crypto/tls" "fmt" "net" @@ -13,6 +14,7 @@ import ( "time" "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" ) func TestConnect(t *testing.T) { @@ -27,14 +29,10 @@ func TestConnect(t *testing.T) { t.Error("Runtime parameters not stored") } - if conn.Pid == 0 { + if conn.PID() == 0 { t.Error("Backend PID not stored") } - if conn.SecretKey == 0 { - t.Error("Backend secret key not stored") - } - var currentDB string err = conn.QueryRow("select current_database()").Scan(¤tDB) if err != nil { @@ -752,7 +750,7 @@ func TestParseConnectionString(t *testing.T) { } func TestParseEnvLibpq(t *testing.T) { - pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME"} + pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE"} savedEnv := make(map[string]string) for _, n := range pgEnvvars { @@ -1035,6 +1033,169 @@ func TestExecFailure(t *testing.T) { } } +func TestExecExContextWithoutCancelation(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + commandTag, err := conn.ExecEx(ctx, "create temporary table foo(id integer primary key);", nil) + if err != nil { + t.Fatal(err) + } + if commandTag != "CREATE TABLE" { + t.Fatalf("Unexpected results from ExecEx: %v", commandTag) + } +} + +func TestExecExContextFailureWithoutCancelation(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + if _, err := conn.ExecEx(ctx, "selct;", nil); err == nil { + t.Fatal("Expected SQL syntax error") + } + + rows, _ := conn.Query("select 1") + rows.Close() + if rows.Err() != nil { + t.Fatalf("ExecEx failure appears to have broken connection: %v", rows.Err()) + } +} + +func TestExecExContextCancelationCancelsQuery(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(500 * time.Millisecond) + cancelFunc() + }() + + _, err := conn.ExecEx(ctx, "select pg_sleep(60)", nil) + if err != context.Canceled { + t.Fatalf("Expected context.Canceled err, got %v", err) + } + + ensureConnValid(t, conn) +} + +func TestExecExExtendedProtocol(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil) + if err != nil { + t.Fatal(err) + } + if commandTag != "CREATE TABLE" { + t.Fatalf("Unexpected results from ExecEx: %v", commandTag) + } + + commandTag, err = conn.ExecEx( + ctx, + "insert into foo(name) values($1);", + nil, + "bar", + ) + if err != nil { + t.Fatal(err) + } + if commandTag != "INSERT 0 1" { + t.Fatalf("Unexpected results from ExecEx: %v", commandTag) + } + + ensureConnValid(t, conn) +} + +func TestExecExSimpleProtocol(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil) + if err != nil { + t.Fatal(err) + } + if commandTag != "CREATE TABLE" { + t.Fatalf("Unexpected results from ExecEx: %v", commandTag) + } + + commandTag, err = conn.ExecEx( + ctx, + "insert into foo(name) values($1);", + &pgx.QueryExOptions{SimpleProtocol: true}, + "bar'; drop table foo;--", + ) + if err != nil { + t.Fatal(err) + } + if commandTag != "INSERT 0 1" { + t.Fatalf("Unexpected results from ExecEx: %v", commandTag) + } +} + +func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, "create temporary table foo(name varchar primary key);") + + commandTag, err := conn.ExecEx( + context.Background(), + "insert into foo(name) values($1);", + &pgx.QueryExOptions{ParameterOIDs: []pgtype.OID{pgtype.VarcharOID}}, + "bar'; drop table foo;--", + ) + if err != nil { + t.Fatal(err) + } + if commandTag != "INSERT 0 1" { + t.Fatalf("Unexpected results from ExecEx: %v", commandTag) + } +} + +func TestConnExecExSuppliedIncorrectParameterOIDs(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, "create temporary table foo(name varchar primary key);") + + _, err := conn.ExecEx( + context.Background(), + "insert into foo(name) values($1);", + &pgx.QueryExOptions{ParameterOIDs: []pgtype.OID{pgtype.Int4OID}}, + "bar'; drop table foo;--", + ) + if err == nil { + t.Fatal("expected error but got none") + } +} + func TestPrepare(t *testing.T) { t.Parallel() @@ -1206,7 +1367,7 @@ func TestPrepareEx(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - _, err := conn.PrepareEx("test", "select $1", &pgx.PrepareExOptions{ParameterOids: []pgx.Oid{pgx.TextOid}}) + _, err := conn.PrepareEx(context.Background(), "test", "select $1", &pgx.PrepareExOptions{ParameterOIDs: []pgtype.OID{pgtype.TextOID}}) if err != nil { t.Errorf("Unable to prepare statement: %v", err) return @@ -1244,7 +1405,7 @@ func TestListenNotify(t *testing.T) { mustExec(t, notifier, "notify chat") // when notification is waiting on the socket to be read - notification, err := listener.WaitForNotification(time.Second) + notification, err := listener.WaitForNotification(context.Background()) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } @@ -1259,7 +1420,10 @@ func TestListenNotify(t *testing.T) { if rows.Err() != nil { t.Fatalf("Unexpected error on Query: %v", rows.Err()) } - notification, err = listener.WaitForNotification(0) + + ctx, cancelFn := context.WithCancel(context.Background()) + cancelFn() + notification, err = listener.WaitForNotification(ctx) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } @@ -1268,8 +1432,9 @@ func TestListenNotify(t *testing.T) { } // when timeout occurs - notification, err = listener.WaitForNotification(time.Millisecond) - if err != pgx.ErrNotificationTimeout { + ctx, _ = context.WithTimeout(context.Background(), time.Millisecond) + notification, err = listener.WaitForNotification(ctx) + if err != context.DeadlineExceeded { t.Errorf("WaitForNotification returned the wrong kind of error: %v", err) } if notification != nil { @@ -1278,7 +1443,7 @@ func TestListenNotify(t *testing.T) { // listener can listen again after a timeout mustExec(t, notifier, "notify chat") - notification, err = listener.WaitForNotification(time.Second) + notification, err = listener.WaitForNotification(context.Background()) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } @@ -1303,7 +1468,7 @@ func TestUnlistenSpecificChannel(t *testing.T) { mustExec(t, notifier, "notify unlisten_test") // when notification is waiting on the socket to be read - notification, err := listener.WaitForNotification(time.Second) + notification, err := listener.WaitForNotification(context.Background()) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } @@ -1323,8 +1488,10 @@ func TestUnlistenSpecificChannel(t *testing.T) { if rows.Err() != nil { t.Fatalf("Unexpected error on Query: %v", rows.Err()) } - notification, err = listener.WaitForNotification(100 * time.Millisecond) - if err != pgx.ErrNotificationTimeout { + + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + notification, err = listener.WaitForNotification(ctx) + if err != context.DeadlineExceeded { t.Errorf("WaitForNotification returned the wrong kind of error: %v", err) } } @@ -1376,13 +1543,9 @@ func TestListenNotifyWhileBusyIsSafe(t *testing.T) { } }() - notifierDone := make(chan bool) go func() { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - defer func() { - notifierDone <- true - }() for i := 0; i < 100000; i++ { mustExec(t, conn, "notify busysafe, 'hello'") @@ -1406,7 +1569,8 @@ func TestListenNotifySelfNotification(t *testing.T) { // Notify self and WaitForNotification immediately mustExec(t, conn, "notify self") - notification, err := conn.WaitForNotification(time.Second) + ctx, _ := context.WithTimeout(context.Background(), time.Second) + notification, err := conn.WaitForNotification(ctx) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } @@ -1423,7 +1587,8 @@ func TestListenNotifySelfNotification(t *testing.T) { t.Fatalf("Unexpected error on Query: %v", rows.Err()) } - notification, err = conn.WaitForNotification(time.Second) + ctx, _ = context.WithTimeout(context.Background(), time.Second) + notification, err = conn.WaitForNotification(ctx) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } @@ -1474,7 +1639,7 @@ func TestFatalRxError(t *testing.T) { } defer otherConn.Close() - if _, err := otherConn.Exec("select pg_terminate_backend($1)", conn.Pid); err != nil { + if _, err := otherConn.Exec("select pg_terminate_backend($1)", conn.PID()); err != nil { t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) } @@ -1500,7 +1665,7 @@ func TestFatalTxError(t *testing.T) { } defer otherConn.Close() - _, err = otherConn.Exec("select pg_terminate_backend($1)", conn.Pid) + _, err = otherConn.Exec("select pg_terminate_backend($1)", conn.PID()) if err != nil { t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) } @@ -1611,26 +1776,17 @@ func TestCatchSimultaneousConnectionQueryAndExec(t *testing.T) { } type testLog struct { - lvl int - msg string - ctx []interface{} + lvl pgx.LogLevel + msg string + data map[string]interface{} } type testLogger struct { logs []testLog } -func (l *testLogger) Debug(msg string, ctx ...interface{}) { - l.logs = append(l.logs, testLog{lvl: pgx.LogLevelDebug, msg: msg, ctx: ctx}) -} -func (l *testLogger) Info(msg string, ctx ...interface{}) { - l.logs = append(l.logs, testLog{lvl: pgx.LogLevelInfo, msg: msg, ctx: ctx}) -} -func (l *testLogger) Warn(msg string, ctx ...interface{}) { - l.logs = append(l.logs, testLog{lvl: pgx.LogLevelWarn, msg: msg, ctx: ctx}) -} -func (l *testLogger) Error(msg string, ctx ...interface{}) { - l.logs = append(l.logs, testLog{lvl: pgx.LogLevelError, msg: msg, ctx: ctx}) +func (l *testLogger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { + l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data}) } func TestSetLogger(t *testing.T) { @@ -1742,3 +1898,30 @@ func TestIdentifierSanitize(t *testing.T) { } } } + +func TestConnOnNotice(t *testing.T) { + t.Parallel() + + var msg string + + connConfig := *defaultConnConfig + connConfig.OnNotice = func(c *pgx.Conn, notice *pgx.Notice) { + msg = notice.Message + } + conn := mustConnect(t, connConfig) + defer closeConn(t, conn) + + _, err := conn.Exec(`do $$ +begin + raise notice 'hello, world'; +end$$;`) + if err != nil { + t.Fatal(err) + } + + if msg != "hello, world" { + t.Errorf("msg => %v, want %v", msg, "hello, world") + } + + ensureConnValid(t, conn) +} diff --git a/copy_from.go b/copy_from.go index 1f8a2306..8b7c3d5b 100644 --- a/copy_from.go +++ b/copy_from.go @@ -3,6 +3,10 @@ package pgx import ( "bytes" "fmt" + + "github.com/jackc/pgx/pgio" + "github.com/jackc/pgx/pgproto3" + "github.com/pkg/errors" ) // CopyFromRows returns a CopyFromSource interface over the provided rows slice @@ -54,25 +58,25 @@ type copyFrom struct { func (ct *copyFrom) readUntilReadyForQuery() { for { - t, r, err := ct.conn.rxMsg() + msg, err := ct.conn.rxMsg() if err != nil { ct.readerErrChan <- err close(ct.readerErrChan) return } - switch t { - case readyForQuery: - ct.conn.rxReadyForQuery(r) + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + ct.conn.rxReadyForQuery(msg) close(ct.readerErrChan) return - case commandComplete: - case errorResponse: - ct.readerErrChan <- ct.conn.rxErrorResponse(r) + case *pgproto3.CommandComplete: + case *pgproto3.ErrorResponse: + ct.readerErrChan <- ct.conn.rxErrorResponse(msg) default: - err = ct.conn.processContextFreeMsg(t, r) + err = ct.conn.processContextFreeMsg(msg) if err != nil { - ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r) + ct.readerErrChan <- ct.conn.processContextFreeMsg(msg) } } } @@ -87,14 +91,14 @@ func (ct *copyFrom) waitForReaderDone() error { func (ct *copyFrom) run() (int, error) { quotedTableName := ct.tableName.Sanitize() - buf := &bytes.Buffer{} + cbuf := &bytes.Buffer{} for i, cn := range ct.columnNames { if i != 0 { - buf.WriteString(", ") + cbuf.WriteString(", ") } - buf.WriteString(quoteIdentifier(cn)) + cbuf.WriteString(quoteIdentifier(cn)) } - quotedColumnNames := buf.String() + quotedColumnNames := cbuf.String() ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName)) if err != nil { @@ -114,11 +118,14 @@ func (ct *copyFrom) run() (int, error) { go ct.readUntilReadyForQuery() defer ct.waitForReaderDone() - wbuf := newWriteBuf(ct.conn, copyData) + buf := ct.conn.wbuf + buf = append(buf, copyData) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000")) - wbuf.WriteInt32(0) - wbuf.WriteInt32(0) + buf = append(buf, "PGCOPY\n\377\r\n\000"...) + buf = pgio.AppendInt32(buf, 0) + buf = pgio.AppendInt32(buf, 0) var sentCount int @@ -129,18 +136,16 @@ func (ct *copyFrom) run() (int, error) { default: } - if len(wbuf.buf) > 65536 { - wbuf.closeMsg() - _, err = ct.conn.conn.Write(wbuf.buf) + if len(buf) > 65536 { + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + _, err = ct.conn.conn.Write(buf) if err != nil { ct.conn.die(err) return 0, err } // Directly manipulate wbuf to reset to reuse the same buffer - wbuf.buf = wbuf.buf[0:5] - wbuf.buf[0] = copyData - wbuf.sizeIdx = 1 + buf = buf[0:5] } sentCount++ @@ -152,12 +157,12 @@ func (ct *copyFrom) run() (int, error) { } if len(values) != len(ct.columnNames) { ct.cancelCopyIn() - return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) + return 0, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) } - wbuf.WriteInt16(int16(len(ct.columnNames))) + buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) for i, val := range values { - err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val) + buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val) if err != nil { ct.cancelCopyIn() return 0, err @@ -171,11 +176,13 @@ func (ct *copyFrom) run() (int, error) { return 0, ct.rowSrc.Err() } - wbuf.WriteInt16(-1) // terminate the copy stream + buf = pgio.AppendInt16(buf, -1) // terminate the copy stream + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - wbuf.startMsg(copyDone) - wbuf.closeMsg() - _, err = ct.conn.conn.Write(wbuf.buf) + buf = append(buf, copyDone) + buf = pgio.AppendInt32(buf, 4) + + _, err = ct.conn.conn.Write(buf) if err != nil { ct.conn.die(err) return 0, err @@ -190,18 +197,16 @@ func (ct *copyFrom) run() (int, error) { func (c *Conn) readUntilCopyInResponse() error { for { - var t byte - var r *msgReader - t, r, err := c.rxMsg() + msg, err := c.rxMsg() if err != nil { return err } - switch t { - case copyInResponse: + switch msg := msg.(type) { + case *pgproto3.CopyInResponse: return nil default: - err = c.processContextFreeMsg(t, r) + err = c.processContextFreeMsg(msg) if err != nil { return err } @@ -210,10 +215,15 @@ func (c *Conn) readUntilCopyInResponse() error { } func (ct *copyFrom) cancelCopyIn() error { - wbuf := newWriteBuf(ct.conn, copyFail) - wbuf.WriteCString("client error: abort") - wbuf.closeMsg() - _, err := ct.conn.conn.Write(wbuf.buf) + buf := ct.conn.wbuf + buf = append(buf, copyFail) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, "client error: abort"...) + buf = append(buf, 0) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + _, err := ct.conn.conn.Write(buf) if err != nil { ct.conn.die(err) return err diff --git a/copy_from_test.go b/copy_from_test.go index 54da6989..ec674855 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -1,12 +1,12 @@ package pgx_test import ( - "fmt" "reflect" "testing" "time" "github.com/jackc/pgx" + "github.com/pkg/errors" ) func TestConnCopyFromSmall(t *testing.T) { @@ -26,7 +26,7 @@ func TestConnCopyFromSmall(t *testing.T) { )`) inputRows := [][]interface{}{ - {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)}, + {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)}, {nil, nil, nil, nil, nil, nil, nil}, } @@ -83,7 +83,7 @@ func TestConnCopyFromLarge(t *testing.T) { inputRows := [][]interface{}{} for i := 0; i < 10000; i++ { - inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}}) + inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}}) } copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows)) @@ -125,8 +125,8 @@ func TestConnCopyFromJSON(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} { - if _, ok := conn.PgTypes[oid]; !ok { + for _, typeName := range []string{"json", "jsonb"} { + if _, ok := conn.ConnInfo.DataTypeForName(typeName); !ok { return // No JSON/JSONB type -- must be running against old PostgreSQL } } @@ -174,6 +174,28 @@ func TestConnCopyFromJSON(t *testing.T) { ensureConnValid(t, conn) } +type clientFailSource struct { + count int + err error +} + +func (cfs *clientFailSource) Next() bool { + cfs.count++ + return cfs.count < 100 +} + +func (cfs *clientFailSource) Values() ([]interface{}, error) { + if cfs.count == 3 { + cfs.err = errors.Errorf("client error") + return nil, cfs.err + } + return []interface{}{make([]byte, 100000)}, nil +} + +func (cfs *clientFailSource) Err() error { + return cfs.err +} + func TestConnCopyFromFailServerSideMidway(t *testing.T) { t.Parallel() @@ -302,28 +324,6 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { ensureConnValid(t, conn) } -type clientFailSource struct { - count int - err error -} - -func (cfs *clientFailSource) Next() bool { - cfs.count++ - return cfs.count < 100 -} - -func (cfs *clientFailSource) Values() ([]interface{}, error) { - if cfs.count == 3 { - cfs.err = fmt.Errorf("client error") - return nil, cfs.err - } - return []interface{}{make([]byte, 100000)}, nil -} - -func (cfs *clientFailSource) Err() error { - return cfs.err -} - func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) { t.Parallel() @@ -381,7 +381,7 @@ func (cfs *clientFinalErrSource) Values() ([]interface{}, error) { } func (cfs *clientFinalErrSource) Err() error { - return fmt.Errorf("final error") + return errors.Errorf("final error") } func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) { diff --git a/copy_to.go b/copy_to.go deleted file mode 100644 index 229e9a43..00000000 --- a/copy_to.go +++ /dev/null @@ -1,222 +0,0 @@ -package pgx - -import ( - "bytes" - "fmt" -) - -// Deprecated. Use CopyFromRows instead. CopyToRows returns a CopyToSource -// interface over the provided rows slice making it usable by *Conn.CopyTo. -func CopyToRows(rows [][]interface{}) CopyToSource { - return ©ToRows{rows: rows, idx: -1} -} - -type copyToRows struct { - rows [][]interface{} - idx int -} - -func (ctr *copyToRows) Next() bool { - ctr.idx++ - return ctr.idx < len(ctr.rows) -} - -func (ctr *copyToRows) Values() ([]interface{}, error) { - return ctr.rows[ctr.idx], nil -} - -func (ctr *copyToRows) Err() error { - return nil -} - -// Deprecated. Use CopyFromSource instead. CopyToSource is the interface used by -// *Conn.CopyTo as the source for copy data. -type CopyToSource interface { - // Next returns true if there is another row and makes the next row data - // available to Values(). When there are no more rows available or an error - // has occurred it returns false. - Next() bool - - // Values returns the values for the current row. - Values() ([]interface{}, error) - - // Err returns any error that has been encountered by the CopyToSource. If - // this is not nil *Conn.CopyTo will abort the copy. - Err() error -} - -type copyTo struct { - conn *Conn - tableName string - columnNames []string - rowSrc CopyToSource - readerErrChan chan error -} - -func (ct *copyTo) readUntilReadyForQuery() { - for { - t, r, err := ct.conn.rxMsg() - if err != nil { - ct.readerErrChan <- err - close(ct.readerErrChan) - return - } - - switch t { - case readyForQuery: - ct.conn.rxReadyForQuery(r) - close(ct.readerErrChan) - return - case commandComplete: - case errorResponse: - ct.readerErrChan <- ct.conn.rxErrorResponse(r) - default: - err = ct.conn.processContextFreeMsg(t, r) - if err != nil { - ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r) - } - } - } -} - -func (ct *copyTo) waitForReaderDone() error { - var err error - for err = range ct.readerErrChan { - } - return err -} - -func (ct *copyTo) run() (int, error) { - quotedTableName := quoteIdentifier(ct.tableName) - buf := &bytes.Buffer{} - for i, cn := range ct.columnNames { - if i != 0 { - buf.WriteString(", ") - } - buf.WriteString(quoteIdentifier(cn)) - } - quotedColumnNames := buf.String() - - ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName)) - if err != nil { - return 0, err - } - - err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) - if err != nil { - return 0, err - } - - err = ct.conn.readUntilCopyInResponse() - if err != nil { - return 0, err - } - - go ct.readUntilReadyForQuery() - defer ct.waitForReaderDone() - - wbuf := newWriteBuf(ct.conn, copyData) - - wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000")) - wbuf.WriteInt32(0) - wbuf.WriteInt32(0) - - var sentCount int - - for ct.rowSrc.Next() { - select { - case err = <-ct.readerErrChan: - return 0, err - default: - } - - if len(wbuf.buf) > 65536 { - wbuf.closeMsg() - _, err = ct.conn.conn.Write(wbuf.buf) - if err != nil { - ct.conn.die(err) - return 0, err - } - - // Directly manipulate wbuf to reset to reuse the same buffer - wbuf.buf = wbuf.buf[0:5] - wbuf.buf[0] = copyData - wbuf.sizeIdx = 1 - } - - sentCount++ - - values, err := ct.rowSrc.Values() - if err != nil { - ct.cancelCopyIn() - return 0, err - } - if len(values) != len(ct.columnNames) { - ct.cancelCopyIn() - return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) - } - - wbuf.WriteInt16(int16(len(ct.columnNames))) - for i, val := range values { - err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val) - if err != nil { - ct.cancelCopyIn() - return 0, err - } - - } - } - - if ct.rowSrc.Err() != nil { - ct.cancelCopyIn() - return 0, ct.rowSrc.Err() - } - - wbuf.WriteInt16(-1) // terminate the copy stream - - wbuf.startMsg(copyDone) - wbuf.closeMsg() - _, err = ct.conn.conn.Write(wbuf.buf) - if err != nil { - ct.conn.die(err) - return 0, err - } - - err = ct.waitForReaderDone() - if err != nil { - return 0, err - } - return sentCount, nil -} - -func (ct *copyTo) cancelCopyIn() error { - wbuf := newWriteBuf(ct.conn, copyFail) - wbuf.WriteCString("client error: abort") - wbuf.closeMsg() - _, err := ct.conn.conn.Write(wbuf.buf) - if err != nil { - ct.conn.die(err) - return err - } - - return nil -} - -// Deprecated. Use CopyFrom instead. CopyTo uses the PostgreSQL copy protocol to -// perform bulk data insertion. It returns the number of rows copied and an -// error. -// -// CopyTo requires all values use the binary format. Almost all types -// implemented by pgx use the binary format by default. Types implementing -// Encoder can only be used if they encode to the binary format. -func (c *Conn) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) { - ct := ©To{ - conn: c, - tableName: tableName, - columnNames: columnNames, - rowSrc: rowSrc, - readerErrChan: make(chan error), - } - - return ct.run() -} diff --git a/copy_to_test.go b/copy_to_test.go deleted file mode 100644 index ac270426..00000000 --- a/copy_to_test.go +++ /dev/null @@ -1,367 +0,0 @@ -package pgx_test - -import ( - "reflect" - "testing" - "time" - - "github.com/jackc/pgx" -) - -func TestConnCopyToSmall(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - mustExec(t, conn, `create temporary table foo( - a int2, - b int4, - c int8, - d varchar, - e text, - f date, - g timestamptz - )`) - - inputRows := [][]interface{}{ - {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)}, - {nil, nil, nil, nil, nil, nil, nil}, - } - - copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyToRows(inputRows)) - if err != nil { - t.Errorf("Unexpected error for CopyTo: %v", err) - } - if copyCount != len(inputRows) { - t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount) - } - - rows, err := conn.Query("select * from foo") - if err != nil { - t.Errorf("Unexpected error for Query: %v", err) - } - - var outputRows [][]interface{} - for rows.Next() { - row, err := rows.Values() - if err != nil { - t.Errorf("Unexpected error for rows.Values(): %v", err) - } - outputRows = append(outputRows, row) - } - - if rows.Err() != nil { - t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) - } - - if !reflect.DeepEqual(inputRows, outputRows) { - t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) - } - - ensureConnValid(t, conn) -} - -func TestConnCopyToLarge(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - mustExec(t, conn, `create temporary table foo( - a int2, - b int4, - c int8, - d varchar, - e text, - f date, - g timestamptz, - h bytea - )`) - - inputRows := [][]interface{}{} - - for i := 0; i < 10000; i++ { - inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}}) - } - - copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyToRows(inputRows)) - if err != nil { - t.Errorf("Unexpected error for CopyTo: %v", err) - } - if copyCount != len(inputRows) { - t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount) - } - - rows, err := conn.Query("select * from foo") - if err != nil { - t.Errorf("Unexpected error for Query: %v", err) - } - - var outputRows [][]interface{} - for rows.Next() { - row, err := rows.Values() - if err != nil { - t.Errorf("Unexpected error for rows.Values(): %v", err) - } - outputRows = append(outputRows, row) - } - - if rows.Err() != nil { - t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) - } - - if !reflect.DeepEqual(inputRows, outputRows) { - t.Errorf("Input rows and output rows do not equal") - } - - ensureConnValid(t, conn) -} - -func TestConnCopyToJSON(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} { - if _, ok := conn.PgTypes[oid]; !ok { - return // No JSON/JSONB type -- must be running against old PostgreSQL - } - } - - mustExec(t, conn, `create temporary table foo( - a json, - b jsonb - )`) - - inputRows := [][]interface{}{ - {map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}}, - {nil, nil}, - } - - copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows)) - if err != nil { - t.Errorf("Unexpected error for CopyTo: %v", err) - } - if copyCount != len(inputRows) { - t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount) - } - - rows, err := conn.Query("select * from foo") - if err != nil { - t.Errorf("Unexpected error for Query: %v", err) - } - - var outputRows [][]interface{} - for rows.Next() { - row, err := rows.Values() - if err != nil { - t.Errorf("Unexpected error for rows.Values(): %v", err) - } - outputRows = append(outputRows, row) - } - - if rows.Err() != nil { - t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) - } - - if !reflect.DeepEqual(inputRows, outputRows) { - t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) - } - - ensureConnValid(t, conn) -} - -func TestConnCopyToFailServerSideMidway(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - mustExec(t, conn, `create temporary table foo( - a int4, - b varchar not null - )`) - - inputRows := [][]interface{}{ - {int32(1), "abc"}, - {int32(2), nil}, // this row should trigger a failure - {int32(3), "def"}, - } - - copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows)) - if err == nil { - t.Errorf("Expected CopyTo return error, but it did not") - } - if _, ok := err.(pgx.PgError); !ok { - t.Errorf("Expected CopyTo return pgx.PgError, but instead it returned: %v", err) - } - if copyCount != 0 { - t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) - } - - rows, err := conn.Query("select * from foo") - if err != nil { - t.Errorf("Unexpected error for Query: %v", err) - } - - var outputRows [][]interface{} - for rows.Next() { - row, err := rows.Values() - if err != nil { - t.Errorf("Unexpected error for rows.Values(): %v", err) - } - outputRows = append(outputRows, row) - } - - if rows.Err() != nil { - t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) - } - - if len(outputRows) != 0 { - t.Errorf("Expected 0 rows, but got %v", outputRows) - } - - ensureConnValid(t, conn) -} - -func TestConnCopyToFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - mustExec(t, conn, `create temporary table foo( - a bytea not null - )`) - - startTime := time.Now() - - copyCount, err := conn.CopyTo("foo", []string{"a"}, &failSource{}) - if err == nil { - t.Errorf("Expected CopyTo return error, but it did not") - } - if _, ok := err.(pgx.PgError); !ok { - t.Errorf("Expected CopyTo return pgx.PgError, but instead it returned: %v", err) - } - if copyCount != 0 { - t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) - } - - endTime := time.Now() - copyTime := endTime.Sub(startTime) - if copyTime > time.Second { - t.Errorf("Failing CopyTo shouldn't have taken so long: %v", copyTime) - } - - rows, err := conn.Query("select * from foo") - if err != nil { - t.Errorf("Unexpected error for Query: %v", err) - } - - var outputRows [][]interface{} - for rows.Next() { - row, err := rows.Values() - if err != nil { - t.Errorf("Unexpected error for rows.Values(): %v", err) - } - outputRows = append(outputRows, row) - } - - if rows.Err() != nil { - t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) - } - - if len(outputRows) != 0 { - t.Errorf("Expected 0 rows, but got %v", outputRows) - } - - ensureConnValid(t, conn) -} - -func TestConnCopyToCopyToSourceErrorMidway(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - mustExec(t, conn, `create temporary table foo( - a bytea not null - )`) - - copyCount, err := conn.CopyTo("foo", []string{"a"}, &clientFailSource{}) - if err == nil { - t.Errorf("Expected CopyTo return error, but it did not") - } - if copyCount != 0 { - t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) - } - - rows, err := conn.Query("select * from foo") - if err != nil { - t.Errorf("Unexpected error for Query: %v", err) - } - - var outputRows [][]interface{} - for rows.Next() { - row, err := rows.Values() - if err != nil { - t.Errorf("Unexpected error for rows.Values(): %v", err) - } - outputRows = append(outputRows, row) - } - - if rows.Err() != nil { - t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) - } - - if len(outputRows) != 0 { - t.Errorf("Expected 0 rows, but got %v", outputRows) - } - - ensureConnValid(t, conn) -} - -func TestConnCopyToCopyToSourceErrorEnd(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - mustExec(t, conn, `create temporary table foo( - a bytea not null - )`) - - copyCount, err := conn.CopyTo("foo", []string{"a"}, &clientFinalErrSource{}) - if err == nil { - t.Errorf("Expected CopyTo return error, but it did not") - } - if copyCount != 0 { - t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) - } - - rows, err := conn.Query("select * from foo") - if err != nil { - t.Errorf("Unexpected error for Query: %v", err) - } - - var outputRows [][]interface{} - for rows.Next() { - row, err := rows.Values() - if err != nil { - t.Errorf("Unexpected error for rows.Values(): %v", err) - } - outputRows = append(outputRows, row) - } - - if rows.Err() != nil { - t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) - } - - if len(outputRows) != 0 { - t.Errorf("Expected 0 rows, but got %v", outputRows) - } - - ensureConnValid(t, conn) -} diff --git a/doc.go b/doc.go index 566c7254..c61329d9 100644 --- a/doc.go +++ b/doc.go @@ -62,17 +62,15 @@ Use Exec to execute a query that does not return a result set. Connection Pool -Connection pool usage is explicit and configurable. In pgx, a connection can -be created and managed directly, or a connection pool with a configurable -maximum connections can be used. Also, the connection pool offers an after -connect hook that allows every connection to be automatically setup before -being made available in the connection pool. This is especially useful to -ensure all connections have the same prepared statements available or to -change any other connection settings. +Connection pool usage is explicit and configurable. In pgx, a connection can be +created and managed directly, or a connection pool with a configurable maximum +connections can be used. The connection pool offers an after connect hook that +allows every connection to be automatically setup before being made available in +the connection pool. -It delegates Query, QueryRow, Exec, and Begin functions to an automatically -checked out and released connection so you can avoid manually acquiring and -releasing connections when you do not need that level of control. +It delegates methods such as QueryRow to an automatically checked out and +released connection so you can avoid manually acquiring and releasing +connections when you do not need that level of control. var name string var weight int64 @@ -117,11 +115,11 @@ particular: Null Mapping -pgx can map nulls in two ways. The first is Null* types that have a data field -and a valid field. They work in a similar fashion to database/sql. The second -is to use a pointer to a pointer. +pgx can map nulls in two ways. The first is package pgtype provides types that +have a data field and a status field. They work in a similar fashion to +database/sql. The second is to use a pointer to a pointer. - var foo pgx.NullString + var foo pgtype.Varchar var bar *string err := conn.QueryRow("select foo, bar from widgets where id=$1", 42).Scan(&a, &b) if err != nil { @@ -133,20 +131,15 @@ Array Mapping pgx maps between int16, int32, int64, float32, float64, and string Go slices and the equivalent PostgreSQL array type. Go slices of native types do not support nulls, so if a PostgreSQL array that contains a null is read into a -native Go slice an error will occur. - -Hstore Mapping - -pgx includes an Hstore type and a NullHstore type. Hstore is simply a -map[string]string and is preferred when the hstore contains no nulls. NullHstore -follows the Null* pattern and supports null values. +native Go slice an error will occur. The pgtype package includes many more +array types for PostgreSQL types that do not directly map to native Go types. JSON and JSONB Mapping pgx includes built-in support to marshal and unmarshal between Go types and the PostgreSQL JSON and JSONB. -Inet and Cidr Mapping +Inet and CIDR Mapping pgx encodes from net.IPNet to and from inet and cidr PostgreSQL types. In addition, as a convenience pgx will encode from a net.IP; it will assume a /32 @@ -155,25 +148,10 @@ netmask for IPv4 and a /128 for IPv6. Custom Type Support pgx includes support for the common data types like integers, floats, strings, -dates, and times that have direct mappings between Go and SQL. Support can be -added for additional types like point, hstore, numeric, etc. that do not have -direct mappings in Go by the types implementing ScannerPgx and Encoder. - -Custom types can support text or binary formats. Binary format can provide a -large performance increase. The natural place for deciding the format for a -value would be in ScannerPgx as it is responsible for decoding the returned -data. However, that is impossible as the query has already been sent by the time -the ScannerPgx is invoked. The solution to this is the global -DefaultTypeFormats. If a custom type prefers binary format it should register it -there. - - pgx.DefaultTypeFormats["point"] = pgx.BinaryFormatCode - -Note that the type is referred to by name, not by OID. This is because custom -PostgreSQL types like hstore will have different OIDs on different servers. When -pgx establishes a connection it queries the pg_type table for all types. It then -matches the names in DefaultTypeFormats with the returned OIDs and stores it in -Conn.PgTypes. +dates, and times that have direct mappings between Go and SQL. In addition, +pgx uses the github.com/jackc/pgx/pgtype library to support more types. See +documention for that library for instructions on how to implement custom +types. See example_custom_type_test.go for an example of a custom type for the PostgreSQL point type. @@ -184,15 +162,12 @@ and database/sql/driver.Valuer interfaces. Raw Bytes Mapping []byte passed as arguments to Query, QueryRow, and Exec are passed unmodified -to PostgreSQL. In like manner, a *[]byte passed to Scan will be filled with -the raw bytes returned by PostgreSQL. This can be especially useful for reading -varchar, text, json, and jsonb values directly into a []byte and avoiding the -type conversion from string. +to PostgreSQL. Transactions -Transactions are started by calling Begin or BeginIso. The BeginIso variant -creates a transaction with a specified isolation level. +Transactions are started by calling Begin or BeginEx. The BeginEx variant +can create a transaction with a specified isolation level. tx, err := conn.Begin() if err != nil { @@ -257,9 +232,8 @@ connection. Logging pgx defines a simple logger interface. Connections optionally accept a logger -that satisfies this interface. The log15 package -(http://gopkg.in/inconshreveable/log15.v2) satisfies this interface and it is -simple to define adapters for other loggers. Set LogLevel to control logging -verbosity. +that satisfies this interface. Set LogLevel to control logging verbosity. +Adapters for github.com/inconshreveable/log15, github.com/Sirupsen/logrus, and +the testing log are provided in the log directory. */ package pgx diff --git a/example_custom_type_test.go b/example_custom_type_test.go index 86b6cdf2..d3cc9085 100644 --- a/example_custom_type_test.go +++ b/example_custom_type_test.go @@ -1,81 +1,74 @@ package pgx_test import ( - "errors" "fmt" "regexp" "strconv" "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" + "github.com/pkg/errors" ) var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`) -// NullPoint represents a point that may be null. -// -// If Valid is false then the value is NULL. -type NullPoint struct { - X, Y float64 // Coordinates of point - Valid bool // Valid is true if not NULL +// Point represents a point that may be null. +type Point struct { + X, Y float64 // Coordinates of point + Status pgtype.Status } -func (p *NullPoint) ScanPgx(vr *pgx.ValueReader) error { - if vr.Type().DataTypeName != "point" { - return pgx.SerializationError(fmt.Sprintf("NullPoint.Scan cannot decode %s (OID %d)", vr.Type().DataTypeName, vr.Type().DataType)) - } +func (dst *Point) Set(src interface{}) error { + return errors.Errorf("cannot convert %v to Point", src) +} - if vr.Len() == -1 { - p.X, p.Y, p.Valid = 0, 0, false +func (dst *Point) Get() interface{} { + switch dst.Status { + case pgtype.Present: + return dst + case pgtype.Null: return nil - } - - switch vr.Type().FormatCode { - case pgx.TextFormatCode: - s := vr.ReadString(vr.Len()) - match := pointRegexp.FindStringSubmatch(s) - if match == nil { - return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s)) - } - - var err error - p.X, err = strconv.ParseFloat(match[1], 64) - if err != nil { - return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s)) - } - p.Y, err = strconv.ParseFloat(match[2], 64) - if err != nil { - return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s)) - } - case pgx.BinaryFormatCode: - return errors.New("binary format not implemented") default: - return fmt.Errorf("unknown format %v", vr.Type().FormatCode) + return dst.Status } - - p.Valid = true - return vr.Err() } -func (p NullPoint) FormatCode() int16 { return pgx.TextFormatCode } +func (src *Point) AssignTo(dst interface{}) error { + return errors.Errorf("cannot assign %v to %T", src, dst) +} -func (p NullPoint) Encode(w *pgx.WriteBuf, oid pgx.Oid) error { - if !p.Valid { - w.WriteInt32(-1) +func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + *dst = Point{Status: pgtype.Null} return nil } - s := fmt.Sprintf("(%v,%v)", p.X, p.Y) - w.WriteInt32(int32(len(s))) - w.WriteBytes([]byte(s)) + s := string(src) + match := pointRegexp.FindStringSubmatch(s) + if match == nil { + return errors.Errorf("Received invalid point: %v", s) + } + + x, err := strconv.ParseFloat(match[1], 64) + if err != nil { + return errors.Errorf("Received invalid point: %v", s) + } + y, err := strconv.ParseFloat(match[2], 64) + if err != nil { + return errors.Errorf("Received invalid point: %v", s) + } + + *dst = Point{X: x, Y: y, Status: pgtype.Present} return nil } -func (p NullPoint) String() string { - if p.Valid { - return fmt.Sprintf("%v, %v", p.X, p.Y) +func (src *Point) String() string { + if src.Status == pgtype.Null { + return "null point" } - return "null point" + + return fmt.Sprintf("%.1f, %.1f", src.X, src.Y) } func Example_CustomType() { @@ -85,22 +78,22 @@ func Example_CustomType() { return } - var p NullPoint - err = conn.QueryRow("select null::point").Scan(&p) + // Override registered handler for point + conn.ConnInfo.RegisterDataType(pgtype.DataType{ + Value: &Point{}, + Name: "point", + OID: 600, + }) + + p := &Point{} + err = conn.QueryRow("select null::point").Scan(p) if err != nil { fmt.Println(err) return } fmt.Println(p) - err = conn.QueryRow("select point(1.5,2.5)").Scan(&p) - if err != nil { - fmt.Println(err) - return - } - fmt.Println(p) - - err = conn.QueryRow("select $1::point", &NullPoint{X: 0.5, Y: 0.75, Valid: true}).Scan(&p) + err = conn.QueryRow("select point(1.5,2.5)").Scan(p) if err != nil { fmt.Println(err) return @@ -109,5 +102,4 @@ func Example_CustomType() { // Output: // null point // 1.5, 2.5 - // 0.5, 0.75 } diff --git a/example_json_test.go b/example_json_test.go index c1534158..09e27cff 100644 --- a/example_json_test.go +++ b/example_json_test.go @@ -2,6 +2,7 @@ package pgx_test import ( "fmt" + "github.com/jackc/pgx" ) @@ -12,13 +13,6 @@ func Example_JSON() { return } - if _, ok := conn.PgTypes[pgx.JsonOid]; !ok { - // No JSON type -- must be running against very old PostgreSQL - // Pretend it works - fmt.Println("John", 42) - return - } - type person struct { Name string `json:"name"` Age int `json:"age"` diff --git a/examples/chat/main.go b/examples/chat/main.go index 517508cc..69ef456b 100644 --- a/examples/chat/main.go +++ b/examples/chat/main.go @@ -2,10 +2,11 @@ package main import ( "bufio" + "context" "fmt" - "github.com/jackc/pgx" "os" - "time" + + "github.com/jackc/pgx" ) var pool *pgx.ConnPool @@ -58,16 +59,13 @@ func listen() { conn.Listen("chat") for { - notification, err := conn.WaitForNotification(time.Second) - if err == pgx.ErrNotificationTimeout { - continue - } + notification, err := conn.WaitForNotification(context.Background()) if err != nil { fmt.Fprintln(os.Stderr, "Error waiting for notification:", err) os.Exit(1) } - fmt.Println("PID:", notification.Pid, "Channel:", notification.Channel, "Payload:", notification.Payload) + fmt.Println("PID:", notification.PID, "Channel:", notification.Channel, "Payload:", notification.Payload) } } diff --git a/examples/url_shortener/main.go b/examples/url_shortener/main.go index 695a5be6..c6576a3a 100644 --- a/examples/url_shortener/main.go +++ b/examples/url_shortener/main.go @@ -1,11 +1,13 @@ package main import ( - "github.com/jackc/pgx" - log "gopkg.in/inconshreveable/log15.v2" "io/ioutil" "net/http" "os" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/log/log15adapter" + log "gopkg.in/inconshreveable/log15.v2" ) var pool *pgx.ConnPool @@ -89,6 +91,8 @@ func urlHandler(w http.ResponseWriter, req *http.Request) { } func main() { + logger := log15adapter.NewLogger(log.New("module", "pgx")) + var err error connPoolConfig := pgx.ConnPoolConfig{ ConnConfig: pgx.ConnConfig{ @@ -96,7 +100,7 @@ func main() { User: "jack", Password: "jack", Database: "url_shortener", - Logger: log.New("module", "pgx"), + Logger: logger, }, MaxConnections: 5, AfterConnect: afterConnect, diff --git a/fastpath.go b/fastpath.go index 19b98784..06e1354a 100644 --- a/fastpath.go +++ b/fastpath.go @@ -2,29 +2,33 @@ package pgx import ( "encoding/binary" + + "github.com/jackc/pgx/pgio" + "github.com/jackc/pgx/pgproto3" + "github.com/jackc/pgx/pgtype" ) func newFastpath(cn *Conn) *fastpath { - return &fastpath{cn: cn, fns: make(map[string]Oid)} + return &fastpath{cn: cn, fns: make(map[string]pgtype.OID)} } type fastpath struct { cn *Conn - fns map[string]Oid + fns map[string]pgtype.OID } -func (f *fastpath) functionOID(name string) Oid { +func (f *fastpath) functionOID(name string) pgtype.OID { return f.fns[name] } -func (f *fastpath) addFunction(name string, oid Oid) { +func (f *fastpath) addFunction(name string, oid pgtype.OID) { f.fns[name] = oid } func (f *fastpath) addFunctions(rows *Rows) error { for rows.Next() { var name string - var oid Oid + var oid pgtype.OID if err := rows.Scan(&name, &oid); err != nil { return err } @@ -47,41 +51,46 @@ func fpInt64Arg(n int64) fpArg { return res } -func (f *fastpath) Call(oid Oid, args []fpArg) (res []byte, err error) { - wbuf := newWriteBuf(f.cn, 'F') // function call - wbuf.WriteInt32(int32(oid)) // function object id - wbuf.WriteInt16(1) // # of argument format codes - wbuf.WriteInt16(1) // format code: binary - wbuf.WriteInt16(int16(len(args))) // # of arguments - for _, arg := range args { - wbuf.WriteInt32(int32(len(arg))) // length of argument - wbuf.WriteBytes(arg) // argument value +func (f *fastpath) Call(oid pgtype.OID, args []fpArg) (res []byte, err error) { + if err := f.cn.ensureConnectionReadyForQuery(); err != nil { + return nil, err } - wbuf.WriteInt16(1) // response format code (binary) - wbuf.closeMsg() - if _, err := f.cn.conn.Write(wbuf.buf); err != nil { + buf := f.cn.wbuf + buf = append(buf, 'F') // function call + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf = pgio.AppendInt32(buf, int32(oid)) // function object id + buf = pgio.AppendInt16(buf, 1) // # of argument format codes + buf = pgio.AppendInt16(buf, 1) // format code: binary + buf = pgio.AppendInt16(buf, int16(len(args))) // # of arguments + for _, arg := range args { + buf = pgio.AppendInt32(buf, int32(len(arg))) // length of argument + buf = append(buf, arg...) // argument value + } + buf = pgio.AppendInt16(buf, 1) // response format code (binary) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + if _, err := f.cn.conn.Write(buf); err != nil { return nil, err } for { - var t byte - var r *msgReader - t, r, err = f.cn.rxMsg() + msg, err := f.cn.rxMsg() if err != nil { return nil, err } - switch t { - case 'V': // FunctionCallResponse - data := r.readBytes(r.readInt32()) - res = make([]byte, len(data)) - copy(res, data) - case 'Z': // Ready for query - f.cn.rxReadyForQuery(r) + switch msg := msg.(type) { + case *pgproto3.FunctionCallResponse: + res = make([]byte, len(msg.Result)) + copy(res, msg.Result) + case *pgproto3.ReadyForQuery: + f.cn.rxReadyForQuery(msg) // done - return + return res, err default: - if err := f.cn.processContextFreeMsg(t, r); err != nil { + if err := f.cn.processContextFreeMsg(msg); err != nil { return nil, err } } diff --git a/helper_test.go b/helper_test.go index eff731e8..78063107 100644 --- a/helper_test.go +++ b/helper_test.go @@ -1,8 +1,9 @@ package pgx_test import ( - "github.com/jackc/pgx" "testing" + + "github.com/jackc/pgx" ) func mustConnect(t testing.TB, config pgx.ConnConfig) *pgx.Conn { @@ -21,7 +22,6 @@ func mustReplicationConnect(t testing.TB, config pgx.ConnConfig) *pgx.Replicatio return conn } - func closeConn(t testing.TB, conn *pgx.Conn) { err := conn.Close() if err != nil { diff --git a/hstore.go b/hstore.go deleted file mode 100644 index 0ab9f779..00000000 --- a/hstore.go +++ /dev/null @@ -1,222 +0,0 @@ -package pgx - -import ( - "bytes" - "errors" - "fmt" - "unicode" - "unicode/utf8" -) - -const ( - hsPre = iota - hsKey - hsSep - hsVal - hsNul - hsNext -) - -type hstoreParser struct { - str string - pos int -} - -func newHSP(in string) *hstoreParser { - return &hstoreParser{ - pos: 0, - str: in, - } -} - -func (p *hstoreParser) Consume() (r rune, end bool) { - if p.pos >= len(p.str) { - end = true - return - } - r, w := utf8.DecodeRuneInString(p.str[p.pos:]) - p.pos += w - return -} - -func (p *hstoreParser) Peek() (r rune, end bool) { - if p.pos >= len(p.str) { - end = true - return - } - r, _ = utf8.DecodeRuneInString(p.str[p.pos:]) - return -} - -func parseHstoreToMap(s string) (m map[string]string, err error) { - keys, values, err := ParseHstore(s) - if err != nil { - return - } - m = make(map[string]string, len(keys)) - for i, key := range keys { - if !values[i].Valid { - err = fmt.Errorf("key '%s' has NULL value", key) - m = nil - return - } - m[key] = values[i].String - } - return -} - -func parseHstoreToNullHstore(s string) (store map[string]NullString, err error) { - keys, values, err := ParseHstore(s) - if err != nil { - return - } - - store = make(map[string]NullString, len(keys)) - - for i, key := range keys { - store[key] = values[i] - } - return -} - -// ParseHstore parses the string representation of an hstore column (the same -// you would get from an ordinary SELECT) into two slices of keys and values. it -// is used internally in the default parsing of hstores, but is exported for use -// in handling custom data structures backed by an hstore column without the -// overhead of creating a map[string]string -func ParseHstore(s string) (k []string, v []NullString, err error) { - if s == "" { - return - } - - buf := bytes.Buffer{} - keys := []string{} - values := []NullString{} - p := newHSP(s) - - r, end := p.Consume() - state := hsPre - - for !end { - switch state { - case hsPre: - if r == '"' { - state = hsKey - } else { - err = errors.New("String does not begin with \"") - } - case hsKey: - switch r { - case '"': //End of the key - if buf.Len() == 0 { - err = errors.New("Empty Key is invalid") - } else { - keys = append(keys, buf.String()) - buf = bytes.Buffer{} - state = hsSep - } - case '\\': //Potential escaped character - n, end := p.Consume() - switch { - case end: - err = errors.New("Found EOS in key, expecting character or \"") - case n == '"', n == '\\': - buf.WriteRune(n) - default: - buf.WriteRune(r) - buf.WriteRune(n) - } - default: //Any other character - buf.WriteRune(r) - } - case hsSep: - if r == '=' { - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after '=', expecting '>'") - case r == '>': - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'") - case r == '"': - state = hsVal - case r == 'N': - state = hsNul - default: - err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r) - } - default: - err = fmt.Errorf("Invalid character after '=', expecting '>'") - } - } else { - err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r) - } - case hsVal: - switch r { - case '"': //End of the value - values = append(values, NullString{String: buf.String(), Valid: true}) - buf = bytes.Buffer{} - state = hsNext - case '\\': //Potential escaped character - n, end := p.Consume() - switch { - case end: - err = errors.New("Found EOS in key, expecting character or \"") - case n == '"', n == '\\': - buf.WriteRune(n) - default: - buf.WriteRune(r) - buf.WriteRune(n) - } - default: //Any other character - buf.WriteRune(r) - } - case hsNul: - nulBuf := make([]rune, 3) - nulBuf[0] = r - for i := 1; i < 3; i++ { - r, end = p.Consume() - if end { - err = errors.New("Found EOS in NULL value") - return - } - nulBuf[i] = r - } - if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' { - values = append(values, NullString{String: "", Valid: false}) - state = hsNext - } else { - err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) - } - case hsNext: - if r == ',' { - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after ',', expcting space") - case (unicode.IsSpace(r)): - r, end = p.Consume() - state = hsKey - default: - err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r) - } - } else { - err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r) - } - } - - if err != nil { - return - } - r, end = p.Consume() - } - if state != hsNext { - err = errors.New("Improperly formatted hstore") - return - } - k = keys - v = values - return -} diff --git a/hstore_test.go b/hstore_test.go deleted file mode 100644 index c948f0cd..00000000 --- a/hstore_test.go +++ /dev/null @@ -1,181 +0,0 @@ -package pgx_test - -import ( - "github.com/jackc/pgx" - "testing" -) - -func TestHstoreTranscode(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - type test struct { - hstore pgx.Hstore - description string - } - - tests := []test{ - {pgx.Hstore{}, "empty"}, - {pgx.Hstore{"foo": "bar"}, "single key/value"}, - {pgx.Hstore{"foo": "bar", "baz": "quz"}, "multiple key/values"}, - {pgx.Hstore{"NULL": "bar"}, `string "NULL" key`}, - {pgx.Hstore{"foo": "NULL"}, `string "NULL" value`}, - } - - specialStringTests := []struct { - input string - description string - }{ - {`"`, `double quote (")`}, - {`'`, `single quote (')`}, - {`\`, `backslash (\)`}, - {`\\`, `multiple backslashes (\\)`}, - {`=>`, `separator (=>)`}, - {` `, `space`}, - {`\ / / \\ => " ' " '`, `multiple special characters`}, - } - for _, sst := range specialStringTests { - tests = append(tests, test{pgx.Hstore{sst.input + "foo": "bar"}, "key with " + sst.description + " at beginning"}) - tests = append(tests, test{pgx.Hstore{"foo" + sst.input + "foo": "bar"}, "key with " + sst.description + " in middle"}) - tests = append(tests, test{pgx.Hstore{"foo" + sst.input: "bar"}, "key with " + sst.description + " at end"}) - tests = append(tests, test{pgx.Hstore{sst.input: "bar"}, "key is " + sst.description}) - - tests = append(tests, test{pgx.Hstore{"foo": sst.input + "bar"}, "value with " + sst.description + " at beginning"}) - tests = append(tests, test{pgx.Hstore{"foo": "bar" + sst.input + "bar"}, "value with " + sst.description + " in middle"}) - tests = append(tests, test{pgx.Hstore{"foo": "bar" + sst.input}, "value with " + sst.description + " at end"}) - tests = append(tests, test{pgx.Hstore{"foo": sst.input}, "value is " + sst.description}) - } - - for _, tt := range tests { - var result pgx.Hstore - err := conn.QueryRow("select $1::hstore", tt.hstore).Scan(&result) - if err != nil { - t.Errorf(`%s: QueryRow.Scan returned an error: %v`, tt.description, err) - } - - for key, inValue := range tt.hstore { - outValue, ok := result[key] - if ok { - if inValue != outValue { - t.Errorf(`%s: Key %s mismatch - expected %s, received %s`, tt.description, key, inValue, outValue) - } - } else { - t.Errorf(`%s: Missing key %s`, tt.description, key) - } - } - - ensureConnValid(t, conn) - } -} - -func TestNullHstoreTranscode(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - type test struct { - nullHstore pgx.NullHstore - description string - } - - tests := []test{ - {pgx.NullHstore{}, "null"}, - {pgx.NullHstore{Valid: true}, "empty"}, - {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}}, - Valid: true}, - "single key/value"}, - {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}, "baz": {String: "quz", Valid: true}}, - Valid: true}, - "multiple key/values"}, - {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"NULL": {String: "bar", Valid: true}}, - Valid: true}, - `string "NULL" key`}, - {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: "NULL", Valid: true}}, - Valid: true}, - `string "NULL" value`}, - {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: "", Valid: false}}, - Valid: true}, - `NULL value`}, - } - - specialStringTests := []struct { - input string - description string - }{ - {`"`, `double quote (")`}, - {`'`, `single quote (')`}, - {`\`, `backslash (\)`}, - {`\\`, `multiple backslashes (\\)`}, - {`=>`, `separator (=>)`}, - {` `, `space`}, - {`\ / / \\ => " ' " '`, `multiple special characters`}, - } - for _, sst := range specialStringTests { - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{sst.input + "foo": {String: "bar", Valid: true}}, - Valid: true}, - "key with " + sst.description + " at beginning"}) - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo" + sst.input + "foo": {String: "bar", Valid: true}}, - Valid: true}, - "key with " + sst.description + " in middle"}) - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo" + sst.input: {String: "bar", Valid: true}}, - Valid: true}, - "key with " + sst.description + " at end"}) - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{sst.input: {String: "bar", Valid: true}}, - Valid: true}, - "key is " + sst.description}) - - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: sst.input + "bar", Valid: true}}, - Valid: true}, - "value with " + sst.description + " at beginning"}) - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input + "bar", Valid: true}}, - Valid: true}, - "value with " + sst.description + " in middle"}) - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input, Valid: true}}, - Valid: true}, - "value with " + sst.description + " at end"}) - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: sst.input, Valid: true}}, - Valid: true}, - "value is " + sst.description}) - } - - for _, tt := range tests { - var result pgx.NullHstore - err := conn.QueryRow("select $1::hstore", tt.nullHstore).Scan(&result) - if err != nil { - t.Errorf(`%s: QueryRow.Scan returned an error: %v`, tt.description, err) - } - - if result.Valid != tt.nullHstore.Valid { - t.Errorf(`%s: Valid mismatch - expected %v, received %v`, tt.description, tt.nullHstore.Valid, result.Valid) - } - - for key, inValue := range tt.nullHstore.Hstore { - outValue, ok := result.Hstore[key] - if ok { - if inValue != outValue { - t.Errorf(`%s: Key %s mismatch - expected %v, received %v`, tt.description, key, inValue, outValue) - } - } else { - t.Errorf(`%s: Missing key %s`, tt.description, key) - } - } - - ensureConnValid(t, conn) - } -} diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go new file mode 100644 index 00000000..53543b89 --- /dev/null +++ b/internal/sanitize/sanitize.go @@ -0,0 +1,237 @@ +package sanitize + +import ( + "bytes" + "encoding/hex" + "strconv" + "strings" + "time" + "unicode/utf8" + + "github.com/pkg/errors" +) + +// Part is either a string or an int. A string is raw SQL. An int is a +// argument placeholder. +type Part interface{} + +type Query struct { + Parts []Part +} + +func (q *Query) Sanitize(args ...interface{}) (string, error) { + argUse := make([]bool, len(args)) + buf := &bytes.Buffer{} + + for _, part := range q.Parts { + var str string + switch part := part.(type) { + case string: + str = part + case int: + argIdx := part - 1 + if argIdx >= len(args) { + return "", errors.Errorf("insufficient arguments") + } + arg := args[argIdx] + switch arg := arg.(type) { + case nil: + str = "null" + case int64: + str = strconv.FormatInt(arg, 10) + case float64: + str = strconv.FormatFloat(arg, 'f', -1, 64) + case bool: + str = strconv.FormatBool(arg) + case []byte: + str = QuoteBytes(arg) + case string: + str = QuoteString(arg) + case time.Time: + str = arg.Format("'2006-01-02 15:04:05.999999999Z07:00:00'") + default: + return "", errors.Errorf("invalid arg type: %T", arg) + } + argUse[argIdx] = true + default: + return "", errors.Errorf("invalid Part type: %T", part) + } + buf.WriteString(str) + } + + for i, used := range argUse { + if !used { + return "", errors.Errorf("unused argument: %d", i) + } + } + return buf.String(), nil +} + +func NewQuery(sql string) (*Query, error) { + l := &sqlLexer{ + src: sql, + stateFn: rawState, + } + + for l.stateFn != nil { + l.stateFn = l.stateFn(l) + } + + query := &Query{Parts: l.parts} + + return query, nil +} + +func QuoteString(str string) string { + return "'" + strings.Replace(str, "'", "''", -1) + "'" +} + +func QuoteBytes(buf []byte) string { + return `'\x` + hex.EncodeToString(buf) + "'" +} + +type sqlLexer struct { + src string + start int + pos int + stateFn stateFn + parts []Part +} + +type stateFn func(*sqlLexer) stateFn + +func rawState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case 'e', 'E': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '\'' { + l.pos += width + return escapeStringState + } + case '\'': + return singleQuoteState + case '"': + return doubleQuoteState + case '$': + nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) + if '0' <= nextRune && nextRune <= '9' { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos-width]) + } + l.start = l.pos + return placeholderState + } + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func singleQuoteState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\'': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '\'' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func doubleQuoteState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '"': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '"' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +// placeholderState consumes a placeholder value. The $ must have already has +// already been consumed. The first rune must be a digit. +func placeholderState(l *sqlLexer) stateFn { + num := 0 + + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + if '0' <= r && r <= '9' { + num *= 10 + num += int(r - '0') + } else { + l.parts = append(l.parts, num) + l.pos -= width + l.start = l.pos + return rawState + } + } +} + +func escapeStringState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\\': + _, width = utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + case '\'': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '\'' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +// SanitizeSQL replaces placeholder values with args. It quotes and escapes args +// as necessary. This function is only safe when standard_conforming_strings is +// on. +func SanitizeSQL(sql string, args ...interface{}) (string, error) { + query, err := NewQuery(sql) + if err != nil { + return "", err + } + return query.Sanitize(args...) +} diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go new file mode 100644 index 00000000..9597840e --- /dev/null +++ b/internal/sanitize/sanitize_test.go @@ -0,0 +1,175 @@ +package sanitize_test + +import ( + "testing" + + "github.com/jackc/pgx/internal/sanitize" +) + +func TestNewQuery(t *testing.T) { + successTests := []struct { + sql string + expected sanitize.Query + }{ + { + sql: "select 42", + expected: sanitize.Query{Parts: []sanitize.Part{"select 42"}}, + }, + { + sql: "select $1", + expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + }, + { + sql: "select 'quoted $42', $1", + expected: sanitize.Query{Parts: []sanitize.Part{"select 'quoted $42', ", 1}}, + }, + { + sql: `select "doubled quoted $42", $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select "doubled quoted $42", `, 1}}, + }, + { + sql: "select 'foo''bar', $1", + expected: sanitize.Query{Parts: []sanitize.Part{"select 'foo''bar', ", 1}}, + }, + { + sql: `select "foo""bar", $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select "foo""bar", `, 1}}, + }, + { + sql: "select '''', $1", + expected: sanitize.Query{Parts: []sanitize.Part{"select '''', ", 1}}, + }, + { + sql: `select """", $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select """", `, 1}}, + }, + { + sql: "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11", + expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2, ", ", 3, ", ", 4, ", ", 5, ", ", 6, ", ", 7, ", ", 8, ", ", 9, ", ", 10, ", ", 11}}, + }, + { + sql: `select "adsf""$1""adsf", $1, 'foo''$$12bar', $2, '$3'`, + expected: sanitize.Query{Parts: []sanitize.Part{`select "adsf""$1""adsf", `, 1, `, 'foo''$$12bar', `, 2, `, '$3'`}}, + }, + { + sql: `select E'escape string\' $42', $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select E'escape string\' $42', `, 1}}, + }, + { + sql: `select e'escape string\' $42', $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}}, + }, + } + + for i, tt := range successTests { + query, err := sanitize.NewQuery(tt.sql) + if err != nil { + t.Errorf("%d. %v", i, err) + } + + if len(query.Parts) == len(tt.expected.Parts) { + for j := range query.Parts { + if query.Parts[j] != tt.expected.Parts[j] { + t.Errorf("%d. expected part %d to be %v but it was %v", i, j, tt.expected.Parts[j], query.Parts[j]) + } + } + } else { + t.Errorf("%d. expected query parts to be %v but it was %v", i, tt.expected.Parts, query.Parts) + } + } +} + +func TestQuerySanitize(t *testing.T) { + successfulTests := []struct { + query sanitize.Query + args []interface{} + expected string + }{ + { + query: sanitize.Query{Parts: []sanitize.Part{"select 42"}}, + args: []interface{}{}, + expected: `select 42`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []interface{}{int64(42)}, + expected: `select 42`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []interface{}{float64(1.23)}, + expected: `select 1.23`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []interface{}{true}, + expected: `select true`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []interface{}{[]byte{0, 1, 2, 3, 255}}, + expected: `select '\x00010203ff'`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []interface{}{nil}, + expected: `select null`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []interface{}{"foobar"}, + expected: `select 'foobar'`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []interface{}{"foo'bar"}, + expected: `select 'foo''bar'`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []interface{}{`foo\'bar`}, + expected: `select 'foo\''bar'`, + }, + } + + for i, tt := range successfulTests { + actual, err := tt.query.Sanitize(tt.args...) + if err != nil { + t.Errorf("%d. %v", i, err) + continue + } + + if tt.expected != actual { + t.Errorf("%d. expected %s, but got %s", i, tt.expected, actual) + } + } + + errorTests := []struct { + query sanitize.Query + args []interface{} + expected string + }{ + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}}, + args: []interface{}{int64(42)}, + expected: `insufficient arguments`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}}, + args: []interface{}{int64(42)}, + expected: `unused argument: 0`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []interface{}{42}, + expected: `invalid arg type: int`, + }, + } + + for i, tt := range errorTests { + _, err := tt.query.Sanitize(tt.args...) + if err == nil || err.Error() != tt.expected { + t.Errorf("%d. expected error %v, got %v", i, tt.expected, err) + } + } +} diff --git a/large_objects.go b/large_objects.go index a4922ef1..e109bce2 100644 --- a/large_objects.go +++ b/large_objects.go @@ -2,6 +2,8 @@ package pgx import ( "io" + + "github.com/jackc/pgx/pgtype" ) // LargeObjects is a structure used to access the large objects API. It is only @@ -14,20 +16,20 @@ type LargeObjects struct { fp *fastpath } -const largeObjectFns = `select proname, oid from pg_catalog.pg_proc +const largeObjectFns = `select proname, oid from pg_catalog.pg_proc where proname in ( -'lo_open', -'lo_close', -'lo_create', -'lo_unlink', -'lo_lseek', -'lo_lseek64', -'lo_tell', -'lo_tell64', -'lo_truncate', -'lo_truncate64', -'loread', -'lowrite') +'lo_open', +'lo_close', +'lo_create', +'lo_unlink', +'lo_lseek', +'lo_lseek64', +'lo_tell', +'lo_tell64', +'lo_truncate', +'lo_truncate64', +'loread', +'lowrite') and pronamespace = (select oid from pg_catalog.pg_namespace where nspname = 'pg_catalog')` // LargeObjects returns a LargeObjects instance for the transaction. @@ -60,19 +62,19 @@ const ( // Create creates a new large object. If id is zero, the server assigns an // unused OID. -func (o *LargeObjects) Create(id Oid) (Oid, error) { - newOid, err := fpInt32(o.fp.CallFn("lo_create", []fpArg{fpIntArg(int32(id))})) - return Oid(newOid), err +func (o *LargeObjects) Create(id pgtype.OID) (pgtype.OID, error) { + newOID, err := fpInt32(o.fp.CallFn("lo_create", []fpArg{fpIntArg(int32(id))})) + return pgtype.OID(newOID), err } // Open opens an existing large object with the given mode. -func (o *LargeObjects) Open(oid Oid, mode LargeObjectMode) (*LargeObject, error) { +func (o *LargeObjects) Open(oid pgtype.OID, mode LargeObjectMode) (*LargeObject, error) { fd, err := fpInt32(o.fp.CallFn("lo_open", []fpArg{fpIntArg(int32(oid)), fpIntArg(int32(mode))})) return &LargeObject{fd: fd, lo: o}, err } // Unlink removes a large object from the database. -func (o *LargeObjects) Unlink(oid Oid) error { +func (o *LargeObjects) Unlink(oid pgtype.OID) error { _, err := o.fp.CallFn("lo_unlink", []fpArg{fpIntArg(int32(oid))}) return err } diff --git a/log/log15adapter/adapter.go b/log/log15adapter/adapter.go new file mode 100644 index 00000000..8623a380 --- /dev/null +++ b/log/log15adapter/adapter.go @@ -0,0 +1,47 @@ +// Package log15adapter provides a logger that writes to a github.com/inconshreveable/log15.Logger +// log. +package log15adapter + +import ( + "github.com/jackc/pgx" +) + +// Log15Logger interface defines the subset of +// github.com/inconshreveable/log15.Logger that this adapter uses. +type Log15Logger interface { + Debug(msg string, ctx ...interface{}) + Info(msg string, ctx ...interface{}) + Warn(msg string, ctx ...interface{}) + Error(msg string, ctx ...interface{}) + Crit(msg string, ctx ...interface{}) +} + +type Logger struct { + l Log15Logger +} + +func NewLogger(l Log15Logger) *Logger { + return &Logger{l: l} +} + +func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { + logArgs := make([]interface{}, 0, len(data)) + for k, v := range data { + logArgs = append(logArgs, k, v) + } + + switch level { + case pgx.LogLevelTrace: + l.l.Debug(msg, append(logArgs, "PGX_LOG_LEVEL", level)...) + case pgx.LogLevelDebug: + l.l.Debug(msg, logArgs...) + case pgx.LogLevelInfo: + l.l.Info(msg, logArgs...) + case pgx.LogLevelWarn: + l.l.Warn(msg, logArgs...) + case pgx.LogLevelError: + l.l.Error(msg, logArgs...) + default: + l.l.Error(msg, append(logArgs, "INVALID_PGX_LOG_LEVEL", level)...) + } +} diff --git a/log/logrusadapter/adapter.go b/log/logrusadapter/adapter.go new file mode 100644 index 00000000..6084c36c --- /dev/null +++ b/log/logrusadapter/adapter.go @@ -0,0 +1,40 @@ +// Package logrusadapter provides a logger that writes to a github.com/Sirupsen/logrus.Logger +// log. +package logrusadapter + +import ( + "github.com/Sirupsen/logrus" + "github.com/jackc/pgx" +) + +type Logger struct { + l *logrus.Logger +} + +func NewLogger(l *logrus.Logger) *Logger { + return &Logger{l: l} +} + +func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { + var logger logrus.FieldLogger + if data != nil { + logger = l.l.WithFields(data) + } else { + logger = l.l + } + + switch level { + case pgx.LogLevelTrace: + logger.WithField("PGX_LOG_LEVEL", level).Debug(msg) + case pgx.LogLevelDebug: + logger.Debug(msg) + case pgx.LogLevelInfo: + logger.Info(msg) + case pgx.LogLevelWarn: + logger.Warn(msg) + case pgx.LogLevelError: + logger.Error(msg) + default: + logger.WithField("INVALID_PGX_LOG_LEVEL", level).Error(msg) + } +} diff --git a/log/testingadapter/adapter.go b/log/testingadapter/adapter.go new file mode 100644 index 00000000..6c9cde83 --- /dev/null +++ b/log/testingadapter/adapter.go @@ -0,0 +1,32 @@ +// Package testingadapter provides a logger that writes to a test or benchmark +// log. +package testingadapter + +import ( + "fmt" + + "github.com/jackc/pgx" +) + +// TestingLogger interface defines the subset of testing.TB methods used by this +// adapter. +type TestingLogger interface { + Log(args ...interface{}) +} + +type Logger struct { + l TestingLogger +} + +func NewLogger(l TestingLogger) *Logger { + return &Logger{l: l} +} + +func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { + logArgs := make([]interface{}, 0, 2+len(data)) + logArgs = append(logArgs, level, msg) + for k, v := range data { + logArgs = append(logArgs, fmt.Sprintf("%s=%v", k, v)) + } + l.l.Log(logArgs...) +} diff --git a/logger.go b/logger.go index 4423325c..528698b1 100644 --- a/logger.go +++ b/logger.go @@ -2,13 +2,13 @@ package pgx import ( "encoding/hex" - "errors" "fmt" + + "github.com/pkg/errors" ) // The values for log levels are chosen such that the zero value means that no -// log level was specified and we can default to LogLevelDebug to preserve -// the behavior that existed prior to log level introduction. +// log level was specified. const ( LogLevelTrace = 6 LogLevelDebug = 5 @@ -18,16 +18,33 @@ const ( LogLevelNone = 1 ) +// LogLevel represents the pgx logging level. See LogLevel* constants for +// possible values. +type LogLevel int + +func (ll LogLevel) String() string { + switch ll { + case LogLevelTrace: + return "trace" + case LogLevelDebug: + return "debug" + case LogLevelInfo: + return "info" + case LogLevelWarn: + return "warn" + case LogLevelError: + return "error" + case LogLevelNone: + return "none" + default: + return fmt.Sprintf("invalid level %d", ll) + } +} + // Logger is the interface used to get logging from pgx internals. -// https://github.com/inconshreveable/log15 is the recommended logging package. -// This logging interface was extracted from there. However, it should be simple -// to adapt any logger to this interface. type Logger interface { - // Log a message at the given level with context key/value pairs - Debug(msg string, ctx ...interface{}) - Info(msg string, ctx ...interface{}) - Warn(msg string, ctx ...interface{}) - Error(msg string, ctx ...interface{}) + // Log a message at the given level with data key/value pairs. data may be nil. + Log(level LogLevel, msg string, data map[string]interface{}) } // LogLevelFromString converts log level string to constant @@ -39,7 +56,7 @@ type Logger interface { // warn // error // none -func LogLevelFromString(s string) (int, error) { +func LogLevelFromString(s string) (LogLevel, error) { switch s { case "trace": return LogLevelTrace, nil diff --git a/messages.go b/messages.go index 317ba273..53a5a67c 100644 --- a/messages.go +++ b/messages.go @@ -1,66 +1,24 @@ package pgx import ( - "encoding/binary" + "github.com/jackc/pgx/pgio" + "github.com/jackc/pgx/pgtype" ) const ( - protocolVersionNumber = 196608 // 3.0 + copyData = 'd' + copyFail = 'f' + copyDone = 'c' ) -const ( - backendKeyData = 'K' - authenticationX = 'R' - readyForQuery = 'Z' - rowDescription = 'T' - dataRow = 'D' - commandComplete = 'C' - errorResponse = 'E' - noticeResponse = 'N' - parseComplete = '1' - parameterDescription = 't' - bindComplete = '2' - notificationResponse = 'A' - emptyQueryResponse = 'I' - noData = 'n' - closeComplete = '3' - flush = 'H' - copyInResponse = 'G' - copyData = 'd' - copyFail = 'f' - copyDone = 'c' -) - -type startupMessage struct { - options map[string]string -} - -func newStartupMessage() *startupMessage { - return &startupMessage{map[string]string{}} -} - -func (s *startupMessage) Bytes() (buf []byte) { - buf = make([]byte, 8, 128) - binary.BigEndian.PutUint32(buf[4:8], uint32(protocolVersionNumber)) - for key, value := range s.options { - buf = append(buf, key...) - buf = append(buf, 0) - buf = append(buf, value...) - buf = append(buf, 0) - } - buf = append(buf, ("\000")...) - binary.BigEndian.PutUint32(buf[0:4], uint32(len(buf))) - return buf -} - type FieldDescription struct { Name string - Table Oid - AttributeNumber int16 - DataType Oid + Table pgtype.OID + AttributeNumber uint16 + DataType pgtype.OID DataTypeSize int16 DataTypeName string - Modifier int32 + Modifier uint32 FormatCode int16 } @@ -91,69 +49,114 @@ func (pe PgError) Error() string { return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" } -func newWriteBuf(c *Conn, t byte) *WriteBuf { - buf := append(c.wbuf[0:0], t, 0, 0, 0, 0) - c.writeBuf = WriteBuf{buf: buf, sizeIdx: 1, conn: c} - return &c.writeBuf +// Notice represents a notice response message reported by the PostgreSQL +// server. Be aware that this is distinct from LISTEN/NOTIFY notification. +type Notice PgError + +// appendParse appends a PostgreSQL wire protocol parse message to buf and returns it. +func appendParse(buf []byte, name string, query string, parameterOIDs []pgtype.OID) []byte { + buf = append(buf, 'P') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, name...) + buf = append(buf, 0) + buf = append(buf, query...) + buf = append(buf, 0) + + buf = pgio.AppendInt16(buf, int16(len(parameterOIDs))) + for _, oid := range parameterOIDs { + buf = pgio.AppendUint32(buf, uint32(oid)) + } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf } -// WriteBuf is used build messages to send to the PostgreSQL server. It is used -// by the Encoder interface when implementing custom encoders. -type WriteBuf struct { - buf []byte - sizeIdx int - conn *Conn +// appendDescribe appends a PostgreSQL wire protocol describe message to buf and returns it. +func appendDescribe(buf []byte, objectType byte, name string) []byte { + buf = append(buf, 'D') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, objectType) + buf = append(buf, name...) + buf = append(buf, 0) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf } -func (wb *WriteBuf) startMsg(t byte) { - wb.closeMsg() - wb.buf = append(wb.buf, t, 0, 0, 0, 0) - wb.sizeIdx = len(wb.buf) - 4 +// appendSync appends a PostgreSQL wire protocol sync message to buf and returns it. +func appendSync(buf []byte) []byte { + buf = append(buf, 'S') + buf = pgio.AppendInt32(buf, 4) + + return buf } -func (wb *WriteBuf) closeMsg() { - binary.BigEndian.PutUint32(wb.buf[wb.sizeIdx:wb.sizeIdx+4], uint32(len(wb.buf)-wb.sizeIdx)) +// appendBind appends a PostgreSQL wire protocol bind message to buf and returns it. +func appendBind( + buf []byte, + destinationPortal, + preparedStatement string, + connInfo *pgtype.ConnInfo, + parameterOIDs []pgtype.OID, + arguments []interface{}, + resultFormatCodes []int16, +) ([]byte, error) { + buf = append(buf, 'B') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, destinationPortal...) + buf = append(buf, 0) + buf = append(buf, preparedStatement...) + buf = append(buf, 0) + + buf = pgio.AppendInt16(buf, int16(len(parameterOIDs))) + for i, oid := range parameterOIDs { + buf = pgio.AppendInt16(buf, chooseParameterFormatCode(connInfo, oid, arguments[i])) + } + + buf = pgio.AppendInt16(buf, int16(len(arguments))) + for i, oid := range parameterOIDs { + var err error + buf, err = encodePreparedStatementArgument(connInfo, buf, oid, arguments[i]) + if err != nil { + return nil, err + } + } + + buf = pgio.AppendInt16(buf, int16(len(resultFormatCodes))) + for _, fc := range resultFormatCodes { + buf = pgio.AppendInt16(buf, fc) + } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf, nil } -func (wb *WriteBuf) WriteByte(b byte) { - wb.buf = append(wb.buf, b) +// appendExecute appends a PostgreSQL wire protocol execute message to buf and returns it. +func appendExecute(buf []byte, portal string, maxRows uint32) []byte { + buf = append(buf, 'E') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf = append(buf, portal...) + buf = append(buf, 0) + buf = pgio.AppendUint32(buf, maxRows) + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf } -func (wb *WriteBuf) WriteCString(s string) { - wb.buf = append(wb.buf, []byte(s)...) - wb.buf = append(wb.buf, 0) -} +// appendQuery appends a PostgreSQL wire protocol query message to buf and returns it. +func appendQuery(buf []byte, query string) []byte { + buf = append(buf, 'Q') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, query...) + buf = append(buf, 0) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) -func (wb *WriteBuf) WriteInt16(n int16) { - b := make([]byte, 2) - binary.BigEndian.PutUint16(b, uint16(n)) - wb.buf = append(wb.buf, b...) -} - -func (wb *WriteBuf) WriteUint16(n uint16) { - b := make([]byte, 2) - binary.BigEndian.PutUint16(b, n) - wb.buf = append(wb.buf, b...) -} - -func (wb *WriteBuf) WriteInt32(n int32) { - b := make([]byte, 4) - binary.BigEndian.PutUint32(b, uint32(n)) - wb.buf = append(wb.buf, b...) -} - -func (wb *WriteBuf) WriteUint32(n uint32) { - b := make([]byte, 4) - binary.BigEndian.PutUint32(b, n) - wb.buf = append(wb.buf, b...) -} - -func (wb *WriteBuf) WriteInt64(n int64) { - b := make([]byte, 8) - binary.BigEndian.PutUint64(b, uint64(n)) - wb.buf = append(wb.buf, b...) -} - -func (wb *WriteBuf) WriteBytes(b []byte) { - wb.buf = append(wb.buf, b...) + return buf } diff --git a/msg_reader.go b/msg_reader.go deleted file mode 100644 index 21db5d26..00000000 --- a/msg_reader.go +++ /dev/null @@ -1,316 +0,0 @@ -package pgx - -import ( - "bufio" - "encoding/binary" - "errors" - "io" -) - -// msgReader is a helper that reads values from a PostgreSQL message. -type msgReader struct { - reader *bufio.Reader - msgBytesRemaining int32 - err error - log func(lvl int, msg string, ctx ...interface{}) - shouldLog func(lvl int) bool -} - -// Err returns any error that the msgReader has experienced -func (r *msgReader) Err() error { - return r.err -} - -// fatal tells rc that a Fatal error has occurred -func (r *msgReader) fatal(err error) { - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgBytesRemaining", r.msgBytesRemaining) - } - r.err = err -} - -// rxMsg reads the type and size of the next message. -func (r *msgReader) rxMsg() (byte, error) { - if r.err != nil { - return 0, r.err - } - - if r.msgBytesRemaining > 0 { - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining) - } - - _, err := r.reader.Discard(int(r.msgBytesRemaining)) - if err != nil { - return 0, err - } - } - - b, err := r.reader.Peek(5) - if err != nil { - r.fatal(err) - return 0, err - } - msgType := b[0] - r.msgBytesRemaining = int32(binary.BigEndian.Uint32(b[1:])) - 4 - r.reader.Discard(5) - return msgType, nil -} - -func (r *msgReader) readByte() byte { - if r.err != nil { - return 0 - } - - r.msgBytesRemaining-- - if r.msgBytesRemaining < 0 { - r.fatal(errors.New("read past end of message")) - return 0 - } - - b, err := r.reader.ReadByte() - if err != nil { - r.fatal(err) - return 0 - } - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgBytesRemaining", r.msgBytesRemaining) - } - - return b -} - -func (r *msgReader) readInt16() int16 { - if r.err != nil { - return 0 - } - - r.msgBytesRemaining -= 2 - if r.msgBytesRemaining < 0 { - r.fatal(errors.New("read past end of message")) - return 0 - } - - b, err := r.reader.Peek(2) - if err != nil { - r.fatal(err) - return 0 - } - - n := int16(binary.BigEndian.Uint16(b)) - - r.reader.Discard(2) - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining) - } - - return n -} - -func (r *msgReader) readInt32() int32 { - if r.err != nil { - return 0 - } - - r.msgBytesRemaining -= 4 - if r.msgBytesRemaining < 0 { - r.fatal(errors.New("read past end of message")) - return 0 - } - - b, err := r.reader.Peek(4) - if err != nil { - r.fatal(err) - return 0 - } - - n := int32(binary.BigEndian.Uint32(b)) - - r.reader.Discard(4) - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining) - } - - return n -} - -func (r *msgReader) readUint16() uint16 { - if r.err != nil { - return 0 - } - - r.msgBytesRemaining -= 2 - if r.msgBytesRemaining < 0 { - r.fatal(errors.New("read past end of message")) - return 0 - } - - b, err := r.reader.Peek(2) - if err != nil { - r.fatal(err) - return 0 - } - - n := uint16(binary.BigEndian.Uint16(b)) - - r.reader.Discard(2) - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgBytesRemaining", r.msgBytesRemaining) - } - - return n -} - -func (r *msgReader) readUint32() uint32 { - if r.err != nil { - return 0 - } - - r.msgBytesRemaining -= 4 - if r.msgBytesRemaining < 0 { - r.fatal(errors.New("read past end of message")) - return 0 - } - - b, err := r.reader.Peek(4) - if err != nil { - r.fatal(err) - return 0 - } - - n := uint32(binary.BigEndian.Uint32(b)) - - r.reader.Discard(4) - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgBytesRemaining", r.msgBytesRemaining) - } - - return n -} - -func (r *msgReader) readInt64() int64 { - if r.err != nil { - return 0 - } - - r.msgBytesRemaining -= 8 - if r.msgBytesRemaining < 0 { - r.fatal(errors.New("read past end of message")) - return 0 - } - - b, err := r.reader.Peek(8) - if err != nil { - r.fatal(err) - return 0 - } - - n := int64(binary.BigEndian.Uint64(b)) - - r.reader.Discard(8) - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining) - } - - return n -} - -func (r *msgReader) readOid() Oid { - return Oid(r.readInt32()) -} - -// readCString reads a null terminated string -func (r *msgReader) readCString() string { - if r.err != nil { - return "" - } - - b, err := r.reader.ReadBytes(0) - if err != nil { - r.fatal(err) - return "" - } - - r.msgBytesRemaining -= int32(len(b)) - if r.msgBytesRemaining < 0 { - r.fatal(errors.New("read past end of message")) - return "" - } - - s := string(b[0 : len(b)-1]) - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgBytesRemaining", r.msgBytesRemaining) - } - - return s -} - -// readString reads count bytes and returns as string -func (r *msgReader) readString(countI32 int32) string { - if r.err != nil { - return "" - } - - r.msgBytesRemaining -= countI32 - if r.msgBytesRemaining < 0 { - r.fatal(errors.New("read past end of message")) - return "" - } - - count := int(countI32) - var s string - - if r.reader.Buffered() >= count { - buf, _ := r.reader.Peek(count) - s = string(buf) - r.reader.Discard(count) - } else { - buf := make([]byte, count) - _, err := io.ReadFull(r.reader, buf) - if err != nil { - r.fatal(err) - return "" - } - s = string(buf) - } - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining) - } - - return s -} - -// readBytes reads count bytes and returns as []byte -func (r *msgReader) readBytes(count int32) []byte { - if r.err != nil { - return nil - } - - r.msgBytesRemaining -= count - if r.msgBytesRemaining < 0 { - r.fatal(errors.New("read past end of message")) - return nil - } - - b := make([]byte, int(count)) - - _, err := io.ReadFull(r.reader, b) - if err != nil { - r.fatal(err) - return nil - } - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readBytes", "value", b, "msgBytesRemaining", r.msgBytesRemaining) - } - - return b -} diff --git a/pgio/doc.go b/pgio/doc.go new file mode 100644 index 00000000..ef2dcc7f --- /dev/null +++ b/pgio/doc.go @@ -0,0 +1,6 @@ +// Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol. +/* +pgio provides functions for appending integers to a []byte while doing byte +order conversion. +*/ +package pgio diff --git a/pgio/write.go b/pgio/write.go new file mode 100644 index 00000000..96aedf9d --- /dev/null +++ b/pgio/write.go @@ -0,0 +1,40 @@ +package pgio + +import "encoding/binary" + +func AppendUint16(buf []byte, n uint16) []byte { + wp := len(buf) + buf = append(buf, 0, 0) + binary.BigEndian.PutUint16(buf[wp:], n) + return buf +} + +func AppendUint32(buf []byte, n uint32) []byte { + wp := len(buf) + buf = append(buf, 0, 0, 0, 0) + binary.BigEndian.PutUint32(buf[wp:], n) + return buf +} + +func AppendUint64(buf []byte, n uint64) []byte { + wp := len(buf) + buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0) + binary.BigEndian.PutUint64(buf[wp:], n) + return buf +} + +func AppendInt16(buf []byte, n int16) []byte { + return AppendUint16(buf, uint16(n)) +} + +func AppendInt32(buf []byte, n int32) []byte { + return AppendUint32(buf, uint32(n)) +} + +func AppendInt64(buf []byte, n int64) []byte { + return AppendUint64(buf, uint64(n)) +} + +func SetInt32(buf []byte, n int32) { + binary.BigEndian.PutUint32(buf, uint32(n)) +} diff --git a/pgio/write_test.go b/pgio/write_test.go new file mode 100644 index 00000000..bd50e71c --- /dev/null +++ b/pgio/write_test.go @@ -0,0 +1,78 @@ +package pgio + +import ( + "reflect" + "testing" +) + +func TestAppendUint16NilBuf(t *testing.T) { + buf := AppendUint16(nil, 1) + if !reflect.DeepEqual(buf, []byte{0, 1}) { + t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1}) + } +} + +func TestAppendUint16EmptyBuf(t *testing.T) { + buf := []byte{} + buf = AppendUint16(buf, 1) + if !reflect.DeepEqual(buf, []byte{0, 1}) { + t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1}) + } +} + +func TestAppendUint16BufWithCapacityDoesNotAllocate(t *testing.T) { + buf := make([]byte, 0, 4) + AppendUint16(buf, 1) + buf = buf[0:2] + if !reflect.DeepEqual(buf, []byte{0, 1}) { + t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1}) + } +} + +func TestAppendUint32NilBuf(t *testing.T) { + buf := AppendUint32(nil, 1) + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) { + t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1}) + } +} + +func TestAppendUint32EmptyBuf(t *testing.T) { + buf := []byte{} + buf = AppendUint32(buf, 1) + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) { + t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1}) + } +} + +func TestAppendUint32BufWithCapacityDoesNotAllocate(t *testing.T) { + buf := make([]byte, 0, 4) + AppendUint32(buf, 1) + buf = buf[0:4] + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) { + t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1}) + } +} + +func TestAppendUint64NilBuf(t *testing.T) { + buf := AppendUint64(nil, 1) + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) { + t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) + } +} + +func TestAppendUint64EmptyBuf(t *testing.T) { + buf := []byte{} + buf = AppendUint64(buf, 1) + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) { + t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) + } +} + +func TestAppendUint64BufWithCapacityDoesNotAllocate(t *testing.T) { + buf := make([]byte, 0, 8) + AppendUint64(buf, 1) + buf = buf[0:8] + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) { + t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) + } +} diff --git a/pgmock/pgmock.go b/pgmock/pgmock.go new file mode 100644 index 00000000..fe78b009 --- /dev/null +++ b/pgmock/pgmock.go @@ -0,0 +1,501 @@ +package pgmock + +import ( + "io" + "net" + "reflect" + + "github.com/pkg/errors" + + "github.com/jackc/pgx/pgproto3" + "github.com/jackc/pgx/pgtype" +) + +type Server struct { + ln net.Listener + controller Controller +} + +func NewServer(controller Controller) (*Server, error) { + ln, err := net.Listen("tcp", "127.0.0.1:") + if err != nil { + return nil, err + } + + server := &Server{ + ln: ln, + controller: controller, + } + + return server, nil +} + +func (s *Server) Addr() net.Addr { + return s.ln.Addr() +} + +func (s *Server) ServeOne() error { + conn, err := s.ln.Accept() + if err != nil { + return err + } + defer conn.Close() + + s.Close() + + backend, err := pgproto3.NewBackend(conn, conn) + if err != nil { + conn.Close() + return err + } + + return s.controller.Serve(backend) +} + +func (s *Server) Close() error { + err := s.ln.Close() + if err != nil { + return err + } + + return nil +} + +type Controller interface { + Serve(backend *pgproto3.Backend) error +} + +type Step interface { + Step(*pgproto3.Backend) error +} + +type Script struct { + Steps []Step +} + +func (s *Script) Run(backend *pgproto3.Backend) error { + for _, step := range s.Steps { + err := step.Step(backend) + if err != nil { + return err + } + } + + return nil +} + +func (s *Script) Serve(backend *pgproto3.Backend) error { + for _, step := range s.Steps { + err := step.Step(backend) + if err != nil { + return err + } + } + + return nil +} + +func (s *Script) Step(backend *pgproto3.Backend) error { + return s.Serve(backend) +} + +type expectMessageStep struct { + want pgproto3.FrontendMessage + any bool +} + +func (e *expectMessageStep) Step(backend *pgproto3.Backend) error { + msg, err := backend.Receive() + if err != nil { + return err + } + + if e.any && reflect.TypeOf(msg) == reflect.TypeOf(e.want) { + return nil + } + + if !reflect.DeepEqual(msg, e.want) { + return errors.Errorf("msg => %#v, e.want => %#v", msg, e.want) + } + + return nil +} + +type expectStartupMessageStep struct { + want *pgproto3.StartupMessage + any bool +} + +func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error { + msg, err := backend.ReceiveStartupMessage() + if err != nil { + return err + } + + if e.any { + return nil + } + + if !reflect.DeepEqual(msg, e.want) { + return errors.Errorf("msg => %#v, e.want => %#v", msg, e.want) + } + + return nil +} + +func ExpectMessage(want pgproto3.FrontendMessage) Step { + return expectMessage(want, false) +} + +func ExpectAnyMessage(want pgproto3.FrontendMessage) Step { + return expectMessage(want, true) +} + +func expectMessage(want pgproto3.FrontendMessage, any bool) Step { + if want, ok := want.(*pgproto3.StartupMessage); ok { + return &expectStartupMessageStep{want: want, any: any} + } + + return &expectMessageStep{want: want, any: any} +} + +type sendMessageStep struct { + msg pgproto3.BackendMessage +} + +func (e *sendMessageStep) Step(backend *pgproto3.Backend) error { + return backend.Send(e.msg) +} + +func SendMessage(msg pgproto3.BackendMessage) Step { + return &sendMessageStep{msg: msg} +} + +type waitForCloseMessageStep struct{} + +func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error { + for { + msg, err := backend.Receive() + if err == io.EOF { + return nil + } else if err != nil { + return err + } + + if _, ok := msg.(*pgproto3.Terminate); ok { + return nil + } + } +} + +func WaitForClose() Step { + return &waitForCloseMessageStep{} +} + +func AcceptUnauthenticatedConnRequestSteps() []Step { + return []Step{ + ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), + SendMessage(&pgproto3.Authentication{Type: pgproto3.AuthTypeOk}), + SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), + SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + } +} + +func PgxInitSteps() []Step { + steps := []Step{ + ExpectMessage(&pgproto3.Parse{ + Query: "select t.oid, t.typname\nfrom pg_type t\nleft join pg_type base_type on t.typelem=base_type.oid\nwhere (\n\t t.typtype in('b', 'p', 'r', 'e')\n\t and (base_type.oid is null or base_type.typtype in('b', 'p', 'r'))\n\t)", + }), + ExpectMessage(&pgproto3.Describe{ + ObjectType: 'S', + }), + ExpectMessage(&pgproto3.Sync{}), + SendMessage(&pgproto3.ParseComplete{}), + SendMessage(&pgproto3.ParameterDescription{}), + SendMessage(&pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: "oid", + TableOID: 1247, + TableAttributeNumber: 65534, + DataTypeOID: 26, + DataTypeSize: 4, + TypeModifier: 4294967295, + Format: 0, + }, + {Name: "typname", + TableOID: 1247, + TableAttributeNumber: 1, + DataTypeOID: 19, + DataTypeSize: 64, + TypeModifier: 4294967295, + Format: 0, + }, + }, + }), + SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + ExpectMessage(&pgproto3.Bind{ + ResultFormatCodes: []int16{1, 1}, + }), + ExpectMessage(&pgproto3.Execute{}), + ExpectMessage(&pgproto3.Sync{}), + SendMessage(&pgproto3.BindComplete{}), + } + + rowVals := []struct { + oid pgtype.OID + name string + }{ + {16, "bool"}, + {17, "bytea"}, + {18, "char"}, + {19, "name"}, + {20, "int8"}, + {21, "int2"}, + {22, "int2vector"}, + {23, "int4"}, + {24, "regproc"}, + {25, "text"}, + {26, "oid"}, + {27, "tid"}, + {28, "xid"}, + {29, "cid"}, + {30, "oidvector"}, + {114, "json"}, + {142, "xml"}, + {143, "_xml"}, + {199, "_json"}, + {194, "pg_node_tree"}, + {32, "pg_ddl_command"}, + {210, "smgr"}, + {600, "point"}, + {601, "lseg"}, + {602, "path"}, + {603, "box"}, + {604, "polygon"}, + {628, "line"}, + {629, "_line"}, + {700, "float4"}, + {701, "float8"}, + {702, "abstime"}, + {703, "reltime"}, + {704, "tinterval"}, + {705, "unknown"}, + {718, "circle"}, + {719, "_circle"}, + {790, "money"}, + {791, "_money"}, + {829, "macaddr"}, + {869, "inet"}, + {650, "cidr"}, + {1000, "_bool"}, + {1001, "_bytea"}, + {1002, "_char"}, + {1003, "_name"}, + {1005, "_int2"}, + {1006, "_int2vector"}, + {1007, "_int4"}, + {1008, "_regproc"}, + {1009, "_text"}, + {1028, "_oid"}, + {1010, "_tid"}, + {1011, "_xid"}, + {1012, "_cid"}, + {1013, "_oidvector"}, + {1014, "_bpchar"}, + {1015, "_varchar"}, + {1016, "_int8"}, + {1017, "_point"}, + {1018, "_lseg"}, + {1019, "_path"}, + {1020, "_box"}, + {1021, "_float4"}, + {1022, "_float8"}, + {1023, "_abstime"}, + {1024, "_reltime"}, + {1025, "_tinterval"}, + {1027, "_polygon"}, + {1033, "aclitem"}, + {1034, "_aclitem"}, + {1040, "_macaddr"}, + {1041, "_inet"}, + {651, "_cidr"}, + {1263, "_cstring"}, + {1042, "bpchar"}, + {1043, "varchar"}, + {1082, "date"}, + {1083, "time"}, + {1114, "timestamp"}, + {1115, "_timestamp"}, + {1182, "_date"}, + {1183, "_time"}, + {1184, "timestamptz"}, + {1185, "_timestamptz"}, + {1186, "interval"}, + {1187, "_interval"}, + {1231, "_numeric"}, + {1266, "timetz"}, + {1270, "_timetz"}, + {1560, "bit"}, + {1561, "_bit"}, + {1562, "varbit"}, + {1563, "_varbit"}, + {1700, "numeric"}, + {1790, "refcursor"}, + {2201, "_refcursor"}, + {2202, "regprocedure"}, + {2203, "regoper"}, + {2204, "regoperator"}, + {2205, "regclass"}, + {2206, "regtype"}, + {4096, "regrole"}, + {4089, "regnamespace"}, + {2207, "_regprocedure"}, + {2208, "_regoper"}, + {2209, "_regoperator"}, + {2210, "_regclass"}, + {2211, "_regtype"}, + {4097, "_regrole"}, + {4090, "_regnamespace"}, + {2950, "uuid"}, + {2951, "_uuid"}, + {3220, "pg_lsn"}, + {3221, "_pg_lsn"}, + {3614, "tsvector"}, + {3642, "gtsvector"}, + {3615, "tsquery"}, + {3734, "regconfig"}, + {3769, "regdictionary"}, + {3643, "_tsvector"}, + {3644, "_gtsvector"}, + {3645, "_tsquery"}, + {3735, "_regconfig"}, + {3770, "_regdictionary"}, + {3802, "jsonb"}, + {3807, "_jsonb"}, + {2970, "txid_snapshot"}, + {2949, "_txid_snapshot"}, + {3904, "int4range"}, + {3905, "_int4range"}, + {3906, "numrange"}, + {3907, "_numrange"}, + {3908, "tsrange"}, + {3909, "_tsrange"}, + {3910, "tstzrange"}, + {3911, "_tstzrange"}, + {3912, "daterange"}, + {3913, "_daterange"}, + {3926, "int8range"}, + {3927, "_int8range"}, + {2249, "record"}, + {2287, "_record"}, + {2275, "cstring"}, + {2276, "any"}, + {2277, "anyarray"}, + {2278, "void"}, + {2279, "trigger"}, + {3838, "event_trigger"}, + {2280, "language_handler"}, + {2281, "internal"}, + {2282, "opaque"}, + {2283, "anyelement"}, + {2776, "anynonarray"}, + {3500, "anyenum"}, + {3115, "fdw_handler"}, + {325, "index_am_handler"}, + {3310, "tsm_handler"}, + {3831, "anyrange"}, + {51367, "gbtreekey4"}, + {51370, "_gbtreekey4"}, + {51371, "gbtreekey8"}, + {51374, "_gbtreekey8"}, + {51375, "gbtreekey16"}, + {51378, "_gbtreekey16"}, + {51379, "gbtreekey32"}, + {51382, "_gbtreekey32"}, + {51383, "gbtreekey_var"}, + {51386, "_gbtreekey_var"}, + {51921, "hstore"}, + {51926, "_hstore"}, + {52005, "ghstore"}, + {52008, "_ghstore"}, + } + + for _, rv := range rowVals { + step := SendMessage(mustBuildDataRow([]interface{}{rv.oid, rv.name}, []int16{pgproto3.BinaryFormat})) + steps = append(steps, step) + } + + steps = append(steps, SendMessage(&pgproto3.CommandComplete{CommandTag: "SELECT 163"})) + steps = append(steps, SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'})) + + return steps +} + +type dataRowValue struct { + Value interface{} + FormatCode int16 +} + +func mustBuildDataRow(values []interface{}, formatCodes []int16) *pgproto3.DataRow { + dr, err := buildDataRow(values, formatCodes) + if err != nil { + panic(err) + } + + return dr +} + +func buildDataRow(values []interface{}, formatCodes []int16) (*pgproto3.DataRow, error) { + dr := &pgproto3.DataRow{ + Values: make([][]byte, len(values)), + } + + if len(formatCodes) == 1 { + for i := 1; i < len(values); i++ { + formatCodes = append(formatCodes, formatCodes[0]) + } + } + + for i := range values { + switch v := values[i].(type) { + case string: + values[i] = &pgtype.Text{String: v, Status: pgtype.Present} + case int16: + values[i] = &pgtype.Int2{Int: v, Status: pgtype.Present} + case int32: + values[i] = &pgtype.Int4{Int: v, Status: pgtype.Present} + case int64: + values[i] = &pgtype.Int8{Int: v, Status: pgtype.Present} + } + } + + for i := range values { + switch formatCodes[i] { + case pgproto3.TextFormat: + if e, ok := values[i].(pgtype.TextEncoder); ok { + buf, err := e.EncodeText(nil, nil) + if err != nil { + return nil, errors.Errorf("failed to encode values[%d]", i) + } + dr.Values[i] = buf + } else { + return nil, errors.Errorf("values[%d] does not implement TextExcoder", i) + } + + case pgproto3.BinaryFormat: + if e, ok := values[i].(pgtype.BinaryEncoder); ok { + buf, err := e.EncodeBinary(nil, nil) + if err != nil { + return nil, errors.Errorf("failed to encode values[%d]", i) + } + dr.Values[i] = buf + } else { + return nil, errors.Errorf("values[%d] does not implement BinaryEncoder", i) + } + default: + return nil, errors.New("unknown FormatCode") + } + } + + return dr, nil +} diff --git a/pgproto3/authentication.go b/pgproto3/authentication.go new file mode 100644 index 00000000..77750b86 --- /dev/null +++ b/pgproto3/authentication.go @@ -0,0 +1,54 @@ +package pgproto3 + +import ( + "encoding/binary" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +const ( + AuthTypeOk = 0 + AuthTypeCleartextPassword = 3 + AuthTypeMD5Password = 5 +) + +type Authentication struct { + Type uint32 + + // MD5Password fields + Salt [4]byte +} + +func (*Authentication) Backend() {} + +func (dst *Authentication) Decode(src []byte) error { + *dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])} + + switch dst.Type { + case AuthTypeOk: + case AuthTypeCleartextPassword: + case AuthTypeMD5Password: + copy(dst.Salt[:], src[4:8]) + default: + return errors.Errorf("unknown authentication type: %d", dst.Type) + } + + return nil +} + +func (src *Authentication) Encode(dst []byte) []byte { + dst = append(dst, 'R') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint32(dst, src.Type) + + switch src.Type { + case AuthTypeMD5Password: + dst = append(dst, src.Salt[:]...) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} diff --git a/pgproto3/backend.go b/pgproto3/backend.go new file mode 100644 index 00000000..9a7ef342 --- /dev/null +++ b/pgproto3/backend.go @@ -0,0 +1,101 @@ +package pgproto3 + +import ( + "encoding/binary" + "io" + + "github.com/jackc/pgx/chunkreader" + "github.com/pkg/errors" +) + +type Backend struct { + cr *chunkreader.ChunkReader + w io.Writer + + // Frontend message flyweights + bind Bind + _close Close + describe Describe + execute Execute + flush Flush + parse Parse + passwordMessage PasswordMessage + query Query + startupMessage StartupMessage + sync Sync + terminate Terminate +} + +func NewBackend(r io.Reader, w io.Writer) (*Backend, error) { + cr := chunkreader.NewChunkReader(r) + return &Backend{cr: cr, w: w}, nil +} + +func (b *Backend) Send(msg BackendMessage) error { + _, err := b.w.Write(msg.Encode(nil)) + return err +} + +func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) { + buf, err := b.cr.Next(4) + if err != nil { + return nil, err + } + msgSize := int(binary.BigEndian.Uint32(buf) - 4) + + buf, err = b.cr.Next(msgSize) + if err != nil { + return nil, err + } + + err = b.startupMessage.Decode(buf) + if err != nil { + return nil, err + } + + return &b.startupMessage, nil +} + +func (b *Backend) Receive() (FrontendMessage, error) { + header, err := b.cr.Next(5) + if err != nil { + return nil, err + } + + msgType := header[0] + bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4 + + var msg FrontendMessage + switch msgType { + case 'B': + msg = &b.bind + case 'C': + msg = &b._close + case 'D': + msg = &b.describe + case 'E': + msg = &b.execute + case 'H': + msg = &b.flush + case 'P': + msg = &b.parse + case 'p': + msg = &b.passwordMessage + case 'Q': + msg = &b.query + case 'S': + msg = &b.sync + case 'X': + msg = &b.terminate + default: + return nil, errors.Errorf("unknown message type: %c", msgType) + } + + msgBody, err := b.cr.Next(bodyLen) + if err != nil { + return nil, err + } + + err = msg.Decode(msgBody) + return msg, err +} diff --git a/pgproto3/backend_key_data.go b/pgproto3/backend_key_data.go new file mode 100644 index 00000000..5a478f10 --- /dev/null +++ b/pgproto3/backend_key_data.go @@ -0,0 +1,46 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +type BackendKeyData struct { + ProcessID uint32 + SecretKey uint32 +} + +func (*BackendKeyData) Backend() {} + +func (dst *BackendKeyData) Decode(src []byte) error { + if len(src) != 8 { + return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)} + } + + dst.ProcessID = binary.BigEndian.Uint32(src[:4]) + dst.SecretKey = binary.BigEndian.Uint32(src[4:]) + + return nil +} + +func (src *BackendKeyData) Encode(dst []byte) []byte { + dst = append(dst, 'K') + dst = pgio.AppendUint32(dst, 12) + dst = pgio.AppendUint32(dst, src.ProcessID) + dst = pgio.AppendUint32(dst, src.SecretKey) + return dst +} + +func (src *BackendKeyData) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProcessID uint32 + SecretKey uint32 + }{ + Type: "BackendKeyData", + ProcessID: src.ProcessID, + SecretKey: src.SecretKey, + }) +} diff --git a/pgproto3/big_endian.go b/pgproto3/big_endian.go new file mode 100644 index 00000000..f7bdb97e --- /dev/null +++ b/pgproto3/big_endian.go @@ -0,0 +1,37 @@ +package pgproto3 + +import ( + "encoding/binary" +) + +type BigEndianBuf [8]byte + +func (b BigEndianBuf) Int16(n int16) []byte { + buf := b[0:2] + binary.BigEndian.PutUint16(buf, uint16(n)) + return buf +} + +func (b BigEndianBuf) Uint16(n uint16) []byte { + buf := b[0:2] + binary.BigEndian.PutUint16(buf, n) + return buf +} + +func (b BigEndianBuf) Int32(n int32) []byte { + buf := b[0:4] + binary.BigEndian.PutUint32(buf, uint32(n)) + return buf +} + +func (b BigEndianBuf) Uint32(n uint32) []byte { + buf := b[0:4] + binary.BigEndian.PutUint32(buf, n) + return buf +} + +func (b BigEndianBuf) Int64(n int64) []byte { + buf := b[0:8] + binary.BigEndian.PutUint64(buf, uint64(n)) + return buf +} diff --git a/pgproto3/bind.go b/pgproto3/bind.go new file mode 100644 index 00000000..cceee6ab --- /dev/null +++ b/pgproto3/bind.go @@ -0,0 +1,171 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +type Bind struct { + DestinationPortal string + PreparedStatement string + ParameterFormatCodes []int16 + Parameters [][]byte + ResultFormatCodes []int16 +} + +func (*Bind) Frontend() {} + +func (dst *Bind) Decode(src []byte) error { + *dst = Bind{} + + idx := bytes.IndexByte(src, 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + dst.DestinationPortal = string(src[:idx]) + rp := idx + 1 + + idx = bytes.IndexByte(src[rp:], 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + dst.PreparedStatement = string(src[rp : rp+idx]) + rp += idx + 1 + + if len(src[rp:]) < 2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + if parameterFormatCodeCount > 0 { + dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount) + + if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + for i := 0; i < parameterFormatCodeCount; i++ { + dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + } + } + + if len(src[rp:]) < 2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + parameterCount := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + if parameterCount > 0 { + dst.Parameters = make([][]byte, parameterCount) + + for i := 0; i < parameterCount; i++ { + if len(src[rp:]) < 4 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + + msgSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + // null + if msgSize == -1 { + continue + } + + if len(src[rp:]) < msgSize { + return &invalidMessageFormatErr{messageType: "Bind"} + } + + dst.Parameters[i] = src[rp : rp+msgSize] + rp += msgSize + } + } + + if len(src[rp:]) < 2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + dst.ResultFormatCodes = make([]int16, resultFormatCodeCount) + if len(src[rp:]) < len(dst.ResultFormatCodes)*2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + for i := 0; i < resultFormatCodeCount; i++ { + dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + } + + return nil +} + +func (src *Bind) Encode(dst []byte) []byte { + dst = append(dst, 'B') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.DestinationPortal...) + dst = append(dst, 0) + dst = append(dst, src.PreparedStatement...) + dst = append(dst, 0) + + dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes))) + for _, fc := range src.ParameterFormatCodes { + dst = pgio.AppendInt16(dst, fc) + } + + dst = pgio.AppendUint16(dst, uint16(len(src.Parameters))) + for _, p := range src.Parameters { + if p == nil { + dst = pgio.AppendInt32(dst, -1) + continue + } + + dst = pgio.AppendInt32(dst, int32(len(p))) + dst = append(dst, p...) + } + + dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes))) + for _, fc := range src.ResultFormatCodes { + dst = pgio.AppendInt16(dst, fc) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +func (src *Bind) MarshalJSON() ([]byte, error) { + formattedParameters := make([]map[string]string, len(src.Parameters)) + for i, p := range src.Parameters { + if p == nil { + continue + } + + if src.ParameterFormatCodes[i] == 0 { + formattedParameters[i] = map[string]string{"text": string(p)} + } else { + formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)} + } + } + + return json.Marshal(struct { + Type string + DestinationPortal string + PreparedStatement string + ParameterFormatCodes []int16 + Parameters []map[string]string + ResultFormatCodes []int16 + }{ + Type: "Bind", + DestinationPortal: src.DestinationPortal, + PreparedStatement: src.PreparedStatement, + ParameterFormatCodes: src.ParameterFormatCodes, + Parameters: formattedParameters, + ResultFormatCodes: src.ResultFormatCodes, + }) +} diff --git a/pgproto3/bind_complete.go b/pgproto3/bind_complete.go new file mode 100644 index 00000000..60360519 --- /dev/null +++ b/pgproto3/bind_complete.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type BindComplete struct{} + +func (*BindComplete) Backend() {} + +func (dst *BindComplete) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *BindComplete) Encode(dst []byte) []byte { + return append(dst, '2', 0, 0, 0, 4) +} + +func (src *BindComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "BindComplete", + }) +} diff --git a/pgproto3/close.go b/pgproto3/close.go new file mode 100644 index 00000000..5ff4c886 --- /dev/null +++ b/pgproto3/close.go @@ -0,0 +1,59 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +type Close struct { + ObjectType byte // 'S' = prepared statement, 'P' = portal + Name string +} + +func (*Close) Frontend() {} + +func (dst *Close) Decode(src []byte) error { + if len(src) < 2 { + return &invalidMessageFormatErr{messageType: "Close"} + } + + dst.ObjectType = src[0] + rp := 1 + + idx := bytes.IndexByte(src[rp:], 0) + if idx != len(src[rp:])-1 { + return &invalidMessageFormatErr{messageType: "Close"} + } + + dst.Name = string(src[rp : len(src)-1]) + + return nil +} + +func (src *Close) Encode(dst []byte) []byte { + dst = append(dst, 'C') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.ObjectType) + dst = append(dst, src.Name...) + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +func (src *Close) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ObjectType string + Name string + }{ + Type: "Close", + ObjectType: string(src.ObjectType), + Name: src.Name, + }) +} diff --git a/pgproto3/close_complete.go b/pgproto3/close_complete.go new file mode 100644 index 00000000..db793c94 --- /dev/null +++ b/pgproto3/close_complete.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type CloseComplete struct{} + +func (*CloseComplete) Backend() {} + +func (dst *CloseComplete) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *CloseComplete) Encode(dst []byte) []byte { + return append(dst, '3', 0, 0, 0, 4) +} + +func (src *CloseComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "CloseComplete", + }) +} diff --git a/pgproto3/command_complete.go b/pgproto3/command_complete.go new file mode 100644 index 00000000..85848532 --- /dev/null +++ b/pgproto3/command_complete.go @@ -0,0 +1,48 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +type CommandComplete struct { + CommandTag string +} + +func (*CommandComplete) Backend() {} + +func (dst *CommandComplete) Decode(src []byte) error { + idx := bytes.IndexByte(src, 0) + if idx != len(src)-1 { + return &invalidMessageFormatErr{messageType: "CommandComplete"} + } + + dst.CommandTag = string(src[:idx]) + + return nil +} + +func (src *CommandComplete) Encode(dst []byte) []byte { + dst = append(dst, 'C') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.CommandTag...) + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +func (src *CommandComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + CommandTag string + }{ + Type: "CommandComplete", + CommandTag: src.CommandTag, + }) +} diff --git a/pgproto3/copy_both_response.go b/pgproto3/copy_both_response.go new file mode 100644 index 00000000..2862a34f --- /dev/null +++ b/pgproto3/copy_both_response.go @@ -0,0 +1,65 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +type CopyBothResponse struct { + OverallFormat byte + ColumnFormatCodes []uint16 +} + +func (*CopyBothResponse) Backend() {} + +func (dst *CopyBothResponse) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 3 { + return &invalidMessageFormatErr{messageType: "CopyBothResponse"} + } + + overallFormat := buf.Next(1)[0] + + columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) + if buf.Len() != columnCount*2 { + return &invalidMessageFormatErr{messageType: "CopyBothResponse"} + } + + columnFormatCodes := make([]uint16, columnCount) + for i := 0; i < columnCount; i++ { + columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) + } + + *dst = CopyBothResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} + + return nil +} + +func (src *CopyBothResponse) Encode(dst []byte) []byte { + dst = append(dst, 'W') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) + for _, fc := range src.ColumnFormatCodes { + dst = pgio.AppendUint16(dst, fc) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +func (src *CopyBothResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ColumnFormatCodes []uint16 + }{ + Type: "CopyBothResponse", + ColumnFormatCodes: src.ColumnFormatCodes, + }) +} diff --git a/pgproto3/copy_data.go b/pgproto3/copy_data.go new file mode 100644 index 00000000..fab139e6 --- /dev/null +++ b/pgproto3/copy_data.go @@ -0,0 +1,37 @@ +package pgproto3 + +import ( + "encoding/hex" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +type CopyData struct { + Data []byte +} + +func (*CopyData) Backend() {} +func (*CopyData) Frontend() {} + +func (dst *CopyData) Decode(src []byte) error { + dst.Data = src + return nil +} + +func (src *CopyData) Encode(dst []byte) []byte { + dst = append(dst, 'd') + dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) + dst = append(dst, src.Data...) + return dst +} + +func (src *CopyData) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "CopyData", + Data: hex.EncodeToString(src.Data), + }) +} diff --git a/pgproto3/copy_in_response.go b/pgproto3/copy_in_response.go new file mode 100644 index 00000000..54083cd6 --- /dev/null +++ b/pgproto3/copy_in_response.go @@ -0,0 +1,65 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +type CopyInResponse struct { + OverallFormat byte + ColumnFormatCodes []uint16 +} + +func (*CopyInResponse) Backend() {} + +func (dst *CopyInResponse) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 3 { + return &invalidMessageFormatErr{messageType: "CopyInResponse"} + } + + overallFormat := buf.Next(1)[0] + + columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) + if buf.Len() != columnCount*2 { + return &invalidMessageFormatErr{messageType: "CopyInResponse"} + } + + columnFormatCodes := make([]uint16, columnCount) + for i := 0; i < columnCount; i++ { + columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) + } + + *dst = CopyInResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} + + return nil +} + +func (src *CopyInResponse) Encode(dst []byte) []byte { + dst = append(dst, 'G') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) + for _, fc := range src.ColumnFormatCodes { + dst = pgio.AppendUint16(dst, fc) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +func (src *CopyInResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ColumnFormatCodes []uint16 + }{ + Type: "CopyInResponse", + ColumnFormatCodes: src.ColumnFormatCodes, + }) +} diff --git a/pgproto3/copy_out_response.go b/pgproto3/copy_out_response.go new file mode 100644 index 00000000..eaa33b8b --- /dev/null +++ b/pgproto3/copy_out_response.go @@ -0,0 +1,65 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +type CopyOutResponse struct { + OverallFormat byte + ColumnFormatCodes []uint16 +} + +func (*CopyOutResponse) Backend() {} + +func (dst *CopyOutResponse) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 3 { + return &invalidMessageFormatErr{messageType: "CopyOutResponse"} + } + + overallFormat := buf.Next(1)[0] + + columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) + if buf.Len() != columnCount*2 { + return &invalidMessageFormatErr{messageType: "CopyOutResponse"} + } + + columnFormatCodes := make([]uint16, columnCount) + for i := 0; i < columnCount; i++ { + columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) + } + + *dst = CopyOutResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} + + return nil +} + +func (src *CopyOutResponse) Encode(dst []byte) []byte { + dst = append(dst, 'H') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) + for _, fc := range src.ColumnFormatCodes { + dst = pgio.AppendUint16(dst, fc) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +func (src *CopyOutResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ColumnFormatCodes []uint16 + }{ + Type: "CopyOutResponse", + ColumnFormatCodes: src.ColumnFormatCodes, + }) +} diff --git a/pgproto3/data_row.go b/pgproto3/data_row.go new file mode 100644 index 00000000..e46d3cc0 --- /dev/null +++ b/pgproto3/data_row.go @@ -0,0 +1,112 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/hex" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +type DataRow struct { + Values [][]byte +} + +func (*DataRow) Backend() {} + +func (dst *DataRow) Decode(src []byte) error { + if len(src) < 2 { + return &invalidMessageFormatErr{messageType: "DataRow"} + } + rp := 0 + fieldCount := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + // If the capacity of the values slice is too small OR substantially too + // large reallocate. This is too avoid one row with many columns from + // permanently allocating memory. + if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 { + newCap := 32 + if newCap < fieldCount { + newCap = fieldCount + } + dst.Values = make([][]byte, fieldCount, newCap) + } else { + dst.Values = dst.Values[:fieldCount] + } + + for i := 0; i < fieldCount; i++ { + if len(src[rp:]) < 4 { + return &invalidMessageFormatErr{messageType: "DataRow"} + } + + msgSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + // null + if msgSize == -1 { + dst.Values[i] = nil + } else { + if len(src[rp:]) < msgSize { + return &invalidMessageFormatErr{messageType: "DataRow"} + } + + dst.Values[i] = src[rp : rp+msgSize] + rp += msgSize + } + } + + return nil +} + +func (src *DataRow) Encode(dst []byte) []byte { + dst = append(dst, 'D') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = pgio.AppendUint16(dst, uint16(len(src.Values))) + for _, v := range src.Values { + if v == nil { + dst = pgio.AppendInt32(dst, -1) + continue + } + + dst = pgio.AppendInt32(dst, int32(len(v))) + dst = append(dst, v...) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +func (src *DataRow) MarshalJSON() ([]byte, error) { + formattedValues := make([]map[string]string, len(src.Values)) + for i, v := range src.Values { + if v == nil { + continue + } + + var hasNonPrintable bool + for _, b := range v { + if b < 32 { + hasNonPrintable = true + break + } + } + + if hasNonPrintable { + formattedValues[i] = map[string]string{"binary": hex.EncodeToString(v)} + } else { + formattedValues[i] = map[string]string{"text": string(v)} + } + } + + return json.Marshal(struct { + Type string + Values []map[string]string + }{ + Type: "DataRow", + Values: formattedValues, + }) +} diff --git a/pgproto3/describe.go b/pgproto3/describe.go new file mode 100644 index 00000000..bb7bc056 --- /dev/null +++ b/pgproto3/describe.go @@ -0,0 +1,59 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +type Describe struct { + ObjectType byte // 'S' = prepared statement, 'P' = portal + Name string +} + +func (*Describe) Frontend() {} + +func (dst *Describe) Decode(src []byte) error { + if len(src) < 2 { + return &invalidMessageFormatErr{messageType: "Describe"} + } + + dst.ObjectType = src[0] + rp := 1 + + idx := bytes.IndexByte(src[rp:], 0) + if idx != len(src[rp:])-1 { + return &invalidMessageFormatErr{messageType: "Describe"} + } + + dst.Name = string(src[rp : len(src)-1]) + + return nil +} + +func (src *Describe) Encode(dst []byte) []byte { + dst = append(dst, 'D') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.ObjectType) + dst = append(dst, src.Name...) + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +func (src *Describe) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ObjectType string + Name string + }{ + Type: "Describe", + ObjectType: string(src.ObjectType), + Name: src.Name, + }) +} diff --git a/pgproto3/empty_query_response.go b/pgproto3/empty_query_response.go new file mode 100644 index 00000000..d283b06d --- /dev/null +++ b/pgproto3/empty_query_response.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type EmptyQueryResponse struct{} + +func (*EmptyQueryResponse) Backend() {} + +func (dst *EmptyQueryResponse) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *EmptyQueryResponse) Encode(dst []byte) []byte { + return append(dst, 'I', 0, 0, 0, 4) +} + +func (src *EmptyQueryResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "EmptyQueryResponse", + }) +} diff --git a/pgproto3/error_response.go b/pgproto3/error_response.go new file mode 100644 index 00000000..160234f2 --- /dev/null +++ b/pgproto3/error_response.go @@ -0,0 +1,197 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "strconv" +) + +type ErrorResponse struct { + Severity string + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string + + UnknownFields map[byte]string +} + +func (*ErrorResponse) Backend() {} + +func (dst *ErrorResponse) Decode(src []byte) error { + *dst = ErrorResponse{} + + buf := bytes.NewBuffer(src) + + for { + k, err := buf.ReadByte() + if err != nil { + return err + } + if k == 0 { + break + } + + vb, err := buf.ReadBytes(0) + if err != nil { + return err + } + v := string(vb[:len(vb)-1]) + + switch k { + case 'S': + dst.Severity = v + case 'C': + dst.Code = v + case 'M': + dst.Message = v + case 'D': + dst.Detail = v + case 'H': + dst.Hint = v + case 'P': + s := v + n, _ := strconv.ParseInt(s, 10, 32) + dst.Position = int32(n) + case 'p': + s := v + n, _ := strconv.ParseInt(s, 10, 32) + dst.InternalPosition = int32(n) + case 'q': + dst.InternalQuery = v + case 'W': + dst.Where = v + case 's': + dst.SchemaName = v + case 't': + dst.TableName = v + case 'c': + dst.ColumnName = v + case 'd': + dst.DataTypeName = v + case 'n': + dst.ConstraintName = v + case 'F': + dst.File = v + case 'L': + s := v + n, _ := strconv.ParseInt(s, 10, 32) + dst.Line = int32(n) + case 'R': + dst.Routine = v + + default: + if dst.UnknownFields == nil { + dst.UnknownFields = make(map[byte]string) + } + dst.UnknownFields[k] = v + } + } + + return nil +} + +func (src *ErrorResponse) Encode(dst []byte) []byte { + return append(dst, src.marshalBinary('E')...) +} + +func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte(typeByte) + buf.Write(bigEndian.Uint32(0)) + + if src.Severity != "" { + buf.WriteString(src.Severity) + buf.WriteByte(0) + } + if src.Code != "" { + buf.WriteString(src.Code) + buf.WriteByte(0) + } + if src.Message != "" { + buf.WriteString(src.Message) + buf.WriteByte(0) + } + if src.Detail != "" { + buf.WriteString(src.Detail) + buf.WriteByte(0) + } + if src.Hint != "" { + buf.WriteString(src.Hint) + buf.WriteByte(0) + } + if src.Position != 0 { + buf.WriteString(strconv.Itoa(int(src.Position))) + buf.WriteByte(0) + } + if src.InternalPosition != 0 { + buf.WriteString(strconv.Itoa(int(src.InternalPosition))) + buf.WriteByte(0) + } + if src.InternalQuery != "" { + buf.WriteString(src.InternalQuery) + buf.WriteByte(0) + } + if src.Where != "" { + buf.WriteString(src.Where) + buf.WriteByte(0) + } + if src.SchemaName != "" { + buf.WriteString(src.SchemaName) + buf.WriteByte(0) + } + if src.TableName != "" { + buf.WriteString(src.TableName) + buf.WriteByte(0) + } + if src.ColumnName != "" { + buf.WriteString(src.ColumnName) + buf.WriteByte(0) + } + if src.DataTypeName != "" { + buf.WriteString(src.DataTypeName) + buf.WriteByte(0) + } + if src.ConstraintName != "" { + buf.WriteString(src.ConstraintName) + buf.WriteByte(0) + } + if src.File != "" { + buf.WriteString(src.File) + buf.WriteByte(0) + } + if src.Line != 0 { + buf.WriteString(strconv.Itoa(int(src.Line))) + buf.WriteByte(0) + } + if src.Routine != "" { + buf.WriteString(src.Routine) + buf.WriteByte(0) + } + + for k, v := range src.UnknownFields { + buf.WriteByte(k) + buf.WriteByte(0) + buf.WriteString(v) + buf.WriteByte(0) + } + buf.WriteByte(0) + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes() +} diff --git a/pgproto3/execute.go b/pgproto3/execute.go new file mode 100644 index 00000000..76da9943 --- /dev/null +++ b/pgproto3/execute.go @@ -0,0 +1,60 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +type Execute struct { + Portal string + MaxRows uint32 +} + +func (*Execute) Frontend() {} + +func (dst *Execute) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + dst.Portal = string(b[:len(b)-1]) + + if buf.Len() < 4 { + return &invalidMessageFormatErr{messageType: "Execute"} + } + dst.MaxRows = binary.BigEndian.Uint32(buf.Next(4)) + + return nil +} + +func (src *Execute) Encode(dst []byte) []byte { + dst = append(dst, 'E') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.Portal...) + dst = append(dst, 0) + + dst = pgio.AppendUint32(dst, src.MaxRows) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +func (src *Execute) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Portal string + MaxRows uint32 + }{ + Type: "Execute", + Portal: src.Portal, + MaxRows: src.MaxRows, + }) +} diff --git a/pgproto3/flush.go b/pgproto3/flush.go new file mode 100644 index 00000000..7fd5e987 --- /dev/null +++ b/pgproto3/flush.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type Flush struct{} + +func (*Flush) Frontend() {} + +func (dst *Flush) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "Flush", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *Flush) Encode(dst []byte) []byte { + return append(dst, 'H', 0, 0, 0, 4) +} + +func (src *Flush) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "Flush", + }) +} diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go new file mode 100644 index 00000000..c8ab5f15 --- /dev/null +++ b/pgproto3/frontend.go @@ -0,0 +1,113 @@ +package pgproto3 + +import ( + "encoding/binary" + "io" + + "github.com/jackc/pgx/chunkreader" + "github.com/pkg/errors" +) + +type Frontend struct { + cr *chunkreader.ChunkReader + w io.Writer + + // Backend message flyweights + authentication Authentication + backendKeyData BackendKeyData + bindComplete BindComplete + closeComplete CloseComplete + commandComplete CommandComplete + copyBothResponse CopyBothResponse + copyData CopyData + copyInResponse CopyInResponse + copyOutResponse CopyOutResponse + dataRow DataRow + emptyQueryResponse EmptyQueryResponse + errorResponse ErrorResponse + functionCallResponse FunctionCallResponse + noData NoData + noticeResponse NoticeResponse + notificationResponse NotificationResponse + parameterDescription ParameterDescription + parameterStatus ParameterStatus + parseComplete ParseComplete + readyForQuery ReadyForQuery + rowDescription RowDescription +} + +func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) { + cr := chunkreader.NewChunkReader(r) + return &Frontend{cr: cr, w: w}, nil +} + +func (b *Frontend) Send(msg FrontendMessage) error { + _, err := b.w.Write(msg.Encode(nil)) + return err +} + +func (b *Frontend) Receive() (BackendMessage, error) { + header, err := b.cr.Next(5) + if err != nil { + return nil, err + } + + msgType := header[0] + bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4 + + var msg BackendMessage + switch msgType { + case '1': + msg = &b.parseComplete + case '2': + msg = &b.bindComplete + case '3': + msg = &b.closeComplete + case 'A': + msg = &b.notificationResponse + case 'C': + msg = &b.commandComplete + case 'd': + msg = &b.copyData + case 'D': + msg = &b.dataRow + case 'E': + msg = &b.errorResponse + case 'G': + msg = &b.copyInResponse + case 'H': + msg = &b.copyOutResponse + case 'I': + msg = &b.emptyQueryResponse + case 'K': + msg = &b.backendKeyData + case 'n': + msg = &b.noData + case 'N': + msg = &b.noticeResponse + case 'R': + msg = &b.authentication + case 'S': + msg = &b.parameterStatus + case 't': + msg = &b.parameterDescription + case 'T': + msg = &b.rowDescription + case 'V': + msg = &b.functionCallResponse + case 'W': + msg = &b.copyBothResponse + case 'Z': + msg = &b.readyForQuery + default: + return nil, errors.Errorf("unknown message type: %c", msgType) + } + + msgBody, err := b.cr.Next(bodyLen) + if err != nil { + return nil, err + } + + err = msg.Decode(msgBody) + return msg, err +} diff --git a/pgproto3/function_call_response.go b/pgproto3/function_call_response.go new file mode 100644 index 00000000..bb325b69 --- /dev/null +++ b/pgproto3/function_call_response.go @@ -0,0 +1,78 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/hex" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +type FunctionCallResponse struct { + Result []byte +} + +func (*FunctionCallResponse) Backend() {} + +func (dst *FunctionCallResponse) Decode(src []byte) error { + if len(src) < 4 { + return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} + } + rp := 0 + resultSize := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + if resultSize == -1 { + dst.Result = nil + return nil + } + + if len(src[rp:]) != resultSize { + return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} + } + + dst.Result = src[rp:] + return nil +} + +func (src *FunctionCallResponse) Encode(dst []byte) []byte { + dst = append(dst, 'V') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + if src.Result == nil { + dst = pgio.AppendInt32(dst, -1) + } else { + dst = pgio.AppendInt32(dst, int32(len(src.Result))) + dst = append(dst, src.Result...) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +func (src *FunctionCallResponse) MarshalJSON() ([]byte, error) { + var formattedValue map[string]string + var hasNonPrintable bool + for _, b := range src.Result { + if b < 32 { + hasNonPrintable = true + break + } + } + + if hasNonPrintable { + formattedValue = map[string]string{"binary": hex.EncodeToString(src.Result)} + } else { + formattedValue = map[string]string{"text": string(src.Result)} + } + + return json.Marshal(struct { + Type string + Result map[string]string + }{ + Type: "FunctionCallResponse", + Result: formattedValue, + }) +} diff --git a/pgproto3/no_data.go b/pgproto3/no_data.go new file mode 100644 index 00000000..1fb47c2a --- /dev/null +++ b/pgproto3/no_data.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type NoData struct{} + +func (*NoData) Backend() {} + +func (dst *NoData) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *NoData) Encode(dst []byte) []byte { + return append(dst, 'n', 0, 0, 0, 4) +} + +func (src *NoData) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "NoData", + }) +} diff --git a/pgproto3/notice_response.go b/pgproto3/notice_response.go new file mode 100644 index 00000000..e4595aa5 --- /dev/null +++ b/pgproto3/notice_response.go @@ -0,0 +1,13 @@ +package pgproto3 + +type NoticeResponse ErrorResponse + +func (*NoticeResponse) Backend() {} + +func (dst *NoticeResponse) Decode(src []byte) error { + return (*ErrorResponse)(dst).Decode(src) +} + +func (src *NoticeResponse) Encode(dst []byte) []byte { + return append(dst, (*ErrorResponse)(src).marshalBinary('N')...) +} diff --git a/pgproto3/notification_response.go b/pgproto3/notification_response.go new file mode 100644 index 00000000..b14007b4 --- /dev/null +++ b/pgproto3/notification_response.go @@ -0,0 +1,67 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +type NotificationResponse struct { + PID uint32 + Channel string + Payload string +} + +func (*NotificationResponse) Backend() {} + +func (dst *NotificationResponse) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + pid := binary.BigEndian.Uint32(buf.Next(4)) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + channel := string(b[:len(b)-1]) + + b, err = buf.ReadBytes(0) + if err != nil { + return err + } + payload := string(b[:len(b)-1]) + + *dst = NotificationResponse{PID: pid, Channel: channel, Payload: payload} + return nil +} + +func (src *NotificationResponse) Encode(dst []byte) []byte { + dst = append(dst, 'A') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.Channel...) + dst = append(dst, 0) + dst = append(dst, src.Payload...) + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +func (src *NotificationResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + PID uint32 + Channel string + Payload string + }{ + Type: "NotificationResponse", + PID: src.PID, + Channel: src.Channel, + Payload: src.Payload, + }) +} diff --git a/pgproto3/parameter_description.go b/pgproto3/parameter_description.go new file mode 100644 index 00000000..1fa3c927 --- /dev/null +++ b/pgproto3/parameter_description.go @@ -0,0 +1,61 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +type ParameterDescription struct { + ParameterOIDs []uint32 +} + +func (*ParameterDescription) Backend() {} + +func (dst *ParameterDescription) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 2 { + return &invalidMessageFormatErr{messageType: "ParameterDescription"} + } + + // Reported parameter count will be incorrect when number of args is greater than uint16 + buf.Next(2) + // Instead infer parameter count by remaining size of message + parameterCount := buf.Len() / 4 + + *dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)} + + for i := 0; i < parameterCount; i++ { + dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4)) + } + + return nil +} + +func (src *ParameterDescription) Encode(dst []byte) []byte { + dst = append(dst, 't') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) + for _, oid := range src.ParameterOIDs { + dst = pgio.AppendUint32(dst, oid) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +func (src *ParameterDescription) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ParameterOIDs []uint32 + }{ + Type: "ParameterDescription", + ParameterOIDs: src.ParameterOIDs, + }) +} diff --git a/pgproto3/parameter_status.go b/pgproto3/parameter_status.go new file mode 100644 index 00000000..b3bac33f --- /dev/null +++ b/pgproto3/parameter_status.go @@ -0,0 +1,61 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +type ParameterStatus struct { + Name string + Value string +} + +func (*ParameterStatus) Backend() {} + +func (dst *ParameterStatus) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + name := string(b[:len(b)-1]) + + b, err = buf.ReadBytes(0) + if err != nil { + return err + } + value := string(b[:len(b)-1]) + + *dst = ParameterStatus{Name: name, Value: value} + return nil +} + +func (src *ParameterStatus) Encode(dst []byte) []byte { + dst = append(dst, 'S') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.Name...) + dst = append(dst, 0) + dst = append(dst, src.Value...) + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +func (ps *ParameterStatus) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Name string + Value string + }{ + Type: "ParameterStatus", + Name: ps.Name, + Value: ps.Value, + }) +} diff --git a/pgproto3/parse.go b/pgproto3/parse.go new file mode 100644 index 00000000..ca4834c6 --- /dev/null +++ b/pgproto3/parse.go @@ -0,0 +1,83 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +type Parse struct { + Name string + Query string + ParameterOIDs []uint32 +} + +func (*Parse) Frontend() {} + +func (dst *Parse) Decode(src []byte) error { + *dst = Parse{} + + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + dst.Name = string(b[:len(b)-1]) + + b, err = buf.ReadBytes(0) + if err != nil { + return err + } + dst.Query = string(b[:len(b)-1]) + + if buf.Len() < 2 { + return &invalidMessageFormatErr{messageType: "Parse"} + } + parameterOIDCount := int(binary.BigEndian.Uint16(buf.Next(2))) + + for i := 0; i < parameterOIDCount; i++ { + if buf.Len() < 4 { + return &invalidMessageFormatErr{messageType: "Parse"} + } + dst.ParameterOIDs = append(dst.ParameterOIDs, binary.BigEndian.Uint32(buf.Next(4))) + } + + return nil +} + +func (src *Parse) Encode(dst []byte) []byte { + dst = append(dst, 'P') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.Name...) + dst = append(dst, 0) + dst = append(dst, src.Query...) + dst = append(dst, 0) + + dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) + for _, oid := range src.ParameterOIDs { + dst = pgio.AppendUint32(dst, oid) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +func (src *Parse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Name string + Query string + ParameterOIDs []uint32 + }{ + Type: "Parse", + Name: src.Name, + Query: src.Query, + ParameterOIDs: src.ParameterOIDs, + }) +} diff --git a/pgproto3/parse_complete.go b/pgproto3/parse_complete.go new file mode 100644 index 00000000..462a89ba --- /dev/null +++ b/pgproto3/parse_complete.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type ParseComplete struct{} + +func (*ParseComplete) Backend() {} + +func (dst *ParseComplete) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *ParseComplete) Encode(dst []byte) []byte { + return append(dst, '1', 0, 0, 0, 4) +} + +func (src *ParseComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "ParseComplete", + }) +} diff --git a/pgproto3/password_message.go b/pgproto3/password_message.go new file mode 100644 index 00000000..2ad3fe4a --- /dev/null +++ b/pgproto3/password_message.go @@ -0,0 +1,46 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +type PasswordMessage struct { + Password string +} + +func (*PasswordMessage) Frontend() {} + +func (dst *PasswordMessage) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + dst.Password = string(b[:len(b)-1]) + + return nil +} + +func (src *PasswordMessage) Encode(dst []byte) []byte { + dst = append(dst, 'p') + dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1)) + + dst = append(dst, src.Password...) + dst = append(dst, 0) + + return dst +} + +func (src *PasswordMessage) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Password string + }{ + Type: "PasswordMessage", + Password: src.Password, + }) +} diff --git a/pgproto3/pgproto3.go b/pgproto3/pgproto3.go new file mode 100644 index 00000000..fe7b085b --- /dev/null +++ b/pgproto3/pgproto3.go @@ -0,0 +1,42 @@ +package pgproto3 + +import "fmt" + +// Message is the interface implemented by an object that can decode and encode +// a particular PostgreSQL message. +type Message interface { + // Decode is allowed and expected to retain a reference to data after + // returning (unlike encoding.BinaryUnmarshaler). + Decode(data []byte) error + + // Encode appends itself to dst and returns the new buffer. + Encode(dst []byte) []byte +} + +type FrontendMessage interface { + Message + Frontend() // no-op method to distinguish frontend from backend methods +} + +type BackendMessage interface { + Message + Backend() // no-op method to distinguish frontend from backend methods +} + +type invalidMessageLenErr struct { + messageType string + expectedLen int + actualLen int +} + +func (e *invalidMessageLenErr) Error() string { + return fmt.Sprintf("%s body must have length of %d, but it is %d", e.messageType, e.expectedLen, e.actualLen) +} + +type invalidMessageFormatErr struct { + messageType string +} + +func (e *invalidMessageFormatErr) Error() string { + return fmt.Sprintf("%s body is invalid", e.messageType) +} diff --git a/pgproto3/query.go b/pgproto3/query.go new file mode 100644 index 00000000..d80c0fb4 --- /dev/null +++ b/pgproto3/query.go @@ -0,0 +1,45 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +type Query struct { + String string +} + +func (*Query) Frontend() {} + +func (dst *Query) Decode(src []byte) error { + i := bytes.IndexByte(src, 0) + if i != len(src)-1 { + return &invalidMessageFormatErr{messageType: "Query"} + } + + dst.String = string(src[:i]) + + return nil +} + +func (src *Query) Encode(dst []byte) []byte { + dst = append(dst, 'Q') + dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1)) + + dst = append(dst, src.String...) + dst = append(dst, 0) + + return dst +} + +func (src *Query) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + String string + }{ + Type: "Query", + String: src.String, + }) +} diff --git a/pgproto3/ready_for_query.go b/pgproto3/ready_for_query.go new file mode 100644 index 00000000..63b902bd --- /dev/null +++ b/pgproto3/ready_for_query.go @@ -0,0 +1,35 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type ReadyForQuery struct { + TxStatus byte +} + +func (*ReadyForQuery) Backend() {} + +func (dst *ReadyForQuery) Decode(src []byte) error { + if len(src) != 1 { + return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)} + } + + dst.TxStatus = src[0] + + return nil +} + +func (src *ReadyForQuery) Encode(dst []byte) []byte { + return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus) +} + +func (src *ReadyForQuery) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + TxStatus string + }{ + Type: "ReadyForQuery", + TxStatus: string(src.TxStatus), + }) +} diff --git a/pgproto3/row_description.go b/pgproto3/row_description.go new file mode 100644 index 00000000..d0df11b0 --- /dev/null +++ b/pgproto3/row_description.go @@ -0,0 +1,100 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +const ( + TextFormat = 0 + BinaryFormat = 1 +) + +type FieldDescription struct { + Name string + TableOID uint32 + TableAttributeNumber uint16 + DataTypeOID uint32 + DataTypeSize int16 + TypeModifier uint32 + Format int16 +} + +type RowDescription struct { + Fields []FieldDescription +} + +func (*RowDescription) Backend() {} + +func (dst *RowDescription) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 2 { + return &invalidMessageFormatErr{messageType: "RowDescription"} + } + fieldCount := int(binary.BigEndian.Uint16(buf.Next(2))) + + *dst = RowDescription{Fields: make([]FieldDescription, fieldCount)} + + for i := 0; i < fieldCount; i++ { + var fd FieldDescription + bName, err := buf.ReadBytes(0) + if err != nil { + return err + } + fd.Name = string(bName[:len(bName)-1]) + + // Since buf.Next() doesn't return an error if we hit the end of the buffer + // check Len ahead of time + if buf.Len() < 18 { + return &invalidMessageFormatErr{messageType: "RowDescription"} + } + + fd.TableOID = binary.BigEndian.Uint32(buf.Next(4)) + fd.TableAttributeNumber = binary.BigEndian.Uint16(buf.Next(2)) + fd.DataTypeOID = binary.BigEndian.Uint32(buf.Next(4)) + fd.DataTypeSize = int16(binary.BigEndian.Uint16(buf.Next(2))) + fd.TypeModifier = binary.BigEndian.Uint32(buf.Next(4)) + fd.Format = int16(binary.BigEndian.Uint16(buf.Next(2))) + + dst.Fields[i] = fd + } + + return nil +} + +func (src *RowDescription) Encode(dst []byte) []byte { + dst = append(dst, 'T') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = pgio.AppendUint16(dst, uint16(len(src.Fields))) + for _, fd := range src.Fields { + dst = append(dst, fd.Name...) + dst = append(dst, 0) + + dst = pgio.AppendUint32(dst, fd.TableOID) + dst = pgio.AppendUint16(dst, fd.TableAttributeNumber) + dst = pgio.AppendUint32(dst, fd.DataTypeOID) + dst = pgio.AppendInt16(dst, fd.DataTypeSize) + dst = pgio.AppendUint32(dst, fd.TypeModifier) + dst = pgio.AppendInt16(dst, fd.Format) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +func (src *RowDescription) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Fields []FieldDescription + }{ + Type: "RowDescription", + Fields: src.Fields, + }) +} diff --git a/pgproto3/startup_message.go b/pgproto3/startup_message.go new file mode 100644 index 00000000..6c5d4f99 --- /dev/null +++ b/pgproto3/startup_message.go @@ -0,0 +1,97 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +const ( + ProtocolVersionNumber = 196608 // 3.0 + sslRequestNumber = 80877103 +) + +type StartupMessage struct { + ProtocolVersion uint32 + Parameters map[string]string +} + +func (*StartupMessage) Frontend() {} + +func (dst *StartupMessage) Decode(src []byte) error { + if len(src) < 4 { + return errors.Errorf("startup message too short") + } + + dst.ProtocolVersion = binary.BigEndian.Uint32(src) + rp := 4 + + if dst.ProtocolVersion == sslRequestNumber { + return errors.Errorf("can't handle ssl connection request") + } + + if dst.ProtocolVersion != ProtocolVersionNumber { + return errors.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion) + } + + dst.Parameters = make(map[string]string) + for { + idx := bytes.IndexByte(src[rp:], 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "StartupMesage"} + } + key := string(src[rp : rp+idx]) + rp += idx + 1 + + idx = bytes.IndexByte(src[rp:], 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "StartupMesage"} + } + value := string(src[rp : rp+idx]) + rp += idx + 1 + + dst.Parameters[key] = value + + if len(src[rp:]) == 1 { + if src[rp] != 0 { + return errors.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp]) + } + break + } + } + + return nil +} + +func (src *StartupMessage) Encode(dst []byte) []byte { + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = pgio.AppendUint32(dst, src.ProtocolVersion) + for k, v := range src.Parameters { + dst = append(dst, k...) + dst = append(dst, 0) + dst = append(dst, v...) + dst = append(dst, 0) + } + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +func (src *StartupMessage) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProtocolVersion uint32 + Parameters map[string]string + }{ + Type: "StartupMessage", + ProtocolVersion: src.ProtocolVersion, + Parameters: src.Parameters, + }) +} diff --git a/pgproto3/sync.go b/pgproto3/sync.go new file mode 100644 index 00000000..85f4749a --- /dev/null +++ b/pgproto3/sync.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type Sync struct{} + +func (*Sync) Frontend() {} + +func (dst *Sync) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "Sync", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *Sync) Encode(dst []byte) []byte { + return append(dst, 'S', 0, 0, 0, 4) +} + +func (src *Sync) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "Sync", + }) +} diff --git a/pgproto3/terminate.go b/pgproto3/terminate.go new file mode 100644 index 00000000..0a3310da --- /dev/null +++ b/pgproto3/terminate.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type Terminate struct{} + +func (*Terminate) Frontend() {} + +func (dst *Terminate) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "Terminate", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *Terminate) Encode(dst []byte) []byte { + return append(dst, 'X', 0, 0, 0, 4) +} + +func (src *Terminate) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "Terminate", + }) +} diff --git a/pgtype/aclitem.go b/pgtype/aclitem.go new file mode 100644 index 00000000..35269e91 --- /dev/null +++ b/pgtype/aclitem.go @@ -0,0 +1,126 @@ +package pgtype + +import ( + "database/sql/driver" + + "github.com/pkg/errors" +) + +// ACLItem is used for PostgreSQL's aclitem data type. A sample aclitem +// might look like this: +// +// postgres=arwdDxt/postgres +// +// Note, however, that because the user/role name part of an aclitem is +// an identifier, it follows all the usual formatting rules for SQL +// identifiers: if it contains spaces and other special characters, +// it should appear in double-quotes: +// +// postgres=arwdDxt/"role with spaces" +// +type ACLItem struct { + String string + Status Status +} + +func (dst *ACLItem) Set(src interface{}) error { + switch value := src.(type) { + case string: + *dst = ACLItem{String: value, Status: Present} + case *string: + if value == nil { + *dst = ACLItem{Status: Null} + } else { + *dst = ACLItem{String: *value, Status: Present} + } + default: + if originalSrc, ok := underlyingStringType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to ACLItem", value) + } + + return nil +} + +func (dst *ACLItem) Get() interface{} { + switch dst.Status { + case Present: + return dst.String + case Null: + return nil + default: + return dst.Status + } +} + +func (src *ACLItem) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *string: + *v = src.String + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *ACLItem) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = ACLItem{Status: Null} + return nil + } + + *dst = ACLItem{String: string(src), Status: Present} + return nil +} + +func (src *ACLItem) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, src.String...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *ACLItem) Scan(src interface{}) error { + if src == nil { + *dst = ACLItem{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *ACLItem) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.String, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/aclitem_array.go b/pgtype/aclitem_array.go new file mode 100644 index 00000000..fe0af434 --- /dev/null +++ b/pgtype/aclitem_array.go @@ -0,0 +1,206 @@ +package pgtype + +import ( + "database/sql/driver" + + "github.com/pkg/errors" +) + +type ACLItemArray struct { + Elements []ACLItem + Dimensions []ArrayDimension + Status Status +} + +func (dst *ACLItemArray) Set(src interface{}) error { + switch value := src.(type) { + + case []string: + if value == nil { + *dst = ACLItemArray{Status: Null} + } else if len(value) == 0 { + *dst = ACLItemArray{Status: Present} + } else { + elements := make([]ACLItem, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = ACLItemArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to ACLItem", value) + } + + return nil +} + +func (dst *ACLItemArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *ACLItemArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *ACLItemArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = ACLItemArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []ACLItem + + if len(uta.Elements) > 0 { + elements = make([]ACLItem, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem ACLItem + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = ACLItemArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (src *ACLItemArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *ACLItemArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *ACLItemArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/aclitem_array_test.go b/pgtype/aclitem_array_test.go new file mode 100644 index 00000000..c01eaa13 --- /dev/null +++ b/pgtype/aclitem_array_test.go @@ -0,0 +1,152 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestACLItemArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "aclitem[]", []interface{}{ + &pgtype.ACLItemArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.ACLItemArray{Status: pgtype.Null}, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{Status: pgtype.Null}, + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestACLItemArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.ACLItemArray + }{ + { + source: []string{"=r/postgres"}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]string)(nil)), + result: pgtype.ACLItemArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.ACLItemArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestACLItemArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + + simpleTests := []struct { + src pgtype.ACLItemArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + expected: []string{"=r/postgres"}, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedStringSlice, + expected: _stringSlice{"=r/postgres"}, + }, + { + src: pgtype.ACLItemArray{Status: pgtype.Null}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.ACLItemArray + dst interface{} + }{ + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/aclitem_test.go b/pgtype/aclitem_test.go new file mode 100644 index 00000000..65399a30 --- /dev/null +++ b/pgtype/aclitem_test.go @@ -0,0 +1,97 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestACLItemTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "aclitem", []interface{}{ + &pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + &pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + &pgtype.ACLItem{Status: pgtype.Null}, + }) +} + +func TestACLItemSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.ACLItem + }{ + {source: "postgres=arwdDxt/postgres", result: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {source: (*string)(nil), result: pgtype.ACLItem{Status: pgtype.Null}}, + } + + for i, tt := range successfulTests { + var d pgtype.ACLItem + err := d.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestACLItemAssignTo(t *testing.T) { + var s string + var ps *string + + simpleTests := []struct { + src pgtype.ACLItem + dst interface{} + expected interface{} + }{ + {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &s, expected: "postgres=arwdDxt/postgres"}, + {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.ACLItem + dst interface{} + expected interface{} + }{ + {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &ps, expected: "postgres=arwdDxt/postgres"}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.ACLItem + dst interface{} + }{ + {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &s}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/array.go b/pgtype/array.go new file mode 100644 index 00000000..5b852ed5 --- /dev/null +++ b/pgtype/array.go @@ -0,0 +1,352 @@ +package pgtype + +import ( + "bytes" + "encoding/binary" + "io" + "strconv" + "strings" + "unicode" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +// Information on the internals of PostgreSQL arrays can be found in +// src/include/utils/array.h and src/backend/utils/adt/arrayfuncs.c. Of +// particular interest is the array_send function. + +type ArrayHeader struct { + ContainsNull bool + ElementOID int32 + Dimensions []ArrayDimension +} + +type ArrayDimension struct { + Length int32 + LowerBound int32 +} + +func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { + if len(src) < 12 { + return 0, errors.Errorf("array header too short: %d", len(src)) + } + + rp := 0 + + numDims := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + dst.ContainsNull = binary.BigEndian.Uint32(src[rp:]) == 1 + rp += 4 + + dst.ElementOID = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + if numDims > 0 { + dst.Dimensions = make([]ArrayDimension, numDims) + } + if len(src) < 12+numDims*8 { + return 0, errors.Errorf("array header too short for %d dimensions: %d", numDims, len(src)) + } + for i := range dst.Dimensions { + dst.Dimensions[i].Length = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + dst.Dimensions[i].LowerBound = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + } + + return rp, nil +} + +func (src *ArrayHeader) EncodeBinary(ci *ConnInfo, buf []byte) []byte { + buf = pgio.AppendInt32(buf, int32(len(src.Dimensions))) + + var containsNull int32 + if src.ContainsNull { + containsNull = 1 + } + buf = pgio.AppendInt32(buf, containsNull) + + buf = pgio.AppendInt32(buf, src.ElementOID) + + for i := range src.Dimensions { + buf = pgio.AppendInt32(buf, src.Dimensions[i].Length) + buf = pgio.AppendInt32(buf, src.Dimensions[i].LowerBound) + } + + return buf +} + +type UntypedTextArray struct { + Elements []string + Dimensions []ArrayDimension +} + +func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { + dst := &UntypedTextArray{} + + buf := bytes.NewBufferString(src) + + skipWhitespace(buf) + + r, _, err := buf.ReadRune() + if err != nil { + return nil, errors.Errorf("invalid array: %v", err) + } + + var explicitDimensions []ArrayDimension + + // Array has explicit dimensions + if r == '[' { + buf.UnreadRune() + + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, errors.Errorf("invalid array: %v", err) + } + + if r == '=' { + break + } else if r != '[' { + return nil, errors.Errorf("invalid array, expected '[' or '=' got %v", r) + } + + lower, err := arrayParseInteger(buf) + if err != nil { + return nil, errors.Errorf("invalid array: %v", err) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, errors.Errorf("invalid array: %v", err) + } + + if r != ':' { + return nil, errors.Errorf("invalid array, expected ':' got %v", r) + } + + upper, err := arrayParseInteger(buf) + if err != nil { + return nil, errors.Errorf("invalid array: %v", err) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, errors.Errorf("invalid array: %v", err) + } + + if r != ']' { + return nil, errors.Errorf("invalid array, expected ']' got %v", r) + } + + explicitDimensions = append(explicitDimensions, ArrayDimension{LowerBound: lower, Length: upper - lower + 1}) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, errors.Errorf("invalid array: %v", err) + } + } + + if r != '{' { + return nil, errors.Errorf("invalid array, expected '{': %v", err) + } + + implicitDimensions := []ArrayDimension{{LowerBound: 1, Length: 0}} + + // Consume all initial opening brackets. This provides number of dimensions. + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, errors.Errorf("invalid array: %v", err) + } + + if r == '{' { + implicitDimensions[len(implicitDimensions)-1].Length = 1 + implicitDimensions = append(implicitDimensions, ArrayDimension{LowerBound: 1}) + } else { + buf.UnreadRune() + break + } + } + currentDim := len(implicitDimensions) - 1 + counterDim := currentDim + + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, errors.Errorf("invalid array: %v", err) + } + + switch r { + case '{': + if currentDim == counterDim { + implicitDimensions[currentDim].Length++ + } + currentDim++ + case ',': + case '}': + currentDim-- + if currentDim < counterDim { + counterDim = currentDim + } + default: + buf.UnreadRune() + value, err := arrayParseValue(buf) + if err != nil { + return nil, errors.Errorf("invalid array value: %v", err) + } + if currentDim == counterDim { + implicitDimensions[currentDim].Length++ + } + dst.Elements = append(dst.Elements, value) + } + + if currentDim < 0 { + break + } + } + + skipWhitespace(buf) + + if buf.Len() > 0 { + return nil, errors.Errorf("unexpected trailing data: %v", buf.String()) + } + + if len(dst.Elements) == 0 { + dst.Dimensions = nil + } else if len(explicitDimensions) > 0 { + dst.Dimensions = explicitDimensions + } else { + dst.Dimensions = implicitDimensions + } + + return dst, nil +} + +func skipWhitespace(buf *bytes.Buffer) { + var r rune + var err error + for r, _, _ = buf.ReadRune(); unicode.IsSpace(r); r, _, _ = buf.ReadRune() { + } + + if err != io.EOF { + buf.UnreadRune() + } +} + +func arrayParseValue(buf *bytes.Buffer) (string, error) { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + if r == '"' { + return arrayParseQuotedValue(buf) + } + buf.UnreadRune() + + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case ',', '}': + buf.UnreadRune() + return s.String(), nil + } + + s.WriteRune(r) + } +} + +func arrayParseQuotedValue(buf *bytes.Buffer) (string, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + case '"': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + buf.UnreadRune() + return s.String(), nil + } + s.WriteRune(r) + } +} + +func arrayParseInteger(buf *bytes.Buffer) (int32, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return 0, err + } + + if '0' <= r && r <= '9' { + s.WriteRune(r) + } else { + buf.UnreadRune() + n, err := strconv.ParseInt(s.String(), 10, 32) + if err != nil { + return 0, err + } + return int32(n), nil + } + } +} + +func EncodeTextArrayDimensions(buf []byte, dimensions []ArrayDimension) []byte { + var customDimensions bool + for _, dim := range dimensions { + if dim.LowerBound != 1 { + customDimensions = true + } + } + + if !customDimensions { + return buf + } + + for _, dim := range dimensions { + buf = append(buf, '[') + buf = append(buf, strconv.FormatInt(int64(dim.LowerBound), 10)...) + buf = append(buf, ':') + buf = append(buf, strconv.FormatInt(int64(dim.LowerBound+dim.Length-1), 10)...) + buf = append(buf, ']') + } + + return append(buf, '=') +} + +var quoteArrayReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) + +func quoteArrayElement(src string) string { + return `"` + quoteArrayReplacer.Replace(src) + `"` +} + +func QuoteArrayElementIfNeeded(src string) string { + if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `{},"\`) { + return quoteArrayElement(src) + } + return src +} diff --git a/pgtype/array_test.go b/pgtype/array_test.go new file mode 100644 index 00000000..d1cdb4c5 --- /dev/null +++ b/pgtype/array_test.go @@ -0,0 +1,105 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestParseUntypedTextArray(t *testing.T) { + tests := []struct { + source string + result pgtype.UntypedTextArray + }{ + { + source: "{}", + result: pgtype.UntypedTextArray{ + Elements: nil, + Dimensions: nil, + }, + }, + { + source: "{1}", + result: pgtype.UntypedTextArray{ + Elements: []string{"1"}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: "{a,b}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b"}, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + }, + }, + { + source: `{"NULL"}`, + result: pgtype.UntypedTextArray{ + Elements: []string{"NULL"}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: `{""}`, + result: pgtype.UntypedTextArray{ + Elements: []string{""}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: `{"He said, \"Hello.\""}`, + result: pgtype.UntypedTextArray{ + Elements: []string{`He said, "Hello."`}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: "{{a,b},{c,d},{e,f}}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b", "c", "d", "e", "f"}, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + }, + }, + { + source: "{{{a,b},{c,d},{e,f}},{{a,b},{c,d},{e,f}}}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b", "c", "d", "e", "f", "a", "b", "c", "d", "e", "f"}, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 1}, + {Length: 3, LowerBound: 1}, + {Length: 2, LowerBound: 1}, + }, + }, + }, + { + source: "[4:4]={1}", + result: pgtype.UntypedTextArray{ + Elements: []string{"1"}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 4}}, + }, + }, + { + source: "[4:5][2:3]={{a,b},{c,d}}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b", "c", "d"}, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + }, + }, + } + + for i, tt := range tests { + r, err := pgtype.ParseUntypedTextArray(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + continue + } + + if !reflect.DeepEqual(*r, tt.result) { + t.Errorf("%d: expected %+v to be parsed to %+v, but it was %+v", i, tt.source, tt.result, *r) + } + } +} diff --git a/pgtype/bool.go b/pgtype/bool.go new file mode 100644 index 00000000..3a3eef48 --- /dev/null +++ b/pgtype/bool.go @@ -0,0 +1,159 @@ +package pgtype + +import ( + "database/sql/driver" + "strconv" + + "github.com/pkg/errors" +) + +type Bool struct { + Bool bool + Status Status +} + +func (dst *Bool) Set(src interface{}) error { + switch value := src.(type) { + case bool: + *dst = Bool{Bool: value, Status: Present} + case string: + bb, err := strconv.ParseBool(value) + if err != nil { + return err + } + *dst = Bool{Bool: bb, Status: Present} + default: + if originalSrc, ok := underlyingBoolType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Bool", value) + } + + return nil +} + +func (dst *Bool) Get() interface{} { + switch dst.Status { + case Present: + return dst.Bool + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Bool) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *bool: + *v = src.Bool + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Bool{Status: Null} + return nil + } + + if len(src) != 1 { + return errors.Errorf("invalid length for bool: %v", len(src)) + } + + *dst = Bool{Bool: src[0] == 't', Status: Present} + return nil +} + +func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Bool{Status: Null} + return nil + } + + if len(src) != 1 { + return errors.Errorf("invalid length for bool: %v", len(src)) + } + + *dst = Bool{Bool: src[0] == 1, Status: Present} + return nil +} + +func (src *Bool) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if src.Bool { + buf = append(buf, 't') + } else { + buf = append(buf, 'f') + } + + return buf, nil +} + +func (src *Bool) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if src.Bool { + buf = append(buf, 1) + } else { + buf = append(buf, 0) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Bool) Scan(src interface{}) error { + if src == nil { + *dst = Bool{Status: Null} + return nil + } + + switch src := src.(type) { + case bool: + *dst = Bool{Bool: src, Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Bool) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.Bool, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go new file mode 100644 index 00000000..e23c27e5 --- /dev/null +++ b/pgtype/bool_array.go @@ -0,0 +1,294 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type BoolArray struct { + Elements []Bool + Dimensions []ArrayDimension + Status Status +} + +func (dst *BoolArray) Set(src interface{}) error { + switch value := src.(type) { + + case []bool: + if value == nil { + *dst = BoolArray{Status: Null} + } else if len(value) == 0 { + *dst = BoolArray{Status: Present} + } else { + elements := make([]Bool, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = BoolArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Bool", value) + } + + return nil +} + +func (dst *BoolArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *BoolArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]bool: + *v = make([]bool, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = BoolArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Bool + + if len(uta.Elements) > 0 { + elements = make([]Bool, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Bool + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = BoolArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *BoolArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = BoolArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = BoolArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Bool, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = BoolArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *BoolArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src *BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("bool"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "bool") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *BoolArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *BoolArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/bool_array_test.go b/pgtype/bool_array_test.go new file mode 100644 index 00000000..87886da6 --- /dev/null +++ b/pgtype/bool_array_test.go @@ -0,0 +1,153 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestBoolArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "bool[]", []interface{}{ + &pgtype.BoolArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.BoolArray{ + Elements: []pgtype.Bool{ + pgtype.Bool{Bool: true, Status: pgtype.Present}, + pgtype.Bool{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.BoolArray{Status: pgtype.Null}, + &pgtype.BoolArray{ + Elements: []pgtype.Bool{ + pgtype.Bool{Bool: true, Status: pgtype.Present}, + pgtype.Bool{Bool: true, Status: pgtype.Present}, + pgtype.Bool{Bool: false, Status: pgtype.Present}, + pgtype.Bool{Bool: true, Status: pgtype.Present}, + pgtype.Bool{Status: pgtype.Null}, + pgtype.Bool{Bool: false, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.BoolArray{ + Elements: []pgtype.Bool{ + pgtype.Bool{Bool: true, Status: pgtype.Present}, + pgtype.Bool{Bool: false, Status: pgtype.Present}, + pgtype.Bool{Bool: true, Status: pgtype.Present}, + pgtype.Bool{Bool: false, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestBoolArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.BoolArray + }{ + { + source: []bool{true}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]bool)(nil)), + result: pgtype.BoolArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.BoolArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestBoolArrayAssignTo(t *testing.T) { + var boolSlice []bool + type _boolSlice []bool + var namedBoolSlice _boolSlice + + simpleTests := []struct { + src pgtype.BoolArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &boolSlice, + expected: []bool{true}, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedBoolSlice, + expected: _boolSlice{true}, + }, + { + src: pgtype.BoolArray{Status: pgtype.Null}, + dst: &boolSlice, + expected: (([]bool)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.BoolArray + dst interface{} + }{ + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &boolSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/bool_test.go b/pgtype/bool_test.go new file mode 100644 index 00000000..2712e3b0 --- /dev/null +++ b/pgtype/bool_test.go @@ -0,0 +1,96 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestBoolTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "bool", []interface{}{ + &pgtype.Bool{Bool: false, Status: pgtype.Present}, + &pgtype.Bool{Bool: true, Status: pgtype.Present}, + &pgtype.Bool{Bool: false, Status: pgtype.Null}, + }) +} + +func TestBoolSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Bool + }{ + {source: true, result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: false, result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + {source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: "false", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + {source: "t", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: "f", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + {source: _bool(true), result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: _bool(false), result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Bool + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestBoolAssignTo(t *testing.T) { + var b bool + var _b _bool + var pb *bool + var _pb *_bool + + simpleTests := []struct { + src pgtype.Bool + dst interface{} + expected interface{} + }{ + {src: pgtype.Bool{Bool: false, Status: pgtype.Present}, dst: &b, expected: false}, + {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &b, expected: true}, + {src: pgtype.Bool{Bool: false, Status: pgtype.Present}, dst: &_b, expected: _bool(false)}, + {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &_b, expected: _bool(true)}, + {src: pgtype.Bool{Bool: false, Status: pgtype.Null}, dst: &pb, expected: ((*bool)(nil))}, + {src: pgtype.Bool{Bool: false, Status: pgtype.Null}, dst: &_pb, expected: ((*_bool)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Bool + dst interface{} + expected interface{} + }{ + {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &pb, expected: true}, + {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &_pb, expected: _bool(true)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/box.go b/pgtype/box.go new file mode 100644 index 00000000..83df0499 --- /dev/null +++ b/pgtype/box.go @@ -0,0 +1,162 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Box struct { + P [2]Vec2 + Status Status +} + +func (dst *Box) Set(src interface{}) error { + return errors.Errorf("cannot convert %v to Box", src) +} + +func (dst *Box) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Box) AssignTo(dst interface{}) error { + return errors.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Box) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Box{Status: Null} + return nil + } + + if len(src) < 11 { + return errors.Errorf("invalid length for Box: %v", len(src)) + } + + str := string(src[1:]) + + var end int + end = strings.IndexByte(str, ',') + + x1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+3:] + end = strings.IndexByte(str, ',') + + x2, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1 : len(str)-1] + + y2, err := strconv.ParseFloat(str, 64) + if err != nil { + return err + } + + *dst = Box{P: [2]Vec2{{x1, y1}, {x2, y2}}, Status: Present} + return nil +} + +func (dst *Box) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Box{Status: Null} + return nil + } + + if len(src) != 32 { + return errors.Errorf("invalid length for Box: %v", len(src)) + } + + x1 := binary.BigEndian.Uint64(src) + y1 := binary.BigEndian.Uint64(src[8:]) + x2 := binary.BigEndian.Uint64(src[16:]) + y2 := binary.BigEndian.Uint64(src[24:]) + + *dst = Box{ + P: [2]Vec2{ + {math.Float64frombits(x1), math.Float64frombits(y1)}, + {math.Float64frombits(x2), math.Float64frombits(y2)}, + }, + Status: Present, + } + return nil +} + +func (src *Box) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, fmt.Sprintf(`(%f,%f),(%f,%f)`, + src.P[0].X, src.P[0].Y, src.P[1].X, src.P[1].Y)...) + return buf, nil +} + +func (src *Box) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].Y)) + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Box) Scan(src interface{}) error { + if src == nil { + *dst = Box{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Box) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/box_test.go b/pgtype/box_test.go new file mode 100644 index 00000000..f26cda68 --- /dev/null +++ b/pgtype/box_test.go @@ -0,0 +1,34 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestBoxTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "box", []interface{}{ + &pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, + Status: pgtype.Present, + }, + &pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Status: pgtype.Present, + }, + &pgtype.Box{Status: pgtype.Null}, + }) +} + +func TestBoxNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select '3.14, 1.678, 7.1, 5.234'::box", + Value: &pgtype.Box{ + P: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, + Status: pgtype.Present, + }, + }, + }) +} diff --git a/pgtype/bytea.go b/pgtype/bytea.go new file mode 100644 index 00000000..c7117f48 --- /dev/null +++ b/pgtype/bytea.go @@ -0,0 +1,156 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/hex" + + "github.com/pkg/errors" +) + +type Bytea struct { + Bytes []byte + Status Status +} + +func (dst *Bytea) Set(src interface{}) error { + if src == nil { + *dst = Bytea{Status: Null} + return nil + } + + switch value := src.(type) { + case []byte: + if value != nil { + *dst = Bytea{Bytes: value, Status: Present} + } else { + *dst = Bytea{Status: Null} + } + default: + if originalSrc, ok := underlyingBytesType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Bytea", value) + } + + return nil +} + +func (dst *Bytea) Get() interface{} { + switch dst.Status { + case Present: + return dst.Bytes + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Bytea) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *[]byte: + buf := make([]byte, len(src.Bytes)) + copy(buf, src.Bytes) + *v = buf + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +// DecodeText only supports the hex format. This has been the default since +// PostgreSQL 9.0. +func (dst *Bytea) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Bytea{Status: Null} + return nil + } + + if len(src) < 2 || src[0] != '\\' || src[1] != 'x' { + return errors.Errorf("invalid hex format") + } + + buf := make([]byte, (len(src)-2)/2) + _, err := hex.Decode(buf, src[2:]) + if err != nil { + return err + } + + *dst = Bytea{Bytes: buf, Status: Present} + return nil +} + +func (dst *Bytea) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Bytea{Status: Null} + return nil + } + + *dst = Bytea{Bytes: src, Status: Present} + return nil +} + +func (src *Bytea) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, `\x`...) + buf = append(buf, hex.EncodeToString(src.Bytes)...) + return buf, nil +} + +func (src *Bytea) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, src.Bytes...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Bytea) Scan(src interface{}) error { + if src == nil { + *dst = Bytea{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + buf := make([]byte, len(src)) + copy(buf, src) + *dst = Bytea{Bytes: buf, Status: Present} + return nil + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Bytea) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.Bytes, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/bytea_array.go b/pgtype/bytea_array.go new file mode 100644 index 00000000..f2842179 --- /dev/null +++ b/pgtype/bytea_array.go @@ -0,0 +1,294 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type ByteaArray struct { + Elements []Bytea + Dimensions []ArrayDimension + Status Status +} + +func (dst *ByteaArray) Set(src interface{}) error { + switch value := src.(type) { + + case [][]byte: + if value == nil { + *dst = ByteaArray{Status: Null} + } else if len(value) == 0 { + *dst = ByteaArray{Status: Present} + } else { + elements := make([]Bytea, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = ByteaArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Bytea", value) + } + + return nil +} + +func (dst *ByteaArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *ByteaArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[][]byte: + *v = make([][]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = ByteaArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Bytea + + if len(uta.Elements) > 0 { + elements = make([]Bytea, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Bytea + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = ByteaArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *ByteaArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = ByteaArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = ByteaArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Bytea, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = ByteaArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *ByteaArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src *ByteaArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("bytea"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "bytea") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *ByteaArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *ByteaArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/bytea_array_test.go b/pgtype/bytea_array_test.go new file mode 100644 index 00000000..451c2461 --- /dev/null +++ b/pgtype/bytea_array_test.go @@ -0,0 +1,120 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestByteaArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "bytea[]", []interface{}{ + &pgtype.ByteaArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.ByteaArray{Status: pgtype.Null}, + &pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Status: pgtype.Null}, + pgtype.Bytea{Bytes: []byte{1}, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{1}, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestByteaArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.ByteaArray + }{ + { + source: [][]byte{{1, 2, 3}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([][]byte)(nil)), + result: pgtype.ByteaArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.ByteaArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestByteaArrayAssignTo(t *testing.T) { + var byteByteSlice [][]byte + + simpleTests := []struct { + src pgtype.ByteaArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &byteByteSlice, + expected: [][]byte{{1, 2, 3}}, + }, + { + src: pgtype.ByteaArray{Status: pgtype.Null}, + dst: &byteByteSlice, + expected: (([][]byte)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/bytea_test.go b/pgtype/bytea_test.go new file mode 100644 index 00000000..fd5a0dec --- /dev/null +++ b/pgtype/bytea_test.go @@ -0,0 +1,73 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestByteaTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "bytea", []interface{}{ + &pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + &pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, + &pgtype.Bytea{Bytes: nil, Status: pgtype.Null}, + }) +} + +func TestByteaSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Bytea + }{ + {source: []byte{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, + {source: []byte{}, result: pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}}, + {source: []byte(nil), result: pgtype.Bytea{Status: pgtype.Null}}, + {source: _byteSlice{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, + {source: _byteSlice(nil), result: pgtype.Bytea{Status: pgtype.Null}}, + } + + for i, tt := range successfulTests { + var r pgtype.Bytea + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestByteaAssignTo(t *testing.T) { + var buf []byte + var _buf _byteSlice + var pbuf *[]byte + var _pbuf *_byteSlice + + simpleTests := []struct { + src pgtype.Bytea + dst interface{} + expected interface{} + }{ + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &buf, expected: []byte{1, 2, 3}}, + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &_buf, expected: _byteSlice{1, 2, 3}}, + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &pbuf, expected: &[]byte{1, 2, 3}}, + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &_pbuf, expected: &_byteSlice{1, 2, 3}}, + {src: pgtype.Bytea{Status: pgtype.Null}, dst: &pbuf, expected: ((*[]byte)(nil))}, + {src: pgtype.Bytea{Status: pgtype.Null}, dst: &_pbuf, expected: ((*_byteSlice)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/cid.go b/pgtype/cid.go new file mode 100644 index 00000000..0ed54f44 --- /dev/null +++ b/pgtype/cid.go @@ -0,0 +1,61 @@ +package pgtype + +import ( + "database/sql/driver" +) + +// CID is PostgreSQL's Command Identifier type. +// +// When one does +// +// select cmin, cmax, * from some_table; +// +// it is the data type of the cmin and cmax hidden system columns. +// +// It is currently implemented as an unsigned four byte integer. +// Its definition can be found in src/include/c.h as CommandId +// in the PostgreSQL sources. +type CID pguint32 + +// Set converts from src to dst. Note that as CID is not a general +// number type Set does not do automatic type conversion as other number +// types do. +func (dst *CID) Set(src interface{}) error { + return (*pguint32)(dst).Set(src) +} + +func (dst *CID) Get() interface{} { + return (*pguint32)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as CID is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *CID) AssignTo(dst interface{}) error { + return (*pguint32)(src).AssignTo(dst) +} + +func (dst *CID) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) +} + +func (dst *CID) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) +} + +func (src *CID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*pguint32)(src).EncodeText(ci, buf) +} + +func (src *CID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*pguint32)(src).EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *CID) Scan(src interface{}) error { + return (*pguint32)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *CID) Value() (driver.Value, error) { + return (*pguint32)(src).Value() +} diff --git a/pgtype/cid_test.go b/pgtype/cid_test.go new file mode 100644 index 00000000..0dfc56d4 --- /dev/null +++ b/pgtype/cid_test.go @@ -0,0 +1,108 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestCIDTranscode(t *testing.T) { + pgTypeName := "cid" + values := []interface{}{ + &pgtype.CID{Uint: 42, Status: pgtype.Present}, + &pgtype.CID{Status: pgtype.Null}, + } + eqFunc := func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + } + + testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + + // No direct conversion from int to cid, convert through text + testutil.TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc) + + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) + } +} + +func TestCIDSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.CID + }{ + {source: uint32(1), result: pgtype.CID{Uint: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.CID + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestCIDAssignTo(t *testing.T) { + var ui32 uint32 + var pui32 *uint32 + + simpleTests := []struct { + src pgtype.CID + dst interface{} + expected interface{} + }{ + {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.CID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.CID + dst interface{} + expected interface{} + }{ + {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.CID + dst interface{} + }{ + {src: pgtype.CID{Status: pgtype.Null}, dst: &ui32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/cidr.go b/pgtype/cidr.go new file mode 100644 index 00000000..519b9cae --- /dev/null +++ b/pgtype/cidr.go @@ -0,0 +1,31 @@ +package pgtype + +type CIDR Inet + +func (dst *CIDR) Set(src interface{}) error { + return (*Inet)(dst).Set(src) +} + +func (dst *CIDR) Get() interface{} { + return (*Inet)(dst).Get() +} + +func (src *CIDR) AssignTo(dst interface{}) error { + return (*Inet)(src).AssignTo(dst) +} + +func (dst *CIDR) DecodeText(ci *ConnInfo, src []byte) error { + return (*Inet)(dst).DecodeText(ci, src) +} + +func (dst *CIDR) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Inet)(dst).DecodeBinary(ci, src) +} + +func (src *CIDR) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Inet)(src).EncodeText(ci, buf) +} + +func (src *CIDR) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Inet)(src).EncodeBinary(ci, buf) +} diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go new file mode 100644 index 00000000..2373da46 --- /dev/null +++ b/pgtype/cidr_array.go @@ -0,0 +1,323 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "net" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type CIDRArray struct { + Elements []CIDR + Dimensions []ArrayDimension + Status Status +} + +func (dst *CIDRArray) Set(src interface{}) error { + switch value := src.(type) { + + case []*net.IPNet: + if value == nil { + *dst = CIDRArray{Status: Null} + } else if len(value) == 0 { + *dst = CIDRArray{Status: Present} + } else { + elements := make([]CIDR, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = CIDRArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []net.IP: + if value == nil { + *dst = CIDRArray{Status: Null} + } else if len(value) == 0 { + *dst = CIDRArray{Status: Present} + } else { + elements := make([]CIDR, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = CIDRArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to CIDR", value) + } + + return nil +} + +func (dst *CIDRArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *CIDRArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]*net.IPNet: + *v = make([]*net.IPNet, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]net.IP: + *v = make([]net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *CIDRArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = CIDRArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []CIDR + + if len(uta.Elements) > 0 { + elements = make([]CIDR, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem CIDR + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = CIDRArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *CIDRArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = CIDRArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = CIDRArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]CIDR, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = CIDRArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *CIDRArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src *CIDRArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("cidr"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "cidr") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *CIDRArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *CIDRArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/cidr_array_test.go b/pgtype/cidr_array_test.go new file mode 100644 index 00000000..70d3f65b --- /dev/null +++ b/pgtype/cidr_array_test.go @@ -0,0 +1,165 @@ +package pgtype_test + +import ( + "net" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestCIDRArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "cidr[]", []interface{}{ + &pgtype.CIDRArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + pgtype.CIDR{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.CIDR{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.CIDRArray{Status: pgtype.Null}, + &pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + pgtype.CIDR{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.CIDR{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.CIDR{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.CIDR{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.CIDR{Status: pgtype.Null}, + pgtype.CIDR{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + pgtype.CIDR{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.CIDR{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.CIDR{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.CIDR{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestCIDRArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.CIDRArray + }{ + { + source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]*net.IPNet)(nil)), + result: pgtype.CIDRArray{Status: pgtype.Null}, + }, + { + source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]net.IP)(nil)), + result: pgtype.CIDRArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.CIDRArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestCIDRArrayAssignTo(t *testing.T) { + var ipnetSlice []*net.IPNet + var ipSlice []net.IP + + simpleTests := []struct { + src pgtype.CIDRArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{nil}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + }, + { + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{nil}, + }, + { + src: pgtype.CIDRArray{Status: pgtype.Null}, + dst: &ipnetSlice, + expected: (([]*net.IPNet)(nil)), + }, + { + src: pgtype.CIDRArray{Status: pgtype.Null}, + dst: &ipSlice, + expected: (([]net.IP)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/circle.go b/pgtype/circle.go new file mode 100644 index 00000000..97ecbf31 --- /dev/null +++ b/pgtype/circle.go @@ -0,0 +1,146 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Circle struct { + P Vec2 + R float64 + Status Status +} + +func (dst *Circle) Set(src interface{}) error { + return errors.Errorf("cannot convert %v to Circle", src) +} + +func (dst *Circle) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Circle) AssignTo(dst interface{}) error { + return errors.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Circle) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Circle{Status: Null} + return nil + } + + if len(src) < 9 { + return errors.Errorf("invalid length for Circle: %v", len(src)) + } + + str := string(src[2:]) + end := strings.IndexByte(str, ',') + x, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+2 : len(str)-1] + + r, err := strconv.ParseFloat(str, 64) + if err != nil { + return err + } + + *dst = Circle{P: Vec2{x, y}, R: r, Status: Present} + return nil +} + +func (dst *Circle) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Circle{Status: Null} + return nil + } + + if len(src) != 24 { + return errors.Errorf("invalid length for Circle: %v", len(src)) + } + + x := binary.BigEndian.Uint64(src) + y := binary.BigEndian.Uint64(src[8:]) + r := binary.BigEndian.Uint64(src[16:]) + + *dst = Circle{ + P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, + R: math.Float64frombits(r), + Status: Present, + } + return nil +} + +func (src *Circle) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, fmt.Sprintf(`<(%f,%f),%f>`, src.P.X, src.P.Y, src.R)...) + return buf, nil +} + +func (src *Circle) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendUint64(buf, math.Float64bits(src.P.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P.Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.R)) + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Circle) Scan(src interface{}) error { + if src == nil { + *dst = Circle{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Circle) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/circle_test.go b/pgtype/circle_test.go new file mode 100644 index 00000000..2747d4f5 --- /dev/null +++ b/pgtype/circle_test.go @@ -0,0 +1,16 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestCircleTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "circle", []interface{}{ + &pgtype.Circle{P: pgtype.Vec2{1.234, 5.6789}, R: 3.5, Status: pgtype.Present}, + &pgtype.Circle{P: pgtype.Vec2{-1.234, -5.6789}, R: 12.9, Status: pgtype.Present}, + &pgtype.Circle{Status: pgtype.Null}, + }) +} diff --git a/pgtype/convert.go b/pgtype/convert.go new file mode 100644 index 00000000..5dfb738e --- /dev/null +++ b/pgtype/convert.go @@ -0,0 +1,424 @@ +package pgtype + +import ( + "math" + "reflect" + "time" + + "github.com/pkg/errors" +) + +const maxUint = ^uint(0) +const maxInt = int(maxUint >> 1) +const minInt = -maxInt - 1 + +// underlyingNumberType gets the underlying type that can be converted to Int2, Int4, Int8, Float4, or Float8 +func underlyingNumberType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Int: + convVal := int(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int8: + convVal := int8(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int16: + convVal := int16(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int32: + convVal := int32(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int64: + convVal := int64(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint: + convVal := uint(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint8: + convVal := uint8(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint16: + convVal := uint16(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint32: + convVal := uint32(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint64: + convVal := uint64(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Float32: + convVal := float32(refVal.Float()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Float64: + convVal := refVal.Float() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.String: + convVal := refVal.String() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + +// underlyingBoolType gets the underlying type that can be converted to Bool +func underlyingBoolType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Bool: + convVal := refVal.Bool() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + +// underlyingBytesType gets the underlying type that can be converted to []byte +func underlyingBytesType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Slice: + if refVal.Type().Elem().Kind() == reflect.Uint8 { + convVal := refVal.Bytes() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + } + + return nil, false +} + +// underlyingStringType gets the underlying type that can be converted to String +func underlyingStringType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.String: + convVal := refVal.String() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + +// underlyingPtrType dereferences a pointer +func underlyingPtrType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + return nil, false +} + +// underlyingTimeType gets the underlying type that can be converted to time.Time +func underlyingTimeType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return time.Time{}, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + timeType := reflect.TypeOf(time.Time{}) + if refVal.Type().ConvertibleTo(timeType) { + return refVal.Convert(timeType).Interface(), true + } + + return time.Time{}, false +} + +// underlyingSliceType gets the underlying slice type +func underlyingSliceType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Slice: + baseSliceType := reflect.SliceOf(refVal.Type().Elem()) + if refVal.Type().ConvertibleTo(baseSliceType) { + convVal := refVal.Convert(baseSliceType) + return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type() + } + } + + return nil, false +} + +func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { + if srcStatus == Present { + switch v := dst.(type) { + case *int: + if srcVal < int64(minInt) { + return errors.Errorf("%d is less than minimum value for int", srcVal) + } else if srcVal > int64(maxInt) { + return errors.Errorf("%d is greater than maximum value for int", srcVal) + } + *v = int(srcVal) + case *int8: + if srcVal < math.MinInt8 { + return errors.Errorf("%d is less than minimum value for int8", srcVal) + } else if srcVal > math.MaxInt8 { + return errors.Errorf("%d is greater than maximum value for int8", srcVal) + } + *v = int8(srcVal) + case *int16: + if srcVal < math.MinInt16 { + return errors.Errorf("%d is less than minimum value for int16", srcVal) + } else if srcVal > math.MaxInt16 { + return errors.Errorf("%d is greater than maximum value for int16", srcVal) + } + *v = int16(srcVal) + case *int32: + if srcVal < math.MinInt32 { + return errors.Errorf("%d is less than minimum value for int32", srcVal) + } else if srcVal > math.MaxInt32 { + return errors.Errorf("%d is greater than maximum value for int32", srcVal) + } + *v = int32(srcVal) + case *int64: + if srcVal < math.MinInt64 { + return errors.Errorf("%d is less than minimum value for int64", srcVal) + } else if srcVal > math.MaxInt64 { + return errors.Errorf("%d is greater than maximum value for int64", srcVal) + } + *v = int64(srcVal) + case *uint: + if srcVal < 0 { + return errors.Errorf("%d is less than zero for uint", srcVal) + } else if uint64(srcVal) > uint64(maxUint) { + return errors.Errorf("%d is greater than maximum value for uint", srcVal) + } + *v = uint(srcVal) + case *uint8: + if srcVal < 0 { + return errors.Errorf("%d is less than zero for uint8", srcVal) + } else if srcVal > math.MaxUint8 { + return errors.Errorf("%d is greater than maximum value for uint8", srcVal) + } + *v = uint8(srcVal) + case *uint16: + if srcVal < 0 { + return errors.Errorf("%d is less than zero for uint32", srcVal) + } else if srcVal > math.MaxUint16 { + return errors.Errorf("%d is greater than maximum value for uint16", srcVal) + } + *v = uint16(srcVal) + case *uint32: + if srcVal < 0 { + return errors.Errorf("%d is less than zero for uint32", srcVal) + } else if srcVal > math.MaxUint32 { + return errors.Errorf("%d is greater than maximum value for uint32", srcVal) + } + *v = uint32(srcVal) + case *uint64: + if srcVal < 0 { + return errors.Errorf("%d is less than zero for uint64", srcVal) + } + *v = uint64(srcVal) + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return int64AssignTo(srcVal, srcStatus, el.Interface()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if el.OverflowInt(int64(srcVal)) { + return errors.Errorf("cannot put %d into %T", srcVal, dst) + } + el.SetInt(int64(srcVal)) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if srcVal < 0 { + return errors.Errorf("%d is less than zero for %T", srcVal, dst) + } + if el.OverflowUint(uint64(srcVal)) { + return errors.Errorf("cannot put %d into %T", srcVal, dst) + } + el.SetUint(uint64(srcVal)) + return nil + } + } + return errors.Errorf("cannot assign %v into %T", srcVal, dst) + } + return nil + } + + // if dst is a pointer to pointer and srcStatus is not Present, nil it out + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + if el.Kind() == reflect.Ptr { + el.Set(reflect.Zero(el.Type())) + return nil + } + } + + return errors.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) +} + +func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { + if srcStatus == Present { + switch v := dst.(type) { + case *float32: + *v = float32(srcVal) + case *float64: + *v = srcVal + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return float64AssignTo(srcVal, srcStatus, el.Interface()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + i64 := int64(srcVal) + if float64(i64) == srcVal { + return int64AssignTo(i64, srcStatus, dst) + } + } + } + return errors.Errorf("cannot assign %v into %T", srcVal, dst) + } + return nil + } + + // if dst is a pointer to pointer and srcStatus is not Present, nil it out + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + if el.Kind() == reflect.Ptr { + el.Set(reflect.Zero(el.Type())) + return nil + } + } + + return errors.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) +} + +func NullAssignTo(dst interface{}) error { + dstPtr := reflect.ValueOf(dst) + + // AssignTo dst must always be a pointer + if dstPtr.Kind() != reflect.Ptr { + return errors.Errorf("cannot assign NULL to %T", dst) + } + + dstVal := dstPtr.Elem() + + switch dstVal.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Map: + dstVal.Set(reflect.Zero(dstVal.Type())) + return nil + } + + return errors.Errorf("cannot assign NULL to %T", dst) +} + +var kindTypes map[reflect.Kind]reflect.Type + +// GetAssignToDstType attempts to convert dst to something AssignTo can assign +// to. If dst is a pointer to pointer it allocates a value and returns the +// dereferences pointer. If dst is a named type such as *Foo where Foo is type +// Foo int16, it converts dst to *int16. +// +// GetAssignToDstType returns the converted dst and a bool representing if any +// change was made. +func GetAssignToDstType(dst interface{}) (interface{}, bool) { + dstPtr := reflect.ValueOf(dst) + + // AssignTo dst must always be a pointer + if dstPtr.Kind() != reflect.Ptr { + return nil, false + } + + dstVal := dstPtr.Elem() + + // if dst is a pointer to pointer, allocate space try again with the dereferenced pointer + if dstVal.Kind() == reflect.Ptr { + dstVal.Set(reflect.New(dstVal.Type().Elem())) + return dstVal.Interface(), true + } + + // if dst is pointer to a base type that has been renamed + if baseValType, ok := kindTypes[dstVal.Kind()]; ok { + nextDst := dstPtr.Convert(reflect.PtrTo(baseValType)) + return nextDst.Interface(), dstPtr.Type() != nextDst.Type() + } + + if dstVal.Kind() == reflect.Slice { + if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok { + baseSliceType := reflect.PtrTo(reflect.SliceOf(baseElemType)) + nextDst := dstPtr.Convert(baseSliceType) + return nextDst.Interface(), dstPtr.Type() != nextDst.Type() + } + } + + return nil, false +} + +func init() { + kindTypes = map[reflect.Kind]reflect.Type{ + reflect.Bool: reflect.TypeOf(false), + reflect.Float32: reflect.TypeOf(float32(0)), + reflect.Float64: reflect.TypeOf(float64(0)), + reflect.Int: reflect.TypeOf(int(0)), + reflect.Int8: reflect.TypeOf(int8(0)), + reflect.Int16: reflect.TypeOf(int16(0)), + reflect.Int32: reflect.TypeOf(int32(0)), + reflect.Int64: reflect.TypeOf(int64(0)), + reflect.Uint: reflect.TypeOf(uint(0)), + reflect.Uint8: reflect.TypeOf(uint8(0)), + reflect.Uint16: reflect.TypeOf(uint16(0)), + reflect.Uint32: reflect.TypeOf(uint32(0)), + reflect.Uint64: reflect.TypeOf(uint64(0)), + reflect.String: reflect.TypeOf(""), + } +} diff --git a/pgtype/database_sql.go b/pgtype/database_sql.go new file mode 100644 index 00000000..969536dd --- /dev/null +++ b/pgtype/database_sql.go @@ -0,0 +1,42 @@ +package pgtype + +import ( + "database/sql/driver" + + "github.com/pkg/errors" +) + +func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { + if valuer, ok := src.(driver.Valuer); ok { + return valuer.Value() + } + + if textEncoder, ok := src.(TextEncoder); ok { + buf, err := textEncoder.EncodeText(ci, nil) + if err != nil { + return nil, err + } + return string(buf), nil + } + + if binaryEncoder, ok := src.(BinaryEncoder); ok { + buf, err := binaryEncoder.EncodeBinary(ci, nil) + if err != nil { + return nil, err + } + return buf, nil + } + + return nil, errors.New("cannot convert to database/sql compatible value") +} + +func EncodeValueText(src TextEncoder) (interface{}, error) { + buf, err := src.EncodeText(nil, make([]byte, 0, 32)) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + return string(buf), err +} diff --git a/pgtype/date.go b/pgtype/date.go new file mode 100644 index 00000000..f1c0d8bd --- /dev/null +++ b/pgtype/date.go @@ -0,0 +1,209 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "time" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Date struct { + Time time.Time + Status Status + InfinityModifier InfinityModifier +} + +const ( + negativeInfinityDayOffset = -2147483648 + infinityDayOffset = 2147483647 +) + +func (dst *Date) Set(src interface{}) error { + if src == nil { + *dst = Date{Status: Null} + return nil + } + + switch value := src.(type) { + case time.Time: + *dst = Date{Time: value, Status: Present} + default: + if originalSrc, ok := underlyingTimeType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Date", value) + } + + return nil +} + +func (dst *Date) Get() interface{} { + switch dst.Status { + case Present: + if dst.InfinityModifier != None { + return dst.InfinityModifier + } + return dst.Time + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Date) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *time.Time: + if src.InfinityModifier != None { + return errors.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.Time + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *Date) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Date{Status: Null} + return nil + } + + sbuf := string(src) + switch sbuf { + case "infinity": + *dst = Date{Status: Present, InfinityModifier: Infinity} + case "-infinity": + *dst = Date{Status: Present, InfinityModifier: -Infinity} + default: + t, err := time.ParseInLocation("2006-01-02", sbuf, time.UTC) + if err != nil { + return err + } + + *dst = Date{Time: t, Status: Present} + } + + return nil +} + +func (dst *Date) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Date{Status: Null} + return nil + } + + if len(src) != 4 { + return errors.Errorf("invalid length for date: %v", len(src)) + } + + dayOffset := int32(binary.BigEndian.Uint32(src)) + + switch dayOffset { + case infinityDayOffset: + *dst = Date{Status: Present, InfinityModifier: Infinity} + case negativeInfinityDayOffset: + *dst = Date{Status: Present, InfinityModifier: -Infinity} + default: + t := time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.UTC) + *dst = Date{Time: t, Status: Present} + } + + return nil +} + +func (src *Date) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var s string + + switch src.InfinityModifier { + case None: + s = src.Time.Format("2006-01-02") + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return append(buf, s...), nil +} + +func (src *Date) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var daysSinceDateEpoch int32 + switch src.InfinityModifier { + case None: + tUnix := time.Date(src.Time.Year(), src.Time.Month(), src.Time.Day(), 0, 0, 0, 0, time.UTC).Unix() + dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() + + secSinceDateEpoch := tUnix - dateEpoch + daysSinceDateEpoch = int32(secSinceDateEpoch / 86400) + case Infinity: + daysSinceDateEpoch = infinityDayOffset + case NegativeInfinity: + daysSinceDateEpoch = negativeInfinityDayOffset + } + + return pgio.AppendInt32(buf, daysSinceDateEpoch), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Date) Scan(src interface{}) error { + if src == nil { + *dst = Date{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + case time.Time: + *dst = Date{Time: src, Status: Present} + return nil + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Date) Value() (driver.Value, error) { + switch src.Status { + case Present: + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/date_array.go b/pgtype/date_array.go new file mode 100644 index 00000000..383945e7 --- /dev/null +++ b/pgtype/date_array.go @@ -0,0 +1,295 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "time" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type DateArray struct { + Elements []Date + Dimensions []ArrayDimension + Status Status +} + +func (dst *DateArray) Set(src interface{}) error { + switch value := src.(type) { + + case []time.Time: + if value == nil { + *dst = DateArray{Status: Null} + } else if len(value) == 0 { + *dst = DateArray{Status: Present} + } else { + elements := make([]Date, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = DateArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Date", value) + } + + return nil +} + +func (dst *DateArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *DateArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]time.Time: + *v = make([]time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *DateArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = DateArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Date + + if len(uta.Elements) > 0 { + elements = make([]Date, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Date + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = DateArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *DateArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = DateArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = DateArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Date, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = DateArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *DateArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src *DateArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("date"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "date") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *DateArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *DateArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/date_array_test.go b/pgtype/date_array_test.go new file mode 100644 index 00000000..74ebfbbe --- /dev/null +++ b/pgtype/date_array_test.go @@ -0,0 +1,143 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestDateArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "date[]", []interface{}{ + &pgtype.DateArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.DateArray{ + Elements: []pgtype.Date{ + pgtype.Date{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.DateArray{Status: pgtype.Null}, + &pgtype.DateArray{ + Elements: []pgtype.Date{ + pgtype.Date{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Status: pgtype.Null}, + pgtype.Date{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.DateArray{ + Elements: []pgtype.Date{ + pgtype.Date{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestDateArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.DateArray + }{ + { + source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]time.Time)(nil)), + result: pgtype.DateArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.DateArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestDateArrayAssignTo(t *testing.T) { + var timeSlice []time.Time + + simpleTests := []struct { + src pgtype.DateArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + }, + { + src: pgtype.DateArray{Status: pgtype.Null}, + dst: &timeSlice, + expected: (([]time.Time)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.DateArray + dst interface{} + }{ + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/date_test.go b/pgtype/date_test.go new file mode 100644 index 00000000..d98e1652 --- /dev/null +++ b/pgtype/date_test.go @@ -0,0 +1,118 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestDateTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "date", []interface{}{ + &pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Status: pgtype.Null}, + &pgtype.Date{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, + &pgtype.Date{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + }, func(a, b interface{}) bool { + at := a.(pgtype.Date) + bt := b.(pgtype.Date) + + return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier + }) +} + +func TestDateSet(t *testing.T) { + type _time time.Time + + successfulTests := []struct { + source interface{} + result pgtype.Date + }{ + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var d pgtype.Date + err := d.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestDateAssignTo(t *testing.T) { + var tim time.Time + var ptim *time.Time + + simpleTests := []struct { + src pgtype.Date + dst interface{} + expected interface{} + }{ + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + {src: pgtype.Date{Time: time.Time{}, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Date + dst interface{} + expected interface{} + }{ + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Date + dst interface{} + }{ + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Null}, dst: &tim}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/daterange.go b/pgtype/daterange.go new file mode 100644 index 00000000..47cd7e46 --- /dev/null +++ b/pgtype/daterange.go @@ -0,0 +1,250 @@ +package pgtype + +import ( + "database/sql/driver" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Daterange struct { + Lower Date + Upper Date + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Daterange) Set(src interface{}) error { + return errors.Errorf("cannot convert %v to Daterange", src) +} + +func (dst *Daterange) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Daterange) AssignTo(dst interface{}) error { + return errors.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Daterange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Daterange{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Daterange{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Daterange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Daterange{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Daterange{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Daterange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) + } + + var err error + + if src.LowerType != Unbounded { + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if src.UpperType != Unbounded { + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) + } + + return buf, nil +} + +func (src Daterange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) + } + + buf = append(buf, rangeType) + + var err error + + if src.LowerType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if src.UpperType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Daterange) Scan(src interface{}) error { + if src == nil { + *dst = Daterange{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Daterange) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/daterange_test.go b/pgtype/daterange_test.go new file mode 100644 index 00000000..d2af5986 --- /dev/null +++ b/pgtype/daterange_test.go @@ -0,0 +1,67 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestDaterangeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "daterange", []interface{}{ + &pgtype.Daterange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + &pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + &pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + &pgtype.Daterange{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Daterange) + b := bb.(pgtype.Daterange) + + return a.Status == b.Status && + a.Lower.Time.Equal(b.Lower.Time) && + a.Lower.Status == b.Lower.Status && + a.Lower.InfinityModifier == b.Lower.InfinityModifier && + a.Upper.Time.Equal(b.Upper.Time) && + a.Upper.Status == b.Upper.Status && + a.Upper.InfinityModifier == b.Upper.InfinityModifier + }) +} + +func TestDaterangeNormalize(t *testing.T) { + testutil.TestSuccessfulNormalizeEqFunc(t, []testutil.NormalizeTest{ + { + SQL: "select daterange('2010-01-01', '2010-01-11', '(]')", + Value: pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(2010, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Date{Time: time.Date(2010, 1, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + }, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Daterange) + b := bb.(pgtype.Daterange) + + return a.Status == b.Status && + a.Lower.Time.Equal(b.Lower.Time) && + a.Lower.Status == b.Lower.Status && + a.Lower.InfinityModifier == b.Lower.InfinityModifier && + a.Upper.Time.Equal(b.Upper.Time) && + a.Upper.Status == b.Upper.Status && + a.Upper.InfinityModifier == b.Upper.InfinityModifier + }) +} diff --git a/pgtype/decimal.go b/pgtype/decimal.go new file mode 100644 index 00000000..79653cf3 --- /dev/null +++ b/pgtype/decimal.go @@ -0,0 +1,31 @@ +package pgtype + +type Decimal Numeric + +func (dst *Decimal) Set(src interface{}) error { + return (*Numeric)(dst).Set(src) +} + +func (dst *Decimal) Get() interface{} { + return (*Numeric)(dst).Get() +} + +func (src *Decimal) AssignTo(dst interface{}) error { + return (*Numeric)(src).AssignTo(dst) +} + +func (dst *Decimal) DecodeText(ci *ConnInfo, src []byte) error { + return (*Numeric)(dst).DecodeText(ci, src) +} + +func (dst *Decimal) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Numeric)(dst).DecodeBinary(ci, src) +} + +func (src *Decimal) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Numeric)(src).EncodeText(ci, buf) +} + +func (src *Decimal) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Numeric)(src).EncodeBinary(ci, buf) +} diff --git a/pgtype/ext/satori-uuid/uuid.go b/pgtype/ext/satori-uuid/uuid.go new file mode 100644 index 00000000..78a90035 --- /dev/null +++ b/pgtype/ext/satori-uuid/uuid.go @@ -0,0 +1,161 @@ +package uuid + +import ( + "database/sql/driver" + + "github.com/pkg/errors" + + "github.com/jackc/pgx/pgtype" + uuid "github.com/satori/go.uuid" +) + +var errUndefined = errors.New("cannot encode status undefined") + +type UUID struct { + UUID uuid.UUID + Status pgtype.Status +} + +func (dst *UUID) Set(src interface{}) error { + switch value := src.(type) { + case uuid.UUID: + *dst = UUID{UUID: value, Status: pgtype.Present} + case [16]byte: + *dst = UUID{UUID: uuid.UUID(value), Status: pgtype.Present} + case []byte: + if len(value) != 16 { + return errors.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) + } + *dst = UUID{Status: pgtype.Present} + copy(dst.UUID[:], value) + case string: + uuid, err := uuid.FromString(value) + if err != nil { + return err + } + *dst = UUID{UUID: uuid, Status: pgtype.Present} + default: + // If all else fails see if pgtype.UUID can handle it. If so, translate through that. + pgUUID := &pgtype.UUID{} + if err := pgUUID.Set(value); err != nil { + return errors.Errorf("cannot convert %v to UUID", value) + } + + *dst = UUID{UUID: uuid.UUID(pgUUID.Bytes), Status: pgUUID.Status} + } + + return nil +} + +func (dst *UUID) Get() interface{} { + switch dst.Status { + case pgtype.Present: + return dst.UUID + case pgtype.Null: + return nil + default: + return dst.Status + } +} + +func (src *UUID) AssignTo(dst interface{}) error { + switch src.Status { + case pgtype.Present: + switch v := dst.(type) { + case *uuid.UUID: + *v = src.UUID + case *[16]byte: + *v = [16]byte(src.UUID) + return nil + case *[]byte: + *v = make([]byte, 16) + copy(*v, src.UUID[:]) + return nil + case *string: + *v = src.UUID.String() + return nil + default: + if nextDst, retry := pgtype.GetAssignToDstType(v); retry { + return src.AssignTo(nextDst) + } + } + case pgtype.Null: + return pgtype.NullAssignTo(dst) + } + + return errors.Errorf("cannot assign %v into %T", src, dst) +} + +func (dst *UUID) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + *dst = UUID{Status: pgtype.Null} + return nil + } + + u, err := uuid.FromString(string(src)) + if err != nil { + return err + } + + *dst = UUID{UUID: u, Status: pgtype.Present} + return nil +} + +func (dst *UUID) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + *dst = UUID{Status: pgtype.Null} + return nil + } + + if len(src) != 16 { + return errors.Errorf("invalid length for UUID: %v", len(src)) + } + + *dst = UUID{Status: pgtype.Present} + copy(dst.UUID[:], src) + return nil +} + +func (src *UUID) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case pgtype.Null: + return nil, nil + case pgtype.Undefined: + return nil, errUndefined + } + + return append(buf, src.UUID.String()...), nil +} + +func (src *UUID) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case pgtype.Null: + return nil, nil + case pgtype.Undefined: + return nil, errUndefined + } + + return append(buf, src.UUID[:]...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *UUID) Scan(src interface{}) error { + if src == nil { + *dst = UUID{Status: pgtype.Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *UUID) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/pgtype/ext/satori-uuid/uuid_test.go b/pgtype/ext/satori-uuid/uuid_test.go new file mode 100644 index 00000000..02ebb770 --- /dev/null +++ b/pgtype/ext/satori-uuid/uuid_test.go @@ -0,0 +1,97 @@ +package uuid_test + +import ( + "bytes" + "testing" + + "github.com/jackc/pgx/pgtype" + satori "github.com/jackc/pgx/pgtype/ext/satori-uuid" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestUUIDTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ + &satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + &satori.UUID{Status: pgtype.Null}, + }) +} + +func TestUUIDSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result satori.UUID + }{ + { + source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: "00010203-0405-0607-0809-0a0b0c0d0e0f", + result: satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r satori.UUID + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestUUIDAssignTo(t *testing.T) { + { + src := satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst [16]byte + expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst []byte + expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if bytes.Compare(dst, expected) != 0 { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst string + expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + +} diff --git a/pgtype/ext/shopspring-numeric/decimal.go b/pgtype/ext/shopspring-numeric/decimal.go new file mode 100644 index 00000000..507a93dc --- /dev/null +++ b/pgtype/ext/shopspring-numeric/decimal.go @@ -0,0 +1,317 @@ +package numeric + +import ( + "database/sql/driver" + "strconv" + + "github.com/pkg/errors" + + "github.com/jackc/pgx/pgtype" + "github.com/shopspring/decimal" +) + +var errUndefined = errors.New("cannot encode status undefined") + +type Numeric struct { + Decimal decimal.Decimal + Status pgtype.Status +} + +func (dst *Numeric) Set(src interface{}) error { + if src == nil { + *dst = Numeric{Status: pgtype.Null} + return nil + } + + switch value := src.(type) { + case decimal.Decimal: + *dst = Numeric{Decimal: value, Status: pgtype.Present} + case float32: + *dst = Numeric{Decimal: decimal.NewFromFloat(float64(value)), Status: pgtype.Present} + case float64: + *dst = Numeric{Decimal: decimal.NewFromFloat(value), Status: pgtype.Present} + case int8: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case uint8: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case int16: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case uint16: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case int32: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case uint32: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case int64: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case uint64: + // uint64 could be greater than int64 so convert to string then to decimal + dec, err := decimal.NewFromString(strconv.FormatUint(value, 10)) + if err != nil { + return err + } + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + case int: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case uint: + // uint could be greater than int64 so convert to string then to decimal + dec, err := decimal.NewFromString(strconv.FormatUint(uint64(value), 10)) + if err != nil { + return err + } + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + case string: + dec, err := decimal.NewFromString(value) + if err != nil { + return err + } + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + default: + // If all else fails see if pgtype.Numeric can handle it. If so, translate through that. + num := &pgtype.Numeric{} + if err := num.Set(value); err != nil { + return errors.Errorf("cannot convert %v to Numeric", value) + } + + buf, err := num.EncodeText(nil, nil) + if err != nil { + return errors.Errorf("cannot convert %v to Numeric", value) + } + + dec, err := decimal.NewFromString(string(buf)) + if err != nil { + return errors.Errorf("cannot convert %v to Numeric", value) + } + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + } + + return nil +} + +func (dst *Numeric) Get() interface{} { + switch dst.Status { + case pgtype.Present: + return dst.Decimal + case pgtype.Null: + return nil + default: + return dst.Status + } +} + +func (src *Numeric) AssignTo(dst interface{}) error { + switch src.Status { + case pgtype.Present: + switch v := dst.(type) { + case *decimal.Decimal: + *v = src.Decimal + case *float32: + f, _ := src.Decimal.Float64() + *v = float32(f) + case *float64: + f, _ := src.Decimal.Float64() + *v = f + case *int: + if src.Decimal.Exponent() < 0 { + return errors.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, strconv.IntSize) + if err != nil { + return errors.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int(n) + case *int8: + if src.Decimal.Exponent() < 0 { + return errors.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, 8) + if err != nil { + return errors.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int8(n) + case *int16: + if src.Decimal.Exponent() < 0 { + return errors.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, 16) + if err != nil { + return errors.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int16(n) + case *int32: + if src.Decimal.Exponent() < 0 { + return errors.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, 32) + if err != nil { + return errors.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int32(n) + case *int64: + if src.Decimal.Exponent() < 0 { + return errors.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, 64) + if err != nil { + return errors.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int64(n) + case *uint: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return errors.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, strconv.IntSize) + if err != nil { + return errors.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint(n) + case *uint8: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return errors.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, 8) + if err != nil { + return errors.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint8(n) + case *uint16: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return errors.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, 16) + if err != nil { + return errors.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint16(n) + case *uint32: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return errors.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, 32) + if err != nil { + return errors.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint32(n) + case *uint64: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return errors.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, 64) + if err != nil { + return errors.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint64(n) + default: + if nextDst, retry := pgtype.GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case pgtype.Null: + return pgtype.NullAssignTo(dst) + } + + return nil +} + +func (dst *Numeric) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + *dst = Numeric{Status: pgtype.Null} + return nil + } + + dec, err := decimal.NewFromString(string(src)) + if err != nil { + return err + } + + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + return nil +} + +func (dst *Numeric) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + *dst = Numeric{Status: pgtype.Null} + return nil + } + + // For now at least, implement this in terms of pgtype.Numeric + + num := &pgtype.Numeric{} + if err := num.DecodeBinary(ci, src); err != nil { + return err + } + + buf, err := num.EncodeText(ci, nil) + if err != nil { + return err + } + + dec, err := decimal.NewFromString(string(buf)) + if err != nil { + return err + } + + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + + return nil +} + +func (src *Numeric) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case pgtype.Null: + return nil, nil + case pgtype.Undefined: + return nil, errUndefined + } + + return append(buf, src.Decimal.String()...), nil +} + +func (src *Numeric) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case pgtype.Null: + return nil, nil + case pgtype.Undefined: + return nil, errUndefined + } + + // For now at least, implement this in terms of pgtype.Numeric + num := &pgtype.Numeric{} + if err := num.DecodeText(ci, []byte(src.Decimal.String())); err != nil { + return nil, err + } + + return num.EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Numeric) Scan(src interface{}) error { + if src == nil { + *dst = Numeric{Status: pgtype.Null} + return nil + } + + switch src := src.(type) { + case float64: + *dst = Numeric{Decimal: decimal.NewFromFloat(src), Status: pgtype.Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Numeric) Value() (driver.Value, error) { + switch src.Status { + case pgtype.Present: + return src.Decimal.Value() + case pgtype.Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/ext/shopspring-numeric/decimal_test.go b/pgtype/ext/shopspring-numeric/decimal_test.go new file mode 100644 index 00000000..79121ef3 --- /dev/null +++ b/pgtype/ext/shopspring-numeric/decimal_test.go @@ -0,0 +1,286 @@ +package numeric_test + +import ( + "fmt" + "math/big" + "math/rand" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + shopspring "github.com/jackc/pgx/pgtype/ext/shopspring-numeric" + "github.com/jackc/pgx/pgtype/testutil" + "github.com/shopspring/decimal" +) + +func mustParseDecimal(t *testing.T, src string) decimal.Decimal { + dec, err := decimal.NewFromString(src) + if err != nil { + t.Fatal(err) + } + return dec +} + +func TestNumericNormalize(t *testing.T) { + testutil.TestSuccessfulNormalizeEqFunc(t, []testutil.NormalizeTest{ + { + SQL: "select '0'::numeric", + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present}, + }, + { + SQL: "select '1'::numeric", + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}, + }, + { + SQL: "select '10.00'::numeric", + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "10.00"), Status: pgtype.Present}, + }, + { + SQL: "select '1e-3'::numeric", + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present}, + }, + { + SQL: "select '-1'::numeric", + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, + }, + { + SQL: "select '10000'::numeric", + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "10000"), Status: pgtype.Present}, + }, + { + SQL: "select '3.14'::numeric", + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present}, + }, + { + SQL: "select '1.1'::numeric", + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1.1"), Status: pgtype.Present}, + }, + { + SQL: "select '100010001'::numeric", + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001"), Status: pgtype.Present}, + }, + { + SQL: "select '100010001.0001'::numeric", + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001.0001"), Status: pgtype.Present}, + }, + { + SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", + Value: &shopspring.Numeric{ + Decimal: mustParseDecimal(t, "4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981"), + Status: pgtype.Present, + }, + }, + { + SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", + Value: &shopspring.Numeric{ + Decimal: mustParseDecimal(t, "0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), + Status: pgtype.Present, + }, + }, + { + SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", + Value: &shopspring.Numeric{ + Decimal: mustParseDecimal(t, "0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123"), + Status: pgtype.Present, + }, + }, + }, func(aa, bb interface{}) bool { + a := aa.(shopspring.Numeric) + b := bb.(shopspring.Numeric) + + return a.Status == b.Status && a.Decimal.Equal(b.Decimal) + }) +} + +func TestNumericTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "100000"), Status: pgtype.Present}, + + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.1"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.01"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0001"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00001"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000001"), Status: pgtype.Present}, + + &shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000123"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000000123"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0000000123"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000123"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001234567890123456789"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "4309132809320932980457137401234890237489238912983572189348951289375283573984571892758234678903467889512893489128589347891272139.8489235871258912789347891235879148795891238915678189467128957812395781238579189025891238901583915890128973578957912385798125789012378905238905471598123758923478294374327894237892234"), Status: pgtype.Present}, + &shopspring.Numeric{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(shopspring.Numeric) + b := bb.(shopspring.Numeric) + + return a.Status == b.Status && a.Decimal.Equal(b.Decimal) + }) + +} + +func TestNumericTranscodeFuzz(t *testing.T) { + r := rand.New(rand.NewSource(0)) + max := &big.Int{} + max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) + + values := make([]interface{}, 0, 2000) + for i := 0; i < 500; i++ { + num := fmt.Sprintf("%s.%s", (&big.Int{}).Rand(r, max).String(), (&big.Int{}).Rand(r, max).String()) + negNum := "-" + num + values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, num), Status: pgtype.Present}) + values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, negNum), Status: pgtype.Present}) + } + + testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", values, + func(aa, bb interface{}) bool { + a := aa.(shopspring.Numeric) + b := bb.(shopspring.Numeric) + + return a.Status == b.Status && a.Decimal.Equal(b.Decimal) + }) +} + +func TestNumericSet(t *testing.T) { + type _int8 int8 + + successfulTests := []struct { + source interface{} + result *shopspring.Numeric + }{ + {source: float32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: float64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int8(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, + {source: int16(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, + {source: int32(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, + {source: int64(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, + {source: uint8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: uint16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: uint32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: uint64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: "1", result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: _int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: float64(1000), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1000"), Status: pgtype.Present}}, + {source: float64(1234), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1234"), Status: pgtype.Present}}, + {source: float64(12345678900), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "12345678900"), Status: pgtype.Present}}, + {source: float64(12345.678901), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "12345.678901"), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + r := &shopspring.Numeric{} + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !(r.Status == tt.result.Status && r.Decimal.Equal(tt.result.Decimal)) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestNumericAssignTo(t *testing.T) { + type _int8 int8 + + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + var f32 float32 + var f64 float64 + var pf32 *float32 + var pf64 *float64 + + simpleTests := []struct { + src *shopspring.Numeric + dst interface{} + expected interface{} + }{ + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &f32, expected: float32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &f64, expected: float64(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Status: pgtype.Present}, dst: &f32, expected: float32(4.2)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Status: pgtype.Present}, dst: &f64, expected: float64(4.2)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Status: pgtype.Present}, dst: &i64, expected: int64(42000)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src *shopspring.Numeric + dst interface{} + expected interface{} + }{ + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src *shopspring.Numeric + dst interface{} + }{ + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "150"), Status: pgtype.Present}, dst: &i8}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "40000"), Status: pgtype.Present}, dst: &i16}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui8}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui16}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui32}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui64}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui}, + {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/float4.go b/pgtype/float4.go new file mode 100644 index 00000000..2207594a --- /dev/null +++ b/pgtype/float4.go @@ -0,0 +1,197 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Float4 struct { + Float float32 + Status Status +} + +func (dst *Float4) Set(src interface{}) error { + if src == nil { + *dst = Float4{Status: Null} + return nil + } + + switch value := src.(type) { + case float32: + *dst = Float4{Float: value, Status: Present} + case float64: + *dst = Float4{Float: float32(value), Status: Present} + case int8: + *dst = Float4{Float: float32(value), Status: Present} + case uint8: + *dst = Float4{Float: float32(value), Status: Present} + case int16: + *dst = Float4{Float: float32(value), Status: Present} + case uint16: + *dst = Float4{Float: float32(value), Status: Present} + case int32: + f32 := float32(value) + if int32(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return errors.Errorf("%v cannot be exactly represented as float32", value) + } + case uint32: + f32 := float32(value) + if uint32(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return errors.Errorf("%v cannot be exactly represented as float32", value) + } + case int64: + f32 := float32(value) + if int64(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return errors.Errorf("%v cannot be exactly represented as float32", value) + } + case uint64: + f32 := float32(value) + if uint64(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return errors.Errorf("%v cannot be exactly represented as float32", value) + } + case int: + f32 := float32(value) + if int(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return errors.Errorf("%v cannot be exactly represented as float32", value) + } + case uint: + f32 := float32(value) + if uint(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return errors.Errorf("%v cannot be exactly represented as float32", value) + } + case string: + num, err := strconv.ParseFloat(value, 32) + if err != nil { + return err + } + *dst = Float4{Float: float32(num), Status: Present} + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Float8", value) + } + + return nil +} + +func (dst *Float4) Get() interface{} { + switch dst.Status { + case Present: + return dst.Float + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Float4) AssignTo(dst interface{}) error { + return float64AssignTo(float64(src.Float), src.Status, dst) +} + +func (dst *Float4) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float4{Status: Null} + return nil + } + + n, err := strconv.ParseFloat(string(src), 32) + if err != nil { + return err + } + + *dst = Float4{Float: float32(n), Status: Present} + return nil +} + +func (dst *Float4) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float4{Status: Null} + return nil + } + + if len(src) != 4 { + return errors.Errorf("invalid length for float4: %v", len(src)) + } + + n := int32(binary.BigEndian.Uint32(src)) + + *dst = Float4{Float: math.Float32frombits(uint32(n)), Status: Present} + return nil +} + +func (src *Float4) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, strconv.FormatFloat(float64(src.Float), 'f', -1, 32)...) + return buf, nil +} + +func (src *Float4) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendUint32(buf, math.Float32bits(src.Float)) + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Float4) Scan(src interface{}) error { + if src == nil { + *dst = Float4{Status: Null} + return nil + } + + switch src := src.(type) { + case float64: + *dst = Float4{Float: float32(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Float4) Value() (driver.Value, error) { + switch src.Status { + case Present: + return float64(src.Float), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/float4_array.go b/pgtype/float4_array.go new file mode 100644 index 00000000..6499064b --- /dev/null +++ b/pgtype/float4_array.go @@ -0,0 +1,294 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Float4Array struct { + Elements []Float4 + Dimensions []ArrayDimension + Status Status +} + +func (dst *Float4Array) Set(src interface{}) error { + switch value := src.(type) { + + case []float32: + if value == nil { + *dst = Float4Array{Status: Null} + } else if len(value) == 0 { + *dst = Float4Array{Status: Present} + } else { + elements := make([]Float4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Float4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Float4", value) + } + + return nil +} + +func (dst *Float4Array) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Float4Array) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]float32: + *v = make([]float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *Float4Array) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float4Array{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Float4 + + if len(uta.Elements) > 0 { + elements = make([]Float4, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Float4 + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Float4Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *Float4Array) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float4Array{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = Float4Array{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Float4, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = Float4Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *Float4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src *Float4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("float4"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "float4") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Float4Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Float4Array) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/float4_array_test.go b/pgtype/float4_array_test.go new file mode 100644 index 00000000..6d6a4f30 --- /dev/null +++ b/pgtype/float4_array_test.go @@ -0,0 +1,152 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestFloat4ArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "float4[]", []interface{}{ + &pgtype.Float4Array{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.Float4Array{ + Elements: []pgtype.Float4{ + pgtype.Float4{Float: 1, Status: pgtype.Present}, + pgtype.Float4{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Float4Array{Status: pgtype.Null}, + &pgtype.Float4Array{ + Elements: []pgtype.Float4{ + pgtype.Float4{Float: 1, Status: pgtype.Present}, + pgtype.Float4{Float: 2, Status: pgtype.Present}, + pgtype.Float4{Float: 3, Status: pgtype.Present}, + pgtype.Float4{Float: 4, Status: pgtype.Present}, + pgtype.Float4{Status: pgtype.Null}, + pgtype.Float4{Float: 6, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Float4Array{ + Elements: []pgtype.Float4{ + pgtype.Float4{Float: 1, Status: pgtype.Present}, + pgtype.Float4{Float: 2, Status: pgtype.Present}, + pgtype.Float4{Float: 3, Status: pgtype.Present}, + pgtype.Float4{Float: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestFloat4ArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Float4Array + }{ + { + source: []float32{1}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]float32)(nil)), + result: pgtype.Float4Array{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Float4Array + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestFloat4ArrayAssignTo(t *testing.T) { + var float32Slice []float32 + var namedFloat32Slice _float32Slice + + simpleTests := []struct { + src pgtype.Float4Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1.23, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float32Slice, + expected: []float32{1.23}, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1.23, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedFloat32Slice, + expected: _float32Slice{1.23}, + }, + { + src: pgtype.Float4Array{Status: pgtype.Null}, + dst: &float32Slice, + expected: (([]float32)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Float4Array + dst interface{} + }{ + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float32Slice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/float4_test.go b/pgtype/float4_test.go new file mode 100644 index 00000000..2ed8d05d --- /dev/null +++ b/pgtype/float4_test.go @@ -0,0 +1,149 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestFloat4Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "float4", []interface{}{ + &pgtype.Float4{Float: -1, Status: pgtype.Present}, + &pgtype.Float4{Float: 0, Status: pgtype.Present}, + &pgtype.Float4{Float: 0.00001, Status: pgtype.Present}, + &pgtype.Float4{Float: 1, Status: pgtype.Present}, + &pgtype.Float4{Float: 9999.99, Status: pgtype.Present}, + &pgtype.Float4{Float: 0, Status: pgtype.Null}, + }) +} + +func TestFloat4Set(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Float4 + }{ + {source: float32(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: float64(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: int8(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Float4 + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestFloat4AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + var f32 float32 + var f64 float64 + var pf32 *float32 + var pf64 *float64 + + simpleTests := []struct { + src pgtype.Float4 + dst interface{} + expected interface{} + }{ + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &f32, expected: float32(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &f64, expected: float64(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Float4{Float: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Float4{Float: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Float4 + dst interface{} + expected interface{} + }{ + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Float4 + dst interface{} + }{ + {src: pgtype.Float4{Float: 150, Status: pgtype.Present}, dst: &i8}, + {src: pgtype.Float4{Float: 40000, Status: pgtype.Present}, dst: &i16}, + {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.Float4{Float: 0, Status: pgtype.Null}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/float8.go b/pgtype/float8.go new file mode 100644 index 00000000..dd34f541 --- /dev/null +++ b/pgtype/float8.go @@ -0,0 +1,187 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Float8 struct { + Float float64 + Status Status +} + +func (dst *Float8) Set(src interface{}) error { + if src == nil { + *dst = Float8{Status: Null} + return nil + } + + switch value := src.(type) { + case float32: + *dst = Float8{Float: float64(value), Status: Present} + case float64: + *dst = Float8{Float: value, Status: Present} + case int8: + *dst = Float8{Float: float64(value), Status: Present} + case uint8: + *dst = Float8{Float: float64(value), Status: Present} + case int16: + *dst = Float8{Float: float64(value), Status: Present} + case uint16: + *dst = Float8{Float: float64(value), Status: Present} + case int32: + *dst = Float8{Float: float64(value), Status: Present} + case uint32: + *dst = Float8{Float: float64(value), Status: Present} + case int64: + f64 := float64(value) + if int64(f64) == value { + *dst = Float8{Float: f64, Status: Present} + } else { + return errors.Errorf("%v cannot be exactly represented as float64", value) + } + case uint64: + f64 := float64(value) + if uint64(f64) == value { + *dst = Float8{Float: f64, Status: Present} + } else { + return errors.Errorf("%v cannot be exactly represented as float64", value) + } + case int: + f64 := float64(value) + if int(f64) == value { + *dst = Float8{Float: f64, Status: Present} + } else { + return errors.Errorf("%v cannot be exactly represented as float64", value) + } + case uint: + f64 := float64(value) + if uint(f64) == value { + *dst = Float8{Float: f64, Status: Present} + } else { + return errors.Errorf("%v cannot be exactly represented as float64", value) + } + case string: + num, err := strconv.ParseFloat(value, 64) + if err != nil { + return err + } + *dst = Float8{Float: float64(num), Status: Present} + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Float8", value) + } + + return nil +} + +func (dst *Float8) Get() interface{} { + switch dst.Status { + case Present: + return dst.Float + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Float8) AssignTo(dst interface{}) error { + return float64AssignTo(src.Float, src.Status, dst) +} + +func (dst *Float8) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float8{Status: Null} + return nil + } + + n, err := strconv.ParseFloat(string(src), 64) + if err != nil { + return err + } + + *dst = Float8{Float: n, Status: Present} + return nil +} + +func (dst *Float8) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float8{Status: Null} + return nil + } + + if len(src) != 8 { + return errors.Errorf("invalid length for float4: %v", len(src)) + } + + n := int64(binary.BigEndian.Uint64(src)) + + *dst = Float8{Float: math.Float64frombits(uint64(n)), Status: Present} + return nil +} + +func (src *Float8) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, strconv.FormatFloat(float64(src.Float), 'f', -1, 64)...) + return buf, nil +} + +func (src *Float8) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendUint64(buf, math.Float64bits(src.Float)) + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Float8) Scan(src interface{}) error { + if src == nil { + *dst = Float8{Status: Null} + return nil + } + + switch src := src.(type) { + case float64: + *dst = Float8{Float: src, Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Float8) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.Float, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/float8_array.go b/pgtype/float8_array.go new file mode 100644 index 00000000..27b24836 --- /dev/null +++ b/pgtype/float8_array.go @@ -0,0 +1,294 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Float8Array struct { + Elements []Float8 + Dimensions []ArrayDimension + Status Status +} + +func (dst *Float8Array) Set(src interface{}) error { + switch value := src.(type) { + + case []float64: + if value == nil { + *dst = Float8Array{Status: Null} + } else if len(value) == 0 { + *dst = Float8Array{Status: Present} + } else { + elements := make([]Float8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Float8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Float8", value) + } + + return nil +} + +func (dst *Float8Array) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Float8Array) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]float64: + *v = make([]float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *Float8Array) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float8Array{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Float8 + + if len(uta.Elements) > 0 { + elements = make([]Float8, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Float8 + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Float8Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *Float8Array) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float8Array{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = Float8Array{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Float8, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = Float8Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *Float8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src *Float8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("float8"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "float8") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Float8Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Float8Array) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/float8_array_test.go b/pgtype/float8_array_test.go new file mode 100644 index 00000000..56801e80 --- /dev/null +++ b/pgtype/float8_array_test.go @@ -0,0 +1,152 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestFloat8ArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "float8[]", []interface{}{ + &pgtype.Float8Array{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.Float8Array{ + Elements: []pgtype.Float8{ + pgtype.Float8{Float: 1, Status: pgtype.Present}, + pgtype.Float8{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Float8Array{Status: pgtype.Null}, + &pgtype.Float8Array{ + Elements: []pgtype.Float8{ + pgtype.Float8{Float: 1, Status: pgtype.Present}, + pgtype.Float8{Float: 2, Status: pgtype.Present}, + pgtype.Float8{Float: 3, Status: pgtype.Present}, + pgtype.Float8{Float: 4, Status: pgtype.Present}, + pgtype.Float8{Status: pgtype.Null}, + pgtype.Float8{Float: 6, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Float8Array{ + Elements: []pgtype.Float8{ + pgtype.Float8{Float: 1, Status: pgtype.Present}, + pgtype.Float8{Float: 2, Status: pgtype.Present}, + pgtype.Float8{Float: 3, Status: pgtype.Present}, + pgtype.Float8{Float: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestFloat8ArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Float8Array + }{ + { + source: []float64{1}, + result: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]float64)(nil)), + result: pgtype.Float8Array{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Float8Array + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestFloat8ArrayAssignTo(t *testing.T) { + var float64Slice []float64 + var namedFloat64Slice _float64Slice + + simpleTests := []struct { + src pgtype.Float8Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1.23, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float64Slice, + expected: []float64{1.23}, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1.23, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedFloat64Slice, + expected: _float64Slice{1.23}, + }, + { + src: pgtype.Float8Array{Status: pgtype.Null}, + dst: &float64Slice, + expected: (([]float64)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Float8Array + dst interface{} + }{ + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float64Slice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/float8_test.go b/pgtype/float8_test.go new file mode 100644 index 00000000..46fc8d5d --- /dev/null +++ b/pgtype/float8_test.go @@ -0,0 +1,149 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestFloat8Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "float8", []interface{}{ + &pgtype.Float8{Float: -1, Status: pgtype.Present}, + &pgtype.Float8{Float: 0, Status: pgtype.Present}, + &pgtype.Float8{Float: 0.00001, Status: pgtype.Present}, + &pgtype.Float8{Float: 1, Status: pgtype.Present}, + &pgtype.Float8{Float: 9999.99, Status: pgtype.Present}, + &pgtype.Float8{Float: 0, Status: pgtype.Null}, + }) +} + +func TestFloat8Set(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Float8 + }{ + {source: float32(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: float64(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: int8(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Float8 + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestFloat8AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + var f32 float32 + var f64 float64 + var pf32 *float32 + var pf64 *float64 + + simpleTests := []struct { + src pgtype.Float8 + dst interface{} + expected interface{} + }{ + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &f32, expected: float32(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &f64, expected: float64(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Float8{Float: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Float8{Float: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Float8 + dst interface{} + expected interface{} + }{ + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Float8 + dst interface{} + }{ + {src: pgtype.Float8{Float: 150, Status: pgtype.Present}, dst: &i8}, + {src: pgtype.Float8{Float: 40000, Status: pgtype.Present}, dst: &i16}, + {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.Float8{Float: 0, Status: pgtype.Null}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/generic_binary.go b/pgtype/generic_binary.go new file mode 100644 index 00000000..2596ecae --- /dev/null +++ b/pgtype/generic_binary.go @@ -0,0 +1,39 @@ +package pgtype + +import ( + "database/sql/driver" +) + +// GenericBinary is a placeholder for binary format values that no other type exists +// to handle. +type GenericBinary Bytea + +func (dst *GenericBinary) Set(src interface{}) error { + return (*Bytea)(dst).Set(src) +} + +func (dst *GenericBinary) Get() interface{} { + return (*Bytea)(dst).Get() +} + +func (src *GenericBinary) AssignTo(dst interface{}) error { + return (*Bytea)(src).AssignTo(dst) +} + +func (dst *GenericBinary) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Bytea)(dst).DecodeBinary(ci, src) +} + +func (src *GenericBinary) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Bytea)(src).EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *GenericBinary) Scan(src interface{}) error { + return (*Bytea)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *GenericBinary) Value() (driver.Value, error) { + return (*Bytea)(src).Value() +} diff --git a/pgtype/generic_text.go b/pgtype/generic_text.go new file mode 100644 index 00000000..0e3db9de --- /dev/null +++ b/pgtype/generic_text.go @@ -0,0 +1,39 @@ +package pgtype + +import ( + "database/sql/driver" +) + +// GenericText is a placeholder for text format values that no other type exists +// to handle. +type GenericText Text + +func (dst *GenericText) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst *GenericText) Get() interface{} { + return (*Text)(dst).Get() +} + +func (src *GenericText) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *GenericText) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (src *GenericText) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Text)(src).EncodeText(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *GenericText) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *GenericText) Value() (driver.Value, error) { + return (*Text)(src).Value() +} diff --git a/pgtype/hstore.go b/pgtype/hstore.go new file mode 100644 index 00000000..347446ae --- /dev/null +++ b/pgtype/hstore.go @@ -0,0 +1,434 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "strings" + "unicode" + "unicode/utf8" + + "github.com/pkg/errors" + + "github.com/jackc/pgx/pgio" +) + +// Hstore represents an hstore column that can be null or have null values +// associated with its keys. +type Hstore struct { + Map map[string]Text + Status Status +} + +func (dst *Hstore) Set(src interface{}) error { + if src == nil { + *dst = Hstore{Status: Null} + return nil + } + + switch value := src.(type) { + case map[string]string: + m := make(map[string]Text, len(value)) + for k, v := range value { + m[k] = Text{String: v, Status: Present} + } + *dst = Hstore{Map: m, Status: Present} + default: + return errors.Errorf("cannot convert %v to Hstore", src) + } + + return nil +} + +func (dst *Hstore) Get() interface{} { + switch dst.Status { + case Present: + return dst.Map + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Hstore) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *map[string]string: + *v = make(map[string]string, len(src.Map)) + for k, val := range src.Map { + if val.Status != Present { + return errors.Errorf("cannot decode %v into %T", src, dst) + } + (*v)[k] = val.String + } + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *Hstore) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Hstore{Status: Null} + return nil + } + + keys, values, err := parseHstore(string(src)) + if err != nil { + return err + } + + m := make(map[string]Text, len(keys)) + for i := range keys { + m[keys[i]] = values[i] + } + + *dst = Hstore{Map: m, Status: Present} + return nil +} + +func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Hstore{Status: Null} + return nil + } + + rp := 0 + + if len(src[rp:]) < 4 { + return errors.Errorf("hstore incomplete %v", src) + } + pairCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + m := make(map[string]Text, pairCount) + + for i := 0; i < pairCount; i++ { + if len(src[rp:]) < 4 { + return errors.Errorf("hstore incomplete %v", src) + } + keyLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + if len(src[rp:]) < keyLen { + return errors.Errorf("hstore incomplete %v", src) + } + key := string(src[rp : rp+keyLen]) + rp += keyLen + + if len(src[rp:]) < 4 { + return errors.Errorf("hstore incomplete %v", src) + } + valueLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + var valueBuf []byte + if valueLen >= 0 { + valueBuf = src[rp : rp+valueLen] + } + rp += valueLen + + var value Text + err := value.DecodeBinary(ci, valueBuf) + if err != nil { + return err + } + m[key] = value + } + + *dst = Hstore{Map: m, Status: Present} + + return nil +} + +func (src *Hstore) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + firstPair := true + + for k, v := range src.Map { + if firstPair { + firstPair = false + } else { + buf = append(buf, ',') + } + + buf = append(buf, quoteHstoreElementIfNeeded(k)...) + buf = append(buf, "=>"...) + + elemBuf, err := v.EncodeText(ci, nil) + if err != nil { + return nil, err + } + + if elemBuf == nil { + buf = append(buf, "NULL"...) + } else { + buf = append(buf, quoteHstoreElementIfNeeded(string(elemBuf))...) + } + } + + return buf, nil +} + +func (src *Hstore) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendInt32(buf, int32(len(src.Map))) + + var err error + for k, v := range src.Map { + buf = pgio.AppendInt32(buf, int32(len(k))) + buf = append(buf, k...) + + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := v.EncodeText(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, err +} + +var quoteHstoreReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) + +func quoteHstoreElement(src string) string { + return `"` + quoteArrayReplacer.Replace(src) + `"` +} + +func quoteHstoreElementIfNeeded(src string) string { + if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || strings.ContainsAny(src, ` {},"\=>`) { + return quoteArrayElement(src) + } + return src +} + +const ( + hsPre = iota + hsKey + hsSep + hsVal + hsNul + hsNext +) + +type hstoreParser struct { + str string + pos int +} + +func newHSP(in string) *hstoreParser { + return &hstoreParser{ + pos: 0, + str: in, + } +} + +func (p *hstoreParser) Consume() (r rune, end bool) { + if p.pos >= len(p.str) { + end = true + return + } + r, w := utf8.DecodeRuneInString(p.str[p.pos:]) + p.pos += w + return +} + +func (p *hstoreParser) Peek() (r rune, end bool) { + if p.pos >= len(p.str) { + end = true + return + } + r, _ = utf8.DecodeRuneInString(p.str[p.pos:]) + return +} + +// parseHstore parses the string representation of an hstore column (the same +// you would get from an ordinary SELECT) into two slices of keys and values. it +// is used internally in the default parsing of hstores. +func parseHstore(s string) (k []string, v []Text, err error) { + if s == "" { + return + } + + buf := bytes.Buffer{} + keys := []string{} + values := []Text{} + p := newHSP(s) + + r, end := p.Consume() + state := hsPre + + for !end { + switch state { + case hsPre: + if r == '"' { + state = hsKey + } else { + err = errors.New("String does not begin with \"") + } + case hsKey: + switch r { + case '"': //End of the key + if buf.Len() == 0 { + err = errors.New("Empty Key is invalid") + } else { + keys = append(keys, buf.String()) + buf = bytes.Buffer{} + state = hsSep + } + case '\\': //Potential escaped character + n, end := p.Consume() + switch { + case end: + err = errors.New("Found EOS in key, expecting character or \"") + case n == '"', n == '\\': + buf.WriteRune(n) + default: + buf.WriteRune(r) + buf.WriteRune(n) + } + default: //Any other character + buf.WriteRune(r) + } + case hsSep: + if r == '=' { + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after '=', expecting '>'") + case r == '>': + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'") + case r == '"': + state = hsVal + case r == 'N': + state = hsNul + default: + err = errors.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r) + } + default: + err = errors.Errorf("Invalid character after '=', expecting '>'") + } + } else { + err = errors.Errorf("Invalid character '%c' after value, expecting '='", r) + } + case hsVal: + switch r { + case '"': //End of the value + values = append(values, Text{String: buf.String(), Status: Present}) + buf = bytes.Buffer{} + state = hsNext + case '\\': //Potential escaped character + n, end := p.Consume() + switch { + case end: + err = errors.New("Found EOS in key, expecting character or \"") + case n == '"', n == '\\': + buf.WriteRune(n) + default: + buf.WriteRune(r) + buf.WriteRune(n) + } + default: //Any other character + buf.WriteRune(r) + } + case hsNul: + nulBuf := make([]rune, 3) + nulBuf[0] = r + for i := 1; i < 3; i++ { + r, end = p.Consume() + if end { + err = errors.New("Found EOS in NULL value") + return + } + nulBuf[i] = r + } + if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' { + values = append(values, Text{Status: Null}) + state = hsNext + } else { + err = errors.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) + } + case hsNext: + if r == ',' { + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after ',', expcting space") + case (unicode.IsSpace(r)): + r, end = p.Consume() + state = hsKey + default: + err = errors.Errorf("Invalid character '%c' after ', ', expecting \"", r) + } + } else { + err = errors.Errorf("Invalid character '%c' after value, expecting ','", r) + } + } + + if err != nil { + return + } + r, end = p.Consume() + } + if state != hsNext { + err = errors.New("Improperly formatted hstore") + return + } + k = keys + v = values + return +} + +// Scan implements the database/sql Scanner interface. +func (dst *Hstore) Scan(src interface{}) error { + if src == nil { + *dst = Hstore{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Hstore) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/hstore_array.go b/pgtype/hstore_array.go new file mode 100644 index 00000000..38ce457b --- /dev/null +++ b/pgtype/hstore_array.go @@ -0,0 +1,294 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type HstoreArray struct { + Elements []Hstore + Dimensions []ArrayDimension + Status Status +} + +func (dst *HstoreArray) Set(src interface{}) error { + switch value := src.(type) { + + case []map[string]string: + if value == nil { + *dst = HstoreArray{Status: Null} + } else if len(value) == 0 { + *dst = HstoreArray{Status: Present} + } else { + elements := make([]Hstore, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = HstoreArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Hstore", value) + } + + return nil +} + +func (dst *HstoreArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *HstoreArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]map[string]string: + *v = make([]map[string]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *HstoreArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = HstoreArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Hstore + + if len(uta.Elements) > 0 { + elements = make([]Hstore, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Hstore + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = HstoreArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *HstoreArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = HstoreArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = HstoreArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Hstore, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = HstoreArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *HstoreArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src *HstoreArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("hstore"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "hstore") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *HstoreArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *HstoreArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/hstore_array_test.go b/pgtype/hstore_array_test.go new file mode 100644 index 00000000..fcf08c49 --- /dev/null +++ b/pgtype/hstore_array_test.go @@ -0,0 +1,184 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestHstoreArrayTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustClose(t, conn) + + text := func(s string) pgtype.Text { + return pgtype.Text{String: s, Status: pgtype.Present} + } + + values := []pgtype.Hstore{ + pgtype.Hstore{Map: map[string]pgtype.Text{}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, + pgtype.Hstore{Status: pgtype.Null}, + } + + specialStrings := []string{ + `"`, + `'`, + `\`, + `\\`, + `=>`, + ` `, + `\ / / \\ => " ' " '`, + } + for _, s := range specialStrings { + // Special key values + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key + + // Special value values + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key + } + + src := &pgtype.HstoreArray{ + Elements: values, + Dimensions: []pgtype.ArrayDimension{{Length: int32(len(values)), LowerBound: 1}}, + Status: pgtype.Present, + } + + ps, err := conn.Prepare("test", "select $1::hstore[]") + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for _, fc := range formats { + ps.FieldDescriptions[0].FormatCode = fc.formatCode + vEncoder := testutil.ForceEncoder(src, fc.formatCode) + if vEncoder == nil { + t.Logf("%#v does not implement %v", src, fc.name) + continue + } + + var result pgtype.HstoreArray + err := conn.QueryRow("test", vEncoder).Scan(&result) + if err != nil { + t.Errorf("%v: %v", fc.name, err) + continue + } + + if result.Status != src.Status { + t.Errorf("%v: expected Status %v, got %v", fc.formatCode, src.Status, result.Status) + continue + } + + if len(result.Elements) != len(src.Elements) { + t.Errorf("%v: expected %v elements, got %v", fc.formatCode, len(src.Elements), len(result.Elements)) + continue + } + + for i := range result.Elements { + a := src.Elements[i] + b := result.Elements[i] + + if a.Status != b.Status { + t.Errorf("%v element idx %d: expected status %v, got %v", fc.formatCode, i, a.Status, b.Status) + } + + if len(a.Map) != len(b.Map) { + t.Errorf("%v element idx %d: expected %v pairs, got %v", fc.formatCode, i, len(a.Map), len(b.Map)) + } + + for k := range a.Map { + if a.Map[k] != b.Map[k] { + t.Errorf("%v element idx %d: expected key %v to be %v, got %v", fc.formatCode, i, k, a.Map[k], b.Map[k]) + } + } + } + } +} + +func TestHstoreArraySet(t *testing.T) { + successfulTests := []struct { + src []map[string]string + result pgtype.HstoreArray + }{ + { + src: []map[string]string{map[string]string{"foo": "bar"}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": pgtype.Text{String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + }, + } + + for i, tt := range successfulTests { + var dst pgtype.HstoreArray + err := dst.Set(tt.src) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(dst, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) + } + } +} + +func TestHstoreArrayAssignTo(t *testing.T) { + var m []map[string]string + + simpleTests := []struct { + src pgtype.HstoreArray + dst *[]map[string]string + expected []map[string]string + }{ + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": pgtype.Text{String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &m, + expected: []map[string]string{{"foo": "bar"}}}, + {src: pgtype.HstoreArray{Status: pgtype.Null}, dst: &m, expected: (([]map[string]string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(*tt.dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } +} diff --git a/pgtype/hstore_test.go b/pgtype/hstore_test.go new file mode 100644 index 00000000..dc2439fc --- /dev/null +++ b/pgtype/hstore_test.go @@ -0,0 +1,109 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestHstoreTranscode(t *testing.T) { + text := func(s string) pgtype.Text { + return pgtype.Text{String: s, Status: pgtype.Present} + } + + values := []interface{}{ + &pgtype.Hstore{Map: map[string]pgtype.Text{}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, + &pgtype.Hstore{Status: pgtype.Null}, + } + + specialStrings := []string{ + `"`, + `'`, + `\`, + `\\`, + `=>`, + ` `, + `\ / / \\ => " ' " '`, + } + for _, s := range specialStrings { + // Special key values + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key + + // Special value values + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key + } + + testutil.TestSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { + a := ai.(pgtype.Hstore) + b := bi.(pgtype.Hstore) + + if len(a.Map) != len(b.Map) || a.Status != b.Status { + return false + } + + for k := range a.Map { + if a.Map[k] != b.Map[k] { + return false + } + } + + return true + }) +} + +func TestHstoreSet(t *testing.T) { + successfulTests := []struct { + src map[string]string + result pgtype.Hstore + }{ + {src: map[string]string{"foo": "bar"}, result: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": pgtype.Text{String: "bar", Status: pgtype.Present}}, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var dst pgtype.Hstore + err := dst.Set(tt.src) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(dst, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) + } + } +} + +func TestHstoreAssignTo(t *testing.T) { + var m map[string]string + + simpleTests := []struct { + src pgtype.Hstore + dst *map[string]string + expected map[string]string + }{ + {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": pgtype.Text{String: "bar", Status: pgtype.Present}}, Status: pgtype.Present}, dst: &m, expected: map[string]string{"foo": "bar"}}, + {src: pgtype.Hstore{Status: pgtype.Null}, dst: &m, expected: ((map[string]string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(*tt.dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } +} diff --git a/pgtype/inet.go b/pgtype/inet.go new file mode 100644 index 00000000..01fc0e5b --- /dev/null +++ b/pgtype/inet.go @@ -0,0 +1,215 @@ +package pgtype + +import ( + "database/sql/driver" + "net" + + "github.com/pkg/errors" +) + +// Network address family is dependent on server socket.h value for AF_INET. +// In practice, all platforms appear to have the same value. See +// src/include/utils/inet.h for more information. +const ( + defaultAFInet = 2 + defaultAFInet6 = 3 +) + +// Inet represents both inet and cidr PostgreSQL types. +type Inet struct { + IPNet *net.IPNet + Status Status +} + +func (dst *Inet) Set(src interface{}) error { + if src == nil { + *dst = Inet{Status: Null} + return nil + } + + switch value := src.(type) { + case net.IPNet: + *dst = Inet{IPNet: &value, Status: Present} + case *net.IPNet: + *dst = Inet{IPNet: value, Status: Present} + case net.IP: + bitCount := len(value) * 8 + mask := net.CIDRMask(bitCount, bitCount) + *dst = Inet{IPNet: &net.IPNet{Mask: mask, IP: value}, Status: Present} + case string: + _, ipnet, err := net.ParseCIDR(value) + if err != nil { + return err + } + *dst = Inet{IPNet: ipnet, Status: Present} + default: + if originalSrc, ok := underlyingPtrType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Inet", value) + } + + return nil +} + +func (dst *Inet) Get() interface{} { + switch dst.Status { + case Present: + return dst.IPNet + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Inet) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *net.IPNet: + *v = net.IPNet{ + IP: make(net.IP, len(src.IPNet.IP)), + Mask: make(net.IPMask, len(src.IPNet.Mask)), + } + copy(v.IP, src.IPNet.IP) + copy(v.Mask, src.IPNet.Mask) + return nil + case *net.IP: + if oneCount, bitCount := src.IPNet.Mask.Size(); oneCount != bitCount { + return errors.Errorf("cannot assign %v to %T", src, dst) + } + *v = make(net.IP, len(src.IPNet.IP)) + copy(*v, src.IPNet.IP) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Inet{Status: Null} + return nil + } + + var ipnet *net.IPNet + var err error + + if ip := net.ParseIP(string(src)); ip != nil { + ipv4 := ip.To4() + if ipv4 != nil { + ip = ipv4 + } + bitCount := len(ip) * 8 + mask := net.CIDRMask(bitCount, bitCount) + ipnet = &net.IPNet{Mask: mask, IP: ip} + } else { + _, ipnet, err = net.ParseCIDR(string(src)) + if err != nil { + return err + } + } + + *dst = Inet{IPNet: ipnet, Status: Present} + return nil +} + +func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Inet{Status: Null} + return nil + } + + if len(src) != 8 && len(src) != 20 { + return errors.Errorf("Received an invalid size for a inet: %d", len(src)) + } + + // ignore family + bits := src[1] + // ignore is_cidr + addressLength := src[3] + + var ipnet net.IPNet + ipnet.IP = make(net.IP, int(addressLength)) + copy(ipnet.IP, src[4:]) + ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8) + + *dst = Inet{IPNet: &ipnet, Status: Present} + + return nil +} + +func (src *Inet) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, src.IPNet.String()...), nil +} + +// EncodeBinary encodes src into w. +func (src *Inet) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var family byte + switch len(src.IPNet.IP) { + case net.IPv4len: + family = defaultAFInet + case net.IPv6len: + family = defaultAFInet6 + default: + return nil, errors.Errorf("Unexpected IP length: %v", len(src.IPNet.IP)) + } + + buf = append(buf, family) + + ones, _ := src.IPNet.Mask.Size() + buf = append(buf, byte(ones)) + + // is_cidr is ignored on server + buf = append(buf, 0) + + buf = append(buf, byte(len(src.IPNet.IP))) + + return append(buf, src.IPNet.IP...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Inet) Scan(src interface{}) error { + if src == nil { + *dst = Inet{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Inet) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go new file mode 100644 index 00000000..3ece23eb --- /dev/null +++ b/pgtype/inet_array.go @@ -0,0 +1,323 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "net" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type InetArray struct { + Elements []Inet + Dimensions []ArrayDimension + Status Status +} + +func (dst *InetArray) Set(src interface{}) error { + switch value := src.(type) { + + case []*net.IPNet: + if value == nil { + *dst = InetArray{Status: Null} + } else if len(value) == 0 { + *dst = InetArray{Status: Present} + } else { + elements := make([]Inet, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = InetArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []net.IP: + if value == nil { + *dst = InetArray{Status: Null} + } else if len(value) == 0 { + *dst = InetArray{Status: Present} + } else { + elements := make([]Inet, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = InetArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Inet", value) + } + + return nil +} + +func (dst *InetArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *InetArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]*net.IPNet: + *v = make([]*net.IPNet, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]net.IP: + *v = make([]net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *InetArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = InetArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Inet + + if len(uta.Elements) > 0 { + elements = make([]Inet, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Inet + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = InetArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *InetArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = InetArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = InetArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Inet, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = InetArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *InetArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src *InetArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("inet"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "inet") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *InetArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *InetArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/inet_array_test.go b/pgtype/inet_array_test.go new file mode 100644 index 00000000..3e2b6a3c --- /dev/null +++ b/pgtype/inet_array_test.go @@ -0,0 +1,165 @@ +package pgtype_test + +import ( + "net" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestInetArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "inet[]", []interface{}{ + &pgtype.InetArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.InetArray{ + Elements: []pgtype.Inet{ + pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.InetArray{Status: pgtype.Null}, + &pgtype.InetArray{ + Elements: []pgtype.Inet{ + pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.Inet{Status: pgtype.Null}, + pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.InetArray{ + Elements: []pgtype.Inet{ + pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestInetArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.InetArray + }{ + { + source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]*net.IPNet)(nil)), + result: pgtype.InetArray{Status: pgtype.Null}, + }, + { + source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]net.IP)(nil)), + result: pgtype.InetArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.InetArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInetArrayAssignTo(t *testing.T) { + var ipnetSlice []*net.IPNet + var ipSlice []net.IP + + simpleTests := []struct { + src pgtype.InetArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{nil}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{nil}, + }, + { + src: pgtype.InetArray{Status: pgtype.Null}, + dst: &ipnetSlice, + expected: (([]*net.IPNet)(nil)), + }, + { + src: pgtype.InetArray{Status: pgtype.Null}, + dst: &ipSlice, + expected: (([]net.IP)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/inet_test.go b/pgtype/inet_test.go new file mode 100644 index 00000000..32d66999 --- /dev/null +++ b/pgtype/inet_test.go @@ -0,0 +1,115 @@ +package pgtype_test + +import ( + "net" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestInetTranscode(t *testing.T) { + for _, pgTypeName := range []string{"inet", "cidr"} { + testutil.TestSuccessfulTranscode(t, pgTypeName, []interface{}{ + &pgtype.Inet{IPNet: mustParseCIDR(t, "0.0.0.0/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::/0"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::1/128"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + &pgtype.Inet{Status: pgtype.Null}, + }) + } +} + +func TestInetSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Inet + }{ + {source: mustParseCIDR(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Inet + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInetAssignTo(t *testing.T) { + var ipnet net.IPNet + var pipnet *net.IPNet + var ip net.IP + var pip *net.IP + + simpleTests := []struct { + src pgtype.Inet + dst interface{} + expected interface{} + }{ + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, + {src: pgtype.Inet{Status: pgtype.Null}, dst: &pipnet, expected: ((*net.IPNet)(nil))}, + {src: pgtype.Inet{Status: pgtype.Null}, dst: &pip, expected: ((*net.IP)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %#v, but result was %#v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Inet + dst interface{} + expected interface{} + }{ + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Inet + dst interface{} + }{ + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.0/16"), Status: pgtype.Present}, dst: &ip}, + {src: pgtype.Inet{Status: pgtype.Null}, dst: &ipnet}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/int2.go b/pgtype/int2.go new file mode 100644 index 00000000..45bce93c --- /dev/null +++ b/pgtype/int2.go @@ -0,0 +1,209 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Int2 struct { + Int int16 + Status Status +} + +func (dst *Int2) Set(src interface{}) error { + if src == nil { + *dst = Int2{Status: Null} + return nil + } + + switch value := src.(type) { + case int8: + *dst = Int2{Int: int16(value), Status: Present} + case uint8: + *dst = Int2{Int: int16(value), Status: Present} + case int16: + *dst = Int2{Int: int16(value), Status: Present} + case uint16: + if value > math.MaxInt16 { + return errors.Errorf("%d is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Status: Present} + case int32: + if value < math.MinInt16 { + return errors.Errorf("%d is greater than maximum value for Int2", value) + } + if value > math.MaxInt16 { + return errors.Errorf("%d is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Status: Present} + case uint32: + if value > math.MaxInt16 { + return errors.Errorf("%d is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Status: Present} + case int64: + if value < math.MinInt16 { + return errors.Errorf("%d is greater than maximum value for Int2", value) + } + if value > math.MaxInt16 { + return errors.Errorf("%d is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Status: Present} + case uint64: + if value > math.MaxInt16 { + return errors.Errorf("%d is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Status: Present} + case int: + if value < math.MinInt16 { + return errors.Errorf("%d is greater than maximum value for Int2", value) + } + if value > math.MaxInt16 { + return errors.Errorf("%d is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Status: Present} + case uint: + if value > math.MaxInt16 { + return errors.Errorf("%d is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Status: Present} + case string: + num, err := strconv.ParseInt(value, 10, 16) + if err != nil { + return err + } + *dst = Int2{Int: int16(num), Status: Present} + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Int2", value) + } + + return nil +} + +func (dst *Int2) Get() interface{} { + switch dst.Status { + case Present: + return dst.Int + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Int2) AssignTo(dst interface{}) error { + return int64AssignTo(int64(src.Int), src.Status, dst) +} + +func (dst *Int2) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int2{Status: Null} + return nil + } + + n, err := strconv.ParseInt(string(src), 10, 16) + if err != nil { + return err + } + + *dst = Int2{Int: int16(n), Status: Present} + return nil +} + +func (dst *Int2) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int2{Status: Null} + return nil + } + + if len(src) != 2 { + return errors.Errorf("invalid length for int2: %v", len(src)) + } + + n := int16(binary.BigEndian.Uint16(src)) + *dst = Int2{Int: n, Status: Present} + return nil +} + +func (src *Int2) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, strconv.FormatInt(int64(src.Int), 10)...), nil +} + +func (src *Int2) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return pgio.AppendInt16(buf, src.Int), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int2) Scan(src interface{}) error { + if src == nil { + *dst = Int2{Status: Null} + return nil + } + + switch src := src.(type) { + case int64: + if src < math.MinInt16 { + return errors.Errorf("%d is greater than maximum value for Int2", src) + } + if src > math.MaxInt16 { + return errors.Errorf("%d is greater than maximum value for Int2", src) + } + *dst = Int2{Int: int16(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Int2) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Int), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} + +func (src *Int2) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + return []byte(strconv.FormatInt(int64(src.Int), 10)), nil + case Null: + return []byte("null"), nil + case Undefined: + return []byte("undefined"), nil + } + + return nil, errBadStatus +} diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go new file mode 100644 index 00000000..e939411b --- /dev/null +++ b/pgtype/int2_array.go @@ -0,0 +1,322 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Int2Array struct { + Elements []Int2 + Dimensions []ArrayDimension + Status Status +} + +func (dst *Int2Array) Set(src interface{}) error { + switch value := src.(type) { + + case []int16: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint16: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Int2", value) + } + + return nil +} + +func (dst *Int2Array) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Int2Array) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]int16: + *v = make([]int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint16: + *v = make([]uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int2Array{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Int2 + + if len(uta.Elements) > 0 { + elements = make([]Int2, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Int2 + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Int2Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *Int2Array) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int2Array{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = Int2Array{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Int2, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = Int2Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *Int2Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src *Int2Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("int2"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "int2") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int2Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Int2Array) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/int2_array_test.go b/pgtype/int2_array_test.go new file mode 100644 index 00000000..0adc1aef --- /dev/null +++ b/pgtype/int2_array_test.go @@ -0,0 +1,177 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestInt2ArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int2[]", []interface{}{ + &pgtype.Int2Array{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.Int2Array{ + Elements: []pgtype.Int2{ + pgtype.Int2{Int: 1, Status: pgtype.Present}, + pgtype.Int2{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int2Array{Status: pgtype.Null}, + &pgtype.Int2Array{ + Elements: []pgtype.Int2{ + pgtype.Int2{Int: 1, Status: pgtype.Present}, + pgtype.Int2{Int: 2, Status: pgtype.Present}, + pgtype.Int2{Int: 3, Status: pgtype.Present}, + pgtype.Int2{Int: 4, Status: pgtype.Present}, + pgtype.Int2{Status: pgtype.Null}, + pgtype.Int2{Int: 6, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int2Array{ + Elements: []pgtype.Int2{ + pgtype.Int2{Int: 1, Status: pgtype.Present}, + pgtype.Int2{Int: 2, Status: pgtype.Present}, + pgtype.Int2{Int: 3, Status: pgtype.Present}, + pgtype.Int2{Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestInt2ArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int2Array + }{ + { + source: []int16{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []uint16{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]int16)(nil)), + result: pgtype.Int2Array{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Int2Array + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt2ArrayAssignTo(t *testing.T) { + var int16Slice []int16 + var uint16Slice []uint16 + var namedInt16Slice _int16Slice + + simpleTests := []struct { + src pgtype.Int2Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int16Slice, + expected: []int16{1}, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint16Slice, + expected: []uint16{1}, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedInt16Slice, + expected: _int16Slice{1}, + }, + { + src: pgtype.Int2Array{Status: pgtype.Null}, + dst: &int16Slice, + expected: (([]int16)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int2Array + dst interface{} + }{ + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int16Slice, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: -1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint16Slice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/int2_test.go b/pgtype/int2_test.go new file mode 100644 index 00000000..d20bf0ed --- /dev/null +++ b/pgtype/int2_test.go @@ -0,0 +1,142 @@ +package pgtype_test + +import ( + "math" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestInt2Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int2", []interface{}{ + &pgtype.Int2{Int: math.MinInt16, Status: pgtype.Present}, + &pgtype.Int2{Int: -1, Status: pgtype.Present}, + &pgtype.Int2{Int: 0, Status: pgtype.Present}, + &pgtype.Int2{Int: 1, Status: pgtype.Present}, + &pgtype.Int2{Int: math.MaxInt16, Status: pgtype.Present}, + &pgtype.Int2{Int: 0, Status: pgtype.Null}, + }) +} + +func TestInt2Set(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int2 + }{ + {source: int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Int2 + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt2AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + + simpleTests := []struct { + src pgtype.Int2 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Int2{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Int2{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Int2 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int2 + dst interface{} + }{ + {src: pgtype.Int2{Int: 150, Status: pgtype.Present}, dst: &i8}, + {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.Int2{Int: 0, Status: pgtype.Null}, dst: &i16}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/int4.go b/pgtype/int4.go new file mode 100644 index 00000000..a3499fef --- /dev/null +++ b/pgtype/int4.go @@ -0,0 +1,200 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Int4 struct { + Int int32 + Status Status +} + +func (dst *Int4) Set(src interface{}) error { + if src == nil { + *dst = Int4{Status: Null} + return nil + } + + switch value := src.(type) { + case int8: + *dst = Int4{Int: int32(value), Status: Present} + case uint8: + *dst = Int4{Int: int32(value), Status: Present} + case int16: + *dst = Int4{Int: int32(value), Status: Present} + case uint16: + *dst = Int4{Int: int32(value), Status: Present} + case int32: + *dst = Int4{Int: int32(value), Status: Present} + case uint32: + if value > math.MaxInt32 { + return errors.Errorf("%d is greater than maximum value for Int4", value) + } + *dst = Int4{Int: int32(value), Status: Present} + case int64: + if value < math.MinInt32 { + return errors.Errorf("%d is greater than maximum value for Int4", value) + } + if value > math.MaxInt32 { + return errors.Errorf("%d is greater than maximum value for Int4", value) + } + *dst = Int4{Int: int32(value), Status: Present} + case uint64: + if value > math.MaxInt32 { + return errors.Errorf("%d is greater than maximum value for Int4", value) + } + *dst = Int4{Int: int32(value), Status: Present} + case int: + if value < math.MinInt32 { + return errors.Errorf("%d is greater than maximum value for Int4", value) + } + if value > math.MaxInt32 { + return errors.Errorf("%d is greater than maximum value for Int4", value) + } + *dst = Int4{Int: int32(value), Status: Present} + case uint: + if value > math.MaxInt32 { + return errors.Errorf("%d is greater than maximum value for Int4", value) + } + *dst = Int4{Int: int32(value), Status: Present} + case string: + num, err := strconv.ParseInt(value, 10, 32) + if err != nil { + return err + } + *dst = Int4{Int: int32(num), Status: Present} + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Int4", value) + } + + return nil +} + +func (dst *Int4) Get() interface{} { + switch dst.Status { + case Present: + return dst.Int + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Int4) AssignTo(dst interface{}) error { + return int64AssignTo(int64(src.Int), src.Status, dst) +} + +func (dst *Int4) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4{Status: Null} + return nil + } + + n, err := strconv.ParseInt(string(src), 10, 32) + if err != nil { + return err + } + + *dst = Int4{Int: int32(n), Status: Present} + return nil +} + +func (dst *Int4) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4{Status: Null} + return nil + } + + if len(src) != 4 { + return errors.Errorf("invalid length for int4: %v", len(src)) + } + + n := int32(binary.BigEndian.Uint32(src)) + *dst = Int4{Int: n, Status: Present} + return nil +} + +func (src *Int4) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, strconv.FormatInt(int64(src.Int), 10)...), nil +} + +func (src *Int4) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return pgio.AppendInt32(buf, src.Int), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int4) Scan(src interface{}) error { + if src == nil { + *dst = Int4{Status: Null} + return nil + } + + switch src := src.(type) { + case int64: + if src < math.MinInt32 { + return errors.Errorf("%d is greater than maximum value for Int4", src) + } + if src > math.MaxInt32 { + return errors.Errorf("%d is greater than maximum value for Int4", src) + } + *dst = Int4{Int: int32(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Int4) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Int), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} + +func (src *Int4) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + return []byte(strconv.FormatInt(int64(src.Int), 10)), nil + case Null: + return []byte("null"), nil + case Undefined: + return []byte("undefined"), nil + } + + return nil, errBadStatus +} diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go new file mode 100644 index 00000000..1a907d2e --- /dev/null +++ b/pgtype/int4_array.go @@ -0,0 +1,322 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Int4Array struct { + Elements []Int4 + Dimensions []ArrayDimension + Status Status +} + +func (dst *Int4Array) Set(src interface{}) error { + switch value := src.(type) { + + case []int32: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint32: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Int4", value) + } + + return nil +} + +func (dst *Int4Array) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Int4Array) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]int32: + *v = make([]int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint32: + *v = make([]uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *Int4Array) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4Array{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Int4 + + if len(uta.Elements) > 0 { + elements = make([]Int4, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Int4 + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Int4Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *Int4Array) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4Array{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = Int4Array{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Int4, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = Int4Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *Int4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src *Int4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("int4"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "int4") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int4Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Int4Array) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/int4_array_test.go b/pgtype/int4_array_test.go new file mode 100644 index 00000000..6fad18bb --- /dev/null +++ b/pgtype/int4_array_test.go @@ -0,0 +1,177 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestInt4ArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int4[]", []interface{}{ + &pgtype.Int4Array{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + pgtype.Int4{Int: 1, Status: pgtype.Present}, + pgtype.Int4{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int4Array{Status: pgtype.Null}, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + pgtype.Int4{Int: 1, Status: pgtype.Present}, + pgtype.Int4{Int: 2, Status: pgtype.Present}, + pgtype.Int4{Int: 3, Status: pgtype.Present}, + pgtype.Int4{Int: 4, Status: pgtype.Present}, + pgtype.Int4{Status: pgtype.Null}, + pgtype.Int4{Int: 6, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + pgtype.Int4{Int: 1, Status: pgtype.Present}, + pgtype.Int4{Int: 2, Status: pgtype.Present}, + pgtype.Int4{Int: 3, Status: pgtype.Present}, + pgtype.Int4{Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestInt4ArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int4Array + }{ + { + source: []int32{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []uint32{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]int32)(nil)), + result: pgtype.Int4Array{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Int4Array + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt4ArrayAssignTo(t *testing.T) { + var int32Slice []int32 + var uint32Slice []uint32 + var namedInt32Slice _int32Slice + + simpleTests := []struct { + src pgtype.Int4Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int32Slice, + expected: []int32{1}, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint32Slice, + expected: []uint32{1}, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedInt32Slice, + expected: _int32Slice{1}, + }, + { + src: pgtype.Int4Array{Status: pgtype.Null}, + dst: &int32Slice, + expected: (([]int32)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int4Array + dst interface{} + }{ + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int32Slice, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: -1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint32Slice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/int4_test.go b/pgtype/int4_test.go new file mode 100644 index 00000000..02f5409f --- /dev/null +++ b/pgtype/int4_test.go @@ -0,0 +1,143 @@ +package pgtype_test + +import ( + "math" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestInt4Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int4", []interface{}{ + &pgtype.Int4{Int: math.MinInt32, Status: pgtype.Present}, + &pgtype.Int4{Int: -1, Status: pgtype.Present}, + &pgtype.Int4{Int: 0, Status: pgtype.Present}, + &pgtype.Int4{Int: 1, Status: pgtype.Present}, + &pgtype.Int4{Int: math.MaxInt32, Status: pgtype.Present}, + &pgtype.Int4{Int: 0, Status: pgtype.Null}, + }) +} + +func TestInt4Set(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int4 + }{ + {source: int8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Int4 + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt4AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + + simpleTests := []struct { + src pgtype.Int4 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Int4{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Int4{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Int4 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int4 + dst interface{} + }{ + {src: pgtype.Int4{Int: 150, Status: pgtype.Present}, dst: &i8}, + {src: pgtype.Int4{Int: 40000, Status: pgtype.Present}, dst: &i16}, + {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.Int4{Int: 0, Status: pgtype.Null}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/int4range.go b/pgtype/int4range.go new file mode 100644 index 00000000..95ad1521 --- /dev/null +++ b/pgtype/int4range.go @@ -0,0 +1,250 @@ +package pgtype + +import ( + "database/sql/driver" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Int4range struct { + Lower Int4 + Upper Int4 + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Int4range) Set(src interface{}) error { + return errors.Errorf("cannot convert %v to Int4range", src) +} + +func (dst *Int4range) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Int4range) AssignTo(dst interface{}) error { + return errors.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Int4range) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4range{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Int4range{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Int4range) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4range{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Int4range{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Int4range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) + } + + var err error + + if src.LowerType != Unbounded { + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if src.UpperType != Unbounded { + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) + } + + return buf, nil +} + +func (src Int4range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) + } + + buf = append(buf, rangeType) + + var err error + + if src.LowerType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if src.UpperType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int4range) Scan(src interface{}) error { + if src == nil { + *dst = Int4range{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int4range) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/int4range_test.go b/pgtype/int4range_test.go new file mode 100644 index 00000000..088097d8 --- /dev/null +++ b/pgtype/int4range_test.go @@ -0,0 +1,26 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestInt4rangeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int4range", []interface{}{ + &pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + &pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + &pgtype.Int4range{Lower: pgtype.Int4{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int4{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + &pgtype.Int4range{Status: pgtype.Null}, + }) +} + +func TestInt4rangeNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select int4range(1, 10, '(]')", + Value: pgtype.Int4range{Lower: pgtype.Int4{Int: 2, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 11, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + }, + }) +} diff --git a/pgtype/int8.go b/pgtype/int8.go new file mode 100644 index 00000000..d671eda7 --- /dev/null +++ b/pgtype/int8.go @@ -0,0 +1,186 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Int8 struct { + Int int64 + Status Status +} + +func (dst *Int8) Set(src interface{}) error { + if src == nil { + *dst = Int8{Status: Null} + return nil + } + + switch value := src.(type) { + case int8: + *dst = Int8{Int: int64(value), Status: Present} + case uint8: + *dst = Int8{Int: int64(value), Status: Present} + case int16: + *dst = Int8{Int: int64(value), Status: Present} + case uint16: + *dst = Int8{Int: int64(value), Status: Present} + case int32: + *dst = Int8{Int: int64(value), Status: Present} + case uint32: + *dst = Int8{Int: int64(value), Status: Present} + case int64: + *dst = Int8{Int: int64(value), Status: Present} + case uint64: + if value > math.MaxInt64 { + return errors.Errorf("%d is greater than maximum value for Int8", value) + } + *dst = Int8{Int: int64(value), Status: Present} + case int: + if int64(value) < math.MinInt64 { + return errors.Errorf("%d is greater than maximum value for Int8", value) + } + if int64(value) > math.MaxInt64 { + return errors.Errorf("%d is greater than maximum value for Int8", value) + } + *dst = Int8{Int: int64(value), Status: Present} + case uint: + if uint64(value) > math.MaxInt64 { + return errors.Errorf("%d is greater than maximum value for Int8", value) + } + *dst = Int8{Int: int64(value), Status: Present} + case string: + num, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return err + } + *dst = Int8{Int: num, Status: Present} + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Int8", value) + } + + return nil +} + +func (dst *Int8) Get() interface{} { + switch dst.Status { + case Present: + return dst.Int + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Int8) AssignTo(dst interface{}) error { + return int64AssignTo(int64(src.Int), src.Status, dst) +} + +func (dst *Int8) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8{Status: Null} + return nil + } + + n, err := strconv.ParseInt(string(src), 10, 64) + if err != nil { + return err + } + + *dst = Int8{Int: n, Status: Present} + return nil +} + +func (dst *Int8) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8{Status: Null} + return nil + } + + if len(src) != 8 { + return errors.Errorf("invalid length for int8: %v", len(src)) + } + + n := int64(binary.BigEndian.Uint64(src)) + + *dst = Int8{Int: n, Status: Present} + return nil +} + +func (src *Int8) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, strconv.FormatInt(src.Int, 10)...), nil +} + +func (src *Int8) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return pgio.AppendInt64(buf, src.Int), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int8) Scan(src interface{}) error { + if src == nil { + *dst = Int8{Status: Null} + return nil + } + + switch src := src.(type) { + case int64: + *dst = Int8{Int: src, Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Int8) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Int), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} + +func (src *Int8) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + return []byte(strconv.FormatInt(src.Int, 10)), nil + case Null: + return []byte("null"), nil + case Undefined: + return []byte("undefined"), nil + } + + return nil, errBadStatus +} diff --git a/pgtype/int8_array.go b/pgtype/int8_array.go new file mode 100644 index 00000000..4f3ab4dc --- /dev/null +++ b/pgtype/int8_array.go @@ -0,0 +1,322 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Int8Array struct { + Elements []Int8 + Dimensions []ArrayDimension + Status Status +} + +func (dst *Int8Array) Set(src interface{}) error { + switch value := src.(type) { + + case []int64: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint64: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Int8", value) + } + + return nil +} + +func (dst *Int8Array) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Int8Array) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]int64: + *v = make([]int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint64: + *v = make([]uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *Int8Array) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8Array{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Int8 + + if len(uta.Elements) > 0 { + elements = make([]Int8, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Int8 + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Int8Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *Int8Array) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8Array{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = Int8Array{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Int8, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = Int8Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *Int8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src *Int8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("int8"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "int8") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int8Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Int8Array) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/int8_array_test.go b/pgtype/int8_array_test.go new file mode 100644 index 00000000..4f5c4f9a --- /dev/null +++ b/pgtype/int8_array_test.go @@ -0,0 +1,177 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestInt8ArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int8[]", []interface{}{ + &pgtype.Int8Array{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.Int8Array{ + Elements: []pgtype.Int8{ + pgtype.Int8{Int: 1, Status: pgtype.Present}, + pgtype.Int8{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int8Array{Status: pgtype.Null}, + &pgtype.Int8Array{ + Elements: []pgtype.Int8{ + pgtype.Int8{Int: 1, Status: pgtype.Present}, + pgtype.Int8{Int: 2, Status: pgtype.Present}, + pgtype.Int8{Int: 3, Status: pgtype.Present}, + pgtype.Int8{Int: 4, Status: pgtype.Present}, + pgtype.Int8{Status: pgtype.Null}, + pgtype.Int8{Int: 6, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int8Array{ + Elements: []pgtype.Int8{ + pgtype.Int8{Int: 1, Status: pgtype.Present}, + pgtype.Int8{Int: 2, Status: pgtype.Present}, + pgtype.Int8{Int: 3, Status: pgtype.Present}, + pgtype.Int8{Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestInt8ArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int8Array + }{ + { + source: []int64{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []uint64{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]int64)(nil)), + result: pgtype.Int8Array{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Int8Array + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt8ArrayAssignTo(t *testing.T) { + var int64Slice []int64 + var uint64Slice []uint64 + var namedInt64Slice _int64Slice + + simpleTests := []struct { + src pgtype.Int8Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int64Slice, + expected: []int64{1}, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint64Slice, + expected: []uint64{1}, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedInt64Slice, + expected: _int64Slice{1}, + }, + { + src: pgtype.Int8Array{Status: pgtype.Null}, + dst: &int64Slice, + expected: (([]int64)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int8Array + dst interface{} + }{ + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int64Slice, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: -1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint64Slice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/int8_test.go b/pgtype/int8_test.go new file mode 100644 index 00000000..0b3bb3eb --- /dev/null +++ b/pgtype/int8_test.go @@ -0,0 +1,144 @@ +package pgtype_test + +import ( + "math" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestInt8Transcode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "int8", []interface{}{ + &pgtype.Int8{Int: math.MinInt64, Status: pgtype.Present}, + &pgtype.Int8{Int: -1, Status: pgtype.Present}, + &pgtype.Int8{Int: 0, Status: pgtype.Present}, + &pgtype.Int8{Int: 1, Status: pgtype.Present}, + &pgtype.Int8{Int: math.MaxInt64, Status: pgtype.Present}, + &pgtype.Int8{Int: 0, Status: pgtype.Null}, + }) +} + +func TestInt8Set(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int8 + }{ + {source: int8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Int8 + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt8AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + + simpleTests := []struct { + src pgtype.Int8 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Int8{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Int8{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Int8 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int8 + dst interface{} + }{ + {src: pgtype.Int8{Int: 150, Status: pgtype.Present}, dst: &i8}, + {src: pgtype.Int8{Int: 40000, Status: pgtype.Present}, dst: &i16}, + {src: pgtype.Int8{Int: 5000000000, Status: pgtype.Present}, dst: &i32}, + {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.Int8{Int: 0, Status: pgtype.Null}, dst: &i64}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/int8range.go b/pgtype/int8range.go new file mode 100644 index 00000000..61d860d3 --- /dev/null +++ b/pgtype/int8range.go @@ -0,0 +1,250 @@ +package pgtype + +import ( + "database/sql/driver" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Int8range struct { + Lower Int8 + Upper Int8 + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Int8range) Set(src interface{}) error { + return errors.Errorf("cannot convert %v to Int8range", src) +} + +func (dst *Int8range) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Int8range) AssignTo(dst interface{}) error { + return errors.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Int8range) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8range{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Int8range{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Int8range) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8range{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Int8range{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Int8range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) + } + + var err error + + if src.LowerType != Unbounded { + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if src.UpperType != Unbounded { + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) + } + + return buf, nil +} + +func (src Int8range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) + } + + buf = append(buf, rangeType) + + var err error + + if src.LowerType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if src.UpperType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int8range) Scan(src interface{}) error { + if src == nil { + *dst = Int8range{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int8range) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/int8range_test.go b/pgtype/int8range_test.go new file mode 100644 index 00000000..c039ec65 --- /dev/null +++ b/pgtype/int8range_test.go @@ -0,0 +1,26 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestInt8rangeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "Int8range", []interface{}{ + &pgtype.Int8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + &pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + &pgtype.Int8range{Lower: pgtype.Int8{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int8{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + &pgtype.Int8range{Status: pgtype.Null}, + }) +} + +func TestInt8rangeNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select Int8range(1, 10, '(]')", + Value: pgtype.Int8range{Lower: pgtype.Int8{Int: 2, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 11, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + }, + }) +} diff --git a/pgtype/interval.go b/pgtype/interval.go new file mode 100644 index 00000000..799ce53a --- /dev/null +++ b/pgtype/interval.go @@ -0,0 +1,250 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "strconv" + "strings" + "time" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +const ( + microsecondsPerSecond = 1000000 + microsecondsPerMinute = 60 * microsecondsPerSecond + microsecondsPerHour = 60 * microsecondsPerMinute +) + +type Interval struct { + Microseconds int64 + Days int32 + Months int32 + Status Status +} + +func (dst *Interval) Set(src interface{}) error { + if src == nil { + *dst = Interval{Status: Null} + return nil + } + + switch value := src.(type) { + case time.Duration: + *dst = Interval{Microseconds: int64(value) / 1000, Status: Present} + default: + if originalSrc, ok := underlyingPtrType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Interval", value) + } + + return nil +} + +func (dst *Interval) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Interval) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *time.Duration: + if src.Days > 0 || src.Months > 0 { + return errors.Errorf("interval with months or days cannot be decoded into %T", dst) + } + *v = time.Duration(src.Microseconds) * time.Microsecond + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Interval{Status: Null} + return nil + } + + var microseconds int64 + var days int32 + var months int32 + + parts := strings.Split(string(src), " ") + + for i := 0; i < len(parts)-1; i += 2 { + scalar, err := strconv.ParseInt(parts[i], 10, 64) + if err != nil { + return errors.Errorf("bad interval format") + } + + switch parts[i+1] { + case "year", "years": + months += int32(scalar * 12) + case "mon", "mons": + months += int32(scalar) + case "day", "days": + days = int32(scalar) + } + } + + if len(parts)%2 == 1 { + timeParts := strings.SplitN(parts[len(parts)-1], ":", 3) + if len(timeParts) != 3 { + return errors.Errorf("bad interval format") + } + + var negative bool + if timeParts[0][0] == '-' { + negative = true + timeParts[0] = timeParts[0][1:] + } + + hours, err := strconv.ParseInt(timeParts[0], 10, 64) + if err != nil { + return errors.Errorf("bad interval hour format: %s", timeParts[0]) + } + + minutes, err := strconv.ParseInt(timeParts[1], 10, 64) + if err != nil { + return errors.Errorf("bad interval minute format: %s", timeParts[1]) + } + + secondParts := strings.SplitN(timeParts[2], ".", 2) + + seconds, err := strconv.ParseInt(secondParts[0], 10, 64) + if err != nil { + return errors.Errorf("bad interval second format: %s", secondParts[0]) + } + + var uSeconds int64 + if len(secondParts) == 2 { + uSeconds, err = strconv.ParseInt(secondParts[1], 10, 64) + if err != nil { + return errors.Errorf("bad interval decimal format: %s", secondParts[1]) + } + + for i := 0; i < 6-len(secondParts[1]); i++ { + uSeconds *= 10 + } + } + + microseconds = hours * microsecondsPerHour + microseconds += minutes * microsecondsPerMinute + microseconds += seconds * microsecondsPerSecond + microseconds += uSeconds + + if negative { + microseconds = -microseconds + } + } + + *dst = Interval{Months: months, Days: days, Microseconds: microseconds, Status: Present} + return nil +} + +func (dst *Interval) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Interval{Status: Null} + return nil + } + + if len(src) != 16 { + return errors.Errorf("Received an invalid size for a interval: %d", len(src)) + } + + microseconds := int64(binary.BigEndian.Uint64(src)) + days := int32(binary.BigEndian.Uint32(src[8:])) + months := int32(binary.BigEndian.Uint32(src[12:])) + + *dst = Interval{Microseconds: microseconds, Days: days, Months: months, Status: Present} + return nil +} + +func (src *Interval) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if src.Months != 0 { + buf = append(buf, strconv.FormatInt(int64(src.Months), 10)...) + buf = append(buf, " mon "...) + } + + if src.Days != 0 { + buf = append(buf, strconv.FormatInt(int64(src.Days), 10)...) + buf = append(buf, " day "...) + } + + absMicroseconds := src.Microseconds + if absMicroseconds < 0 { + absMicroseconds = -absMicroseconds + buf = append(buf, '-') + } + + hours := absMicroseconds / microsecondsPerHour + minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute + seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond + microseconds := absMicroseconds % microsecondsPerSecond + + timeStr := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, microseconds) + return append(buf, timeStr...), nil +} + +// EncodeBinary encodes src into w. +func (src *Interval) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendInt64(buf, src.Microseconds) + buf = pgio.AppendInt32(buf, src.Days) + return pgio.AppendInt32(buf, src.Months), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Interval) Scan(src interface{}) error { + if src == nil { + *dst = Interval{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Interval) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/interval_test.go b/pgtype/interval_test.go new file mode 100644 index 00000000..76ea3240 --- /dev/null +++ b/pgtype/interval_test.go @@ -0,0 +1,63 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestIntervalTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "interval", []interface{}{ + &pgtype.Interval{Microseconds: 1, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, + &pgtype.Interval{Days: 1, Status: pgtype.Present}, + &pgtype.Interval{Months: 1, Status: pgtype.Present}, + &pgtype.Interval{Months: 12, Status: pgtype.Present}, + &pgtype.Interval{Months: 13, Days: 15, Microseconds: 1000001, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: -1, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: -1000000, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: -1000001, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: -123202800000000, Status: pgtype.Present}, + &pgtype.Interval{Days: -1, Status: pgtype.Present}, + &pgtype.Interval{Months: -1, Status: pgtype.Present}, + &pgtype.Interval{Months: -12, Status: pgtype.Present}, + &pgtype.Interval{Months: -13, Days: -15, Microseconds: -1000001, Status: pgtype.Present}, + &pgtype.Interval{Status: pgtype.Null}, + }) +} + +func TestIntervalNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select '1 second'::interval", + Value: &pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, + }, + { + SQL: "select '1.000001 second'::interval", + Value: &pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, + }, + { + SQL: "select '34223 hours'::interval", + Value: &pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, + }, + { + SQL: "select '1 day'::interval", + Value: &pgtype.Interval{Days: 1, Status: pgtype.Present}, + }, + { + SQL: "select '1 month'::interval", + Value: &pgtype.Interval{Months: 1, Status: pgtype.Present}, + }, + { + SQL: "select '1 year'::interval", + Value: &pgtype.Interval{Months: 12, Status: pgtype.Present}, + }, + { + SQL: "select '-13 mon'::interval", + Value: &pgtype.Interval{Months: -13, Status: pgtype.Present}, + }, + }) +} diff --git a/pgtype/json.go b/pgtype/json.go new file mode 100644 index 00000000..562722aa --- /dev/null +++ b/pgtype/json.go @@ -0,0 +1,152 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/json" + + "github.com/pkg/errors" +) + +type JSON struct { + Bytes []byte + Status Status +} + +func (dst *JSON) Set(src interface{}) error { + if src == nil { + *dst = JSON{Status: Null} + return nil + } + + switch value := src.(type) { + case string: + *dst = JSON{Bytes: []byte(value), Status: Present} + case *string: + if value == nil { + *dst = JSON{Status: Null} + } else { + *dst = JSON{Bytes: []byte(*value), Status: Present} + } + case []byte: + if value == nil { + *dst = JSON{Status: Null} + } else { + *dst = JSON{Bytes: value, Status: Present} + } + default: + buf, err := json.Marshal(value) + if err != nil { + return err + } + *dst = JSON{Bytes: buf, Status: Present} + } + + return nil +} + +func (dst *JSON) Get() interface{} { + switch dst.Status { + case Present: + var i interface{} + err := json.Unmarshal(dst.Bytes, &i) + if err != nil { + return dst + } + return i + case Null: + return nil + default: + return dst.Status + } +} + +func (src *JSON) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *string: + if src.Status != Present { + v = nil + } else { + *v = string(src.Bytes) + } + case **string: + *v = new(string) + return src.AssignTo(*v) + case *[]byte: + if src.Status != Present { + *v = nil + } else { + buf := make([]byte, len(src.Bytes)) + copy(buf, src.Bytes) + *v = buf + } + default: + data := src.Bytes + if data == nil || src.Status != Present { + data = []byte("null") + } + + return json.Unmarshal(data, dst) + } + + return nil +} + +func (dst *JSON) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = JSON{Status: Null} + return nil + } + + *dst = JSON{Bytes: src, Status: Present} + return nil +} + +func (dst *JSON) DecodeBinary(ci *ConnInfo, src []byte) error { + return dst.DecodeText(ci, src) +} + +func (src *JSON) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, src.Bytes...), nil +} + +func (src *JSON) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return src.EncodeText(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *JSON) Scan(src interface{}) error { + if src == nil { + *dst = JSON{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *JSON) Value() (driver.Value, error) { + switch src.Status { + case Present: + return string(src.Bytes), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/json_test.go b/pgtype/json_test.go new file mode 100644 index 00000000..82c02539 --- /dev/null +++ b/pgtype/json_test.go @@ -0,0 +1,136 @@ +package pgtype_test + +import ( + "bytes" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestJSONTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "json", []interface{}{ + &pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}, + &pgtype.JSON{Bytes: []byte("null"), Status: pgtype.Present}, + &pgtype.JSON{Bytes: []byte("42"), Status: pgtype.Present}, + &pgtype.JSON{Bytes: []byte(`"hello"`), Status: pgtype.Present}, + &pgtype.JSON{Status: pgtype.Null}, + }) +} + +func TestJSONSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.JSON + }{ + {source: "{}", result: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: []byte("{}"), result: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: ([]byte)(nil), result: pgtype.JSON{Status: pgtype.Null}}, + {source: (*string)(nil), result: pgtype.JSON{Status: pgtype.Null}}, + {source: []int{1, 2, 3}, result: pgtype.JSON{Bytes: []byte("[1,2,3]"), Status: pgtype.Present}}, + {source: map[string]interface{}{"foo": "bar"}, result: pgtype.JSON{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var d pgtype.JSON + err := d.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(d, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestJSONAssignTo(t *testing.T) { + var s string + var ps *string + var b []byte + + rawStringTests := []struct { + src pgtype.JSON + dst *string + expected string + }{ + {src: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &s, expected: "{}"}, + } + + for i, tt := range rawStringTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if *tt.dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } + + rawBytesTests := []struct { + src pgtype.JSON + dst *[]byte + expected []byte + }{ + {src: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &b, expected: []byte("{}")}, + {src: pgtype.JSON{Status: pgtype.Null}, dst: &b, expected: (([]byte)(nil))}, + } + + for i, tt := range rawBytesTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if bytes.Compare(tt.expected, *tt.dst) != 0 { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } + + var mapDst map[string]interface{} + type structDst struct { + Name string `json:"name"` + Age int `json:"age"` + } + var strDst structDst + + unmarshalTests := []struct { + src pgtype.JSON + dst interface{} + expected interface{} + }{ + {src: pgtype.JSON{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, + {src: pgtype.JSON{Bytes: []byte(`{"name":"John","age":42}`), Status: pgtype.Present}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, + } + for i, tt := range unmarshalTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.JSON + dst **string + expected *string + }{ + {src: pgtype.JSON{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if *tt.dst == tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } +} diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go new file mode 100644 index 00000000..c315c588 --- /dev/null +++ b/pgtype/jsonb.go @@ -0,0 +1,70 @@ +package pgtype + +import ( + "database/sql/driver" + + "github.com/pkg/errors" +) + +type JSONB JSON + +func (dst *JSONB) Set(src interface{}) error { + return (*JSON)(dst).Set(src) +} + +func (dst *JSONB) Get() interface{} { + return (*JSON)(dst).Get() +} + +func (src *JSONB) AssignTo(dst interface{}) error { + return (*JSON)(src).AssignTo(dst) +} + +func (dst *JSONB) DecodeText(ci *ConnInfo, src []byte) error { + return (*JSON)(dst).DecodeText(ci, src) +} + +func (dst *JSONB) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = JSONB{Status: Null} + return nil + } + + if len(src) == 0 { + return errors.Errorf("jsonb too short") + } + + if src[0] != 1 { + return errors.Errorf("unknown jsonb version number %d", src[0]) + } + + *dst = JSONB{Bytes: src[1:], Status: Present} + return nil + +} + +func (src *JSONB) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*JSON)(src).EncodeText(ci, buf) +} + +func (src *JSONB) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, 1) + return append(buf, src.Bytes...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *JSONB) Scan(src interface{}) error { + return (*JSON)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *JSONB) Value() (driver.Value, error) { + return (*JSON)(src).Value() +} diff --git a/pgtype/jsonb_test.go b/pgtype/jsonb_test.go new file mode 100644 index 00000000..1a9a3056 --- /dev/null +++ b/pgtype/jsonb_test.go @@ -0,0 +1,142 @@ +package pgtype_test + +import ( + "bytes" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestJSONBTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustClose(t, conn) + if _, ok := conn.ConnInfo.DataTypeForName("jsonb"); !ok { + t.Skip("Skipping due to no jsonb type") + } + + testutil.TestSuccessfulTranscode(t, "jsonb", []interface{}{ + &pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}, + &pgtype.JSONB{Bytes: []byte("null"), Status: pgtype.Present}, + &pgtype.JSONB{Bytes: []byte("42"), Status: pgtype.Present}, + &pgtype.JSONB{Bytes: []byte(`"hello"`), Status: pgtype.Present}, + &pgtype.JSONB{Status: pgtype.Null}, + }) +} + +func TestJSONBSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.JSONB + }{ + {source: "{}", result: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: []byte("{}"), result: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: ([]byte)(nil), result: pgtype.JSONB{Status: pgtype.Null}}, + {source: (*string)(nil), result: pgtype.JSONB{Status: pgtype.Null}}, + {source: []int{1, 2, 3}, result: pgtype.JSONB{Bytes: []byte("[1,2,3]"), Status: pgtype.Present}}, + {source: map[string]interface{}{"foo": "bar"}, result: pgtype.JSONB{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var d pgtype.JSONB + err := d.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(d, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestJSONBAssignTo(t *testing.T) { + var s string + var ps *string + var b []byte + + rawStringTests := []struct { + src pgtype.JSONB + dst *string + expected string + }{ + {src: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &s, expected: "{}"}, + } + + for i, tt := range rawStringTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if *tt.dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } + + rawBytesTests := []struct { + src pgtype.JSONB + dst *[]byte + expected []byte + }{ + {src: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &b, expected: []byte("{}")}, + {src: pgtype.JSONB{Status: pgtype.Null}, dst: &b, expected: (([]byte)(nil))}, + } + + for i, tt := range rawBytesTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if bytes.Compare(tt.expected, *tt.dst) != 0 { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } + + var mapDst map[string]interface{} + type structDst struct { + Name string `json:"name"` + Age int `json:"age"` + } + var strDst structDst + + unmarshalTests := []struct { + src pgtype.JSONB + dst interface{} + expected interface{} + }{ + {src: pgtype.JSONB{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, + {src: pgtype.JSONB{Bytes: []byte(`{"name":"John","age":42}`), Status: pgtype.Present}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, + } + for i, tt := range unmarshalTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.JSONB + dst **string + expected *string + }{ + {src: pgtype.JSONB{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if *tt.dst == tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } +} diff --git a/pgtype/line.go b/pgtype/line.go new file mode 100644 index 00000000..f6eadf0e --- /dev/null +++ b/pgtype/line.go @@ -0,0 +1,143 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Line struct { + A, B, C float64 + Status Status +} + +func (dst *Line) Set(src interface{}) error { + return errors.Errorf("cannot convert %v to Line", src) +} + +func (dst *Line) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Line) AssignTo(dst interface{}) error { + return errors.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Line) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Line{Status: Null} + return nil + } + + if len(src) < 7 { + return errors.Errorf("invalid length for Line: %v", len(src)) + } + + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 3) + if len(parts) < 3 { + return errors.Errorf("invalid format for line") + } + + a, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return err + } + + b, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + return err + } + + c, err := strconv.ParseFloat(parts[2], 64) + if err != nil { + return err + } + + *dst = Line{A: a, B: b, C: c, Status: Present} + return nil +} + +func (dst *Line) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Line{Status: Null} + return nil + } + + if len(src) != 24 { + return errors.Errorf("invalid length for Line: %v", len(src)) + } + + a := binary.BigEndian.Uint64(src) + b := binary.BigEndian.Uint64(src[8:]) + c := binary.BigEndian.Uint64(src[16:]) + + *dst = Line{ + A: math.Float64frombits(a), + B: math.Float64frombits(b), + C: math.Float64frombits(c), + Status: Present, + } + return nil +} + +func (src *Line) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, fmt.Sprintf(`{%f,%f,%f}`, src.A, src.B, src.C)...), nil +} + +func (src *Line) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendUint64(buf, math.Float64bits(src.A)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.B)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.C)) + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Line) Scan(src interface{}) error { + if src == nil { + *dst = Line{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Line) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/line_test.go b/pgtype/line_test.go new file mode 100644 index 00000000..09e48019 --- /dev/null +++ b/pgtype/line_test.go @@ -0,0 +1,36 @@ +package pgtype_test + +import ( + "testing" + + version "github.com/hashicorp/go-version" + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestLineTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + serverVersion, err := version.NewVersion(conn.RuntimeParams["server_version"]) + if err != nil { + t.Fatalf("cannot get server version: %v", err) + } + testutil.MustClose(t, conn) + + minVersion := version.Must(version.NewVersion("9.4")) + + if serverVersion.LessThan(minVersion) { + t.Skipf("Skipping line test for server version %v", serverVersion) + } + + testutil.TestSuccessfulTranscode(t, "line", []interface{}{ + &pgtype.Line{ + A: 1.23, B: 4.56, C: 7.89, + Status: pgtype.Present, + }, + &pgtype.Line{ + A: -1.23, B: -4.56, C: -7.89, + Status: pgtype.Present, + }, + &pgtype.Line{Status: pgtype.Null}, + }) +} diff --git a/pgtype/lseg.go b/pgtype/lseg.go new file mode 100644 index 00000000..a9d740cf --- /dev/null +++ b/pgtype/lseg.go @@ -0,0 +1,161 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Lseg struct { + P [2]Vec2 + Status Status +} + +func (dst *Lseg) Set(src interface{}) error { + return errors.Errorf("cannot convert %v to Lseg", src) +} + +func (dst *Lseg) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Lseg) AssignTo(dst interface{}) error { + return errors.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Lseg) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Lseg{Status: Null} + return nil + } + + if len(src) < 11 { + return errors.Errorf("invalid length for Lseg: %v", len(src)) + } + + str := string(src[2:]) + + var end int + end = strings.IndexByte(str, ',') + + x1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+3:] + end = strings.IndexByte(str, ',') + + x2, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1 : len(str)-2] + + y2, err := strconv.ParseFloat(str, 64) + if err != nil { + return err + } + + *dst = Lseg{P: [2]Vec2{{x1, y1}, {x2, y2}}, Status: Present} + return nil +} + +func (dst *Lseg) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Lseg{Status: Null} + return nil + } + + if len(src) != 32 { + return errors.Errorf("invalid length for Lseg: %v", len(src)) + } + + x1 := binary.BigEndian.Uint64(src) + y1 := binary.BigEndian.Uint64(src[8:]) + x2 := binary.BigEndian.Uint64(src[16:]) + y2 := binary.BigEndian.Uint64(src[24:]) + + *dst = Lseg{ + P: [2]Vec2{ + {math.Float64frombits(x1), math.Float64frombits(y1)}, + {math.Float64frombits(x2), math.Float64frombits(y2)}, + }, + Status: Present, + } + return nil +} + +func (src *Lseg) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, fmt.Sprintf(`(%f,%f),(%f,%f)`, + src.P[0].X, src.P[0].Y, src.P[1].X, src.P[1].Y)...) + return buf, nil +} + +func (src *Lseg) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].Y)) + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Lseg) Scan(src interface{}) error { + if src == nil { + *dst = Lseg{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Lseg) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/lseg_test.go b/pgtype/lseg_test.go new file mode 100644 index 00000000..bd394e3c --- /dev/null +++ b/pgtype/lseg_test.go @@ -0,0 +1,22 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestLsegTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "lseg", []interface{}{ + &pgtype.Lseg{ + P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}}, + Status: pgtype.Present, + }, + &pgtype.Lseg{ + P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Status: pgtype.Present, + }, + &pgtype.Lseg{Status: pgtype.Null}, + }) +} diff --git a/pgtype/macaddr.go b/pgtype/macaddr.go new file mode 100644 index 00000000..4c6e2212 --- /dev/null +++ b/pgtype/macaddr.go @@ -0,0 +1,154 @@ +package pgtype + +import ( + "database/sql/driver" + "net" + + "github.com/pkg/errors" +) + +type Macaddr struct { + Addr net.HardwareAddr + Status Status +} + +func (dst *Macaddr) Set(src interface{}) error { + if src == nil { + *dst = Macaddr{Status: Null} + return nil + } + + switch value := src.(type) { + case net.HardwareAddr: + addr := make(net.HardwareAddr, len(value)) + copy(addr, value) + *dst = Macaddr{Addr: addr, Status: Present} + case string: + addr, err := net.ParseMAC(value) + if err != nil { + return err + } + *dst = Macaddr{Addr: addr, Status: Present} + default: + if originalSrc, ok := underlyingPtrType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Macaddr", value) + } + + return nil +} + +func (dst *Macaddr) Get() interface{} { + switch dst.Status { + case Present: + return dst.Addr + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Macaddr) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *net.HardwareAddr: + *v = make(net.HardwareAddr, len(src.Addr)) + copy(*v, src.Addr) + return nil + case *string: + *v = src.Addr.String() + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *Macaddr) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Macaddr{Status: Null} + return nil + } + + addr, err := net.ParseMAC(string(src)) + if err != nil { + return err + } + + *dst = Macaddr{Addr: addr, Status: Present} + return nil +} + +func (dst *Macaddr) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Macaddr{Status: Null} + return nil + } + + if len(src) != 6 { + return errors.Errorf("Received an invalid size for a macaddr: %d", len(src)) + } + + addr := make(net.HardwareAddr, 6) + copy(addr, src) + + *dst = Macaddr{Addr: addr, Status: Present} + + return nil +} + +func (src *Macaddr) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, src.Addr.String()...), nil +} + +// EncodeBinary encodes src into w. +func (src *Macaddr) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, src.Addr...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Macaddr) Scan(src interface{}) error { + if src == nil { + *dst = Macaddr{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Macaddr) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/macaddr_test.go b/pgtype/macaddr_test.go new file mode 100644 index 00000000..5d329249 --- /dev/null +++ b/pgtype/macaddr_test.go @@ -0,0 +1,78 @@ +package pgtype_test + +import ( + "bytes" + "net" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestMacaddrTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "macaddr", []interface{}{ + &pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + &pgtype.Macaddr{Status: pgtype.Null}, + }) +} + +func TestMacaddrSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Macaddr + }{ + { + source: mustParseMacaddr(t, "01:23:45:67:89:ab"), + result: pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + }, + { + source: "01:23:45:67:89:ab", + result: pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Macaddr + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestMacaddrAssignTo(t *testing.T) { + { + src := pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present} + var dst net.HardwareAddr + expected := mustParseMacaddr(t, "01:23:45:67:89:ab") + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if bytes.Compare([]byte(dst), []byte(expected)) != 0 { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present} + var dst string + expected := "01:23:45:67:89:ab" + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } +} diff --git a/pgtype/name.go b/pgtype/name.go new file mode 100644 index 00000000..af064a82 --- /dev/null +++ b/pgtype/name.go @@ -0,0 +1,58 @@ +package pgtype + +import ( + "database/sql/driver" +) + +// Name is a type used for PostgreSQL's special 63-byte +// name data type, used for identifiers like table names. +// The pg_class.relname column is a good example of where the +// name data type is used. +// +// Note that the underlying Go data type of pgx.Name is string, +// so there is no way to enforce the 63-byte length. Inputting +// a longer name into PostgreSQL will result in silent truncation +// to 63 bytes. +// +// Also, if you have custom-compiled PostgreSQL and set +// NAMEDATALEN to a different value, obviously that number of +// bytes applies, rather than the default 63. +type Name Text + +func (dst *Name) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst *Name) Get() interface{} { + return (*Text)(dst).Get() +} + +func (src *Name) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *Name) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (dst *Name) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) +} + +func (src *Name) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Text)(src).EncodeText(ci, buf) +} + +func (src *Name) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Text)(src).EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Name) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Name) Value() (driver.Value, error) { + return (*Text)(src).Value() +} diff --git a/pgtype/name_test.go b/pgtype/name_test.go new file mode 100644 index 00000000..ec0820c4 --- /dev/null +++ b/pgtype/name_test.go @@ -0,0 +1,98 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestNameTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "name", []interface{}{ + &pgtype.Name{String: "", Status: pgtype.Present}, + &pgtype.Name{String: "foo", Status: pgtype.Present}, + &pgtype.Name{Status: pgtype.Null}, + }) +} + +func TestNameSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Name + }{ + {source: "foo", result: pgtype.Name{String: "foo", Status: pgtype.Present}}, + {source: _string("bar"), result: pgtype.Name{String: "bar", Status: pgtype.Present}}, + {source: (*string)(nil), result: pgtype.Name{Status: pgtype.Null}}, + } + + for i, tt := range successfulTests { + var d pgtype.Name + err := d.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestNameAssignTo(t *testing.T) { + var s string + var ps *string + + simpleTests := []struct { + src pgtype.Name + dst interface{} + expected interface{} + }{ + {src: pgtype.Name{String: "foo", Status: pgtype.Present}, dst: &s, expected: "foo"}, + {src: pgtype.Name{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Name + dst interface{} + expected interface{} + }{ + {src: pgtype.Name{String: "foo", Status: pgtype.Present}, dst: &ps, expected: "foo"}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Name + dst interface{} + }{ + {src: pgtype.Name{Status: pgtype.Null}, dst: &s}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/numeric.go b/pgtype/numeric.go new file mode 100644 index 00000000..fded6359 --- /dev/null +++ b/pgtype/numeric.go @@ -0,0 +1,596 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "math" + "math/big" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +// PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000 +const nbase = 10000 + +var big0 *big.Int = big.NewInt(0) +var big10 *big.Int = big.NewInt(10) +var big100 *big.Int = big.NewInt(100) +var big1000 *big.Int = big.NewInt(1000) + +var bigMaxInt8 *big.Int = big.NewInt(math.MaxInt8) +var bigMinInt8 *big.Int = big.NewInt(math.MinInt8) +var bigMaxInt16 *big.Int = big.NewInt(math.MaxInt16) +var bigMinInt16 *big.Int = big.NewInt(math.MinInt16) +var bigMaxInt32 *big.Int = big.NewInt(math.MaxInt32) +var bigMinInt32 *big.Int = big.NewInt(math.MinInt32) +var bigMaxInt64 *big.Int = big.NewInt(math.MaxInt64) +var bigMinInt64 *big.Int = big.NewInt(math.MinInt64) +var bigMaxInt *big.Int = big.NewInt(int64(maxInt)) +var bigMinInt *big.Int = big.NewInt(int64(minInt)) + +var bigMaxUint8 *big.Int = big.NewInt(math.MaxUint8) +var bigMaxUint16 *big.Int = big.NewInt(math.MaxUint16) +var bigMaxUint32 *big.Int = big.NewInt(math.MaxUint32) +var bigMaxUint64 *big.Int = (&big.Int{}).SetUint64(uint64(math.MaxUint64)) +var bigMaxUint *big.Int = (&big.Int{}).SetUint64(uint64(maxUint)) + +var bigNBase *big.Int = big.NewInt(nbase) +var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase) +var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) +var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) + +type Numeric struct { + Int *big.Int + Exp int32 + Status Status +} + +func (dst *Numeric) Set(src interface{}) error { + if src == nil { + *dst = Numeric{Status: Null} + return nil + } + + switch value := src.(type) { + case float32: + num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64)) + if err != nil { + return err + } + *dst = Numeric{Int: num, Exp: exp, Status: Present} + case float64: + num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64)) + if err != nil { + return err + } + *dst = Numeric{Int: num, Exp: exp, Status: Present} + case int8: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case uint8: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case int16: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case uint16: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case int32: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case uint32: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case int64: + *dst = Numeric{Int: big.NewInt(value), Status: Present} + case uint64: + *dst = Numeric{Int: (&big.Int{}).SetUint64(value), Status: Present} + case int: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case uint: + *dst = Numeric{Int: (&big.Int{}).SetUint64(uint64(value)), Status: Present} + case string: + num, exp, err := parseNumericString(value) + if err != nil { + return err + } + *dst = Numeric{Int: num, Exp: exp, Status: Present} + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Numeric", value) + } + + return nil +} + +func (dst *Numeric) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Numeric) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *float32: + f, err := src.toFloat64() + if err != nil { + return err + } + return float64AssignTo(f, src.Status, dst) + case *float64: + f, err := src.toFloat64() + if err != nil { + return err + } + return float64AssignTo(f, src.Status, dst) + case *int: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt) > 0 { + return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt) < 0 { + return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int(normalizedInt.Int64()) + case *int8: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt8) > 0 { + return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt8) < 0 { + return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int8(normalizedInt.Int64()) + case *int16: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt16) > 0 { + return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt16) < 0 { + return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int16(normalizedInt.Int64()) + case *int32: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt32) > 0 { + return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt32) < 0 { + return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int32(normalizedInt.Int64()) + case *int64: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt64) > 0 { + return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt64) < 0 { + return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = normalizedInt.Int64() + case *uint: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint) > 0 { + return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint(normalizedInt.Uint64()) + case *uint8: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint8) > 0 { + return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint8(normalizedInt.Uint64()) + case *uint16: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint16) > 0 { + return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint16(normalizedInt.Uint64()) + case *uint32: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint32) > 0 { + return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint32(normalizedInt.Uint64()) + case *uint64: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint64) > 0 { + return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = normalizedInt.Uint64() + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return nil +} + +func (dst *Numeric) toBigInt() (*big.Int, error) { + if dst.Exp == 0 { + return dst.Int, nil + } + + num := &big.Int{} + num.Set(dst.Int) + if dst.Exp > 0 { + mul := &big.Int{} + mul.Exp(big10, big.NewInt(int64(dst.Exp)), nil) + num.Mul(num, mul) + return num, nil + } + + div := &big.Int{} + div.Exp(big10, big.NewInt(int64(-dst.Exp)), nil) + remainder := &big.Int{} + num.DivMod(num, div, remainder) + if remainder.Cmp(big0) != 0 { + return nil, errors.Errorf("cannot convert %v to integer", dst) + } + return num, nil +} + +func (src *Numeric) toFloat64() (float64, error) { + f, err := strconv.ParseFloat(src.Int.String(), 64) + if err != nil { + return 0, err + } + if src.Exp > 0 { + for i := 0; i < int(src.Exp); i++ { + f *= 10 + } + } else if src.Exp < 0 { + for i := 0; i > int(src.Exp); i-- { + f /= 10 + } + } + return f, nil +} + +func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Numeric{Status: Null} + return nil + } + + num, exp, err := parseNumericString(string(src)) + if err != nil { + return err + } + + *dst = Numeric{Int: num, Exp: exp, Status: Present} + return nil +} + +func parseNumericString(str string) (n *big.Int, exp int32, err error) { + parts := strings.SplitN(str, ".", 2) + digits := strings.Join(parts, "") + + if len(parts) > 1 { + exp = int32(-len(parts[1])) + } else { + for len(digits) > 1 && digits[len(digits)-1] == '0' { + digits = digits[:len(digits)-1] + exp++ + } + } + + accum := &big.Int{} + if _, ok := accum.SetString(digits, 10); !ok { + return nil, 0, errors.Errorf("%s is not a number", str) + } + + return accum, exp, nil +} + +func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Numeric{Status: Null} + return nil + } + + if len(src) < 8 { + return errors.Errorf("numeric incomplete %v", src) + } + + rp := 0 + ndigits := int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + if ndigits == 0 { + *dst = Numeric{Int: big.NewInt(0), Status: Present} + return nil + } + + weight := int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + sign := int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + dscale := int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + if len(src[rp:]) < int(ndigits)*2 { + return errors.Errorf("numeric incomplete %v", src) + } + + accum := &big.Int{} + + for i := 0; i < int(ndigits+3)/4; i++ { + int64accum, bytesRead, digitsRead := nbaseDigitsToInt64(src[rp:]) + rp += bytesRead + + if i > 0 { + var mul *big.Int + switch digitsRead { + case 1: + mul = bigNBase + case 2: + mul = bigNBaseX2 + case 3: + mul = bigNBaseX3 + case 4: + mul = bigNBaseX4 + default: + return errors.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead) + } + accum.Mul(accum, mul) + } + + accum.Add(accum, big.NewInt(int64accum)) + } + + exp := (int32(weight) - int32(ndigits) + 1) * 4 + + if dscale > 0 { + fracNBaseDigits := ndigits - weight - 1 + fracDecimalDigits := fracNBaseDigits * 4 + + if dscale > fracDecimalDigits { + multCount := int(dscale - fracDecimalDigits) + for i := 0; i < multCount; i++ { + accum.Mul(accum, big10) + exp-- + } + } else if dscale < fracDecimalDigits { + divCount := int(fracDecimalDigits - dscale) + for i := 0; i < divCount; i++ { + accum.Div(accum, big10) + exp++ + } + } + } + + reduced := &big.Int{} + remainder := &big.Int{} + if exp >= 0 { + for { + reduced.DivMod(accum, big10, remainder) + if remainder.Cmp(big0) != 0 { + break + } + accum.Set(reduced) + exp++ + } + } + + if sign != 0 { + accum.Neg(accum) + } + + *dst = Numeric{Int: accum, Exp: exp, Status: Present} + + return nil + +} + +func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { + digits := len(src) / 2 + if digits > 4 { + digits = 4 + } + + rp := 0 + + for i := 0; i < digits; i++ { + if i > 0 { + accum *= nbase + } + accum += int64(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + } + + return accum, rp, digits +} + +func (src *Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, src.Int.String()...) + buf = append(buf, 'e') + buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...) + return buf, nil +} + +func (src *Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var sign int16 + if src.Int.Cmp(big0) < 0 { + sign = 16384 + } + + absInt := &big.Int{} + wholePart := &big.Int{} + fracPart := &big.Int{} + remainder := &big.Int{} + absInt.Abs(src.Int) + + // Normalize absInt and exp to where exp is always a multiple of 4. This makes + // converting to 16-bit base 10,000 digits easier. + var exp int32 + switch src.Exp % 4 { + case 1, -3: + exp = src.Exp - 1 + absInt.Mul(absInt, big10) + case 2, -2: + exp = src.Exp - 2 + absInt.Mul(absInt, big100) + case 3, -1: + exp = src.Exp - 3 + absInt.Mul(absInt, big1000) + default: + exp = src.Exp + } + + if exp < 0 { + divisor := &big.Int{} + divisor.Exp(big10, big.NewInt(int64(-exp)), nil) + wholePart.DivMod(absInt, divisor, fracPart) + } else { + wholePart = absInt + } + + var wholeDigits, fracDigits []int16 + + for wholePart.Cmp(big0) != 0 { + wholePart.DivMod(wholePart, bigNBase, remainder) + wholeDigits = append(wholeDigits, int16(remainder.Int64())) + } + + for fracPart.Cmp(big0) != 0 { + fracPart.DivMod(fracPart, bigNBase, remainder) + fracDigits = append(fracDigits, int16(remainder.Int64())) + } + + buf = pgio.AppendInt16(buf, int16(len(wholeDigits)+len(fracDigits))) + + var weight int16 + if len(wholeDigits) > 0 { + weight = int16(len(wholeDigits) - 1) + if exp > 0 { + weight += int16(exp / 4) + } + } else { + weight = int16(exp/4) - 1 + int16(len(fracDigits)) + } + buf = pgio.AppendInt16(buf, weight) + + buf = pgio.AppendInt16(buf, sign) + + var dscale int16 + if src.Exp < 0 { + dscale = int16(-src.Exp) + } + buf = pgio.AppendInt16(buf, dscale) + + for i := len(wholeDigits) - 1; i >= 0; i-- { + buf = pgio.AppendInt16(buf, wholeDigits[i]) + } + + for i := len(fracDigits) - 1; i >= 0; i-- { + buf = pgio.AppendInt16(buf, fracDigits[i]) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Numeric) Scan(src interface{}) error { + if src == nil { + *dst = Numeric{Status: Null} + return nil + } + + switch src := src.(type) { + case float64: + // TODO + // *dst = Numeric{Float: src, Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Numeric) Value() (driver.Value, error) { + switch src.Status { + case Present: + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + + return string(buf), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/numeric_array.go b/pgtype/numeric_array.go new file mode 100644 index 00000000..6dfbe5e3 --- /dev/null +++ b/pgtype/numeric_array.go @@ -0,0 +1,322 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type NumericArray struct { + Elements []Numeric + Dimensions []ArrayDimension + Status Status +} + +func (dst *NumericArray) Set(src interface{}) error { + switch value := src.(type) { + + case []float32: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []float64: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Numeric", value) + } + + return nil +} + +func (dst *NumericArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *NumericArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]float32: + *v = make([]float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]float64: + *v = make([]float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *NumericArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = NumericArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Numeric + + if len(uta.Elements) > 0 { + elements = make([]Numeric, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Numeric + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = NumericArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *NumericArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = NumericArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = NumericArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Numeric, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = NumericArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *NumericArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src *NumericArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("numeric"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "numeric") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *NumericArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *NumericArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/numeric_array_test.go b/pgtype/numeric_array_test.go new file mode 100644 index 00000000..25531840 --- /dev/null +++ b/pgtype/numeric_array_test.go @@ -0,0 +1,160 @@ +package pgtype_test + +import ( + "math/big" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestNumericArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "numeric[]", []interface{}{ + &pgtype.NumericArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}, + pgtype.Numeric{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.NumericArray{Status: pgtype.Null}, + &pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(2), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(3), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(4), Status: pgtype.Present}, + pgtype.Numeric{Status: pgtype.Null}, + pgtype.Numeric{Int: big.NewInt(6), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(2), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(3), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(4), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestNumericArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.NumericArray + }{ + { + source: []float32{1}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []float64{1}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]float32)(nil)), + result: pgtype.NumericArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.NumericArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestNumericArrayAssignTo(t *testing.T) { + var float32Slice []float32 + var float64Slice []float64 + + simpleTests := []struct { + src pgtype.NumericArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float32Slice, + expected: []float32{1}, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float64Slice, + expected: []float64{1}, + }, + { + src: pgtype.NumericArray{Status: pgtype.Null}, + dst: &float32Slice, + expected: (([]float32)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.NumericArray + dst interface{} + }{ + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float32Slice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go new file mode 100644 index 00000000..5f3a3416 --- /dev/null +++ b/pgtype/numeric_test.go @@ -0,0 +1,319 @@ +package pgtype_test + +import ( + "math/big" + "math/rand" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +// For test purposes only. Note that it does not normalize values. e.g. (Int: 1, Exp: 3) will not equal (Int: 1000, Exp: 0) +func numericEqual(left, right *pgtype.Numeric) bool { + return left.Status == right.Status && + left.Exp == right.Exp && + ((left.Int == nil && right.Int == nil) || (left.Int != nil && right.Int != nil && left.Int.Cmp(right.Int) == 0)) +} + +// For test purposes only. +func numericNormalizedEqual(left, right *pgtype.Numeric) bool { + if left.Status != right.Status { + return false + } + + normLeft := &pgtype.Numeric{Int: (&big.Int{}).Set(left.Int), Status: left.Status} + normRight := &pgtype.Numeric{Int: (&big.Int{}).Set(right.Int), Status: right.Status} + + if left.Exp < right.Exp { + mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(right.Exp-left.Exp)), nil) + normRight.Int.Mul(normRight.Int, mul) + } else if left.Exp > right.Exp { + mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(left.Exp-right.Exp)), nil) + normLeft.Int.Mul(normLeft.Int, mul) + } + + return normLeft.Int.Cmp(normRight.Int) == 0 +} + +func mustParseBigInt(t *testing.T, src string) *big.Int { + i := &big.Int{} + if _, ok := i.SetString(src, 10); !ok { + t.Fatalf("could not parse big.Int: %s", src) + } + return i +} + +func TestNumericNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select '0'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, + }, + { + SQL: "select '1'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, + }, + { + SQL: "select '10.00'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(1000), Exp: -2, Status: pgtype.Present}, + }, + { + SQL: "select '1e-3'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: -3, Status: pgtype.Present}, + }, + { + SQL: "select '-1'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, + }, + { + SQL: "select '10000'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: 4, Status: pgtype.Present}, + }, + { + SQL: "select '3.14'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, + }, + { + SQL: "select '1.1'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(11), Exp: -1, Status: pgtype.Present}, + }, + { + SQL: "select '100010001'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(100010001), Exp: 0, Status: pgtype.Present}, + }, + { + SQL: "select '100010001.0001'::numeric", + Value: &pgtype.Numeric{Int: big.NewInt(1000100010001), Exp: -4, Status: pgtype.Present}, + }, + { + SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", + Value: &pgtype.Numeric{ + Int: mustParseBigInt(t, "423723478923478928934789237432487213832189417894318904389012483210893443219085471578891547854892438945012347981"), + Exp: -41, + Status: pgtype.Present, + }, + }, + { + SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", + Value: &pgtype.Numeric{ + Int: mustParseBigInt(t, "8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), + Exp: -196, + Status: pgtype.Present, + }, + }, + { + SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", + Value: &pgtype.Numeric{ + Int: mustParseBigInt(t, "123"), + Exp: -186, + Status: pgtype.Present, + }, + }, + }) +} + +func TestNumericTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ + &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(1), Exp: 6, Status: pgtype.Present}, + + // preserves significant zeroes + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -1, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -2, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -3, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -4, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -5, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -6, Status: pgtype.Present}, + + &pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -7, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -8, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -9, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -1500, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "2437"), Exp: 23790, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "243723409723490243842378942378901237502734019231380123"), Exp: 23790, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 80, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3723409723490243842378942378901237502734019231380123"), Exp: 81, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "723409723490243842378942378901237502734019231380123"), Exp: 82, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "23409723490243842378942378901237502734019231380123"), Exp: 83, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3409723490243842378942378901237502734019231380123"), Exp: 84, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "913423409823409243892349028349023482934092340892390101"), Exp: -14021, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -90, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3423409823409243892349028349023482934092340892390101"), Exp: -91, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "423409823409243892349028349023482934092340892390101"), Exp: -92, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "23409823409243892349028349023482934092340892390101"), Exp: -93, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3409823409243892349028349023482934092340892390101"), Exp: -94, Status: pgtype.Present}, + &pgtype.Numeric{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Numeric) + b := bb.(pgtype.Numeric) + + return numericEqual(&a, &b) + }) + +} + +func TestNumericTranscodeFuzz(t *testing.T) { + r := rand.New(rand.NewSource(0)) + max := &big.Int{} + max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) + + values := make([]interface{}, 0, 2000) + for i := 0; i < 10; i++ { + for j := -50; j < 50; j++ { + num := (&big.Int{}).Rand(r, max) + negNum := &big.Int{} + negNum.Neg(num) + values = append(values, &pgtype.Numeric{Int: num, Exp: int32(j), Status: pgtype.Present}) + values = append(values, &pgtype.Numeric{Int: negNum, Exp: int32(j), Status: pgtype.Present}) + } + } + + testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", values, + func(aa, bb interface{}) bool { + a := aa.(pgtype.Numeric) + b := bb.(pgtype.Numeric) + + return numericNormalizedEqual(&a, &b) + }) +} + +func TestNumericSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result *pgtype.Numeric + }{ + {source: float32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: float64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int8(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, + {source: int16(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, + {source: int32(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, + {source: int64(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, + {source: uint8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: uint16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: uint32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: uint64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: "1", result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: _int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: float64(1000), result: &pgtype.Numeric{Int: big.NewInt(1), Exp: 3, Status: pgtype.Present}}, + {source: float64(1234), result: &pgtype.Numeric{Int: big.NewInt(1234), Exp: 0, Status: pgtype.Present}}, + {source: float64(12345678900), result: &pgtype.Numeric{Int: big.NewInt(123456789), Exp: 2, Status: pgtype.Present}}, + {source: float64(12345.678901), result: &pgtype.Numeric{Int: big.NewInt(12345678901), Exp: -6, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + r := &pgtype.Numeric{} + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !numericEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestNumericAssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + var f32 float32 + var f64 float64 + var pf32 *float32 + var pf64 *float64 + + simpleTests := []struct { + src *pgtype.Numeric + dst interface{} + expected interface{} + }{ + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f32, expected: float32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f64, expected: float64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Status: pgtype.Present}, dst: &f32, expected: float32(4.2)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Status: pgtype.Present}, dst: &f64, expected: float64(4.2)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: 3, Status: pgtype.Present}, dst: &i64, expected: int64(42000)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src *pgtype.Numeric + dst interface{} + expected interface{} + }{ + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src *pgtype.Numeric + dst interface{} + }{ + {src: &pgtype.Numeric{Int: big.NewInt(150), Status: pgtype.Present}, dst: &i8}, + {src: &pgtype.Numeric{Int: big.NewInt(40000), Status: pgtype.Present}, dst: &i16}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui8}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui16}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui32}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui64}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui}, + {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/numrange.go b/pgtype/numrange.go new file mode 100644 index 00000000..aaed62ce --- /dev/null +++ b/pgtype/numrange.go @@ -0,0 +1,250 @@ +package pgtype + +import ( + "database/sql/driver" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Numrange struct { + Lower Numeric + Upper Numeric + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Numrange) Set(src interface{}) error { + return errors.Errorf("cannot convert %v to Numrange", src) +} + +func (dst *Numrange) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Numrange) AssignTo(dst interface{}) error { + return errors.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Numrange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Numrange{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Numrange{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Numrange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Numrange{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Numrange{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Numrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) + } + + var err error + + if src.LowerType != Unbounded { + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if src.UpperType != Unbounded { + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) + } + + return buf, nil +} + +func (src Numrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) + } + + buf = append(buf, rangeType) + + var err error + + if src.LowerType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if src.UpperType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Numrange) Scan(src interface{}) error { + if src == nil { + *dst = Numrange{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Numrange) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/numrange_test.go b/pgtype/numrange_test.go new file mode 100644 index 00000000..32267c86 --- /dev/null +++ b/pgtype/numrange_test.go @@ -0,0 +1,34 @@ +package pgtype_test + +import ( + "math/big" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestNumrangeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "numrange", []interface{}{ + &pgtype.Numrange{ + LowerType: pgtype.Empty, + UpperType: pgtype.Empty, + Status: pgtype.Present, + }, + &pgtype.Numrange{ + Lower: pgtype.Numeric{Int: big.NewInt(-543), Exp: 3, Status: pgtype.Present}, + Upper: pgtype.Numeric{Int: big.NewInt(342), Exp: 1, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + &pgtype.Numrange{ + Lower: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Status: pgtype.Present}, + Upper: pgtype.Numeric{Int: big.NewInt(-5), Exp: 0, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + &pgtype.Numrange{Status: pgtype.Null}, + }) +} diff --git a/pgtype/oid.go b/pgtype/oid.go new file mode 100644 index 00000000..59370d66 --- /dev/null +++ b/pgtype/oid.go @@ -0,0 +1,81 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "strconv" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +// OID (Object Identifier Type) is, according to +// https://www.postgresql.org/docs/current/static/datatype-oid.html, used +// internally by PostgreSQL as a primary key for various system tables. It is +// currently implemented as an unsigned four-byte integer. Its definition can be +// found in src/include/postgres_ext.h in the PostgreSQL sources. Because it is +// so frequently required to be in a NOT NULL condition OID cannot be NULL. To +// allow for NULL OIDs use OIDValue. +type OID uint32 + +func (dst *OID) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + return errors.Errorf("cannot decode nil into OID") + } + + n, err := strconv.ParseUint(string(src), 10, 32) + if err != nil { + return err + } + + *dst = OID(n) + return nil +} + +func (dst *OID) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + return errors.Errorf("cannot decode nil into OID") + } + + if len(src) != 4 { + return errors.Errorf("invalid length: %v", len(src)) + } + + n := binary.BigEndian.Uint32(src) + *dst = OID(n) + return nil +} + +func (src OID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return append(buf, strconv.FormatUint(uint64(src), 10)...), nil +} + +func (src OID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return pgio.AppendUint32(buf, uint32(src)), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *OID) Scan(src interface{}) error { + if src == nil { + return errors.Errorf("cannot scan NULL into %T", src) + } + + switch src := src.(type) { + case int64: + *dst = OID(src) + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src OID) Value() (driver.Value, error) { + return int64(src), nil +} diff --git a/pgtype/oid_value.go b/pgtype/oid_value.go new file mode 100644 index 00000000..7eae4bf1 --- /dev/null +++ b/pgtype/oid_value.go @@ -0,0 +1,55 @@ +package pgtype + +import ( + "database/sql/driver" +) + +// OIDValue (Object Identifier Type) is, according to +// https://www.postgresql.org/docs/current/static/datatype-OIDValue.html, used +// internally by PostgreSQL as a primary key for various system tables. It is +// currently implemented as an unsigned four-byte integer. Its definition can be +// found in src/include/postgres_ext.h in the PostgreSQL sources. +type OIDValue pguint32 + +// Set converts from src to dst. Note that as OIDValue is not a general +// number type Set does not do automatic type conversion as other number +// types do. +func (dst *OIDValue) Set(src interface{}) error { + return (*pguint32)(dst).Set(src) +} + +func (dst *OIDValue) Get() interface{} { + return (*pguint32)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as OIDValue is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *OIDValue) AssignTo(dst interface{}) error { + return (*pguint32)(src).AssignTo(dst) +} + +func (dst *OIDValue) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) +} + +func (dst *OIDValue) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) +} + +func (src *OIDValue) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*pguint32)(src).EncodeText(ci, buf) +} + +func (src *OIDValue) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*pguint32)(src).EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *OIDValue) Scan(src interface{}) error { + return (*pguint32)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *OIDValue) Value() (driver.Value, error) { + return (*pguint32)(src).Value() +} diff --git a/pgtype/oid_value_test.go b/pgtype/oid_value_test.go new file mode 100644 index 00000000..f5ff16cf --- /dev/null +++ b/pgtype/oid_value_test.go @@ -0,0 +1,95 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestOIDValueTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "oid", []interface{}{ + &pgtype.OIDValue{Uint: 42, Status: pgtype.Present}, + &pgtype.OIDValue{Status: pgtype.Null}, + }) +} + +func TestOIDValueSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.OIDValue + }{ + {source: uint32(1), result: pgtype.OIDValue{Uint: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.OIDValue + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestOIDValueAssignTo(t *testing.T) { + var ui32 uint32 + var pui32 *uint32 + + simpleTests := []struct { + src pgtype.OIDValue + dst interface{} + expected interface{} + }{ + {src: pgtype.OIDValue{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.OIDValue{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.OIDValue + dst interface{} + expected interface{} + }{ + {src: pgtype.OIDValue{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.OIDValue + dst interface{} + }{ + {src: pgtype.OIDValue{Status: pgtype.Null}, dst: &ui32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/path.go b/pgtype/path.go new file mode 100644 index 00000000..aa0cee8e --- /dev/null +++ b/pgtype/path.go @@ -0,0 +1,193 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Path struct { + P []Vec2 + Closed bool + Status Status +} + +func (dst *Path) Set(src interface{}) error { + return errors.Errorf("cannot convert %v to Path", src) +} + +func (dst *Path) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Path) AssignTo(dst interface{}) error { + return errors.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Path) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Path{Status: Null} + return nil + } + + if len(src) < 7 { + return errors.Errorf("invalid length for Path: %v", len(src)) + } + + closed := src[0] == '(' + points := make([]Vec2, 0) + + str := string(src[2:]) + + for { + end := strings.IndexByte(str, ',') + x, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + points = append(points, Vec2{x, y}) + + if end+3 < len(str) { + str = str[end+3:] + } else { + break + } + } + + *dst = Path{P: points, Closed: closed, Status: Present} + return nil +} + +func (dst *Path) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Path{Status: Null} + return nil + } + + if len(src) < 5 { + return errors.Errorf("invalid length for Path: %v", len(src)) + } + + closed := src[0] == 1 + pointCount := int(binary.BigEndian.Uint32(src[1:])) + + rp := 5 + + if 5+pointCount*16 != len(src) { + return errors.Errorf("invalid length for Path with %d points: %v", pointCount, len(src)) + } + + points := make([]Vec2, pointCount) + for i := 0; i < len(points); i++ { + x := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + y := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + points[i] = Vec2{math.Float64frombits(x), math.Float64frombits(y)} + } + + *dst = Path{ + P: points, + Closed: closed, + Status: Present, + } + return nil +} + +func (src *Path) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var startByte, endByte byte + if src.Closed { + startByte = '(' + endByte = ')' + } else { + startByte = '[' + endByte = ']' + } + buf = append(buf, startByte) + + for i, p := range src.P { + if i > 0 { + buf = append(buf, ',') + } + buf = append(buf, fmt.Sprintf(`(%f,%f)`, p.X, p.Y)...) + } + + return append(buf, endByte), nil +} + +func (src *Path) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var closeByte byte + if src.Closed { + closeByte = 1 + } + buf = append(buf, closeByte) + + buf = pgio.AppendInt32(buf, int32(len(src.P))) + + for _, p := range src.P { + buf = pgio.AppendUint64(buf, math.Float64bits(p.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(p.Y)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Path) Scan(src interface{}) error { + if src == nil { + *dst = Path{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Path) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/path_test.go b/pgtype/path_test.go new file mode 100644 index 00000000..d213a1b4 --- /dev/null +++ b/pgtype/path_test.go @@ -0,0 +1,29 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestPathTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "path", []interface{}{ + &pgtype.Path{ + P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}}, + Closed: false, + Status: pgtype.Present, + }, + &pgtype.Path{ + P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {23.1, 9.34}}, + Closed: true, + Status: pgtype.Present, + }, + &pgtype.Path{ + P: []pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Closed: true, + Status: pgtype.Present, + }, + &pgtype.Path{Status: pgtype.Null}, + }) +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go new file mode 100644 index 00000000..6f8e7986 --- /dev/null +++ b/pgtype/pgtype.go @@ -0,0 +1,271 @@ +package pgtype + +import ( + "reflect" + + "github.com/pkg/errors" +) + +// PostgreSQL oids for common types +const ( + BoolOID = 16 + ByteaOID = 17 + CharOID = 18 + NameOID = 19 + Int8OID = 20 + Int2OID = 21 + Int4OID = 23 + TextOID = 25 + OIDOID = 26 + TIDOID = 27 + XIDOID = 28 + CIDOID = 29 + JSONOID = 114 + CIDROID = 650 + CIDRArrayOID = 651 + Float4OID = 700 + Float8OID = 701 + UnknownOID = 705 + InetOID = 869 + BoolArrayOID = 1000 + Int2ArrayOID = 1005 + Int4ArrayOID = 1007 + TextArrayOID = 1009 + ByteaArrayOID = 1001 + VarcharArrayOID = 1015 + Int8ArrayOID = 1016 + Float4ArrayOID = 1021 + Float8ArrayOID = 1022 + ACLItemOID = 1033 + ACLItemArrayOID = 1034 + InetArrayOID = 1041 + VarcharOID = 1043 + DateOID = 1082 + TimestampOID = 1114 + TimestampArrayOID = 1115 + DateArrayOID = 1182 + TimestamptzOID = 1184 + TimestamptzArrayOID = 1185 + RecordOID = 2249 + UUIDOID = 2950 + JSONBOID = 3802 +) + +type Status byte + +const ( + Undefined Status = iota + Null + Present +) + +type InfinityModifier int8 + +const ( + Infinity InfinityModifier = 1 + None InfinityModifier = 0 + NegativeInfinity InfinityModifier = -Infinity +) + +func (im InfinityModifier) String() string { + switch im { + case None: + return "none" + case Infinity: + return "infinity" + case NegativeInfinity: + return "-infinity" + default: + return "invalid" + } +} + +type Value interface { + // Set converts and assigns src to itself. + Set(src interface{}) error + + // Get returns the simplest representation of Value. If the Value is Null or + // Undefined that is the return value. If no simpler representation is + // possible, then Get() returns Value. + Get() interface{} + + // AssignTo converts and assigns the Value to dst. It MUST make a deep copy of + // any reference types. + AssignTo(dst interface{}) error +} + +type BinaryDecoder interface { + // DecodeBinary decodes src into BinaryDecoder. If src is nil then the + // original SQL value is NULL. BinaryDecoder takes ownership of src. The + // caller MUST not use it again. + DecodeBinary(ci *ConnInfo, src []byte) error +} + +type TextDecoder interface { + // DecodeText decodes src into TextDecoder. If src is nil then the original + // SQL value is NULL. TextDecoder takes ownership of src. The caller MUST not + // use it again. + DecodeText(ci *ConnInfo, src []byte) error +} + +// BinaryEncoder is implemented by types that can encode themselves into the +// PostgreSQL binary wire format. +type BinaryEncoder interface { + // EncodeBinary should append the binary format of self to buf. If self is the + // SQL value NULL then append nothing and return (nil, nil). The caller of + // EncodeBinary is responsible for writing the correct NULL value or the + // length of the data written. + EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) +} + +// TextEncoder is implemented by types that can encode themselves into the +// PostgreSQL text wire format. +type TextEncoder interface { + // EncodeText should append the text format of self to buf. If self is the + // SQL value NULL then append nothing and return (nil, nil). The caller of + // EncodeText is responsible for writing the correct NULL value or the + // length of the data written. + EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error) +} + +var errUndefined = errors.New("cannot encode status undefined") +var errBadStatus = errors.New("invalid status") + +type DataType struct { + Value Value + Name string + OID OID +} + +type ConnInfo struct { + oidToDataType map[OID]*DataType + nameToDataType map[string]*DataType + reflectTypeToDataType map[reflect.Type]*DataType +} + +func NewConnInfo() *ConnInfo { + return &ConnInfo{ + oidToDataType: make(map[OID]*DataType, 256), + nameToDataType: make(map[string]*DataType, 256), + reflectTypeToDataType: make(map[reflect.Type]*DataType, 256), + } +} + +func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]OID) { + for name, oid := range nameOIDs { + var value Value + if t, ok := nameValues[name]; ok { + value = reflect.New(reflect.ValueOf(t).Elem().Type()).Interface().(Value) + } else { + value = &GenericText{} + } + ci.RegisterDataType(DataType{Value: value, Name: name, OID: oid}) + } +} + +func (ci *ConnInfo) RegisterDataType(t DataType) { + ci.oidToDataType[t.OID] = &t + ci.nameToDataType[t.Name] = &t + ci.reflectTypeToDataType[reflect.ValueOf(t.Value).Type()] = &t +} + +func (ci *ConnInfo) DataTypeForOID(oid OID) (*DataType, bool) { + dt, ok := ci.oidToDataType[oid] + return dt, ok +} + +func (ci *ConnInfo) DataTypeForName(name string) (*DataType, bool) { + dt, ok := ci.nameToDataType[name] + return dt, ok +} + +func (ci *ConnInfo) DataTypeForValue(v Value) (*DataType, bool) { + dt, ok := ci.reflectTypeToDataType[reflect.ValueOf(v).Type()] + return dt, ok +} + +// DeepCopy makes a deep copy of the ConnInfo. +func (ci *ConnInfo) DeepCopy() *ConnInfo { + ci2 := &ConnInfo{ + oidToDataType: make(map[OID]*DataType, len(ci.oidToDataType)), + nameToDataType: make(map[string]*DataType, len(ci.nameToDataType)), + reflectTypeToDataType: make(map[reflect.Type]*DataType, len(ci.reflectTypeToDataType)), + } + + for _, dt := range ci.oidToDataType { + ci2.RegisterDataType(DataType{ + Value: reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value), + Name: dt.Name, + OID: dt.OID, + }) + } + + return ci2 +} + +var nameValues map[string]Value + +func init() { + nameValues = map[string]Value{ + "_aclitem": &ACLItemArray{}, + "_bool": &BoolArray{}, + "_bytea": &ByteaArray{}, + "_cidr": &CIDRArray{}, + "_date": &DateArray{}, + "_float4": &Float4Array{}, + "_float8": &Float8Array{}, + "_inet": &InetArray{}, + "_int2": &Int2Array{}, + "_int4": &Int4Array{}, + "_int8": &Int8Array{}, + "_numeric": &NumericArray{}, + "_text": &TextArray{}, + "_timestamp": &TimestampArray{}, + "_timestamptz": &TimestamptzArray{}, + "_varchar": &VarcharArray{}, + "aclitem": &ACLItem{}, + "bool": &Bool{}, + "box": &Box{}, + "bytea": &Bytea{}, + "char": &QChar{}, + "cid": &CID{}, + "cidr": &CIDR{}, + "circle": &Circle{}, + "date": &Date{}, + "daterange": &Daterange{}, + "decimal": &Decimal{}, + "float4": &Float4{}, + "float8": &Float8{}, + "hstore": &Hstore{}, + "inet": &Inet{}, + "int2": &Int2{}, + "int4": &Int4{}, + "int4range": &Int4range{}, + "int8": &Int8{}, + "int8range": &Int8range{}, + "json": &JSON{}, + "jsonb": &JSONB{}, + "line": &Line{}, + "lseg": &Lseg{}, + "macaddr": &Macaddr{}, + "name": &Name{}, + "numeric": &Numeric{}, + "numrange": &Numrange{}, + "oid": &OIDValue{}, + "path": &Path{}, + "point": &Point{}, + "polygon": &Polygon{}, + "record": &Record{}, + "text": &Text{}, + "tid": &TID{}, + "timestamp": &Timestamp{}, + "timestamptz": &Timestamptz{}, + "tsrange": &Tsrange{}, + "tstzrange": &Tstzrange{}, + "unknown": &Unknown{}, + "uuid": &UUID{}, + "varbit": &Varbit{}, + "varchar": &Varchar{}, + "xid": &XID{}, + } +} diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go new file mode 100644 index 00000000..f7e743b2 --- /dev/null +++ b/pgtype/pgtype_test.go @@ -0,0 +1,39 @@ +package pgtype_test + +import ( + "net" + "testing" + + _ "github.com/jackc/pgx/stdlib" + _ "github.com/lib/pq" +) + +// Test for renamed types +type _string string +type _bool bool +type _int8 int8 +type _int16 int16 +type _int16Slice []int16 +type _int32Slice []int32 +type _int64Slice []int64 +type _float32Slice []float32 +type _float64Slice []float64 +type _byteSlice []byte + +func mustParseCIDR(t testing.TB, s string) *net.IPNet { + _, ipnet, err := net.ParseCIDR(s) + if err != nil { + t.Fatal(err) + } + + return ipnet +} + +func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { + addr, err := net.ParseMAC(s) + if err != nil { + t.Fatal(err) + } + + return addr +} diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go new file mode 100644 index 00000000..e441a690 --- /dev/null +++ b/pgtype/pguint32.go @@ -0,0 +1,162 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +// pguint32 is the core type that is used to implement PostgreSQL types such as +// CID and XID. +type pguint32 struct { + Uint uint32 + Status Status +} + +// Set converts from src to dst. Note that as pguint32 is not a general +// number type Set does not do automatic type conversion as other number +// types do. +func (dst *pguint32) Set(src interface{}) error { + switch value := src.(type) { + case int64: + if value < 0 { + return errors.Errorf("%d is less than minimum value for pguint32", value) + } + if value > math.MaxUint32 { + return errors.Errorf("%d is greater than maximum value for pguint32", value) + } + *dst = pguint32{Uint: uint32(value), Status: Present} + case uint32: + *dst = pguint32{Uint: value, Status: Present} + default: + return errors.Errorf("cannot convert %v to pguint32", value) + } + + return nil +} + +func (dst *pguint32) Get() interface{} { + switch dst.Status { + case Present: + return dst.Uint + case Null: + return nil + default: + return dst.Status + } +} + +// AssignTo assigns from src to dst. Note that as pguint32 is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *pguint32) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *uint32: + if src.Status == Present { + *v = src.Uint + } else { + return errors.Errorf("cannot assign %v into %T", src, dst) + } + case **uint32: + if src.Status == Present { + n := src.Uint + *v = &n + } else { + *v = nil + } + } + + return nil +} + +func (dst *pguint32) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = pguint32{Status: Null} + return nil + } + + n, err := strconv.ParseUint(string(src), 10, 32) + if err != nil { + return err + } + + *dst = pguint32{Uint: uint32(n), Status: Present} + return nil +} + +func (dst *pguint32) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = pguint32{Status: Null} + return nil + } + + if len(src) != 4 { + return errors.Errorf("invalid length: %v", len(src)) + } + + n := binary.BigEndian.Uint32(src) + *dst = pguint32{Uint: n, Status: Present} + return nil +} + +func (src *pguint32) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, strconv.FormatUint(uint64(src.Uint), 10)...), nil +} + +func (src *pguint32) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return pgio.AppendUint32(buf, src.Uint), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *pguint32) Scan(src interface{}) error { + if src == nil { + *dst = pguint32{Status: Null} + return nil + } + + switch src := src.(type) { + case uint32: + *dst = pguint32{Uint: src, Status: Present} + return nil + case int64: + *dst = pguint32{Uint: uint32(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *pguint32) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Uint), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/point.go b/pgtype/point.go new file mode 100644 index 00000000..3132a939 --- /dev/null +++ b/pgtype/point.go @@ -0,0 +1,139 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Vec2 struct { + X float64 + Y float64 +} + +type Point struct { + P Vec2 + Status Status +} + +func (dst *Point) Set(src interface{}) error { + return errors.Errorf("cannot convert %v to Point", src) +} + +func (dst *Point) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Point) AssignTo(dst interface{}) error { + return errors.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Point{Status: Null} + return nil + } + + if len(src) < 5 { + return errors.Errorf("invalid length for point: %v", len(src)) + } + + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) + if len(parts) < 2 { + return errors.Errorf("invalid format for point") + } + + x, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return err + } + + y, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + return err + } + + *dst = Point{P: Vec2{x, y}, Status: Present} + return nil +} + +func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Point{Status: Null} + return nil + } + + if len(src) != 16 { + return errors.Errorf("invalid length for point: %v", len(src)) + } + + x := binary.BigEndian.Uint64(src) + y := binary.BigEndian.Uint64(src[8:]) + + *dst = Point{ + P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, + Status: Present, + } + return nil +} + +func (src *Point) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, fmt.Sprintf(`(%f,%f)`, src.P.X, src.P.Y)...), nil +} + +func (src *Point) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendUint64(buf, math.Float64bits(src.P.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P.Y)) + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Point) Scan(src interface{}) error { + if src == nil { + *dst = Point{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Point) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/point_test.go b/pgtype/point_test.go new file mode 100644 index 00000000..f46b342d --- /dev/null +++ b/pgtype/point_test.go @@ -0,0 +1,16 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestPointTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "point", []interface{}{ + &pgtype.Point{P: pgtype.Vec2{1.234, 5.6789}, Status: pgtype.Present}, + &pgtype.Point{P: pgtype.Vec2{-1.234, -5.6789}, Status: pgtype.Present}, + &pgtype.Point{Status: pgtype.Null}, + }) +} diff --git a/pgtype/polygon.go b/pgtype/polygon.go new file mode 100644 index 00000000..3f3d9f53 --- /dev/null +++ b/pgtype/polygon.go @@ -0,0 +1,174 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Polygon struct { + P []Vec2 + Status Status +} + +func (dst *Polygon) Set(src interface{}) error { + return errors.Errorf("cannot convert %v to Polygon", src) +} + +func (dst *Polygon) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Polygon) AssignTo(dst interface{}) error { + return errors.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Polygon) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Polygon{Status: Null} + return nil + } + + if len(src) < 7 { + return errors.Errorf("invalid length for Polygon: %v", len(src)) + } + + points := make([]Vec2, 0) + + str := string(src[2:]) + + for { + end := strings.IndexByte(str, ',') + x, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + points = append(points, Vec2{x, y}) + + if end+3 < len(str) { + str = str[end+3:] + } else { + break + } + } + + *dst = Polygon{P: points, Status: Present} + return nil +} + +func (dst *Polygon) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Polygon{Status: Null} + return nil + } + + if len(src) < 5 { + return errors.Errorf("invalid length for Polygon: %v", len(src)) + } + + pointCount := int(binary.BigEndian.Uint32(src)) + rp := 4 + + if 4+pointCount*16 != len(src) { + return errors.Errorf("invalid length for Polygon with %d points: %v", pointCount, len(src)) + } + + points := make([]Vec2, pointCount) + for i := 0; i < len(points); i++ { + x := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + y := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + points[i] = Vec2{math.Float64frombits(x), math.Float64frombits(y)} + } + + *dst = Polygon{ + P: points, + Status: Present, + } + return nil +} + +func (src *Polygon) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, '(') + + for i, p := range src.P { + if i > 0 { + buf = append(buf, ',') + } + buf = append(buf, fmt.Sprintf(`(%f,%f)`, p.X, p.Y)...) + } + + return append(buf, ')'), nil +} + +func (src *Polygon) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendInt32(buf, int32(len(src.P))) + + for _, p := range src.P { + buf = pgio.AppendUint64(buf, math.Float64bits(p.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(p.Y)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Polygon) Scan(src interface{}) error { + if src == nil { + *dst = Polygon{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Polygon) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/polygon_test.go b/pgtype/polygon_test.go new file mode 100644 index 00000000..48481dc5 --- /dev/null +++ b/pgtype/polygon_test.go @@ -0,0 +1,22 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestPolygonTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "polygon", []interface{}{ + &pgtype.Polygon{ + P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {5.0, 3.234}}, + Status: pgtype.Present, + }, + &pgtype.Polygon{ + P: []pgtype.Vec2{{3.14, -1.678}, {7.1, -5.234}, {23.1, 9.34}}, + Status: pgtype.Present, + }, + &pgtype.Polygon{Status: pgtype.Null}, + }) +} diff --git a/pgtype/qchar.go b/pgtype/qchar.go new file mode 100644 index 00000000..064dab1e --- /dev/null +++ b/pgtype/qchar.go @@ -0,0 +1,146 @@ +package pgtype + +import ( + "math" + "strconv" + + "github.com/pkg/errors" +) + +// QChar is for PostgreSQL's special 8-bit-only "char" type more akin to the C +// language's char type, or Go's byte type. (Note that the name in PostgreSQL +// itself is "char", in double-quotes, and not char.) It gets used a lot in +// PostgreSQL's system tables to hold a single ASCII character value (eg +// pg_class.relkind). It is named Qchar for quoted char to disambiguate from SQL +// standard type char. +// +// Not all possible values of QChar are representable in the text format. +// Therefore, QChar does not implement TextEncoder and TextDecoder. In +// addition, database/sql Scanner and database/sql/driver Value are not +// implemented. +type QChar struct { + Int int8 + Status Status +} + +func (dst *QChar) Set(src interface{}) error { + if src == nil { + *dst = QChar{Status: Null} + return nil + } + + switch value := src.(type) { + case int8: + *dst = QChar{Int: value, Status: Present} + case uint8: + if value > math.MaxInt8 { + return errors.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case int16: + if value < math.MinInt8 { + return errors.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return errors.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case uint16: + if value > math.MaxInt8 { + return errors.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case int32: + if value < math.MinInt8 { + return errors.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return errors.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case uint32: + if value > math.MaxInt8 { + return errors.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case int64: + if value < math.MinInt8 { + return errors.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return errors.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case uint64: + if value > math.MaxInt8 { + return errors.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case int: + if value < math.MinInt8 { + return errors.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return errors.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case uint: + if value > math.MaxInt8 { + return errors.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case string: + num, err := strconv.ParseInt(value, 10, 8) + if err != nil { + return err + } + *dst = QChar{Int: int8(num), Status: Present} + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to QChar", value) + } + + return nil +} + +func (dst *QChar) Get() interface{} { + switch dst.Status { + case Present: + return dst.Int + case Null: + return nil + default: + return dst.Status + } +} + +func (src *QChar) AssignTo(dst interface{}) error { + return int64AssignTo(int64(src.Int), src.Status, dst) +} + +func (dst *QChar) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = QChar{Status: Null} + return nil + } + + if len(src) != 1 { + return errors.Errorf(`invalid length for "char": %v`, len(src)) + } + + *dst = QChar{Int: int8(src[0]), Status: Present} + return nil +} + +func (src *QChar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, byte(src.Int)), nil +} diff --git a/pgtype/qchar_test.go b/pgtype/qchar_test.go new file mode 100644 index 00000000..057a557f --- /dev/null +++ b/pgtype/qchar_test.go @@ -0,0 +1,143 @@ +package pgtype_test + +import ( + "math" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestQCharTranscode(t *testing.T) { + testutil.TestPgxSuccessfulTranscodeEqFunc(t, `"char"`, []interface{}{ + &pgtype.QChar{Int: math.MinInt8, Status: pgtype.Present}, + &pgtype.QChar{Int: -1, Status: pgtype.Present}, + &pgtype.QChar{Int: 0, Status: pgtype.Present}, + &pgtype.QChar{Int: 1, Status: pgtype.Present}, + &pgtype.QChar{Int: math.MaxInt8, Status: pgtype.Present}, + &pgtype.QChar{Int: 0, Status: pgtype.Null}, + }, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + }) +} + +func TestQCharSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.QChar + }{ + {source: int8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.QChar + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestQCharAssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + + simpleTests := []struct { + src pgtype.QChar + dst interface{} + expected interface{} + }{ + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.QChar + dst interface{} + expected interface{} + }{ + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.QChar + dst interface{} + }{ + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &i16}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/range.go b/pgtype/range.go new file mode 100644 index 00000000..d870834f --- /dev/null +++ b/pgtype/range.go @@ -0,0 +1,274 @@ +package pgtype + +import ( + "bytes" + "encoding/binary" + + "github.com/pkg/errors" +) + +type BoundType byte + +const ( + Inclusive = BoundType('i') + Exclusive = BoundType('e') + Unbounded = BoundType('U') + Empty = BoundType('E') +) + +type UntypedTextRange struct { + Lower string + Upper string + LowerType BoundType + UpperType BoundType +} + +func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { + utr := &UntypedTextRange{} + if src == "empty" { + utr.LowerType = 'E' + utr.UpperType = 'E' + return utr, nil + } + + buf := bytes.NewBufferString(src) + + skipWhitespace(buf) + + r, _, err := buf.ReadRune() + if err != nil { + return nil, errors.Errorf("invalid lower bound: %v", err) + } + switch r { + case '(': + utr.LowerType = Exclusive + case '[': + utr.LowerType = Inclusive + default: + return nil, errors.Errorf("missing lower bound, instead got: %v", string(r)) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, errors.Errorf("invalid lower value: %v", err) + } + buf.UnreadRune() + + if r == ',' { + utr.LowerType = Unbounded + } else { + utr.Lower, err = rangeParseValue(buf) + if err != nil { + return nil, errors.Errorf("invalid lower value: %v", err) + } + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, errors.Errorf("missing range separator: %v", err) + } + if r != ',' { + return nil, errors.Errorf("missing range separator: %v", r) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, errors.Errorf("invalid upper value: %v", err) + } + buf.UnreadRune() + + if r == ')' || r == ']' { + utr.UpperType = Unbounded + } else { + utr.Upper, err = rangeParseValue(buf) + if err != nil { + return nil, errors.Errorf("invalid upper value: %v", err) + } + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, errors.Errorf("missing upper bound: %v", err) + } + switch r { + case ')': + utr.UpperType = Exclusive + case ']': + utr.UpperType = Inclusive + default: + return nil, errors.Errorf("missing upper bound, instead got: %v", string(r)) + } + + skipWhitespace(buf) + + if buf.Len() > 0 { + return nil, errors.Errorf("unexpected trailing data: %v", buf.String()) + } + + return utr, nil +} + +func rangeParseValue(buf *bytes.Buffer) (string, error) { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + if r == '"' { + return rangeParseQuotedValue(buf) + } + buf.UnreadRune() + + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + case ',', '[', ']', '(', ')': + buf.UnreadRune() + return s.String(), nil + } + + s.WriteRune(r) + } +} + +func rangeParseQuotedValue(buf *bytes.Buffer) (string, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + case '"': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + if r != '"' { + buf.UnreadRune() + return s.String(), nil + } + } + s.WriteRune(r) + } +} + +type UntypedBinaryRange struct { + Lower []byte + Upper []byte + LowerType BoundType + UpperType BoundType +} + +// 0 = () = 00000 +// 1 = empty = 00001 +// 2 = [) = 00010 +// 4 = (] = 00100 +// 6 = [] = 00110 +// 8 = ) = 01000 +// 12 = ] = 01100 +// 16 = ( = 10000 +// 18 = [ = 10010 +// 24 = = 11000 + +const emptyMask = 1 +const lowerInclusiveMask = 2 +const upperInclusiveMask = 4 +const lowerUnboundedMask = 8 +const upperUnboundedMask = 16 + +func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { + ubr := &UntypedBinaryRange{} + + if len(src) == 0 { + return nil, errors.Errorf("range too short: %v", len(src)) + } + + rangeType := src[0] + rp := 1 + + if rangeType&emptyMask > 0 { + if len(src[rp:]) > 0 { + return nil, errors.Errorf("unexpected trailing bytes parsing empty range: %v", len(src[rp:])) + } + ubr.LowerType = Empty + ubr.UpperType = Empty + return ubr, nil + } + + if rangeType&lowerInclusiveMask > 0 { + ubr.LowerType = Inclusive + } else if rangeType&lowerUnboundedMask > 0 { + ubr.LowerType = Unbounded + } else { + ubr.LowerType = Exclusive + } + + if rangeType&upperInclusiveMask > 0 { + ubr.UpperType = Inclusive + } else if rangeType&upperUnboundedMask > 0 { + ubr.UpperType = Unbounded + } else { + ubr.UpperType = Exclusive + } + + if ubr.LowerType == Unbounded && ubr.UpperType == Unbounded { + if len(src[rp:]) > 0 { + return nil, errors.Errorf("unexpected trailing bytes parsing unbounded range: %v", len(src[rp:])) + } + return ubr, nil + } + + if len(src[rp:]) < 4 { + return nil, errors.Errorf("too few bytes for size: %v", src[rp:]) + } + valueLen := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + val := src[rp : rp+valueLen] + rp += valueLen + + if ubr.LowerType != Unbounded { + ubr.Lower = val + } else { + ubr.Upper = val + if len(src[rp:]) > 0 { + return nil, errors.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + } + return ubr, nil + } + + if ubr.UpperType != Unbounded { + if len(src[rp:]) < 4 { + return nil, errors.Errorf("too few bytes for size: %v", src[rp:]) + } + valueLen := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + ubr.Upper = src[rp : rp+valueLen] + rp += valueLen + } + + if len(src[rp:]) > 0 { + return nil, errors.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + } + + return ubr, nil + +} diff --git a/pgtype/range_test.go b/pgtype/range_test.go new file mode 100644 index 00000000..9e16df59 --- /dev/null +++ b/pgtype/range_test.go @@ -0,0 +1,177 @@ +package pgtype + +import ( + "bytes" + "testing" +) + +func TestParseUntypedTextRange(t *testing.T) { + tests := []struct { + src string + result UntypedTextRange + err error + }{ + { + src: `[1,2)`, + result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[1,2]`, + result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: `(1,3)`, + result: UntypedTextRange{Lower: "1", Upper: "3", LowerType: Exclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: ` [1,2) `, + result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[ foo , bar )`, + result: UntypedTextRange{Lower: " foo ", Upper: " bar ", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["foo","bar")`, + result: UntypedTextRange{Lower: "foo", Upper: "bar", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["f""oo","b""ar")`, + result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["f""oo","b""ar")`, + result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["","bar")`, + result: UntypedTextRange{Lower: ``, Upper: `bar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[f\"oo\,,b\\ar\))`, + result: UntypedTextRange{Lower: `f"oo,`, Upper: `b\ar)`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `empty`, + result: UntypedTextRange{Lower: "", Upper: "", LowerType: Empty, UpperType: Empty}, + err: nil, + }, + } + + for i, tt := range tests { + r, err := ParseUntypedTextRange(tt.src) + if err != tt.err { + t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) + continue + } + + if r.LowerType != tt.result.LowerType { + t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType)) + } + + if r.UpperType != tt.result.UpperType { + t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType)) + } + + if r.Lower != tt.result.Lower { + t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower) + } + + if r.Upper != tt.result.Upper { + t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper) + } + } +} + +func TestParseUntypedBinaryRange(t *testing.T) { + tests := []struct { + src []byte + result UntypedBinaryRange + err error + }{ + { + src: []byte{0, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{1}, + result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Empty, UpperType: Empty}, + err: nil, + }, + { + src: []byte{2, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{4, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{6, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{8, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{12, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{16, 0, 0, 0, 2, 0, 4}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Exclusive, UpperType: Unbounded}, + err: nil, + }, + { + src: []byte{18, 0, 0, 0, 2, 0, 4}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Inclusive, UpperType: Unbounded}, + err: nil, + }, + { + src: []byte{24}, + result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Unbounded, UpperType: Unbounded}, + err: nil, + }, + } + + for i, tt := range tests { + r, err := ParseUntypedBinaryRange(tt.src) + if err != tt.err { + t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) + continue + } + + if r.LowerType != tt.result.LowerType { + t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType)) + } + + if r.UpperType != tt.result.UpperType { + t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType)) + } + + if bytes.Compare(r.Lower, tt.result.Lower) != 0 { + t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower) + } + + if bytes.Compare(r.Upper, tt.result.Upper) != 0 { + t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper) + } + } +} diff --git a/pgtype/record.go b/pgtype/record.go new file mode 100644 index 00000000..14b415c3 --- /dev/null +++ b/pgtype/record.go @@ -0,0 +1,124 @@ +package pgtype + +import ( + "encoding/binary" + + "github.com/pkg/errors" +) + +// Record is the generic PostgreSQL record type such as is created with the +// "row" function. Record only implements BinaryEncoder and Value. The text +// format output format from PostgreSQL does not include type information and is +// therefore impossible to decode. No encoders are implemented because +// PostgreSQL does not support input of generic records. +type Record struct { + Fields []Value + Status Status +} + +func (dst *Record) Set(src interface{}) error { + if src == nil { + *dst = Record{Status: Null} + return nil + } + + switch value := src.(type) { + case []Value: + *dst = Record{Fields: value, Status: Present} + default: + return errors.Errorf("cannot convert %v to Record", src) + } + + return nil +} + +func (dst *Record) Get() interface{} { + switch dst.Status { + case Present: + return dst.Fields + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Record) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *[]Value: + *v = make([]Value, len(src.Fields)) + copy(*v, src.Fields) + return nil + case *[]interface{}: + *v = make([]interface{}, len(src.Fields)) + for i := range *v { + (*v)[i] = src.Fields[i].Get() + } + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Record{Status: Null} + return nil + } + + rp := 0 + + if len(src[rp:]) < 4 { + return errors.Errorf("Record incomplete %v", src) + } + fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + fields := make([]Value, fieldCount) + + for i := 0; i < fieldCount; i++ { + if len(src[rp:]) < 8 { + return errors.Errorf("Record incomplete %v", src) + } + fieldOID := OID(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + fieldLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + var binaryDecoder BinaryDecoder + if dt, ok := ci.DataTypeForOID(fieldOID); ok { + if binaryDecoder, ok = dt.Value.(BinaryDecoder); !ok { + return errors.Errorf("unknown oid while decoding record: %v", fieldOID) + } + } + + var fieldBytes []byte + if fieldLen >= 0 { + if len(src[rp:]) < fieldLen { + return errors.Errorf("Record incomplete %v", src) + } + fieldBytes = src[rp : rp+fieldLen] + rp += fieldLen + } + + if err := binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { + return err + } + + fields[i] = binaryDecoder.(Value) + } + + *dst = Record{Fields: fields, Status: Present} + + return nil +} diff --git a/pgtype/record_test.go b/pgtype/record_test.go new file mode 100644 index 00000000..df17501f --- /dev/null +++ b/pgtype/record_test.go @@ -0,0 +1,151 @@ +package pgtype_test + +import ( + "fmt" + "reflect" + "testing" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestRecordTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustClose(t, conn) + + tests := []struct { + sql string + expected pgtype.Record + }{ + { + sql: `select row()`, + expected: pgtype.Record{ + Fields: []pgtype.Value{}, + Status: pgtype.Present, + }, + }, + { + sql: `select row('foo'::text, 42::int4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + pgtype.Int4{Int: 1, Status: pgtype.Present}, + pgtype.Int4{Int: 2, Status: pgtype.Present}, + pgtype.Int4{Status: pgtype.Null}, + pgtype.Int4{Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 4, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select row(null)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Unknown{Status: pgtype.Null}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select null::record`, + expected: pgtype.Record{ + Status: pgtype.Null, + }, + }, + } + + for i, tt := range tests { + psName := fmt.Sprintf("test%d", i) + ps, err := conn.Prepare(psName, tt.sql) + if err != nil { + t.Fatal(err) + } + ps.FieldDescriptions[0].FormatCode = pgx.BinaryFormatCode + + var result pgtype.Record + if err := conn.QueryRow(psName).Scan(&result); err != nil { + t.Errorf("%d: %v", i, err) + continue + } + + if !reflect.DeepEqual(tt.expected, result) { + t.Errorf("%d: expected %v, got %v", i, tt.expected, result) + } + } +} + +func TestRecordAssignTo(t *testing.T) { + var valueSlice []pgtype.Value + var interfaceSlice []interface{} + + simpleTests := []struct { + src pgtype.Record + dst interface{} + expected interface{} + }{ + { + src: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + dst: &valueSlice, + expected: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + }, + { + src: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + dst: &interfaceSlice, + expected: []interface{}{"foo", int32(42)}, + }, + { + src: pgtype.Record{Status: pgtype.Null}, + dst: &valueSlice, + expected: (([]pgtype.Value)(nil)), + }, + { + src: pgtype.Record{Status: pgtype.Null}, + dst: &interfaceSlice, + expected: (([]interface{})(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/testutil/testutil.go b/pgtype/testutil/testutil.go new file mode 100644 index 00000000..0effb42d --- /dev/null +++ b/pgtype/testutil/testutil.go @@ -0,0 +1,297 @@ +package testutil + +import ( + "context" + "database/sql" + "fmt" + "os" + "reflect" + "testing" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" + _ "github.com/jackc/pgx/stdlib" + _ "github.com/lib/pq" +) + +func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { + var sqlDriverName string + switch driverName { + case "github.com/lib/pq": + sqlDriverName = "postgres" + case "github.com/jackc/pgx/stdlib": + sqlDriverName = "pgx" + default: + t.Fatalf("Unknown driver %v", driverName) + } + + db, err := sql.Open(sqlDriverName, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + t.Fatal(err) + } + + return db +} + +func MustConnectPgx(t testing.TB) *pgx.Conn { + config, err := pgx.ParseConnectionString(os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + t.Fatal(err) + } + + conn, err := pgx.Connect(config) + if err != nil { + t.Fatal(err) + } + + return conn +} + +func MustClose(t testing.TB, conn interface { + Close() error +}) { + err := conn.Close() + if err != nil { + t.Fatal(err) + } +} + +type forceTextEncoder struct { + e pgtype.TextEncoder +} + +func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + return f.e.EncodeText(ci, buf) +} + +type forceBinaryEncoder struct { + e pgtype.BinaryEncoder +} + +func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + return f.e.EncodeBinary(ci, buf) +} + +func ForceEncoder(e interface{}, formatCode int16) interface{} { + switch formatCode { + case pgx.TextFormatCode: + if e, ok := e.(pgtype.TextEncoder); ok { + return forceTextEncoder{e: e} + } + case pgx.BinaryFormatCode: + if e, ok := e.(pgtype.BinaryEncoder); ok { + return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)} + } + } + return nil +} + +func TestSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) { + TestSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + }) +} + +func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) + } +} + +func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := MustConnectPgx(t) + defer MustClose(t, conn) + + ps, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for i, v := range values { + for _, fc := range formats { + ps.FieldDescriptions[0].FormatCode = fc.formatCode + vEncoder := ForceEncoder(v, fc.formatCode) + if vEncoder == nil { + t.Logf("Skipping: %#v does not implement %v", v, fc.name) + continue + } + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := conn.QueryRow("test", ForceEncoder(v, fc.formatCode)).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", fc.name, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) + } + } + } +} + +func TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := MustConnectPgx(t) + defer MustClose(t, conn) + + for i, v := range values { + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := conn.QueryRowEx( + context.Background(), + fmt.Sprintf("select ($1)::%s", pgTypeName), + &pgx.QueryExOptions{SimpleProtocol: true}, + v, + ).Scan(result.Interface()) + if err != nil { + t.Errorf("Simple protocol %d: %v", i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("Simple protocol %d: expected %v, got %v", i, derefV, result.Elem().Interface()) + } + } +} + +func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := MustConnectDatabaseSQL(t, driverName) + defer MustClose(t, conn) + + ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + for i, v := range values { + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := ps.QueryRow(v).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", driverName, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) + } + } +} + +type NormalizeTest struct { + SQL string + Value interface{} +} + +func TestSuccessfulNormalize(t testing.TB, tests []NormalizeTest) { + TestSuccessfulNormalizeEqFunc(t, tests, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + }) +} + +func TestSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { + TestPgxSuccessfulNormalizeEqFunc(t, tests, eqFunc) + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + TestDatabaseSQLSuccessfulNormalizeEqFunc(t, driverName, tests, eqFunc) + } +} + +func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { + conn := MustConnectPgx(t) + defer MustClose(t, conn) + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for i, tt := range tests { + for _, fc := range formats { + psName := fmt.Sprintf("test%d", i) + ps, err := conn.Prepare(psName, tt.SQL) + if err != nil { + t.Fatal(err) + } + + ps.FieldDescriptions[0].FormatCode = fc.formatCode + if ForceEncoder(tt.Value, fc.formatCode) == nil { + t.Logf("Skipping: %#v does not implement %v", tt.Value, fc.name) + continue + } + // Derefence value if it is a pointer + derefV := tt.Value + refVal := reflect.ValueOf(tt.Value) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err = conn.QueryRow(psName).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", fc.name, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) + } + } + } +} + +func TestDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { + conn := MustConnectDatabaseSQL(t, driverName) + defer MustClose(t, conn) + + for i, tt := range tests { + ps, err := conn.Prepare(tt.SQL) + if err != nil { + t.Errorf("%d. %v", i, err) + continue + } + + // Derefence value if it is a pointer + derefV := tt.Value + refVal := reflect.ValueOf(tt.Value) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err = ps.QueryRow().Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", driverName, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) + } + } +} diff --git a/pgtype/text.go b/pgtype/text.go new file mode 100644 index 00000000..f05e1e89 --- /dev/null +++ b/pgtype/text.go @@ -0,0 +1,151 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/json" + + "github.com/pkg/errors" +) + +type Text struct { + String string + Status Status +} + +func (dst *Text) Set(src interface{}) error { + if src == nil { + *dst = Text{Status: Null} + return nil + } + + switch value := src.(type) { + case string: + *dst = Text{String: value, Status: Present} + case *string: + if value == nil { + *dst = Text{Status: Null} + } else { + *dst = Text{String: *value, Status: Present} + } + case []byte: + if value == nil { + *dst = Text{Status: Null} + } else { + *dst = Text{String: string(value), Status: Present} + } + default: + if originalSrc, ok := underlyingStringType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Text", value) + } + + return nil +} + +func (dst *Text) Get() interface{} { + switch dst.Status { + case Present: + return dst.String + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Text) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *string: + *v = src.String + return nil + case *[]byte: + *v = make([]byte, len(src.String)) + copy(*v, src.String) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *Text) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Text{Status: Null} + return nil + } + + *dst = Text{String: string(src), Status: Present} + return nil +} + +func (dst *Text) DecodeBinary(ci *ConnInfo, src []byte) error { + return dst.DecodeText(ci, src) +} + +func (src *Text) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, src.String...), nil +} + +func (src *Text) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return src.EncodeText(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Text) Scan(src interface{}) error { + if src == nil { + *dst = Text{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Text) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.String, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} + +func (src *Text) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + return json.Marshal(src.String) + case Null: + return []byte("null"), nil + case Undefined: + return []byte("undefined"), nil + } + + return nil, errBadStatus +} diff --git a/pgtype/text_array.go b/pgtype/text_array.go new file mode 100644 index 00000000..2609a2cc --- /dev/null +++ b/pgtype/text_array.go @@ -0,0 +1,294 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type TextArray struct { + Elements []Text + Dimensions []ArrayDimension + Status Status +} + +func (dst *TextArray) Set(src interface{}) error { + switch value := src.(type) { + + case []string: + if value == nil { + *dst = TextArray{Status: Null} + } else if len(value) == 0 { + *dst = TextArray{Status: Present} + } else { + elements := make([]Text, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = TextArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Text", value) + } + + return nil +} + +func (dst *TextArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *TextArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *TextArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TextArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Text + + if len(uta.Elements) > 0 { + elements = make([]Text, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Text + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = TextArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *TextArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TextArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = TextArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Text, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = TextArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *TextArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `"NULL"`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src *TextArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("text"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "text") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *TextArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *TextArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/text_array_test.go b/pgtype/text_array_test.go new file mode 100644 index 00000000..35ebef96 --- /dev/null +++ b/pgtype/text_array_test.go @@ -0,0 +1,152 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestTextArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "text[]", []interface{}{ + &pgtype.TextArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.TextArray{ + Elements: []pgtype.Text{ + pgtype.Text{String: "foo", Status: pgtype.Present}, + pgtype.Text{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TextArray{Status: pgtype.Null}, + &pgtype.TextArray{ + Elements: []pgtype.Text{ + pgtype.Text{String: "bar ", Status: pgtype.Present}, + pgtype.Text{String: "NuLL", Status: pgtype.Present}, + pgtype.Text{String: `wow"quz\`, Status: pgtype.Present}, + pgtype.Text{String: "", Status: pgtype.Present}, + pgtype.Text{Status: pgtype.Null}, + pgtype.Text{String: "null", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TextArray{ + Elements: []pgtype.Text{ + pgtype.Text{String: "bar", Status: pgtype.Present}, + pgtype.Text{String: "baz", Status: pgtype.Present}, + pgtype.Text{String: "quz", Status: pgtype.Present}, + pgtype.Text{String: "foo", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestTextArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.TextArray + }{ + { + source: []string{"foo"}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]string)(nil)), + result: pgtype.TextArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.TextArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTextArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + + simpleTests := []struct { + src pgtype.TextArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + expected: []string{"foo"}, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedStringSlice, + expected: _stringSlice{"bar"}, + }, + { + src: pgtype.TextArray{Status: pgtype.Null}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.TextArray + dst interface{} + }{ + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/text_test.go b/pgtype/text_test.go new file mode 100644 index 00000000..bd971807 --- /dev/null +++ b/pgtype/text_test.go @@ -0,0 +1,123 @@ +package pgtype_test + +import ( + "bytes" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestTextTranscode(t *testing.T) { + for _, pgTypeName := range []string{"text", "varchar"} { + testutil.TestSuccessfulTranscode(t, pgTypeName, []interface{}{ + &pgtype.Text{String: "", Status: pgtype.Present}, + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Text{Status: pgtype.Null}, + }) + } +} + +func TestTextSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Text + }{ + {source: "foo", result: pgtype.Text{String: "foo", Status: pgtype.Present}}, + {source: _string("bar"), result: pgtype.Text{String: "bar", Status: pgtype.Present}}, + {source: (*string)(nil), result: pgtype.Text{Status: pgtype.Null}}, + } + + for i, tt := range successfulTests { + var d pgtype.Text + err := d.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestTextAssignTo(t *testing.T) { + var s string + var ps *string + + stringTests := []struct { + src pgtype.Text + dst interface{} + expected interface{} + }{ + {src: pgtype.Text{String: "foo", Status: pgtype.Present}, dst: &s, expected: "foo"}, + {src: pgtype.Text{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range stringTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + var buf []byte + + bytesTests := []struct { + src pgtype.Text + dst *[]byte + expected []byte + }{ + {src: pgtype.Text{String: "foo", Status: pgtype.Present}, dst: &buf, expected: []byte("foo")}, + {src: pgtype.Text{Status: pgtype.Null}, dst: &buf, expected: nil}, + } + + for i, tt := range bytesTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if bytes.Compare(*tt.dst, tt.expected) != 0 { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, tt.dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Text + dst interface{} + expected interface{} + }{ + {src: pgtype.Text{String: "foo", Status: pgtype.Present}, dst: &ps, expected: "foo"}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Text + dst interface{} + }{ + {src: pgtype.Text{Status: pgtype.Null}, dst: &s}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/tid.go b/pgtype/tid.go new file mode 100644 index 00000000..21852a14 --- /dev/null +++ b/pgtype/tid.go @@ -0,0 +1,144 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +// TID is PostgreSQL's Tuple Identifier type. +// +// When one does +// +// select ctid, * from some_table; +// +// it is the data type of the ctid hidden system column. +// +// It is currently implemented as a pair unsigned two byte integers. +// Its conversion functions can be found in src/backend/utils/adt/tid.c +// in the PostgreSQL sources. +type TID struct { + BlockNumber uint32 + OffsetNumber uint16 + Status Status +} + +func (dst *TID) Set(src interface{}) error { + return errors.Errorf("cannot convert %v to TID", src) +} + +func (dst *TID) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *TID) AssignTo(dst interface{}) error { + return errors.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *TID) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TID{Status: Null} + return nil + } + + if len(src) < 5 { + return errors.Errorf("invalid length for tid: %v", len(src)) + } + + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) + if len(parts) < 2 { + return errors.Errorf("invalid format for tid") + } + + blockNumber, err := strconv.ParseUint(parts[0], 10, 32) + if err != nil { + return err + } + + offsetNumber, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return err + } + + *dst = TID{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber), Status: Present} + return nil +} + +func (dst *TID) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TID{Status: Null} + return nil + } + + if len(src) != 6 { + return errors.Errorf("invalid length for tid: %v", len(src)) + } + + *dst = TID{ + BlockNumber: binary.BigEndian.Uint32(src), + OffsetNumber: binary.BigEndian.Uint16(src[4:]), + Status: Present, + } + return nil +} + +func (src *TID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, fmt.Sprintf(`(%d,%d)`, src.BlockNumber, src.OffsetNumber)...) + return buf, nil +} + +func (src *TID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendUint32(buf, src.BlockNumber) + buf = pgio.AppendUint16(buf, src.OffsetNumber) + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *TID) Scan(src interface{}) error { + if src == nil { + *dst = TID{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *TID) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/tid_test.go b/pgtype/tid_test.go new file mode 100644 index 00000000..9185cb31 --- /dev/null +++ b/pgtype/tid_test.go @@ -0,0 +1,16 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestTIDTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "tid", []interface{}{ + &pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, + &pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, + &pgtype.TID{Status: pgtype.Null}, + }) +} diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go new file mode 100644 index 00000000..d906f467 --- /dev/null +++ b/pgtype/timestamp.go @@ -0,0 +1,225 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "time" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +const pgTimestampFormat = "2006-01-02 15:04:05.999999999" + +// Timestamp represents the PostgreSQL timestamp type. The PostgreSQL +// timestamp does not have a time zone. This presents a problem when +// translating to and from time.Time which requires a time zone. It is highly +// recommended to use timestamptz whenever possible. Timestamp methods either +// convert to UTC or return an error on non-UTC times. +type Timestamp struct { + Time time.Time // Time must always be in UTC. + Status Status + InfinityModifier InfinityModifier +} + +// Set converts src into a Timestamp and stores in dst. If src is a +// time.Time in a non-UTC time zone, the time zone is discarded. +func (dst *Timestamp) Set(src interface{}) error { + if src == nil { + *dst = Timestamp{Status: Null} + return nil + } + + switch value := src.(type) { + case time.Time: + *dst = Timestamp{Time: time.Date(value.Year(), value.Month(), value.Day(), value.Hour(), value.Minute(), value.Second(), value.Nanosecond(), time.UTC), Status: Present} + default: + if originalSrc, ok := underlyingTimeType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Timestamp", value) + } + + return nil +} + +func (dst *Timestamp) Get() interface{} { + switch dst.Status { + case Present: + if dst.InfinityModifier != None { + return dst.InfinityModifier + } + return dst.Time + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Timestamp) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *time.Time: + if src.InfinityModifier != None { + return errors.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.Time + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +// DecodeText decodes from src into dst. The decoded time is considered to +// be in UTC. +func (dst *Timestamp) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Timestamp{Status: Null} + return nil + } + + sbuf := string(src) + switch sbuf { + case "infinity": + *dst = Timestamp{Status: Present, InfinityModifier: Infinity} + case "-infinity": + *dst = Timestamp{Status: Present, InfinityModifier: -Infinity} + default: + tim, err := time.Parse(pgTimestampFormat, sbuf) + if err != nil { + return err + } + + *dst = Timestamp{Time: tim, Status: Present} + } + + return nil +} + +// DecodeBinary decodes from src into dst. The decoded time is considered to +// be in UTC. +func (dst *Timestamp) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Timestamp{Status: Null} + return nil + } + + if len(src) != 8 { + return errors.Errorf("invalid length for timestamp: %v", len(src)) + } + + microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) + + switch microsecSinceY2K { + case infinityMicrosecondOffset: + *dst = Timestamp{Status: Present, InfinityModifier: Infinity} + case negativeInfinityMicrosecondOffset: + *dst = Timestamp{Status: Present, InfinityModifier: -Infinity} + default: + microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K + tim := time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000).UTC() + *dst = Timestamp{Time: tim, Status: Present} + } + + return nil +} + +// EncodeText writes the text encoding of src into w. If src.Time is not in +// the UTC time zone it returns an error. +func (src *Timestamp) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + if src.Time.Location() != time.UTC { + return nil, errors.Errorf("cannot encode non-UTC time into timestamp") + } + + var s string + + switch src.InfinityModifier { + case None: + s = src.Time.Format(pgTimestampFormat) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return append(buf, s...), nil +} + +// EncodeBinary writes the binary encoding of src into w. If src.Time is not in +// the UTC time zone it returns an error. +func (src *Timestamp) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + if src.Time.Location() != time.UTC { + return nil, errors.Errorf("cannot encode non-UTC time into timestamp") + } + + var microsecSinceY2K int64 + switch src.InfinityModifier { + case None: + microsecSinceUnixEpoch := src.Time.Unix()*1000000 + int64(src.Time.Nanosecond())/1000 + microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + case Infinity: + microsecSinceY2K = infinityMicrosecondOffset + case NegativeInfinity: + microsecSinceY2K = negativeInfinityMicrosecondOffset + } + + return pgio.AppendInt64(buf, microsecSinceY2K), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Timestamp) Scan(src interface{}) error { + if src == nil { + *dst = Timestamp{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + case time.Time: + *dst = Timestamp{Time: src, Status: Present} + return nil + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Timestamp) Value() (driver.Value, error) { + switch src.Status { + case Present: + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/timestamp_array.go b/pgtype/timestamp_array.go new file mode 100644 index 00000000..be281f2e --- /dev/null +++ b/pgtype/timestamp_array.go @@ -0,0 +1,295 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "time" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type TimestampArray struct { + Elements []Timestamp + Dimensions []ArrayDimension + Status Status +} + +func (dst *TimestampArray) Set(src interface{}) error { + switch value := src.(type) { + + case []time.Time: + if value == nil { + *dst = TimestampArray{Status: Null} + } else if len(value) == 0 { + *dst = TimestampArray{Status: Present} + } else { + elements := make([]Timestamp, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = TimestampArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Timestamp", value) + } + + return nil +} + +func (dst *TimestampArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *TimestampArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]time.Time: + *v = make([]time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *TimestampArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TimestampArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Timestamp + + if len(uta.Elements) > 0 { + elements = make([]Timestamp, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Timestamp + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = TimestampArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *TimestampArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TimestampArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = TimestampArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Timestamp, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = TimestampArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *TimestampArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src *TimestampArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("timestamp"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "timestamp") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *TimestampArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *TimestampArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/timestamp_array_test.go b/pgtype/timestamp_array_test.go new file mode 100644 index 00000000..c75d101f --- /dev/null +++ b/pgtype/timestamp_array_test.go @@ -0,0 +1,159 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestTimestampArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp[]", []interface{}{ + &pgtype.TimestampArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + pgtype.Timestamp{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TimestampArray{Status: pgtype.Null}, + &pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + pgtype.Timestamp{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Status: pgtype.Null}, + pgtype.Timestamp{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + pgtype.Timestamp{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }, func(a, b interface{}) bool { + ata := a.(pgtype.TimestampArray) + bta := b.(pgtype.TimestampArray) + + if len(ata.Elements) != len(bta.Elements) || ata.Status != bta.Status { + return false + } + + for i := range ata.Elements { + ae, be := ata.Elements[i], bta.Elements[i] + if !(ae.Time.Equal(be.Time) && ae.Status == be.Status && ae.InfinityModifier == be.InfinityModifier) { + return false + } + } + + return true + }) +} + +func TestTimestampArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.TimestampArray + }{ + { + source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + result: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]time.Time)(nil)), + result: pgtype.TimestampArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.TimestampArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTimestampArrayAssignTo(t *testing.T) { + var timeSlice []time.Time + + simpleTests := []struct { + src pgtype.TimestampArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + }, + { + src: pgtype.TimestampArray{Status: pgtype.Null}, + dst: &timeSlice, + expected: (([]time.Time)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.TimestampArray + dst interface{} + }{ + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/timestamp_test.go b/pgtype/timestamp_test.go new file mode 100644 index 00000000..267f1a7e --- /dev/null +++ b/pgtype/timestamp_test.go @@ -0,0 +1,123 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestTimestampTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp", []interface{}{ + &pgtype.Timestamp{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Status: pgtype.Null}, + &pgtype.Timestamp{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, + &pgtype.Timestamp{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + }, func(a, b interface{}) bool { + at := a.(pgtype.Timestamp) + bt := b.(pgtype.Timestamp) + + return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier + }) +} + +func TestTimestampSet(t *testing.T) { + type _time time.Time + + successfulTests := []struct { + source interface{} + result pgtype.Timestamp + }{ + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Timestamp + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTimestampAssignTo(t *testing.T) { + var tim time.Time + var ptim *time.Time + + simpleTests := []struct { + src pgtype.Timestamp + dst interface{} + expected interface{} + }{ + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC)}, + {src: pgtype.Timestamp{Time: time.Time{}, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Timestamp + dst interface{} + expected interface{} + }{ + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Timestamp + dst interface{} + }{ + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Null}, dst: &tim}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go new file mode 100644 index 00000000..74fe4954 --- /dev/null +++ b/pgtype/timestamptz.go @@ -0,0 +1,221 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "time" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +const pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" +const pgTimestamptzMinuteFormat = "2006-01-02 15:04:05.999999999Z07:00" +const pgTimestamptzSecondFormat = "2006-01-02 15:04:05.999999999Z07:00:00" +const microsecFromUnixEpochToY2K = 946684800 * 1000000 + +const ( + negativeInfinityMicrosecondOffset = -9223372036854775808 + infinityMicrosecondOffset = 9223372036854775807 +) + +type Timestamptz struct { + Time time.Time + Status Status + InfinityModifier InfinityModifier +} + +func (dst *Timestamptz) Set(src interface{}) error { + if src == nil { + *dst = Timestamptz{Status: Null} + return nil + } + + switch value := src.(type) { + case time.Time: + *dst = Timestamptz{Time: value, Status: Present} + default: + if originalSrc, ok := underlyingTimeType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Timestamptz", value) + } + + return nil +} + +func (dst *Timestamptz) Get() interface{} { + switch dst.Status { + case Present: + if dst.InfinityModifier != None { + return dst.InfinityModifier + } + return dst.Time + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Timestamptz) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *time.Time: + if src.InfinityModifier != None { + return errors.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.Time + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *Timestamptz) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Timestamptz{Status: Null} + return nil + } + + sbuf := string(src) + switch sbuf { + case "infinity": + *dst = Timestamptz{Status: Present, InfinityModifier: Infinity} + case "-infinity": + *dst = Timestamptz{Status: Present, InfinityModifier: -Infinity} + default: + var format string + if sbuf[len(sbuf)-9] == '-' || sbuf[len(sbuf)-9] == '+' { + format = pgTimestamptzSecondFormat + } else if sbuf[len(sbuf)-6] == '-' || sbuf[len(sbuf)-6] == '+' { + format = pgTimestamptzMinuteFormat + } else { + format = pgTimestamptzHourFormat + } + + tim, err := time.Parse(format, sbuf) + if err != nil { + return err + } + + *dst = Timestamptz{Time: tim, Status: Present} + } + + return nil +} + +func (dst *Timestamptz) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Timestamptz{Status: Null} + return nil + } + + if len(src) != 8 { + return errors.Errorf("invalid length for timestamptz: %v", len(src)) + } + + microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) + + switch microsecSinceY2K { + case infinityMicrosecondOffset: + *dst = Timestamptz{Status: Present, InfinityModifier: Infinity} + case negativeInfinityMicrosecondOffset: + *dst = Timestamptz{Status: Present, InfinityModifier: -Infinity} + default: + microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K + tim := time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) + *dst = Timestamptz{Time: tim, Status: Present} + } + + return nil +} + +func (src *Timestamptz) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var s string + + switch src.InfinityModifier { + case None: + s = src.Time.UTC().Format(pgTimestamptzSecondFormat) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return append(buf, s...), nil +} + +func (src *Timestamptz) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var microsecSinceY2K int64 + switch src.InfinityModifier { + case None: + microsecSinceUnixEpoch := src.Time.Unix()*1000000 + int64(src.Time.Nanosecond())/1000 + microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + case Infinity: + microsecSinceY2K = infinityMicrosecondOffset + case NegativeInfinity: + microsecSinceY2K = negativeInfinityMicrosecondOffset + } + + return pgio.AppendInt64(buf, microsecSinceY2K), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Timestamptz) Scan(src interface{}) error { + if src == nil { + *dst = Timestamptz{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + case time.Time: + *dst = Timestamptz{Time: src, Status: Present} + return nil + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Timestamptz) Value() (driver.Value, error) { + switch src.Status { + case Present: + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/timestamptz_array.go b/pgtype/timestamptz_array.go new file mode 100644 index 00000000..086a4ef0 --- /dev/null +++ b/pgtype/timestamptz_array.go @@ -0,0 +1,295 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "time" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type TimestamptzArray struct { + Elements []Timestamptz + Dimensions []ArrayDimension + Status Status +} + +func (dst *TimestamptzArray) Set(src interface{}) error { + switch value := src.(type) { + + case []time.Time: + if value == nil { + *dst = TimestamptzArray{Status: Null} + } else if len(value) == 0 { + *dst = TimestamptzArray{Status: Present} + } else { + elements := make([]Timestamptz, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = TimestamptzArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Timestamptz", value) + } + + return nil +} + +func (dst *TimestamptzArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *TimestamptzArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]time.Time: + *v = make([]time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *TimestamptzArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TimestamptzArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Timestamptz + + if len(uta.Elements) > 0 { + elements = make([]Timestamptz, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Timestamptz + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = TimestamptzArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *TimestamptzArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TimestamptzArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = TimestamptzArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Timestamptz, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = TimestamptzArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *TimestamptzArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("timestamptz"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "timestamptz") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *TimestamptzArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *TimestamptzArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/timestamptz_array_test.go b/pgtype/timestamptz_array_test.go new file mode 100644 index 00000000..50ee65d0 --- /dev/null +++ b/pgtype/timestamptz_array_test.go @@ -0,0 +1,159 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestTimestamptzArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz[]", []interface{}{ + &pgtype.TimestamptzArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + pgtype.Timestamptz{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TimestamptzArray{Status: pgtype.Null}, + &pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + pgtype.Timestamptz{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Status: pgtype.Null}, + pgtype.Timestamptz{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + pgtype.Timestamptz{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }, func(a, b interface{}) bool { + ata := a.(pgtype.TimestamptzArray) + bta := b.(pgtype.TimestamptzArray) + + if len(ata.Elements) != len(bta.Elements) || ata.Status != bta.Status { + return false + } + + for i := range ata.Elements { + ae, be := ata.Elements[i], bta.Elements[i] + if !(ae.Time.Equal(be.Time) && ae.Status == be.Status && ae.InfinityModifier == be.InfinityModifier) { + return false + } + } + + return true + }) +} + +func TestTimestamptzArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.TimestamptzArray + }{ + { + source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]time.Time)(nil)), + result: pgtype.TimestamptzArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.TimestamptzArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTimestamptzArrayAssignTo(t *testing.T) { + var timeSlice []time.Time + + simpleTests := []struct { + src pgtype.TimestamptzArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + }, + { + src: pgtype.TimestamptzArray{Status: pgtype.Null}, + dst: &timeSlice, + expected: (([]time.Time)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.TimestamptzArray + dst interface{} + }{ + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/timestamptz_test.go b/pgtype/timestamptz_test.go new file mode 100644 index 00000000..c326802d --- /dev/null +++ b/pgtype/timestamptz_test.go @@ -0,0 +1,122 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestTimestamptzTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz", []interface{}{ + &pgtype.Timestamptz{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Status: pgtype.Null}, + &pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, + &pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + }, func(a, b interface{}) bool { + at := a.(pgtype.Timestamptz) + bt := b.(pgtype.Timestamptz) + + return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier + }) +} + +func TestTimestamptzSet(t *testing.T) { + type _time time.Time + + successfulTests := []struct { + source interface{} + result pgtype.Timestamptz + }{ + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local)), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Timestamptz + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTimestamptzAssignTo(t *testing.T) { + var tim time.Time + var ptim *time.Time + + simpleTests := []struct { + src pgtype.Timestamptz + dst interface{} + expected interface{} + }{ + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + {src: pgtype.Timestamptz{Time: time.Time{}, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Timestamptz + dst interface{} + expected interface{} + }{ + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Timestamptz + dst interface{} + }{ + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Null}, dst: &tim}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/tsrange.go b/pgtype/tsrange.go new file mode 100644 index 00000000..8a67d65e --- /dev/null +++ b/pgtype/tsrange.go @@ -0,0 +1,250 @@ +package pgtype + +import ( + "database/sql/driver" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Tsrange struct { + Lower Timestamp + Upper Timestamp + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Tsrange) Set(src interface{}) error { + return errors.Errorf("cannot convert %v to Tsrange", src) +} + +func (dst *Tsrange) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Tsrange) AssignTo(dst interface{}) error { + return errors.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Tsrange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Tsrange{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Tsrange{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Tsrange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Tsrange{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Tsrange{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Tsrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) + } + + var err error + + if src.LowerType != Unbounded { + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if src.UpperType != Unbounded { + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) + } + + return buf, nil +} + +func (src Tsrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) + } + + buf = append(buf, rangeType) + + var err error + + if src.LowerType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if src.UpperType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Tsrange) Scan(src interface{}) error { + if src == nil { + *dst = Tsrange{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Tsrange) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/tsrange_test.go b/pgtype/tsrange_test.go new file mode 100644 index 00000000..78eb1cd3 --- /dev/null +++ b/pgtype/tsrange_test.go @@ -0,0 +1,41 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestTsrangeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "tsrange", []interface{}{ + &pgtype.Tsrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + &pgtype.Tsrange{ + Lower: pgtype.Timestamp{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Timestamp{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + &pgtype.Tsrange{ + Lower: pgtype.Timestamp{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + &pgtype.Tsrange{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Tsrange) + b := bb.(pgtype.Tsrange) + + return a.Status == b.Status && + a.Lower.Time.Equal(b.Lower.Time) && + a.Lower.Status == b.Lower.Status && + a.Lower.InfinityModifier == b.Lower.InfinityModifier && + a.Upper.Time.Equal(b.Upper.Time) && + a.Upper.Status == b.Upper.Status && + a.Upper.InfinityModifier == b.Upper.InfinityModifier + }) +} diff --git a/pgtype/tstzrange.go b/pgtype/tstzrange.go new file mode 100644 index 00000000..b5129093 --- /dev/null +++ b/pgtype/tstzrange.go @@ -0,0 +1,250 @@ +package pgtype + +import ( + "database/sql/driver" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Tstzrange struct { + Lower Timestamptz + Upper Timestamptz + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Tstzrange) Set(src interface{}) error { + return errors.Errorf("cannot convert %v to Tstzrange", src) +} + +func (dst *Tstzrange) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Tstzrange) AssignTo(dst interface{}) error { + return errors.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Tstzrange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Tstzrange{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Tstzrange{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Tstzrange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Tstzrange{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Tstzrange{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Tstzrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) + } + + var err error + + if src.LowerType != Unbounded { + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if src.UpperType != Unbounded { + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) + } + + return buf, nil +} + +func (src Tstzrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) + } + + buf = append(buf, rangeType) + + var err error + + if src.LowerType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if src.UpperType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Tstzrange) Scan(src interface{}) error { + if src == nil { + *dst = Tstzrange{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Tstzrange) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/tstzrange_test.go b/pgtype/tstzrange_test.go new file mode 100644 index 00000000..a27ddd3a --- /dev/null +++ b/pgtype/tstzrange_test.go @@ -0,0 +1,41 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestTstzrangeTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "tstzrange", []interface{}{ + &pgtype.Tstzrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + &pgtype.Tstzrange{ + Lower: pgtype.Timestamptz{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Timestamptz{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + &pgtype.Tstzrange{ + Lower: pgtype.Timestamptz{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + &pgtype.Tstzrange{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Tstzrange) + b := bb.(pgtype.Tstzrange) + + return a.Status == b.Status && + a.Lower.Time.Equal(b.Lower.Time) && + a.Lower.Status == b.Lower.Status && + a.Lower.InfinityModifier == b.Lower.InfinityModifier && + a.Upper.Time.Equal(b.Upper.Time) && + a.Upper.Status == b.Upper.Status && + a.Upper.InfinityModifier == b.Upper.InfinityModifier + }) +} diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb new file mode 100644 index 00000000..7a69d0ab --- /dev/null +++ b/pgtype/typed_array.go.erb @@ -0,0 +1,298 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type <%= pgtype_array_type %> struct { + Elements []<%= pgtype_element_type %> + Dimensions []ArrayDimension + Status Status +} + +func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { + switch value := src.(type) { + <% go_array_types.split(",").each do |t| %> + case <%= t %>: + if value == nil { + *dst = <%= pgtype_array_type %>{Status: Null} + } else if len(value) == 0 { + *dst = <%= pgtype_array_type %>{Status: Present} + } else { + elements := make([]<%= pgtype_element_type %>, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = <%= pgtype_array_type %>{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + <% end %> + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to <%= pgtype_element_type %>", value) + } + + return nil +} + +func (dst *<%= pgtype_array_type %>) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + <% go_array_types.split(",").each do |t| %> + case *<%= t %>: + *v = make(<%= t %>, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + <% end %> + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = <%= pgtype_array_type %>{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []<%= pgtype_element_type %> + + if len(uta.Elements) > 0 { + elements = make([]<%= pgtype_element_type %>, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem <%= pgtype_element_type %> + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = <%= pgtype_array_type %>{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +<% if binary_format == "true" %> +func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = <%= pgtype_array_type %>{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = <%= pgtype_array_type %>{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]<%= pgtype_element_type %>, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp:rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = <%= pgtype_array_type %>{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} +<% end %> + +func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `<%= text_null %>`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +<% if binary_format == "true" %> + func (src *<%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("<%= element_type_name %>"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil + } +<% end %> + +// Scan implements the database/sql Scanner interface. +func (dst *<%= pgtype_array_type %>) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *<%= pgtype_array_type %>) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh new file mode 100644 index 00000000..1aa6c354 --- /dev/null +++ b/pgtype/typed_array_gen.sh @@ -0,0 +1,18 @@ +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go +erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go +erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go +erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_type_name=timestamptz text_null=NULL binary_format=true typed_array.go.erb > timestamptz_array.go +erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_type_name=timestamp text_null=NULL binary_format=true typed_array.go.erb > timestamp_array.go +erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go +erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go +erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_type_name=inet text_null=NULL binary_format=true typed_array.go.erb > inet_array.go +erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR go_array_types=[]*net.IPNet,[]net.IP element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go +erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null='"NULL"' binary_format=true typed_array.go.erb > text_array.go +erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_type_name=varchar text_null='"NULL"' binary_format=true typed_array.go.erb > varchar_array.go +erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go +erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go +erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go +erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]float64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go +goimports -w *_array.go diff --git a/pgtype/typed_range.go.erb b/pgtype/typed_range.go.erb new file mode 100644 index 00000000..91a5cb97 --- /dev/null +++ b/pgtype/typed_range.go.erb @@ -0,0 +1,252 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type <%= range_type %> struct { + Lower <%= element_type %> + Upper <%= element_type %> + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *<%= range_type %>) Set(src interface{}) error { + return errors.Errorf("cannot convert %v to <%= range_type %>", src) +} + +func (dst *<%= range_type %>) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *<%= range_type %>) AssignTo(dst interface{}) error { + return errors.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *<%= range_type %>) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = <%= range_type %>{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = <%= range_type %>{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *<%= range_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = <%= range_type %>{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = <%= range_type %>{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src <%= range_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) + } + + var err error + + if src.LowerType != Unbounded { + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if src.UpperType != Unbounded { + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) + } + + return buf, nil +} + +func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) + } + + buf = append(buf, rangeType) + + var err error + + if src.LowerType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if src.UpperType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *<%= range_type %>) Scan(src interface{}) error { + if src == nil { + *dst = <%= range_type %>{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src <%= range_type %>) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/typed_range_gen.sh b/pgtype/typed_range_gen.sh new file mode 100644 index 00000000..bedda292 --- /dev/null +++ b/pgtype/typed_range_gen.sh @@ -0,0 +1,7 @@ +erb range_type=Int4range element_type=Int4 typed_range.go.erb > int4range.go +erb range_type=Int8range element_type=Int8 typed_range.go.erb > int8range.go +erb range_type=Tsrange element_type=Timestamp typed_range.go.erb > tsrange.go +erb range_type=Tstzrange element_type=Timestamptz typed_range.go.erb > tstzrange.go +erb range_type=Daterange element_type=Date typed_range.go.erb > daterange.go +erb range_type=Numrange element_type=Numeric typed_range.go.erb > numrange.go +goimports -w *range.go diff --git a/pgtype/unknown.go b/pgtype/unknown.go new file mode 100644 index 00000000..567831d7 --- /dev/null +++ b/pgtype/unknown.go @@ -0,0 +1,44 @@ +package pgtype + +import "database/sql/driver" + +// Unknown represents the PostgreSQL unknown type. It is either a string literal +// or NULL. It is used when PostgreSQL does not know the type of a value. In +// general, this will only be used in pgx when selecting a null value without +// type information. e.g. SELECT NULL; +type Unknown struct { + String string + Status Status +} + +func (dst *Unknown) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst *Unknown) Get() interface{} { + return (*Text)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as Unknown is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *Unknown) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *Unknown) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (dst *Unknown) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Unknown) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Unknown) Value() (driver.Value, error) { + return (*Text)(src).Value() +} diff --git a/pgtype/uuid.go b/pgtype/uuid.go new file mode 100644 index 00000000..33e79536 --- /dev/null +++ b/pgtype/uuid.go @@ -0,0 +1,174 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/hex" + "fmt" + + "github.com/pkg/errors" +) + +type UUID struct { + Bytes [16]byte + Status Status +} + +func (dst *UUID) Set(src interface{}) error { + switch value := src.(type) { + case [16]byte: + *dst = UUID{Bytes: value, Status: Present} + case []byte: + if len(value) != 16 { + return errors.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) + } + *dst = UUID{Status: Present} + copy(dst.Bytes[:], value) + case string: + uuid, err := parseUUID(value) + if err != nil { + return err + } + *dst = UUID{Bytes: uuid, Status: Present} + default: + if originalSrc, ok := underlyingPtrType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to UUID", value) + } + + return nil +} + +func (dst *UUID) Get() interface{} { + switch dst.Status { + case Present: + return dst.Bytes + case Null: + return nil + default: + return dst.Status + } +} + +func (src *UUID) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *[16]byte: + *v = src.Bytes + return nil + case *[]byte: + *v = make([]byte, 16) + copy(*v, src.Bytes[:]) + return nil + case *string: + *v = encodeUUID(src.Bytes) + return nil + default: + if nextDst, retry := GetAssignToDstType(v); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot assign %v into %T", src, dst) +} + +// parseUUID converts a string UUID in standard form to a byte array. +func parseUUID(src string) (dst [16]byte, err error) { + src = src[0:8] + src[9:13] + src[14:18] + src[19:23] + src[24:] + buf, err := hex.DecodeString(src) + if err != nil { + return dst, err + } + + copy(dst[:], buf) + return dst, err +} + +// encodeUUID converts a uuid byte array to UUID standard string form. +func encodeUUID(src [16]byte) string { + return fmt.Sprintf("%x-%x-%x-%x-%x", src[0:4], src[4:6], src[6:8], src[8:10], src[10:16]) +} + +func (dst *UUID) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = UUID{Status: Null} + return nil + } + + if len(src) != 36 { + return errors.Errorf("invalid length for UUID: %v", len(src)) + } + + buf, err := parseUUID(string(src)) + if err != nil { + return err + } + + *dst = UUID{Bytes: buf, Status: Present} + return nil +} + +func (dst *UUID) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = UUID{Status: Null} + return nil + } + + if len(src) != 16 { + return errors.Errorf("invalid length for UUID: %v", len(src)) + } + + *dst = UUID{Status: Present} + copy(dst.Bytes[:], src) + return nil +} + +func (src *UUID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, encodeUUID(src.Bytes)...), nil +} + +func (src *UUID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, src.Bytes[:]...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *UUID) Scan(src interface{}) error { + if src == nil { + *dst = UUID{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *UUID) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/uuid_test.go b/pgtype/uuid_test.go new file mode 100644 index 00000000..5ab52b35 --- /dev/null +++ b/pgtype/uuid_test.go @@ -0,0 +1,96 @@ +package pgtype_test + +import ( + "bytes" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestUUIDTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ + &pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + &pgtype.UUID{Status: pgtype.Null}, + }) +} + +func TestUUIDSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.UUID + }{ + { + source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: "00010203-0405-0607-0809-0a0b0c0d0e0f", + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.UUID + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestUUIDAssignTo(t *testing.T) { + { + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst [16]byte + expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst []byte + expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if bytes.Compare(dst, expected) != 0 { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst string + expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + +} diff --git a/pgtype/varbit.go b/pgtype/varbit.go new file mode 100644 index 00000000..dfa194d2 --- /dev/null +++ b/pgtype/varbit.go @@ -0,0 +1,133 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type Varbit struct { + Bytes []byte + Len int32 // Number of bits + Status Status +} + +func (dst *Varbit) Set(src interface{}) error { + return errors.Errorf("cannot convert %v to Varbit", src) +} + +func (dst *Varbit) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Varbit) AssignTo(dst interface{}) error { + return errors.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Varbit) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Varbit{Status: Null} + return nil + } + + bitLen := len(src) + byteLen := bitLen / 8 + if bitLen%8 > 0 { + byteLen++ + } + buf := make([]byte, byteLen) + + for i, b := range src { + if b == '1' { + byteIdx := i / 8 + bitIdx := uint(i % 8) + buf[byteIdx] = buf[byteIdx] | (128 >> bitIdx) + } + } + + *dst = Varbit{Bytes: buf, Len: int32(bitLen), Status: Present} + return nil +} + +func (dst *Varbit) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Varbit{Status: Null} + return nil + } + + if len(src) < 4 { + return errors.Errorf("invalid length for varbit: %v", len(src)) + } + + bitLen := int32(binary.BigEndian.Uint32(src)) + rp := 4 + + *dst = Varbit{Bytes: src[rp:], Len: bitLen, Status: Present} + return nil +} + +func (src *Varbit) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + for i := int32(0); i < src.Len; i++ { + byteIdx := i / 8 + bitMask := byte(128 >> byte(i%8)) + char := byte('0') + if src.Bytes[byteIdx]&bitMask > 0 { + char = '1' + } + buf = append(buf, char) + } + + return buf, nil +} + +func (src *Varbit) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendInt32(buf, src.Len) + return append(buf, src.Bytes...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Varbit) Scan(src interface{}) error { + if src == nil { + *dst = Varbit{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Varbit) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/pgtype/varbit_test.go b/pgtype/varbit_test.go new file mode 100644 index 00000000..6c813aae --- /dev/null +++ b/pgtype/varbit_test.go @@ -0,0 +1,26 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestVarbitTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "varbit", []interface{}{ + &pgtype.Varbit{Bytes: []byte{}, Len: 0, Status: pgtype.Present}, + &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Status: pgtype.Present}, + &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 128}, Len: 33, Status: pgtype.Present}, + &pgtype.Varbit{Status: pgtype.Null}, + }) +} + +func TestVarbitNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select B'111111111'", + Value: &pgtype.Varbit{Bytes: []byte{255, 128}, Len: 9, Status: pgtype.Present}, + }, + }) +} diff --git a/pgtype/varchar.go b/pgtype/varchar.go new file mode 100644 index 00000000..371efd7e --- /dev/null +++ b/pgtype/varchar.go @@ -0,0 +1,54 @@ +package pgtype + +import ( + "database/sql/driver" +) + +type Varchar Text + +// Set converts from src to dst. Note that as Varchar is not a general +// number type Set does not do automatic type conversion as other number +// types do. +func (dst *Varchar) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst *Varchar) Get() interface{} { + return (*Text)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as Varchar is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *Varchar) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *Varchar) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (dst *Varchar) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) +} + +func (src *Varchar) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Text)(src).EncodeText(ci, buf) +} + +func (src *Varchar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Text)(src).EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Varchar) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Varchar) Value() (driver.Value, error) { + return (*Text)(src).Value() +} + +func (src *Varchar) MarshalJSON() ([]byte, error) { + return (*Text)(src).MarshalJSON() +} diff --git a/pgtype/varchar_array.go b/pgtype/varchar_array.go new file mode 100644 index 00000000..fecbb2e5 --- /dev/null +++ b/pgtype/varchar_array.go @@ -0,0 +1,294 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type VarcharArray struct { + Elements []Varchar + Dimensions []ArrayDimension + Status Status +} + +func (dst *VarcharArray) Set(src interface{}) error { + switch value := src.(type) { + + case []string: + if value == nil { + *dst = VarcharArray{Status: Null} + } else if len(value) == 0 { + *dst = VarcharArray{Status: Present} + } else { + elements := make([]Varchar, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = VarcharArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to Varchar", value) + } + + return nil +} + +func (dst *VarcharArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *VarcharArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *VarcharArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = VarcharArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Varchar + + if len(uta.Elements) > 0 { + elements = make([]Varchar, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Varchar + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = VarcharArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *VarcharArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = VarcharArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = VarcharArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Varchar, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = VarcharArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *VarcharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `"NULL"`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src *VarcharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("varchar"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "varchar") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *VarcharArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *VarcharArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/varchar_array_test.go b/pgtype/varchar_array_test.go new file mode 100644 index 00000000..7d6fb39b --- /dev/null +++ b/pgtype/varchar_array_test.go @@ -0,0 +1,152 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestVarcharArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "varchar[]", []interface{}{ + &pgtype.VarcharArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + pgtype.Varchar{String: "foo", Status: pgtype.Present}, + pgtype.Varchar{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.VarcharArray{Status: pgtype.Null}, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + pgtype.Varchar{String: "bar ", Status: pgtype.Present}, + pgtype.Varchar{String: "NuLL", Status: pgtype.Present}, + pgtype.Varchar{String: `wow"quz\`, Status: pgtype.Present}, + pgtype.Varchar{String: "", Status: pgtype.Present}, + pgtype.Varchar{Status: pgtype.Null}, + pgtype.Varchar{String: "null", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + pgtype.Varchar{String: "bar", Status: pgtype.Present}, + pgtype.Varchar{String: "baz", Status: pgtype.Present}, + pgtype.Varchar{String: "quz", Status: pgtype.Present}, + pgtype.Varchar{String: "foo", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestVarcharArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.VarcharArray + }{ + { + source: []string{"foo"}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]string)(nil)), + result: pgtype.VarcharArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.VarcharArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestVarcharArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + + simpleTests := []struct { + src pgtype.VarcharArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + expected: []string{"foo"}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedStringSlice, + expected: _stringSlice{"bar"}, + }, + { + src: pgtype.VarcharArray{Status: pgtype.Null}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.VarcharArray + dst interface{} + }{ + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/xid.go b/pgtype/xid.go new file mode 100644 index 00000000..f66f5367 --- /dev/null +++ b/pgtype/xid.go @@ -0,0 +1,64 @@ +package pgtype + +import ( + "database/sql/driver" +) + +// XID is PostgreSQL's Transaction ID type. +// +// In later versions of PostgreSQL, it is the type used for the backend_xid +// and backend_xmin columns of the pg_stat_activity system view. +// +// Also, when one does +// +// select xmin, xmax, * from some_table; +// +// it is the data type of the xmin and xmax hidden system columns. +// +// It is currently implemented as an unsigned four byte integer. +// Its definition can be found in src/include/postgres_ext.h as TransactionId +// in the PostgreSQL sources. +type XID pguint32 + +// Set converts from src to dst. Note that as XID is not a general +// number type Set does not do automatic type conversion as other number +// types do. +func (dst *XID) Set(src interface{}) error { + return (*pguint32)(dst).Set(src) +} + +func (dst *XID) Get() interface{} { + return (*pguint32)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as XID is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *XID) AssignTo(dst interface{}) error { + return (*pguint32)(src).AssignTo(dst) +} + +func (dst *XID) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) +} + +func (dst *XID) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) +} + +func (src *XID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*pguint32)(src).EncodeText(ci, buf) +} + +func (src *XID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*pguint32)(src).EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *XID) Scan(src interface{}) error { + return (*pguint32)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *XID) Value() (driver.Value, error) { + return (*pguint32)(src).Value() +} diff --git a/pgtype/xid_test.go b/pgtype/xid_test.go new file mode 100644 index 00000000..d0f3f0ab --- /dev/null +++ b/pgtype/xid_test.go @@ -0,0 +1,108 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestXIDTranscode(t *testing.T) { + pgTypeName := "xid" + values := []interface{}{ + &pgtype.XID{Uint: 42, Status: pgtype.Present}, + &pgtype.XID{Status: pgtype.Null}, + } + eqFunc := func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + } + + testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + + // No direct conversion from int to xid, convert through text + testutil.TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc) + + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) + } +} + +func TestXIDSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.XID + }{ + {source: uint32(1), result: pgtype.XID{Uint: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.XID + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestXIDAssignTo(t *testing.T) { + var ui32 uint32 + var pui32 *uint32 + + simpleTests := []struct { + src pgtype.XID + dst interface{} + expected interface{} + }{ + {src: pgtype.XID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.XID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.XID + dst interface{} + expected interface{} + }{ + {src: pgtype.XID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.XID + dst interface{} + }{ + {src: pgtype.XID{Status: pgtype.Null}, dst: &ui32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/private_test.go b/private_test.go new file mode 100644 index 00000000..df732a72 --- /dev/null +++ b/private_test.go @@ -0,0 +1,7 @@ +package pgx + +// This file contains methods that expose internal pgx state to tests. + +func (c *Conn) TxStatus() byte { + return c.txStatus +} diff --git a/query.go b/query.go index 19b867e2..811e95b1 100644 --- a/query.go +++ b/query.go @@ -1,10 +1,16 @@ package pgx import ( + "context" "database/sql" - "errors" "fmt" "time" + + "github.com/pkg/errors" + + "github.com/jackc/pgx/internal/sanitize" + "github.com/jackc/pgx/pgproto3" + "github.com/jackc/pgx/pgtype" ) // Row is a convenience wrapper over Rows that is returned by QueryRow. @@ -37,16 +43,16 @@ func (r *Row) Scan(dest ...interface{}) (err error) { // calling Next() until it returns false, or when a fatal error occurs. type Rows struct { conn *Conn - mr *msgReader + connPool *ConnPool + batch *Batch + values [][]byte fields []FieldDescription - vr ValueReader rowCount int columnIdx int err error startTime time.Time sql string args []interface{} - afterClose func(*Rows) unlockConn bool closed bool } @@ -55,7 +61,9 @@ func (rows *Rows) FieldDescriptions() []FieldDescription { return rows.fields } -func (rows *Rows) close() { +// Close closes the rows, making the connection ready for use again. It is safe +// to call Close after rows is already closed. +func (rows *Rows) Close() { if rows.closed { return } @@ -67,80 +75,33 @@ func (rows *Rows) close() { rows.closed = true + rows.err = rows.conn.termContext(rows.err) + if rows.err == nil { if rows.conn.shouldLog(LogLevelInfo) { endTime := time.Now() - rows.conn.log(LogLevelInfo, "Query", "sql", rows.sql, "args", logQueryArgs(rows.args), "time", endTime.Sub(rows.startTime), "rowCount", rows.rowCount) + rows.conn.log(LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount}) } } else if rows.conn.shouldLog(LogLevelError) { - rows.conn.log(LogLevelError, "Query", "sql", rows.sql, "args", logQueryArgs(rows.args)) + rows.conn.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)}) } - if rows.afterClose != nil { - rows.afterClose(rows) + if rows.batch != nil && rows.err != nil { + rows.batch.die(rows.err) } -} -func (rows *Rows) readUntilReadyForQuery() { - for { - t, r, err := rows.conn.rxMsg() - if err != nil { - rows.close() - return - } - - switch t { - case readyForQuery: - rows.conn.rxReadyForQuery(r) - rows.close() - return - case rowDescription: - case dataRow: - case commandComplete: - case bindComplete: - case errorResponse: - err = rows.conn.rxErrorResponse(r) - if rows.err == nil { - rows.err = err - } - default: - err = rows.conn.processContextFreeMsg(t, r) - if err != nil { - rows.close() - return - } - } + if rows.connPool != nil { + rows.connPool.Release(rows.conn) } } -// Close closes the rows, making the connection ready for use again. It is safe -// to call Close after rows is already closed. -func (rows *Rows) Close() { - if rows.closed { - return - } - rows.readUntilReadyForQuery() - rows.close() -} - func (rows *Rows) Err() error { return rows.err } -// abort signals that the query was not successfully sent to the server. -// This differs from Fatal in that it is not necessary to readUntilReadyForQuery -func (rows *Rows) abort(err error) { - if rows.err != nil { - return - } - - rows.err = err - rows.close() -} - -// Fatal signals an error occurred after the query was sent to the server. It +// fatal signals an error occurred after the query was sent to the server. It // closes the rows automatically. -func (rows *Rows) Fatal(err error) { +func (rows *Rows) fatal(err error) { if rows.err != nil { return } @@ -159,64 +120,61 @@ func (rows *Rows) Next() bool { rows.rowCount++ rows.columnIdx = 0 - rows.vr = ValueReader{} for { - t, r, err := rows.conn.rxMsg() + msg, err := rows.conn.rxMsg() if err != nil { - rows.Fatal(err) + rows.fatal(err) return false } - switch t { - case readyForQuery: - rows.conn.rxReadyForQuery(r) - rows.close() - return false - case dataRow: - fieldCount := r.readInt16() - if int(fieldCount) != len(rows.fields) { - rows.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), fieldCount))) + switch msg := msg.(type) { + case *pgproto3.RowDescription: + rows.fields = rows.conn.rxRowDescription(msg) + for i := range rows.fields { + if dt, ok := rows.conn.ConnInfo.DataTypeForOID(rows.fields[i].DataType); ok { + rows.fields[i].DataTypeName = dt.Name + rows.fields[i].FormatCode = TextFormatCode + } else { + rows.fatal(errors.Errorf("unknown oid: %d", rows.fields[i].DataType)) + return false + } + } + case *pgproto3.DataRow: + if len(msg.Values) != len(rows.fields) { + rows.fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), len(msg.Values)))) return false } - rows.mr = r + rows.values = msg.Values return true - case commandComplete: - case bindComplete: + case *pgproto3.CommandComplete: + rows.Close() + return false + default: - err = rows.conn.processContextFreeMsg(t, r) + err = rows.conn.processContextFreeMsg(msg) if err != nil { - rows.Fatal(err) + rows.fatal(err) return false } } } } -// Conn returns the *Conn this *Rows is using. -func (rows *Rows) Conn() *Conn { - return rows.conn -} - -func (rows *Rows) nextColumn() (*ValueReader, bool) { +func (rows *Rows) nextColumn() ([]byte, *FieldDescription, bool) { if rows.closed { - return nil, false + return nil, nil, false } if len(rows.fields) <= rows.columnIdx { - rows.Fatal(ProtocolError("No next column available")) - return nil, false - } - - if rows.vr.Len() > 0 { - rows.mr.readBytes(rows.vr.Len()) + rows.fatal(ProtocolError("No next column available")) + return nil, nil, false } + buf := rows.values[rows.columnIdx] fd := &rows.fields[rows.columnIdx] rows.columnIdx++ - size := rows.mr.readInt32() - rows.vr = ValueReader{mr: rows.mr, fd: fd, valueBytesRemaining: size} - return &rows.vr, true + return buf, fd, true } type scanArgError struct { @@ -234,94 +192,72 @@ func (e scanArgError) Error() string { // copy the raw bytes received from PostgreSQL. nil will skip the value entirely. func (rows *Rows) Scan(dest ...interface{}) (err error) { if len(rows.fields) != len(dest) { - err = fmt.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields)) - rows.Fatal(err) + err = errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields)) + rows.fatal(err) return err } for i, d := range dest { - vr, _ := rows.nextColumn() + buf, fd, _ := rows.nextColumn() if d == nil { continue } - // Check for []byte first as we allow sidestepping the decoding process and retrieving the raw bytes - if b, ok := d.(*[]byte); ok { - // If it actually is a bytea then pass it through decodeBytea (so it can be decoded if it is in text format) - // Otherwise read the bytes directly regardless of what the actual type is. - if vr.Type().DataType == ByteaOid { - *b = decodeBytea(vr) - } else { - if vr.Len() != -1 { - *b = vr.ReadBytes(vr.Len()) - } else { - *b = nil - } - } - } else if s, ok := d.(Scanner); ok { - err = s.Scan(vr) + if s, ok := d.(pgtype.BinaryDecoder); ok && fd.FormatCode == BinaryFormatCode { + err = s.DecodeBinary(rows.conn.ConnInfo, buf) if err != nil { - rows.Fatal(scanArgError{col: i, err: err}) + rows.fatal(scanArgError{col: i, err: err}) } - } else if s, ok := d.(PgxScanner); ok { - err = s.ScanPgx(vr) + } else if s, ok := d.(pgtype.TextDecoder); ok && fd.FormatCode == TextFormatCode { + err = s.DecodeText(rows.conn.ConnInfo, buf) if err != nil { - rows.Fatal(scanArgError{col: i, err: err}) + rows.fatal(scanArgError{col: i, err: err}) } - } else if s, ok := d.(sql.Scanner); ok { - var val interface{} - if 0 <= vr.Len() { - switch vr.Type().DataType { - case BoolOid: - val = decodeBool(vr) - case Int8Oid: - val = int64(decodeInt8(vr)) - case Int2Oid: - val = int64(decodeInt2(vr)) - case Int4Oid: - val = int64(decodeInt4(vr)) - case TextOid, VarcharOid: - val = decodeText(vr) - case OidOid: - val = int64(decodeOid(vr)) - case Float4Oid: - val = float64(decodeFloat4(vr)) - case Float8Oid: - val = decodeFloat8(vr) - case DateOid: - val = decodeDate(vr) - case TimestampOid: - val = decodeTimestamp(vr) - case TimestampTzOid: - val = decodeTimestampTz(vr) - default: - val = vr.ReadBytes(vr.Len()) - } - } - err = s.Scan(val) - if err != nil { - rows.Fatal(scanArgError{col: i, err: err}) - } - } else if vr.Type().DataType == JsonOid { - // Because the argument passed to decodeJSON will escape the heap. - // This allows d to be stack allocated and only copied to the heap when - // we actually are decoding JSON. This saves one memory allocation per - // row. - d2 := d - decodeJSON(vr, &d2) - } else if vr.Type().DataType == JsonbOid { - // Same trick as above for getting stack allocation - d2 := d - decodeJSONB(vr, &d2) } else { - if err := Decode(vr, d); err != nil { - rows.Fatal(scanArgError{col: i, err: err}) + if dt, ok := rows.conn.ConnInfo.DataTypeForOID(fd.DataType); ok { + value := dt.Value + switch fd.FormatCode { + case TextFormatCode: + if textDecoder, ok := value.(pgtype.TextDecoder); ok { + err = textDecoder.DecodeText(rows.conn.ConnInfo, buf) + if err != nil { + rows.fatal(scanArgError{col: i, err: err}) + } + } else { + rows.fatal(scanArgError{col: i, err: errors.Errorf("%T is not a pgtype.TextDecoder", value)}) + } + case BinaryFormatCode: + if binaryDecoder, ok := value.(pgtype.BinaryDecoder); ok { + err = binaryDecoder.DecodeBinary(rows.conn.ConnInfo, buf) + if err != nil { + rows.fatal(scanArgError{col: i, err: err}) + } + } else { + rows.fatal(scanArgError{col: i, err: errors.Errorf("%T is not a pgtype.BinaryDecoder", value)}) + } + default: + rows.fatal(scanArgError{col: i, err: errors.Errorf("unknown format code: %v", fd.FormatCode)}) + } + + if rows.Err() == nil { + if scanner, ok := d.(sql.Scanner); ok { + sqlSrc, err := pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value) + if err != nil { + rows.fatal(err) + } + err = scanner.Scan(sqlSrc) + if err != nil { + rows.fatal(scanArgError{col: i, err: err}) + } + } else if err := value.AssignTo(d); err != nil { + rows.fatal(scanArgError{col: i, err: err}) + } + } + } else { + rows.fatal(scanArgError{col: i, err: errors.Errorf("unknown oid: %v", fd.DataType)}) } } - if vr.Err() != nil { - rows.Fatal(scanArgError{col: i, err: vr.Err()}) - } if rows.Err() != nil { return rows.Err() @@ -340,79 +276,42 @@ func (rows *Rows) Values() ([]interface{}, error) { values := make([]interface{}, 0, len(rows.fields)) for range rows.fields { - vr, _ := rows.nextColumn() + buf, fd, _ := rows.nextColumn() - if vr.Len() == -1 { + if buf == nil { values = append(values, nil) continue } - switch vr.Type().FormatCode { - // All intrinsic types (except string) are encoded with binary - // encoding so anything else should be treated as a string - case TextFormatCode: - values = append(values, vr.ReadString(vr.Len())) - case BinaryFormatCode: - switch vr.Type().DataType { - case TextOid, VarcharOid: - values = append(values, decodeText(vr)) - case BoolOid: - values = append(values, decodeBool(vr)) - case ByteaOid: - values = append(values, decodeBytea(vr)) - case Int8Oid: - values = append(values, decodeInt8(vr)) - case Int2Oid: - values = append(values, decodeInt2(vr)) - case Int4Oid: - values = append(values, decodeInt4(vr)) - case OidOid: - values = append(values, decodeOid(vr)) - case Float4Oid: - values = append(values, decodeFloat4(vr)) - case Float8Oid: - values = append(values, decodeFloat8(vr)) - case BoolArrayOid: - values = append(values, decodeBoolArray(vr)) - case Int2ArrayOid: - values = append(values, decodeInt2Array(vr)) - case Int4ArrayOid: - values = append(values, decodeInt4Array(vr)) - case Int8ArrayOid: - values = append(values, decodeInt8Array(vr)) - case Float4ArrayOid: - values = append(values, decodeFloat4Array(vr)) - case Float8ArrayOid: - values = append(values, decodeFloat8Array(vr)) - case TextArrayOid, VarcharArrayOid: - values = append(values, decodeTextArray(vr)) - case TimestampArrayOid, TimestampTzArrayOid: - values = append(values, decodeTimestampArray(vr)) - case DateOid: - values = append(values, decodeDate(vr)) - case TimestampTzOid: - values = append(values, decodeTimestampTz(vr)) - case TimestampOid: - values = append(values, decodeTimestamp(vr)) - case InetOid, CidrOid: - values = append(values, decodeInet(vr)) - case JsonOid: - var d interface{} - decodeJSON(vr, &d) - values = append(values, d) - case JsonbOid: - var d interface{} - decodeJSONB(vr, &d) - values = append(values, d) - default: - rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types")) - } - default: - rows.Fatal(errors.New("Unknown format code")) - } + if dt, ok := rows.conn.ConnInfo.DataTypeForOID(fd.DataType); ok { + value := dt.Value - if vr.Err() != nil { - rows.Fatal(vr.Err()) + switch fd.FormatCode { + case TextFormatCode: + decoder := value.(pgtype.TextDecoder) + if decoder == nil { + decoder = &pgtype.GenericText{} + } + err := decoder.DecodeText(rows.conn.ConnInfo, buf) + if err != nil { + rows.fatal(err) + } + values = append(values, decoder.(pgtype.Value).Get()) + case BinaryFormatCode: + decoder := value.(pgtype.BinaryDecoder) + if decoder == nil { + decoder = &pgtype.GenericBinary{} + } + err := decoder.DecodeBinary(rows.conn.ConnInfo, buf) + if err != nil { + rows.fatal(err) + } + values = append(values, value.Get()) + default: + rows.fatal(errors.New("Unknown format code")) + } + } else { + rows.fatal(errors.New("Unknown type")) } if rows.Err() != nil { @@ -423,50 +322,11 @@ func (rows *Rows) Values() ([]interface{}, error) { return values, rows.Err() } -// AfterClose adds f to a LILO queue of functions that will be called when -// rows is closed. -func (rows *Rows) AfterClose(f func(*Rows)) { - if rows.afterClose == nil { - rows.afterClose = f - } else { - prevFn := rows.afterClose - rows.afterClose = func(rows *Rows) { - f(rows) - prevFn(rows) - } - } -} - // Query executes sql with args. If there is an error the returned *Rows will // be returned in an error state. So it is allowed to ignore the error returned // from Query and handle it in *Rows. func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) { - c.lastActivityTime = time.Now() - - rows := c.getRows(sql, args) - - if err := c.lock(); err != nil { - rows.abort(err) - return rows, err - } - rows.unlockConn = true - - ps, ok := c.preparedStatements[sql] - if !ok { - var err error - ps, err = c.Prepare("", sql) - if err != nil { - rows.abort(err) - return rows, rows.err - } - } - rows.sql = ps.SQL - rows.fields = ps.FieldDescriptions - err := c.sendPreparedQuery(ps, args...) - if err != nil { - rows.abort(err) - } - return rows, rows.err + return c.QueryEx(context.Background(), sql, nil, args...) } func (c *Conn) getRows(sql string, args []interface{}) *Rows { @@ -492,3 +352,190 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { rows, _ := c.Query(sql, args...) return (*Row)(rows) } + +type QueryExOptions struct { + // When ParameterOIDs are present and the query is not a prepared statement, + // then ParameterOIDs and ResultFormatCodes will be used to avoid an extra + // network round-trip. + ParameterOIDs []pgtype.OID + ResultFormatCodes []int16 + + SimpleProtocol bool +} + +func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) { + err = c.waitForPreviousCancelQuery(ctx) + if err != nil { + return nil, err + } + + if err := c.ensureConnectionReadyForQuery(); err != nil { + return nil, err + } + + c.lastActivityTime = time.Now() + + rows = c.getRows(sql, args) + + if err := c.lock(); err != nil { + rows.fatal(err) + return rows, err + } + rows.unlockConn = true + + err = c.initContext(ctx) + if err != nil { + rows.fatal(err) + return rows, rows.err + } + + if options != nil && options.SimpleProtocol { + err = c.sanitizeAndSendSimpleQuery(sql, args...) + if err != nil { + rows.fatal(err) + return rows, err + } + + return rows, nil + } + + if options != nil && len(options.ParameterOIDs) > 0 { + + buf, err := c.buildOneRoundTripQueryEx(c.wbuf, sql, options, args) + if err != nil { + rows.fatal(err) + return rows, err + } + + buf = appendSync(buf) + + n, err := c.conn.Write(buf) + if err != nil && fatalWriteErr(n, err) { + rows.fatal(err) + c.die(err) + return nil, err + } + c.pendingReadyForQueryCount++ + + fieldDescriptions, err := c.readUntilRowDescription() + if err != nil { + rows.fatal(err) + return nil, err + } + + if len(options.ResultFormatCodes) == 0 { + for i := range fieldDescriptions { + fieldDescriptions[i].FormatCode = TextFormatCode + } + } else if len(options.ResultFormatCodes) == 1 { + fc := options.ResultFormatCodes[0] + for i := range fieldDescriptions { + fieldDescriptions[i].FormatCode = fc + } + } else { + for i := range options.ResultFormatCodes { + fieldDescriptions[i].FormatCode = options.ResultFormatCodes[i] + } + } + + rows.sql = sql + rows.fields = fieldDescriptions + return rows, nil + } + + ps, ok := c.preparedStatements[sql] + if !ok { + var err error + ps, err = c.prepareEx("", sql, nil) + if err != nil { + rows.fatal(err) + return rows, rows.err + } + } + rows.sql = ps.SQL + rows.fields = ps.FieldDescriptions + + err = c.sendPreparedQuery(ps, args...) + if err != nil { + rows.fatal(err) + } + + return rows, rows.err +} + +func (c *Conn) buildOneRoundTripQueryEx(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) { + if len(arguments) != len(options.ParameterOIDs) { + return nil, errors.Errorf("mismatched number of arguments (%d) and options.ParameterOIDs (%d)", len(arguments), len(options.ParameterOIDs)) + } + + if len(options.ParameterOIDs) > 65535 { + return nil, errors.Errorf("Number of QueryExOptions ParameterOIDs must be between 0 and 65535, received %d", len(options.ParameterOIDs)) + } + + buf = appendParse(buf, "", sql, options.ParameterOIDs) + buf = appendDescribe(buf, 'S', "") + buf, err := appendBind(buf, "", "", c.ConnInfo, options.ParameterOIDs, arguments, options.ResultFormatCodes) + if err != nil { + return nil, err + } + buf = appendExecute(buf, "", 0) + + return buf, nil +} + +func (c *Conn) readUntilRowDescription() ([]FieldDescription, error) { + for { + msg, err := c.rxMsg() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.ParameterDescription: + case *pgproto3.RowDescription: + fieldDescriptions := c.rxRowDescription(msg) + for i := range fieldDescriptions { + if dt, ok := c.ConnInfo.DataTypeForOID(fieldDescriptions[i].DataType); ok { + fieldDescriptions[i].DataTypeName = dt.Name + } else { + return nil, errors.Errorf("unknown oid: %d", fieldDescriptions[i].DataType) + } + } + return fieldDescriptions, nil + default: + if err := c.processContextFreeMsg(msg); err != nil { + return nil, err + } + } + } +} + +func (c *Conn) sanitizeAndSendSimpleQuery(sql string, args ...interface{}) (err error) { + if c.RuntimeParams["standard_conforming_strings"] != "on" { + return errors.New("simple protocol queries must be run with standard_conforming_strings=on") + } + + if c.RuntimeParams["client_encoding"] != "UTF8" { + return errors.New("simple protocol queries must be run with client_encoding=UTF8") + } + + valueArgs := make([]interface{}, len(args)) + for i, a := range args { + valueArgs[i], err = convertSimpleArgument(c.ConnInfo, a) + if err != nil { + return err + } + } + + sql, err = sanitize.SanitizeSQL(sql, valueArgs...) + if err != nil { + return err + } + + return c.sendSimpleQuery(sql) +} + +func (c *Conn) QueryRowEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) *Row { + rows, _ := c.QueryEx(ctx, sql, options, args...) + return (*Row)(rows) +} diff --git a/query_test.go b/query_test.go index f08887b5..9379bd23 100644 --- a/query_test.go +++ b/query_test.go @@ -2,6 +2,7 @@ package pgx_test import ( "bytes" + "context" "database/sql" "fmt" "strings" @@ -9,7 +10,7 @@ import ( "time" "github.com/jackc/pgx" - + "github.com/jackc/pgx/pgtype" "github.com/shopspring/decimal" ) @@ -46,6 +47,58 @@ func TestConnQueryScan(t *testing.T) { } } +func TestConnQueryScanWithManyColumns(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + columnCount := 1000 + sql := "select " + for i := 0; i < columnCount; i++ { + if i > 0 { + sql += "," + } + sql += fmt.Sprintf(" %d", i) + } + sql += " from generate_series(1,5)" + + dest := make([]int, columnCount) + + var rowCount int + + rows, err := conn.Query(sql) + if err != nil { + t.Fatalf("conn.Query failed: %v", err) + } + defer rows.Close() + + for rows.Next() { + destPtrs := make([]interface{}, columnCount) + for i := range destPtrs { + destPtrs[i] = &dest[i] + } + if err := rows.Scan(destPtrs...); err != nil { + t.Fatalf("rows.Scan failed: %v", err) + } + rowCount++ + + for i := range dest { + if dest[i] != i { + t.Errorf("dest[%d] => %d, want %d", i, dest[i], i) + } + } + } + + if rows.Err() != nil { + t.Fatalf("conn.Query failed: %v", err) + } + + if rowCount != 5 { + t.Errorf("rowCount => %d, want %d", rowCount, 5) + } +} + func TestConnQueryValues(t *testing.T) { t.Parallel() @@ -54,7 +107,7 @@ func TestConnQueryValues(t *testing.T) { var rowCount int32 - rows, err := conn.Query("select 'foo'::text, 'bar'::varchar, n, null, n::oid from generate_series(1,$1) n", 10) + rows, err := conn.Query("select 'foo'::text, 'bar'::varchar, n, null, n from generate_series(1,$1) n", 10) if err != nil { t.Fatalf("conn.Query failed: %v", err) } @@ -85,7 +138,7 @@ func TestConnQueryValues(t *testing.T) { t.Errorf(`Expected values[3] to be %v, but it was %d`, nil, values[3]) } - if values[4] != pgx.Oid(rowCount) { + if values[4] != rowCount { t.Errorf(`Expected values[4] to be %d, but it was %d`, rowCount, values[4]) } } @@ -99,6 +152,23 @@ func TestConnQueryValues(t *testing.T) { } } +// https://github.com/jackc/pgx/issues/228 +func TestRowsScanDoesNotAllowScanningBinaryFormatValuesIntoString(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var s string + + err := conn.QueryRow("select 1").Scan(&s) + if err == nil || !(strings.Contains(err.Error(), "cannot decode binary value into string") || strings.Contains(err.Error(), "cannot assign")) { + t.Fatalf("Expected Scan to fail to encode binary value into string but: %v", err) + } + + ensureConnValid(t, conn) +} + // Test that a connection stays valid when query results are closed early func TestConnQueryCloseEarly(t *testing.T) { t.Parallel() @@ -181,7 +251,7 @@ func TestConnQueryReadWrongTypeError(t *testing.T) { t.Fatal("Expected Rows to have an error after an improper read but it didn't") } - if rows.Err().Error() != "can't scan into dest[0]: Can't convert OID 23 to time.Time" { + if rows.Err().Error() != "can't scan into dest[0]: Can't convert OID 23 to time.Time" && !strings.Contains(rows.Err().Error(), "cannot assign") { t.Fatalf("Expected different Rows.Err(): %v", rows.Err()) } @@ -254,105 +324,6 @@ func TestConnQueryScanIgnoreColumn(t *testing.T) { ensureConnValid(t, conn) } -func TestConnQueryScanner(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - rows, err := conn.Query("select null::int8, 1::int8") - if err != nil { - t.Fatalf("conn.Query failed: %v", err) - } - - ok := rows.Next() - if !ok { - t.Fatal("rows.Next terminated early") - } - - var n, m pgx.NullInt64 - err = rows.Scan(&n, &m) - if err != nil { - t.Fatalf("rows.Scan failed: %v", err) - } - rows.Close() - - if n.Valid { - t.Error("Null should not be valid, but it was") - } - - if !m.Valid { - t.Error("1 should be valid, but it wasn't") - } - - if m.Int64 != 1 { - t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64) - } - - ensureConnValid(t, conn) -} - -type pgxNullInt64 struct { - Int64 int64 - Valid bool // Valid is true if Int64 is not NULL -} - -func (n *pgxNullInt64) ScanPgx(vr *pgx.ValueReader) error { - if vr.Type().DataType != pgx.Int8Oid { - return pgx.SerializationError(fmt.Sprintf("pgxNullInt64.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Int64, n.Valid = 0, false - return nil - } - n.Valid = true - - err := pgx.Decode(vr, &n.Int64) - if err != nil { - return err - } - return vr.Err() -} - -func TestConnQueryPgxScanner(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - rows, err := conn.Query("select null::int8, 1::int8") - if err != nil { - t.Fatalf("conn.Query failed: %v", err) - } - - ok := rows.Next() - if !ok { - t.Fatal("rows.Next terminated early") - } - - var n, m pgxNullInt64 - err = rows.Scan(&n, &m) - if err != nil { - t.Fatalf("rows.Scan failed: %v", err) - } - rows.Close() - - if n.Valid { - t.Error("Null should not be valid, but it was") - } - - if !m.Valid { - t.Error("1 should be valid, but it wasn't") - } - - if m.Int64 != 1 { - t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64) - } - - ensureConnValid(t, conn) -} - func TestConnQueryErrorWhileReturningRows(t *testing.T) { t.Parallel() @@ -384,42 +355,6 @@ func TestConnQueryErrorWhileReturningRows(t *testing.T) { } -func TestConnQueryEncoder(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - n := pgx.NullInt64{Int64: 1, Valid: true} - - rows, err := conn.Query("select $1::int8", &n) - if err != nil { - t.Fatalf("conn.Query failed: %v", err) - } - - ok := rows.Next() - if !ok { - t.Fatal("rows.Next terminated early") - } - - var m pgx.NullInt64 - err = rows.Scan(&m) - if err != nil { - t.Fatalf("rows.Scan failed: %v", err) - } - rows.Close() - - if !m.Valid { - t.Error("m should be valid, but it wasn't") - } - - if m.Int64 != 1 { - t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64) - } - - ensureConnValid(t, conn) -} - func TestQueryEncodeError(t *testing.T) { t.Parallel() @@ -442,35 +377,6 @@ func TestQueryEncodeError(t *testing.T) { } } -// Ensure that an argument that implements Encoder works when the parameter type -// is a core type. -type coreEncoder struct{} - -func (n coreEncoder) FormatCode() int16 { return pgx.TextFormatCode } - -func (n *coreEncoder) Encode(w *pgx.WriteBuf, oid pgx.Oid) error { - w.WriteInt32(int32(2)) - w.WriteBytes([]byte("42")) - return nil -} - -func TestQueryEncodeCoreTextFormatError(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - var n int32 - err := conn.QueryRow("select $1::integer", &coreEncoder{}).Scan(&n) - if err != nil { - t.Fatalf("Unexpected conn.QueryRow error: %v", err) - } - - if n != 42 { - t.Errorf("Expected 42, got %v", n) - } -} - func TestQueryRowCoreTypes(t *testing.T) { t.Parallel() @@ -483,7 +389,7 @@ func TestQueryRowCoreTypes(t *testing.T) { f64 float64 b bool t time.Time - oid pgx.Oid + oid pgtype.OID } var actual, zero allTypes @@ -499,9 +405,9 @@ func TestQueryRowCoreTypes(t *testing.T) { {"select $1::float8", []interface{}{float64(1.23)}, []interface{}{&actual.f64}, allTypes{f64: 1.23}}, {"select $1::bool", []interface{}{true}, []interface{}{&actual.b}, allTypes{b: true}}, {"select $1::timestamptz", []interface{}{time.Unix(123, 5000)}, []interface{}{&actual.t}, allTypes{t: time.Unix(123, 5000)}}, - {"select $1::timestamp", []interface{}{time.Date(2010, 1, 2, 3, 4, 5, 0, time.Local)}, []interface{}{&actual.t}, allTypes{t: time.Date(2010, 1, 2, 3, 4, 5, 0, time.Local)}}, - {"select $1::date", []interface{}{time.Date(1987, 1, 2, 0, 0, 0, 0, time.Local)}, []interface{}{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.Local)}}, - {"select $1::oid", []interface{}{pgx.Oid(42)}, []interface{}{&actual.oid}, allTypes{oid: 42}}, + {"select $1::timestamp", []interface{}{time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}, []interface{}{&actual.t}, allTypes{t: time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}}, + {"select $1::date", []interface{}{time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}, []interface{}{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}}, + {"select $1::oid", []interface{}{pgtype.OID(42)}, []interface{}{&actual.oid}, allTypes{oid: 42}}, } for i, tt := range tests { @@ -523,9 +429,6 @@ func TestQueryRowCoreTypes(t *testing.T) { if err == nil { t.Errorf("%d. Expected null to cause error, but it didn't (sql -> %v)", i, tt.sql) } - if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { - t.Errorf(`%d. Expected null to cause error "Cannot decode null..." but it was %v (sql -> %v)`, i, err, tt.sql) - } ensureConnValid(t, conn) } @@ -834,7 +737,6 @@ func TestQueryRowCoreByteSlice(t *testing.T) { }{ {"select $1::text", "Jack", []byte("Jack")}, {"select $1::text", []byte("Jack"), []byte("Jack")}, - {"select $1::int4", int32(239023409), []byte{14, 63, 53, 49}}, {"select $1::varchar", []byte("Jack"), []byte("Jack")}, {"select $1::bytea", []byte{0, 15, 255, 17}, []byte{0, 15, 255, 17}}, } @@ -855,36 +757,25 @@ func TestQueryRowCoreByteSlice(t *testing.T) { } } -func TestQueryRowByteSliceArgument(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - sql := "select $1::int4" - queryArg := []byte{14, 63, 53, 49} - expected := int32(239023409) - - var actual int32 - - err := conn.QueryRow(sql, queryArg).Scan(&actual) - if err != nil { - t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) - } - - if expected != actual { - t.Errorf("Expected %v, got %v (sql -> %v)", expected, actual, sql) - } - - ensureConnValid(t, conn) -} - func TestQueryRowUnknownType(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) + // Clear existing type mappings + conn.ConnInfo = pgtype.NewConnInfo() + conn.ConnInfo.RegisterDataType(pgtype.DataType{ + Value: &pgtype.GenericText{}, + Name: "point", + OID: 600, + }) + conn.ConnInfo.RegisterDataType(pgtype.DataType{ + Value: &pgtype.Int4{}, + Name: "int4", + OID: pgtype.Int4OID, + }) + sql := "select $1::point" expected := "(1,0)" var actual string @@ -925,8 +816,8 @@ func TestQueryRowErrors(t *testing.T) { {"select $1", []interface{}{"Jack"}, []interface{}{&actual.i16}, "could not determine data type of parameter $1 (SQLSTATE 42P18)"}, {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`}, {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"}, - {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "Cannot decode oid 25 into any integer type"}, - {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot encode int8 into oid 600"}, + {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "cannot decode"}, + {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot convert 705 to Point"}, } for i, tt := range tests { @@ -959,290 +850,6 @@ func TestQueryRowNoResults(t *testing.T) { ensureConnValid(t, conn) } -func TestQueryRowCoreInt16Slice(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - var actual []int16 - - tests := []struct { - sql string - expected []int16 - }{ - {"select $1::int2[]", []int16{1, 2, 3, 4, 5}}, - {"select $1::int2[]", []int16{}}, - } - - for i, tt := range tests { - err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v", i, err) - } - - if len(actual) != len(tt.expected) { - t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) - } - - for j := 0; j < len(actual); j++ { - if actual[j] != tt.expected[j] { - t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) - } - } - - ensureConnValid(t, conn) - } - - // Check that Scan errors when an array with a null is scanned into a core slice type - err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int2[];").Scan(&actual) - if err == nil { - t.Error("Expected null to cause error when scanned into slice, but it didn't") - } - if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { - t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) - } - - ensureConnValid(t, conn) -} - -func TestQueryRowCoreInt32Slice(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - var actual []int32 - - tests := []struct { - sql string - expected []int32 - }{ - {"select $1::int4[]", []int32{1, 2, 3, 4, 5}}, - {"select $1::int4[]", []int32{}}, - } - - for i, tt := range tests { - err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v", i, err) - } - - if len(actual) != len(tt.expected) { - t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) - } - - for j := 0; j < len(actual); j++ { - if actual[j] != tt.expected[j] { - t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) - } - } - - ensureConnValid(t, conn) - } - - // Check that Scan errors when an array with a null is scanned into a core slice type - err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int4[];").Scan(&actual) - if err == nil { - t.Error("Expected null to cause error when scanned into slice, but it didn't") - } - if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { - t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) - } - - ensureConnValid(t, conn) -} - -func TestQueryRowCoreInt64Slice(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - var actual []int64 - - tests := []struct { - sql string - expected []int64 - }{ - {"select $1::int8[]", []int64{1, 2, 3, 4, 5}}, - {"select $1::int8[]", []int64{}}, - } - - for i, tt := range tests { - err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v", i, err) - } - - if len(actual) != len(tt.expected) { - t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) - } - - for j := 0; j < len(actual); j++ { - if actual[j] != tt.expected[j] { - t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) - } - } - - ensureConnValid(t, conn) - } - - // Check that Scan errors when an array with a null is scanned into a core slice type - err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int8[];").Scan(&actual) - if err == nil { - t.Error("Expected null to cause error when scanned into slice, but it didn't") - } - if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { - t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) - } - - ensureConnValid(t, conn) -} - -func TestQueryRowCoreFloat32Slice(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - var actual []float32 - - tests := []struct { - sql string - expected []float32 - }{ - {"select $1::float4[]", []float32{1.5, 2.0, 3.5}}, - {"select $1::float4[]", []float32{}}, - } - - for i, tt := range tests { - err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v", i, err) - } - - if len(actual) != len(tt.expected) { - t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) - } - - for j := 0; j < len(actual); j++ { - if actual[j] != tt.expected[j] { - t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) - } - } - - ensureConnValid(t, conn) - } - - // Check that Scan errors when an array with a null is scanned into a core slice type - err := conn.QueryRow("select '{1.5, 2.0, 3.5, null}'::float4[];").Scan(&actual) - if err == nil { - t.Error("Expected null to cause error when scanned into slice, but it didn't") - } - if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { - t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) - } - - ensureConnValid(t, conn) -} - -func TestQueryRowCoreFloat64Slice(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - var actual []float64 - - tests := []struct { - sql string - expected []float64 - }{ - {"select $1::float8[]", []float64{1.5, 2.0, 3.5}}, - {"select $1::float8[]", []float64{}}, - } - - for i, tt := range tests { - err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v", i, err) - } - - if len(actual) != len(tt.expected) { - t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) - } - - for j := 0; j < len(actual); j++ { - if actual[j] != tt.expected[j] { - t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) - } - } - - ensureConnValid(t, conn) - } - - // Check that Scan errors when an array with a null is scanned into a core slice type - err := conn.QueryRow("select '{1.5, 2.0, 3.5, null}'::float8[];").Scan(&actual) - if err == nil { - t.Error("Expected null to cause error when scanned into slice, but it didn't") - } - if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { - t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) - } - - ensureConnValid(t, conn) -} - -func TestQueryRowCoreStringSlice(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - var actual []string - - tests := []struct { - sql string - expected []string - }{ - {"select $1::text[]", []string{"Adam", "Eve", "UTF-8 Characters ƅ Ɔ Ƌ Ļæ"}}, - {"select $1::text[]", []string{}}, - {"select $1::varchar[]", []string{"Adam", "Eve", "UTF-8 Characters ƅ Ɔ Ƌ Ļæ"}}, - {"select $1::varchar[]", []string{}}, - } - - for i, tt := range tests { - err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v", i, err) - } - - if len(actual) != len(tt.expected) { - t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) - } - - for j := 0; j < len(actual); j++ { - if actual[j] != tt.expected[j] { - t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) - } - } - - ensureConnValid(t, conn) - } - - // Check that Scan errors when an array with a null is scanned into a core slice type - err := conn.QueryRow("select '{Adam,Eve,NULL}'::text[];").Scan(&actual) - if err == nil { - t.Error("Expected null to cause error when scanned into slice, but it didn't") - } - if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { - t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) - } - - ensureConnValid(t, conn) -} - func TestReadingValueAfterEmptyArray(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) @@ -1412,3 +1019,408 @@ func TestConnQueryDatabaseSQLNullX(t *testing.T) { ensureConnValid(t, conn) } + +func TestQueryExContextSuccess(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + rows, err := conn.QueryEx(ctx, "select 42::integer", nil) + if err != nil { + t.Fatal(err) + } + + var result, rowCount int + for rows.Next() { + err = rows.Scan(&result) + if err != nil { + t.Fatal(err) + } + rowCount++ + } + + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + + if rowCount != 1 { + t.Fatalf("Expected 1 row, got %d", rowCount) + } + if result != 42 { + t.Fatalf("Expected result 42, got %d", result) + } + + ensureConnValid(t, conn) +} + +func TestQueryExContextErrorWhileReceivingRows(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + rows, err := conn.QueryEx(ctx, "select 10/(10-n) from generate_series(1, 100) n", nil) + if err != nil { + t.Fatal(err) + } + + var result, rowCount int + for rows.Next() { + err = rows.Scan(&result) + if err != nil { + t.Fatal(err) + } + rowCount++ + } + + if rows.Err() == nil || rows.Err().Error() != "ERROR: division by zero (SQLSTATE 22012)" { + t.Fatalf("Expected division by zero error, but got %v", rows.Err()) + } + + if rowCount != 9 { + t.Fatalf("Expected 9 rows, got %d", rowCount) + } + if result != 10 { + t.Fatalf("Expected result 10, got %d", result) + } + + ensureConnValid(t, conn) +} + +func TestQueryExContextCancelationCancelsQuery(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(500 * time.Millisecond) + cancelFunc() + }() + + rows, err := conn.QueryEx(ctx, "select pg_sleep(5)", nil) + if err != nil { + t.Fatal(err) + } + + for rows.Next() { + t.Fatal("No rows should ever be ready -- context cancel apparently did not happen") + } + + if rows.Err() != context.Canceled { + t.Fatalf("Expected context.Canceled error, got %v", rows.Err()) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowExContextSuccess(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + var result int + err := conn.QueryRowEx(ctx, "select 42::integer", nil).Scan(&result) + if err != nil { + t.Fatal(err) + } + if result != 42 { + t.Fatalf("Expected result 42, got %d", result) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowExContextErrorWhileReceivingRow(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + var result int + err := conn.QueryRowEx(ctx, "select 10/0", nil).Scan(&result) + if err == nil || err.Error() != "ERROR: division by zero (SQLSTATE 22012)" { + t.Fatalf("Expected division by zero error, but got %v", err) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowExContextCancelationCancelsQuery(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(500 * time.Millisecond) + cancelFunc() + }() + + var result []byte + err := conn.QueryRowEx(ctx, "select pg_sleep(5)", nil).Scan(&result) + if err != context.Canceled { + t.Fatalf("Expected context.Canceled error, got %v", err) + } + + ensureConnValid(t, conn) +} + +func TestConnQueryRowExSingleRoundTrip(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var result int32 + err := conn.QueryRowEx( + context.Background(), + "select $1 + $2", + &pgx.QueryExOptions{ + ParameterOIDs: []pgtype.OID{pgtype.Int4OID, pgtype.Int4OID}, + ResultFormatCodes: []int16{pgx.BinaryFormatCode}, + }, + 1, 2, + ).Scan(&result) + if err != nil { + t.Fatal(err) + } + if result != 3 { + t.Fatal("result => %d, want %d", result, 3) + } + + ensureConnValid(t, conn) +} + +func TestConnSimpleProtocol(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + // Test all supported low-level types + + { + expected := int64(42) + var actual int64 + err := conn.QueryRowEx( + context.Background(), + "select $1::int8", + &pgx.QueryExOptions{SimpleProtocol: true}, + expected, + ).Scan(&actual) + if err != nil { + t.Error(err) + } + if expected != actual { + t.Errorf("expected %v got %v", expected, actual) + } + } + + { + expected := float64(1.23) + var actual float64 + err := conn.QueryRowEx( + context.Background(), + "select $1::float8", + &pgx.QueryExOptions{SimpleProtocol: true}, + expected, + ).Scan(&actual) + if err != nil { + t.Error(err) + } + if expected != actual { + t.Errorf("expected %v got %v", expected, actual) + } + } + + { + expected := true + var actual bool + err := conn.QueryRowEx( + context.Background(), + "select $1", + &pgx.QueryExOptions{SimpleProtocol: true}, + expected, + ).Scan(&actual) + if err != nil { + t.Error(err) + } + if expected != actual { + t.Errorf("expected %v got %v", expected, actual) + } + } + + { + expected := []byte{0, 1, 20, 35, 64, 80, 120, 3, 255, 240, 128, 95} + var actual []byte + err := conn.QueryRowEx( + context.Background(), + "select $1::bytea", + &pgx.QueryExOptions{SimpleProtocol: true}, + expected, + ).Scan(&actual) + if err != nil { + t.Error(err) + } + if bytes.Compare(actual, expected) != 0 { + t.Errorf("expected %v got %v", expected, actual) + } + } + + { + expected := "test" + var actual string + err := conn.QueryRowEx( + context.Background(), + "select $1::text", + &pgx.QueryExOptions{SimpleProtocol: true}, + expected, + ).Scan(&actual) + if err != nil { + t.Error(err) + } + if expected != actual { + t.Errorf("expected %v got %v", expected, actual) + } + } + + // Test high-level type + + { + expected := pgtype.Circle{P: pgtype.Vec2{1, 2}, R: 1.5, Status: pgtype.Present} + actual := expected + err := conn.QueryRowEx( + context.Background(), + "select $1::circle", + &pgx.QueryExOptions{SimpleProtocol: true}, + &expected, + ).Scan(&actual) + if err != nil { + t.Error(err) + } + if expected != actual { + t.Errorf("expected %v got %v", expected, actual) + } + } + + // Test multiple args in single query + + { + expectedInt64 := int64(234423) + expectedFloat64 := float64(-0.2312) + expectedBool := true + expectedBytes := []byte{255, 0, 23, 16, 87, 45, 9, 23, 45, 223} + expectedString := "test" + var actualInt64 int64 + var actualFloat64 float64 + var actualBool bool + var actualBytes []byte + var actualString string + err := conn.QueryRowEx( + context.Background(), + "select $1::int8, $2::float8, $3, $4::bytea, $5::text", + &pgx.QueryExOptions{SimpleProtocol: true}, + expectedInt64, expectedFloat64, expectedBool, expectedBytes, expectedString, + ).Scan(&actualInt64, &actualFloat64, &actualBool, &actualBytes, &actualString) + if err != nil { + t.Error(err) + } + if expectedInt64 != actualInt64 { + t.Errorf("expected %v got %v", expectedInt64, actualInt64) + } + if expectedFloat64 != actualFloat64 { + t.Errorf("expected %v got %v", expectedFloat64, actualFloat64) + } + if expectedBool != actualBool { + t.Errorf("expected %v got %v", expectedBool, actualBool) + } + if bytes.Compare(expectedBytes, actualBytes) != 0 { + t.Errorf("expected %v got %v", expectedBytes, actualBytes) + } + if expectedString != actualString { + t.Errorf("expected %v got %v", expectedString, actualString) + } + } + + // Test dangerous cases + + { + expected := "foo';drop table users;" + var actual string + err := conn.QueryRowEx( + context.Background(), + "select $1", + &pgx.QueryExOptions{SimpleProtocol: true}, + expected, + ).Scan(&actual) + if err != nil { + t.Error(err) + } + if expected != actual { + t.Errorf("expected %v got %v", expected, actual) + } + } + + ensureConnValid(t, conn) +} + +func TestConnSimpleProtocolRefusesNonUTF8ClientEncoding(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, "set client_encoding to 'SQL_ASCII'") + + var expected string + err := conn.QueryRowEx( + context.Background(), + "select $1", + &pgx.QueryExOptions{SimpleProtocol: true}, + "test", + ).Scan(&expected) + if err == nil { + t.Error("expected error when client_encoding not UTF8, but no error occurred") + } + + ensureConnValid(t, conn) +} + +func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, "set standard_conforming_strings to off") + + var expected string + err := conn.QueryRowEx( + context.Background(), + "select $1", + &pgx.QueryExOptions{SimpleProtocol: true}, + `\'; drop table users; --`, + ).Scan(&expected) + if err == nil { + t.Error("expected error when standard_conforming_strings is off, but no error occurred") + } + + ensureConnValid(t, conn) +} diff --git a/replication.go b/replication.go index 7b28d6b6..bfa81e54 100644 --- a/replication.go +++ b/replication.go @@ -1,10 +1,15 @@ package pgx import ( - "errors" + "context" + "encoding/binary" "fmt" - "net" "time" + + "github.com/pkg/errors" + + "github.com/jackc/pgx/pgio" + "github.com/jackc/pgx/pgproto3" ) const ( @@ -172,17 +177,21 @@ type ReplicationConn struct { // message to the server, as well as carries the WAL position of the // client, which then updates the server's replication slot position. func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) { - writeBuf := newWriteBuf(rc.c, copyData) - writeBuf.WriteByte(standbyStatusUpdate) - writeBuf.WriteInt64(int64(k.WalWritePosition)) - writeBuf.WriteInt64(int64(k.WalFlushPosition)) - writeBuf.WriteInt64(int64(k.WalApplyPosition)) - writeBuf.WriteInt64(int64(k.ClientTime)) - writeBuf.WriteByte(k.ReplyRequested) + buf := rc.c.wbuf + buf = append(buf, copyData) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - writeBuf.closeMsg() + buf = append(buf, standbyStatusUpdate) + buf = pgio.AppendInt64(buf, int64(k.WalWritePosition)) + buf = pgio.AppendInt64(buf, int64(k.WalFlushPosition)) + buf = pgio.AppendInt64(buf, int64(k.WalApplyPosition)) + buf = pgio.AppendInt64(buf, int64(k.ClientTime)) + buf = append(buf, k.ReplyRequested) - _, err = rc.c.conn.Write(writeBuf.buf) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + _, err = rc.c.conn.Write(buf) if err != nil { rc.c.die(err) } @@ -203,107 +212,115 @@ func (rc *ReplicationConn) CauseOfDeath() error { } func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err error) { - var t byte - var reader *msgReader - t, reader, err = rc.c.rxMsg() + msg, err := rc.c.rxMsg() if err != nil { return } - switch t { - case noticeResponse: - pgError := rc.c.rxErrorResponse(reader) + switch msg := msg.(type) { + case *pgproto3.NoticeResponse: + pgError := rc.c.rxErrorResponse((*pgproto3.ErrorResponse)(msg)) if rc.c.shouldLog(LogLevelInfo) { - rc.c.log(LogLevelInfo, pgError.Error()) + rc.c.log(LogLevelInfo, pgError.Error(), nil) } - case errorResponse: - err = rc.c.rxErrorResponse(reader) + case *pgproto3.ErrorResponse: + err = rc.c.rxErrorResponse(msg) if rc.c.shouldLog(LogLevelError) { - rc.c.log(LogLevelError, err.Error()) + rc.c.log(LogLevelError, err.Error(), nil) } return - case copyBothResponse: + case *pgproto3.CopyBothResponse: // This is the tail end of the replication process start, // and can be safely ignored return - case copyData: - var msgType byte - msgType = reader.readByte() + case *pgproto3.CopyData: + msgType := msg.Data[0] + rp := 1 + switch msgType { case walData: - walStart := reader.readInt64() - serverWalEnd := reader.readInt64() - serverTime := reader.readInt64() - walData := reader.readBytes(reader.msgBytesRemaining) - walMessage := WalMessage{WalStart: uint64(walStart), - ServerWalEnd: uint64(serverWalEnd), - ServerTime: uint64(serverTime), + walStart := binary.BigEndian.Uint64(msg.Data[rp:]) + rp += 8 + serverWalEnd := binary.BigEndian.Uint64(msg.Data[rp:]) + rp += 8 + serverTime := binary.BigEndian.Uint64(msg.Data[rp:]) + rp += 8 + walData := msg.Data[rp:] + walMessage := WalMessage{WalStart: walStart, + ServerWalEnd: serverWalEnd, + ServerTime: serverTime, WalData: walData, } return &ReplicationMessage{WalMessage: &walMessage}, nil case senderKeepalive: - serverWalEnd := reader.readInt64() - serverTime := reader.readInt64() - replyNow := reader.readByte() - h := &ServerHeartbeat{ServerWalEnd: uint64(serverWalEnd), ServerTime: uint64(serverTime), ReplyRequested: replyNow} + serverWalEnd := binary.BigEndian.Uint64(msg.Data[rp:]) + rp += 8 + serverTime := binary.BigEndian.Uint64(msg.Data[rp:]) + rp += 8 + replyNow := msg.Data[rp] + rp += 1 + h := &ServerHeartbeat{ServerWalEnd: serverWalEnd, ServerTime: serverTime, ReplyRequested: replyNow} return &ReplicationMessage{ServerHeartbeat: h}, nil default: if rc.c.shouldLog(LogLevelError) { - rc.c.log(LogLevelError, "Unexpected data playload message type %v", t) + rc.c.log(LogLevelError, "Unexpected data playload message type", map[string]interface{}{"type": msgType}) } } default: if rc.c.shouldLog(LogLevelError) { - rc.c.log(LogLevelError, "Unexpected replication message type %v", t) + rc.c.log(LogLevelError, "Unexpected replication message type", map[string]interface{}{"type": msg}) } } return } -// Wait for a single replication message up to timeout time. +// Wait for a single replication message. // // Properly using this requires some knowledge of the postgres replication mechanisms, // as the client can receive both WAL data (the ultimate payload) and server heartbeat // updates. The caller also must send standby status updates in order to keep the connection // alive and working. // -// This returns pgx.ErrNotificationTimeout when there is no replication message by the specified -// duration. -func (rc *ReplicationConn) WaitForReplicationMessage(timeout time.Duration) (r *ReplicationMessage, err error) { - var zeroTime time.Time - - deadline := time.Now().Add(timeout) - - // Use SetReadDeadline to implement the timeout. SetReadDeadline will - // cause operations to fail with a *net.OpError that has a Timeout() - // of true. Because the normal pgx rxMsg path considers any error to - // have potentially corrupted the state of the connection, it dies - // on any errors. So to avoid timeout errors in rxMsg we set the - // deadline and peek into the reader. If a timeout error occurs there - // we don't break the pgx connection. If the Peek returns that data - // is available then we turn off the read deadline before the rxMsg. - err = rc.c.conn.SetReadDeadline(deadline) - if err != nil { - return nil, err +// This returns the context error when there is no replication message before +// the context is canceled. +func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*ReplicationMessage, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: } - // Wait until there is a byte available before continuing onto the normal msg reading path - _, err = rc.c.reader.Peek(1) - if err != nil { - rc.c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline - if err, ok := err.(*net.OpError); ok && err.Timeout() { - return nil, ErrNotificationTimeout + go func() { + select { + case <-ctx.Done(): + if err := rc.c.conn.SetDeadline(time.Now()); err != nil { + rc.Close() // Close connection if unable to set deadline + return + } + rc.c.closedChan <- ctx.Err() + case <-rc.c.doneChan: } - return nil, err + }() + + r, opErr := rc.readReplicationMessage() + + var err error + select { + case err = <-rc.c.closedChan: + if err := rc.c.conn.SetDeadline(time.Time{}); err != nil { + rc.Close() // Close connection if unable to disable deadline + return nil, err + } + + if opErr == nil { + err = nil + } + case rc.c.doneChan <- struct{}{}: + err = opErr } - err = rc.c.conn.SetReadDeadline(zeroTime) - if err != nil { - return nil, err - } - - return rc.readReplicationMessage() + return r, err } func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { @@ -312,32 +329,30 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { rows := rc.c.getRows(sql, nil) if err := rc.c.lock(); err != nil { - rows.abort(err) + rows.fatal(err) return rows, err } rows.unlockConn = true err := rc.c.sendSimpleQuery(sql) if err != nil { - rows.abort(err) + rows.fatal(err) } - var t byte - var r *msgReader - t, r, err = rc.c.rxMsg() + msg, err := rc.c.rxMsg() if err != nil { return nil, err } - switch t { - case rowDescription: - rows.fields = rc.c.rxRowDescription(r) + switch msg := msg.(type) { + case *pgproto3.RowDescription: + rows.fields = rc.c.rxRowDescription(msg) // We don't have c.PgTypes here because we're a replication // connection. This means the field descriptions will have - // only Oids. Not much we can do about this. + // only OIDs. Not much we can do about this. default: - if e := rc.c.processContextFreeMsg(t, r); e != nil { - rows.abort(e) + if e := rc.c.processContextFreeMsg(msg); e != nil { + rows.fatal(e) return rows, e } } @@ -354,7 +369,7 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { // // NOTE: Because this is a replication mode connection, we don't have // type names, so the field descriptions in the result will have only -// Oids and no DataTypeName values +// OIDs and no DataTypeName values func (rc *ReplicationConn) IdentifySystem() (r *Rows, err error) { return rc.sendReplicationModeQuery("IDENTIFY_SYSTEM") } @@ -369,7 +384,7 @@ func (rc *ReplicationConn) IdentifySystem() (r *Rows, err error) { // // NOTE: Because this is a replication mode connection, we don't have // type names, so the field descriptions in the result will have only -// Oids and no DataTypeName values +// OIDs and no DataTypeName values func (rc *ReplicationConn) TimelineHistory(timeline int) (r *Rows, err error) { return rc.sendReplicationModeQuery(fmt.Sprintf("TIMELINE_HISTORY %d", timeline)) } @@ -401,15 +416,18 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti return } + ctx, cancelFn := context.WithTimeout(context.Background(), initialReplicationResponseTimeout) + defer cancelFn() + // The first replication message that comes back here will be (in a success case) // a empty CopyBoth that is (apparently) sent as the confirmation that the replication has // started. This call will either return nil, nil or if it returns an error // that indicates the start replication command failed var r *ReplicationMessage - r, err = rc.WaitForReplicationMessage(initialReplicationResponseTimeout) + r, err = rc.WaitForReplicationMessage(ctx) if err != nil && r != nil { if rc.c.shouldLog(LogLevelError) { - rc.c.log(LogLevelError, "Unxpected replication message %v", r) + rc.c.log(LogLevelError, "Unexpected replication message", map[string]interface{}{"msg": r, "err": err}) } } diff --git a/replication_test.go b/replication_test.go index 4f810c78..d75233c1 100644 --- a/replication_test.go +++ b/replication_test.go @@ -1,13 +1,15 @@ package pgx_test import ( + "context" "fmt" - "github.com/jackc/pgx" "reflect" "strconv" "strings" "testing" "time" + + "github.com/jackc/pgx" ) // This function uses a postgresql 9.6 specific column @@ -37,8 +39,6 @@ func getConfirmedFlushLsnFor(t *testing.T, conn *pgx.Conn, slot string) string { // - Checks the wal position of the slot on the server to make sure // the update succeeded func TestSimpleReplicationConnection(t *testing.T) { - t.Parallel() - var err error if replicationConnConfig == nil { @@ -46,14 +46,19 @@ func TestSimpleReplicationConnection(t *testing.T) { } conn := mustConnect(t, *replicationConnConfig) - defer closeConn(t, conn) + defer func() { + // Ensure replication slot is destroyed, but don't check for errors as it + // should have already been destroyed. + conn.Exec("select pg_drop_replication_slot('pgx_test')") + closeConn(t, conn) + }() replicationConn := mustReplicationConnect(t, *replicationConnConfig) defer closeReplicationConn(t, replicationConn) err = replicationConn.CreateReplicationSlot("pgx_test", "test_decoding") if err != nil { - t.Logf("replication slot create failed: %v", err) + t.Fatalf("replication slot create failed: %v", err) } // Do a simple change so we can get some wal data @@ -67,71 +72,63 @@ func TestSimpleReplicationConnection(t *testing.T) { t.Fatalf("Failed to start replication: %v", err) } - var i int32 var insertedTimes []int64 - for i < 5 { + currentTime := time.Now().Unix() + + for i := 0; i < 5; i++ { var ct pgx.CommandTag - currentTime := time.Now().Unix() insertedTimes = append(insertedTimes, currentTime) ct, err = conn.Exec("insert into replication_test(a) values($1)", currentTime) if err != nil { t.Fatalf("Insert failed: %v", err) } t.Logf("Inserted %d rows", ct.RowsAffected()) - i++ + currentTime++ } - i = 0 var foundTimes []int64 var foundCount int var maxWal uint64 + + ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelFn() + for { var message *pgx.ReplicationMessage - message, err = replicationConn.WaitForReplicationMessage(time.Duration(1 * time.Second)) + message, err = replicationConn.WaitForReplicationMessage(ctx) if err != nil { - if err != pgx.ErrNotificationTimeout { - t.Fatalf("Replication failed: %v %s", err, reflect.TypeOf(err)) - } + t.Fatalf("Replication failed: %v %s", err, reflect.TypeOf(err)) } - if message != nil { - if message.WalMessage != nil { - // The waldata payload with the test_decoding plugin looks like: - // public.replication_test: INSERT: a[integer]:2 - // What we wanna do here is check that once we find one of our inserted times, - // that they occur in the wal stream in the order we executed them. - walString := string(message.WalMessage.WalData) - if strings.Contains(walString, "public.replication_test: INSERT") { - stringParts := strings.Split(walString, ":") - offset, err := strconv.ParseInt(stringParts[len(stringParts)-1], 10, 64) - if err != nil { - t.Fatalf("Failed to parse walString %s", walString) - } - if foundCount > 0 || offset == insertedTimes[0] { - foundTimes = append(foundTimes, offset) - foundCount++ - } - } - if message.WalMessage.WalStart > maxWal { - maxWal = message.WalMessage.WalStart - } + if message.WalMessage != nil { + // The waldata payload with the test_decoding plugin looks like: + // public.replication_test: INSERT: a[integer]:2 + // What we wanna do here is check that once we find one of our inserted times, + // that they occur in the wal stream in the order we executed them. + walString := string(message.WalMessage.WalData) + if strings.Contains(walString, "public.replication_test: INSERT") { + stringParts := strings.Split(walString, ":") + offset, err := strconv.ParseInt(stringParts[len(stringParts)-1], 10, 64) + if err != nil { + t.Fatalf("Failed to parse walString %s", walString) + } + if foundCount > 0 || offset == insertedTimes[0] { + foundTimes = append(foundTimes, offset) + foundCount++ + } + if foundCount == len(insertedTimes) { + break + } } - if message.ServerHeartbeat != nil { - t.Logf("Got heartbeat: %s", message.ServerHeartbeat) + if message.WalMessage.WalStart > maxWal { + maxWal = message.WalMessage.WalStart } - } else { - t.Log("Timed out waiting for wal message") - i++ - } - if i > 3 { - t.Log("Actual timeout") - break - } - } - if foundCount != len(insertedTimes) { - t.Fatalf("Failed to find all inserted time values in WAL stream (found %d expected %d)", foundCount, len(insertedTimes)) + } + if message.ServerHeartbeat != nil { + t.Logf("Got heartbeat: %s", message.ServerHeartbeat) + } } for i := range insertedTimes { @@ -249,11 +246,7 @@ func getCurrentTimeline(t *testing.T, rc *pgx.ReplicationConn) int { if e != nil { t.Error(e) } - timeline, e := strconv.Atoi(values[1].(string)) - if e != nil { - t.Error(e) - } - return timeline + return int(values[1].(int32)) } t.Fatal("Failed to read timeline") return -1 diff --git a/stdlib/sql.go b/stdlib/sql.go index 8c78cd39..0c140343 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -14,154 +14,206 @@ // return err // } // -// Or a normal pgx connection pool can be established and the database/sql -// connection can be created through stdlib.OpenFromConnPool(). This allows -// more control over the connection process (such as TLS), more control -// over the connection pool, setting an AfterConnect hook, and using both -// database/sql and pgx interfaces as needed. +// A DriverConfig can be used to further configure the connection process. This +// allows configuring TLS configuration, setting a custom dialer, logging, and +// setting an AfterConnect hook. // -// connConfig := pgx.ConnConfig{ -// Host: "localhost", -// User: "pgx_md5", -// Password: "secret", -// Database: "pgx_test", -// } -// -// config := pgx.ConnPoolConfig{ConnConfig: connConfig} -// pool, err := pgx.NewConnPool(config) -// if err != nil { -// return err -// } -// -// db, err := stdlib.OpenFromConnPool(pool) -// if err != nil { -// t.Fatalf("Unable to create connection pool: %v", err) +// driverConfig := stdlib.DriverConfig{ +// ConnConfig: ConnConfig: pgx.ConnConfig{ +// Logger: logger, +// }, +// AfterConnect: func(c *pgx.Conn) error { +// // Ensure all connections have this temp table available +// _, err := c.Exec("create temporary table foo(...)") +// return err +// }, // } // -// If the database/sql connection is established through -// stdlib.OpenFromConnPool then access to a pgx *ConnPool can be regained -// through db.Driver(). This allows writing a fast path for pgx while -// preserving compatibility with other drivers and database +// stdlib.RegisterDriverConfig(&driverConfig) // -// if driver, ok := db.Driver().(*stdlib.Driver); ok && driver.Pool != nil { +// db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")) +// if err != nil { +// return err +// } +// +// AcquireConn and ReleaseConn acquire and release a *pgx.Conn from the standard +// database/sql.DB connection pool. This allows operations that must be +// performed on a single connection, but should not be run in a transaction or +// to use pgx specific functionality. +// +// conn, err := stdlib.AcquireConn(db) +// if err != nil { +// return err +// } +// defer stdlib.ReleaseConn(db, conn) +// +// // do stuff with pgx.Conn +// +// It also can be used to enable a fast path for pgx while preserving +// compatibility with other drivers and database. +// +// conn, err := stdlib.AcquireConn(db) +// if err == nil { // // fast path with pgx +// // ... +// // release conn when done +// stdlib.ReleaseConn(db, conn) // } else { // // normal path for other drivers and databases // } package stdlib import ( + "context" "database/sql" "database/sql/driver" - "errors" + "encoding/binary" "fmt" "io" + "strings" "sync" - "github.com/jackc/pgx" -) + "github.com/pkg/errors" -var ( - openFromConnPoolCountMu sync.Mutex - openFromConnPoolCount int + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" ) // oids that map to intrinsic database/sql types. These will be allowed to be // binary, anything else will be forced to text format -var databaseSqlOids map[pgx.Oid]bool +var databaseSqlOIDs map[pgtype.OID]bool + +var pgxDriver *Driver + +type ctxKey int + +var ctxKeyFakeTx ctxKey = 0 + +var ErrNotPgx = errors.New("not pgx *sql.DB") func init() { - d := &Driver{} - sql.Register("pgx", d) + pgxDriver = &Driver{ + configs: make(map[int64]*DriverConfig), + fakeTxConns: make(map[*pgx.Conn]*sql.Tx), + } + sql.Register("pgx", pgxDriver) - databaseSqlOids = make(map[pgx.Oid]bool) - databaseSqlOids[pgx.BoolOid] = true - databaseSqlOids[pgx.ByteaOid] = true - databaseSqlOids[pgx.Int2Oid] = true - databaseSqlOids[pgx.Int4Oid] = true - databaseSqlOids[pgx.Int8Oid] = true - databaseSqlOids[pgx.Float4Oid] = true - databaseSqlOids[pgx.Float8Oid] = true - databaseSqlOids[pgx.DateOid] = true - databaseSqlOids[pgx.TimestampTzOid] = true - databaseSqlOids[pgx.TimestampOid] = true + databaseSqlOIDs = make(map[pgtype.OID]bool) + databaseSqlOIDs[pgtype.BoolOID] = true + databaseSqlOIDs[pgtype.ByteaOID] = true + databaseSqlOIDs[pgtype.CIDOID] = true + databaseSqlOIDs[pgtype.DateOID] = true + databaseSqlOIDs[pgtype.Float4OID] = true + databaseSqlOIDs[pgtype.Float8OID] = true + databaseSqlOIDs[pgtype.Int2OID] = true + databaseSqlOIDs[pgtype.Int4OID] = true + databaseSqlOIDs[pgtype.Int8OID] = true + databaseSqlOIDs[pgtype.OIDOID] = true + databaseSqlOIDs[pgtype.TimestampOID] = true + databaseSqlOIDs[pgtype.TimestamptzOID] = true + databaseSqlOIDs[pgtype.XIDOID] = true } type Driver struct { - Pool *pgx.ConnPool + configMutex sync.Mutex + configCount int64 + configs map[int64]*DriverConfig + + fakeTxMutex sync.Mutex + fakeTxConns map[*pgx.Conn]*sql.Tx } func (d *Driver) Open(name string) (driver.Conn, error) { - if d.Pool != nil { - conn, err := d.Pool.Acquire() - if err != nil { - return nil, err - } - - return &Conn{conn: conn, pool: d.Pool}, nil + var connConfig pgx.ConnConfig + var afterConnect func(*pgx.Conn) error + if len(name) >= 9 && name[0] == 0 { + idBuf := []byte(name)[1:9] + id := int64(binary.BigEndian.Uint64(idBuf)) + connConfig = d.configs[id].ConnConfig + afterConnect = d.configs[id].AfterConnect + name = name[9:] } - connConfig, err := pgx.ParseConnectionString(name) + parsedConfig, err := pgx.ParseConnectionString(name) if err != nil { return nil, err } + connConfig = connConfig.Merge(parsedConfig) conn, err := pgx.Connect(connConfig) if err != nil { return nil, err } - c := &Conn{conn: conn} + if afterConnect != nil { + err = afterConnect(conn) + if err != nil { + return nil, err + } + } + + c := &Conn{conn: conn, driver: d} return c, nil } -// OpenFromConnPool takes the existing *pgx.ConnPool pool and returns a *sql.DB -// with pool as the backend. This enables full control over the connection -// process and configuration while maintaining compatibility with the -// database/sql interface. In addition, by calling Driver() on the returned -// *sql.DB and typecasting to *stdlib.Driver a reference to the pgx.ConnPool can -// be reaquired later. This allows fast paths targeting pgx to be used while -// still maintaining compatibility with other databases and drivers. -// -// pool connection size must be at least 2. -func OpenFromConnPool(pool *pgx.ConnPool) (*sql.DB, error) { - d := &Driver{Pool: pool} +type DriverConfig struct { + pgx.ConnConfig + AfterConnect func(*pgx.Conn) error // function to call on every new connection + driver *Driver + id int64 +} - openFromConnPoolCountMu.Lock() - name := fmt.Sprintf("pgx-%d", openFromConnPoolCount) - openFromConnPoolCount++ - openFromConnPoolCountMu.Unlock() - - sql.Register(name, d) - db, err := sql.Open(name, "") - if err != nil { - return nil, err +// ConnectionString encodes the DriverConfig into the original connection +// string. DriverConfig must be registered before calling ConnectionString. +func (c *DriverConfig) ConnectionString(original string) string { + if c.driver == nil { + panic("DriverConfig must be registered before calling ConnectionString") } - // Presumably OpenFromConnPool is being used because the user wants to use - // database/sql most of the time, but fast path with pgx some of the time. - // Allow database/sql to use all the connections, but release 2 idle ones. - // Don't have database/sql immediately release all idle connections because - // that would mean that prepared statements would be lost (which kills - // performance if the prepared statements constantly have to be reprepared) - stat := pool.Stat() + buf := make([]byte, 9) + binary.BigEndian.PutUint64(buf[1:], uint64(c.id)) + buf = append(buf, original...) + return string(buf) +} - if stat.MaxConnections <= 2 { - return nil, errors.New("pool connection size must be at least 3") - } - db.SetMaxIdleConns(stat.MaxConnections - 2) - db.SetMaxOpenConns(stat.MaxConnections) +func (d *Driver) registerDriverConfig(c *DriverConfig) { + d.configMutex.Lock() - return db, nil + c.driver = d + c.id = d.configCount + d.configs[d.configCount] = c + d.configCount++ + + d.configMutex.Unlock() +} + +func (d *Driver) unregisterDriverConfig(c *DriverConfig) { + d.configMutex.Lock() + delete(d.configs, c.id) + d.configMutex.Unlock() +} + +// RegisterDriverConfig registers a DriverConfig for use with Open. +func RegisterDriverConfig(c *DriverConfig) { + pgxDriver.registerDriverConfig(c) +} + +// UnregisterDriverConfig removes a DriverConfig registration. +func UnregisterDriverConfig(c *DriverConfig) { + pgxDriver.unregisterDriverConfig(c) } type Conn struct { conn *pgx.Conn - pool *pgx.ConnPool psCount int64 // Counter used for creating unique prepared statement names + driver *Driver } func (c *Conn) Prepare(query string) (driver.Stmt, error) { + return c.PrepareContext(context.Background(), query) +} + +func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { if !c.conn.IsAlive() { return nil, driver.ErrBadConn } @@ -169,7 +221,7 @@ func (c *Conn) Prepare(query string) (driver.Stmt, error) { name := fmt.Sprintf("pgx_%d", c.psCount) c.psCount++ - ps, err := c.conn.Prepare(name, query) + ps, err := c.conn.PrepareEx(ctx, name, query, nil) if err != nil { return nil, err } @@ -180,25 +232,43 @@ func (c *Conn) Prepare(query string) (driver.Stmt, error) { } func (c *Conn) Close() error { - err := c.conn.Close() - if c.pool != nil { - c.pool.Release(c.conn) - } - - return err + return c.conn.Close() } func (c *Conn) Begin() (driver.Tx, error) { + return c.BeginTx(context.Background(), driver.TxOptions{}) +} + +func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { if !c.conn.IsAlive() { return nil, driver.ErrBadConn } - _, err := c.conn.Exec("begin") - if err != nil { - return nil, err + if pconn, ok := ctx.Value(ctxKeyFakeTx).(**pgx.Conn); ok { + *pconn = c.conn + return fakeTx{}, nil } - return &Tx{conn: c.conn}, nil + var pgxOpts pgx.TxOptions + switch sql.IsolationLevel(opts.Isolation) { + case sql.LevelDefault: + case sql.LevelReadUncommitted: + pgxOpts.IsoLevel = pgx.ReadUncommitted + case sql.LevelReadCommitted: + pgxOpts.IsoLevel = pgx.ReadCommitted + case sql.LevelSnapshot: + pgxOpts.IsoLevel = pgx.RepeatableRead + case sql.LevelSerializable: + pgxOpts.IsoLevel = pgx.Serializable + default: + return nil, errors.Errorf("unsupported isolation: %v", opts.Isolation) + } + + if opts.ReadOnly { + pgxOpts.AccessMode = pgx.ReadOnly + } + + return c.conn.BeginEx(ctx, &pgxOpts) } func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) { @@ -211,6 +281,17 @@ func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) { return driver.RowsAffected(commandTag.RowsAffected()), err } +func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Result, error) { + if !c.conn.IsAlive() { + return nil, driver.ErrBadConn + } + + args := namedValueToInterface(argsV) + + commandTag, err := c.conn.ExecEx(ctx, query, nil, args...) + return driver.RowsAffected(commandTag.RowsAffected()), err +} + func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) { if !c.conn.IsAlive() { return nil, driver.ErrBadConn @@ -226,6 +307,21 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) { return c.queryPrepared("", argsV) } +func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) { + if !c.conn.IsAlive() { + return nil, driver.ErrBadConn + } + + ps, err := c.conn.PrepareEx(ctx, "", query, nil) + if err != nil { + return nil, err + } + + restrictBinaryToDatabaseSqlTypes(ps) + + return c.queryPreparedContext(ctx, "", argsV) +} + func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) { if !c.conn.IsAlive() { return nil, driver.ErrBadConn @@ -241,12 +337,35 @@ func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, er return &Rows{rows: rows}, nil } +func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []driver.NamedValue) (driver.Rows, error) { + if !c.conn.IsAlive() { + return nil, driver.ErrBadConn + } + + args := namedValueToInterface(argsV) + + rows, err := c.conn.QueryEx(ctx, name, nil, args...) + if err != nil { + return nil, err + } + + return &Rows{rows: rows}, nil +} + +func (c *Conn) Ping(ctx context.Context) error { + if !c.conn.IsAlive() { + return driver.ErrBadConn + } + + return c.conn.Ping(ctx) +} + // Anything that isn't a database/sql compatible type needs to be forced to // text format so that pgx.Rows.Values doesn't decode it into a native type // (e.g. []int32) func restrictBinaryToDatabaseSqlTypes(ps *pgx.PreparedStatement) { - for i := range ps.FieldDescriptions { - intrinsic, _ := databaseSqlOids[ps.FieldDescriptions[i].DataType] + for i, _ := range ps.FieldDescriptions { + intrinsic, _ := databaseSqlOIDs[ps.FieldDescriptions[i].DataType] if !intrinsic { ps.FieldDescriptions[i].FormatCode = pgx.TextFormatCode } @@ -263,20 +382,28 @@ func (s *Stmt) Close() error { } func (s *Stmt) NumInput() int { - return len(s.ps.ParameterOids) + return len(s.ps.ParameterOIDs) } func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) { return s.conn.Exec(s.ps.Name, argsV) } +func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) { + return s.conn.ExecContext(ctx, s.ps.Name, argsV) +} + func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { return s.conn.queryPrepared(s.ps.Name, argsV) } -// TODO - rename to avoid alloc +func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) { + return s.conn.queryPreparedContext(ctx, s.ps.Name, argsV) +} + type Rows struct { - rows *pgx.Rows + rows *pgx.Rows + values []interface{} } func (r *Rows) Columns() []string { @@ -288,12 +415,52 @@ func (r *Rows) Columns() []string { return names } +func (r *Rows) ColumnTypeDatabaseTypeName(index int) string { + return strings.ToUpper(r.rows.FieldDescriptions()[index].DataTypeName) +} + func (r *Rows) Close() error { r.rows.Close() return nil } func (r *Rows) Next(dest []driver.Value) error { + if r.values == nil { + r.values = make([]interface{}, len(r.rows.FieldDescriptions())) + for i, fd := range r.rows.FieldDescriptions() { + switch fd.DataType { + case pgtype.BoolOID: + r.values[i] = &pgtype.Bool{} + case pgtype.ByteaOID: + r.values[i] = &pgtype.Bytea{} + case pgtype.CIDOID: + r.values[i] = &pgtype.CID{} + case pgtype.DateOID: + r.values[i] = &pgtype.Date{} + case pgtype.Float4OID: + r.values[i] = &pgtype.Float4{} + case pgtype.Float8OID: + r.values[i] = &pgtype.Float8{} + case pgtype.Int2OID: + r.values[i] = &pgtype.Int2{} + case pgtype.Int4OID: + r.values[i] = &pgtype.Int4{} + case pgtype.Int8OID: + r.values[i] = &pgtype.Int8{} + case pgtype.OIDOID: + r.values[i] = &pgtype.OIDValue{} + case pgtype.TimestampOID: + r.values[i] = &pgtype.Timestamp{} + case pgtype.TimestamptzOID: + r.values[i] = &pgtype.Timestamptz{} + case pgtype.XIDOID: + r.values[i] = &pgtype.XID{} + default: + r.values[i] = &pgtype.GenericText{} + } + } + } + more := r.rows.Next() if !more { if r.rows.Err() == nil { @@ -303,19 +470,16 @@ func (r *Rows) Next(dest []driver.Value) error { } } - values, err := r.rows.Values() + err := r.rows.Scan(r.values...) if err != nil { return err } - if len(dest) < len(values) { - fmt.Printf("%d: %#v\n", len(dest), dest) - fmt.Printf("%d: %#v\n", len(values), values) - return errors.New("expected more values than were received") - } - - for i, v := range values { - dest[i] = driver.Value(v) + for i, v := range r.values { + dest[i], err = v.(driver.Valuer).Value() + if err != nil { + return err + } } return nil @@ -333,16 +497,58 @@ func valueToInterface(argsV []driver.Value) []interface{} { return args } -type Tx struct { - conn *pgx.Conn +func namedValueToInterface(argsV []driver.NamedValue) []interface{} { + args := make([]interface{}, 0, len(argsV)) + for _, v := range argsV { + if v.Value != nil { + args = append(args, v.Value.(interface{})) + } else { + args = append(args, nil) + } + } + return args } -func (t *Tx) Commit() error { - _, err := t.conn.Exec("commit") - return err +type fakeTx struct{} + +func (fakeTx) Commit() error { return nil } + +func (fakeTx) Rollback() error { return nil } + +func AcquireConn(db *sql.DB) (*pgx.Conn, error) { + driver, ok := db.Driver().(*Driver) + if !ok { + return nil, ErrNotPgx + } + + var conn *pgx.Conn + ctx := context.WithValue(context.Background(), ctxKeyFakeTx, &conn) + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + + driver.fakeTxMutex.Lock() + driver.fakeTxConns[conn] = tx + driver.fakeTxMutex.Unlock() + + return conn, nil } -func (t *Tx) Rollback() error { - _, err := t.conn.Exec("rollback") - return err +func ReleaseConn(db *sql.DB, conn *pgx.Conn) error { + var tx *sql.Tx + var ok bool + + driver := db.Driver().(*Driver) + driver.fakeTxMutex.Lock() + tx, ok = driver.fakeTxConns[conn] + if ok { + delete(driver.fakeTxConns, conn) + driver.fakeTxMutex.Unlock() + } else { + driver.fakeTxMutex.Unlock() + return errors.Errorf("can't release conn that is not acquired") + } + + return tx.Rollback() } diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 1455ca1d..65f80ac4 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -2,11 +2,16 @@ package stdlib_test import ( "bytes" + "context" "database/sql" - "github.com/jackc/pgx" - "github.com/jackc/pgx/stdlib" - "sync" + "fmt" "testing" + "time" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgmock" + "github.com/jackc/pgx/pgproto3" + "github.com/jackc/pgx/stdlib" ) func openDB(t *testing.T) *sql.DB { @@ -119,86 +124,31 @@ func TestNormalLifeCycle(t *testing.T) { ensureConnValid(t, db) } -func TestSqlOpenDoesNotHavePool(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) - - driver := db.Driver().(*stdlib.Driver) - if driver.Pool != nil { - t.Fatal("Did not expect driver opened through database/sql to have Pool, but it did") - } -} - -func TestOpenFromConnPool(t *testing.T) { - connConfig := pgx.ConnConfig{ - Host: "127.0.0.1", - User: "pgx_md5", - Password: "secret", - Database: "pgx_test", +func TestOpenWithDriverConfigAfterConnect(t *testing.T) { + driverConfig := stdlib.DriverConfig{ + AfterConnect: func(c *pgx.Conn) error { + _, err := c.Exec("create temporary sequence pgx") + return err + }, } - config := pgx.ConnPoolConfig{ConnConfig: connConfig} - pool, err := pgx.NewConnPool(config) + stdlib.RegisterDriverConfig(&driverConfig) + defer stdlib.UnregisterDriverConfig(&driverConfig) + + db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")) if err != nil { - t.Fatalf("Unable to create connection pool: %v", err) - } - defer pool.Close() - - db, err := stdlib.OpenFromConnPool(pool) - if err != nil { - t.Fatalf("Unable to create connection pool: %v", err) + t.Fatalf("sql.Open failed: %v", err) } defer closeDB(t, db) - // Can get pgx.ConnPool from driver - driver := db.Driver().(*stdlib.Driver) - if driver.Pool == nil { - t.Fatal("Expected driver opened through OpenFromConnPool to have Pool, but it did not") - } - - // Normal sql/database still works var n int64 - err = db.QueryRow("select 1").Scan(&n) + err = db.QueryRow("select nextval('pgx')").Scan(&n) if err != nil { t.Fatalf("db.QueryRow unexpectedly failed: %v", err) } -} - -func TestOpenFromConnPoolRace(t *testing.T) { - wg := &sync.WaitGroup{} - connConfig := pgx.ConnConfig{ - Host: "127.0.0.1", - User: "pgx_md5", - Password: "secret", - Database: "pgx_test", + if n != 1 { + t.Fatalf("n => %d, want %d", n, 1) } - - config := pgx.ConnPoolConfig{ConnConfig: connConfig} - pool, err := pgx.NewConnPool(config) - if err != nil { - t.Fatalf("Unable to create connection pool: %v", err) - } - defer pool.Close() - - wg.Add(10) - for i := 0; i < 10; i++ { - go func() { - defer wg.Done() - db, err := stdlib.OpenFromConnPool(pool) - if err != nil { - t.Fatalf("Unable to create connection pool: %v", err) - } - defer closeDB(t, db) - - // Can get pgx.ConnPool from driver - driver := db.Driver().(*stdlib.Driver) - if driver.Pool == nil { - t.Fatal("Expected driver opened through OpenFromConnPool to have Pool, but it did not") - } - }() - } - - wg.Wait() } func TestStmtExec(t *testing.T) { @@ -364,67 +314,53 @@ func TestConnQuery(t *testing.T) { } type testLog struct { - lvl int - msg string - ctx []interface{} + lvl pgx.LogLevel + msg string + data map[string]interface{} } type testLogger struct { logs []testLog } -func (l *testLogger) Debug(msg string, ctx ...interface{}) { - l.logs = append(l.logs, testLog{lvl: pgx.LogLevelDebug, msg: msg, ctx: ctx}) -} -func (l *testLogger) Info(msg string, ctx ...interface{}) { - l.logs = append(l.logs, testLog{lvl: pgx.LogLevelInfo, msg: msg, ctx: ctx}) -} -func (l *testLogger) Warn(msg string, ctx ...interface{}) { - l.logs = append(l.logs, testLog{lvl: pgx.LogLevelWarn, msg: msg, ctx: ctx}) -} -func (l *testLogger) Error(msg string, ctx ...interface{}) { - l.logs = append(l.logs, testLog{lvl: pgx.LogLevelError, msg: msg, ctx: ctx}) +func (l *testLogger) Log(lvl pgx.LogLevel, msg string, data map[string]interface{}) { + l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data}) } func TestConnQueryLog(t *testing.T) { logger := &testLogger{} - connConfig := pgx.ConnConfig{ - Host: "127.0.0.1", - User: "pgx_md5", - Password: "secret", - Database: "pgx_test", - Logger: logger, + driverConfig := stdlib.DriverConfig{ + ConnConfig: pgx.ConnConfig{ + Host: "127.0.0.1", + User: "pgx_md5", + Password: "secret", + Database: "pgx_test", + Logger: logger, + }, } - config := pgx.ConnPoolConfig{ConnConfig: connConfig} - pool, err := pgx.NewConnPool(config) - if err != nil { - t.Fatalf("Unable to create connection pool: %v", err) - } - defer pool.Close() + stdlib.RegisterDriverConfig(&driverConfig) + defer stdlib.UnregisterDriverConfig(&driverConfig) - db, err := stdlib.OpenFromConnPool(pool) + db, err := sql.Open("pgx", driverConfig.ConnectionString("")) if err != nil { - t.Fatalf("Unable to create connection pool: %v", err) + t.Fatalf("sql.Open failed: %v", err) } defer closeDB(t, db) - // clear logs from initial connection - logger.logs = []testLog{} - var n int64 err = db.QueryRow("select 1").Scan(&n) if err != nil { t.Fatalf("db.QueryRow unexpectedly failed: %v", err) } - l := logger.logs[0] + l := logger.logs[len(logger.logs)-1] if l.msg != "Query" { t.Errorf("Expected to log Query, but got %v", l) } - if !(l.ctx[0] == "sql" && l.ctx[1] == "select 1") { + if l.data["sql"] != "select 1" { t.Errorf("Expected to log Query with sql 'select 1', but got %v", l) } } @@ -544,10 +480,6 @@ func TestConnQueryJSONIntoByteSlice(t *testing.T) { db := openDB(t) defer closeDB(t, db) - if !serverHasJSON(t, db) { - t.Skip("Skipping due to server's lack of JSON type") - } - _, err := db.Exec(` create temporary table docs( body json not null @@ -584,10 +516,6 @@ func TestConnExecInsertByteSliceIntoJSON(t *testing.T) { db := openDB(t) defer closeDB(t, db) - if !serverHasJSON(t, db) { - t.Skip("Skipping due to server's lack of JSON type") - } - _, err := db.Exec(` create temporary table docs( body json not null @@ -622,15 +550,6 @@ func TestConnExecInsertByteSliceIntoJSON(t *testing.T) { ensureConnValid(t, db) } -func serverHasJSON(t *testing.T, db *sql.DB) bool { - var hasJSON bool - err := db.QueryRow(`select exists(select 1 from pg_type where typname='json')`).Scan(&hasJSON) - if err != nil { - t.Fatalf("db.QueryRow unexpectedly failed: %v", err) - } - return hasJSON -} - func TestTransactionLifeCycle(t *testing.T) { db := openDB(t) defer closeDB(t, db) @@ -689,3 +608,653 @@ func TestTransactionLifeCycle(t *testing.T) { ensureConnValid(t, db) } + +func TestConnBeginTxIsolation(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + var defaultIsoLevel string + err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel) + if err != nil { + t.Fatalf("QueryRow failed: %v", err) + } + + supportedTests := []struct { + sqlIso sql.IsolationLevel + pgIso string + }{ + {sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel}, + {sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"}, + {sqlIso: sql.LevelReadCommitted, pgIso: "read committed"}, + {sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"}, + {sqlIso: sql.LevelSerializable, pgIso: "serializable"}, + } + for i, tt := range supportedTests { + func() { + tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso}) + if err != nil { + t.Errorf("%d. BeginTx failed: %v", i, err) + return + } + defer tx.Rollback() + + var pgIso string + err = tx.QueryRow("show transaction_isolation").Scan(&pgIso) + if err != nil { + t.Errorf("%d. QueryRow failed: %v", i, err) + } + + if pgIso != tt.pgIso { + t.Errorf("%d. pgIso => %s, want %s", i, pgIso, tt.pgIso) + } + }() + } + + unsupportedTests := []struct { + sqlIso sql.IsolationLevel + }{ + {sqlIso: sql.LevelWriteCommitted}, + {sqlIso: sql.LevelLinearizable}, + } + for i, tt := range unsupportedTests { + tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso}) + if err == nil { + t.Errorf("%d. BeginTx should have failed", i) + tx.Rollback() + } + } + + ensureConnValid(t, db) +} + +func TestConnBeginTxReadOnly(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) + if err != nil { + t.Fatalf("BeginTx failed: %v", err) + } + defer tx.Rollback() + + var pgReadOnly string + err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly) + if err != nil { + t.Errorf("QueryRow failed: %v", err) + } + + if pgReadOnly != "on" { + t.Errorf("pgReadOnly => %s, want %s", pgReadOnly, "on") + } + + ensureConnValid(t, db) +} + +func TestBeginTxContextCancel(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + _, err := db.Exec("drop table if exists t") + if err != nil { + t.Fatalf("db.Exec failed: %v", err) + } + + ctx, cancelFn := context.WithCancel(context.Background()) + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatalf("BeginTx failed: %v", err) + } + + _, err = tx.Exec("create table t(id serial)") + if err != nil { + t.Fatalf("tx.Exec failed: %v", err) + } + + cancelFn() + + err = tx.Commit() + if err != context.Canceled && err != sql.ErrTxDone { + t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone) + } + + var n int + err = db.QueryRow("select count(*) from t").Scan(&n) + if pgErr, ok := err.(pgx.PgError); !ok || pgErr.Code != "42P01" { + t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err) + } + + ensureConnValid(t, db) +} + +func acceptStandardPgxConn(backend *pgproto3.Backend) error { + script := pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + + err := script.Run(backend) + if err != nil { + return err + } + + typeScript := pgmock.Script{ + Steps: pgmock.PgxInitSteps(), + } + + return typeScript.Run(backend) +} + +func TestBeginTxContextCancelWithDeadConn(t *testing.T) { + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}), + pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: "BEGIN"}), + pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'T'}), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + + errChan := make(chan error) + go func() { + errChan <- server.ServeOne() + }() + + db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + defer closeDB(t, db) + + ctx, cancelFn := context.WithCancel(context.Background()) + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatalf("BeginTx failed: %v", err) + } + + cancelFn() + + err = tx.Commit() + if err != context.Canceled && err != sql.ErrTxDone { + t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone) + } + + if err := <-errChan; err != nil { + t.Fatalf("mock server err: %v", err) + } +} + +func TestAcquireConn(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + var conns []*pgx.Conn + + for i := 1; i < 6; i++ { + conn, err := stdlib.AcquireConn(db) + if err != nil { + t.Errorf("%d. AcquireConn failed: %v", i, err) + continue + } + + var n int32 + err = conn.QueryRow("select 1").Scan(&n) + if err != nil { + t.Errorf("%d. QueryRow failed: %v", i, err) + } + if n != 1 { + t.Errorf("%d. n => %d, want %d", i, n, 1) + } + + stats := db.Stats() + if stats.OpenConnections != i { + t.Errorf("%d. stats.OpenConnections => %d, want %d", i, stats.OpenConnections, i) + } + + conns = append(conns, conn) + } + + for i, conn := range conns { + if err := stdlib.ReleaseConn(db, conn); err != nil { + t.Errorf("%d. stdlib.ReleaseConn failed: %v", i, err) + } + } + + ensureConnValid(t, db) +} + +func TestConnPingContextSuccess(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + if err := db.PingContext(context.Background()); err != nil { + t.Fatalf("db.PingContext failed: %v", err) + } + + ensureConnValid(t, db) +} + +func TestConnPingContextCancel(t *testing.T) { + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Query{String: ";"}), + pgmock.WaitForClose(), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + errChan := make(chan error, 1) + go func() { + errChan <- server.ServeOne() + }() + + db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + defer closeDB(t, db) + + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + + err = db.PingContext(ctx) + if err != context.DeadlineExceeded { + t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) + } + + if err := <-errChan; err != nil { + t.Errorf("mock server err: %v", err) + } +} + +func TestConnPrepareContextSuccess(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + stmt, err := db.PrepareContext(context.Background(), "select now()") + if err != nil { + t.Fatalf("db.PrepareContext failed: %v", err) + } + stmt.Close() + + ensureConnValid(t, db) +} + +func TestConnPrepareContextCancel(t *testing.T) { + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Parse{Name: "pgx_0", Query: "select now()"}), + pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: "pgx_0"}), + pgmock.ExpectMessage(&pgproto3.Sync{}), + pgmock.WaitForClose(), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + errChan := make(chan error) + go func() { + errChan <- server.ServeOne() + }() + + db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + defer closeDB(t, db) + + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + + _, err = db.PrepareContext(ctx, "select now()") + if err != context.DeadlineExceeded { + t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) + } + + if err := <-errChan; err != nil { + t.Errorf("mock server err: %v", err) + } +} + +func TestConnExecContextSuccess(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + _, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)") + if err != nil { + t.Fatalf("db.ExecContext failed: %v", err) + } + + ensureConnValid(t, db) +} + +func TestConnExecContextCancel(t *testing.T) { + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Query{String: "create temporary table exec_context_test(id serial primary key)"}), + pgmock.WaitForClose(), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + errChan := make(chan error) + go func() { + errChan <- server.ServeOne() + }() + + db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + defer closeDB(t, db) + + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + + _, err = db.ExecContext(ctx, "create temporary table exec_context_test(id serial primary key)") + if err != context.DeadlineExceeded { + t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) + } + + if err := <-errChan; err != nil { + t.Errorf("mock server err: %v", err) + } +} + +func TestConnQueryContextSuccess(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n") + if err != nil { + t.Fatalf("db.QueryContext failed: %v", err) + } + + for rows.Next() { + var n int64 + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + } + + if rows.Err() != nil { + t.Error(rows.Err()) + } + + ensureConnValid(t, db) +} + +func TestConnQueryContextCancel(t *testing.T) { + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Parse{Query: "select * from generate_series(1,10) n"}), + pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S'}), + pgmock.ExpectMessage(&pgproto3.Sync{}), + + pgmock.SendMessage(&pgproto3.ParseComplete{}), + pgmock.SendMessage(&pgproto3.ParameterDescription{}), + pgmock.SendMessage(&pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + { + Name: "n", + DataTypeOID: 23, + DataTypeSize: 4, + TypeModifier: 4294967295, + }, + }, + }), + pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + + pgmock.ExpectMessage(&pgproto3.Bind{ResultFormatCodes: []int16{1}}), + pgmock.ExpectMessage(&pgproto3.Execute{}), + pgmock.ExpectMessage(&pgproto3.Sync{}), + + pgmock.SendMessage(&pgproto3.BindComplete{}), + pgmock.WaitForClose(), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + errChan := make(chan error) + go func() { + errChan <- server.ServeOne() + }() + + db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + defer db.Close() + + ctx, cancelFn := context.WithCancel(context.Background()) + + rows, err := db.QueryContext(ctx, "select * from generate_series(1,10) n") + if err != nil { + t.Fatalf("db.QueryContext failed: %v", err) + } + + cancelFn() + + for rows.Next() { + t.Fatalf("no rows should ever be received") + } + + if rows.Err() != context.Canceled { + t.Errorf("rows.Err() => %v, want %v", rows.Err(), context.Canceled) + } + + if err := <-errChan; err != nil { + t.Errorf("mock server err: %v", err) + } +} + +func TestRowsColumnTypeDatabaseTypeName(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + rows, err := db.Query("select * from generate_series(1,10) n") + if err != nil { + t.Fatalf("db.Query failed: %v", err) + } + + columnTypes, err := rows.ColumnTypes() + if err != nil { + t.Fatalf("rows.ColumnTypes failed: %v", err) + } + + if len(columnTypes) != 1 { + t.Fatalf("len(columnTypes) => %v, want %v", len(columnTypes), 1) + } + + if columnTypes[0].DatabaseTypeName() != "INT4" { + t.Errorf("columnTypes[0].DatabaseTypeName() => %v, want %v", columnTypes[0].DatabaseTypeName(), "INT4") + } + + rows.Close() + + ensureConnValid(t, db) +} + +func TestStmtExecContextSuccess(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + _, err := db.Exec("create temporary table t(id int primary key)") + if err != nil { + t.Fatalf("db.Exec failed: %v", err) + } + + stmt, err := db.Prepare("insert into t(id) values ($1::int4)") + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + _, err = stmt.ExecContext(context.Background(), 42) + if err != nil { + t.Fatal(err) + } + + ensureConnValid(t, db) +} + +func TestStmtExecContextCancel(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + _, err := db.Exec("create temporary table t(id int primary key)") + if err != nil { + t.Fatalf("db.Exec failed: %v", err) + } + + stmt, err := db.Prepare("insert into t(id) select $1::int4 from pg_sleep(5)") + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + + _, err = stmt.ExecContext(ctx, 42) + if err != context.DeadlineExceeded { + t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) + } + + ensureConnValid(t, db) +} + +func TestStmtQueryContextSuccess(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + stmt, err := db.Prepare("select * from generate_series(1,$1::int4) n") + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + rows, err := stmt.QueryContext(context.Background(), 5) + if err != nil { + t.Fatalf("stmt.QueryContext failed: %v", err) + } + + for rows.Next() { + var n int64 + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + } + + if rows.Err() != nil { + t.Error(rows.Err()) + } + + ensureConnValid(t, db) +} + +func TestStmtQueryContextCancel(t *testing.T) { + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Parse{Name: "pgx_0", Query: "select * from generate_series(1, $1::int4) n"}), + pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: "pgx_0"}), + pgmock.ExpectMessage(&pgproto3.Sync{}), + + pgmock.SendMessage(&pgproto3.ParseComplete{}), + pgmock.SendMessage(&pgproto3.ParameterDescription{ParameterOIDs: []uint32{23}}), + pgmock.SendMessage(&pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + { + Name: "n", + DataTypeOID: 23, + DataTypeSize: 4, + TypeModifier: 4294967295, + }, + }, + }), + pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + + pgmock.ExpectMessage(&pgproto3.Bind{PreparedStatement: "pgx_0", ParameterFormatCodes: []int16{1}, Parameters: [][]uint8{[]uint8{0x0, 0x0, 0x0, 0x2a}}, ResultFormatCodes: []int16{1}}), + pgmock.ExpectMessage(&pgproto3.Execute{}), + pgmock.ExpectMessage(&pgproto3.Sync{}), + + pgmock.SendMessage(&pgproto3.BindComplete{}), + pgmock.WaitForClose(), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + errChan := make(chan error) + go func() { + errChan <- server.ServeOne() + }() + + db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + // defer closeDB(t, db) // mock DB doesn't close correctly yet + + stmt, err := db.Prepare("select * from generate_series(1, $1::int4) n") + if err != nil { + t.Fatal(err) + } + // defer stmt.Close() + + ctx, cancelFn := context.WithCancel(context.Background()) + + rows, err := stmt.QueryContext(ctx, 42) + if err != nil { + t.Fatalf("stmt.QueryContext failed: %v", err) + } + + cancelFn() + + for rows.Next() { + t.Fatalf("no rows should ever be received") + } + + if rows.Err() != context.Canceled { + t.Errorf("rows.Err() => %v, want %v", rows.Err(), context.Canceled) + } + + if err := <-errChan; err != nil { + t.Errorf("mock server err: %v", err) + } +} diff --git a/stress_test.go b/stress_test.go index 150d13c8..114bec81 100644 --- a/stress_test.go +++ b/stress_test.go @@ -1,12 +1,16 @@ package pgx_test import ( - "errors" + "context" "fmt" "math/rand" + "os" + "strconv" "testing" "time" + "github.com/pkg/errors" + "github.com/jackc/fake" "github.com/jackc/pgx" ) @@ -22,6 +26,8 @@ type queryRower interface { } func TestStressConnPool(t *testing.T) { + t.Parallel() + maxConnections := 8 pool := createConnPool(t, maxConnections) defer pool.Close() @@ -44,14 +50,19 @@ func TestStressConnPool(t *testing.T) { {"listenAndPoolUnlistens", listenAndPoolUnlistens}, {"reset", func(p *pgx.ConnPool, n int) error { p.Reset(); return nil }}, {"poolPrepareUseAndDeallocate", poolPrepareUseAndDeallocate}, + {"canceledQueryExContext", canceledQueryExContext}, + {"canceledExecExContext", canceledExecExContext}, } - var timer *time.Timer - if testing.Short() { - timer = time.NewTimer(5 * time.Second) - } else { - timer = time.NewTimer(60 * time.Second) + actionCount := 1000 + if s := os.Getenv("STRESS_FACTOR"); s != "" { + stressFactor, err := strconv.ParseInt(s, 10, 64) + if err != nil { + t.Fatalf("failed to parse STRESS_FACTOR: %v", s) + } + actionCount *= int(stressFactor) } + workerCount := 16 workChan := make(chan int) @@ -63,7 +74,7 @@ func TestStressConnPool(t *testing.T) { action := actions[rand.Intn(len(actions))] err := action.fn(pool, n) if err != nil { - errChan <- err + errChan <- errors.Errorf("%s: %v", action.name, err) break } } @@ -74,11 +85,8 @@ func TestStressConnPool(t *testing.T) { go work() } - var stop bool - for i := 0; !stop; i++ { + for i := 0; i < actionCount; i++ { select { - case <-timer.C: - stop = true case workChan <- i: case err := <-errChan: close(workChan) @@ -92,42 +100,6 @@ func TestStressConnPool(t *testing.T) { } } -func TestStressTLSConnection(t *testing.T) { - t.Parallel() - - if tlsConnConfig == nil { - t.Skip("Skipping due to undefined tlsConnConfig") - } - - if testing.Short() { - t.Skip("Skipping due to testing -short") - } - - conn, err := pgx.Connect(*tlsConnConfig) - if err != nil { - t.Fatalf("Unable to establish connection: %v", err) - } - defer conn.Close() - - for i := 0; i < 50; i++ { - sql := `select * from generate_series(1, $1)` - - rows, err := conn.Query(sql, 2000000) - if err != nil { - t.Fatal(err) - } - - var n int32 - for rows.Next() { - rows.Scan(&n) - } - - if rows.Err() != nil { - t.Fatalf("queryCount: %d, Row number: %d. %v", i, n, rows.Err()) - } - } -} - func setupStressDB(t *testing.T, pool *pgx.ConnPool) { _, err := pool.Exec(` drop table if exists widgets; @@ -241,8 +213,9 @@ func listenAndPoolUnlistens(pool *pgx.ConnPool, actionNum int) error { return err } - _, err = conn.WaitForNotification(100 * time.Millisecond) - if err == pgx.ErrNotificationTimeout { + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + _, err = conn.WaitForNotification(ctx) + if err == context.DeadlineExceeded { return nil } return err @@ -263,7 +236,7 @@ func poolPrepareUseAndDeallocate(pool *pgx.ConnPool, actionNum int) error { } if s != "hello" { - return fmt.Errorf("Prepared statement did not return expected value: %v", s) + return errors.Errorf("Prepared statement did not return expected value: %v", s) } return pool.Deallocate(psName) @@ -344,3 +317,43 @@ func txMultipleQueries(pool *pgx.ConnPool, actionNum int) error { return tx.Commit() } + +func canceledQueryExContext(pool *pgx.ConnPool, actionNum int) error { + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond) + cancelFunc() + }() + + rows, err := pool.QueryEx(ctx, "select pg_sleep(2)", nil) + if err == context.Canceled { + return nil + } else if err != nil { + return errors.Errorf("Only allowed error is context.Canceled, got %v", err) + } + + for rows.Next() { + return errors.New("should never receive row") + } + + if rows.Err() != context.Canceled { + return errors.Errorf("Expected context.Canceled error, got %v", rows.Err()) + } + + return nil +} + +func canceledExecExContext(pool *pgx.ConnPool, actionNum int) error { + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond) + cancelFunc() + }() + + _, err := pool.ExecEx(ctx, "select pg_sleep(2)", nil) + if err != context.Canceled { + return errors.Errorf("Expected context.Canceled error, got %v", err) + } + + return nil +} diff --git a/tx.go b/tx.go index deb6c01c..f9607f70 100644 --- a/tx.go +++ b/tx.go @@ -1,16 +1,38 @@ package pgx import ( - "errors" + "bytes" + "context" "fmt" + "time" + + "github.com/pkg/errors" ) +type TxIsoLevel string + // Transaction isolation levels const ( - Serializable = "serializable" - RepeatableRead = "repeatable read" - ReadCommitted = "read committed" - ReadUncommitted = "read uncommitted" + Serializable = TxIsoLevel("serializable") + RepeatableRead = TxIsoLevel("repeatable read") + ReadCommitted = TxIsoLevel("read committed") + ReadUncommitted = TxIsoLevel("read uncommitted") +) + +type TxAccessMode string + +// Transaction access modes +const ( + ReadWrite = TxAccessMode("read write") + ReadOnly = TxAccessMode("read only") +) + +type TxDeferrableMode string + +// Transaction deferrable modes +const ( + Deferrable = TxDeferrableMode("deferrable") + NotDeferrable = TxDeferrableMode("not deferrable") ) const ( @@ -21,6 +43,32 @@ const ( TxStatusRollbackSuccess = 2 ) +type TxOptions struct { + IsoLevel TxIsoLevel + AccessMode TxAccessMode + DeferrableMode TxDeferrableMode +} + +func (txOptions *TxOptions) beginSQL() string { + if txOptions == nil { + return "begin" + } + + buf := &bytes.Buffer{} + buf.WriteString("begin") + if txOptions.IsoLevel != "" { + fmt.Fprintf(buf, " isolation level %s", txOptions.IsoLevel) + } + if txOptions.AccessMode != "" { + fmt.Fprintf(buf, " %s", txOptions.AccessMode) + } + if txOptions.DeferrableMode != "" { + fmt.Fprintf(buf, " %s", txOptions.DeferrableMode) + } + + return buf.String() +} + var ErrTxClosed = errors.New("tx is closed") // ErrTxCommitRollback occurs when an error has occurred in a transaction and @@ -28,34 +76,21 @@ var ErrTxClosed = errors.New("tx is closed") // it is treated as ROLLBACK. var ErrTxCommitRollback = errors.New("commit unexpectedly resulted in rollback") -// Begin starts a transaction with the default isolation level for the current -// connection. To use a specific isolation level see BeginIso. +// Begin starts a transaction with the default transaction mode for the +// current connection. To use a specific transaction mode see BeginEx. func (c *Conn) Begin() (*Tx, error) { - return c.begin("") + return c.BeginEx(context.Background(), nil) } -// BeginIso starts a transaction with isoLevel as the transaction isolation -// level. -// -// Valid isolation levels (and their constants) are: -// serializable (pgx.Serializable) -// repeatable read (pgx.RepeatableRead) -// read committed (pgx.ReadCommitted) -// read uncommitted (pgx.ReadUncommitted) -func (c *Conn) BeginIso(isoLevel string) (*Tx, error) { - return c.begin(isoLevel) -} - -func (c *Conn) begin(isoLevel string) (*Tx, error) { - var beginSQL string - if isoLevel == "" { - beginSQL = "begin" - } else { - beginSQL = fmt.Sprintf("begin isolation level %s", isoLevel) - } - - _, err := c.Exec(beginSQL) +// BeginEx starts a transaction with txOptions determining the transaction +// mode. Unlike database/sql, the context only affects the begin command. i.e. +// there is no auto-rollback on context cancelation. +func (c *Conn) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) { + _, err := c.ExecEx(ctx, txOptions.beginSQL(), nil) if err != nil { + // begin should never fail unless there is an underlying connection issue or + // a context timeout. In either case, the connection is possibly broken. + c.die(errors.New("failed to begin transaction")) return nil, err } @@ -67,19 +102,24 @@ func (c *Conn) begin(isoLevel string) (*Tx, error) { // All Tx methods return ErrTxClosed if Commit or Rollback has already been // called on the Tx. type Tx struct { - conn *Conn - afterClose func(*Tx) - err error - status int8 + conn *Conn + connPool *ConnPool + err error + status int8 } // Commit commits the transaction func (tx *Tx) Commit() error { + return tx.CommitEx(context.Background()) +} + +// CommitEx commits the transaction with a context. +func (tx *Tx) CommitEx(ctx context.Context) error { if tx.status != TxStatusInProgress { return ErrTxClosed } - commandTag, err := tx.conn.Exec("commit") + commandTag, err := tx.conn.ExecEx(ctx, "commit", nil) if err == nil && commandTag == "COMMIT" { tx.status = TxStatusCommitSuccess } else if err == nil && commandTag == "ROLLBACK" { @@ -88,11 +128,14 @@ func (tx *Tx) Commit() error { } else { tx.status = TxStatusCommitFailure tx.err = err + // A commit failure leaves the connection in an undefined state + tx.conn.die(errors.New("commit failed")) } - if tx.afterClose != nil { - tx.afterClose(tx) + if tx.connPool != nil { + tx.connPool.Release(tx.conn) } + return tx.err } @@ -105,16 +148,20 @@ func (tx *Tx) Rollback() error { return ErrTxClosed } - _, tx.err = tx.conn.Exec("rollback") + ctx, _ := context.WithTimeout(context.Background(), 15*time.Second) + _, tx.err = tx.conn.ExecEx(ctx, "rollback", nil) if tx.err == nil { tx.status = TxStatusRollbackSuccess } else { tx.status = TxStatusRollbackFailure + // A rollback failure leaves the connection in an undefined state + tx.conn.die(errors.New("rollback failed")) } - if tx.afterClose != nil { - tx.afterClose(tx) + if tx.connPool != nil { + tx.connPool.Release(tx.conn) } + return tx.err } @@ -129,16 +176,16 @@ func (tx *Tx) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, // Prepare delegates to the underlying *Conn func (tx *Tx) Prepare(name, sql string) (*PreparedStatement, error) { - return tx.PrepareEx(name, sql, nil) + return tx.PrepareEx(context.Background(), name, sql, nil) } // PrepareEx delegates to the underlying *Conn -func (tx *Tx) PrepareEx(name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) { +func (tx *Tx) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) { if tx.status != TxStatusInProgress { return nil, ErrTxClosed } - return tx.conn.PrepareEx(name, sql, opts) + return tx.conn.PrepareEx(ctx, name, sql, opts) } // Query delegates to the underlying *Conn @@ -158,15 +205,6 @@ func (tx *Tx) QueryRow(sql string, args ...interface{}) *Row { return (*Row)(rows) } -// Deprecated. Use CopyFrom instead. CopyTo delegates to the underlying *Conn -func (tx *Tx) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) { - if tx.status != TxStatusInProgress { - return 0, ErrTxClosed - } - - return tx.conn.CopyTo(tableName, columnNames, rowSrc) -} - // CopyFrom delegates to the underlying *Conn func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) { if tx.status != TxStatusInProgress { @@ -176,11 +214,6 @@ func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFr return tx.conn.CopyFrom(tableName, columnNames, rowSrc) } -// Conn returns the *Conn this transaction is using. -func (tx *Tx) Conn() *Conn { - return tx.conn -} - // Status returns the status of the transaction from the set of // pgx.TxStatus* constants. func (tx *Tx) Status() int8 { @@ -191,17 +224,3 @@ func (tx *Tx) Status() int8 { func (tx *Tx) Err() error { return tx.err } - -// AfterClose adds f to a LILO queue of functions that will be called when -// the transaction is closed (either Commit or Rollback). -func (tx *Tx) AfterClose(f func(*Tx)) { - if tx.afterClose == nil { - tx.afterClose = f - } else { - prevFn := tx.afterClose - tx.afterClose = func(tx *Tx) { - f(tx) - prevFn(tx) - } - } -} diff --git a/tx_test.go b/tx_test.go index 435521a3..b25e1c9f 100644 --- a/tx_test.go +++ b/tx_test.go @@ -1,9 +1,14 @@ package pgx_test import ( - "github.com/jackc/pgx" + "context" + "fmt" "testing" "time" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgmock" + "github.com/jackc/pgx/pgproto3" ) func TestTransactionSuccessfulCommit(t *testing.T) { @@ -107,15 +112,15 @@ func TestTxCommitSerializationFailure(t *testing.T) { } defer pool.Exec(`drop table tx_serializable_sums`) - tx1, err := pool.BeginIso(pgx.Serializable) + tx1, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable}) if err != nil { - t.Fatalf("BeginIso failed: %v", err) + t.Fatalf("BeginEx failed: %v", err) } defer tx1.Rollback() - tx2, err := pool.BeginIso(pgx.Serializable) + tx2, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable}) if err != nil { - t.Fatalf("BeginIso failed: %v", err) + t.Fatalf("BeginEx failed: %v", err) } defer tx2.Rollback() @@ -182,20 +187,20 @@ func TestTransactionSuccessfulRollback(t *testing.T) { } } -func TestBeginIso(t *testing.T) { +func TestBeginExIsoLevels(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - isoLevels := []string{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted} + isoLevels := []pgx.TxIsoLevel{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted} for _, iso := range isoLevels { - tx, err := conn.BeginIso(iso) + tx, err := conn.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: iso}) if err != nil { - t.Fatalf("conn.BeginIso failed: %v", err) + t.Fatalf("conn.BeginEx failed: %v", err) } - var level string + var level pgx.TxIsoLevel conn.QueryRow("select current_setting('transaction_isolation')").Scan(&level) if level != iso { t.Errorf("Expected to be in isolation level %v but was %v", iso, level) @@ -208,38 +213,120 @@ func TestBeginIso(t *testing.T) { } } -func TestTxAfterClose(t *testing.T) { +func TestBeginExReadOnly(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) + tx, err := conn.BeginEx(context.Background(), &pgx.TxOptions{AccessMode: pgx.ReadOnly}) + if err != nil { + t.Fatalf("conn.BeginEx failed: %v", err) + } + defer tx.Rollback() + + _, err = conn.Exec("create table foo(id serial primary key)") + if pgErr, ok := err.(pgx.PgError); !ok || pgErr.Code != "25006" { + t.Errorf("Expected error SQLSTATE 25006, but got %#v", err) + } +} + +func TestConnBeginExContextCancel(t *testing.T) { + t.Parallel() + + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}), + pgmock.WaitForClose(), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + errChan := make(chan error, 1) + go func() { + errChan <- server.ServeOne() + }() + + mockConfig, err := pgx.ParseURI(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatal(err) + } + + conn := mustConnect(t, mockConfig) + + ctx, _ := context.WithTimeout(context.Background(), 50*time.Millisecond) + + _, err = conn.BeginEx(ctx, nil) + if err != context.DeadlineExceeded { + t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) + } + + if conn.IsAlive() { + t.Error("expected conn to be dead after BeginEx failure") + } + + if err := <-errChan; err != nil { + t.Errorf("mock server err: %v", err) + } +} + +func TestTxCommitExCancel(t *testing.T) { + t.Parallel() + + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}), + pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: "BEGIN"}), + pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'T'}), + pgmock.WaitForClose(), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + errChan := make(chan error, 1) + go func() { + errChan <- server.ServeOne() + }() + + mockConfig, err := pgx.ParseURI(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatal(err) + } + + conn := mustConnect(t, mockConfig) + defer conn.Close() + tx, err := conn.Begin() if err != nil { t.Fatal(err) } - var zeroTime, t1, t2 time.Time - tx.AfterClose(func(tx *pgx.Tx) { - t1 = time.Now() - }) - - tx.AfterClose(func(tx *pgx.Tx) { - t2 = time.Now() - }) - - tx.Rollback() - - if t1 == zeroTime { - t.Error("First Tx.AfterClose callback not called") + ctx, _ := context.WithTimeout(context.Background(), 50*time.Millisecond) + err = tx.CommitEx(ctx) + if err != context.DeadlineExceeded { + t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) } - if t2 == zeroTime { - t.Error("Second Tx.AfterClose callback not called") + if conn.IsAlive() { + t.Error("expected conn to be dead after CommitEx failure") } - if t1.Before(t2) { - t.Errorf("AfterClose callbacks called out of order: %v, %v", t1, t2) + if err := <-errChan; err != nil { + t.Errorf("mock server err: %v", err) } } diff --git a/value_reader.go b/value_reader.go deleted file mode 100644 index a4897543..00000000 --- a/value_reader.go +++ /dev/null @@ -1,156 +0,0 @@ -package pgx - -import ( - "errors" -) - -// ValueReader is used by the Scanner interface to decode values. -type ValueReader struct { - mr *msgReader - fd *FieldDescription - valueBytesRemaining int32 - err error -} - -// Err returns any error that the ValueReader has experienced -func (r *ValueReader) Err() error { - return r.err -} - -// Fatal tells r that a Fatal error has occurred -func (r *ValueReader) Fatal(err error) { - r.err = err -} - -// Len returns the number of unread bytes -func (r *ValueReader) Len() int32 { - return r.valueBytesRemaining -} - -// Type returns the *FieldDescription of the value -func (r *ValueReader) Type() *FieldDescription { - return r.fd -} - -func (r *ValueReader) ReadByte() byte { - if r.err != nil { - return 0 - } - - r.valueBytesRemaining-- - if r.valueBytesRemaining < 0 { - r.Fatal(errors.New("read past end of value")) - return 0 - } - - return r.mr.readByte() -} - -func (r *ValueReader) ReadInt16() int16 { - if r.err != nil { - return 0 - } - - r.valueBytesRemaining -= 2 - if r.valueBytesRemaining < 0 { - r.Fatal(errors.New("read past end of value")) - return 0 - } - - return r.mr.readInt16() -} - -func (r *ValueReader) ReadUint16() uint16 { - if r.err != nil { - return 0 - } - - r.valueBytesRemaining -= 2 - if r.valueBytesRemaining < 0 { - r.Fatal(errors.New("read past end of value")) - return 0 - } - - return r.mr.readUint16() -} - -func (r *ValueReader) ReadInt32() int32 { - if r.err != nil { - return 0 - } - - r.valueBytesRemaining -= 4 - if r.valueBytesRemaining < 0 { - r.Fatal(errors.New("read past end of value")) - return 0 - } - - return r.mr.readInt32() -} - -func (r *ValueReader) ReadUint32() uint32 { - if r.err != nil { - return 0 - } - - r.valueBytesRemaining -= 4 - if r.valueBytesRemaining < 0 { - r.Fatal(errors.New("read past end of value")) - return 0 - } - - return r.mr.readUint32() -} - -func (r *ValueReader) ReadInt64() int64 { - if r.err != nil { - return 0 - } - - r.valueBytesRemaining -= 8 - if r.valueBytesRemaining < 0 { - r.Fatal(errors.New("read past end of value")) - return 0 - } - - return r.mr.readInt64() -} - -func (r *ValueReader) ReadOid() Oid { - return Oid(r.ReadUint32()) -} - -// ReadString reads count bytes and returns as string -func (r *ValueReader) ReadString(count int32) string { - if r.err != nil { - return "" - } - - r.valueBytesRemaining -= count - if r.valueBytesRemaining < 0 { - r.Fatal(errors.New("read past end of value")) - return "" - } - - return r.mr.readString(count) -} - -// ReadBytes reads count bytes and returns as []byte -func (r *ValueReader) ReadBytes(count int32) []byte { - if r.err != nil { - return nil - } - - if count < 0 { - r.Fatal(errors.New("count must not be negative")) - return nil - } - - r.valueBytesRemaining -= count - if r.valueBytesRemaining < 0 { - r.Fatal(errors.New("read past end of value")) - return nil - } - - return r.mr.readBytes(count) -} diff --git a/values.go b/values.go index a189e180..86ae3afe 100644 --- a/values.go +++ b/values.go @@ -1,62 +1,15 @@ package pgx import ( - "bytes" "database/sql/driver" - "encoding/json" "fmt" - "io" "math" - "net" "reflect" - "regexp" - "strconv" - "strings" "time" -) -// PostgreSQL oids for common types -const ( - BoolOid = 16 - ByteaOid = 17 - CharOid = 18 - NameOid = 19 - Int8Oid = 20 - Int2Oid = 21 - Int4Oid = 23 - TextOid = 25 - OidOid = 26 - TidOid = 27 - XidOid = 28 - CidOid = 29 - JsonOid = 114 - CidrOid = 650 - CidrArrayOid = 651 - Float4Oid = 700 - Float8Oid = 701 - UnknownOid = 705 - InetOid = 869 - BoolArrayOid = 1000 - Int2ArrayOid = 1005 - Int4ArrayOid = 1007 - TextArrayOid = 1009 - ByteaArrayOid = 1001 - VarcharArrayOid = 1015 - Int8ArrayOid = 1016 - Float4ArrayOid = 1021 - Float8ArrayOid = 1022 - AclItemOid = 1033 - AclItemArrayOid = 1034 - InetArrayOid = 1041 - VarcharOid = 1043 - DateOid = 1082 - TimestampOid = 1114 - TimestampArrayOid = 1115 - TimestampTzOid = 1184 - TimestampTzArrayOid = 1185 - RecordOid = 2249 - UuidOid = 2950 - JsonbOid = 3802 + "github.com/jackc/pgx/pgio" + "github.com/jackc/pgx/pgtype" + "github.com/pkg/errors" ) // PostgreSQL format codes @@ -65,61 +18,6 @@ const ( BinaryFormatCode = 1 ) -const maxUint = ^uint(0) -const maxInt = int(maxUint >> 1) -const minInt = -maxInt - 1 - -// DefaultTypeFormats maps type names to their default requested format (text -// or binary). In theory the Scanner interface should be the one to determine -// the format of the returned values. However, the query has already been -// executed by the time Scan is called so it has no chance to set the format. -// So for types that should always be returned in binary the format should be -// set here. -var DefaultTypeFormats map[string]int16 - -func init() { - DefaultTypeFormats = map[string]int16{ - "_aclitem": TextFormatCode, // Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin) - "_bool": BinaryFormatCode, - "_bytea": BinaryFormatCode, - "_cidr": BinaryFormatCode, - "_float4": BinaryFormatCode, - "_float8": BinaryFormatCode, - "_inet": BinaryFormatCode, - "_int2": BinaryFormatCode, - "_int4": BinaryFormatCode, - "_int8": BinaryFormatCode, - "_text": BinaryFormatCode, - "_timestamp": BinaryFormatCode, - "_timestamptz": BinaryFormatCode, - "_varchar": BinaryFormatCode, - "aclitem": TextFormatCode, // Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin) - "bool": BinaryFormatCode, - "bytea": BinaryFormatCode, - "char": BinaryFormatCode, - "cid": BinaryFormatCode, - "cidr": BinaryFormatCode, - "date": BinaryFormatCode, - "float4": BinaryFormatCode, - "float8": BinaryFormatCode, - "json": BinaryFormatCode, - "jsonb": BinaryFormatCode, - "inet": BinaryFormatCode, - "int2": BinaryFormatCode, - "int4": BinaryFormatCode, - "int8": BinaryFormatCode, - "name": BinaryFormatCode, - "oid": BinaryFormatCode, - "record": BinaryFormatCode, - "text": BinaryFormatCode, - "tid": BinaryFormatCode, - "timestamp": BinaryFormatCode, - "timestamptz": BinaryFormatCode, - "varchar": BinaryFormatCode, - "xid": BinaryFormatCode, - } -} - // SerializationError occurs on failure to encode or decode a value type SerializationError string @@ -127,3313 +25,223 @@ func (e SerializationError) Error() string { return string(e) } -// Deprecated: Scanner is an interface used to decode values from the PostgreSQL -// server. To allow types to support pgx and database/sql.Scan this interface -// has been deprecated in favor of PgxScanner. -type Scanner interface { - // Scan MUST check r.Type().DataType (to check by OID) or - // r.Type().DataTypeName (to check by name) to ensure that it is scanning an - // expected column type. It also MUST check r.Type().FormatCode before - // decoding. It should not assume that it was called on a data type or format - // that it understands. - Scan(r *ValueReader) error -} - -// PgxScanner is an interface used to decode values from the PostgreSQL server. -// It is used exactly the same as the Scanner interface. It simply has renamed -// the method. -type PgxScanner interface { - // ScanPgx MUST check r.Type().DataType (to check by OID) or - // r.Type().DataTypeName (to check by name) to ensure that it is scanning an - // expected column type. It also MUST check r.Type().FormatCode before - // decoding. It should not assume that it was called on a data type or format - // that it understands. - ScanPgx(r *ValueReader) error -} - -// Encoder is an interface used to encode values for transmission to the -// PostgreSQL server. -type Encoder interface { - // Encode writes the value to w. - // - // If the value is NULL an int32(-1) should be written. - // - // Encode MUST check oid to see if the parameter data type is compatible. If - // this is not done, the PostgreSQL server may detect the error if the - // expected data size or format of the encoded data does not match. But if - // the encoded data is a valid representation of the data type PostgreSQL - // expects such as date and int4, incorrect data may be stored. - Encode(w *WriteBuf, oid Oid) error - - // FormatCode returns the format that the encoder writes the value. It must be - // either pgx.TextFormatCode or pgx.BinaryFormatCode. - FormatCode() int16 -} - -// NullFloat32 represents an float4 that may be null. NullFloat32 implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullFloat32 struct { - Float32 float32 - Valid bool // Valid is true if Float32 is not NULL -} - -func (n *NullFloat32) Scan(vr *ValueReader) error { - if vr.Type().DataType != Float4Oid { - return SerializationError(fmt.Sprintf("NullFloat32.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Float32, n.Valid = 0, false - return nil - } - n.Valid = true - n.Float32 = decodeFloat4(vr) - return vr.Err() -} - -func (n NullFloat32) FormatCode() int16 { return BinaryFormatCode } - -func (n NullFloat32) Encode(w *WriteBuf, oid Oid) error { - if oid != Float4Oid { - return SerializationError(fmt.Sprintf("NullFloat32.Encode cannot encode into OID %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeFloat32(w, oid, n.Float32) -} - -// NullFloat64 represents an float8 that may be null. NullFloat64 implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullFloat64 struct { - Float64 float64 - Valid bool // Valid is true if Float64 is not NULL -} - -func (n *NullFloat64) Scan(vr *ValueReader) error { - if vr.Type().DataType != Float8Oid { - return SerializationError(fmt.Sprintf("NullFloat64.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Float64, n.Valid = 0, false - return nil - } - n.Valid = true - n.Float64 = decodeFloat8(vr) - return vr.Err() -} - -func (n NullFloat64) FormatCode() int16 { return BinaryFormatCode } - -func (n NullFloat64) Encode(w *WriteBuf, oid Oid) error { - if oid != Float8Oid { - return SerializationError(fmt.Sprintf("NullFloat64.Encode cannot encode into OID %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeFloat64(w, oid, n.Float64) -} - -// NullString represents an string that may be null. NullString implements the -// Scanner Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullString struct { - String string - Valid bool // Valid is true if String is not NULL -} - -func (n *NullString) Scan(vr *ValueReader) error { - // Not checking oid as so we can scan anything into into a NullString - may revisit this decision later - - if vr.Len() == -1 { - n.String, n.Valid = "", false - return nil - } - - n.Valid = true - n.String = decodeText(vr) - return vr.Err() -} - -func (n NullString) FormatCode() int16 { return TextFormatCode } - -func (n NullString) Encode(w *WriteBuf, oid Oid) error { - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeString(w, oid, n.String) -} - -// AclItem is used for PostgreSQL's aclitem data type. A sample aclitem -// might look like this: -// -// postgres=arwdDxt/postgres -// -// Note, however, that because the user/role name part of an aclitem is -// an identifier, it follows all the usual formatting rules for SQL -// identifiers: if it contains spaces and other special characters, -// it should appear in double-quotes: -// -// postgres=arwdDxt/"role with spaces" -// -type AclItem string - -// NullAclItem represents a pgx.AclItem that may be null. NullAclItem implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan for prepared and unprepared queries. -// -// If Valid is false then the value is NULL. -type NullAclItem struct { - AclItem AclItem - Valid bool // Valid is true if AclItem is not NULL -} - -func (n *NullAclItem) Scan(vr *ValueReader) error { - if vr.Type().DataType != AclItemOid { - return SerializationError(fmt.Sprintf("NullAclItem.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.AclItem, n.Valid = "", false - return nil - } - - n.Valid = true - n.AclItem = AclItem(decodeText(vr)) - return vr.Err() -} - -// Particularly important to return TextFormatCode, seeing as Postgres -// only ever sends aclitem as text, not binary. -func (n NullAclItem) FormatCode() int16 { return TextFormatCode } - -func (n NullAclItem) Encode(w *WriteBuf, oid Oid) error { - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeString(w, oid, string(n.AclItem)) -} - -// Name is a type used for PostgreSQL's special 63-byte -// name data type, used for identifiers like table names. -// The pg_class.relname column is a good example of where the -// name data type is used. -// -// Note that the underlying Go data type of pgx.Name is string, -// so there is no way to enforce the 63-byte length. Inputting -// a longer name into PostgreSQL will result in silent truncation -// to 63 bytes. -// -// Also, if you have custom-compiled PostgreSQL and set -// NAMEDATALEN to a different value, obviously that number of -// bytes applies, rather than the default 63. -type Name string - -// NullName represents a pgx.Name that may be null. NullName implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan for prepared and unprepared queries. -// -// If Valid is false then the value is NULL. -type NullName struct { - Name Name - Valid bool // Valid is true if Name is not NULL -} - -func (n *NullName) Scan(vr *ValueReader) error { - if vr.Type().DataType != NameOid { - return SerializationError(fmt.Sprintf("NullName.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Name, n.Valid = "", false - return nil - } - - n.Valid = true - n.Name = Name(decodeText(vr)) - return vr.Err() -} - -func (n NullName) FormatCode() int16 { return TextFormatCode } - -func (n NullName) Encode(w *WriteBuf, oid Oid) error { - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeString(w, oid, string(n.Name)) -} - -// The pgx.Char type is for PostgreSQL's special 8-bit-only -// "char" type more akin to the C language's char type, or Go's byte type. -// (Note that the name in PostgreSQL itself is "char", in double-quotes, -// and not char.) It gets used a lot in PostgreSQL's system tables to hold -// a single ASCII character value (eg pg_class.relkind). -type Char byte - -// NullChar represents a pgx.Char that may be null. NullChar implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan for prepared and unprepared queries. -// -// If Valid is false then the value is NULL. -type NullChar struct { - Char Char - Valid bool // Valid is true if Char is not NULL -} - -func (n *NullChar) Scan(vr *ValueReader) error { - if vr.Type().DataType != CharOid { - return SerializationError(fmt.Sprintf("NullChar.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Char, n.Valid = 0, false - return nil - } - n.Valid = true - n.Char = decodeChar(vr) - return vr.Err() -} - -func (n NullChar) FormatCode() int16 { return BinaryFormatCode } - -func (n NullChar) Encode(w *WriteBuf, oid Oid) error { - if oid != CharOid { - return SerializationError(fmt.Sprintf("NullChar.Encode cannot encode into OID %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeChar(w, oid, n.Char) -} - -// NullInt16 represents a smallint that may be null. NullInt16 implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan for prepared and unprepared queries. -// -// If Valid is false then the value is NULL. -type NullInt16 struct { - Int16 int16 - Valid bool // Valid is true if Int16 is not NULL -} - -func (n *NullInt16) Scan(vr *ValueReader) error { - if vr.Type().DataType != Int2Oid { - return SerializationError(fmt.Sprintf("NullInt16.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Int16, n.Valid = 0, false - return nil - } - n.Valid = true - n.Int16 = decodeInt2(vr) - return vr.Err() -} - -func (n NullInt16) FormatCode() int16 { return BinaryFormatCode } - -func (n NullInt16) Encode(w *WriteBuf, oid Oid) error { - if oid != Int2Oid { - return SerializationError(fmt.Sprintf("NullInt16.Encode cannot encode into OID %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeInt16(w, oid, n.Int16) -} - -// NullInt32 represents an integer that may be null. NullInt32 implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullInt32 struct { - Int32 int32 - Valid bool // Valid is true if Int32 is not NULL -} - -func (n *NullInt32) Scan(vr *ValueReader) error { - if vr.Type().DataType != Int4Oid { - return SerializationError(fmt.Sprintf("NullInt32.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Int32, n.Valid = 0, false - return nil - } - n.Valid = true - n.Int32 = decodeInt4(vr) - return vr.Err() -} - -func (n NullInt32) FormatCode() int16 { return BinaryFormatCode } - -func (n NullInt32) Encode(w *WriteBuf, oid Oid) error { - if oid != Int4Oid { - return SerializationError(fmt.Sprintf("NullInt32.Encode cannot encode into OID %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeInt32(w, oid, n.Int32) -} - -// Oid (Object Identifier Type) is, according to https://www.postgresql.org/docs/current/static/datatype-oid.html, -// used internally by PostgreSQL as a primary key for various system tables. It is currently implemented -// as an unsigned four-byte integer. Its definition can be found in src/include/postgres_ext.h -// in the PostgreSQL sources. -type Oid uint32 - -// NullOid represents a Command Identifier (Oid) that may be null. NullOid implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullOid struct { - Oid Oid - Valid bool // Valid is true if Oid is not NULL -} - -func (n *NullOid) Scan(vr *ValueReader) error { - if vr.Type().DataType != OidOid { - return SerializationError(fmt.Sprintf("NullOid.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Oid, n.Valid = 0, false - return nil - } - n.Valid = true - n.Oid = decodeOid(vr) - return vr.Err() -} - -func (n NullOid) FormatCode() int16 { return BinaryFormatCode } - -func (n NullOid) Encode(w *WriteBuf, oid Oid) error { - if oid != OidOid { - return SerializationError(fmt.Sprintf("NullOid.Encode cannot encode into OID %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeOid(w, oid, n.Oid) -} - -// Xid is PostgreSQL's Transaction ID type. -// -// In later versions of PostgreSQL, it is the type used for the backend_xid -// and backend_xmin columns of the pg_stat_activity system view. -// -// Also, when one does -// -// select xmin, xmax, * from some_table; -// -// it is the data type of the xmin and xmax hidden system columns. -// -// It is currently implemented as an unsigned four byte integer. -// Its definition can be found in src/include/postgres_ext.h as TransactionId -// in the PostgreSQL sources. -type Xid uint32 - -// NullXid represents a Transaction ID (Xid) that may be null. NullXid implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullXid struct { - Xid Xid - Valid bool // Valid is true if Xid is not NULL -} - -func (n *NullXid) Scan(vr *ValueReader) error { - if vr.Type().DataType != XidOid { - return SerializationError(fmt.Sprintf("NullXid.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Xid, n.Valid = 0, false - return nil - } - n.Valid = true - n.Xid = decodeXid(vr) - return vr.Err() -} - -func (n NullXid) FormatCode() int16 { return BinaryFormatCode } - -func (n NullXid) Encode(w *WriteBuf, oid Oid) error { - if oid != XidOid { - return SerializationError(fmt.Sprintf("NullXid.Encode cannot encode into OID %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeXid(w, oid, n.Xid) -} - -// Cid is PostgreSQL's Command Identifier type. -// -// When one does -// -// select cmin, cmax, * from some_table; -// -// it is the data type of the cmin and cmax hidden system columns. -// -// It is currently implemented as an unsigned four byte integer. -// Its definition can be found in src/include/c.h as CommandId -// in the PostgreSQL sources. -type Cid uint32 - -// NullCid represents a Command Identifier (Cid) that may be null. NullCid implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullCid struct { - Cid Cid - Valid bool // Valid is true if Cid is not NULL -} - -func (n *NullCid) Scan(vr *ValueReader) error { - if vr.Type().DataType != CidOid { - return SerializationError(fmt.Sprintf("NullCid.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Cid, n.Valid = 0, false - return nil - } - n.Valid = true - n.Cid = decodeCid(vr) - return vr.Err() -} - -func (n NullCid) FormatCode() int16 { return BinaryFormatCode } - -func (n NullCid) Encode(w *WriteBuf, oid Oid) error { - if oid != CidOid { - return SerializationError(fmt.Sprintf("NullCid.Encode cannot encode into OID %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeCid(w, oid, n.Cid) -} - -// Tid is PostgreSQL's Tuple Identifier type. -// -// When one does -// -// select ctid, * from some_table; -// -// it is the data type of the ctid hidden system column. -// -// It is currently implemented as a pair unsigned two byte integers. -// Its conversion functions can be found in src/backend/utils/adt/tid.c -// in the PostgreSQL sources. -type Tid struct { - BlockNumber uint32 - OffsetNumber uint16 -} - -// NullTid represents a Tuple Identifier (Tid) that may be null. NullTid implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullTid struct { - Tid Tid - Valid bool // Valid is true if Tid is not NULL -} - -func (n *NullTid) Scan(vr *ValueReader) error { - if vr.Type().DataType != TidOid { - return SerializationError(fmt.Sprintf("NullTid.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Tid, n.Valid = Tid{BlockNumber: 0, OffsetNumber: 0}, false - return nil - } - n.Valid = true - n.Tid = decodeTid(vr) - return vr.Err() -} - -func (n NullTid) FormatCode() int16 { return BinaryFormatCode } - -func (n NullTid) Encode(w *WriteBuf, oid Oid) error { - if oid != TidOid { - return SerializationError(fmt.Sprintf("NullTid.Encode cannot encode into OID %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeTid(w, oid, n.Tid) -} - -// NullInt64 represents an bigint that may be null. NullInt64 implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullInt64 struct { - Int64 int64 - Valid bool // Valid is true if Int64 is not NULL -} - -func (n *NullInt64) Scan(vr *ValueReader) error { - if vr.Type().DataType != Int8Oid { - return SerializationError(fmt.Sprintf("NullInt64.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Int64, n.Valid = 0, false - return nil - } - n.Valid = true - n.Int64 = decodeInt8(vr) - return vr.Err() -} - -func (n NullInt64) FormatCode() int16 { return BinaryFormatCode } - -func (n NullInt64) Encode(w *WriteBuf, oid Oid) error { - if oid != Int8Oid { - return SerializationError(fmt.Sprintf("NullInt64.Encode cannot encode into OID %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeInt64(w, oid, n.Int64) -} - -// NullBool represents an bool that may be null. NullBool implements the Scanner -// and Encoder interfaces so it may be used both as an argument to Query[Row] -// and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullBool struct { - Bool bool - Valid bool // Valid is true if Bool is not NULL -} - -func (n *NullBool) Scan(vr *ValueReader) error { - if vr.Type().DataType != BoolOid { - return SerializationError(fmt.Sprintf("NullBool.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Bool, n.Valid = false, false - return nil - } - n.Valid = true - n.Bool = decodeBool(vr) - return vr.Err() -} - -func (n NullBool) FormatCode() int16 { return BinaryFormatCode } - -func (n NullBool) Encode(w *WriteBuf, oid Oid) error { - if oid != BoolOid { - return SerializationError(fmt.Sprintf("NullBool.Encode cannot encode into OID %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeBool(w, oid, n.Bool) -} - -// NullTime represents an time.Time that may be null. NullTime implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. It corresponds with the PostgreSQL -// types timestamptz, timestamp, and date. -// -// If Valid is false then the value is NULL. -type NullTime struct { - Time time.Time - Valid bool // Valid is true if Time is not NULL -} - -func (n *NullTime) Scan(vr *ValueReader) error { - oid := vr.Type().DataType - if oid != TimestampTzOid && oid != TimestampOid && oid != DateOid { - return SerializationError(fmt.Sprintf("NullTime.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Time, n.Valid = time.Time{}, false - return nil - } - - n.Valid = true - switch oid { - case TimestampTzOid: - n.Time = decodeTimestampTz(vr) - case TimestampOid: - n.Time = decodeTimestamp(vr) - case DateOid: - n.Time = decodeDate(vr) - } - - return vr.Err() -} - -func (n NullTime) FormatCode() int16 { return BinaryFormatCode } - -func (n NullTime) Encode(w *WriteBuf, oid Oid) error { - if oid != TimestampTzOid && oid != TimestampOid && oid != DateOid { - return SerializationError(fmt.Sprintf("NullTime.Encode cannot encode into OID %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeTime(w, oid, n.Time) -} - -// Hstore represents an hstore column. It does not support a null column or null -// key values (use NullHstore for this). Hstore implements the Scanner and -// Encoder interfaces so it may be used both as an argument to Query[Row] and a -// destination for Scan. -type Hstore map[string]string - -func (h *Hstore) Scan(vr *ValueReader) error { - //oid for hstore not standardized, so we check its type name - if vr.Type().DataTypeName != "hstore" { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into Hstore", vr.Type().DataTypeName))) - return nil - } - - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null column into Hstore")) - return nil - } - - switch vr.Type().FormatCode { - case TextFormatCode: - m, err := parseHstoreToMap(vr.ReadString(vr.Len())) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err))) - return nil - } - hm := Hstore(m) - *h = hm - return nil - case BinaryFormatCode: - vr.Fatal(ProtocolError("Can't decode binary hstore")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } -} - -func (h Hstore) FormatCode() int16 { return TextFormatCode } - -func (h Hstore) Encode(w *WriteBuf, oid Oid) error { - var buf bytes.Buffer - - i := 0 - for k, v := range h { - i++ - ks := strings.Replace(k, `\`, `\\`, -1) - ks = strings.Replace(ks, `"`, `\"`, -1) - vs := strings.Replace(v, `\`, `\\`, -1) - vs = strings.Replace(vs, `"`, `\"`, -1) - buf.WriteString(`"`) - buf.WriteString(ks) - buf.WriteString(`"=>"`) - buf.WriteString(vs) - buf.WriteString(`"`) - if i < len(h) { - buf.WriteString(", ") - } - } - w.WriteInt32(int32(buf.Len())) - w.WriteBytes(buf.Bytes()) - return nil -} - -// NullHstore represents an hstore column that can be null or have null values -// associated with its keys. NullHstore implements the Scanner and Encoder -// interfaces so it may be used both as an argument to Query[Row] and a -// destination for Scan. -// -// If Valid is false, then the value of the entire hstore column is NULL -// If any of the NullString values in Store has Valid set to false, the key -// appears in the hstore column, but its value is explicitly set to NULL. -type NullHstore struct { - Hstore map[string]NullString - Valid bool -} - -func (h *NullHstore) Scan(vr *ValueReader) error { - //oid for hstore not standardized, so we check its type name - if vr.Type().DataTypeName != "hstore" { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into NullHstore", vr.Type().DataTypeName))) - return nil - } - - if vr.Len() == -1 { - h.Valid = false - return nil - } - - switch vr.Type().FormatCode { - case TextFormatCode: - store, err := parseHstoreToNullHstore(vr.ReadString(vr.Len())) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err))) - return nil - } - h.Valid = true - h.Hstore = store - return nil - case BinaryFormatCode: - vr.Fatal(ProtocolError("Can't decode binary hstore")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } -} - -func (h NullHstore) FormatCode() int16 { return TextFormatCode } - -func (h NullHstore) Encode(w *WriteBuf, oid Oid) error { - var buf bytes.Buffer - - if !h.Valid { - w.WriteInt32(-1) - return nil - } - - i := 0 - for k, v := range h.Hstore { - i++ - ks := strings.Replace(k, `\`, `\\`, -1) - ks = strings.Replace(ks, `"`, `\"`, -1) - if v.Valid { - vs := strings.Replace(v.String, `\`, `\\`, -1) - vs = strings.Replace(vs, `"`, `\"`, -1) - buf.WriteString(fmt.Sprintf(`"%s"=>"%s"`, ks, vs)) - } else { - buf.WriteString(fmt.Sprintf(`"%s"=>NULL`, ks)) - } - if i < len(h.Hstore) { - buf.WriteString(", ") - } - } - w.WriteInt32(int32(buf.Len())) - w.WriteBytes(buf.Bytes()) - return nil -} - -// Encode encodes arg into wbuf as the type oid. This allows implementations -// of the Encoder interface to delegate the actual work of encoding to the -// built-in functionality. -func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { +func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, error) { if arg == nil { - wbuf.WriteInt32(-1) - return nil + return nil, nil } switch arg := arg.(type) { - case Encoder: - return arg.Encode(wbuf, oid) case driver.Valuer: - v, err := arg.Value() + return arg.Value() + case pgtype.TextEncoder: + buf, err := arg.EncodeText(ci, nil) if err != nil { - return err + return nil, err } - return Encode(wbuf, oid, v) + if buf == nil { + return nil, nil + } + return string(buf), nil + case int64: + return arg, nil + case float64: + return arg, nil + case bool: + return arg, nil + case time.Time: + return arg, nil case string: - return encodeString(wbuf, oid, arg) - case []AclItem: - return encodeAclItemSlice(wbuf, oid, arg) + return arg, nil case []byte: - return encodeByteSlice(wbuf, oid, arg) - case [][]byte: - return encodeByteSliceSlice(wbuf, oid, arg) + return arg, nil + case int8: + return int64(arg), nil + case int16: + return int64(arg), nil + case int32: + return int64(arg), nil + case int: + return int64(arg), nil + case uint8: + return int64(arg), nil + case uint16: + return int64(arg), nil + case uint32: + return int64(arg), nil + case uint64: + if arg > math.MaxInt64 { + return nil, errors.Errorf("arg too big for int64: %v", arg) + } + return int64(arg), nil + case uint: + if arg > math.MaxInt64 { + return nil, errors.Errorf("arg too big for int64: %v", arg) + } + return int64(arg), nil + case float32: + return float64(arg), nil } refVal := reflect.ValueOf(arg) if refVal.Kind() == reflect.Ptr { if refVal.IsNil() { - wbuf.WriteInt32(-1) - return nil + return nil, nil } arg = refVal.Elem().Interface() - return Encode(wbuf, oid, arg) + return convertSimpleArgument(ci, arg) } - if oid == JsonOid { - return encodeJSON(wbuf, oid, arg) + if strippedArg, ok := stripNamedType(&refVal); ok { + return convertSimpleArgument(ci, strippedArg) } - if oid == JsonbOid { - return encodeJSONB(wbuf, oid, arg) + return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg)) +} + +func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid pgtype.OID, arg interface{}) ([]byte, error) { + if arg == nil { + return pgio.AppendInt32(buf, -1), nil } switch arg := arg.(type) { - case []string: - return encodeStringSlice(wbuf, oid, arg) - case bool: - return encodeBool(wbuf, oid, arg) - case []bool: - return encodeBoolSlice(wbuf, oid, arg) - case int: - return encodeInt(wbuf, oid, arg) - case uint: - return encodeUInt(wbuf, oid, arg) - case Char: - return encodeChar(wbuf, oid, arg) - case AclItem: - // The aclitem data type goes over the wire using the same format as string, - // so just cast to string and use encodeString - return encodeString(wbuf, oid, string(arg)) - case Name: - // The name data type goes over the wire using the same format as string, - // so just cast to string and use encodeString - return encodeString(wbuf, oid, string(arg)) - case int8: - return encodeInt8(wbuf, oid, arg) - case uint8: - return encodeUInt8(wbuf, oid, arg) - case int16: - return encodeInt16(wbuf, oid, arg) - case []int16: - return encodeInt16Slice(wbuf, oid, arg) - case uint16: - return encodeUInt16(wbuf, oid, arg) - case []uint16: - return encodeUInt16Slice(wbuf, oid, arg) - case int32: - return encodeInt32(wbuf, oid, arg) - case []int32: - return encodeInt32Slice(wbuf, oid, arg) - case uint32: - return encodeUInt32(wbuf, oid, arg) - case []uint32: - return encodeUInt32Slice(wbuf, oid, arg) - case int64: - return encodeInt64(wbuf, oid, arg) - case []int64: - return encodeInt64Slice(wbuf, oid, arg) - case uint64: - return encodeUInt64(wbuf, oid, arg) - case []uint64: - return encodeUInt64Slice(wbuf, oid, arg) - case float32: - return encodeFloat32(wbuf, oid, arg) - case []float32: - return encodeFloat32Slice(wbuf, oid, arg) - case float64: - return encodeFloat64(wbuf, oid, arg) - case []float64: - return encodeFloat64Slice(wbuf, oid, arg) - case time.Time: - return encodeTime(wbuf, oid, arg) - case []time.Time: - return encodeTimeSlice(wbuf, oid, arg) - case net.IP: - return encodeIP(wbuf, oid, arg) - case []net.IP: - return encodeIPSlice(wbuf, oid, arg) - case net.IPNet: - return encodeIPNet(wbuf, oid, arg) - case []net.IPNet: - return encodeIPNetSlice(wbuf, oid, arg) - case Oid: - return encodeOid(wbuf, oid, arg) - case Xid: - return encodeXid(wbuf, oid, arg) - case Cid: - return encodeCid(wbuf, oid, arg) - default: - if strippedArg, ok := stripNamedType(&refVal); ok { - return Encode(wbuf, oid, strippedArg) + case pgtype.BinaryEncoder: + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := arg.EncodeBinary(ci, buf) + if err != nil { + return nil, err } - return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) + if argBuf != nil { + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + return buf, nil + case pgtype.TextEncoder: + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := arg.EncodeText(ci, buf) + if err != nil { + return nil, err + } + if argBuf != nil { + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + return buf, nil + case driver.Valuer: + v, err := arg.Value() + if err != nil { + return nil, err + } + return encodePreparedStatementArgument(ci, buf, oid, v) + case string: + buf = pgio.AppendInt32(buf, int32(len(arg))) + buf = append(buf, arg...) + return buf, nil } + + refVal := reflect.ValueOf(arg) + + if refVal.Kind() == reflect.Ptr { + if refVal.IsNil() { + return pgio.AppendInt32(buf, -1), nil + } + arg = refVal.Elem().Interface() + return encodePreparedStatementArgument(ci, buf, oid, arg) + } + + if dt, ok := ci.DataTypeForOID(oid); ok { + value := dt.Value + err := value.Set(arg) + if err != nil { + return nil, err + } + + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if argBuf != nil { + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + return buf, nil + } + + if strippedArg, ok := stripNamedType(&refVal); ok { + return encodePreparedStatementArgument(ci, buf, oid, strippedArg) + } + return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) +} + +// chooseParameterFormatCode determines the correct format code for an +// argument to a prepared statement. It defaults to TextFormatCode if no +// determination can be made. +func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid pgtype.OID, arg interface{}) int16 { + switch arg.(type) { + case pgtype.BinaryEncoder: + return BinaryFormatCode + case string, *string, pgtype.TextEncoder: + return TextFormatCode + } + + if dt, ok := ci.DataTypeForOID(oid); ok { + if _, ok := dt.Value.(pgtype.BinaryEncoder); ok { + if arg, ok := arg.(driver.Valuer); ok { + if err := dt.Value.Set(arg); err != nil { + if value, err := arg.Value(); err == nil { + if _, ok := value.(string); ok { + return TextFormatCode + } + } + } + } + + return BinaryFormatCode + } + } + + return TextFormatCode } func stripNamedType(val *reflect.Value) (interface{}, bool) { switch val.Kind() { case reflect.Int: - return int(val.Int()), true + convVal := int(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Int8: - return int8(val.Int()), true + convVal := int8(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Int16: - return int16(val.Int()), true + convVal := int16(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Int32: - return int32(val.Int()), true + convVal := int32(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Int64: - return int64(val.Int()), true + convVal := int64(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Uint: - return uint(val.Uint()), true + convVal := uint(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Uint8: - return uint8(val.Uint()), true + convVal := uint8(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Uint16: - return uint16(val.Uint()), true + convVal := uint16(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Uint32: - return uint32(val.Uint()), true + convVal := uint32(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Uint64: - return uint64(val.Uint()), true + convVal := uint64(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.String: - return val.String(), true + convVal := val.String() + return convVal, reflect.TypeOf(convVal) != val.Type() } return nil, false } - -// Decode decodes from vr into d. d must be a pointer. This allows -// implementations of the Decoder interface to delegate the actual work of -// decoding to the built-in functionality. -func Decode(vr *ValueReader, d interface{}) error { - switch v := d.(type) { - case *bool: - *v = decodeBool(vr) - case *int: - n := decodeInt(vr) - if n < int64(minInt) { - return fmt.Errorf("%d is less than minimum value for int", n) - } else if n > int64(maxInt) { - return fmt.Errorf("%d is greater than maximum value for int", n) - } - *v = int(n) - case *int8: - n := decodeInt(vr) - if n < math.MinInt8 { - return fmt.Errorf("%d is less than minimum value for int8", n) - } else if n > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for int8", n) - } - *v = int8(n) - case *int16: - n := decodeInt(vr) - if n < math.MinInt16 { - return fmt.Errorf("%d is less than minimum value for int16", n) - } else if n > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for int16", n) - } - *v = int16(n) - case *int32: - n := decodeInt(vr) - if n < math.MinInt32 { - return fmt.Errorf("%d is less than minimum value for int32", n) - } else if n > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for int32", n) - } - *v = int32(n) - case *int64: - n := decodeInt(vr) - if n < math.MinInt64 { - return fmt.Errorf("%d is less than minimum value for int64", n) - } else if n > math.MaxInt64 { - return fmt.Errorf("%d is greater than maximum value for int64", n) - } - *v = int64(n) - case *uint: - n := decodeInt(vr) - if n < 0 { - return fmt.Errorf("%d is less than zero for uint8", n) - } else if maxInt == math.MaxInt32 && n > math.MaxUint32 { - return fmt.Errorf("%d is greater than maximum value for uint", n) - } - *v = uint(n) - case *uint8: - n := decodeInt(vr) - if n < 0 { - return fmt.Errorf("%d is less than zero for uint8", n) - } else if n > math.MaxUint8 { - return fmt.Errorf("%d is greater than maximum value for uint8", n) - } - *v = uint8(n) - case *uint16: - n := decodeInt(vr) - if n < 0 { - return fmt.Errorf("%d is less than zero for uint16", n) - } else if n > math.MaxUint16 { - return fmt.Errorf("%d is greater than maximum value for uint16", n) - } - *v = uint16(n) - case *uint32: - n := decodeInt(vr) - if n < 0 { - return fmt.Errorf("%d is less than zero for uint32", n) - } else if n > math.MaxUint32 { - return fmt.Errorf("%d is greater than maximum value for uint32", n) - } - *v = uint32(n) - case *uint64: - n := decodeInt(vr) - if n < 0 { - return fmt.Errorf("%d is less than zero for uint64", n) - } - *v = uint64(n) - case *Char: - *v = decodeChar(vr) - case *AclItem: - // aclitem goes over the wire just like text - *v = AclItem(decodeText(vr)) - case *Name: - // name goes over the wire just like text - *v = Name(decodeText(vr)) - case *Oid: - *v = decodeOid(vr) - case *Xid: - *v = decodeXid(vr) - case *Tid: - *v = decodeTid(vr) - case *Cid: - *v = decodeCid(vr) - case *string: - *v = decodeText(vr) - case *float32: - *v = decodeFloat4(vr) - case *float64: - *v = decodeFloat8(vr) - case *[]AclItem: - *v = decodeAclItemArray(vr) - case *[]bool: - *v = decodeBoolArray(vr) - case *[]int16: - *v = decodeInt2Array(vr) - case *[]uint16: - *v = decodeInt2ArrayToUInt(vr) - case *[]int32: - *v = decodeInt4Array(vr) - case *[]uint32: - *v = decodeInt4ArrayToUInt(vr) - case *[]int64: - *v = decodeInt8Array(vr) - case *[]uint64: - *v = decodeInt8ArrayToUInt(vr) - case *[]float32: - *v = decodeFloat4Array(vr) - case *[]float64: - *v = decodeFloat8Array(vr) - case *[]string: - *v = decodeTextArray(vr) - case *[]time.Time: - *v = decodeTimestampArray(vr) - case *[][]byte: - *v = decodeByteaArray(vr) - case *[]interface{}: - *v = decodeRecord(vr) - case *time.Time: - switch vr.Type().DataType { - case DateOid: - *v = decodeDate(vr) - case TimestampTzOid: - *v = decodeTimestampTz(vr) - case TimestampOid: - *v = decodeTimestamp(vr) - default: - return fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType) - } - case *net.IP: - ipnet := decodeInet(vr) - if oneCount, bitCount := ipnet.Mask.Size(); oneCount != bitCount { - return fmt.Errorf("Cannot decode netmask into *net.IP") - } - *v = ipnet.IP - case *[]net.IP: - ipnets := decodeInetArray(vr) - ips := make([]net.IP, len(ipnets)) - for i, ipnet := range ipnets { - if oneCount, bitCount := ipnet.Mask.Size(); oneCount != bitCount { - return fmt.Errorf("Cannot decode netmask into *net.IP") - } - ips[i] = ipnet.IP - } - *v = ips - case *net.IPNet: - *v = decodeInet(vr) - case *[]net.IPNet: - *v = decodeInetArray(vr) - default: - if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if d is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - // -1 is a null value - if vr.Len() == -1 { - if !el.IsNil() { - // if the destination pointer is not nil, nil it out - el.Set(reflect.Zero(el.Type())) - } - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - d = el.Interface() - return Decode(vr, d) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - n := decodeInt(vr) - if el.OverflowInt(n) { - return fmt.Errorf("Scan cannot decode %d into %T", n, d) - } - el.SetInt(n) - return nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - n := decodeInt(vr) - if n < 0 { - return fmt.Errorf("%d is less than zero for %T", n, d) - } - if el.OverflowUint(uint64(n)) { - return fmt.Errorf("Scan cannot decode %d into %T", n, d) - } - el.SetUint(uint64(n)) - return nil - case reflect.String: - el.SetString(decodeText(vr)) - return nil - } - } - return fmt.Errorf("Scan cannot decode into %T", d) - } - - return nil -} - -func decodeBool(vr *ValueReader) bool { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into bool")) - return false - } - - if vr.Type().DataType != BoolOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into bool", vr.Type().DataType))) - return false - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return false - } - - if vr.Len() != 1 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool: %d", vr.Len()))) - return false - } - - b := vr.ReadByte() - return b != 0 -} - -func encodeBool(w *WriteBuf, oid Oid, value bool) error { - if oid != BoolOid { - return fmt.Errorf("cannot encode Go %s into oid %d", "bool", oid) - } - - w.WriteInt32(1) - - var n byte - if value { - n = 1 - } - - w.WriteByte(n) - - return nil -} - -func decodeInt(vr *ValueReader) int64 { - switch vr.Type().DataType { - case Int2Oid: - return int64(decodeInt2(vr)) - case Int4Oid: - return int64(decodeInt4(vr)) - case Int8Oid: - return int64(decodeInt8(vr)) - } - - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into any integer type", vr.Type().DataType))) - return 0 -} - -func decodeInt8(vr *ValueReader) int64 { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into int64")) - return 0 - } - - if vr.Type().DataType != Int8Oid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int8", vr.Type().DataType))) - return 0 - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return 0 - } - - if vr.Len() != 8 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", vr.Len()))) - return 0 - } - - return vr.ReadInt64() -} - -func decodeChar(vr *ValueReader) Char { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into char")) - return Char(0) - } - - if vr.Type().DataType != CharOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into char", vr.Type().DataType))) - return Char(0) - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return Char(0) - } - - if vr.Len() != 1 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a char: %d", vr.Len()))) - return Char(0) - } - - return Char(vr.ReadByte()) -} - -func decodeInt2(vr *ValueReader) int16 { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into int16")) - return 0 - } - - if vr.Type().DataType != Int2Oid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int16", vr.Type().DataType))) - return 0 - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return 0 - } - - if vr.Len() != 2 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2: %d", vr.Len()))) - return 0 - } - - return vr.ReadInt16() -} - -func encodeInt(w *WriteBuf, oid Oid, value int) error { - switch oid { - case Int2Oid: - if value < math.MinInt16 { - return fmt.Errorf("%d is less than min pg:int2", value) - } else if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than max pg:int2", value) - } - w.WriteInt32(2) - w.WriteInt16(int16(value)) - case Int4Oid: - if value < math.MinInt32 { - return fmt.Errorf("%d is less than min pg:int4", value) - } else if value > math.MaxInt32 { - return fmt.Errorf("%d is greater than max pg:int4", value) - } - w.WriteInt32(4) - w.WriteInt32(int32(value)) - case Int8Oid: - if int64(value) <= int64(math.MaxInt64) { - w.WriteInt32(8) - w.WriteInt64(int64(value)) - } else { - return fmt.Errorf("%d is larger than max int64 %d", value, int64(math.MaxInt64)) - } - default: - return fmt.Errorf("cannot encode %s into oid %v", "int8", oid) - } - - return nil -} - -func encodeUInt(w *WriteBuf, oid Oid, value uint) error { - switch oid { - case Int2Oid: - if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than max pg:int2", value) - } - w.WriteInt32(2) - w.WriteInt16(int16(value)) - case Int4Oid: - if value > math.MaxInt32 { - return fmt.Errorf("%d is greater than max pg:int4", value) - } - w.WriteInt32(4) - w.WriteInt32(int32(value)) - case Int8Oid: - //****** Changed value to int64(value) and math.MaxInt64 to int64(math.MaxInt64) - if int64(value) > int64(math.MaxInt64) { - return fmt.Errorf("%d is greater than max pg:int8", value) - } - w.WriteInt32(8) - w.WriteInt64(int64(value)) - - default: - return fmt.Errorf("cannot encode %s into oid %v", "uint8", oid) - } - - return nil -} - -func encodeChar(w *WriteBuf, oid Oid, value Char) error { - w.WriteInt32(1) - w.WriteByte(byte(value)) - return nil -} - -func encodeInt8(w *WriteBuf, oid Oid, value int8) error { - switch oid { - case Int2Oid: - w.WriteInt32(2) - w.WriteInt16(int16(value)) - case Int4Oid: - w.WriteInt32(4) - w.WriteInt32(int32(value)) - case Int8Oid: - w.WriteInt32(8) - w.WriteInt64(int64(value)) - default: - return fmt.Errorf("cannot encode %s into oid %v", "int8", oid) - } - - return nil -} - -func encodeUInt8(w *WriteBuf, oid Oid, value uint8) error { - switch oid { - case Int2Oid: - w.WriteInt32(2) - w.WriteInt16(int16(value)) - case Int4Oid: - w.WriteInt32(4) - w.WriteInt32(int32(value)) - case Int8Oid: - w.WriteInt32(8) - w.WriteInt64(int64(value)) - default: - return fmt.Errorf("cannot encode %s into oid %v", "uint8", oid) - } - - return nil -} - -func encodeInt16(w *WriteBuf, oid Oid, value int16) error { - switch oid { - case Int2Oid: - w.WriteInt32(2) - w.WriteInt16(value) - case Int4Oid: - w.WriteInt32(4) - w.WriteInt32(int32(value)) - case Int8Oid: - w.WriteInt32(8) - w.WriteInt64(int64(value)) - default: - return fmt.Errorf("cannot encode %s into oid %v", "int16", oid) - } - - return nil -} - -func encodeUInt16(w *WriteBuf, oid Oid, value uint16) error { - switch oid { - case Int2Oid: - if value <= math.MaxInt16 { - w.WriteInt32(2) - w.WriteInt16(int16(value)) - } else { - return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) - } - case Int4Oid: - w.WriteInt32(4) - w.WriteInt32(int32(value)) - case Int8Oid: - w.WriteInt32(8) - w.WriteInt64(int64(value)) - default: - return fmt.Errorf("cannot encode %s into oid %v", "int16", oid) - } - - return nil -} - -func encodeInt32(w *WriteBuf, oid Oid, value int32) error { - switch oid { - case Int2Oid: - if value <= math.MaxInt16 { - w.WriteInt32(2) - w.WriteInt16(int16(value)) - } else { - return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) - } - case Int4Oid: - w.WriteInt32(4) - w.WriteInt32(value) - case Int8Oid: - w.WriteInt32(8) - w.WriteInt64(int64(value)) - default: - return fmt.Errorf("cannot encode %s into oid %v", "int32", oid) - } - - return nil -} - -func encodeUInt32(w *WriteBuf, oid Oid, value uint32) error { - switch oid { - case Int2Oid: - if value <= math.MaxInt16 { - w.WriteInt32(2) - w.WriteInt16(int16(value)) - } else { - return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) - } - case Int4Oid: - if value <= math.MaxInt32 { - w.WriteInt32(4) - w.WriteInt32(int32(value)) - } else { - return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32) - } - case Int8Oid: - w.WriteInt32(8) - w.WriteInt64(int64(value)) - default: - return fmt.Errorf("cannot encode %s into oid %v", "uint32", oid) - } - - return nil -} - -func encodeInt64(w *WriteBuf, oid Oid, value int64) error { - switch oid { - case Int2Oid: - if value <= math.MaxInt16 { - w.WriteInt32(2) - w.WriteInt16(int16(value)) - } else { - return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) - } - case Int4Oid: - if value <= math.MaxInt32 { - w.WriteInt32(4) - w.WriteInt32(int32(value)) - } else { - return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32) - } - case Int8Oid: - w.WriteInt32(8) - w.WriteInt64(value) - default: - return fmt.Errorf("cannot encode %s into oid %v", "int64", oid) - } - - return nil -} - -func encodeUInt64(w *WriteBuf, oid Oid, value uint64) error { - switch oid { - case Int2Oid: - if value <= math.MaxInt16 { - w.WriteInt32(2) - w.WriteInt16(int16(value)) - } else { - return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) - } - case Int4Oid: - if value <= math.MaxInt32 { - w.WriteInt32(4) - w.WriteInt32(int32(value)) - } else { - return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32) - } - case Int8Oid: - - if value <= math.MaxInt64 { - w.WriteInt32(8) - w.WriteInt64(int64(value)) - } else { - return fmt.Errorf("%d is greater than max int64 %d", value, int64(math.MaxInt64)) - } - default: - return fmt.Errorf("cannot encode %s into oid %v", "uint64", oid) - } - - return nil -} - -func decodeInt4(vr *ValueReader) int32 { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into int32")) - return 0 - } - - if vr.Type().DataType != Int4Oid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int32", vr.Type().DataType))) - return 0 - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return 0 - } - - if vr.Len() != 4 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4: %d", vr.Len()))) - return 0 - } - - return vr.ReadInt32() -} - -func decodeOid(vr *ValueReader) Oid { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into Oid")) - return Oid(0) - } - - if vr.Type().DataType != OidOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Oid", vr.Type().DataType))) - return Oid(0) - } - - // Oid needs to decode text format because it is used in loadPgTypes - switch vr.Type().FormatCode { - case TextFormatCode: - s := vr.ReadString(vr.Len()) - n, err := strconv.ParseUint(s, 10, 32) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s))) - } - return Oid(n) - case BinaryFormatCode: - if vr.Len() != 4 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len()))) - return Oid(0) - } - return Oid(vr.ReadInt32()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return Oid(0) - } -} - -func encodeOid(w *WriteBuf, oid Oid, value Oid) error { - if oid != OidOid { - return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Oid", oid) - } - - w.WriteInt32(4) - w.WriteUint32(uint32(value)) - - return nil -} - -func decodeXid(vr *ValueReader) Xid { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into Xid")) - return Xid(0) - } - - if vr.Type().DataType != XidOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Xid", vr.Type().DataType))) - return Xid(0) - } - - // Unlikely Xid will ever go over the wire as text format, but who knows? - switch vr.Type().FormatCode { - case TextFormatCode: - s := vr.ReadString(vr.Len()) - n, err := strconv.ParseUint(s, 10, 32) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s))) - } - return Xid(n) - case BinaryFormatCode: - if vr.Len() != 4 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len()))) - return Xid(0) - } - return Xid(vr.ReadUint32()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return Xid(0) - } -} - -func encodeXid(w *WriteBuf, oid Oid, value Xid) error { - if oid != XidOid { - return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Xid", oid) - } - - w.WriteInt32(4) - w.WriteUint32(uint32(value)) - - return nil -} - -func decodeCid(vr *ValueReader) Cid { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into Cid")) - return Cid(0) - } - - if vr.Type().DataType != CidOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Cid", vr.Type().DataType))) - return Cid(0) - } - - // Unlikely Cid will ever go over the wire as text format, but who knows? - switch vr.Type().FormatCode { - case TextFormatCode: - s := vr.ReadString(vr.Len()) - n, err := strconv.ParseUint(s, 10, 32) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s))) - } - return Cid(n) - case BinaryFormatCode: - if vr.Len() != 4 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len()))) - return Cid(0) - } - return Cid(vr.ReadUint32()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return Cid(0) - } -} - -func encodeCid(w *WriteBuf, oid Oid, value Cid) error { - if oid != CidOid { - return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Cid", oid) - } - - w.WriteInt32(4) - w.WriteUint32(uint32(value)) - - return nil -} - -// Note that we do not match negative numbers, because neither the -// BlockNumber nor OffsetNumber of a Tid can be negative. -var tidRegexp *regexp.Regexp = regexp.MustCompile(`^\((\d*),(\d*)\)$`) - -func decodeTid(vr *ValueReader) Tid { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into Tid")) - return Tid{BlockNumber: 0, OffsetNumber: 0} - } - - if vr.Type().DataType != TidOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Tid", vr.Type().DataType))) - return Tid{BlockNumber: 0, OffsetNumber: 0} - } - - // Unlikely Tid will ever go over the wire as text format, but who knows? - switch vr.Type().FormatCode { - case TextFormatCode: - s := vr.ReadString(vr.Len()) - - match := tidRegexp.FindStringSubmatch(s) - if match == nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s))) - return Tid{BlockNumber: 0, OffsetNumber: 0} - } - - blockNumber, err := strconv.ParseUint(s, 10, 16) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid BlockNumber part of a Tid: %v", s))) - } - - offsetNumber, err := strconv.ParseUint(s, 10, 16) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid offsetNumber part of a Tid: %v", s))) - } - return Tid{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber)} - case BinaryFormatCode: - if vr.Len() != 6 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len()))) - return Tid{BlockNumber: 0, OffsetNumber: 0} - } - return Tid{BlockNumber: vr.ReadUint32(), OffsetNumber: vr.ReadUint16()} - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return Tid{BlockNumber: 0, OffsetNumber: 0} - } -} - -func encodeTid(w *WriteBuf, oid Oid, value Tid) error { - if oid != TidOid { - return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Tid", oid) - } - - w.WriteInt32(6) - w.WriteUint32(value.BlockNumber) - w.WriteUint16(value.OffsetNumber) - - return nil -} - -func decodeFloat4(vr *ValueReader) float32 { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into float32")) - return 0 - } - - if vr.Type().DataType != Float4Oid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float32", vr.Type().DataType))) - return 0 - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return 0 - } - - if vr.Len() != 4 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", vr.Len()))) - return 0 - } - - i := vr.ReadInt32() - return math.Float32frombits(uint32(i)) -} - -func encodeFloat32(w *WriteBuf, oid Oid, value float32) error { - switch oid { - case Float4Oid: - w.WriteInt32(4) - w.WriteInt32(int32(math.Float32bits(value))) - case Float8Oid: - w.WriteInt32(8) - w.WriteInt64(int64(math.Float64bits(float64(value)))) - default: - return fmt.Errorf("cannot encode %s into oid %v", "float32", oid) - } - - return nil -} - -func decodeFloat8(vr *ValueReader) float64 { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into float64")) - return 0 - } - - if vr.Type().DataType != Float8Oid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float64", vr.Type().DataType))) - return 0 - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return 0 - } - - if vr.Len() != 8 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float8: %d", vr.Len()))) - return 0 - } - - i := vr.ReadInt64() - return math.Float64frombits(uint64(i)) -} - -func encodeFloat64(w *WriteBuf, oid Oid, value float64) error { - switch oid { - case Float8Oid: - w.WriteInt32(8) - w.WriteInt64(int64(math.Float64bits(value))) - default: - return fmt.Errorf("cannot encode %s into oid %v", "float64", oid) - } - - return nil -} - -func decodeText(vr *ValueReader) string { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into string")) - return "" - } - - return vr.ReadString(vr.Len()) -} - -func encodeString(w *WriteBuf, oid Oid, value string) error { - w.WriteInt32(int32(len(value))) - w.WriteBytes([]byte(value)) - return nil -} - -func decodeBytea(vr *ValueReader) []byte { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != ByteaOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []byte", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - return vr.ReadBytes(vr.Len()) -} - -func encodeByteSlice(w *WriteBuf, oid Oid, value []byte) error { - w.WriteInt32(int32(len(value))) - w.WriteBytes(value) - - return nil -} - -func decodeJSON(vr *ValueReader, d interface{}) error { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != JsonOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into json", vr.Type().DataType))) - } - - bytes := vr.ReadBytes(vr.Len()) - err := json.Unmarshal(bytes, d) - if err != nil { - vr.Fatal(err) - } - return err -} - -func encodeJSON(w *WriteBuf, oid Oid, value interface{}) error { - if oid != JsonOid { - return fmt.Errorf("cannot encode JSON into oid %v", oid) - } - - s, err := json.Marshal(value) - if err != nil { - return fmt.Errorf("Failed to encode json from type: %T", value) - } - - w.WriteInt32(int32(len(s))) - w.WriteBytes(s) - - return nil -} - -func decodeJSONB(vr *ValueReader, d interface{}) error { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != JsonbOid { - err := ProtocolError(fmt.Sprintf("Cannot decode oid %v into jsonb", vr.Type().DataType)) - vr.Fatal(err) - return err - } - - bytes := vr.ReadBytes(vr.Len()) - if vr.Type().FormatCode == BinaryFormatCode { - if bytes[0] != 1 { - err := ProtocolError(fmt.Sprintf("Unknown jsonb format byte: %x", bytes[0])) - vr.Fatal(err) - return err - } - bytes = bytes[1:] - } - - err := json.Unmarshal(bytes, d) - if err != nil { - vr.Fatal(err) - } - return err -} - -func encodeJSONB(w *WriteBuf, oid Oid, value interface{}) error { - if oid != JsonbOid { - return fmt.Errorf("cannot encode JSON into oid %v", oid) - } - - s, err := json.Marshal(value) - if err != nil { - return fmt.Errorf("Failed to encode json from type: %T", value) - } - - w.WriteInt32(int32(len(s) + 1)) - w.WriteByte(1) // JSONB format header - w.WriteBytes(s) - - return nil -} - -func decodeDate(vr *ValueReader) time.Time { - var zeroTime time.Time - - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into time.Time")) - return zeroTime - } - - if vr.Type().DataType != DateOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) - return zeroTime - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return zeroTime - } - - if vr.Len() != 4 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an date: %d", vr.Len()))) - } - dayOffset := vr.ReadInt32() - return time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.Local) -} - -func encodeTime(w *WriteBuf, oid Oid, value time.Time) error { - switch oid { - case DateOid: - tUnix := time.Date(value.Year(), value.Month(), value.Day(), 0, 0, 0, 0, time.UTC).Unix() - dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() - - secSinceDateEpoch := tUnix - dateEpoch - daysSinceDateEpoch := secSinceDateEpoch / 86400 - - w.WriteInt32(4) - w.WriteInt32(int32(daysSinceDateEpoch)) - - return nil - case TimestampTzOid, TimestampOid: - microsecSinceUnixEpoch := value.Unix()*1000000 + int64(value.Nanosecond())/1000 - microsecSinceY2K := microsecSinceUnixEpoch - microsecFromUnixEpochToY2K - - w.WriteInt32(8) - w.WriteInt64(microsecSinceY2K) - - return nil - default: - return fmt.Errorf("cannot encode %s into oid %v", "time.Time", oid) - } -} - -const microsecFromUnixEpochToY2K = 946684800 * 1000000 - -func decodeTimestampTz(vr *ValueReader) time.Time { - var zeroTime time.Time - - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into time.Time")) - return zeroTime - } - - if vr.Type().DataType != TimestampTzOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) - return zeroTime - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return zeroTime - } - - if vr.Len() != 8 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamptz: %d", vr.Len()))) - return zeroTime - } - - microsecSinceY2K := vr.ReadInt64() - microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K - return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) -} - -func decodeTimestamp(vr *ValueReader) time.Time { - var zeroTime time.Time - - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into timestamp")) - return zeroTime - } - - if vr.Type().DataType != TimestampOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) - return zeroTime - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return zeroTime - } - - if vr.Len() != 8 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamp: %d", vr.Len()))) - return zeroTime - } - - microsecSinceY2K := vr.ReadInt64() - microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K - return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) -} - -func decodeInet(vr *ValueReader) net.IPNet { - var zero net.IPNet - - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into net.IPNet")) - return zero - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return zero - } - - pgType := vr.Type() - if pgType.DataType != InetOid && pgType.DataType != CidrOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, pgType.Name))) - return zero - } - if vr.Len() != 8 && vr.Len() != 20 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a %s: %d", pgType.Name, vr.Len()))) - return zero - } - - vr.ReadByte() // ignore family - bits := vr.ReadByte() - vr.ReadByte() // ignore is_cidr - addressLength := vr.ReadByte() - - var ipnet net.IPNet - ipnet.IP = vr.ReadBytes(int32(addressLength)) - ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8) - - return ipnet -} - -func encodeIPNet(w *WriteBuf, oid Oid, value net.IPNet) error { - if oid != InetOid && oid != CidrOid { - return fmt.Errorf("cannot encode %s into oid %v", "net.IPNet", oid) - } - - var size int32 - var family byte - switch len(value.IP) { - case net.IPv4len: - size = 8 - family = *w.conn.pgsqlAfInet - case net.IPv6len: - size = 20 - family = *w.conn.pgsqlAfInet6 - default: - return fmt.Errorf("Unexpected IP length: %v", len(value.IP)) - } - - w.WriteInt32(size) - w.WriteByte(family) - ones, _ := value.Mask.Size() - w.WriteByte(byte(ones)) - w.WriteByte(0) // is_cidr is ignored on server - w.WriteByte(byte(len(value.IP))) - w.WriteBytes(value.IP) - - return nil -} - -func encodeIP(w *WriteBuf, oid Oid, value net.IP) error { - if oid != InetOid && oid != CidrOid { - return fmt.Errorf("cannot encode %s into oid %v", "net.IP", oid) - } - - var ipnet net.IPNet - ipnet.IP = value - bitCount := len(value) * 8 - ipnet.Mask = net.CIDRMask(bitCount, bitCount) - return encodeIPNet(w, oid, ipnet) -} - -func decodeRecord(vr *ValueReader) []interface{} { - if vr.Len() == -1 { - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - if vr.Type().DataType != RecordOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []interface{}", vr.Type().DataType))) - return nil - } - - valueCount := vr.ReadInt32() - record := make([]interface{}, 0, int(valueCount)) - - for i := int32(0); i < valueCount; i++ { - fd := FieldDescription{FormatCode: BinaryFormatCode} - fieldVR := ValueReader{mr: vr.mr, fd: &fd} - fd.DataType = vr.ReadOid() - fieldVR.valueBytesRemaining = vr.ReadInt32() - vr.valueBytesRemaining -= fieldVR.valueBytesRemaining - - switch fd.DataType { - case BoolOid: - record = append(record, decodeBool(&fieldVR)) - case ByteaOid: - record = append(record, decodeBytea(&fieldVR)) - case Int8Oid: - record = append(record, decodeInt8(&fieldVR)) - case Int2Oid: - record = append(record, decodeInt2(&fieldVR)) - case Int4Oid: - record = append(record, decodeInt4(&fieldVR)) - case OidOid: - record = append(record, decodeOid(&fieldVR)) - case Float4Oid: - record = append(record, decodeFloat4(&fieldVR)) - case Float8Oid: - record = append(record, decodeFloat8(&fieldVR)) - case DateOid: - record = append(record, decodeDate(&fieldVR)) - case TimestampTzOid: - record = append(record, decodeTimestampTz(&fieldVR)) - case TimestampOid: - record = append(record, decodeTimestamp(&fieldVR)) - case InetOid, CidrOid: - record = append(record, decodeInet(&fieldVR)) - case TextOid, VarcharOid, UnknownOid: - record = append(record, decodeText(&fieldVR)) - default: - vr.Fatal(fmt.Errorf("decodeRecord cannot decode oid %d", fd.DataType)) - return nil - } - - // Consume any remaining data - if fieldVR.Len() > 0 { - fieldVR.ReadBytes(fieldVR.Len()) - } - - if fieldVR.Err() != nil { - vr.Fatal(fieldVR.Err()) - return nil - } - } - - return record -} - -func decode1dArrayHeader(vr *ValueReader) (length int32, err error) { - numDims := vr.ReadInt32() - if numDims > 1 { - return 0, ProtocolError(fmt.Sprintf("Expected array to have 0 or 1 dimension, but it had %v", numDims)) - } - - vr.ReadInt32() // 0 if no nulls / 1 if there is one or more nulls -- but we don't care - vr.ReadInt32() // element oid - - if numDims == 0 { - return 0, nil - } - - length = vr.ReadInt32() - - idxFirstElem := vr.ReadInt32() - if idxFirstElem != 1 { - return 0, ProtocolError(fmt.Sprintf("Expected array's first element to start a index 1, but it is %d", idxFirstElem)) - } - - return length, nil -} - -func decodeBoolArray(vr *ValueReader) []bool { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != BoolArrayOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []bool", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]bool, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case 1: - if vr.ReadByte() == 1 { - a[i] = true - } - case -1: - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool element: %d", elSize))) - return nil - } - } - - return a -} - -func encodeBoolSlice(w *WriteBuf, oid Oid, slice []bool) error { - if oid != BoolArrayOid { - return fmt.Errorf("cannot encode Go %s into oid %d", "[]bool", oid) - } - - encodeArrayHeader(w, BoolOid, len(slice), 5) - for _, v := range slice { - w.WriteInt32(1) - var b byte - if v { - b = 1 - } - w.WriteByte(b) - } - - return nil -} - -func decodeByteaArray(vr *ValueReader) [][]byte { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != ByteaArrayOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into [][]byte", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([][]byte, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case -1: - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - default: - a[i] = vr.ReadBytes(elSize) - } - } - - return a -} - -func encodeByteSliceSlice(w *WriteBuf, oid Oid, value [][]byte) error { - if oid != ByteaArrayOid { - return fmt.Errorf("cannot encode Go %s into oid %d", "[][]byte", oid) - } - - size := 20 // array header size - for _, el := range value { - size += 4 + len(el) - } - - w.WriteInt32(int32(size)) - - w.WriteInt32(1) // number of dimensions - w.WriteInt32(0) // no nulls - w.WriteInt32(int32(ByteaOid)) // type of elements - w.WriteInt32(int32(len(value))) // number of elements - w.WriteInt32(1) // index of first element - - for _, el := range value { - encodeByteSlice(w, ByteaOid, el) - } - - return nil -} - -func decodeInt2Array(vr *ValueReader) []int16 { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != Int2ArrayOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int16", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]int16, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case 2: - a[i] = vr.ReadInt16() - case -1: - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2 element: %d", elSize))) - return nil - } - } - - return a -} - -func decodeInt2ArrayToUInt(vr *ValueReader) []uint16 { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != Int2ArrayOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint16", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]uint16, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case 2: - tmp := vr.ReadInt16() - if tmp < 0 { - vr.Fatal(ProtocolError(fmt.Sprintf("%d is less than zero for uint16", tmp))) - return nil - } - a[i] = uint16(tmp) - case -1: - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2 element: %d", elSize))) - return nil - } - } - - return a -} - -func encodeInt16Slice(w *WriteBuf, oid Oid, slice []int16) error { - if oid != Int2ArrayOid { - return fmt.Errorf("cannot encode Go %s into oid %d", "[]int16", oid) - } - - encodeArrayHeader(w, Int2Oid, len(slice), 6) - for _, v := range slice { - w.WriteInt32(2) - w.WriteInt16(v) - } - - return nil -} - -func encodeUInt16Slice(w *WriteBuf, oid Oid, slice []uint16) error { - if oid != Int2ArrayOid { - return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint16", oid) - } - - encodeArrayHeader(w, Int2Oid, len(slice), 6) - for _, v := range slice { - if v <= math.MaxInt16 { - w.WriteInt32(2) - w.WriteInt16(int16(v)) - } else { - return fmt.Errorf("%d is greater than max smallint %d", v, math.MaxInt16) - } - } - - return nil -} - -func decodeInt4Array(vr *ValueReader) []int32 { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != Int4ArrayOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int32", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]int32, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case 4: - a[i] = vr.ReadInt32() - case -1: - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4 element: %d", elSize))) - return nil - } - } - - return a -} - -func decodeInt4ArrayToUInt(vr *ValueReader) []uint32 { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != Int4ArrayOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint32", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]uint32, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case 4: - tmp := vr.ReadInt32() - if tmp < 0 { - vr.Fatal(ProtocolError(fmt.Sprintf("%d is less than zero for uint32", tmp))) - return nil - } - a[i] = uint32(tmp) - case -1: - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4 element: %d", elSize))) - return nil - } - } - - return a -} - -func encodeInt32Slice(w *WriteBuf, oid Oid, slice []int32) error { - if oid != Int4ArrayOid { - return fmt.Errorf("cannot encode Go %s into oid %d", "[]int32", oid) - } - - encodeArrayHeader(w, Int4Oid, len(slice), 8) - for _, v := range slice { - w.WriteInt32(4) - w.WriteInt32(v) - } - - return nil -} - -func encodeUInt32Slice(w *WriteBuf, oid Oid, slice []uint32) error { - if oid != Int4ArrayOid { - return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint32", oid) - } - - encodeArrayHeader(w, Int4Oid, len(slice), 8) - for _, v := range slice { - if v <= math.MaxInt32 { - w.WriteInt32(4) - w.WriteInt32(int32(v)) - } else { - return fmt.Errorf("%d is greater than max integer %d", v, math.MaxInt32) - } - } - - return nil -} - -func decodeInt8Array(vr *ValueReader) []int64 { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != Int8ArrayOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int64", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]int64, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case 8: - a[i] = vr.ReadInt64() - case -1: - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8 element: %d", elSize))) - return nil - } - } - - return a -} - -func decodeInt8ArrayToUInt(vr *ValueReader) []uint64 { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != Int8ArrayOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint64", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]uint64, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case 8: - tmp := vr.ReadInt64() - if tmp < 0 { - vr.Fatal(ProtocolError(fmt.Sprintf("%d is less than zero for uint64", tmp))) - return nil - } - a[i] = uint64(tmp) - case -1: - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8 element: %d", elSize))) - return nil - } - } - - return a -} - -func encodeInt64Slice(w *WriteBuf, oid Oid, slice []int64) error { - if oid != Int8ArrayOid { - return fmt.Errorf("cannot encode Go %s into oid %d", "[]int64", oid) - } - - encodeArrayHeader(w, Int8Oid, len(slice), 12) - for _, v := range slice { - w.WriteInt32(8) - w.WriteInt64(v) - } - - return nil -} - -func encodeUInt64Slice(w *WriteBuf, oid Oid, slice []uint64) error { - if oid != Int8ArrayOid { - return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint64", oid) - } - - encodeArrayHeader(w, Int8Oid, len(slice), 12) - for _, v := range slice { - if v <= math.MaxInt64 { - w.WriteInt32(8) - w.WriteInt64(int64(v)) - } else { - return fmt.Errorf("%d is greater than max bigint %d", v, int64(math.MaxInt64)) - } - } - - return nil -} - -func decodeFloat4Array(vr *ValueReader) []float32 { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != Float4ArrayOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []float32", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]float32, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case 4: - n := vr.ReadInt32() - a[i] = math.Float32frombits(uint32(n)) - case -1: - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4 element: %d", elSize))) - return nil - } - } - - return a -} - -func encodeFloat32Slice(w *WriteBuf, oid Oid, slice []float32) error { - if oid != Float4ArrayOid { - return fmt.Errorf("cannot encode Go %s into oid %d", "[]float32", oid) - } - - encodeArrayHeader(w, Float4Oid, len(slice), 8) - for _, v := range slice { - w.WriteInt32(4) - w.WriteInt32(int32(math.Float32bits(v))) - } - - return nil -} - -func decodeFloat8Array(vr *ValueReader) []float64 { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != Float8ArrayOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []float64", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]float64, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case 8: - n := vr.ReadInt64() - a[i] = math.Float64frombits(uint64(n)) - case -1: - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4 element: %d", elSize))) - return nil - } - } - - return a -} - -func encodeFloat64Slice(w *WriteBuf, oid Oid, slice []float64) error { - if oid != Float8ArrayOid { - return fmt.Errorf("cannot encode Go %s into oid %d", "[]float64", oid) - } - - encodeArrayHeader(w, Float8Oid, len(slice), 12) - for _, v := range slice { - w.WriteInt32(8) - w.WriteInt64(int64(math.Float64bits(v))) - } - - return nil -} - -func decodeTextArray(vr *ValueReader) []string { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != TextArrayOid && vr.Type().DataType != VarcharArrayOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []string", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]string, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - if elSize == -1 { - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - } - - a[i] = vr.ReadString(elSize) - } - - return a -} - -// escapeAclItem escapes an AclItem before it is added to -// its aclitem[] string representation. The PostgreSQL aclitem -// datatype itself can need escapes because it follows the -// formatting rules of SQL identifiers. Think of this function -// as escaping the escapes, so that PostgreSQL's array parser -// will do the right thing. -func escapeAclItem(acl string) (string, error) { - var escapedAclItem bytes.Buffer - reader := strings.NewReader(acl) - for { - rn, _, err := reader.ReadRune() - if err != nil { - if err == io.EOF { - // Here, EOF is an expected end state, not an error. - return escapedAclItem.String(), nil - } - // This error was not expected - return "", err - } - if needsEscape(rn) { - escapedAclItem.WriteRune('\\') - } - escapedAclItem.WriteRune(rn) - } -} - -// needsEscape determines whether or not a rune needs escaping -// before being placed in the textual representation of an -// aclitem[] array. -func needsEscape(rn rune) bool { - return rn == '\\' || rn == ',' || rn == '"' || rn == '}' -} - -// encodeAclItemSlice encodes a slice of AclItems in -// their textual represention for PostgreSQL. -func encodeAclItemSlice(w *WriteBuf, oid Oid, aclitems []AclItem) error { - strs := make([]string, len(aclitems)) - var escapedAclItem string - var err error - for i := range strs { - escapedAclItem, err = escapeAclItem(string(aclitems[i])) - if err != nil { - return err - } - strs[i] = string(escapedAclItem) - } - - var buf bytes.Buffer - buf.WriteRune('{') - buf.WriteString(strings.Join(strs, ",")) - buf.WriteRune('}') - str := buf.String() - w.WriteInt32(int32(len(str))) - w.WriteBytes([]byte(str)) - return nil -} - -// parseAclItemArray parses the textual representation -// of the aclitem[] type. The textual representation is chosen because -// Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin). -// See https://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO -// for formatting notes. -func parseAclItemArray(arr string) ([]AclItem, error) { - reader := strings.NewReader(arr) - // Difficult to guess a performant initial capacity for a slice of - // aclitems, but let's go with 5. - aclItems := make([]AclItem, 0, 5) - // A single value - aclItem := AclItem("") - for { - // Grab the first/next/last rune to see if we are dealing with a - // quoted value, an unquoted value, or the end of the string. - rn, _, err := reader.ReadRune() - if err != nil { - if err == io.EOF { - // Here, EOF is an expected end state, not an error. - return aclItems, nil - } - // This error was not expected - return nil, err - } - - if rn == '"' { - // Discard the opening quote of the quoted value. - aclItem, err = parseQuotedAclItem(reader) - } else { - // We have just read the first rune of an unquoted (bare) value; - // put it back so that ParseBareValue can read it. - err := reader.UnreadRune() - if err != nil { - return nil, err - } - aclItem, err = parseBareAclItem(reader) - } - - if err != nil { - if err == io.EOF { - // Here, EOF is an expected end state, not an error.. - aclItems = append(aclItems, aclItem) - return aclItems, nil - } - // This error was not expected. - return nil, err - } - aclItems = append(aclItems, aclItem) - } -} - -// parseBareAclItem parses a bare (unquoted) aclitem from reader -func parseBareAclItem(reader *strings.Reader) (AclItem, error) { - var aclItem bytes.Buffer - for { - rn, _, err := reader.ReadRune() - if err != nil { - // Return the read value in case the error is a harmless io.EOF. - // (io.EOF marks the end of a bare aclitem at the end of a string) - return AclItem(aclItem.String()), err - } - if rn == ',' { - // A comma marks the end of a bare aclitem. - return AclItem(aclItem.String()), nil - } else { - aclItem.WriteRune(rn) - } - } -} - -// parseQuotedAclItem parses an aclitem which is in double quotes from reader -func parseQuotedAclItem(reader *strings.Reader) (AclItem, error) { - var aclItem bytes.Buffer - for { - rn, escaped, err := readPossiblyEscapedRune(reader) - if err != nil { - if err == io.EOF { - // Even when it is the last value, the final rune of - // a quoted aclitem should be the final closing quote, not io.EOF. - return AclItem(""), fmt.Errorf("unexpected end of quoted value") - } - // Return the read aclitem in case the error is a harmless io.EOF, - // which will be determined by the caller. - return AclItem(aclItem.String()), err - } - if !escaped && rn == '"' { - // An unescaped double quote marks the end of a quoted value. - // The next rune should either be a comma or the end of the string. - rn, _, err := reader.ReadRune() - if err != nil { - // Return the read value in case the error is a harmless io.EOF, - // which will be determined by the caller. - return AclItem(aclItem.String()), err - } - if rn != ',' { - return AclItem(""), fmt.Errorf("unexpected rune after quoted value") - } - return AclItem(aclItem.String()), nil - } - aclItem.WriteRune(rn) - } -} - -// Returns the next rune from r, unless it is a backslash; -// in that case, it returns the rune after the backslash. The second -// return value tells us whether or not the rune was -// preceeded by a backslash (escaped). -func readPossiblyEscapedRune(reader *strings.Reader) (rune, bool, error) { - rn, _, err := reader.ReadRune() - if err != nil { - return 0, false, err - } - if rn == '\\' { - // Discard the backslash and read the next rune. - rn, _, err = reader.ReadRune() - if err != nil { - return 0, false, err - } - return rn, true, nil - } - return rn, false, nil -} - -func decodeAclItemArray(vr *ValueReader) []AclItem { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into []AclItem")) - return nil - } - - str := vr.ReadString(vr.Len()) - - // Short-circuit empty array. - if str == "{}" { - return []AclItem{} - } - - // Remove the '{' at the front and the '}' at the end, - // so that parseAclItemArray doesn't have to deal with them. - str = str[1 : len(str)-1] - aclItems, err := parseAclItemArray(str) - if err != nil { - vr.Fatal(ProtocolError(err.Error())) - return nil - } - return aclItems -} - -func encodeStringSlice(w *WriteBuf, oid Oid, slice []string) error { - var elOid Oid - switch oid { - case VarcharArrayOid: - elOid = VarcharOid - case TextArrayOid: - elOid = TextOid - default: - return fmt.Errorf("cannot encode Go %s into oid %d", "[]string", oid) - } - - var totalStringSize int - for _, v := range slice { - totalStringSize += len(v) - } - - size := 20 + len(slice)*4 + totalStringSize - w.WriteInt32(int32(size)) - - w.WriteInt32(1) // number of dimensions - w.WriteInt32(0) // no nulls - w.WriteInt32(int32(elOid)) // type of elements - w.WriteInt32(int32(len(slice))) // number of elements - w.WriteInt32(1) // index of first element - - for _, v := range slice { - w.WriteInt32(int32(len(v))) - w.WriteBytes([]byte(v)) - } - - return nil -} - -func decodeTimestampArray(vr *ValueReader) []time.Time { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != TimestampArrayOid && vr.Type().DataType != TimestampTzArrayOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []time.Time", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]time.Time, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case 8: - microsecSinceY2K := vr.ReadInt64() - microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K - a[i] = time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) - case -1: - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an time.Time element: %d", elSize))) - return nil - } - } - - return a -} - -func encodeTimeSlice(w *WriteBuf, oid Oid, slice []time.Time) error { - var elOid Oid - switch oid { - case TimestampArrayOid: - elOid = TimestampOid - case TimestampTzArrayOid: - elOid = TimestampTzOid - default: - return fmt.Errorf("cannot encode Go %s into oid %d", "[]time.Time", oid) - } - - encodeArrayHeader(w, int(elOid), len(slice), 12) - for _, t := range slice { - w.WriteInt32(8) - microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000 - microsecSinceY2K := microsecSinceUnixEpoch - microsecFromUnixEpochToY2K - w.WriteInt64(microsecSinceY2K) - } - - return nil -} - -func decodeInetArray(vr *ValueReader) []net.IPNet { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != InetArrayOid && vr.Type().DataType != CidrArrayOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []net.IP", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]net.IPNet, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - if elSize == -1 { - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - } - - vr.ReadByte() // ignore family - bits := vr.ReadByte() - vr.ReadByte() // ignore is_cidr - addressLength := vr.ReadByte() - - var ipnet net.IPNet - ipnet.IP = vr.ReadBytes(int32(addressLength)) - ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8) - - a[i] = ipnet - } - - return a -} - -func encodeIPNetSlice(w *WriteBuf, oid Oid, slice []net.IPNet) error { - var elOid Oid - switch oid { - case InetArrayOid: - elOid = InetOid - case CidrArrayOid: - elOid = CidrOid - default: - return fmt.Errorf("cannot encode Go %s into oid %d", "[]net.IPNet", oid) - } - - size := int32(20) // array header size - for _, ipnet := range slice { - size += 4 + 4 + int32(len(ipnet.IP)) // size of element + inet/cidr metadata + IP bytes - } - w.WriteInt32(int32(size)) - - w.WriteInt32(1) // number of dimensions - w.WriteInt32(0) // no nulls - w.WriteInt32(int32(elOid)) // type of elements - w.WriteInt32(int32(len(slice))) // number of elements - w.WriteInt32(1) // index of first element - - for _, ipnet := range slice { - encodeIPNet(w, elOid, ipnet) - } - - return nil -} - -func encodeIPSlice(w *WriteBuf, oid Oid, slice []net.IP) error { - var elOid Oid - switch oid { - case InetArrayOid: - elOid = InetOid - case CidrArrayOid: - elOid = CidrOid - default: - return fmt.Errorf("cannot encode Go %s into oid %d", "[]net.IPNet", oid) - } - - size := int32(20) // array header size - for _, ip := range slice { - size += 4 + 4 + int32(len(ip)) // size of element + inet/cidr metadata + IP bytes - } - w.WriteInt32(int32(size)) - - w.WriteInt32(1) // number of dimensions - w.WriteInt32(0) // no nulls - w.WriteInt32(int32(elOid)) // type of elements - w.WriteInt32(int32(len(slice))) // number of elements - w.WriteInt32(1) // index of first element - - for _, ip := range slice { - encodeIP(w, elOid, ip) - } - - return nil -} - -func encodeArrayHeader(w *WriteBuf, oid, length, sizePerItem int) { - w.WriteInt32(int32(20 + length*sizePerItem)) - w.WriteInt32(1) // number of dimensions - w.WriteInt32(0) // no nulls - w.WriteInt32(int32(oid)) // type of elements - w.WriteInt32(int32(length)) // number of elements - w.WriteInt32(1) // index of first element -} diff --git a/values_test.go b/values_test.go index 42d5bd3d..b8aec46a 100644 --- a/values_test.go +++ b/values_test.go @@ -4,7 +4,6 @@ import ( "bytes" "net" "reflect" - "strings" "testing" "time" @@ -18,24 +17,24 @@ func TestDateTranscode(t *testing.T) { defer closeConn(t, conn) dates := []time.Time{ - time.Date(1, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(1000, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(1600, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(1700, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), - time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(2001, 1, 2, 0, 0, 0, 0, time.Local), - time.Date(2004, 2, 29, 0, 0, 0, 0, time.Local), - time.Date(2013, 7, 4, 0, 0, 0, 0, time.Local), - time.Date(2013, 12, 25, 0, 0, 0, 0, time.Local), - time.Date(2029, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(2081, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(2096, 2, 29, 0, 0, 0, 0, time.Local), - time.Date(2550, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(9999, 12, 31, 0, 0, 0, 0, time.Local), + time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1000, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1600, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1700, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), + time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2001, 1, 2, 0, 0, 0, 0, time.UTC), + time.Date(2004, 2, 29, 0, 0, 0, 0, time.UTC), + time.Date(2013, 7, 4, 0, 0, 0, 0, time.UTC), + time.Date(2013, 12, 25, 0, 0, 0, 0, time.UTC), + time.Date(2029, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2081, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2096, 2, 29, 0, 0, 0, 0, time.UTC), + time.Date(2550, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(9999, 12, 31, 0, 0, 0, 0, time.UTC), } for _, actualDate := range dates { @@ -78,84 +77,78 @@ func TestTimestampTzTranscode(t *testing.T) { } } -func TestJsonAndJsonbTranscode(t *testing.T) { +// TODO - move these tests to pgtype + +func TestJSONAndJSONBTranscode(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} { - if _, ok := conn.PgTypes[oid]; !ok { - return // No JSON/JSONB type -- must be running against old PostgreSQL + for _, typename := range []string{"json", "jsonb"} { + if _, ok := conn.ConnInfo.DataTypeForName(typename); !ok { + continue // No JSON/JSONB type -- must be running against old PostgreSQL } - for _, format := range []int16{pgx.TextFormatCode, pgx.BinaryFormatCode} { - pgtype := conn.PgTypes[oid] - pgtype.DefaultFormat = format - conn.PgTypes[oid] = pgtype - - typename := conn.PgTypes[oid].Name - - testJsonString(t, conn, typename, format) - testJsonStringPointer(t, conn, typename, format) - testJsonSingleLevelStringMap(t, conn, typename, format) - testJsonNestedMap(t, conn, typename, format) - testJsonStringArray(t, conn, typename, format) - testJsonInt64Array(t, conn, typename, format) - testJsonInt16ArrayFailureDueToOverflow(t, conn, typename, format) - testJsonStruct(t, conn, typename, format) - } + testJSONString(t, conn, typename) + testJSONStringPointer(t, conn, typename) + testJSONSingleLevelStringMap(t, conn, typename) + testJSONNestedMap(t, conn, typename) + testJSONStringArray(t, conn, typename) + testJSONInt64Array(t, conn, typename) + testJSONInt16ArrayFailureDueToOverflow(t, conn, typename) + testJSONStruct(t, conn, typename) } } -func testJsonString(t *testing.T, conn *pgx.Conn, typename string, format int16) { +func testJSONString(t *testing.T, conn *pgx.Conn, typename string) { input := `{"key": "value"}` expectedOutput := map[string]string{"key": "value"} var output map[string]string err := conn.QueryRow("select $1::"+typename, input).Scan(&output) if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) + t.Errorf("%s: QueryRow Scan failed: %v", typename, err) return } if !reflect.DeepEqual(expectedOutput, output) { - t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output) + t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, expectedOutput, output) return } } -func testJsonStringPointer(t *testing.T, conn *pgx.Conn, typename string, format int16) { +func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string) { input := `{"key": "value"}` expectedOutput := map[string]string{"key": "value"} var output map[string]string err := conn.QueryRow("select $1::"+typename, &input).Scan(&output) if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) + t.Errorf("%s: QueryRow Scan failed: %v", typename, err) return } if !reflect.DeepEqual(expectedOutput, output) { - t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output) + t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, expectedOutput, output) return } } -func testJsonSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string, format int16) { +func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string) { input := map[string]string{"key": "value"} var output map[string]string err := conn.QueryRow("select $1::"+typename, input).Scan(&output) if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) + t.Errorf("%s: QueryRow Scan failed: %v", typename, err) return } if !reflect.DeepEqual(input, output) { - t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, input, output) + t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, input, output) return } } -func testJsonNestedMap(t *testing.T, conn *pgx.Conn, typename string, format int16) { +func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string) { input := map[string]interface{}{ "name": "Uncanny", "stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)}, @@ -164,52 +157,52 @@ func testJsonNestedMap(t *testing.T, conn *pgx.Conn, typename string, format int var output map[string]interface{} err := conn.QueryRow("select $1::"+typename, input).Scan(&output) if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) + t.Errorf("%s: QueryRow Scan failed: %v", typename, err) return } if !reflect.DeepEqual(input, output) { - t.Errorf("%s %d: Did not transcode map[string]interface{} successfully: %v is not %v", typename, format, input, output) + t.Errorf("%s: Did not transcode map[string]interface{} successfully: %v is not %v", typename, input, output) return } } -func testJsonStringArray(t *testing.T, conn *pgx.Conn, typename string, format int16) { +func testJSONStringArray(t *testing.T, conn *pgx.Conn, typename string) { input := []string{"foo", "bar", "baz"} var output []string err := conn.QueryRow("select $1::"+typename, input).Scan(&output) if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) + t.Errorf("%s: QueryRow Scan failed: %v", typename, err) } if !reflect.DeepEqual(input, output) { - t.Errorf("%s %d: Did not transcode []string successfully: %v is not %v", typename, format, input, output) + t.Errorf("%s: Did not transcode []string successfully: %v is not %v", typename, input, output) } } -func testJsonInt64Array(t *testing.T, conn *pgx.Conn, typename string, format int16) { +func testJSONInt64Array(t *testing.T, conn *pgx.Conn, typename string) { input := []int64{1, 2, 234432} var output []int64 err := conn.QueryRow("select $1::"+typename, input).Scan(&output) if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) + t.Errorf("%s: QueryRow Scan failed: %v", typename, err) } if !reflect.DeepEqual(input, output) { - t.Errorf("%s %d: Did not transcode []int64 successfully: %v is not %v", typename, format, input, output) + t.Errorf("%s: Did not transcode []int64 successfully: %v is not %v", typename, input, output) } } -func testJsonInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string, format int16) { +func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string) { input := []int{1, 2, 234432} var output []int16 err := conn.QueryRow("select $1::"+typename, input).Scan(&output) if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" { - t.Errorf("%s %d: Expected *json.UnmarkalTypeError, but got %v", typename, format, err) + t.Errorf("%s: Expected *json.UnmarkalTypeError, but got %v", typename, err) } } -func testJsonStruct(t *testing.T, conn *pgx.Conn, typename string, format int16) { +func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string) { type person struct { Name string `json:"name"` Age int `json:"age"` @@ -224,21 +217,21 @@ func testJsonStruct(t *testing.T, conn *pgx.Conn, typename string, format int16) err := conn.QueryRow("select $1::"+typename, input).Scan(&output) if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) + t.Errorf("%s: QueryRow Scan failed: %v", typename, err) } if !reflect.DeepEqual(input, output) { - t.Errorf("%s %d: Did not transcode struct successfully: %v is not %v", typename, format, input, output) + t.Errorf("%s: Did not transcode struct successfully: %v is not %v", typename, input, output) } } -func mustParseCIDR(t *testing.T, s string) net.IPNet { +func mustParseCIDR(t *testing.T, s string) *net.IPNet { _, ipnet, err := net.ParseCIDR(s) if err != nil { t.Fatal(err) } - return *ipnet + return ipnet } func TestStringToNotTextTypeTranscode(t *testing.T) { @@ -267,7 +260,7 @@ func TestStringToNotTextTypeTranscode(t *testing.T) { } } -func TestInetCidrTranscodeIPNet(t *testing.T) { +func TestInetCIDRTranscodeIPNet(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -275,7 +268,7 @@ func TestInetCidrTranscodeIPNet(t *testing.T) { tests := []struct { sql string - value net.IPNet + value *net.IPNet }{ {"select $1::inet", mustParseCIDR(t, "0.0.0.0/32")}, {"select $1::inet", mustParseCIDR(t, "127.0.0.1/32")}, @@ -316,7 +309,7 @@ func TestInetCidrTranscodeIPNet(t *testing.T) { } } -func TestInetCidrTranscodeIP(t *testing.T) { +func TestInetCIDRTranscodeIP(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -358,7 +351,7 @@ func TestInetCidrTranscodeIP(t *testing.T) { failTests := []struct { sql string - value net.IPNet + value *net.IPNet }{ {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")}, {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")}, @@ -367,8 +360,8 @@ func TestInetCidrTranscodeIP(t *testing.T) { var actual net.IP err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) - if !strings.Contains(err.Error(), "Cannot decode netmask") { - t.Errorf("%d. Expected failure cannot decode netmask, but got: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + if err == nil { + t.Errorf("%d. Expected failure but got none", i) continue } @@ -376,7 +369,7 @@ func TestInetCidrTranscodeIP(t *testing.T) { } } -func TestInetCidrArrayTranscodeIPNet(t *testing.T) { +func TestInetCIDRArrayTranscodeIPNet(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -384,11 +377,11 @@ func TestInetCidrArrayTranscodeIPNet(t *testing.T) { tests := []struct { sql string - value []net.IPNet + value []*net.IPNet }{ { "select $1::inet[]", - []net.IPNet{ + []*net.IPNet{ mustParseCIDR(t, "0.0.0.0/32"), mustParseCIDR(t, "127.0.0.1/32"), mustParseCIDR(t, "12.34.56.0/32"), @@ -403,7 +396,7 @@ func TestInetCidrArrayTranscodeIPNet(t *testing.T) { }, { "select $1::cidr[]", - []net.IPNet{ + []*net.IPNet{ mustParseCIDR(t, "0.0.0.0/32"), mustParseCIDR(t, "127.0.0.1/32"), mustParseCIDR(t, "12.34.56.0/32"), @@ -419,7 +412,7 @@ func TestInetCidrArrayTranscodeIPNet(t *testing.T) { } for i, tt := range tests { - var actual []net.IPNet + var actual []*net.IPNet err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) if err != nil { @@ -435,7 +428,7 @@ func TestInetCidrArrayTranscodeIPNet(t *testing.T) { } } -func TestInetCidrArrayTranscodeIP(t *testing.T) { +func TestInetCIDRArrayTranscodeIP(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -485,18 +478,18 @@ func TestInetCidrArrayTranscodeIP(t *testing.T) { failTests := []struct { sql string - value []net.IPNet + value []*net.IPNet }{ { "select $1::inet[]", - []net.IPNet{ + []*net.IPNet{ mustParseCIDR(t, "12.34.56.0/32"), mustParseCIDR(t, "192.168.1.0/24"), }, }, { "select $1::cidr[]", - []net.IPNet{ + []*net.IPNet{ mustParseCIDR(t, "12.34.56.0/32"), mustParseCIDR(t, "192.168.1.0/24"), }, @@ -507,8 +500,8 @@ func TestInetCidrArrayTranscodeIP(t *testing.T) { var actual []net.IP err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) - if err == nil || !strings.Contains(err.Error(), "Cannot decode netmask") { - t.Errorf("%d. Expected failure cannot decode netmask, but got: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + if err == nil { + t.Errorf("%d. Expected failure but got none", i) continue } @@ -516,7 +509,7 @@ func TestInetCidrArrayTranscodeIP(t *testing.T) { } } -func TestInetCidrTranscodeWithJustIP(t *testing.T) { +func TestInetCIDRTranscodeWithJustIP(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -558,144 +551,6 @@ func TestInetCidrTranscodeWithJustIP(t *testing.T) { } } -func TestNullX(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - type allTypes struct { - s pgx.NullString - i16 pgx.NullInt16 - i32 pgx.NullInt32 - c pgx.NullChar - a pgx.NullAclItem - n pgx.NullName - oid pgx.NullOid - xid pgx.NullXid - cid pgx.NullCid - tid pgx.NullTid - i64 pgx.NullInt64 - f32 pgx.NullFloat32 - f64 pgx.NullFloat64 - b pgx.NullBool - t pgx.NullTime - } - - var actual, zero allTypes - - tests := []struct { - sql string - queryArgs []interface{} - scanArgs []interface{} - expected allTypes - }{ - {"select $1::text", []interface{}{pgx.NullString{String: "foo", Valid: true}}, []interface{}{&actual.s}, allTypes{s: pgx.NullString{String: "foo", Valid: true}}}, - {"select $1::text", []interface{}{pgx.NullString{String: "foo", Valid: false}}, []interface{}{&actual.s}, allTypes{s: pgx.NullString{String: "", Valid: false}}}, - {"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: true}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 1, Valid: true}}}, - {"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: false}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 0, Valid: false}}}, - {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 1, Valid: true}}}, - {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: false}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 0, Valid: false}}}, - {"select $1::oid", []interface{}{pgx.NullOid{Oid: 1, Valid: true}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOid{Oid: 1, Valid: true}}}, - {"select $1::oid", []interface{}{pgx.NullOid{Oid: 1, Valid: false}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOid{Oid: 0, Valid: false}}}, - {"select $1::oid", []interface{}{pgx.NullOid{Oid: 4294967295, Valid: true}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOid{Oid: 4294967295, Valid: true}}}, - {"select $1::xid", []interface{}{pgx.NullXid{Xid: 1, Valid: true}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 1, Valid: true}}}, - {"select $1::xid", []interface{}{pgx.NullXid{Xid: 1, Valid: false}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 0, Valid: false}}}, - {"select $1::xid", []interface{}{pgx.NullXid{Xid: 4294967295, Valid: true}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 4294967295, Valid: true}}}, - {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 1, Valid: true}}}, - {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: false}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 0, Valid: false}}}, - {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 255, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 255, Valid: true}}}, - {"select $1::name", []interface{}{pgx.NullName{Name: "foo", Valid: true}}, []interface{}{&actual.n}, allTypes{n: pgx.NullName{Name: "foo", Valid: true}}}, - {"select $1::name", []interface{}{pgx.NullName{Name: "foo", Valid: false}}, []interface{}{&actual.n}, allTypes{n: pgx.NullName{Name: "", Valid: false}}}, - {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}}, - {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: false}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "", Valid: false}}}, - // A tricky (and valid) aclitem can still be used, especially with Go's useful backticks - {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}}, - {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 1, Valid: true}}}, - {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: false}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 0, Valid: false}}}, - {"select $1::cid", []interface{}{pgx.NullCid{Cid: 4294967295, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 4294967295, Valid: true}}}, - {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}}, - {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: false}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 0, OffsetNumber: 0}, Valid: false}}}, - {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}}, - {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 1, Valid: true}}}, - {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: false}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 0, Valid: false}}}, - {"select $1::float4", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, allTypes{f32: pgx.NullFloat32{Float32: 1.23, Valid: true}}}, - {"select $1::float4", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: false}}, []interface{}{&actual.f32}, allTypes{f32: pgx.NullFloat32{Float32: 0, Valid: false}}}, - {"select $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.f64}, allTypes{f64: pgx.NullFloat64{Float64: 1.23, Valid: true}}}, - {"select $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: false}}, []interface{}{&actual.f64}, allTypes{f64: pgx.NullFloat64{Float64: 0, Valid: false}}}, - {"select $1::bool", []interface{}{pgx.NullBool{Bool: true, Valid: true}}, []interface{}{&actual.b}, allTypes{b: pgx.NullBool{Bool: true, Valid: true}}}, - {"select $1::bool", []interface{}{pgx.NullBool{Bool: true, Valid: false}}, []interface{}{&actual.b}, allTypes{b: pgx.NullBool{Bool: false, Valid: false}}}, - {"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}}, - {"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, - {"select $1::timestamp", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}}, - {"select $1::timestamp", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, - {"select $1::date", []interface{}{pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}}, - {"select $1::date", []interface{}{pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, - {"select 42::int4, $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.i32, &actual.f64}, allTypes{i32: pgx.NullInt32{Int32: 42, Valid: true}, f64: pgx.NullFloat64{Float64: 1.23, Valid: true}}}, - } - - for i, tt := range tests { - actual = zero - - err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) - if err != nil { - t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs) - } - - if actual != tt.expected { - t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, tt.sql, tt.queryArgs) - } - - ensureConnValid(t, conn) - } -} - -func assertAclItemSlicesEqual(t *testing.T, query, scan []pgx.AclItem) { - if !reflect.DeepEqual(query, scan) { - t.Errorf("failed to encode aclitem[]\n EXPECTED: %d %v\n ACTUAL: %d %v", len(query), query, len(scan), scan) - } -} - -func TestAclArrayDecoding(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - sql := "select $1::aclitem[]" - var scan []pgx.AclItem - - tests := []struct { - query []pgx.AclItem - }{ - { - []pgx.AclItem{}, - }, - { - []pgx.AclItem{"=r/postgres"}, - }, - { - []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres"}, - }, - { - []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres", `postgres=arwdDxt/" tricky, ' } "" \ test user "`}, - }, - } - for i, tt := range tests { - err := conn.QueryRow(sql, tt.query).Scan(&scan) - if err != nil { - // t.Errorf(`%d. error reading array: %v`, i, err) - t.Errorf(`%d. error reading array: %v query: %s`, i, err, tt.query) - if pgerr, ok := err.(pgx.PgError); ok { - t.Errorf(`%d. error reading array (detail): %s`, i, pgerr.Detail) - } - continue - } - assertAclItemSlicesEqual(t, tt.query, scan) - ensureConnValid(t, conn) - } -} - func TestArrayDecoding(t *testing.T) { t.Parallel() @@ -772,14 +627,6 @@ func TestArrayDecoding(t *testing.T) { } }, }, - { - "select $1::timestamp[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, - func(t *testing.T, query, scan interface{}) { - if !reflect.DeepEqual(query, *(scan.(*[]time.Time))) { - t.Errorf("failed to encode time.Time[] to timestamp[]") - } - }, - }, { "select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, func(t *testing.T, query, scan interface{}) { @@ -818,36 +665,6 @@ func TestArrayDecoding(t *testing.T) { } } -type shortScanner struct{} - -func (*shortScanner) Scan(r *pgx.ValueReader) error { - r.ReadByte() - return nil -} - -func TestShortScanner(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - rows, err := conn.Query("select 'ab', 'cd' union select 'cd', 'ef'") - if err != nil { - t.Error(err) - } - defer rows.Close() - - for rows.Next() { - var s1, s2 shortScanner - err = rows.Scan(&s1, &s2) - if err != nil { - t.Error(err) - } - } - - ensureConnValid(t, conn) -} - func TestEmptyArrayDecoding(t *testing.T) { t.Parallel() @@ -896,53 +713,6 @@ func TestEmptyArrayDecoding(t *testing.T) { ensureConnValid(t, conn) } -func TestNullXMismatch(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - type allTypes struct { - s pgx.NullString - i16 pgx.NullInt16 - i32 pgx.NullInt32 - i64 pgx.NullInt64 - f32 pgx.NullFloat32 - f64 pgx.NullFloat64 - b pgx.NullBool - t pgx.NullTime - } - - var actual, zero allTypes - - tests := []struct { - sql string - queryArgs []interface{} - scanArgs []interface{} - err string - }{ - {"select $1::date", []interface{}{pgx.NullString{String: "foo", Valid: true}}, []interface{}{&actual.s}, "invalid input syntax for type date"}, - {"select $1::date", []interface{}{pgx.NullInt16{Int16: 1, Valid: true}}, []interface{}{&actual.i16}, "cannot encode into OID 1082"}, - {"select $1::date", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, "cannot encode into OID 1082"}, - {"select $1::date", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, "cannot encode into OID 1082"}, - {"select $1::date", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, "cannot encode into OID 1082"}, - {"select $1::date", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.f64}, "cannot encode into OID 1082"}, - {"select $1::date", []interface{}{pgx.NullBool{Bool: true, Valid: true}}, []interface{}{&actual.b}, "cannot encode into OID 1082"}, - {"select $1::int4", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, "cannot encode into OID 23"}, - } - - for i, tt := range tests { - actual = zero - - err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) - if err == nil || !strings.Contains(err.Error(), tt.err) { - t.Errorf(`%d. Expected error to contain "%s", but it didn't: %v`, i, tt.err, err) - } - - ensureConnValid(t, conn) - } -} - func TestPointerPointer(t *testing.T) { t.Parallel() @@ -1003,8 +773,6 @@ func TestPointerPointer(t *testing.T) { {"select $1::bool", []interface{}{zero.b}, []interface{}{&actual.b}, allTypes{}}, {"select $1::timestamptz", []interface{}{expected.t}, []interface{}{&actual.t}, allTypes{t: expected.t}}, {"select $1::timestamptz", []interface{}{zero.t}, []interface{}{&actual.t}, allTypes{}}, - {"select $1::timestamp", []interface{}{expected.t}, []interface{}{&actual.t}, allTypes{t: expected.t}}, - {"select $1::timestamp", []interface{}{zero.t}, []interface{}{&actual.t}, allTypes{}}, } for i, tt := range tests { @@ -1048,11 +816,11 @@ func TestEncodeTypeRename(t *testing.T) { defer closeConn(t, conn) type _int int - inInt := _int(3) + inInt := _int(1) var outInt _int type _int8 int8 - inInt8 := _int8(3) + inInt8 := _int8(2) var outInt8 _int8 type _int16 int16