minimal changes for pgbouncer

pull/365/head
James Lawrence 2017-12-16 17:30:55 -05:00
parent 4adfd1ca2f
commit 7471e7f9eb
5 changed files with 192 additions and 83 deletions

14
conn.go
View File

@ -469,10 +469,12 @@ where t.typtype = 'b'
}
for name, oid := range nameOIDs {
v := &pgtype.EnumArray{}
c.ConnInfo.RegisterDataType(pgtype.DataType{
&pgtype.EnumArray{},
name,
oid,
Value: v,
Name: name,
OID: oid,
FormatCode: pgtype.DetermineFormatCode(v),
})
}
@ -942,11 +944,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
for i := range ps.FieldDescriptions {
if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok {
ps.FieldDescriptions[i].DataTypeName = dt.Name
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
ps.FieldDescriptions[i].FormatCode = BinaryFormatCode
} else {
ps.FieldDescriptions[i].FormatCode = TextFormatCode
}
ps.FieldDescriptions[i].FormatCode = dt.FormatCode
} else {
return nil, errors.Errorf("unknown oid: %d", ps.FieldDescriptions[i].DataType)
}

View File

@ -53,6 +53,12 @@ const (
JSONBOID = 3802
)
// PostgreSQL format codes
const (
textFormatCode int16 = 0
binaryFormatCode = 1
)
type Status byte
const (
@ -134,9 +140,10 @@ var errUndefined = errors.New("cannot encode status undefined")
var errBadStatus = errors.New("invalid status")
type DataType struct {
Value Value
Name string
OID OID
Value Value
Name string
OID OID
FormatCode int16
}
type ConnInfo struct {
@ -153,6 +160,16 @@ func NewConnInfo() *ConnInfo {
}
}
func (ci *ConnInfo) DataTypes() map[OID]DataType {
out := make(map[OID]DataType, len(ci.oidToDataType))
for _, dt := range ci.oidToDataType {
tmp := *dt
out[dt.OID] = tmp
}
return out
}
func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]OID) {
for name, oid := range nameOIDs {
var value Value
@ -161,7 +178,7 @@ func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]OID) {
} else {
value = &GenericText{}
}
ci.RegisterDataType(DataType{Value: value, Name: name, OID: oid})
ci.RegisterDataType(DataType{Value: value, Name: name, OID: oid, FormatCode: DetermineFormatCode(value)})
}
}
@ -196,9 +213,10 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo {
for _, dt := range ci.oidToDataType {
ci2.RegisterDataType(DataType{
Value: reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value),
Name: dt.Name,
OID: dt.OID,
Value: reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value),
Name: dt.Name,
OID: dt.OID,
FormatCode: dt.FormatCode,
})
}
@ -274,3 +292,13 @@ func init() {
"xid": &XID{},
}
}
// DetermineFormatCode determines the default format code to use
// for the given value.
func DetermineFormatCode(v Value) int16 {
if _, ok := v.(BinaryDecoder); ok {
return binaryFormatCode
}
return textFormatCode
}

View File

