diff --git a/conn.go b/conn.go index c2519003..c928ed98 100644 --- a/conn.go +++ b/conn.go @@ -63,8 +63,8 @@ type Conn struct { logLevel int mr msgReader fp *fastpath - pgsql_af_inet *byte - pgsql_af_inet6 *byte + pgsqlAfInet *byte + pgsqlAfInet6 *byte busy bool poolResetCount int preallocatedRows []Rows @@ -145,7 +145,7 @@ func Connect(config ConnConfig) (c *Conn, err error) { return connect(config, nil, nil, nil) } -func connect(config ConnConfig, pgTypes map[Oid]PgType, pgsql_af_inet *byte, pgsql_af_inet6 *byte) (c *Conn, err error) { +func connect(config ConnConfig, pgTypes map[Oid]PgType, pgsqlAfInet *byte, pgsqlAfInet6 *byte) (c *Conn, err error) { c = new(Conn) c.config = config @@ -157,13 +157,13 @@ func connect(config ConnConfig, pgTypes map[Oid]PgType, pgsql_af_inet *byte, pgs } } - if pgsql_af_inet != nil { - c.pgsql_af_inet = new(byte) - *c.pgsql_af_inet = *pgsql_af_inet + if pgsqlAfInet != nil { + c.pgsqlAfInet = new(byte) + *c.pgsqlAfInet = *pgsqlAfInet } - if pgsql_af_inet6 != nil { - c.pgsql_af_inet6 = new(byte) - *c.pgsql_af_inet6 = *pgsql_af_inet6 + if pgsqlAfInet6 != nil { + c.pgsqlAfInet6 = new(byte) + *c.pgsqlAfInet6 = *pgsqlAfInet6 } if c.config.LogLevel != 0 { @@ -315,7 +315,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } } - if c.pgsql_af_inet == nil || c.pgsql_af_inet6 == nil { + if c.pgsqlAfInet == nil || c.pgsqlAfInet6 == nil { err = c.loadInetConstants() if err != nil { return err @@ -372,8 +372,8 @@ func (c *Conn) loadInetConstants() error { return err } - c.pgsql_af_inet = &ipv4[0] - c.pgsql_af_inet6 = &ipv6[0] + c.pgsqlAfInet = &ipv4[0] + c.pgsqlAfInet6 = &ipv6[0] return nil } @@ -446,7 +446,7 @@ func ParseURI(uri string) (ConnConfig, error) { return cp, nil } -var dsn_regexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) +var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) // ParseDSN parses a database DSN (data source name) into a ConnConfig // @@ -462,7 +462,7 @@ var dsn_regexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) func ParseDSN(s string) (ConnConfig, error) { var cp ConnConfig - m := dsn_regexp.FindAllStringSubmatch(s, -1) + m := dsnRegexp.FindAllStringSubmatch(s, -1) var sslmode string @@ -477,11 +477,11 @@ func ParseDSN(s string) (ConnConfig, error) { case "host": cp.Host = b[2] case "port": - if p, err := strconv.ParseUint(b[2], 10, 16); err != nil { + p, err := strconv.ParseUint(b[2], 10, 16) + if err != nil { return cp, err - } else { - cp.Port = uint16(p) } + cp.Port = uint16(p) case "dbname": cp.Database = b[2] case "sslmode": @@ -627,7 +627,7 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared if opts != nil { if len(opts.ParameterOids) > 65535 { - return nil, errors.New(fmt.Sprintf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids))) + return nil, fmt.Errorf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids)) } wbuf.WriteInt16(int16(len(opts.ParameterOids))) for _, oid := range opts.ParameterOids { diff --git a/conn_config_test.go.example b/conn_config_test.go.example index 358e0247..0b80d490 100644 --- a/conn_config_test.go.example +++ b/conn_config_test.go.example @@ -11,7 +11,6 @@ var tcpConnConfig *pgx.ConnConfig = nil var unixSocketConnConfig *pgx.ConnConfig = nil var md5ConnConfig *pgx.ConnConfig = nil var plainPasswordConnConfig *pgx.ConnConfig = nil -var noPasswordConnConfig *pgx.ConnConfig = nil var invalidUserConnConfig *pgx.ConnConfig = nil var tlsConnConfig *pgx.ConnConfig = nil var customDialerConnConfig *pgx.ConnConfig = nil @@ -20,7 +19,6 @@ var customDialerConnConfig *pgx.ConnConfig = nil // var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"} // var md5ConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} // var plainPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"} -// var noPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_none", Database: "pgx_test"} // var invalidUserConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"} // var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}} // var customDialerConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} diff --git a/conn_pool.go b/conn_pool.go index fdd54114..6fbe143a 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -28,8 +28,8 @@ type ConnPool struct { preparedStatements map[string]*PreparedStatement acquireTimeout time.Duration pgTypes map[Oid]PgType - pgsql_af_inet *byte - pgsql_af_inet6 *byte + pgsqlAfInet *byte + pgsqlAfInet6 *byte txAfterClose func(tx *Tx) rowsAfterClose func(rows *Rows) } @@ -148,26 +148,25 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { // Create a new connection. // Careful here: createConnectionUnlocked() removes the current lock, // creates a connection and then locks it back. - if c, err := p.createConnectionUnlocked(); err == nil { - c.poolResetCount = p.resetCount - p.allConnections = append(p.allConnections, c) - return c, nil - } else { + c, err := p.createConnectionUnlocked() + if err != nil { return nil, err } - } else { - // All connections are in use and we cannot create more - if p.logLevel >= LogLevelWarn { - p.logger.Warn("All connections in pool are busy - waiting...") - } + c.poolResetCount = p.resetCount + p.allConnections = append(p.allConnections, c) + return c, nil + } + // All connections are in use and we cannot create more + if p.logLevel >= LogLevelWarn { + p.logger.Warn("All connections in pool are busy - waiting...") + } - // Wait until there is an available connection OR room to create a new connection - for len(p.availableConnections) == 0 && len(p.allConnections)+p.inProgressConnects == p.maxConnections { - if p.deadlinePassed(deadline) { - return nil, errors.New("Timeout: All connections in pool are busy") - } - p.cond.Wait() + // Wait until there is an available connection OR room to create a new connection + for len(p.availableConnections) == 0 && len(p.allConnections)+p.inProgressConnects == p.maxConnections { + if p.deadlinePassed(deadline) { + return nil, errors.New("Timeout: All connections in pool are busy") } + p.cond.Wait() } // Stop the timer so that we do not spawn it on every acquire call. @@ -282,7 +281,7 @@ func (p *ConnPool) Stat() (s ConnPoolStat) { } func (p *ConnPool) createConnection() (*Conn, error) { - c, err := connect(p.config, p.pgTypes, p.pgsql_af_inet, p.pgsql_af_inet6) + c, err := connect(p.config, p.pgTypes, p.pgsqlAfInet, p.pgsqlAfInet6) if err != nil { return nil, err } @@ -318,8 +317,8 @@ func (p *ConnPool) createConnectionUnlocked() (*Conn, error) { // all the known statements for the new connection. func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) { p.pgTypes = c.PgTypes - p.pgsql_af_inet = c.pgsql_af_inet - p.pgsql_af_inet6 = c.pgsql_af_inet6 + p.pgsqlAfInet = c.pgsqlAfInet + p.pgsqlAfInet6 = c.pgsqlAfInet6 if p.afterConnect != nil { err := p.afterConnect(c) diff --git a/conn_pool_test.go b/conn_pool_test.go index 9aa31758..e3ae0036 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -40,7 +40,7 @@ func releaseAllConnections(pool *pgx.ConnPool, connections []*pgx.Conn) { func acquireWithTimeTaken(pool *pgx.ConnPool) (*pgx.Conn, time.Duration, error) { startTime := time.Now() c, err := pool.Acquire() - return c, time.Now().Sub(startTime), err + return c, time.Since(startTime), err } func TestNewConnPool(t *testing.T) { @@ -215,7 +215,7 @@ func TestPoolNonBlockingConnections(t *testing.T) { // Prior to createConnectionUnlocked() use the test took // maxConnections * openTimeout seconds to complete. // With createConnectionUnlocked() it takes ~ 1 * openTimeout seconds. - timeTaken := time.Now().Sub(startedAt) + timeTaken := time.Since(startedAt) if timeTaken > openTimeout+1*time.Second { t.Fatalf("Expected all Acquire() to run in parallel and take about %v, instead it took '%v'", openTimeout, timeTaken) } diff --git a/conn_test.go b/conn_test.go index 181a3ed2..9ed073ce 100644 --- a/conn_test.go +++ b/conn_test.go @@ -914,7 +914,7 @@ func TestPrepareQueryManyParameters(t *testing.T) { args := make([]interface{}, 0, tt.count) for j := 0; j < tt.count; j++ { params = append(params, fmt.Sprintf("($%d::text)", j+1)) - args = append(args, strconv.FormatInt(int64(j), 10)) + args = append(args, strconv.Itoa(j)) } sql := "values" + strings.Join(params, ", ") diff --git a/fastpath.go b/fastpath.go index 8814e559..19b98784 100644 --- a/fastpath.go +++ b/fastpath.go @@ -4,8 +4,6 @@ import ( "encoding/binary" ) -type fastpathArg []byte - func newFastpath(cn *Conn) *fastpath { return &fastpath{cn: cn, fns: make(map[string]Oid)} } diff --git a/hstore.go b/hstore.go index a5d40cce..0ab9f779 100644 --- a/hstore.go +++ b/hstore.go @@ -15,7 +15,6 @@ const ( hsVal hsNul hsNext - hsEnd ) type hstoreParser struct { diff --git a/messages.go b/messages.go index 7f04f1f2..db0258de 100644 --- a/messages.go +++ b/messages.go @@ -39,10 +39,10 @@ func newStartupMessage() *startupMessage { return &startupMessage{map[string]string{}} } -func (self *startupMessage) Bytes() (buf []byte) { +func (s *startupMessage) Bytes() (buf []byte) { buf = make([]byte, 8, 128) binary.BigEndian.PutUint32(buf[4:8], uint32(protocolVersionNumber)) - for key, value := range self.options { + for key, value := range s.options { buf = append(buf, key...) buf = append(buf, 0) buf = append(buf, value...) @@ -89,8 +89,8 @@ type PgError struct { Routine string } -func (self PgError) Error() string { - return self.Severity + ": " + self.Message + " (SQLSTATE " + self.Code + ")" +func (pe PgError) Error() string { + return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" } func newWriteBuf(c *Conn, t byte) *WriteBuf { @@ -99,7 +99,7 @@ func newWriteBuf(c *Conn, t byte) *WriteBuf { return &c.writeBuf } -// WrifeBuf is used build messages to send to the PostgreSQL server. It is used +// WriteBuf is used build messages to send to the PostgreSQL server. It is used // by the Encoder interface when implementing custom encoders. type WriteBuf struct { buf []byte diff --git a/msg_reader.go b/msg_reader.go index 069094cd..c8869bdd 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -62,7 +62,7 @@ func (r *msgReader) readByte() byte { return 0 } - r.msgBytesRemaining -= 1 + r.msgBytesRemaining-- if r.msgBytesRemaining < 0 { r.fatal(errors.New("read past end of message")) return 0 @@ -216,7 +216,7 @@ func (r *msgReader) readString(countI32 int32) string { s = string(buf) r.reader.Discard(count) } else { - buf := make([]byte, int(count)) + buf := make([]byte, count) _, err := io.ReadFull(r.reader, buf) if err != nil { r.fatal(err) diff --git a/query_test.go b/query_test.go index 2cf8b3cd..06a18ffe 100644 --- a/query_test.go +++ b/query_test.go @@ -3,11 +3,12 @@ package pgx_test import ( "bytes" "database/sql" - "github.com/jackc/pgx" "strings" "testing" "time" + "github.com/jackc/pgx" + "github.com/shopspring/decimal" ) @@ -784,7 +785,7 @@ func TestQueryRowCoreByteSlice(t *testing.T) { t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql) } - if bytes.Compare(actual, tt.expected) != 0 { + if !bytes.Equal(actual, tt.expected) { t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql) } diff --git a/sql.go b/sql.go index 9445c263..7ee0f2a0 100644 --- a/sql.go +++ b/sql.go @@ -14,7 +14,7 @@ func init() { placeholders = make([]string, 64) for i := 1; i < 64; i++ { - placeholders[i] = "$" + strconv.FormatInt(int64(i), 10) + placeholders[i] = "$" + strconv.Itoa(i) } } @@ -25,5 +25,5 @@ func (qa *QueryArgs) Append(v interface{}) string { if len(*qa) < len(placeholders) { return placeholders[len(*qa)] } - return "$" + strconv.FormatInt(int64(len(*qa)), 10) + return "$" + strconv.Itoa(len(*qa)) } diff --git a/sql_test.go b/sql_test.go index eafd92fa..dd036035 100644 --- a/sql_test.go +++ b/sql_test.go @@ -1,16 +1,17 @@ package pgx_test import ( - "github.com/jackc/pgx" "strconv" "testing" + + "github.com/jackc/pgx" ) func TestQueryArgs(t *testing.T) { var qa pgx.QueryArgs for i := 1; i < 512; i++ { - expectedPlaceholder := "$" + strconv.FormatInt(int64(i), 10) + expectedPlaceholder := "$" + strconv.Itoa(i) placeholder := qa.Append(i) if placeholder != expectedPlaceholder { t.Errorf(`Expected qa.Append to return "%s", but it returned "%s"`, expectedPlaceholder, placeholder) diff --git a/values.go b/values.go index b6e0a84b..ee43813b 100644 --- a/values.go +++ b/values.go @@ -225,28 +225,28 @@ type NullString struct { Valid bool // Valid is true if String is not NULL } -func (s *NullString) Scan(vr *ValueReader) error { +func (n *NullString) Scan(vr *ValueReader) error { // Not checking oid as so we can scan anything into into a NullString - may revisit this decision later if vr.Len() == -1 { - s.String, s.Valid = "", false + n.String, n.Valid = "", false return nil } - s.Valid = true - s.String = decodeText(vr) + n.Valid = true + n.String = decodeText(vr) return vr.Err() } func (n NullString) FormatCode() int16 { return TextFormatCode } -func (s NullString) Encode(w *WriteBuf, oid Oid) error { - if !s.Valid { +func (n NullString) Encode(w *WriteBuf, oid Oid) error { + if !n.Valid { w.WriteInt32(-1) return nil } - return encodeString(w, oid, s.String) + return encodeString(w, oid, n.String) } // NullInt16 represents an smallint that may be null. NullInt16 implements the @@ -621,10 +621,9 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { if refVal.IsNil() { wbuf.WriteInt32(-1) return nil - } else { - arg = refVal.Elem().Interface() - return Encode(wbuf, oid, arg) } + arg = refVal.Elem().Interface() + return Encode(wbuf, oid, arg) } if oid == JsonOid || oid == JsonbOid { @@ -892,14 +891,13 @@ func Decode(vr *ValueReader, d interface{}) error { el.Set(reflect.Zero(el.Type())) } return nil - } else { - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - d = el.Interface() - return Decode(vr, d) } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + d = el.Interface() + return Decode(vr, d) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: n := decodeInt(vr) if el.OverflowInt(n) { @@ -1645,10 +1643,10 @@ func encodeIPNet(w *WriteBuf, oid Oid, value net.IPNet) error { switch len(value.IP) { case net.IPv4len: size = 8 - family = *w.conn.pgsql_af_inet + family = *w.conn.pgsqlAfInet case net.IPv6len: size = 20 - family = *w.conn.pgsql_af_inet6 + family = *w.conn.pgsqlAfInet6 default: return fmt.Errorf("Unexpected IP length: %v", len(value.IP)) } diff --git a/values_test.go b/values_test.go index 063598d9..7a690055 100644 --- a/values_test.go +++ b/values_test.go @@ -630,7 +630,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::bool[]", []bool{true, false, true}, &[]bool{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]bool))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]bool))) { t.Errorf("failed to encode bool[]") } }, @@ -638,7 +638,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::smallint[]", []int16{2, 4, 484, 32767}, &[]int16{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]int16))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]int16))) { t.Errorf("failed to encode smallint[]") } }, @@ -646,7 +646,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::smallint[]", []uint16{2, 4, 484, 32767}, &[]uint16{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]uint16))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]uint16))) { t.Errorf("failed to encode smallint[]") } }, @@ -654,7 +654,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::int[]", []int32{2, 4, 484}, &[]int32{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]int32))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]int32))) { t.Errorf("failed to encode int[]") } }, @@ -662,7 +662,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::int[]", []uint32{2, 4, 484, 2147483647}, &[]uint32{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]uint32))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]uint32))) { t.Errorf("failed to encode int[]") } }, @@ -670,7 +670,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::bigint[]", []int64{2, 4, 484, 9223372036854775807}, &[]int64{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]int64))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]int64))) { t.Errorf("failed to encode bigint[]") } }, @@ -678,7 +678,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::bigint[]", []uint64{2, 4, 484, 9223372036854775807}, &[]uint64{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]uint64))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]uint64))) { t.Errorf("failed to encode bigint[]") } }, @@ -686,7 +686,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]string))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]string))) { t.Errorf("failed to encode text[]") } }, @@ -694,7 +694,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::timestamp[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]time.Time))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]time.Time))) { t.Errorf("failed to encode time.Time[] to timestamp[]") } }, @@ -702,7 +702,7 @@ func TestArrayDecoding(t *testing.T) { { "select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, func(t *testing.T, query, scan interface{}) { - if reflect.DeepEqual(query, *(scan.(*[]time.Time))) == false { + if !reflect.DeepEqual(query, *(scan.(*[]time.Time))) { t.Errorf("failed to encode time.Time[] to timestamptz[]") } }, @@ -718,7 +718,7 @@ func TestArrayDecoding(t *testing.T) { for i := range queryBytesSliceSlice { qb := queryBytesSliceSlice[i] sb := scanBytesSliceSlice[i] - if bytes.Compare(qb, sb) != 0 { + if !bytes.Equal(qb, sb) { t.Errorf("failed to encode byte[][] to bytea[]: expected %v to equal %v", qb, sb) } }