From 76348773bdf5d028abb958ba1e61450a94693f14 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 10 Sep 2019 18:09:21 -0500 Subject: [PATCH] Make Conn.ConnInfo private --- conn.go | 23 +++++++++++---------- conn_test.go | 6 +++--- copy_from.go | 2 +- copy_from_test.go | 2 +- example_custom_type_test.go | 2 +- query_test.go | 40 ++----------------------------------- rows.go | 4 ++-- stdlib/sql.go | 2 +- values_test.go | 2 +- 9 files changed, 25 insertions(+), 58 deletions(-) diff --git a/conn.go b/conn.go index cde42a8f..d0450da8 100644 --- a/conn.go +++ b/conn.go @@ -62,7 +62,7 @@ type Conn struct { doneChan chan struct{} closedChan chan error - ConnInfo *pgtype.ConnInfo + connInfo *pgtype.ConnInfo wbuf []byte preallocatedRows []connRows @@ -171,7 +171,7 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { c = &Conn{ config: config, - ConnInfo: pgtype.NewConnInfo(), + connInfo: pgtype.NewConnInfo(), logLevel: config.LogLevel, logger: config.Logger, } @@ -408,6 +408,9 @@ func (c *Conn) PgConn() *pgconn.PgConn { return c.pgConn } // StatementCache returns the statement cache used for this connection. func (c *Conn) StatementCache() stmtcache.Cache { return c.stmtcache } +// ConnInfo returns the connection info used for this connection. +func (c *Conn) ConnInfo() *pgtype.ConnInfo { return c.connInfo } + // Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced // positionally from the sql string as $1, $2, etc. func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { @@ -499,14 +502,14 @@ func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, argu } for i := range args { - err = c.eqb.AppendParam(c.ConnInfo, sd.ParamOIDs[i], args[i]) + err = c.eqb.AppendParam(c.connInfo, sd.ParamOIDs[i], args[i]) if err != nil { return err } } for i := range sd.Fields { - c.eqb.AppendResultFormat(c.ConnInfo.ResultFormatCodeForOID(sd.Fields[i].DataTypeOID)) + c.eqb.AppendResultFormat(c.ConnInfo().ResultFormatCodeForOID(sd.Fields[i].DataTypeOID)) } return nil @@ -542,7 +545,7 @@ func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *con r.ctx = ctx r.logger = c - r.connInfo = c.ConnInfo + r.connInfo = c.connInfo r.startTime = time.Now() r.sql = sql r.args = args @@ -642,7 +645,7 @@ optionLoop: } for i := range args { - err = c.eqb.AppendParam(c.ConnInfo, sd.ParamOIDs[i], args[i]) + err = c.eqb.AppendParam(c.connInfo, sd.ParamOIDs[i], args[i]) if err != nil { rows.fatal(err) return rows, rows.err @@ -658,7 +661,7 @@ optionLoop: if resultFormats == nil { for i := range sd.Fields { - c.eqb.AppendResultFormat(c.ConnInfo.ResultFormatCodeForOID(sd.Fields[i].DataTypeOID)) + c.eqb.AppendResultFormat(c.ConnInfo().ResultFormatCodeForOID(sd.Fields[i].DataTypeOID)) } resultFormats = c.eqb.resultFormats @@ -732,14 +735,14 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { } for i := range args { - err = c.eqb.AppendParam(c.ConnInfo, sd.ParamOIDs[i], args[i]) + err = c.eqb.AppendParam(c.connInfo, sd.ParamOIDs[i], args[i]) if err != nil { return &batchResults{ctx: ctx, conn: c, err: err} } } for i := range sd.Fields { - c.eqb.AppendResultFormat(c.ConnInfo.ResultFormatCodeForOID(sd.Fields[i].DataTypeOID)) + c.eqb.AppendResultFormat(c.ConnInfo().ResultFormatCodeForOID(sd.Fields[i].DataTypeOID)) } if sd.Name == "" { @@ -770,7 +773,7 @@ func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, var err error valueArgs := make([]interface{}, len(args)) for i, a := range args { - valueArgs[i], err = convertSimpleArgument(c.ConnInfo, a) + valueArgs[i], err = convertSimpleArgument(c.connInfo, a) if err != nil { return "", err } diff --git a/conn_test.go b/conn_test.go index 673980ab..f2103775 100644 --- a/conn_test.go +++ b/conn_test.go @@ -803,11 +803,11 @@ func TestConnInitConnInfo(t *testing.T) { "text": pgtype.TextOID, } for name, oid := range nameOIDs { - dtByName, ok := conn.ConnInfo.DataTypeForName(name) + dtByName, ok := conn.ConnInfo().DataTypeForName(name) if !ok { t.Fatalf("Expected type named %v to be present", name) } - dtByOID, ok := conn.ConnInfo.DataTypeForOID(oid) + dtByOID, ok := conn.ConnInfo().DataTypeForOID(oid) if !ok { t.Fatalf("Expected type OID %v to be present", oid) } @@ -866,7 +866,7 @@ func TestDomainType(t *testing.T) { if err != nil { t.Fatalf("did not find uint64 OID, %v", err) } - conn.ConnInfo.RegisterDataType(pgtype.DataType{Value: &pgtype.Numeric{}, Name: "uint64", OID: uint64OID}) + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: &pgtype.Numeric{}, Name: "uint64", OID: uint64OID}) // String is still an acceptable argument after registration err = conn.QueryRow(context.Background(), "select $1::uint64", "7").Scan(&n) diff --git a/copy_from.go b/copy_from.go index 81c13976..b924412d 100644 --- a/copy_from.go +++ b/copy_from.go @@ -130,7 +130,7 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) for i, val := range values { - buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, sd.Fields[i].DataTypeOID, val) + buf, err = encodePreparedStatementArgument(ct.conn.connInfo, buf, sd.Fields[i].DataTypeOID, val) if err != nil { return false, nil, err } diff --git a/copy_from_test.go b/copy_from_test.go index cb45debe..5b1612ec 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -133,7 +133,7 @@ func TestConnCopyFromJSON(t *testing.T) { defer closeConn(t, conn) for _, typeName := range []string{"json", "jsonb"} { - if _, ok := conn.ConnInfo.DataTypeForName(typeName); !ok { + if _, ok := conn.ConnInfo().DataTypeForName(typeName); !ok { return // No JSON/JSONB type -- must be running against old PostgreSQL } } diff --git a/example_custom_type_test.go b/example_custom_type_test.go index 1cc16401..c35e999a 100644 --- a/example_custom_type_test.go +++ b/example_custom_type_test.go @@ -81,7 +81,7 @@ func Example_CustomType() { } // Override registered handler for point - conn.ConnInfo.RegisterDataType(pgtype.DataType{ + conn.ConnInfo().RegisterDataType(pgtype.DataType{ Value: &Point{}, Name: "point", OID: 600, diff --git a/query_test.go b/query_test.go index 431135b5..1e075fa6 100644 --- a/query_test.go +++ b/query_test.go @@ -965,42 +965,6 @@ func TestQueryRowCoreByteSlice(t *testing.T) { } } -func TestQueryRowUnknownType(t *testing.T) { - t.Parallel() - - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - // Clear existing type mappings - conn.ConnInfo = pgtype.NewConnInfo() - conn.ConnInfo.RegisterDataType(pgtype.DataType{ - Value: &pgtype.GenericText{}, - Name: "point", - OID: 600, - }) - conn.ConnInfo.RegisterDataType(pgtype.DataType{ - Value: &pgtype.Int4{}, - Name: "int4", - OID: pgtype.Int4OID, - }) - - sql := "select $1::point" - expected := "(1,0)" - var actual string - - err := conn.QueryRow(context.Background(), sql, expected).Scan(&actual) - if err != nil { - t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) - } - - if actual != expected { - t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, expected, actual, sql) - - } - - ensureConnValid(t, conn) -} - func TestQueryRowErrors(t *testing.T) { t.Parallel() @@ -1197,7 +1161,7 @@ func TestConnQueryDatabaseSQLDriverValuerWithBinaryPgTypeThatAcceptsSameType(t * conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - conn.ConnInfo.RegisterDataType(pgtype.DataType{ + conn.ConnInfo().RegisterDataType(pgtype.DataType{ Value: &satori.UUID{}, Name: "uuid", OID: 2950, @@ -1413,7 +1377,7 @@ func TestRowsFromResultReader(t *testing.T) { var sum, rowCount int32 - rows := pgx.RowsFromResultReader(conn.ConnInfo, resultReader) + rows := pgx.RowsFromResultReader(conn.ConnInfo(), resultReader) defer rows.Close() for rows.Next() { diff --git a/rows.go b/rows.go index 379bdc7a..7389c56b 100644 --- a/rows.go +++ b/rows.go @@ -188,7 +188,7 @@ func (rows *connRows) Scan(dest ...interface{}) error { continue } - err := rows.connInfo.Scan(uint32(fd.DataTypeOID), fd.Format, buf, d) + err := rows.connInfo.Scan(fd.DataTypeOID, fd.Format, buf, d) if err != nil { rows.fatal(scanArgError{col: i, err: err}) return err @@ -214,7 +214,7 @@ func (rows *connRows) Values() ([]interface{}, error) { continue } - if dt, ok := rows.connInfo.DataTypeForOID(uint32(fd.DataTypeOID)); ok { + if dt, ok := rows.connInfo.DataTypeForOID(fd.DataTypeOID); ok { value := reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(pgtype.Value) switch fd.Format { diff --git a/stdlib/sql.go b/stdlib/sql.go index f4c8fb8c..1218d5f2 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -310,7 +310,7 @@ func (r *Rows) Columns() []string { // ColumnTypeDatabaseTypeName return the database system type name. func (r *Rows) ColumnTypeDatabaseTypeName(index int) string { - if dt, ok := r.conn.conn.ConnInfo.DataTypeForOID(r.rows.FieldDescriptions()[index].DataTypeOID); ok { + if dt, ok := r.conn.conn.ConnInfo().DataTypeForOID(r.rows.FieldDescriptions()[index].DataTypeOID); ok { return strings.ToUpper(dt.Name) } diff --git a/values_test.go b/values_test.go index 48c36628..282045d9 100644 --- a/values_test.go +++ b/values_test.go @@ -80,7 +80,7 @@ func TestJSONAndJSONBTranscode(t *testing.T) { defer closeConn(t, conn) for _, typename := range []string{"json", "jsonb"} { - if _, ok := conn.ConnInfo.DataTypeForName(typename); !ok { + if _, ok := conn.ConnInfo().DataTypeForName(typename); !ok { continue // No JSON/JSONB type -- must be running against old PostgreSQL }