@ -134,7 +134,7 @@ func (rows *Rows) Next() bool {
for i := range rows.fields {
if dt, ok := rows.conn.ConnInfo.DataTypeForOID(rows.fields[i].DataType); ok {
rows.fields[i].DataTypeName = dt.Name
rows.fields[i].FormatCode = TextFormatCode
rows.fields[i].FormatCode = dt.FormatCode
} else {
rows.fatal(errors.Errorf("unknown oid: %d", rows.fields[i].DataType))
return false
@ -367,6 +367,9 @@ type QueryExOptions struct {
}
func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) {
var (
fieldDescriptions []FieldDescription
)
c.lastActivityTime = time.Now()
rows = c.getRows(sql, args)
@ -376,12 +379,12 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
return rows, err
}
if err := c.ensureConnectionReadyForQuery(); err != nil {
if err = c.ensureConnectionReadyForQuery(); err != nil {
rows.fatal(err)
return rows, err
}
if err := c.lock(); err != nil {
if err = c.lock(); err != nil {
rows.fatal(err)
return rows, err
}
@ -394,12 +397,18 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
}
if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) {
err = c.sanitizeAndSendSimpleQuery(sql, args...)
if err != nil {
if err = c.sanitizeAndSendSimpleQuery(sql, args...); err != nil {
rows.fatal(err)
return rows, err
}
if fieldDescriptions, err = c.readFieldDescriptions(QueryExOptions{}); err != nil {
rows.fatal(err)
return rows, err
}
rows.sql = sql
rows.fields = fieldDescriptions
return rows, nil
}
@ -421,27 +430,11 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
}
c.pendingReadyForQueryCount++
fieldDescriptions, err := c.readUntilRowDescription()
if err != nil {
if fieldDescriptions, err = c.readFieldDescriptions(*options); err != nil {
rows.fatal(err)
return rows, err
}
if len(options.ResultFormatCodes) == 0 {
for i := range fieldDescriptions {
fieldDescriptions[i].FormatCode = TextFormatCode
}
} else if len(options.ResultFormatCodes) == 1 {
fc := options.ResultFormatCodes[0]
for i := range fieldDescriptions {
fieldDescriptions[i].FormatCode = fc
}
} else {
for i := range options.ResultFormatCodes {
fieldDescriptions[i].FormatCode = options.ResultFormatCodes[i]
}
}
rows.sql = sql
rows.fields = fieldDescriptions
return rows, nil
@ -449,7 +442,6 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
ps, ok := c.preparedStatements[sql]
if !ok {
var err error
ps, err = c.prepareEx("", sql, nil)
if err != nil {
rows.fatal(err)
@ -543,3 +535,27 @@ func (c *Conn) QueryRowEx(ctx context.Context, sql string, options *QueryExOptio
rows, _ := c.QueryEx(ctx, sql, options, args...)
return (*Row)(rows)
}
func (c *Conn) readFieldDescriptions(options QueryExOptions) (fieldDescriptions []FieldDescription, err error) {
if fieldDescriptions, err = c.readUntilRowDescription(); err != nil {
return fieldDescriptions, err
}
switch len(options.ResultFormatCodes) {
case 0:
for i := range fieldDescriptions {
fieldDescriptions[i].FormatCode = TextFormatCode
}
case 1:
fc := options.ResultFormatCodes[0]
for i := range fieldDescriptions {
fieldDescriptions[i].FormatCode = fc
}
default:
for i := range options.ResultFormatCodes {
fieldDescriptions[i].FormatCode = options.ResultFormatCodes[i]
}
}
return fieldDescriptions, err
}

View File

@ -85,10 +85,6 @@ import (
"github.com/jackc/pgx/pgtype"
)
// oids that map to intrinsic database/sql types. These will be allowed to be
// binary, anything else will be forced to text format
var databaseSqlOIDs map[pgtype.OID]bool
var pgxDriver *Driver
type ctxKey int
@ -97,27 +93,30 @@ var ctxKeyFakeTx ctxKey = 0
var ErrNotPgx = errors.New("not pgx *sql.DB")
// oids that map to intrinsic database/sql types. These will be allowed to be
// binary, anything else will be forced to text format
var allowedBinaryOID = []pgtype.OID{
pgtype.BoolOID,
pgtype.ByteaOID,
pgtype.CIDOID,
pgtype.DateOID,
pgtype.Float4OID,
pgtype.Float8OID,
pgtype.Int2OID,
pgtype.Int4OID,
pgtype.Int8OID,
pgtype.OIDOID,
pgtype.TimestampOID,
pgtype.TimestamptzOID,
pgtype.XIDOID,
}
func init() {
pgxDriver = &Driver{
configs: make(map[int64]*DriverConfig),
fakeTxConns: make(map[*pgx.Conn]*sql.Tx),
}
sql.Register("pgx", pgxDriver)
databaseSqlOIDs = make(map[pgtype.OID]bool)
databaseSqlOIDs[pgtype.BoolOID] = true
databaseSqlOIDs[pgtype.ByteaOID] = true
databaseSqlOIDs[pgtype.CIDOID] = true
databaseSqlOIDs[pgtype.DateOID] = true
databaseSqlOIDs[pgtype.Float4OID] = true
databaseSqlOIDs[pgtype.Float8OID] = true
databaseSqlOIDs[pgtype.Int2OID] = true
databaseSqlOIDs[pgtype.Int4OID] = true
databaseSqlOIDs[pgtype.Int8OID] = true
databaseSqlOIDs[pgtype.OIDOID] = true
databaseSqlOIDs[pgtype.TimestampOID] = true
databaseSqlOIDs[pgtype.TimestamptzOID] = true
databaseSqlOIDs[pgtype.XIDOID] = true
}
type Driver struct {
@ -130,8 +129,11 @@ type Driver struct {
}
func (d *Driver) Open(name string) (driver.Conn, error) {
var connConfig pgx.ConnConfig
var afterConnect func(*pgx.Conn) error
var (
afterConnect func(*pgx.Conn) error
connConfig pgx.ConnConfig
)
if len(name) >= 9 && name[0] == 0 {
idBuf := []byte(name)[1:9]
id := int64(binary.BigEndian.Uint64(idBuf))
@ -151,6 +153,8 @@ func (d *Driver) Open(name string) (driver.Conn, error) {
return nil, err
}
conn.ConnInfo = restrictBinary(conn.ConnInfo)
if afterConnect != nil {
err = afterConnect(conn)
if err != nil {
@ -232,8 +236,6 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
return nil, err
}
restrictBinaryToDatabaseSqlTypes(ps)
return &Stmt{ps: ps, conn: c}, nil
}
@ -299,33 +301,28 @@ func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.Nam
}
func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
ps, err := c.conn.Prepare("", query)
if err != nil {
return nil, err
}
restrictBinaryToDatabaseSqlTypes(ps)
return c.queryPrepared("", argsV)
return c.query(context.Background(), query, valueToInterface(argsV))
}
func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) {
return c.query(ctx, query, namedValueToInterface(argsV))
}
func (c *Conn) query(ctx context.Context, query string, args []interface{}) (driver.Rows, error) {
var (
err error
rows *pgx.Rows
)
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
ps, err := c.conn.PrepareEx(ctx, "", query, nil)
if err != nil {
if rows, err = c.conn.QueryEx(ctx, query, nil, args...); err != nil {
return nil, err
}
restrictBinaryToDatabaseSqlTypes(ps)
return c.queryPreparedContext(ctx, "", argsV)
return &Rows{rows: rows}, nil
}
func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) {
@ -369,13 +366,26 @@ func (c *Conn) Ping(ctx context.Context) error {
// Anything that isn't a database/sql compatible type needs to be forced to
// text format so that pgx.Rows.Values doesn't decode it into a native type
// (e.g. []int32)
func restrictBinaryToDatabaseSqlTypes(ps *pgx.PreparedStatement) {
for i, _ := range ps.FieldDescriptions {
intrinsic, _ := databaseSqlOIDs[ps.FieldDescriptions[i].DataType]
if !intrinsic {
ps.FieldDescriptions[i].FormatCode = pgx.TextFormatCode
func restrictBinary(in *pgtype.ConnInfo) (out *pgtype.ConnInfo) {
out = in.DeepCopy()
for oid, dt := range out.DataTypes() {
if textOID(oid) {
dt.FormatCode = pgx.TextFormatCode
out.RegisterDataType(dt)
}
}
return out
}
func textOID(oid pgtype.OID) bool {
for _, roid := range allowedBinaryOID {
if roid == oid {
return false
}
}
return true
}
type Stmt struct {

View File

@ -81,6 +81,63 @@ func closeStmt(t *testing.T, stmt *sql.Stmt) {
}
}
func TestSimpleQueryLifeCycle(t *testing.T) {
driverConfig := stdlib.DriverConfig{
ConnConfig: pgx.ConnConfig{PreferSimpleProtocol: true},
}
stdlib.RegisterDriverConfig(&driverConfig)
defer stdlib.UnregisterDriverConfig(&driverConfig)
db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test"))
if err != nil {
t.Fatalf("sql.Open failed: %v", err)
}
defer closeDB(t, db)
rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3)
if err != nil {
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
}
rowCount := int64(0)
for rows.Next() {
rowCount++
var (
s string
n int64
)
if err := rows.Scan(&s, &n); err != nil {
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
}
if s != "foo" {
t.Errorf(`Expected "foo", received "%v"`, s)
}
if n != rowCount {
t.Errorf("Expected %d, received %d", rowCount, n)
}
}
if err = rows.Err(); err != nil {
t.Fatalf("rows.Err unexpectedly is: %v", err)
}
if rowCount != 10 {
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
}
err = rows.Close()
if err != nil {
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
ensureConnValid(t, db)
}
func TestNormalLifeCycle(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)