From 43dcd47a92e74baef888010c3d78c46e8fe028c8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Jul 2014 18:23:19 -0500 Subject: [PATCH] Move to Scan interface Remove SelectValue --- bench_test.go | 42 ----- conn.go | 222 ++++++++++++-------------- conn_pool.go | 16 +- conn_pool_test.go | 44 ++++-- conn_test.go | 258 +++++++++++-------------------- example_value_transcoder_test.go | 90 ----------- examples/url_shortener/main.go | 5 +- helper_test.go | 8 - stdlib/sql.go | 7 +- value_transcoder_test.go | 235 ++++++++++++++-------------- 10 files changed, 351 insertions(+), 576 deletions(-) delete mode 100644 example_value_transcoder_test.go diff --git a/bench_test.go b/bench_test.go index bf1e4627..d2e0f62c 100644 --- a/bench_test.go +++ b/bench_test.go @@ -2,51 +2,9 @@ package pgx_test import ( "github.com/jackc/pgx" - "math/rand" "testing" ) -func createNarrowTestData(b *testing.B, conn *pgx.Conn) { - mustExec(b, conn, ` - drop table if exists narrow; - - create table narrow( - id serial primary key, - a int not null, - b int not null, - c int not null, - d int not null - ); - - insert into narrow(a, b, c, d) - select (random()*1000000)::int, (random()*1000000)::int, (random()*1000000)::int, (random()*1000000)::int - from generate_series(1, 10000); - - analyze narrow; - `) - - mustPrepare(b, conn, "getNarrowById", "select * from narrow where id=$1") - mustPrepare(b, conn, "getMultipleNarrowById", "select * from narrow where id between $1 and $2") - mustPrepare(b, conn, "getMultipleNarrowByIdAsJSON", "select json_agg(row_to_json(narrow)) from narrow where id between $1 and $2") -} - -func BenchmarkSelectValuePreparedNarrow(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - createNarrowTestData(b, conn) - - // Get random ids outside of timing - ids := make([]int32, b.N) - for i := 0; i < b.N; i++ { - ids[i] = 1 + rand.Int31n(9999) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - mustSelectValue(b, conn, "getMultipleNarrowByIdAsJSON", ids[i], ids[i]+10) - } -} - func BenchmarkConnPool(b *testing.B) { config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 5} pool, err := pgx.NewConnPool(config) diff --git a/conn.go b/conn.go index ffd011f7..78a810a2 100644 --- a/conn.go +++ b/conn.go @@ -271,46 +271,6 @@ func ParseURI(uri string) (ConnConfig, error) { return cp, nil } -// SelectValue executes sql and returns a single value. sql can be either a prepared -// statement name or an SQL string. arguments will be sanitized before being -// interpolated into sql strings. arguments should be referenced positionally from -// the sql string as $1, $2, etc. -// -// Returns a UnexpectedColumnCountError if exactly one column is not found -// Returns a NotSingleRowError if exactly one row is not found -func (c *Conn) SelectValue(sql string, arguments ...interface{}) (interface{}, error) { - startTime := time.Now() - - var numRowsFound int64 - var v interface{} - - qr, _ := c.Query(sql, arguments...) - defer qr.Close() - - for qr.NextRow() { - if len(qr.fields) != 1 { - qr.Close() - return nil, UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: int16(len(qr.fields))} - } - - numRowsFound++ - var rr RowReader - v = rr.ReadValue(qr) - } - if qr.Err() != nil { - return nil, qr.Err() - } - - if numRowsFound != 1 { - return nil, NotSingleRowError{RowCount: numRowsFound} - } - - endTime := time.Now() - c.logger.Info("SelectValue", "sql", sql, "args", arguments, "rowsFound", numRowsFound, "time", endTime.Sub(startTime)) - - return v, nil -} - // Prepare creates a prepared statement with name and sql. sql can contain placeholders // for bound parameters. These placeholders are referenced positional as $1, $2, etc. func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { @@ -459,99 +419,26 @@ func (c *Conn) CauseOfDeath() error { return c.causeOfDeath } -type RowReader struct{} +type Row QueryResult -// TODO - Read*... +func (r *Row) Scan(dest ...interface{}) (err error) { + qr := (*QueryResult)(r) -func (rr *RowReader) ReadInt32(qr *QueryResult) int32 { - fd, size, ok := qr.NextColumn() - if !ok { - return 0 + if qr.Err() != nil { + return qr.Err() } - if size == -1 { - qr.Fatal(errors.New("Unexpected null")) - return 0 - } - - return decodeInt4(qr, fd, size) -} - -func (rr *RowReader) ReadInt64(qr *QueryResult) int64 { - fd, size, ok := qr.NextColumn() - if !ok { - return 0 - } - - if size == -1 { - qr.Fatal(errors.New("Unexpected null")) - return 0 - } - - return decodeInt8(qr, fd, size) -} - -func (rr *RowReader) ReadTime(qr *QueryResult) time.Time { - var zeroTime time.Time - - fd, size, ok := qr.NextColumn() - if !ok { - return zeroTime - } - - if size == -1 { - qr.Fatal(errors.New("Unexpected null")) - return zeroTime - } - - return decodeTimestampTz(qr, fd, size) -} - -func (rr *RowReader) ReadDate(qr *QueryResult) time.Time { - var zeroTime time.Time - - fd, size, ok := qr.NextColumn() - if !ok { - return zeroTime - } - - if size == -1 { - qr.Fatal(errors.New("Unexpected null")) - return zeroTime - } - - return decodeDate(qr, fd, size) -} - -func (rr *RowReader) ReadString(qr *QueryResult) string { - fd, size, ok := qr.NextColumn() - if !ok { - return "" - } - - if size == -1 { - qr.Fatal(errors.New("Unexpected null")) - return "" - } - - return decodeText(qr, fd, size) -} - -func (rr *RowReader) ReadValue(qr *QueryResult) interface{} { - fd, size, ok := qr.NextColumn() - if !ok { - return nil - } - - if size > -1 { - if vt, present := ValueTranscoders[fd.DataType]; present && vt.Decode != nil { - return vt.Decode(qr, fd, size) + if !qr.NextRow() { + if qr.Err() == nil { + return errors.New("No rows") } else { - return decodeText(qr, fd, size) + return qr.Err() } - } else { - return nil } + + qr.Scan(dest...) + qr.Close() + return qr.Err() } type QueryResult struct { @@ -682,6 +569,84 @@ func (qr *QueryResult) NextColumn() (*FieldDescription, int32, bool) { return fd, size, true } +func (qr *QueryResult) Scan(dest ...interface{}) (err error) { + if len(qr.fields) != len(dest) { + err = errors.New("Scan received wrong number of arguments") + qr.Fatal(err) + return err + } + + for _, d := range dest { + fd, size, _ := qr.NextColumn() + switch d := d.(type) { + case *bool: + *d = decodeBool(qr, fd, size) + case *[]byte: + *d = decodeBytea(qr, fd, size) + case *int64: + *d = decodeInt8(qr, fd, size) + case *int16: + *d = decodeInt2(qr, fd, size) + case *int32: + *d = decodeInt4(qr, fd, size) + case *string: + *d = decodeText(qr, fd, size) + case *float32: + *d = decodeFloat4(qr, fd, size) + case *float64: + *d = decodeFloat8(qr, fd, size) + case *time.Time: + if fd.DataType == DateOid { + *d = decodeDate(qr, fd, size) + } else { + *d = decodeTimestampTz(qr, fd, size) + } + } + } + + return nil +} + +func (qr *QueryResult) ReadValue() (v interface{}, err error) { + fd, size, _ := qr.NextColumn() + if qr.Err() != nil { + return nil, qr.Err() + } + + switch fd.DataType { + case BoolOid: + return decodeBool(qr, fd, size), qr.Err() + case ByteaOid: + return decodeBytea(qr, fd, size), qr.Err() + case Int8Oid: + return decodeInt8(qr, fd, size), qr.Err() + case Int2Oid: + return decodeInt2(qr, fd, size), qr.Err() + case Int4Oid: + return decodeInt4(qr, fd, size), qr.Err() + case VarcharOid, TextOid: + return decodeText(qr, fd, size), qr.Err() + case Float4Oid: + return decodeFloat4(qr, fd, size), qr.Err() + case Float8Oid: + return decodeFloat8(qr, fd, size), qr.Err() + case DateOid: + return decodeDate(qr, fd, size), qr.Err() + case TimestampTzOid: + return decodeTimestampTz(qr, fd, size), qr.Err() + } + + // if it is not an intrinsic type then return the text + switch fd.FormatCode { + case TextFormatCode: + return qr.MsgReader().ReadString(size), qr.Err() + // TODO + //case BinaryFormatCode: + default: + return nil, errors.New("Unknown format code") + } +} + // TODO - document func (c *Conn) Query(sql string, args ...interface{}) (*QueryResult, error) { c.qr = QueryResult{conn: c} @@ -728,6 +693,11 @@ func (c *Conn) Query(sql string, args ...interface{}) (*QueryResult, error) { } } +func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { + qr, _ := c.Query(sql, args...) + return (*Row)(qr) +} + func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) { if ps, present := c.preparedStatements[sql]; present { return c.sendPreparedQuery(ps, arguments...) diff --git a/conn_pool.go b/conn_pool.go index 3bac3de9..a19fed7f 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -165,17 +165,6 @@ func (p *ConnPool) createConnection() (c *Conn, err error) { return } -// SelectValue acquires a connection, delegates the call to that connection, and releases the connection -func (p *ConnPool) SelectValue(sql string, arguments ...interface{}) (v interface{}, err error) { - var c *Conn - if c, err = p.Acquire(); err != nil { - return - } - defer p.Release(c) - - return c.SelectValue(sql, arguments...) -} - // Exec acquires a connection, delegates the call to that connection, and releases the connection func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { var c *Conn @@ -204,6 +193,11 @@ func (p *ConnPool) Query(sql string, args ...interface{}) (*QueryResult, error) return qr, nil } +func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row { + qr, _ := p.Query(sql, args...) + return (*Row)(qr) +} + // Transaction acquires a connection, delegates the call to that connection, // and releases the connection. The call signature differs slightly from the // underlying Transaction in that the callback function accepts a *Conn diff --git a/conn_pool_test.go b/conn_pool_test.go index 92b66450..1d364253 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -138,8 +138,8 @@ func TestPoolAcquireAndReleaseCycle(t *testing.T) { allConnections = acquireAll() for _, c := range allConnections { - v := mustSelectValue(t, c, "select counter from t") - n := v.(int32) + var n int32 + c.QueryRow("select counter from t").Scan(&n) if n == 0 { t.Error("A connection was never used") } @@ -209,7 +209,8 @@ func TestPoolAcquireAndReleaseCycleAutoConnect(t *testing.T) { if err != nil { t.Fatalf("Unable to Acquire: %v", err) } - c.SelectValue("select 1") + qr, _ := c.Query("select 1") + qr.Close() pool.Release(c) } @@ -272,7 +273,9 @@ func TestPoolReleaseDiscardsDeadConnections(t *testing.T) { } // do something with the connection so it knows it's dead - if _, err = c1.SelectValue("select 1"); err == nil { + qr, _ := c1.Query("select 1") + qr.Close() + if qr.Err() == nil { t.Fatal("Expected error but none occurred") } @@ -318,15 +321,22 @@ func TestPoolTransaction(t *testing.T) { } committed, err = pool.Transaction(func(conn *pgx.Conn) bool { - n := mustSelectValue(t, conn, "select count(*) from foo") - if n.(int64) != 0 { + var n int64 + err := conn.QueryRow("select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow.Scan failed: %v", err) + } + if n != 0 { t.Fatalf("Did not receive expected value: %v", n) } mustExec(t, conn, "insert into foo(id) values(default)") - n = mustSelectValue(t, conn, "select count(*) from foo") - if n.(int64) != 1 { + err = conn.QueryRow("select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow.Scan failed: %v", err) + } + if n != 1 { t.Fatalf("Did not receive expected value: %v", n) } @@ -340,8 +350,12 @@ func TestPoolTransaction(t *testing.T) { } committed, err = pool.Transaction(func(conn *pgx.Conn) bool { - n := mustSelectValue(t, conn, "select count(*) from foo") - if n.(int64) != 0 { + var n int64 + err := conn.QueryRow("select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow.Scan failed: %v", err) + } + if n != 0 { t.Fatalf("Did not receive expected value: %v", n) } return true @@ -362,7 +376,10 @@ func TestPoolTransactionIso(t *testing.T) { defer pool.Close() committed, err := pool.TransactionIso("serializable", func(conn *pgx.Conn) bool { - if level := mustSelectValue(t, conn, "select current_setting('transaction_isolation')"); level != "serializable" { + var level string + conn.QueryRow("select current_setting('transaction_isolation')").Scan(&level) + + if level != "serializable" { t.Errorf("Expected to be in isolation level %v but was %v", "serializable", level) } return true @@ -394,8 +411,9 @@ func TestConnPoolQuery(t *testing.T) { } for qr.NextRow() { - var rr pgx.RowReader - sum += rr.ReadInt32(qr) + var n int32 + qr.Scan(&n) + sum += n rowCount++ } diff --git a/conn_test.go b/conn_test.go index 731f5b35..bddf49c7 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,8 +1,6 @@ package pgx_test import ( - "bytes" - "fmt" "github.com/jackc/pgx" "strings" "sync" @@ -30,12 +28,21 @@ func TestConnect(t *testing.T) { t.Error("Backend secret key not stored") } - currentDB, err := conn.SelectValue("select current_database()") - if err != nil || currentDB != defaultConnConfig.Database { + var currentDB string + err = conn.QueryRow("select current_database()").Scan(¤tDB) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if currentDB != defaultConnConfig.Database { t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database) } - if user := mustSelectValue(t, conn, "select current_user"); user != defaultConnConfig.User { + var user string + err = conn.QueryRow("select current_user").Scan(&user) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if user != defaultConnConfig.User { t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User) } @@ -272,8 +279,10 @@ func TestExecFailure(t *testing.T) { t.Fatal("Expected SQL syntax error") } - if _, err := conn.SelectValue("select 1"); err != nil { - t.Fatalf("Exec failure appears to have broken connection: %v", err) + qr, _ := conn.Query("select 1") + qr.Close() + if qr.Err() != nil { + t.Fatalf("Exec failure appears to have broken connection: %v", qr.Err()) } } @@ -292,8 +301,9 @@ func TestConnQuery(t *testing.T) { defer rows.Close() for rows.NextRow() { - var rr pgx.RowReader - sum += rr.ReadInt32(rows) + var n int32 + rows.Scan(&n) + sum += n rowCount++ } @@ -320,8 +330,9 @@ func ensureConnValid(t *testing.T, conn *pgx.Conn) { defer qr.Close() for qr.NextRow() { - var rr pgx.RowReader - sum += rr.ReadInt32(qr) + var n int32 + qr.Scan(&n) + sum += n rowCount++ } @@ -364,8 +375,9 @@ func TestConnQueryCloseEarly(t *testing.T) { t.Fatal("qr.NextRow terminated early") } - var rr pgx.RowReader - if n := rr.ReadInt32(qr); n != 1 { + var n int32 + qr.Scan(&n) + if n != 1 { t.Fatalf("Expected 1 from first row, but got %v", n) } @@ -390,31 +402,8 @@ func TestConnQueryReadWrongTypeError(t *testing.T) { rowsRead := 0 for qr.NextRow() { - var rr pgx.RowReader - rr.ReadDate(qr) - rowsRead++ - } - - if rowsRead != 1 { - t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead) - } - - if qr.Err() == nil { - t.Fatal("Expected QueryResult to have an error after an improper read but it didn't") - } - - // Read too many values - qr, err = conn.Query("select generate_series(1,$1)", 10) - if err != nil { - t.Fatalf("conn.Query failed: ", err) - } - - rowsRead = 0 - - for qr.NextRow() { - var rr pgx.RowReader - rr.ReadInt32(qr) - rr.ReadInt32(qr) + var t time.Time + qr.Scan(&t) rowsRead++ } @@ -445,9 +434,8 @@ func TestConnQueryReadTooManyValues(t *testing.T) { rowsRead := 0 for qr.NextRow() { - var rr pgx.RowReader - rr.ReadInt32(qr) - rr.ReadInt32(qr) + var n, m int32 + qr.Scan(&n, &m) rowsRead++ } @@ -462,126 +450,26 @@ func TestConnQueryReadTooManyValues(t *testing.T) { ensureConnValid(t, conn) } -func TestConnectionSelectValue(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - test := func(sql string, expected interface{}, arguments ...interface{}) { - v, err := conn.SelectValue(sql, arguments...) - if err != nil { - t.Errorf("%v while running %v", err, sql) - } else { - if v != expected { - t.Errorf("Expected: %#v Received: %#v", expected, v) - } - } - } - - fmt.Println("Starting test") - test("select $1", "foo", "foo") - test("select 'foo'", "foo") - test("select true", true) - test("select false", false) - test("select 1::int2", int16(1)) - test("select 1::int4", int32(1)) - test("select 1::int8", int64(1)) - test("select 1.23::float4", float32(1.23)) - test("select 1.23::float8", float64(1.23)) - - _, err := conn.SelectValue("select 'Jack' as name where 1=2") - if _, ok := err.(pgx.NotSingleRowError); !ok { - t.Error("No matching row should have returned NoRowsFoundError") - } - - _, err = conn.SelectValue("select * from (values ('Matthew'), ('Mark')) t") - if _, ok := err.(pgx.NotSingleRowError); !ok { - t.Error("Multiple matching rows should have returned NotSingleRowError") - } - - _, err = conn.SelectValue("select 'Matthew', 'Mark'") - if _, ok := err.(pgx.UnexpectedColumnCountError); !ok { - t.Error("Multiple columns should have returned UnexpectedColumnCountError") - } -} - func TestPrepare(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - testTranscode := func(sql string, value interface{}) { - if _, err := conn.Prepare("testTranscode", sql); err != nil { - t.Errorf("Unable to prepare statement: %v", err) - return - } - defer func() { - err := conn.Deallocate("testTranscode") - if err != nil { - t.Errorf("Deallocate failed: %v", err) - } - }() - - result, err := conn.SelectValue("testTranscode", value) - if err != nil { - t.Errorf("%v while running %v", err, "testTranscode") - } else { - if result != value { - t.Errorf("Expected: %#v Received: %#v", value, result) - } - } - } - - // Test parameter encoding and decoding for simple supported data types - testTranscode("select $1::varchar", "foo") - testTranscode("select $1::text", "foo") - testTranscode("select $1::int2", int16(1)) - testTranscode("select $1::int4", int32(1)) - testTranscode("select $1::int8", int64(1)) - testTranscode("select $1::float4", float32(1.23)) - testTranscode("select $1::float8", float64(1.23)) - testTranscode("select $1::boolean", true) - - // Ensure that unknown types are just treated as strings - testTranscode("select $1::point", "(0,0)") - - if _, err := conn.Prepare("testByteSliceTranscode", "select $1::bytea"); err != nil { + _, err := conn.Prepare("test", "select $1::varchar") + if err != nil { t.Errorf("Unable to prepare statement: %v", err) return } - defer func() { - err := conn.Deallocate("testByteSliceTranscode") - if err != nil { - t.Errorf("Deallocate failed: %v", err) - } - }() - bytea := make([]byte, 4) - bytea[0] = 0 // 0x00 - bytea[1] = 15 // 0x0F - bytea[2] = 255 // 0xFF - bytea[3] = 17 // 0x11 - - if sql, err := conn.SanitizeSql("select $1", bytea); err != nil { - t.Errorf("Error sanitizing []byte: %v", err) - } else if sql != `select E'\\x000fff11'` { - t.Error("Failed to sanitize []byte") - } - - result, err := conn.SelectValue("testByteSliceTranscode", bytea) + var s string + err = conn.QueryRow("test", "hello").Scan(&s) if err != nil { - t.Errorf("%v while running %v", err, "testByteSliceTranscode") - } else { - if bytes.Compare(result.([]byte), bytea) != 0 { - t.Errorf("Expected: %#v Received: %#v", bytea, result) - } + t.Errorf("Executing prepared statement failed: %v", err) } - mustExec(t, conn, "create temporary table foo(id serial)") - if _, err = conn.Prepare("deleteFoo", "delete from foo"); err != nil { - t.Fatalf("Unable to prepare delete: %v", err) + if s != "hello" { + t.Errorf("Prepared statement did not return expected value: %v", s) } } @@ -595,9 +483,7 @@ func TestPrepareFailure(t *testing.T) { t.Fatal("Prepare should have failed with syntax error") } - if _, err := conn.SelectValue("select 1"); err != nil { - t.Fatalf("Prepare failure appears to have broken connection: %v", err) - } + ensureConnValid(t, conn) } func TestTransaction(t *testing.T) { @@ -629,9 +515,12 @@ func TestTransaction(t *testing.T) { t.Fatal("Transaction was not committed") } - var n interface{} - n = mustSelectValue(t, conn, "select count(*) from foo") - if n.(int64) != 1 { + var n int64 + err = conn.QueryRow("select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if n != 1 { t.Fatalf("Did not receive correct number of rows: %v", n) } @@ -648,8 +537,11 @@ func TestTransaction(t *testing.T) { if committed { t.Fatal("Transaction should not have been committed") } - n = mustSelectValue(t, conn, "select count(*) from foo") - if n.(int64) != 0 { + err = conn.QueryRow("select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if n != 0 { t.Fatalf("Did not receive correct number of rows: %v", n) } @@ -667,8 +559,11 @@ func TestTransaction(t *testing.T) { if committed { t.Fatal("Transaction was committed when it shouldn't have been") } - n = mustSelectValue(t, conn, "select count(*) from foo") - if n.(int64) != 0 { + err = conn.QueryRow("select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if n != 0 { t.Fatalf("Did not receive correct number of rows: %v", n) } @@ -685,8 +580,11 @@ func TestTransaction(t *testing.T) { t.Fatal("Transaction was committed when it should have failed") } - n = mustSelectValue(t, conn, "select count(*) from foo") - if n.(int64) != 0 { + err = conn.QueryRow("select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if n != 0 { t.Fatalf("Did not receive correct number of rows: %v", n) } @@ -701,8 +599,11 @@ func TestTransaction(t *testing.T) { panic("stop!") }) - n = mustSelectValue(t, conn, "select count(*) from foo") - if n.(int64) != 0 { + err = conn.QueryRow("select count(*) from foo").Scan(&n) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if n != 0 { t.Fatalf("Did not receive correct number of rows: %v", n) } }() @@ -717,7 +618,9 @@ func TestTransactionIso(t *testing.T) { isoLevels := []string{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted} for _, iso := range isoLevels { _, err := conn.TransactionIso(iso, func() bool { - if level := mustSelectValue(t, conn, "select current_setting('transaction_isolation')"); level != iso { + var level string + conn.QueryRow("select current_setting('transaction_isolation')").Scan(&level) + if level != iso { t.Errorf("Expected to be in isolation level %v but was %v", iso, level) } return true @@ -754,7 +657,11 @@ func TestListenNotify(t *testing.T) { // when notification has already been read during previous query mustExec(t, notifier, "notify chat") - mustSelectValue(t, listener, "select 1") + qr, _ := listener.Query("select 1") + qr.Close() + if qr.Err() != nil { + t.Fatalf("Unexpected error on Query: %v", qr.Err()) + } notification, err = listener.WaitForNotification(0) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) @@ -793,9 +700,11 @@ func TestFatalRxError(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - _, err := conn.SelectValue("select 1, pg_sleep(10)") - if err == nil { - t.Fatal("Expected error but none occurred") + var n int32 + var s string + err := conn.QueryRow("select 1::int4, pg_sleep(10)::varchar").Scan(&n, &s) + if err, ok := err.(pgx.PgError); !ok || err.Severity != "FATAL" { + t.Fatalf("Expected QueryRow Scan to return fatal PgError, but instead received %v", err) } }() @@ -833,7 +742,7 @@ func TestFatalTxError(t *testing.T) { t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) } - _, err = conn.SelectValue("select 1") + _, err = conn.Query("select 1") if err == nil { t.Fatal("Expected error but none occurred") } @@ -867,3 +776,16 @@ func TestCommandTag(t *testing.T) { } } } + +func TestQueryRowError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var n int32 + err := conn.QueryRow("SYNTAX ERROR").Scan(&n) + if _, ok := err.(pgx.PgError); !ok { + t.Fatalf("Expected to receive PgError, but instead received: %v", err) + } +} diff --git a/example_value_transcoder_test.go b/example_value_transcoder_test.go deleted file mode 100644 index fe5cb632..00000000 --- a/example_value_transcoder_test.go +++ /dev/null @@ -1,90 +0,0 @@ -package pgx_test - -import ( - "fmt" - "github.com/jackc/pgx" - "regexp" - "strconv" -) - -const ( - pointOid = 600 -) - -var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`) - -type Point struct { - x float64 - y float64 -} - -func (p Point) String() string { - return fmt.Sprintf("%v, %v", p.x, p.y) -} - -func Example_customValueTranscoder() { - pgx.ValueTranscoders[pointOid] = &pgx.ValueTranscoder{ - Decode: func(qr *pgx.QueryResult, fd *pgx.FieldDescription, size int32) interface{} { - return decodePoint(qr, fd, size) - }, - EncodeTo: encodePoint} - - conn, err := pgx.Connect(*defaultConnConfig) - if err != nil { - fmt.Printf("Unable to establish connection: %v", err) - return - } - - v, _ := conn.SelectValue("select point(1.5,2.5)") - fmt.Println(v) - // Output: - // 1.5, 2.5 -} - -func decodePoint(qr *pgx.QueryResult, fd *pgx.FieldDescription, size int32) Point { - var p Point - - if fd.DataType != pointOid { - qr.Fatal(pgx.ProtocolError(fmt.Sprintf("Tried to read point but received: %v", fd.DataType))) - return p - } - - switch fd.FormatCode { - case pgx.TextFormatCode: - s := qr.MsgReader().ReadString(size) - match := pointRegexp.FindStringSubmatch(s) - if match == nil { - qr.Fatal(pgx.ProtocolError(fmt.Sprintf("Received invalid point: %v", s))) - return p - } - - var err error - p.x, err = strconv.ParseFloat(match[1], 64) - if err != nil { - qr.Fatal(pgx.ProtocolError(fmt.Sprintf("Received invalid point: %v", s))) - return p - } - p.y, err = strconv.ParseFloat(match[2], 64) - if err != nil { - qr.Fatal(pgx.ProtocolError(fmt.Sprintf("Received invalid point: %v", s))) - return p - } - return p - default: - qr.Fatal(pgx.ProtocolError(fmt.Sprintf("Unknown field description format code: %v", fd.FormatCode))) - return p - } -} - -func encodePoint(w *pgx.WriteBuf, value interface{}) error { - p, ok := value.(Point) - if !ok { - return fmt.Errorf("Expected Point, received %T", value) - } - - s := fmt.Sprintf("point(%v,%v)", p.x, p.y) - w.WriteInt32(int32(len(s))) - w.WriteBytes([]byte(s)) - - return nil -} diff --git a/examples/url_shortener/main.go b/examples/url_shortener/main.go index 0daf92f5..d2a1a20b 100644 --- a/examples/url_shortener/main.go +++ b/examples/url_shortener/main.go @@ -44,8 +44,9 @@ func afterConnect(conn *pgx.Conn) (err error) { } func getUrlHandler(w http.ResponseWriter, req *http.Request) { - if url, err := pool.SelectValue("getUrl", req.URL.Path); err == nil { - http.Redirect(w, req, url.(string), http.StatusSeeOther) + var url string + if err := pool.QueryRow("getUrl", req.URL.Path).Scan(&url); err == nil { + http.Redirect(w, req, url, http.StatusSeeOther) } else if _, ok := err.(pgx.NotSingleRowError); ok { http.NotFound(w, req) } else { diff --git a/helper_test.go b/helper_test.go index f51e1240..7eb5062a 100644 --- a/helper_test.go +++ b/helper_test.go @@ -33,11 +33,3 @@ func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{} } return } - -func mustSelectValue(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (value interface{}) { - var err error - if value, err = conn.SelectValue(sql, arguments...); err != nil { - t.Fatalf("SelectValue unexpectedly failed with %v: %v", sql, err) - } - return -} diff --git a/stdlib/sql.go b/stdlib/sql.go index 58f99d58..8c9025a5 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -191,9 +191,12 @@ func (r *Rows) Next(dest []driver.Value) error { } } - var rr pgx.RowReader for i, _ := range r.qr.FieldDescriptions() { - dest[i] = driver.Value(rr.ReadValue(r.qr)) + v, err := r.qr.ReadValue() + if err != nil { + return err + } + dest[i] = driver.Value(v) } return nil diff --git a/value_transcoder_test.go b/value_transcoder_test.go index 229839fd..b86139c4 100644 --- a/value_transcoder_test.go +++ b/value_transcoder_test.go @@ -5,7 +5,7 @@ import ( "time" ) -func TestTranscodeError(t *testing.T) { +func TestEncodeError(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -18,7 +18,7 @@ func TestTranscodeError(t *testing.T) { } }() - _, err := conn.SelectValue("testTranscode", "wrong") + _, err := conn.Query("testTranscode", "wrong") switch { case err == nil: t.Error("Expected transcode error to return error, but it didn't") @@ -29,31 +29,32 @@ func TestTranscodeError(t *testing.T) { } } +// TODO func TestNilTranscode(t *testing.T) { - t.Parallel() + // t.Parallel() - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) + // conn := mustConnect(t, *defaultConnConfig) + // defer closeConn(t, conn) - var inputNil interface{} - inputNil = nil + // var inputNil interface{} + // inputNil = nil - result := mustSelectValue(t, conn, "select $1::integer", inputNil) - if result != nil { - t.Errorf("Did not transcode nil successfully for normal query: %v", result) - } + // result := mustSelectValue(t, conn, "select $1::integer", inputNil) + // if result != nil { + // t.Errorf("Did not transcode nil successfully for normal query: %v", result) + // } - mustPrepare(t, conn, "testTranscode", "select $1::integer") - defer func() { - if err := conn.Deallocate("testTranscode"); err != nil { - t.Fatalf("Unable to deallocate prepared statement: %v", err) - } - }() + // mustPrepare(t, conn, "testTranscode", "select $1::integer") + // defer func() { + // if err := conn.Deallocate("testTranscode"); err != nil { + // t.Fatalf("Unable to deallocate prepared statement: %v", err) + // } + // }() - result = mustSelectValue(t, conn, "testTranscode", inputNil) - if result != nil { - t.Errorf("Did not transcode nil successfully for prepared query: %v", result) - } + // result = mustSelectValue(t, conn, "testTranscode", inputNil) + // if result != nil { + // t.Errorf("Did not transcode nil successfully for prepared query: %v", result) + // } } func TestDateTranscode(t *testing.T) { @@ -80,21 +81,24 @@ func TestDateTranscode(t *testing.T) { } for _, actualDate := range dates { - var v interface{} var d time.Time // Test text format - v = mustSelectValue(t, conn, "select $1::date", actualDate) - d = v.(time.Time) + err := conn.QueryRow("select $1::date", actualDate).Scan(&d) + if err != nil { + t.Fatalf("Unexpected failure on QueryRow Scan: %v", err) + } if !actualDate.Equal(d) { - t.Errorf("Did not transcode date successfully: %v is not %v", v, actualDate) + t.Errorf("Did not transcode date successfully: %v is not %v", d, actualDate) } // Test binary format - v = mustSelectValue(t, conn, "testTranscode", actualDate) - d = v.(time.Time) + err = conn.QueryRow("testTranscode", actualDate).Scan(&d) + if err != nil { + t.Fatalf("Unexpected failure on QueryRow Scan: %v", err) + } if !actualDate.Equal(d) { - t.Errorf("Did not transcode date successfully: %v is not %v", v, actualDate) + t.Errorf("Did not transcode date successfully: %v is not %v", d, actualDate) } } } @@ -107,11 +111,12 @@ func TestTimestampTzTranscode(t *testing.T) { inputTime := time.Date(2013, 1, 2, 3, 4, 5, 6000, time.Local) - var v interface{} var outputTime time.Time - v = mustSelectValue(t, conn, "select $1::timestamptz", inputTime) - outputTime = v.(time.Time) + err := conn.QueryRow("select $1::timestamptz", inputTime).Scan(&outputTime) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } if !inputTime.Equal(outputTime) { t.Errorf("Did not transcode time successfully: %v is not %v", outputTime, inputTime) } @@ -123,111 +128,113 @@ func TestTimestampTzTranscode(t *testing.T) { } }() - v = mustSelectValue(t, conn, "testTranscode", inputTime) - outputTime = v.(time.Time) + err = conn.QueryRow("testTranscode", inputTime).Scan(&outputTime) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } if !inputTime.Equal(outputTime) { t.Errorf("Did not transcode time successfully: %v is not %v", outputTime, inputTime) } } -func TestInt2SliceTranscode(t *testing.T) { - t.Parallel() +// func TestInt2SliceTranscode(t *testing.T) { +// t.Parallel() - testEqual := func(a, b []int16) { - if len(a) != len(b) { - t.Errorf("Did not transcode []int16 successfully: %v is not %v", a, b) - } - for i := range a { - if a[i] != b[i] { - t.Errorf("Did not transcode []int16 successfully: %v is not %v", a, b) - } - } - } +// testEqual := func(a, b []int16) { +// if len(a) != len(b) { +// t.Errorf("Did not transcode []int16 successfully: %v is not %v", a, b) +// } +// for i := range a { +// if a[i] != b[i] { +// t.Errorf("Did not transcode []int16 successfully: %v is not %v", a, b) +// } +// } +// } - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) +// conn := mustConnect(t, *defaultConnConfig) +// defer closeConn(t, conn) - inputNumbers := []int16{1, 2, 3, 4, 5, 6, 7, 8} - var outputNumbers []int16 +// inputNumbers := []int16{1, 2, 3, 4, 5, 6, 7, 8} +// var outputNumbers []int16 - outputNumbers = mustSelectValue(t, conn, "select $1::int2[]", inputNumbers).([]int16) - testEqual(inputNumbers, outputNumbers) +// outputNumbers = mustSelectValue(t, conn, "select $1::int2[]", inputNumbers).([]int16) +// testEqual(inputNumbers, outputNumbers) - mustPrepare(t, conn, "testTranscode", "select $1::int2[]") - defer func() { - if err := conn.Deallocate("testTranscode"); err != nil { - t.Fatalf("Unable to deallocate prepared statement: %v", err) - } - }() +// mustPrepare(t, conn, "testTranscode", "select $1::int2[]") +// defer func() { +// if err := conn.Deallocate("testTranscode"); err != nil { +// t.Fatalf("Unable to deallocate prepared statement: %v", err) +// } +// }() - outputNumbers = mustSelectValue(t, conn, "testTranscode", inputNumbers).([]int16) - testEqual(inputNumbers, outputNumbers) -} +// outputNumbers = mustSelectValue(t, conn, "testTranscode", inputNumbers).([]int16) +// testEqual(inputNumbers, outputNumbers) +// } -func TestInt4SliceTranscode(t *testing.T) { - t.Parallel() +// func TestInt4SliceTranscode(t *testing.T) { +// t.Parallel() - testEqual := func(a, b []int32) { - if len(a) != len(b) { - t.Errorf("Did not transcode []int32 successfully: %v is not %v", a, b) - } - for i := range a { - if a[i] != b[i] { - t.Errorf("Did not transcode []int32 successfully: %v is not %v", a, b) - } - } - } +// testEqual := func(a, b []int32) { +// if len(a) != len(b) { +// t.Errorf("Did not transcode []int32 successfully: %v is not %v", a, b) +// } +// for i := range a { +// if a[i] != b[i] { +// t.Errorf("Did not transcode []int32 successfully: %v is not %v", a, b) +// } +// } +// } - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) +// conn := mustConnect(t, *defaultConnConfig) +// defer closeConn(t, conn) - inputNumbers := []int32{1, 2, 3, 4, 5, 6, 7, 8} - var outputNumbers []int32 +// inputNumbers := []int32{1, 2, 3, 4, 5, 6, 7, 8} +// var outputNumbers []int32 - outputNumbers = mustSelectValue(t, conn, "select $1::int4[]", inputNumbers).([]int32) - testEqual(inputNumbers, outputNumbers) +// outputNumbers = mustSelectValue(t, conn, "select $1::int4[]", inputNumbers).([]int32) +// testEqual(inputNumbers, outputNumbers) - mustPrepare(t, conn, "testTranscode", "select $1::int4[]") - defer func() { - if err := conn.Deallocate("testTranscode"); err != nil { - t.Fatalf("Unable to deallocate prepared statement: %v", err) - } - }() +// mustPrepare(t, conn, "testTranscode", "select $1::int4[]") +// defer func() { +// if err := conn.Deallocate("testTranscode"); err != nil { +// t.Fatalf("Unable to deallocate prepared statement: %v", err) +// } +// }() - outputNumbers = mustSelectValue(t, conn, "testTranscode", inputNumbers).([]int32) - testEqual(inputNumbers, outputNumbers) -} +// outputNumbers = mustSelectValue(t, conn, "testTranscode", inputNumbers).([]int32) +// testEqual(inputNumbers, outputNumbers) +// } -func TestInt8SliceTranscode(t *testing.T) { - t.Parallel() +// func TestInt8SliceTranscode(t *testing.T) { +// t.Parallel() - testEqual := func(a, b []int64) { - if len(a) != len(b) { - t.Errorf("Did not transcode []int64 successfully: %v is not %v", a, b) - } - for i := range a { - if a[i] != b[i] { - t.Errorf("Did not transcode []int64 successfully: %v is not %v", a, b) - } - } - } +// testEqual := func(a, b []int64) { +// if len(a) != len(b) { +// t.Errorf("Did not transcode []int64 successfully: %v is not %v", a, b) +// } +// for i := range a { +// if a[i] != b[i] { +// t.Errorf("Did not transcode []int64 successfully: %v is not %v", a, b) +// } +// } +// } - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) +// conn := mustConnect(t, *defaultConnConfig) +// defer closeConn(t, conn) - inputNumbers := []int64{1, 2, 3, 4, 5, 6, 7, 8} - var outputNumbers []int64 +// inputNumbers := []int64{1, 2, 3, 4, 5, 6, 7, 8} +// var outputNumbers []int64 - outputNumbers = mustSelectValue(t, conn, "select $1::int8[]", inputNumbers).([]int64) - testEqual(inputNumbers, outputNumbers) +// outputNumbers = mustSelectValue(t, conn, "select $1::int8[]", inputNumbers).([]int64) +// testEqual(inputNumbers, outputNumbers) - mustPrepare(t, conn, "testTranscode", "select $1::int8[]") - defer func() { - if err := conn.Deallocate("testTranscode"); err != nil { - t.Fatalf("Unable to deallocate prepared statement: %v", err) - } - }() +// mustPrepare(t, conn, "testTranscode", "select $1::int8[]") +// defer func() { +// if err := conn.Deallocate("testTranscode"); err != nil { +// t.Fatalf("Unable to deallocate prepared statement: %v", err) +// } +// }() - outputNumbers = mustSelectValue(t, conn, "testTranscode", inputNumbers).([]int64) - testEqual(inputNumbers, outputNumbers) -} +// outputNumbers = mustSelectValue(t, conn, "testTranscode", inputNumbers).([]int64) +// testEqual(inputNumbers, outputNumbers) +// }