mirror of https://github.com/jackc/pgx.git
minimal changes for pgbouncer
parent
4adfd1ca2f
commit
7471e7f9eb
14
conn.go
14
conn.go
|
@ -469,10 +469,12 @@ where t.typtype = 'b'
|
|||
}
|
||||
|
||||
for name, oid := range nameOIDs {
|
||||
v := &pgtype.EnumArray{}
|
||||
c.ConnInfo.RegisterDataType(pgtype.DataType{
|
||||
&pgtype.EnumArray{},
|
||||
name,
|
||||
oid,
|
||||
Value: v,
|
||||
Name: name,
|
||||
OID: oid,
|
||||
FormatCode: pgtype.DetermineFormatCode(v),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -942,11 +944,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
|
|||
for i := range ps.FieldDescriptions {
|
||||
if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok {
|
||||
ps.FieldDescriptions[i].DataTypeName = dt.Name
|
||||
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
|
||||
ps.FieldDescriptions[i].FormatCode = BinaryFormatCode
|
||||
} else {
|
||||
ps.FieldDescriptions[i].FormatCode = TextFormatCode
|
||||
}
|
||||
ps.FieldDescriptions[i].FormatCode = dt.FormatCode
|
||||
} else {
|
||||
return nil, errors.Errorf("unknown oid: %d", ps.FieldDescriptions[i].DataType)
|
||||
}
|
||||
|
|
|
@ -53,6 +53,12 @@ const (
|
|||
JSONBOID = 3802
|
||||
)
|
||||
|
||||
// PostgreSQL format codes
|
||||
const (
|
||||
textFormatCode int16 = 0
|
||||
binaryFormatCode = 1
|
||||
)
|
||||
|
||||
type Status byte
|
||||
|
||||
const (
|
||||
|
@ -134,9 +140,10 @@ var errUndefined = errors.New("cannot encode status undefined")
|
|||
var errBadStatus = errors.New("invalid status")
|
||||
|
||||
type DataType struct {
|
||||
Value Value
|
||||
Name string
|
||||
OID OID
|
||||
Value Value
|
||||
Name string
|
||||
OID OID
|
||||
FormatCode int16
|
||||
}
|
||||
|
||||
type ConnInfo struct {
|
||||
|
@ -153,6 +160,16 @@ func NewConnInfo() *ConnInfo {
|
|||
}
|
||||
}
|
||||
|
||||
func (ci *ConnInfo) DataTypes() map[OID]DataType {
|
||||
out := make(map[OID]DataType, len(ci.oidToDataType))
|
||||
for _, dt := range ci.oidToDataType {
|
||||
tmp := *dt
|
||||
out[dt.OID] = tmp
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]OID) {
|
||||
for name, oid := range nameOIDs {
|
||||
var value Value
|
||||
|
@ -161,7 +178,7 @@ func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]OID) {
|
|||
} else {
|
||||
value = &GenericText{}
|
||||
}
|
||||
ci.RegisterDataType(DataType{Value: value, Name: name, OID: oid})
|
||||
ci.RegisterDataType(DataType{Value: value, Name: name, OID: oid, FormatCode: DetermineFormatCode(value)})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -196,9 +213,10 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo {
|
|||
|
||||
for _, dt := range ci.oidToDataType {
|
||||
ci2.RegisterDataType(DataType{
|
||||
Value: reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value),
|
||||
Name: dt.Name,
|
||||
OID: dt.OID,
|
||||
Value: reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value),
|
||||
Name: dt.Name,
|
||||
OID: dt.OID,
|
||||
FormatCode: dt.FormatCode,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -274,3 +292,13 @@ func init() {
|
|||
"xid": &XID{},
|
||||
}
|
||||
}
|
||||
|
||||
// DetermineFormatCode determines the default format code to use
|
||||
// for the given value.
|
||||
func DetermineFormatCode(v Value) int16 {
|
||||
if _, ok := v.(BinaryDecoder); ok {
|
||||
return binaryFormatCode
|
||||
}
|
||||
|
||||
return textFormatCode
|
||||
}
|
||||
|
|
62
query.go
62
query.go
|
@ -134,7 +134,7 @@ func (rows *Rows) Next() bool {
|
|||
for i := range rows.fields {
|
||||
if dt, ok := rows.conn.ConnInfo.DataTypeForOID(rows.fields[i].DataType); ok {
|
||||
rows.fields[i].DataTypeName = dt.Name
|
||||
rows.fields[i].FormatCode = TextFormatCode
|
||||
rows.fields[i].FormatCode = dt.FormatCode
|
||||
} else {
|
||||
rows.fatal(errors.Errorf("unknown oid: %d", rows.fields[i].DataType))
|
||||
return false
|
||||
|
@ -367,6 +367,9 @@ type QueryExOptions struct {
|
|||
}
|
||||
|
||||
func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) {
|
||||
var (
|
||||
fieldDescriptions []FieldDescription
|
||||
)
|
||||
c.lastActivityTime = time.Now()
|
||||
rows = c.getRows(sql, args)
|
||||
|
||||
|
@ -376,12 +379,12 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
|
|||
return rows, err
|
||||
}
|
||||
|
||||
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
||||
if err = c.ensureConnectionReadyForQuery(); err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
if err := c.lock(); err != nil {
|
||||
if err = c.lock(); err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
@ -394,12 +397,18 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
|
|||
}
|
||||
|
||||
if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) {
|
||||
err = c.sanitizeAndSendSimpleQuery(sql, args...)
|
||||
if err != nil {
|
||||
if err = c.sanitizeAndSendSimpleQuery(sql, args...); err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
if fieldDescriptions, err = c.readFieldDescriptions(QueryExOptions{}); err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
rows.sql = sql
|
||||
rows.fields = fieldDescriptions
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
|
@ -421,27 +430,11 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
|
|||
}
|
||||
c.pendingReadyForQueryCount++
|
||||
|
||||
fieldDescriptions, err := c.readUntilRowDescription()
|
||||
if err != nil {
|
||||
if fieldDescriptions, err = c.readFieldDescriptions(*options); err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
if len(options.ResultFormatCodes) == 0 {
|
||||
for i := range fieldDescriptions {
|
||||
fieldDescriptions[i].FormatCode = TextFormatCode
|
||||
}
|
||||
} else if len(options.ResultFormatCodes) == 1 {
|
||||
fc := options.ResultFormatCodes[0]
|
||||
for i := range fieldDescriptions {
|
||||
fieldDescriptions[i].FormatCode = fc
|
||||
}
|
||||
} else {
|
||||
for i := range options.ResultFormatCodes {
|
||||
fieldDescriptions[i].FormatCode = options.ResultFormatCodes[i]
|
||||
}
|
||||
}
|
||||
|
||||
rows.sql = sql
|
||||
rows.fields = fieldDescriptions
|
||||
return rows, nil
|
||||
|
@ -449,7 +442,6 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
|
|||
|
||||
ps, ok := c.preparedStatements[sql]
|
||||
if !ok {
|
||||
var err error
|
||||
ps, err = c.prepareEx("", sql, nil)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
|
@ -543,3 +535,27 @@ func (c *Conn) QueryRowEx(ctx context.Context, sql string, options *QueryExOptio
|
|||
rows, _ := c.QueryEx(ctx, sql, options, args...)
|
||||
return (*Row)(rows)
|
||||
}
|
||||
|
||||
func (c *Conn) readFieldDescriptions(options QueryExOptions) (fieldDescriptions []FieldDescription, err error) {
|
||||
if fieldDescriptions, err = c.readUntilRowDescription(); err != nil {
|
||||
return fieldDescriptions, err
|
||||
}
|
||||
|
||||
switch len(options.ResultFormatCodes) {
|
||||
case 0:
|
||||
for i := range fieldDescriptions {
|
||||
fieldDescriptions[i].FormatCode = TextFormatCode
|
||||
}
|
||||
case 1:
|
||||
fc := options.ResultFormatCodes[0]
|
||||
for i := range fieldDescriptions {
|
||||
fieldDescriptions[i].FormatCode = fc
|
||||
}
|
||||
default:
|
||||
for i := range options.ResultFormatCodes {
|
||||
fieldDescriptions[i].FormatCode = options.ResultFormatCodes[i]
|
||||
}
|
||||
}
|
||||
|
||||
return fieldDescriptions, err
|
||||
}
|
||||
|
|
100
stdlib/sql.go
100
stdlib/sql.go
|
@ -85,10 +85,6 @@ import (
|
|||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
// oids that map to intrinsic database/sql types. These will be allowed to be
|
||||
// binary, anything else will be forced to text format
|
||||
var databaseSqlOIDs map[pgtype.OID]bool
|
||||
|
||||
var pgxDriver *Driver
|
||||
|
||||
type ctxKey int
|
||||
|
@ -97,27 +93,30 @@ var ctxKeyFakeTx ctxKey = 0
|
|||
|
||||
var ErrNotPgx = errors.New("not pgx *sql.DB")
|
||||
|
||||
// oids that map to intrinsic database/sql types. These will be allowed to be
|
||||
// binary, anything else will be forced to text format
|
||||
var allowedBinaryOID = []pgtype.OID{
|
||||
pgtype.BoolOID,
|
||||
pgtype.ByteaOID,
|
||||
pgtype.CIDOID,
|
||||
pgtype.DateOID,
|
||||
pgtype.Float4OID,
|
||||
pgtype.Float8OID,
|
||||
pgtype.Int2OID,
|
||||
pgtype.Int4OID,
|
||||
pgtype.Int8OID,
|
||||
pgtype.OIDOID,
|
||||
pgtype.TimestampOID,
|
||||
pgtype.TimestamptzOID,
|
||||
pgtype.XIDOID,
|
||||
}
|
||||
|
||||
func init() {
|
||||
pgxDriver = &Driver{
|
||||
configs: make(map[int64]*DriverConfig),
|
||||
fakeTxConns: make(map[*pgx.Conn]*sql.Tx),
|
||||
}
|
||||
sql.Register("pgx", pgxDriver)
|
||||
|
||||
databaseSqlOIDs = make(map[pgtype.OID]bool)
|
||||
databaseSqlOIDs[pgtype.BoolOID] = true
|
||||
databaseSqlOIDs[pgtype.ByteaOID] = true
|
||||
databaseSqlOIDs[pgtype.CIDOID] = true
|
||||
databaseSqlOIDs[pgtype.DateOID] = true
|
||||
databaseSqlOIDs[pgtype.Float4OID] = true
|
||||
databaseSqlOIDs[pgtype.Float8OID] = true
|
||||
databaseSqlOIDs[pgtype.Int2OID] = true
|
||||
databaseSqlOIDs[pgtype.Int4OID] = true
|
||||
databaseSqlOIDs[pgtype.Int8OID] = true
|
||||
databaseSqlOIDs[pgtype.OIDOID] = true
|
||||
databaseSqlOIDs[pgtype.TimestampOID] = true
|
||||
databaseSqlOIDs[pgtype.TimestamptzOID] = true
|
||||
databaseSqlOIDs[pgtype.XIDOID] = true
|
||||
}
|
||||
|
||||
type Driver struct {
|
||||
|
@ -130,8 +129,11 @@ type Driver struct {
|
|||
}
|
||||
|
||||
func (d *Driver) Open(name string) (driver.Conn, error) {
|
||||
var connConfig pgx.ConnConfig
|
||||
var afterConnect func(*pgx.Conn) error
|
||||
var (
|
||||
afterConnect func(*pgx.Conn) error
|
||||
connConfig pgx.ConnConfig
|
||||
)
|
||||
|
||||
if len(name) >= 9 && name[0] == 0 {
|
||||
idBuf := []byte(name)[1:9]
|
||||
id := int64(binary.BigEndian.Uint64(idBuf))
|
||||
|
@ -151,6 +153,8 @@ func (d *Driver) Open(name string) (driver.Conn, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
conn.ConnInfo = restrictBinary(conn.ConnInfo)
|
||||
|
||||
if afterConnect != nil {
|
||||
err = afterConnect(conn)
|
||||
if err != nil {
|
||||
|
@ -232,8 +236,6 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
|
|||
return nil, err
|
||||
}
|
||||
|
||||
restrictBinaryToDatabaseSqlTypes(ps)
|
||||
|
||||
return &Stmt{ps: ps, conn: c}, nil
|
||||
}
|
||||
|
||||
|
@ -299,33 +301,28 @@ func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.Nam
|
|||
}
|
||||
|
||||
func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
|
||||
if !c.conn.IsAlive() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
ps, err := c.conn.Prepare("", query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restrictBinaryToDatabaseSqlTypes(ps)
|
||||
|
||||
return c.queryPrepared("", argsV)
|
||||
return c.query(context.Background(), query, valueToInterface(argsV))
|
||||
}
|
||||
|
||||
func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) {
|
||||
return c.query(ctx, query, namedValueToInterface(argsV))
|
||||
}
|
||||
|
||||
func (c *Conn) query(ctx context.Context, query string, args []interface{}) (driver.Rows, error) {
|
||||
var (
|
||||
err error
|
||||
rows *pgx.Rows
|
||||
)
|
||||
|
||||
if !c.conn.IsAlive() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
ps, err := c.conn.PrepareEx(ctx, "", query, nil)
|
||||
if err != nil {
|
||||
if rows, err = c.conn.QueryEx(ctx, query, nil, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restrictBinaryToDatabaseSqlTypes(ps)
|
||||
|
||||
return c.queryPreparedContext(ctx, "", argsV)
|
||||
return &Rows{rows: rows}, nil
|
||||
}
|
||||
|
||||
func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) {
|
||||
|
@ -369,13 +366,26 @@ func (c *Conn) Ping(ctx context.Context) error {
|
|||
// Anything that isn't a database/sql compatible type needs to be forced to
|
||||
// text format so that pgx.Rows.Values doesn't decode it into a native type
|
||||
// (e.g. []int32)
|
||||
func restrictBinaryToDatabaseSqlTypes(ps *pgx.PreparedStatement) {
|
||||
for i, _ := range ps.FieldDescriptions {
|
||||
intrinsic, _ := databaseSqlOIDs[ps.FieldDescriptions[i].DataType]
|
||||
if !intrinsic {
|
||||
ps.FieldDescriptions[i].FormatCode = pgx.TextFormatCode
|
||||
func restrictBinary(in *pgtype.ConnInfo) (out *pgtype.ConnInfo) {
|
||||
out = in.DeepCopy()
|
||||
for oid, dt := range out.DataTypes() {
|
||||
if textOID(oid) {
|
||||
dt.FormatCode = pgx.TextFormatCode
|
||||
out.RegisterDataType(dt)
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func textOID(oid pgtype.OID) bool {
|
||||
for _, roid := range allowedBinaryOID {
|
||||
if roid == oid {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
type Stmt struct {
|
||||
|
|
|
@ -81,6 +81,63 @@ func closeStmt(t *testing.T, stmt *sql.Stmt) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSimpleQueryLifeCycle(t *testing.T) {
|
||||
driverConfig := stdlib.DriverConfig{
|
||||
ConnConfig: pgx.ConnConfig{PreferSimpleProtocol: true},
|
||||
}
|
||||
|
||||
stdlib.RegisterDriverConfig(&driverConfig)
|
||||
defer stdlib.UnregisterDriverConfig(&driverConfig)
|
||||
|
||||
db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test"))
|
||||
if err != nil {
|
||||
t.Fatalf("sql.Open failed: %v", err)
|
||||
}
|
||||
defer closeDB(t, db)
|
||||
|
||||
rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
|
||||
}
|
||||
|
||||
rowCount := int64(0)
|
||||
|
||||
for rows.Next() {
|
||||
rowCount++
|
||||
var (
|
||||
s string
|
||||
n int64
|
||||
)
|
||||
|
||||
if err := rows.Scan(&s, &n); err != nil {
|
||||
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
|
||||
}
|
||||
|
||||
if s != "foo" {
|
||||
t.Errorf(`Expected "foo", received "%v"`, s)
|
||||
}
|
||||
|
||||
if n != rowCount {
|
||||
t.Errorf("Expected %d, received %d", rowCount, n)
|
||||
}
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
t.Fatalf("rows.Err unexpectedly is: %v", err)
|
||||
}
|
||||
|
||||
if rowCount != 10 {
|
||||
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
|
||||
}
|
||||
|
||||
err = rows.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("rows.Close unexpectedly failed: %v", err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, db)
|
||||
}
|
||||
|
||||
func TestNormalLifeCycle(t *testing.T) {
|
||||
db := openDB(t)
|
||||
defer closeDB(t, db)
|
||||
|
|
Loading…
Reference in New Issue