mirror of https://github.com/jackc/pgx.git
minimal changes for pgbouncer
parent
4adfd1ca2f
commit
7471e7f9eb
14
conn.go
14
conn.go
|
@ -469,10 +469,12 @@ where t.typtype = 'b'
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, oid := range nameOIDs {
|
for name, oid := range nameOIDs {
|
||||||
|
v := &pgtype.EnumArray{}
|
||||||
c.ConnInfo.RegisterDataType(pgtype.DataType{
|
c.ConnInfo.RegisterDataType(pgtype.DataType{
|
||||||
&pgtype.EnumArray{},
|
Value: v,
|
||||||
name,
|
Name: name,
|
||||||
oid,
|
OID: oid,
|
||||||
|
FormatCode: pgtype.DetermineFormatCode(v),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -942,11 +944,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
|
||||||
for i := range ps.FieldDescriptions {
|
for i := range ps.FieldDescriptions {
|
||||||
if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok {
|
if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok {
|
||||||
ps.FieldDescriptions[i].DataTypeName = dt.Name
|
ps.FieldDescriptions[i].DataTypeName = dt.Name
|
||||||
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
ps.FieldDescriptions[i].FormatCode = dt.FormatCode
|
||||||
ps.FieldDescriptions[i].FormatCode = BinaryFormatCode
|
|
||||||
} else {
|
|
||||||
ps.FieldDescriptions[i].FormatCode = TextFormatCode
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
return nil, errors.Errorf("unknown oid: %d", ps.FieldDescriptions[i].DataType)
|
return nil, errors.Errorf("unknown oid: %d", ps.FieldDescriptions[i].DataType)
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,6 +53,12 @@ const (
|
||||||
JSONBOID = 3802
|
JSONBOID = 3802
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PostgreSQL format codes
|
||||||
|
const (
|
||||||
|
textFormatCode int16 = 0
|
||||||
|
binaryFormatCode = 1
|
||||||
|
)
|
||||||
|
|
||||||
type Status byte
|
type Status byte
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -134,9 +140,10 @@ var errUndefined = errors.New("cannot encode status undefined")
|
||||||
var errBadStatus = errors.New("invalid status")
|
var errBadStatus = errors.New("invalid status")
|
||||||
|
|
||||||
type DataType struct {
|
type DataType struct {
|
||||||
Value Value
|
Value Value
|
||||||
Name string
|
Name string
|
||||||
OID OID
|
OID OID
|
||||||
|
FormatCode int16
|
||||||
}
|
}
|
||||||
|
|
||||||
type ConnInfo struct {
|
type ConnInfo struct {
|
||||||
|
@ -153,6 +160,16 @@ func NewConnInfo() *ConnInfo {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ci *ConnInfo) DataTypes() map[OID]DataType {
|
||||||
|
out := make(map[OID]DataType, len(ci.oidToDataType))
|
||||||
|
for _, dt := range ci.oidToDataType {
|
||||||
|
tmp := *dt
|
||||||
|
out[dt.OID] = tmp
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]OID) {
|
func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]OID) {
|
||||||
for name, oid := range nameOIDs {
|
for name, oid := range nameOIDs {
|
||||||
var value Value
|
var value Value
|
||||||
|
@ -161,7 +178,7 @@ func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]OID) {
|
||||||
} else {
|
} else {
|
||||||
value = &GenericText{}
|
value = &GenericText{}
|
||||||
}
|
}
|
||||||
ci.RegisterDataType(DataType{Value: value, Name: name, OID: oid})
|
ci.RegisterDataType(DataType{Value: value, Name: name, OID: oid, FormatCode: DetermineFormatCode(value)})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -196,9 +213,10 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo {
|
||||||
|
|
||||||
for _, dt := range ci.oidToDataType {
|
for _, dt := range ci.oidToDataType {
|
||||||
ci2.RegisterDataType(DataType{
|
ci2.RegisterDataType(DataType{
|
||||||
Value: reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value),
|
Value: reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value),
|
||||||
Name: dt.Name,
|
Name: dt.Name,
|
||||||
OID: dt.OID,
|
OID: dt.OID,
|
||||||
|
FormatCode: dt.FormatCode,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -274,3 +292,13 @@ func init() {
|
||||||
"xid": &XID{},
|
"xid": &XID{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DetermineFormatCode determines the default format code to use
|
||||||
|
// for the given value.
|
||||||
|
func DetermineFormatCode(v Value) int16 {
|
||||||
|
if _, ok := v.(BinaryDecoder); ok {
|
||||||
|
return binaryFormatCode
|
||||||
|
}
|
||||||
|
|
||||||
|
return textFormatCode
|
||||||
|
}
|
||||||
|
|
62
query.go
62
query.go
|
@ -134,7 +134,7 @@ func (rows *Rows) Next() bool {
|
||||||
for i := range rows.fields {
|
for i := range rows.fields {
|
||||||
if dt, ok := rows.conn.ConnInfo.DataTypeForOID(rows.fields[i].DataType); ok {
|
if dt, ok := rows.conn.ConnInfo.DataTypeForOID(rows.fields[i].DataType); ok {
|
||||||
rows.fields[i].DataTypeName = dt.Name
|
rows.fields[i].DataTypeName = dt.Name
|
||||||
rows.fields[i].FormatCode = TextFormatCode
|
rows.fields[i].FormatCode = dt.FormatCode
|
||||||
} else {
|
} else {
|
||||||
rows.fatal(errors.Errorf("unknown oid: %d", rows.fields[i].DataType))
|
rows.fatal(errors.Errorf("unknown oid: %d", rows.fields[i].DataType))
|
||||||
return false
|
return false
|
||||||
|
@ -367,6 +367,9 @@ type QueryExOptions struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) {
|
func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) {
|
||||||
|
var (
|
||||||
|
fieldDescriptions []FieldDescription
|
||||||
|
)
|
||||||
c.lastActivityTime = time.Now()
|
c.lastActivityTime = time.Now()
|
||||||
rows = c.getRows(sql, args)
|
rows = c.getRows(sql, args)
|
||||||
|
|
||||||
|
@ -376,12 +379,12 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
|
||||||
return rows, err
|
return rows, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
if err = c.ensureConnectionReadyForQuery(); err != nil {
|
||||||
rows.fatal(err)
|
rows.fatal(err)
|
||||||
return rows, err
|
return rows, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.lock(); err != nil {
|
if err = c.lock(); err != nil {
|
||||||
rows.fatal(err)
|
rows.fatal(err)
|
||||||
return rows, err
|
return rows, err
|
||||||
}
|
}
|
||||||
|
@ -394,12 +397,18 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
|
||||||
}
|
}
|
||||||
|
|
||||||
if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) {
|
if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) {
|
||||||
err = c.sanitizeAndSendSimpleQuery(sql, args...)
|
if err = c.sanitizeAndSendSimpleQuery(sql, args...); err != nil {
|
||||||
if err != nil {
|
|
||||||
rows.fatal(err)
|
rows.fatal(err)
|
||||||
return rows, err
|
return rows, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if fieldDescriptions, err = c.readFieldDescriptions(QueryExOptions{}); err != nil {
|
||||||
|
rows.fatal(err)
|
||||||
|
return rows, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rows.sql = sql
|
||||||
|
rows.fields = fieldDescriptions
|
||||||
return rows, nil
|
return rows, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -421,27 +430,11 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
|
||||||
}
|
}
|
||||||
c.pendingReadyForQueryCount++
|
c.pendingReadyForQueryCount++
|
||||||
|
|
||||||
fieldDescriptions, err := c.readUntilRowDescription()
|
if fieldDescriptions, err = c.readFieldDescriptions(*options); err != nil {
|
||||||
if err != nil {
|
|
||||||
rows.fatal(err)
|
rows.fatal(err)
|
||||||
return rows, err
|
return rows, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(options.ResultFormatCodes) == 0 {
|
|
||||||
for i := range fieldDescriptions {
|
|
||||||
fieldDescriptions[i].FormatCode = TextFormatCode
|
|
||||||
}
|
|
||||||
} else if len(options.ResultFormatCodes) == 1 {
|
|
||||||
fc := options.ResultFormatCodes[0]
|
|
||||||
for i := range fieldDescriptions {
|
|
||||||
fieldDescriptions[i].FormatCode = fc
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for i := range options.ResultFormatCodes {
|
|
||||||
fieldDescriptions[i].FormatCode = options.ResultFormatCodes[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
rows.sql = sql
|
rows.sql = sql
|
||||||
rows.fields = fieldDescriptions
|
rows.fields = fieldDescriptions
|
||||||
return rows, nil
|
return rows, nil
|
||||||
|
@ -449,7 +442,6 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
|
||||||
|
|
||||||
ps, ok := c.preparedStatements[sql]
|
ps, ok := c.preparedStatements[sql]
|
||||||
if !ok {
|
if !ok {
|
||||||
var err error
|
|
||||||
ps, err = c.prepareEx("", sql, nil)
|
ps, err = c.prepareEx("", sql, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rows.fatal(err)
|
rows.fatal(err)
|
||||||
|
@ -543,3 +535,27 @@ func (c *Conn) QueryRowEx(ctx context.Context, sql string, options *QueryExOptio
|
||||||
rows, _ := c.QueryEx(ctx, sql, options, args...)
|
rows, _ := c.QueryEx(ctx, sql, options, args...)
|
||||||
return (*Row)(rows)
|
return (*Row)(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Conn) readFieldDescriptions(options QueryExOptions) (fieldDescriptions []FieldDescription, err error) {
|
||||||
|
if fieldDescriptions, err = c.readUntilRowDescription(); err != nil {
|
||||||
|
return fieldDescriptions, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch len(options.ResultFormatCodes) {
|
||||||
|
case 0:
|
||||||
|
for i := range fieldDescriptions {
|
||||||
|
fieldDescriptions[i].FormatCode = TextFormatCode
|
||||||
|
}
|
||||||
|
case 1:
|
||||||
|
fc := options.ResultFormatCodes[0]
|
||||||
|
for i := range fieldDescriptions {
|
||||||
|
fieldDescriptions[i].FormatCode = fc
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
for i := range options.ResultFormatCodes {
|
||||||
|
fieldDescriptions[i].FormatCode = options.ResultFormatCodes[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fieldDescriptions, err
|
||||||
|
}
|
||||||
|
|
100
stdlib/sql.go
100
stdlib/sql.go
|
@ -85,10 +85,6 @@ import (
|
||||||
"github.com/jackc/pgx/pgtype"
|
"github.com/jackc/pgx/pgtype"
|
||||||
)
|
)
|
||||||
|
|
||||||
// oids that map to intrinsic database/sql types. These will be allowed to be
|
|
||||||
// binary, anything else will be forced to text format
|
|
||||||
var databaseSqlOIDs map[pgtype.OID]bool
|
|
||||||
|
|
||||||
var pgxDriver *Driver
|
var pgxDriver *Driver
|
||||||
|
|
||||||
type ctxKey int
|
type ctxKey int
|
||||||
|
@ -97,27 +93,30 @@ var ctxKeyFakeTx ctxKey = 0
|
||||||
|
|
||||||
var ErrNotPgx = errors.New("not pgx *sql.DB")
|
var ErrNotPgx = errors.New("not pgx *sql.DB")
|
||||||
|
|
||||||
|
// oids that map to intrinsic database/sql types. These will be allowed to be
|
||||||
|
// binary, anything else will be forced to text format
|
||||||
|
var allowedBinaryOID = []pgtype.OID{
|
||||||
|
pgtype.BoolOID,
|
||||||
|
pgtype.ByteaOID,
|
||||||
|
pgtype.CIDOID,
|
||||||
|
pgtype.DateOID,
|
||||||
|
pgtype.Float4OID,
|
||||||
|
pgtype.Float8OID,
|
||||||
|
pgtype.Int2OID,
|
||||||
|
pgtype.Int4OID,
|
||||||
|
pgtype.Int8OID,
|
||||||
|
pgtype.OIDOID,
|
||||||
|
pgtype.TimestampOID,
|
||||||
|
pgtype.TimestamptzOID,
|
||||||
|
pgtype.XIDOID,
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
pgxDriver = &Driver{
|
pgxDriver = &Driver{
|
||||||
configs: make(map[int64]*DriverConfig),
|
configs: make(map[int64]*DriverConfig),
|
||||||
fakeTxConns: make(map[*pgx.Conn]*sql.Tx),
|
fakeTxConns: make(map[*pgx.Conn]*sql.Tx),
|
||||||
}
|
}
|
||||||
sql.Register("pgx", pgxDriver)
|
sql.Register("pgx", pgxDriver)
|
||||||
|
|
||||||
databaseSqlOIDs = make(map[pgtype.OID]bool)
|
|
||||||
databaseSqlOIDs[pgtype.BoolOID] = true
|
|
||||||
databaseSqlOIDs[pgtype.ByteaOID] = true
|
|
||||||
databaseSqlOIDs[pgtype.CIDOID] = true
|
|
||||||
databaseSqlOIDs[pgtype.DateOID] = true
|
|
||||||
databaseSqlOIDs[pgtype.Float4OID] = true
|
|
||||||
databaseSqlOIDs[pgtype.Float8OID] = true
|
|
||||||
databaseSqlOIDs[pgtype.Int2OID] = true
|
|
||||||
databaseSqlOIDs[pgtype.Int4OID] = true
|
|
||||||
databaseSqlOIDs[pgtype.Int8OID] = true
|
|
||||||
databaseSqlOIDs[pgtype.OIDOID] = true
|
|
||||||
databaseSqlOIDs[pgtype.TimestampOID] = true
|
|
||||||
databaseSqlOIDs[pgtype.TimestamptzOID] = true
|
|
||||||
databaseSqlOIDs[pgtype.XIDOID] = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Driver struct {
|
type Driver struct {
|
||||||
|
@ -130,8 +129,11 @@ type Driver struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Driver) Open(name string) (driver.Conn, error) {
|
func (d *Driver) Open(name string) (driver.Conn, error) {
|
||||||
var connConfig pgx.ConnConfig
|
var (
|
||||||
var afterConnect func(*pgx.Conn) error
|
afterConnect func(*pgx.Conn) error
|
||||||
|
connConfig pgx.ConnConfig
|
||||||
|
)
|
||||||
|
|
||||||
if len(name) >= 9 && name[0] == 0 {
|
if len(name) >= 9 && name[0] == 0 {
|
||||||
idBuf := []byte(name)[1:9]
|
idBuf := []byte(name)[1:9]
|
||||||
id := int64(binary.BigEndian.Uint64(idBuf))
|
id := int64(binary.BigEndian.Uint64(idBuf))
|
||||||
|
@ -151,6 +153,8 @@ func (d *Driver) Open(name string) (driver.Conn, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
conn.ConnInfo = restrictBinary(conn.ConnInfo)
|
||||||
|
|
||||||
if afterConnect != nil {
|
if afterConnect != nil {
|
||||||
err = afterConnect(conn)
|
err = afterConnect(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -232,8 +236,6 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
restrictBinaryToDatabaseSqlTypes(ps)
|
|
||||||
|
|
||||||
return &Stmt{ps: ps, conn: c}, nil
|
return &Stmt{ps: ps, conn: c}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -299,33 +301,28 @@ func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.Nam
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
|
func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
|
||||||
if !c.conn.IsAlive() {
|
return c.query(context.Background(), query, valueToInterface(argsV))
|
||||||
return nil, driver.ErrBadConn
|
|
||||||
}
|
|
||||||
|
|
||||||
ps, err := c.conn.Prepare("", query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
restrictBinaryToDatabaseSqlTypes(ps)
|
|
||||||
|
|
||||||
return c.queryPrepared("", argsV)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) {
|
func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) {
|
||||||
|
return c.query(ctx, query, namedValueToInterface(argsV))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) query(ctx context.Context, query string, args []interface{}) (driver.Rows, error) {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
rows *pgx.Rows
|
||||||
|
)
|
||||||
|
|
||||||
if !c.conn.IsAlive() {
|
if !c.conn.IsAlive() {
|
||||||
return nil, driver.ErrBadConn
|
return nil, driver.ErrBadConn
|
||||||
}
|
}
|
||||||
|
|
||||||
ps, err := c.conn.PrepareEx(ctx, "", query, nil)
|
if rows, err = c.conn.QueryEx(ctx, query, nil, args...); err != nil {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
restrictBinaryToDatabaseSqlTypes(ps)
|
return &Rows{rows: rows}, nil
|
||||||
|
|
||||||
return c.queryPreparedContext(ctx, "", argsV)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) {
|
func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) {
|
||||||
|
@ -369,13 +366,26 @@ func (c *Conn) Ping(ctx context.Context) error {
|
||||||
// Anything that isn't a database/sql compatible type needs to be forced to
|
// Anything that isn't a database/sql compatible type needs to be forced to
|
||||||
// text format so that pgx.Rows.Values doesn't decode it into a native type
|
// text format so that pgx.Rows.Values doesn't decode it into a native type
|
||||||
// (e.g. []int32)
|
// (e.g. []int32)
|
||||||
func restrictBinaryToDatabaseSqlTypes(ps *pgx.PreparedStatement) {
|
func restrictBinary(in *pgtype.ConnInfo) (out *pgtype.ConnInfo) {
|
||||||
for i, _ := range ps.FieldDescriptions {
|
out = in.DeepCopy()
|
||||||
intrinsic, _ := databaseSqlOIDs[ps.FieldDescriptions[i].DataType]
|
for oid, dt := range out.DataTypes() {
|
||||||
if !intrinsic {
|
if textOID(oid) {
|
||||||
ps.FieldDescriptions[i].FormatCode = pgx.TextFormatCode
|
dt.FormatCode = pgx.TextFormatCode
|
||||||
|
out.RegisterDataType(dt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func textOID(oid pgtype.OID) bool {
|
||||||
|
for _, roid := range allowedBinaryOID {
|
||||||
|
if roid == oid {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
type Stmt struct {
|
type Stmt struct {
|
||||||
|
|
|
@ -81,6 +81,63 @@ func closeStmt(t *testing.T, stmt *sql.Stmt) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSimpleQueryLifeCycle(t *testing.T) {
|
||||||
|
driverConfig := stdlib.DriverConfig{
|
||||||
|
ConnConfig: pgx.ConnConfig{PreferSimpleProtocol: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
stdlib.RegisterDriverConfig(&driverConfig)
|
||||||
|
defer stdlib.UnregisterDriverConfig(&driverConfig)
|
||||||
|
|
||||||
|
db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sql.Open failed: %v", err)
|
||||||
|
}
|
||||||
|
defer closeDB(t, db)
|
||||||
|
|
||||||
|
rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rowCount := int64(0)
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
rowCount++
|
||||||
|
var (
|
||||||
|
s string
|
||||||
|
n int64
|
||||||
|
)
|
||||||
|
|
||||||
|
if err := rows.Scan(&s, &n); err != nil {
|
||||||
|
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s != "foo" {
|
||||||
|
t.Errorf(`Expected "foo", received "%v"`, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n != rowCount {
|
||||||
|
t.Errorf("Expected %d, received %d", rowCount, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = rows.Err(); err != nil {
|
||||||
|
t.Fatalf("rows.Err unexpectedly is: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rowCount != 10 {
|
||||||
|
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rows.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("rows.Close unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, db)
|
||||||
|
}
|
||||||
|
|
||||||
func TestNormalLifeCycle(t *testing.T) {
|
func TestNormalLifeCycle(t *testing.T) {
|
||||||
db := openDB(t)
|
db := openDB(t)
|
||||||
defer closeDB(t, db)
|
defer closeDB(t, db)
|
||||||
|
|
Loading…
Reference in New Issue