diff --git a/conn.go b/conn.go index 0d1c5582..ee30d15f 100644 --- a/conn.go +++ b/conn.go @@ -389,318 +389,6 @@ func (c *Conn) CauseOfDeath() error { return c.causeOfDeath } -type Row Rows - -func (r *Row) Scan(dest ...interface{}) (err error) { - rows := (*Rows)(r) - - if rows.Err() != nil { - return rows.Err() - } - - if !rows.Next() { - if rows.Err() == nil { - return ErrNoRows - } else { - return rows.Err() - } - } - - rows.Scan(dest...) - rows.Close() - return rows.Err() -} - -type Rows struct { - pool *ConnPool - conn *Conn - mr *MsgReader - fields []FieldDescription - rowCount int - columnIdx int - err error - closed bool -} - -func (rows *Rows) FieldDescriptions() []FieldDescription { - return rows.fields -} - -func (rows *Rows) MsgReader() *MsgReader { - return rows.mr -} - -func (rows *Rows) close() { - if rows.pool != nil { - rows.pool.Release(rows.conn) - rows.pool = nil - } - - rows.closed = true -} - -func (rows *Rows) readUntilReadyForQuery() { - for { - t, r, err := rows.conn.rxMsg() - if err != nil { - rows.close() - return - } - - switch t { - case readyForQuery: - rows.conn.rxReadyForQuery(r) - rows.close() - return - case rowDescription: - case dataRow: - case commandComplete: - case bindComplete: - default: - err = rows.conn.processContextFreeMsg(t, r) - if err != nil { - rows.close() - return - } - } - } -} - -func (rows *Rows) Close() { - if rows.closed { - return - } - rows.readUntilReadyForQuery() - rows.close() -} - -func (rows *Rows) Err() error { - return rows.err -} - -// abort signals that the query was not successfully sent to the server. -// This differs from Fatal in that it is not necessary to readUntilReadyForQuery -func (rows *Rows) abort(err error) { - if rows.err != nil { - return - } - - rows.err = err - rows.close() -} - -// Fatal signals an error occurred after the query was sent to the server -func (rows *Rows) Fatal(err error) { - if rows.err != nil { - return - } - - rows.err = err - rows.Close() -} - -func (rows *Rows) Next() bool { - if rows.closed { - return false - } - - rows.rowCount++ - rows.columnIdx = 0 - - for { - t, r, err := rows.conn.rxMsg() - if err != nil { - rows.Fatal(err) - return false - } - - switch t { - case readyForQuery: - rows.conn.rxReadyForQuery(r) - rows.close() - return false - case dataRow: - fieldCount := r.ReadInt16() - if int(fieldCount) != len(rows.fields) { - rows.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), fieldCount))) - return false - } - - rows.mr = r - return true - case commandComplete: - case bindComplete: - default: - err = rows.conn.processContextFreeMsg(t, r) - if err != nil { - rows.Fatal(err) - return false - } - } - } -} - -func (rows *Rows) nextColumn() (*FieldDescription, int32, bool) { - if rows.closed { - return nil, 0, false - } - if len(rows.fields) <= rows.columnIdx { - rows.Fatal(ProtocolError("No next column available")) - return nil, 0, false - } - - fd := &rows.fields[rows.columnIdx] - rows.columnIdx++ - size := rows.mr.ReadInt32() - return fd, size, true -} - -func (rows *Rows) Scan(dest ...interface{}) (err error) { - if len(rows.fields) != len(dest) { - err = errors.New("Scan received wrong number of arguments") - rows.Fatal(err) - return err - } - - for _, d := range dest { - fd, size, _ := rows.nextColumn() - switch d := d.(type) { - case *bool: - *d = decodeBool(rows, fd, size) - case *[]byte: - *d = decodeBytea(rows, fd, size) - case *int64: - *d = decodeInt8(rows, fd, size) - case *int16: - *d = decodeInt2(rows, fd, size) - case *int32: - *d = decodeInt4(rows, fd, size) - case *string: - *d = decodeText(rows, fd, size) - case *float32: - *d = decodeFloat4(rows, fd, size) - case *float64: - *d = decodeFloat8(rows, fd, size) - case *time.Time: - if fd.DataType == DateOid { - *d = decodeDate(rows, fd, size) - } else { - *d = decodeTimestampTz(rows, fd, size) - } - - case Scanner: - err = d.Scan(rows, fd, size) - if err != nil { - return err - } - default: - return errors.New("Unknown type") - } - } - - return nil -} - -// Values returns an array of the row values -func (rows *Rows) Values() ([]interface{}, error) { - if rows.closed { - return nil, errors.New("rows is closed") - } - - values := make([]interface{}, 0, len(rows.fields)) - - for _, _ = range rows.fields { - if rows.Err() != nil { - return nil, rows.Err() - } - - fd, size, _ := rows.nextColumn() - - switch fd.DataType { - case BoolOid: - values = append(values, decodeBool(rows, fd, size)) - case ByteaOid: - values = append(values, decodeBytea(rows, fd, size)) - case Int8Oid: - values = append(values, decodeInt8(rows, fd, size)) - case Int2Oid: - values = append(values, decodeInt2(rows, fd, size)) - case Int4Oid: - values = append(values, decodeInt4(rows, fd, size)) - case VarcharOid, TextOid: - values = append(values, decodeText(rows, fd, size)) - case Float4Oid: - values = append(values, decodeFloat4(rows, fd, size)) - case Float8Oid: - values = append(values, decodeFloat8(rows, fd, size)) - case DateOid: - values = append(values, decodeDate(rows, fd, size)) - case TimestampTzOid: - values = append(values, decodeTimestampTz(rows, fd, size)) - default: - // if it is not an intrinsic type then return the text - switch fd.FormatCode { - case TextFormatCode: - values = append(values, rows.MsgReader().ReadString(size)) - case BinaryFormatCode: - return nil, errors.New("Values cannot handle binary format non-intrinsic types") - default: - return nil, errors.New("Unknown format code") - } - } - } - - return values, rows.Err() -} - -// TODO - document -func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) { - c.rows = Rows{conn: c} - rows := &c.rows - - if ps, present := c.preparedStatements[sql]; present { - rows.fields = ps.FieldDescriptions - err := c.sendPreparedQuery(ps, args...) - if err != nil { - rows.abort(err) - } - return rows, rows.err - } - - err := c.sendSimpleQuery(sql, args...) - if err != nil { - rows.abort(err) - return rows, rows.err - } - - // Simple queries don't know the field descriptions of the result. - // Read until that is known before returning - for { - t, r, err := c.rxMsg() - if err != nil { - rows.Fatal(err) - return rows, rows.err - } - - switch t { - case rowDescription: - rows.fields = rows.conn.rxRowDescription(r) - return rows, nil - default: - err = rows.conn.processContextFreeMsg(t, r) - if err != nil { - rows.Fatal(err) - return rows, rows.err - } - } - } -} - -func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { - rows, _ := c.Query(sql, args...) - return (*Row)(rows) -} - func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) { if ps, present := c.preparedStatements[sql]; present { return c.sendPreparedQuery(ps, arguments...) diff --git a/conn_test.go b/conn_test.go index 04e418e7..a2524dec 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,8 +1,6 @@ package pgx_test import ( - "bytes" - "fmt" "github.com/jackc/pgx" "strings" "sync" @@ -302,380 +300,6 @@ func TestExecFailure(t *testing.T) { } } -func TestConnQueryScan(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - var sum, rowCount int32 - - rows, err := conn.Query("select generate_series(1,$1)", 10) - if err != nil { - t.Fatalf("conn.Query failed: ", err) - } - defer rows.Close() - - for rows.Next() { - var n int32 - rows.Scan(&n) - sum += n - rowCount++ - } - - if rows.Err() != nil { - t.Fatalf("conn.Query failed: ", err) - } - - if rowCount != 10 { - t.Error("Select called onDataRow wrong number of times") - } - if sum != 55 { - t.Error("Wrong values returned") - } -} - -func TestConnQueryValues(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - var rowCount int32 - - rows, err := conn.Query("select 'foo', n from generate_series(1,$1) n", 10) - if err != nil { - t.Fatalf("conn.Query failed: ", err) - } - defer rows.Close() - - for rows.Next() { - rowCount++ - - values, err := rows.Values() - if err != nil { - t.Fatalf("rows.Values failed: %v", err) - } - if len(values) != 2 { - t.Errorf("Expected rows.Values to return 2 values, but it returned %d", len(values)) - } - if values[0] != "foo" { - t.Errorf(`Expected values[0] to be "foo", but it was %v`, values[0]) - } - if values[0] != "foo" { - t.Errorf(`Expected values[0] to be "foo", but it was %v`, values[0]) - } - - if values[1] != rowCount { - t.Errorf(`Expected values[1] to be %d, but it was %d`, rowCount, values[1]) - } - } - - if rows.Err() != nil { - t.Fatalf("conn.Query failed: ", err) - } - - if rowCount != 10 { - t.Error("Select called onDataRow wrong number of times") - } -} - -// Do a simple query to ensure the connection is still usable -func ensureConnValid(t *testing.T, conn *pgx.Conn) { - var sum, rowCount int32 - - rows, err := conn.Query("select generate_series(1,$1)", 10) - if err != nil { - t.Fatalf("conn.Query failed: ", err) - } - defer rows.Close() - - for rows.Next() { - var n int32 - rows.Scan(&n) - sum += n - rowCount++ - } - - if rows.Err() != nil { - t.Fatalf("conn.Query failed: ", err) - } - - if rowCount != 10 { - t.Error("Select called onDataRow wrong number of times") - } - if sum != 55 { - t.Error("Wrong values returned") - } -} - -// Test that a connection stays valid when query results are closed early -func TestConnQueryCloseEarly(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - // Immediately close query without reading any rows - rows, err := conn.Query("select generate_series(1,$1)", 10) - if err != nil { - t.Fatalf("conn.Query failed: ", err) - } - rows.Close() - - ensureConnValid(t, conn) - - // Read partial response then close - rows, err = conn.Query("select generate_series(1,$1)", 10) - if err != nil { - t.Fatalf("conn.Query failed: ", err) - } - - ok := rows.Next() - if !ok { - t.Fatal("rows.Next terminated early") - } - - var n int32 - rows.Scan(&n) - if n != 1 { - t.Fatalf("Expected 1 from first row, but got %v", n) - } - - rows.Close() - - ensureConnValid(t, conn) -} - -// Test that a connection stays valid when query results read incorrectly -func TestConnQueryReadWrongTypeError(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - // Read a single value incorrectly - rows, err := conn.Query("select generate_series(1,$1)", 10) - if err != nil { - t.Fatalf("conn.Query failed: ", err) - } - - rowsRead := 0 - - for rows.Next() { - var t time.Time - rows.Scan(&t) - rowsRead++ - } - - if rowsRead != 1 { - t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead) - } - - if rows.Err() == nil { - t.Fatal("Expected Rows to have an error after an improper read but it didn't") - } - - ensureConnValid(t, conn) -} - -// Test that a connection stays valid when query results read incorrectly -func TestConnQueryReadTooManyValues(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - // Read too many values - rows, err := conn.Query("select generate_series(1,$1)", 10) - if err != nil { - t.Fatalf("conn.Query failed: ", err) - } - - rowsRead := 0 - - for rows.Next() { - var n, m int32 - rows.Scan(&n, &m) - rowsRead++ - } - - if rowsRead != 1 { - t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead) - } - - if rows.Err() == nil { - t.Fatal("Expected Rows to have an error after an improper read but it didn't") - } - - ensureConnValid(t, conn) -} - -func TestConnQueryUnpreparedScanner(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - rows, err := conn.Query("select null::int8, 1::int8") - if err != nil { - t.Fatalf("conn.Query failed: ", err) - } - - ok := rows.Next() - if !ok { - t.Fatal("rows.Next terminated early") - } - - var n, m pgx.NullInt64 - err = rows.Scan(&n, &m) - if err != nil { - t.Fatalf("rows.Scan failed: ", err) - } - rows.Close() - - if n.Valid { - t.Error("Null should not be valid, but it was") - } - - if !m.Valid { - t.Error("1 should be valid, but it wasn't") - } - - if m.Int64 != 1 { - t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64) - } - - ensureConnValid(t, conn) -} - -func TestConnQueryPreparedScanner(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - mustPrepare(t, conn, "scannerTest", "select null::int8, 1::int8") - - rows, err := conn.Query("scannerTest") - if err != nil { - t.Fatalf("conn.Query failed: ", err) - } - - ok := rows.Next() - if !ok { - t.Fatal("rows.Next terminated early") - } - - var n, m pgx.NullInt64 - err = rows.Scan(&n, &m) - if err != nil { - t.Fatalf("rows.Scan failed: ", err) - } - rows.Close() - - if n.Valid { - t.Error("Null should not be valid, but it was") - } - - if !m.Valid { - t.Error("1 should be valid, but it wasn't") - } - - if m.Int64 != 1 { - t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64) - } - - ensureConnValid(t, conn) -} - -func TestConnQueryUnpreparedEncoder(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - n := pgx.NullInt64{Int64: 1, Valid: true} - - rows, err := conn.Query("select $1::int8", &n) - if err != nil { - t.Fatalf("conn.Query failed: ", err) - } - - ok := rows.Next() - if !ok { - t.Fatal("rows.Next terminated early") - } - - var m pgx.NullInt64 - err = rows.Scan(&m) - if err != nil { - t.Fatalf("rows.Scan failed: ", err) - } - rows.Close() - - if !m.Valid { - t.Error("m should be valid, but it wasn't") - } - - if m.Int64 != 1 { - t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64) - } - - ensureConnValid(t, conn) -} - -func TestQueryPreparedEncodeError(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - mustPrepare(t, conn, "testTranscode", "select $1::integer") - defer func() { - if err := conn.Deallocate("testTranscode"); err != nil { - t.Fatalf("Unable to deallocate prepared statement: %v", err) - } - }() - - _, err := conn.Query("testTranscode", "wrong") - switch { - case err == nil: - t.Error("Expected transcode error to return error, but it didn't") - case err.Error() == "Expected integer representable in int32, received string wrong": - // Correct behavior - default: - t.Errorf("Expected transcode error, received %v", err) - } -} - -// Ensure that an argument that implements TextEncoder, but not BinaryEncoder -// works when the parameter type is a core type. -type coreTextEncoder struct{} - -func (n *coreTextEncoder) EncodeText() (string, byte, error) { - return "42", pgx.SafeText, nil -} - -func TestQueryPreparedEncodeCoreTextFormatError(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - mustPrepare(t, conn, "testTranscode", "select $1::integer") - - var n int32 - err := conn.QueryRow("testTranscode", &coreTextEncoder{}).Scan(&n) - if err != nil { - t.Fatalf("Unexpected conn.QueryRow error: %v", err) - } - - if n != 42 { - t.Errorf("Expected 42, got %v", n) - } -} - func TestPrepare(t *testing.T) { t.Parallel() @@ -857,188 +481,3 @@ func TestCommandTag(t *testing.T) { } } } - -func TestQueryRowCoreTypes(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - type allTypes struct { - s string - i16 int16 - i32 int32 - i64 int64 - f32 float32 - f64 float64 - b bool - t time.Time - } - - var actual, zero allTypes - - tests := []struct { - sql string - queryArgs []interface{} - scanArgs []interface{} - expected allTypes - }{ - {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.s}, allTypes{s: "Jack"}}, - {"select $1::int2", []interface{}{int16(42)}, []interface{}{&actual.i16}, allTypes{i16: 42}}, - {"select $1::int4", []interface{}{int32(42)}, []interface{}{&actual.i32}, allTypes{i32: 42}}, - {"select $1::int8", []interface{}{int64(42)}, []interface{}{&actual.i64}, allTypes{i64: 42}}, - {"select $1::float4", []interface{}{float32(1.23)}, []interface{}{&actual.f32}, allTypes{f32: 1.23}}, - {"select $1::float8", []interface{}{float64(1.23)}, []interface{}{&actual.f64}, allTypes{f64: 1.23}}, - {"select $1::bool", []interface{}{true}, []interface{}{&actual.b}, allTypes{b: true}}, - {"select $1::timestamptz", []interface{}{time.Unix(123, 5000)}, []interface{}{&actual.t}, allTypes{t: time.Unix(123, 5000)}}, - {"select $1::date", []interface{}{time.Date(1987, 1, 2, 0, 0, 0, 0, time.Local)}, []interface{}{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.Local)}}, - } - - for i, tt := range tests { - psName := fmt.Sprintf("success%d", i) - mustPrepare(t, conn, psName, tt.sql) - - for _, sql := range []string{tt.sql, psName} { - actual = zero - - err := conn.QueryRow(sql, tt.queryArgs...).Scan(tt.scanArgs...) - if err != nil { - t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, sql, tt.queryArgs) - } - - if actual != tt.expected { - t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, sql, tt.queryArgs) - } - - ensureConnValid(t, conn) - } - } -} - -func TestQueryRowCoreBytea(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - var actual []byte - sql := "select $1::bytea" - queryArg := []byte{0, 15, 255, 17} - expected := []byte{0, 15, 255, 17} - - psName := "selectBytea" - mustPrepare(t, conn, psName, sql) - - for _, sql := range []string{sql, psName} { - actual = nil - - err := conn.QueryRow(sql, queryArg).Scan(&actual) - if err != nil { - t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) - } - - if bytes.Compare(actual, expected) != 0 { - t.Errorf("Expected %v, got %v (sql -> %v)", expected, actual, sql) - } - - ensureConnValid(t, conn) - } -} - -func TestQueryRowUnpreparedErrors(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - type allTypes struct { - i16 int16 - } - - var actual, zero allTypes - - tests := []struct { - sql string - queryArgs []interface{} - scanArgs []interface{} - err string - }{ - {"select $1", []interface{}{"Jack"}, []interface{}{&actual.i16}, "Expected type oid 21 but received type oid 705"}, - {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`}, - {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"}, - } - - for i, tt := range tests { - actual = zero - - err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) - if err == nil { - t.Errorf("%d. Unexpected success (sql -> %v, queryArgs -> %v)", i, tt.sql, tt.queryArgs) - } - if !strings.Contains(err.Error(), tt.err) { - t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, queryArgs -> %v)", i, tt.err, err, tt.sql, tt.queryArgs) - } - - ensureConnValid(t, conn) - } -} - -func TestQueryRowPreparedErrors(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - type allTypes struct { - i16 int16 - } - - var actual, zero allTypes - - tests := []struct { - sql string - queryArgs []interface{} - scanArgs []interface{} - err string - }{ - {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "Expected type oid 21 but received type oid 25"}, - } - - for i, tt := range tests { - psName := fmt.Sprintf("ps%d", i) - mustPrepare(t, conn, psName, tt.sql) - - actual = zero - - err := conn.QueryRow(psName, tt.queryArgs...).Scan(tt.scanArgs...) - if err == nil { - t.Errorf("%d. Unexpected success (sql -> %v, queryArgs -> %v)", i, tt.sql, tt.queryArgs) - } - if !strings.Contains(err.Error(), tt.err) { - t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, queryArgs -> %v)", i, tt.err, err, tt.sql, tt.queryArgs) - } - - ensureConnValid(t, conn) - } -} - -func TestQueryRowNoResults(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - sql := "select 1 where 1=0" - psName := "selectNothing" - mustPrepare(t, conn, psName, sql) - - for _, sql := range []string{sql, psName} { - var n int32 - err := conn.QueryRow(sql).Scan(&n) - if err != pgx.ErrNoRows { - t.Errorf("Expected pgx.ErrNoRows, got %v", err) - } - - ensureConnValid(t, conn) - } -} diff --git a/helper_test.go b/helper_test.go index 7eb5062a..039b5811 100644 --- a/helper_test.go +++ b/helper_test.go @@ -33,3 +33,32 @@ func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{} } return } + +// Do a simple query to ensure the connection is still usable +func ensureConnValid(t *testing.T, conn *pgx.Conn) { + var sum, rowCount int32 + + rows, err := conn.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + defer rows.Close() + + for rows.Next() { + var n int32 + rows.Scan(&n) + sum += n + rowCount++ + } + + if rows.Err() != nil { + t.Fatalf("conn.Query failed: ", err) + } + + if rowCount != 10 { + t.Error("Select called onDataRow wrong number of times") + } + if sum != 55 { + t.Error("Wrong values returned") + } +} diff --git a/query.go b/query.go new file mode 100644 index 00000000..3c23d52c --- /dev/null +++ b/query.go @@ -0,0 +1,319 @@ +package pgx + +import ( + "errors" + "fmt" + "time" +) + +type Row Rows + +func (r *Row) Scan(dest ...interface{}) (err error) { + rows := (*Rows)(r) + + if rows.Err() != nil { + return rows.Err() + } + + if !rows.Next() { + if rows.Err() == nil { + return ErrNoRows + } else { + return rows.Err() + } + } + + rows.Scan(dest...) + rows.Close() + return rows.Err() +} + +type Rows struct { + pool *ConnPool + conn *Conn + mr *MsgReader + fields []FieldDescription + rowCount int + columnIdx int + err error + closed bool +} + +func (rows *Rows) FieldDescriptions() []FieldDescription { + return rows.fields +} + +func (rows *Rows) MsgReader() *MsgReader { + return rows.mr +} + +func (rows *Rows) close() { + if rows.pool != nil { + rows.pool.Release(rows.conn) + rows.pool = nil + } + + rows.closed = true +} + +func (rows *Rows) readUntilReadyForQuery() { + for { + t, r, err := rows.conn.rxMsg() + if err != nil { + rows.close() + return + } + + switch t { + case readyForQuery: + rows.conn.rxReadyForQuery(r) + rows.close() + return + case rowDescription: + case dataRow: + case commandComplete: + case bindComplete: + default: + err = rows.conn.processContextFreeMsg(t, r) + if err != nil { + rows.close() + return + } + } + } +} + +func (rows *Rows) Close() { + if rows.closed { + return + } + rows.readUntilReadyForQuery() + rows.close() +} + +func (rows *Rows) Err() error { + return rows.err +} + +// abort signals that the query was not successfully sent to the server. +// This differs from Fatal in that it is not necessary to readUntilReadyForQuery +func (rows *Rows) abort(err error) { + if rows.err != nil { + return + } + + rows.err = err + rows.close() +} + +// Fatal signals an error occurred after the query was sent to the server +func (rows *Rows) Fatal(err error) { + if rows.err != nil { + return + } + + rows.err = err + rows.Close() +} + +func (rows *Rows) Next() bool { + if rows.closed { + return false + } + + rows.rowCount++ + rows.columnIdx = 0 + + for { + t, r, err := rows.conn.rxMsg() + if err != nil { + rows.Fatal(err) + return false + } + + switch t { + case readyForQuery: + rows.conn.rxReadyForQuery(r) + rows.close() + return false + case dataRow: + fieldCount := r.ReadInt16() + if int(fieldCount) != len(rows.fields) { + rows.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), fieldCount))) + return false + } + + rows.mr = r + return true + case commandComplete: + case bindComplete: + default: + err = rows.conn.processContextFreeMsg(t, r) + if err != nil { + rows.Fatal(err) + return false + } + } + } +} + +func (rows *Rows) nextColumn() (*FieldDescription, int32, bool) { + if rows.closed { + return nil, 0, false + } + if len(rows.fields) <= rows.columnIdx { + rows.Fatal(ProtocolError("No next column available")) + return nil, 0, false + } + + fd := &rows.fields[rows.columnIdx] + rows.columnIdx++ + size := rows.mr.ReadInt32() + return fd, size, true +} + +func (rows *Rows) Scan(dest ...interface{}) (err error) { + if len(rows.fields) != len(dest) { + err = errors.New("Scan received wrong number of arguments") + rows.Fatal(err) + return err + } + + for _, d := range dest { + fd, size, _ := rows.nextColumn() + switch d := d.(type) { + case *bool: + *d = decodeBool(rows, fd, size) + case *[]byte: + *d = decodeBytea(rows, fd, size) + case *int64: + *d = decodeInt8(rows, fd, size) + case *int16: + *d = decodeInt2(rows, fd, size) + case *int32: + *d = decodeInt4(rows, fd, size) + case *string: + *d = decodeText(rows, fd, size) + case *float32: + *d = decodeFloat4(rows, fd, size) + case *float64: + *d = decodeFloat8(rows, fd, size) + case *time.Time: + if fd.DataType == DateOid { + *d = decodeDate(rows, fd, size) + } else { + *d = decodeTimestampTz(rows, fd, size) + } + + case Scanner: + err = d.Scan(rows, fd, size) + if err != nil { + return err + } + default: + return errors.New("Unknown type") + } + } + + return nil +} + +// Values returns an array of the row values +func (rows *Rows) Values() ([]interface{}, error) { + if rows.closed { + return nil, errors.New("rows is closed") + } + + values := make([]interface{}, 0, len(rows.fields)) + + for _, _ = range rows.fields { + if rows.Err() != nil { + return nil, rows.Err() + } + + fd, size, _ := rows.nextColumn() + + switch fd.DataType { + case BoolOid: + values = append(values, decodeBool(rows, fd, size)) + case ByteaOid: + values = append(values, decodeBytea(rows, fd, size)) + case Int8Oid: + values = append(values, decodeInt8(rows, fd, size)) + case Int2Oid: + values = append(values, decodeInt2(rows, fd, size)) + case Int4Oid: + values = append(values, decodeInt4(rows, fd, size)) + case VarcharOid, TextOid: + values = append(values, decodeText(rows, fd, size)) + case Float4Oid: + values = append(values, decodeFloat4(rows, fd, size)) + case Float8Oid: + values = append(values, decodeFloat8(rows, fd, size)) + case DateOid: + values = append(values, decodeDate(rows, fd, size)) + case TimestampTzOid: + values = append(values, decodeTimestampTz(rows, fd, size)) + default: + // if it is not an intrinsic type then return the text + switch fd.FormatCode { + case TextFormatCode: + values = append(values, rows.MsgReader().ReadString(size)) + case BinaryFormatCode: + return nil, errors.New("Values cannot handle binary format non-intrinsic types") + default: + return nil, errors.New("Unknown format code") + } + } + } + + return values, rows.Err() +} + +// TODO - document +func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) { + c.rows = Rows{conn: c} + rows := &c.rows + + if ps, present := c.preparedStatements[sql]; present { + rows.fields = ps.FieldDescriptions + err := c.sendPreparedQuery(ps, args...) + if err != nil { + rows.abort(err) + } + return rows, rows.err + } + + err := c.sendSimpleQuery(sql, args...) + if err != nil { + rows.abort(err) + return rows, rows.err + } + + // Simple queries don't know the field descriptions of the result. + // Read until that is known before returning + for { + t, r, err := c.rxMsg() + if err != nil { + rows.Fatal(err) + return rows, rows.err + } + + switch t { + case rowDescription: + rows.fields = rows.conn.rxRowDescription(r) + return rows, nil + default: + err = rows.conn.processContextFreeMsg(t, r) + if err != nil { + rows.Fatal(err) + return rows, rows.err + } + } + } +} + +func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { + rows, _ := c.Query(sql, args...) + return (*Row)(rows) +} diff --git a/query_test.go b/query_test.go new file mode 100644 index 00000000..067947c4 --- /dev/null +++ b/query_test.go @@ -0,0 +1,540 @@ +package pgx_test + +import ( + "bytes" + "fmt" + "github.com/jackc/pgx" + "strings" + "testing" + "time" +) + +func TestConnQueryScan(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var sum, rowCount int32 + + rows, err := conn.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + defer rows.Close() + + for rows.Next() { + var n int32 + rows.Scan(&n) + sum += n + rowCount++ + } + + if rows.Err() != nil { + t.Fatalf("conn.Query failed: ", err) + } + + if rowCount != 10 { + t.Error("Select called onDataRow wrong number of times") + } + if sum != 55 { + t.Error("Wrong values returned") + } +} + +func TestConnQueryValues(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var rowCount int32 + + rows, err := conn.Query("select 'foo', n from generate_series(1,$1) n", 10) + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + defer rows.Close() + + for rows.Next() { + rowCount++ + + values, err := rows.Values() + if err != nil { + t.Fatalf("rows.Values failed: %v", err) + } + if len(values) != 2 { + t.Errorf("Expected rows.Values to return 2 values, but it returned %d", len(values)) + } + if values[0] != "foo" { + t.Errorf(`Expected values[0] to be "foo", but it was %v`, values[0]) + } + if values[0] != "foo" { + t.Errorf(`Expected values[0] to be "foo", but it was %v`, values[0]) + } + + if values[1] != rowCount { + t.Errorf(`Expected values[1] to be %d, but it was %d`, rowCount, values[1]) + } + } + + if rows.Err() != nil { + t.Fatalf("conn.Query failed: ", err) + } + + if rowCount != 10 { + t.Error("Select called onDataRow wrong number of times") + } +} + +// Test that a connection stays valid when query results are closed early +func TestConnQueryCloseEarly(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + // Immediately close query without reading any rows + rows, err := conn.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + rows.Close() + + ensureConnValid(t, conn) + + // Read partial response then close + rows, err = conn.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + + ok := rows.Next() + if !ok { + t.Fatal("rows.Next terminated early") + } + + var n int32 + rows.Scan(&n) + if n != 1 { + t.Fatalf("Expected 1 from first row, but got %v", n) + } + + rows.Close() + + ensureConnValid(t, conn) +} + +// Test that a connection stays valid when query results read incorrectly +func TestConnQueryReadWrongTypeError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + // Read a single value incorrectly + rows, err := conn.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + + rowsRead := 0 + + for rows.Next() { + var t time.Time + rows.Scan(&t) + rowsRead++ + } + + if rowsRead != 1 { + t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead) + } + + if rows.Err() == nil { + t.Fatal("Expected Rows to have an error after an improper read but it didn't") + } + + ensureConnValid(t, conn) +} + +// Test that a connection stays valid when query results read incorrectly +func TestConnQueryReadTooManyValues(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + // Read too many values + rows, err := conn.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + + rowsRead := 0 + + for rows.Next() { + var n, m int32 + rows.Scan(&n, &m) + rowsRead++ + } + + if rowsRead != 1 { + t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead) + } + + if rows.Err() == nil { + t.Fatal("Expected Rows to have an error after an improper read but it didn't") + } + + ensureConnValid(t, conn) +} + +func TestConnQueryUnpreparedScanner(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + rows, err := conn.Query("select null::int8, 1::int8") + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + + ok := rows.Next() + if !ok { + t.Fatal("rows.Next terminated early") + } + + var n, m pgx.NullInt64 + err = rows.Scan(&n, &m) + if err != nil { + t.Fatalf("rows.Scan failed: ", err) + } + rows.Close() + + if n.Valid { + t.Error("Null should not be valid, but it was") + } + + if !m.Valid { + t.Error("1 should be valid, but it wasn't") + } + + if m.Int64 != 1 { + t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64) + } + + ensureConnValid(t, conn) +} + +func TestConnQueryPreparedScanner(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustPrepare(t, conn, "scannerTest", "select null::int8, 1::int8") + + rows, err := conn.Query("scannerTest") + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + + ok := rows.Next() + if !ok { + t.Fatal("rows.Next terminated early") + } + + var n, m pgx.NullInt64 + err = rows.Scan(&n, &m) + if err != nil { + t.Fatalf("rows.Scan failed: ", err) + } + rows.Close() + + if n.Valid { + t.Error("Null should not be valid, but it was") + } + + if !m.Valid { + t.Error("1 should be valid, but it wasn't") + } + + if m.Int64 != 1 { + t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64) + } + + ensureConnValid(t, conn) +} + +func TestConnQueryUnpreparedEncoder(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + n := pgx.NullInt64{Int64: 1, Valid: true} + + rows, err := conn.Query("select $1::int8", &n) + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + + ok := rows.Next() + if !ok { + t.Fatal("rows.Next terminated early") + } + + var m pgx.NullInt64 + err = rows.Scan(&m) + if err != nil { + t.Fatalf("rows.Scan failed: ", err) + } + rows.Close() + + if !m.Valid { + t.Error("m should be valid, but it wasn't") + } + + if m.Int64 != 1 { + t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64) + } + + ensureConnValid(t, conn) +} + +func TestQueryPreparedEncodeError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustPrepare(t, conn, "testTranscode", "select $1::integer") + defer func() { + if err := conn.Deallocate("testTranscode"); err != nil { + t.Fatalf("Unable to deallocate prepared statement: %v", err) + } + }() + + _, err := conn.Query("testTranscode", "wrong") + switch { + case err == nil: + t.Error("Expected transcode error to return error, but it didn't") + case err.Error() == "Expected integer representable in int32, received string wrong": + // Correct behavior + default: + t.Errorf("Expected transcode error, received %v", err) + } +} + +// Ensure that an argument that implements TextEncoder, but not BinaryEncoder +// works when the parameter type is a core type. +type coreTextEncoder struct{} + +func (n *coreTextEncoder) EncodeText() (string, byte, error) { + return "42", pgx.SafeText, nil +} + +func TestQueryPreparedEncodeCoreTextFormatError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustPrepare(t, conn, "testTranscode", "select $1::integer") + + var n int32 + err := conn.QueryRow("testTranscode", &coreTextEncoder{}).Scan(&n) + if err != nil { + t.Fatalf("Unexpected conn.QueryRow error: %v", err) + } + + if n != 42 { + t.Errorf("Expected 42, got %v", n) + } +} + +func TestQueryRowCoreTypes(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + type allTypes struct { + s string + i16 int16 + i32 int32 + i64 int64 + f32 float32 + f64 float64 + b bool + t time.Time + } + + var actual, zero allTypes + + tests := []struct { + sql string + queryArgs []interface{} + scanArgs []interface{} + expected allTypes + }{ + {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.s}, allTypes{s: "Jack"}}, + {"select $1::int2", []interface{}{int16(42)}, []interface{}{&actual.i16}, allTypes{i16: 42}}, + {"select $1::int4", []interface{}{int32(42)}, []interface{}{&actual.i32}, allTypes{i32: 42}}, + {"select $1::int8", []interface{}{int64(42)}, []interface{}{&actual.i64}, allTypes{i64: 42}}, + {"select $1::float4", []interface{}{float32(1.23)}, []interface{}{&actual.f32}, allTypes{f32: 1.23}}, + {"select $1::float8", []interface{}{float64(1.23)}, []interface{}{&actual.f64}, allTypes{f64: 1.23}}, + {"select $1::bool", []interface{}{true}, []interface{}{&actual.b}, allTypes{b: true}}, + {"select $1::timestamptz", []interface{}{time.Unix(123, 5000)}, []interface{}{&actual.t}, allTypes{t: time.Unix(123, 5000)}}, + {"select $1::date", []interface{}{time.Date(1987, 1, 2, 0, 0, 0, 0, time.Local)}, []interface{}{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.Local)}}, + } + + for i, tt := range tests { + psName := fmt.Sprintf("success%d", i) + mustPrepare(t, conn, psName, tt.sql) + + for _, sql := range []string{tt.sql, psName} { + actual = zero + + err := conn.QueryRow(sql, tt.queryArgs...).Scan(tt.scanArgs...) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, sql, tt.queryArgs) + } + + if actual != tt.expected { + t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, sql, tt.queryArgs) + } + + ensureConnValid(t, conn) + } + } +} + +func TestQueryRowCoreBytea(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var actual []byte + sql := "select $1::bytea" + queryArg := []byte{0, 15, 255, 17} + expected := []byte{0, 15, 255, 17} + + psName := "selectBytea" + mustPrepare(t, conn, psName, sql) + + for _, sql := range []string{sql, psName} { + actual = nil + + err := conn.QueryRow(sql, queryArg).Scan(&actual) + if err != nil { + t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) + } + + if bytes.Compare(actual, expected) != 0 { + t.Errorf("Expected %v, got %v (sql -> %v)", expected, actual, sql) + } + + ensureConnValid(t, conn) + } +} + +func TestQueryRowUnpreparedErrors(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + type allTypes struct { + i16 int16 + } + + var actual, zero allTypes + + tests := []struct { + sql string + queryArgs []interface{} + scanArgs []interface{} + err string + }{ + {"select $1", []interface{}{"Jack"}, []interface{}{&actual.i16}, "Expected type oid 21 but received type oid 705"}, + {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`}, + {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"}, + } + + for i, tt := range tests { + actual = zero + + err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) + if err == nil { + t.Errorf("%d. Unexpected success (sql -> %v, queryArgs -> %v)", i, tt.sql, tt.queryArgs) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, queryArgs -> %v)", i, tt.err, err, tt.sql, tt.queryArgs) + } + + ensureConnValid(t, conn) + } +} + +func TestQueryRowPreparedErrors(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + type allTypes struct { + i16 int16 + } + + var actual, zero allTypes + + tests := []struct { + sql string + queryArgs []interface{} + scanArgs []interface{} + err string + }{ + {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "Expected type oid 21 but received type oid 25"}, + } + + for i, tt := range tests { + psName := fmt.Sprintf("ps%d", i) + mustPrepare(t, conn, psName, tt.sql) + + actual = zero + + err := conn.QueryRow(psName, tt.queryArgs...).Scan(tt.scanArgs...) + if err == nil { + t.Errorf("%d. Unexpected success (sql -> %v, queryArgs -> %v)", i, tt.sql, tt.queryArgs) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, queryArgs -> %v)", i, tt.err, err, tt.sql, tt.queryArgs) + } + + ensureConnValid(t, conn) + } +} + +func TestQueryRowNoResults(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + sql := "select 1 where 1=0" + psName := "selectNothing" + mustPrepare(t, conn, psName, sql) + + for _, sql := range []string{sql, psName} { + var n int32 + err := conn.QueryRow(sql).Scan(&n) + if err != pgx.ErrNoRows { + t.Errorf("Expected pgx.ErrNoRows, got %v", err) + } + + ensureConnValid(t, conn) + } +}