mirror of https://github.com/jackc/pgx.git
parent
2508faa9ce
commit
9f6b99e332
36
conn.go
36
conn.go
|
@ -63,8 +63,8 @@ type Conn struct {
|
||||||
logLevel int
|
logLevel int
|
||||||
mr msgReader
|
mr msgReader
|
||||||
fp *fastpath
|
fp *fastpath
|
||||||
pgsql_af_inet *byte
|
pgsqlAfInet *byte
|
||||||
pgsql_af_inet6 *byte
|
pgsqlAfInet6 *byte
|
||||||
busy bool
|
busy bool
|
||||||
poolResetCount int
|
poolResetCount int
|
||||||
preallocatedRows []Rows
|
preallocatedRows []Rows
|
||||||
|
@ -145,7 +145,7 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
||||||
return connect(config, nil, nil, nil)
|
return connect(config, nil, nil, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func connect(config ConnConfig, pgTypes map[Oid]PgType, pgsql_af_inet *byte, pgsql_af_inet6 *byte) (c *Conn, err error) {
|
func connect(config ConnConfig, pgTypes map[Oid]PgType, pgsqlAfInet *byte, pgsqlAfInet6 *byte) (c *Conn, err error) {
|
||||||
c = new(Conn)
|
c = new(Conn)
|
||||||
|
|
||||||
c.config = config
|
c.config = config
|
||||||
|
@ -157,13 +157,13 @@ func connect(config ConnConfig, pgTypes map[Oid]PgType, pgsql_af_inet *byte, pgs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if pgsql_af_inet != nil {
|
if pgsqlAfInet != nil {
|
||||||
c.pgsql_af_inet = new(byte)
|
c.pgsqlAfInet = new(byte)
|
||||||
*c.pgsql_af_inet = *pgsql_af_inet
|
*c.pgsqlAfInet = *pgsqlAfInet
|
||||||
}
|
}
|
||||||
if pgsql_af_inet6 != nil {
|
if pgsqlAfInet6 != nil {
|
||||||
c.pgsql_af_inet6 = new(byte)
|
c.pgsqlAfInet6 = new(byte)
|
||||||
*c.pgsql_af_inet6 = *pgsql_af_inet6
|
*c.pgsqlAfInet6 = *pgsqlAfInet6
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.config.LogLevel != 0 {
|
if c.config.LogLevel != 0 {
|
||||||
|
@ -315,7 +315,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.pgsql_af_inet == nil || c.pgsql_af_inet6 == nil {
|
if c.pgsqlAfInet == nil || c.pgsqlAfInet6 == nil {
|
||||||
err = c.loadInetConstants()
|
err = c.loadInetConstants()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -372,8 +372,8 @@ func (c *Conn) loadInetConstants() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.pgsql_af_inet = &ipv4[0]
|
c.pgsqlAfInet = &ipv4[0]
|
||||||
c.pgsql_af_inet6 = &ipv6[0]
|
c.pgsqlAfInet6 = &ipv6[0]
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -446,7 +446,7 @@ func ParseURI(uri string) (ConnConfig, error) {
|
||||||
return cp, nil
|
return cp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var dsn_regexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`)
|
var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`)
|
||||||
|
|
||||||
// ParseDSN parses a database DSN (data source name) into a ConnConfig
|
// ParseDSN parses a database DSN (data source name) into a ConnConfig
|
||||||
//
|
//
|
||||||
|
@ -462,7 +462,7 @@ var dsn_regexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`)
|
||||||
func ParseDSN(s string) (ConnConfig, error) {
|
func ParseDSN(s string) (ConnConfig, error) {
|
||||||
var cp ConnConfig
|
var cp ConnConfig
|
||||||
|
|
||||||
m := dsn_regexp.FindAllStringSubmatch(s, -1)
|
m := dsnRegexp.FindAllStringSubmatch(s, -1)
|
||||||
|
|
||||||
var sslmode string
|
var sslmode string
|
||||||
|
|
||||||
|
@ -477,11 +477,11 @@ func ParseDSN(s string) (ConnConfig, error) {
|
||||||
case "host":
|
case "host":
|
||||||
cp.Host = b[2]
|
cp.Host = b[2]
|
||||||
case "port":
|
case "port":
|
||||||
if p, err := strconv.ParseUint(b[2], 10, 16); err != nil {
|
p, err := strconv.ParseUint(b[2], 10, 16)
|
||||||
|
if err != nil {
|
||||||
return cp, err
|
return cp, err
|
||||||
} else {
|
|
||||||
cp.Port = uint16(p)
|
|
||||||
}
|
}
|
||||||
|
cp.Port = uint16(p)
|
||||||
case "dbname":
|
case "dbname":
|
||||||
cp.Database = b[2]
|
cp.Database = b[2]
|
||||||
case "sslmode":
|
case "sslmode":
|
||||||
|
@ -627,7 +627,7 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
|
||||||
|
|
||||||
if opts != nil {
|
if opts != nil {
|
||||||
if len(opts.ParameterOids) > 65535 {
|
if len(opts.ParameterOids) > 65535 {
|
||||||
return nil, errors.New(fmt.Sprintf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids)))
|
return nil, fmt.Errorf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids))
|
||||||
}
|
}
|
||||||
wbuf.WriteInt16(int16(len(opts.ParameterOids)))
|
wbuf.WriteInt16(int16(len(opts.ParameterOids)))
|
||||||
for _, oid := range opts.ParameterOids {
|
for _, oid := range opts.ParameterOids {
|
||||||
|
|
|
@ -11,7 +11,6 @@ var tcpConnConfig *pgx.ConnConfig = nil
|
||||||
var unixSocketConnConfig *pgx.ConnConfig = nil
|
var unixSocketConnConfig *pgx.ConnConfig = nil
|
||||||
var md5ConnConfig *pgx.ConnConfig = nil
|
var md5ConnConfig *pgx.ConnConfig = nil
|
||||||
var plainPasswordConnConfig *pgx.ConnConfig = nil
|
var plainPasswordConnConfig *pgx.ConnConfig = nil
|
||||||
var noPasswordConnConfig *pgx.ConnConfig = nil
|
|
||||||
var invalidUserConnConfig *pgx.ConnConfig = nil
|
var invalidUserConnConfig *pgx.ConnConfig = nil
|
||||||
var tlsConnConfig *pgx.ConnConfig = nil
|
var tlsConnConfig *pgx.ConnConfig = nil
|
||||||
var customDialerConnConfig *pgx.ConnConfig = nil
|
var customDialerConnConfig *pgx.ConnConfig = nil
|
||||||
|
@ -20,7 +19,6 @@ var customDialerConnConfig *pgx.ConnConfig = nil
|
||||||
// var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"}
|
// var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"}
|
||||||
// var md5ConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
// var md5ConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||||
// var plainPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"}
|
// var plainPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"}
|
||||||
// var noPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_none", Database: "pgx_test"}
|
|
||||||
// var invalidUserConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"}
|
// var invalidUserConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"}
|
||||||
// var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}}
|
// var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||||
// var customDialerConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
// var customDialerConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||||
|
|
41
conn_pool.go
41
conn_pool.go
|
@ -28,8 +28,8 @@ type ConnPool struct {
|
||||||
preparedStatements map[string]*PreparedStatement
|
preparedStatements map[string]*PreparedStatement
|
||||||
acquireTimeout time.Duration
|
acquireTimeout time.Duration
|
||||||
pgTypes map[Oid]PgType
|
pgTypes map[Oid]PgType
|
||||||
pgsql_af_inet *byte
|
pgsqlAfInet *byte
|
||||||
pgsql_af_inet6 *byte
|
pgsqlAfInet6 *byte
|
||||||
txAfterClose func(tx *Tx)
|
txAfterClose func(tx *Tx)
|
||||||
rowsAfterClose func(rows *Rows)
|
rowsAfterClose func(rows *Rows)
|
||||||
}
|
}
|
||||||
|
@ -148,26 +148,25 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) {
|
||||||
// Create a new connection.
|
// Create a new connection.
|
||||||
// Careful here: createConnectionUnlocked() removes the current lock,
|
// Careful here: createConnectionUnlocked() removes the current lock,
|
||||||
// creates a connection and then locks it back.
|
// creates a connection and then locks it back.
|
||||||
if c, err := p.createConnectionUnlocked(); err == nil {
|
c, err := p.createConnectionUnlocked()
|
||||||
c.poolResetCount = p.resetCount
|
if err != nil {
|
||||||
p.allConnections = append(p.allConnections, c)
|
|
||||||
return c, nil
|
|
||||||
} else {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
c.poolResetCount = p.resetCount
|
||||||
// All connections are in use and we cannot create more
|
p.allConnections = append(p.allConnections, c)
|
||||||
if p.logLevel >= LogLevelWarn {
|
return c, nil
|
||||||
p.logger.Warn("All connections in pool are busy - waiting...")
|
}
|
||||||
}
|
// All connections are in use and we cannot create more
|
||||||
|
if p.logLevel >= LogLevelWarn {
|
||||||
|
p.logger.Warn("All connections in pool are busy - waiting...")
|
||||||
|
}
|
||||||
|
|
||||||
// Wait until there is an available connection OR room to create a new connection
|
// Wait until there is an available connection OR room to create a new connection
|
||||||
for len(p.availableConnections) == 0 && len(p.allConnections)+p.inProgressConnects == p.maxConnections {
|
for len(p.availableConnections) == 0 && len(p.allConnections)+p.inProgressConnects == p.maxConnections {
|
||||||
if p.deadlinePassed(deadline) {
|
if p.deadlinePassed(deadline) {
|
||||||
return nil, errors.New("Timeout: All connections in pool are busy")
|
return nil, errors.New("Timeout: All connections in pool are busy")
|
||||||
}
|
|
||||||
p.cond.Wait()
|
|
||||||
}
|
}
|
||||||
|
p.cond.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the timer so that we do not spawn it on every acquire call.
|
// Stop the timer so that we do not spawn it on every acquire call.
|
||||||
|
@ -282,7 +281,7 @@ func (p *ConnPool) Stat() (s ConnPoolStat) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ConnPool) createConnection() (*Conn, error) {
|
func (p *ConnPool) createConnection() (*Conn, error) {
|
||||||
c, err := connect(p.config, p.pgTypes, p.pgsql_af_inet, p.pgsql_af_inet6)
|
c, err := connect(p.config, p.pgTypes, p.pgsqlAfInet, p.pgsqlAfInet6)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -318,8 +317,8 @@ func (p *ConnPool) createConnectionUnlocked() (*Conn, error) {
|
||||||
// all the known statements for the new connection.
|
// all the known statements for the new connection.
|
||||||
func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) {
|
func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) {
|
||||||
p.pgTypes = c.PgTypes
|
p.pgTypes = c.PgTypes
|
||||||
p.pgsql_af_inet = c.pgsql_af_inet
|
p.pgsqlAfInet = c.pgsqlAfInet
|
||||||
p.pgsql_af_inet6 = c.pgsql_af_inet6
|
p.pgsqlAfInet6 = c.pgsqlAfInet6
|
||||||
|
|
||||||
if p.afterConnect != nil {
|
if p.afterConnect != nil {
|
||||||
err := p.afterConnect(c)
|
err := p.afterConnect(c)
|
||||||
|
|
|
@ -40,7 +40,7 @@ func releaseAllConnections(pool *pgx.ConnPool, connections []*pgx.Conn) {
|
||||||
func acquireWithTimeTaken(pool *pgx.ConnPool) (*pgx.Conn, time.Duration, error) {
|
func acquireWithTimeTaken(pool *pgx.ConnPool) (*pgx.Conn, time.Duration, error) {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
c, err := pool.Acquire()
|
c, err := pool.Acquire()
|
||||||
return c, time.Now().Sub(startTime), err
|
return c, time.Since(startTime), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewConnPool(t *testing.T) {
|
func TestNewConnPool(t *testing.T) {
|
||||||
|
@ -215,7 +215,7 @@ func TestPoolNonBlockingConnections(t *testing.T) {
|
||||||
// Prior to createConnectionUnlocked() use the test took
|
// Prior to createConnectionUnlocked() use the test took
|
||||||
// maxConnections * openTimeout seconds to complete.
|
// maxConnections * openTimeout seconds to complete.
|
||||||
// With createConnectionUnlocked() it takes ~ 1 * openTimeout seconds.
|
// With createConnectionUnlocked() it takes ~ 1 * openTimeout seconds.
|
||||||
timeTaken := time.Now().Sub(startedAt)
|
timeTaken := time.Since(startedAt)
|
||||||
if timeTaken > openTimeout+1*time.Second {
|
if timeTaken > openTimeout+1*time.Second {
|
||||||
t.Fatalf("Expected all Acquire() to run in parallel and take about %v, instead it took '%v'", openTimeout, timeTaken)
|
t.Fatalf("Expected all Acquire() to run in parallel and take about %v, instead it took '%v'", openTimeout, timeTaken)
|
||||||
}
|
}
|
||||||
|
|
|
@ -914,7 +914,7 @@ func TestPrepareQueryManyParameters(t *testing.T) {
|
||||||
args := make([]interface{}, 0, tt.count)
|
args := make([]interface{}, 0, tt.count)
|
||||||
for j := 0; j < tt.count; j++ {
|
for j := 0; j < tt.count; j++ {
|
||||||
params = append(params, fmt.Sprintf("($%d::text)", j+1))
|
params = append(params, fmt.Sprintf("($%d::text)", j+1))
|
||||||
args = append(args, strconv.FormatInt(int64(j), 10))
|
args = append(args, strconv.Itoa(j))
|
||||||
}
|
}
|
||||||
|
|
||||||
sql := "values" + strings.Join(params, ", ")
|
sql := "values" + strings.Join(params, ", ")
|
||||||
|
|
|
@ -4,8 +4,6 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
)
|
)
|
||||||
|
|
||||||
type fastpathArg []byte
|
|
||||||
|
|
||||||
func newFastpath(cn *Conn) *fastpath {
|
func newFastpath(cn *Conn) *fastpath {
|
||||||
return &fastpath{cn: cn, fns: make(map[string]Oid)}
|
return &fastpath{cn: cn, fns: make(map[string]Oid)}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,6 @@ const (
|
||||||
hsVal
|
hsVal
|
||||||
hsNul
|
hsNul
|
||||||
hsNext
|
hsNext
|
||||||
hsEnd
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type hstoreParser struct {
|
type hstoreParser struct {
|
||||||
|
|
10
messages.go
10
messages.go
|
@ -39,10 +39,10 @@ func newStartupMessage() *startupMessage {
|
||||||
return &startupMessage{map[string]string{}}
|
return &startupMessage{map[string]string{}}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *startupMessage) Bytes() (buf []byte) {
|
func (s *startupMessage) Bytes() (buf []byte) {
|
||||||
buf = make([]byte, 8, 128)
|
buf = make([]byte, 8, 128)
|
||||||
binary.BigEndian.PutUint32(buf[4:8], uint32(protocolVersionNumber))
|
binary.BigEndian.PutUint32(buf[4:8], uint32(protocolVersionNumber))
|
||||||
for key, value := range self.options {
|
for key, value := range s.options {
|
||||||
buf = append(buf, key...)
|
buf = append(buf, key...)
|
||||||
buf = append(buf, 0)
|
buf = append(buf, 0)
|
||||||
buf = append(buf, value...)
|
buf = append(buf, value...)
|
||||||
|
@ -89,8 +89,8 @@ type PgError struct {
|
||||||
Routine string
|
Routine string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self PgError) Error() string {
|
func (pe PgError) Error() string {
|
||||||
return self.Severity + ": " + self.Message + " (SQLSTATE " + self.Code + ")"
|
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
|
||||||
}
|
}
|
||||||
|
|
||||||
func newWriteBuf(c *Conn, t byte) *WriteBuf {
|
func newWriteBuf(c *Conn, t byte) *WriteBuf {
|
||||||
|
@ -99,7 +99,7 @@ func newWriteBuf(c *Conn, t byte) *WriteBuf {
|
||||||
return &c.writeBuf
|
return &c.writeBuf
|
||||||
}
|
}
|
||||||
|
|
||||||
// WrifeBuf is used build messages to send to the PostgreSQL server. It is used
|
// WriteBuf is used build messages to send to the PostgreSQL server. It is used
|
||||||
// by the Encoder interface when implementing custom encoders.
|
// by the Encoder interface when implementing custom encoders.
|
||||||
type WriteBuf struct {
|
type WriteBuf struct {
|
||||||
buf []byte
|
buf []byte
|
||||||
|
|
|
@ -62,7 +62,7 @@ func (r *msgReader) readByte() byte {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
r.msgBytesRemaining -= 1
|
r.msgBytesRemaining--
|
||||||
if r.msgBytesRemaining < 0 {
|
if r.msgBytesRemaining < 0 {
|
||||||
r.fatal(errors.New("read past end of message"))
|
r.fatal(errors.New("read past end of message"))
|
||||||
return 0
|
return 0
|
||||||
|
@ -216,7 +216,7 @@ func (r *msgReader) readString(countI32 int32) string {
|
||||||
s = string(buf)
|
s = string(buf)
|
||||||
r.reader.Discard(count)
|
r.reader.Discard(count)
|
||||||
} else {
|
} else {
|
||||||
buf := make([]byte, int(count))
|
buf := make([]byte, count)
|
||||||
_, err := io.ReadFull(r.reader, buf)
|
_, err := io.ReadFull(r.reader, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.fatal(err)
|
r.fatal(err)
|
||||||
|
|
|
@ -3,11 +3,12 @@ package pgx_test
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"github.com/jackc/pgx"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx"
|
||||||
|
|
||||||
"github.com/shopspring/decimal"
|
"github.com/shopspring/decimal"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -784,7 +785,7 @@ func TestQueryRowCoreByteSlice(t *testing.T) {
|
||||||
t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql)
|
t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql)
|
||||||
}
|
}
|
||||||
|
|
||||||
if bytes.Compare(actual, tt.expected) != 0 {
|
if !bytes.Equal(actual, tt.expected) {
|
||||||
t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql)
|
t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
4
sql.go
4
sql.go
|
@ -14,7 +14,7 @@ func init() {
|
||||||
placeholders = make([]string, 64)
|
placeholders = make([]string, 64)
|
||||||
|
|
||||||
for i := 1; i < 64; i++ {
|
for i := 1; i < 64; i++ {
|
||||||
placeholders[i] = "$" + strconv.FormatInt(int64(i), 10)
|
placeholders[i] = "$" + strconv.Itoa(i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,5 +25,5 @@ func (qa *QueryArgs) Append(v interface{}) string {
|
||||||
if len(*qa) < len(placeholders) {
|
if len(*qa) < len(placeholders) {
|
||||||
return placeholders[len(*qa)]
|
return placeholders[len(*qa)]
|
||||||
}
|
}
|
||||||
return "$" + strconv.FormatInt(int64(len(*qa)), 10)
|
return "$" + strconv.Itoa(len(*qa))
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,16 +1,17 @@
|
||||||
package pgx_test
|
package pgx_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/jackc/pgx"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestQueryArgs(t *testing.T) {
|
func TestQueryArgs(t *testing.T) {
|
||||||
var qa pgx.QueryArgs
|
var qa pgx.QueryArgs
|
||||||
|
|
||||||
for i := 1; i < 512; i++ {
|
for i := 1; i < 512; i++ {
|
||||||
expectedPlaceholder := "$" + strconv.FormatInt(int64(i), 10)
|
expectedPlaceholder := "$" + strconv.Itoa(i)
|
||||||
placeholder := qa.Append(i)
|
placeholder := qa.Append(i)
|
||||||
if placeholder != expectedPlaceholder {
|
if placeholder != expectedPlaceholder {
|
||||||
t.Errorf(`Expected qa.Append to return "%s", but it returned "%s"`, expectedPlaceholder, placeholder)
|
t.Errorf(`Expected qa.Append to return "%s", but it returned "%s"`, expectedPlaceholder, placeholder)
|
||||||
|
|
36
values.go
36
values.go
|
@ -225,28 +225,28 @@ type NullString struct {
|
||||||
Valid bool // Valid is true if String is not NULL
|
Valid bool // Valid is true if String is not NULL
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *NullString) Scan(vr *ValueReader) error {
|
func (n *NullString) Scan(vr *ValueReader) error {
|
||||||
// Not checking oid as so we can scan anything into into a NullString - may revisit this decision later
|
// Not checking oid as so we can scan anything into into a NullString - may revisit this decision later
|
||||||
|
|
||||||
if vr.Len() == -1 {
|
if vr.Len() == -1 {
|
||||||
s.String, s.Valid = "", false
|
n.String, n.Valid = "", false
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
s.Valid = true
|
n.Valid = true
|
||||||
s.String = decodeText(vr)
|
n.String = decodeText(vr)
|
||||||
return vr.Err()
|
return vr.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n NullString) FormatCode() int16 { return TextFormatCode }
|
func (n NullString) FormatCode() int16 { return TextFormatCode }
|
||||||
|
|
||||||
func (s NullString) Encode(w *WriteBuf, oid Oid) error {
|
func (n NullString) Encode(w *WriteBuf, oid Oid) error {
|
||||||
if !s.Valid {
|
if !n.Valid {
|
||||||
w.WriteInt32(-1)
|
w.WriteInt32(-1)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return encodeString(w, oid, s.String)
|
return encodeString(w, oid, n.String)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NullInt16 represents an smallint that may be null. NullInt16 implements the
|
// NullInt16 represents an smallint that may be null. NullInt16 implements the
|
||||||
|
@ -621,10 +621,9 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error {
|
||||||
if refVal.IsNil() {
|
if refVal.IsNil() {
|
||||||
wbuf.WriteInt32(-1)
|
wbuf.WriteInt32(-1)
|
||||||
return nil
|
return nil
|
||||||
} else {
|
|
||||||
arg = refVal.Elem().Interface()
|
|
||||||
return Encode(wbuf, oid, arg)
|
|
||||||
}
|
}
|
||||||
|
arg = refVal.Elem().Interface()
|
||||||
|
return Encode(wbuf, oid, arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
if oid == JsonOid || oid == JsonbOid {
|
if oid == JsonOid || oid == JsonbOid {
|
||||||
|
@ -892,14 +891,13 @@ func Decode(vr *ValueReader, d interface{}) error {
|
||||||
el.Set(reflect.Zero(el.Type()))
|
el.Set(reflect.Zero(el.Type()))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
} else {
|
|
||||||
if el.IsNil() {
|
|
||||||
// allocate destination
|
|
||||||
el.Set(reflect.New(el.Type().Elem()))
|
|
||||||
}
|
|
||||||
d = el.Interface()
|
|
||||||
return Decode(vr, d)
|
|
||||||
}
|
}
|
||||||
|
if el.IsNil() {
|
||||||
|
// allocate destination
|
||||||
|
el.Set(reflect.New(el.Type().Elem()))
|
||||||
|
}
|
||||||
|
d = el.Interface()
|
||||||
|
return Decode(vr, d)
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
n := decodeInt(vr)
|
n := decodeInt(vr)
|
||||||
if el.OverflowInt(n) {
|
if el.OverflowInt(n) {
|
||||||
|
@ -1645,10 +1643,10 @@ func encodeIPNet(w *WriteBuf, oid Oid, value net.IPNet) error {
|
||||||
switch len(value.IP) {
|
switch len(value.IP) {
|
||||||
case net.IPv4len:
|
case net.IPv4len:
|
||||||
size = 8
|
size = 8
|
||||||
family = *w.conn.pgsql_af_inet
|
family = *w.conn.pgsqlAfInet
|
||||||
case net.IPv6len:
|
case net.IPv6len:
|
||||||
size = 20
|
size = 20
|
||||||
family = *w.conn.pgsql_af_inet6
|
family = *w.conn.pgsqlAfInet6
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("Unexpected IP length: %v", len(value.IP))
|
return fmt.Errorf("Unexpected IP length: %v", len(value.IP))
|
||||||
}
|
}
|
||||||
|
|
|
@ -630,7 +630,7 @@ func TestArrayDecoding(t *testing.T) {
|
||||||
{
|
{
|
||||||
"select $1::bool[]", []bool{true, false, true}, &[]bool{},
|
"select $1::bool[]", []bool{true, false, true}, &[]bool{},
|
||||||
func(t *testing.T, query, scan interface{}) {
|
func(t *testing.T, query, scan interface{}) {
|
||||||
if reflect.DeepEqual(query, *(scan.(*[]bool))) == false {
|
if !reflect.DeepEqual(query, *(scan.(*[]bool))) {
|
||||||
t.Errorf("failed to encode bool[]")
|
t.Errorf("failed to encode bool[]")
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -638,7 +638,7 @@ func TestArrayDecoding(t *testing.T) {
|
||||||
{
|
{
|
||||||
"select $1::smallint[]", []int16{2, 4, 484, 32767}, &[]int16{},
|
"select $1::smallint[]", []int16{2, 4, 484, 32767}, &[]int16{},
|
||||||
func(t *testing.T, query, scan interface{}) {
|
func(t *testing.T, query, scan interface{}) {
|
||||||
if reflect.DeepEqual(query, *(scan.(*[]int16))) == false {
|
if !reflect.DeepEqual(query, *(scan.(*[]int16))) {
|
||||||
t.Errorf("failed to encode smallint[]")
|
t.Errorf("failed to encode smallint[]")
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -646,7 +646,7 @@ func TestArrayDecoding(t *testing.T) {
|
||||||
{
|
{
|
||||||
"select $1::smallint[]", []uint16{2, 4, 484, 32767}, &[]uint16{},
|
"select $1::smallint[]", []uint16{2, 4, 484, 32767}, &[]uint16{},
|
||||||
func(t *testing.T, query, scan interface{}) {
|
func(t *testing.T, query, scan interface{}) {
|
||||||
if reflect.DeepEqual(query, *(scan.(*[]uint16))) == false {
|
if !reflect.DeepEqual(query, *(scan.(*[]uint16))) {
|
||||||
t.Errorf("failed to encode smallint[]")
|
t.Errorf("failed to encode smallint[]")
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -654,7 +654,7 @@ func TestArrayDecoding(t *testing.T) {
|
||||||
{
|
{
|
||||||
"select $1::int[]", []int32{2, 4, 484}, &[]int32{},
|
"select $1::int[]", []int32{2, 4, 484}, &[]int32{},
|
||||||
func(t *testing.T, query, scan interface{}) {
|
func(t *testing.T, query, scan interface{}) {
|
||||||
if reflect.DeepEqual(query, *(scan.(*[]int32))) == false {
|
if !reflect.DeepEqual(query, *(scan.(*[]int32))) {
|
||||||
t.Errorf("failed to encode int[]")
|
t.Errorf("failed to encode int[]")
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -662,7 +662,7 @@ func TestArrayDecoding(t *testing.T) {
|
||||||
{
|
{
|
||||||
"select $1::int[]", []uint32{2, 4, 484, 2147483647}, &[]uint32{},
|
"select $1::int[]", []uint32{2, 4, 484, 2147483647}, &[]uint32{},
|
||||||
func(t *testing.T, query, scan interface{}) {
|
func(t *testing.T, query, scan interface{}) {
|
||||||
if reflect.DeepEqual(query, *(scan.(*[]uint32))) == false {
|
if !reflect.DeepEqual(query, *(scan.(*[]uint32))) {
|
||||||
t.Errorf("failed to encode int[]")
|
t.Errorf("failed to encode int[]")
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -670,7 +670,7 @@ func TestArrayDecoding(t *testing.T) {
|
||||||
{
|
{
|
||||||
"select $1::bigint[]", []int64{2, 4, 484, 9223372036854775807}, &[]int64{},
|
"select $1::bigint[]", []int64{2, 4, 484, 9223372036854775807}, &[]int64{},
|
||||||
func(t *testing.T, query, scan interface{}) {
|
func(t *testing.T, query, scan interface{}) {
|
||||||
if reflect.DeepEqual(query, *(scan.(*[]int64))) == false {
|
if !reflect.DeepEqual(query, *(scan.(*[]int64))) {
|
||||||
t.Errorf("failed to encode bigint[]")
|
t.Errorf("failed to encode bigint[]")
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -678,7 +678,7 @@ func TestArrayDecoding(t *testing.T) {
|
||||||
{
|
{
|
||||||
"select $1::bigint[]", []uint64{2, 4, 484, 9223372036854775807}, &[]uint64{},
|
"select $1::bigint[]", []uint64{2, 4, 484, 9223372036854775807}, &[]uint64{},
|
||||||
func(t *testing.T, query, scan interface{}) {
|
func(t *testing.T, query, scan interface{}) {
|
||||||
if reflect.DeepEqual(query, *(scan.(*[]uint64))) == false {
|
if !reflect.DeepEqual(query, *(scan.(*[]uint64))) {
|
||||||
t.Errorf("failed to encode bigint[]")
|
t.Errorf("failed to encode bigint[]")
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -686,7 +686,7 @@ func TestArrayDecoding(t *testing.T) {
|
||||||
{
|
{
|
||||||
"select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{},
|
"select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{},
|
||||||
func(t *testing.T, query, scan interface{}) {
|
func(t *testing.T, query, scan interface{}) {
|
||||||
if reflect.DeepEqual(query, *(scan.(*[]string))) == false {
|
if !reflect.DeepEqual(query, *(scan.(*[]string))) {
|
||||||
t.Errorf("failed to encode text[]")
|
t.Errorf("failed to encode text[]")
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -694,7 +694,7 @@ func TestArrayDecoding(t *testing.T) {
|
||||||
{
|
{
|
||||||
"select $1::timestamp[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{},
|
"select $1::timestamp[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{},
|
||||||
func(t *testing.T, query, scan interface{}) {
|
func(t *testing.T, query, scan interface{}) {
|
||||||
if reflect.DeepEqual(query, *(scan.(*[]time.Time))) == false {
|
if !reflect.DeepEqual(query, *(scan.(*[]time.Time))) {
|
||||||
t.Errorf("failed to encode time.Time[] to timestamp[]")
|
t.Errorf("failed to encode time.Time[] to timestamp[]")
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -702,7 +702,7 @@ func TestArrayDecoding(t *testing.T) {
|
||||||
{
|
{
|
||||||
"select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{},
|
"select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{},
|
||||||
func(t *testing.T, query, scan interface{}) {
|
func(t *testing.T, query, scan interface{}) {
|
||||||
if reflect.DeepEqual(query, *(scan.(*[]time.Time))) == false {
|
if !reflect.DeepEqual(query, *(scan.(*[]time.Time))) {
|
||||||
t.Errorf("failed to encode time.Time[] to timestamptz[]")
|
t.Errorf("failed to encode time.Time[] to timestamptz[]")
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -718,7 +718,7 @@ func TestArrayDecoding(t *testing.T) {
|
||||||
for i := range queryBytesSliceSlice {
|
for i := range queryBytesSliceSlice {
|
||||||
qb := queryBytesSliceSlice[i]
|
qb := queryBytesSliceSlice[i]
|
||||||
sb := scanBytesSliceSlice[i]
|
sb := scanBytesSliceSlice[i]
|
||||||
if bytes.Compare(qb, sb) != 0 {
|
if !bytes.Equal(qb, sb) {
|
||||||
t.Errorf("failed to encode byte[][] to bytea[]: expected %v to equal %v", qb, sb)
|
t.Errorf("failed to encode byte[][] to bytea[]: expected %v to equal %v", qb, sb)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue