From fe0af9b35724be39ea7aa894d63b7cf563a9959c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 29 May 2017 19:15:42 -0500 Subject: [PATCH] Happy-path batch query mode --- batch.go | 211 +++++++++++++++++++++++++++++++++++++++++++++++++ batch_test.go | 150 +++++++++++++++++++++++++++++++++++ conn.go | 20 ++--- helper_test.go | 3 +- query.go | 2 +- 5 files changed, 375 insertions(+), 11 deletions(-) create mode 100644 batch.go create mode 100644 batch_test.go diff --git a/batch.go b/batch.go new file mode 100644 index 00000000..722ce340 --- /dev/null +++ b/batch.go @@ -0,0 +1,211 @@ +package pgx + +import ( + "context" + + "github.com/jackc/pgx/pgproto3" + "github.com/jackc/pgx/pgtype" +) + +type batchItem struct { + query string + arguments []interface{} + parameterOids []pgtype.Oid + resultFormatCodes []int16 +} + +type Batch struct { + conn *Conn + items []*batchItem + resultsRead int + sent bool +} + +// Begin starts a transaction with the default transaction mode for the +// current connection. To use a specific transaction mode see BeginEx. +func (c *Conn) BeginBatch() *Batch { + // TODO - the type stuff below + + // err = c.waitForPreviousCancelQuery(ctx) + // if err != nil { + // return nil, err + // } + + // if err := c.ensureConnectionReadyForQuery(); err != nil { + // return nil, err + // } + + // c.lastActivityTime = time.Now() + + // rows = c.getRows(sql, args) + + // if err := c.lock(); err != nil { + // rows.fatal(err) + // return rows, err + // } + // rows.unlockConn = true + + // err = c.initContext(ctx) + // if err != nil { + // rows.fatal(err) + // return rows, rows.err + // } + + // if options != nil && options.SimpleProtocol { + // err = c.sanitizeAndSendSimpleQuery(sql, args...) + // if err != nil { + // rows.fatal(err) + // return rows, err + // } + + // return rows, nil + // } + + return &Batch{conn: c} +} + +func (b *Batch) Conn() *Conn { + return b.conn +} + +func (b *Batch) Queue(query string, arguments []interface{}, parameterOids []pgtype.Oid, resultFormatCodes []int16) { + b.items = append(b.items, &batchItem{ + query: query, + arguments: arguments, + parameterOids: parameterOids, + resultFormatCodes: resultFormatCodes, + }) +} + +func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { + buf := appendQuery(b.conn.wbuf, txOptions.beginSQL()) + + for _, bi := range b.items { + // TODO - don't parse if named prepared statement + buf = appendParse(buf, "", bi.query, bi.parameterOids) + + var err error + buf, err = appendBind(buf, "", "", b.conn.ConnInfo, bi.parameterOids, bi.arguments, bi.resultFormatCodes) + if err != nil { + return err + } + + buf = appendDescribe(buf, 'P', "") + buf = appendExecute(buf, "", 0) + } + + buf = appendSync(buf) + buf = appendQuery(buf, "commit") + + n, err := b.conn.conn.Write(buf) + if err != nil { + if fatalWriteErr(n, err) { + b.conn.die(err) + } + return err + } + + // expect ReadyForQuery from sync and from commit + b.conn.pendingReadyForQueryCount = b.conn.pendingReadyForQueryCount + 2 + + b.sent = true + + for { + msg, err := b.conn.rxMsg() + if err != nil { + return err + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + return nil + default: + if err := b.conn.processContextFreeMsg(msg); err != nil { + return err + } + } + } + + return nil +} + +func (b *Batch) ExecResults() (CommandTag, error) { + b.resultsRead++ + + for { + msg, err := b.conn.rxMsg() + if err != nil { + return "", err + } + + switch msg := msg.(type) { + case *pgproto3.CommandComplete: + return CommandTag(msg.CommandTag), nil + default: + if err := b.conn.processContextFreeMsg(msg); err != nil { + return "", err + } + } + } +} + +func (b *Batch) QueryResults() (*Rows, error) { + b.resultsRead++ + + rows := b.conn.getRows("batch query", nil) + + fieldDescriptions, err := b.conn.readUntilRowDescription() + if err != nil { + rows.fatal(err) + return nil, err + } + + rows.fields = fieldDescriptions + return rows, nil +} + +func (b *Batch) QueryRowResults() *Row { + rows, _ := b.QueryResults() + return (*Row)(rows) + +} + +func (b *Batch) Finish() error { + for i := b.resultsRead; i < len(b.items); i++ { + _, err := b.ExecResults() + if err != nil { + return err + } + } + + // readyForQueryCount := 0 + + // for { + // msg, err := b.conn.rxMsg() + // if err != nil { + // return "", err + // } + + // switch msg := msg.(type) { + // case *pgproto3.ReadyForQuery: + // c.rxReadyForQuery(msg) + // default: + // if err := b.conn.processContextFreeMsg(msg); err != nil { + // return "", err + // } + // } + // } + + // switch msg := msg.(type) { + // case *pgproto3.ErrorResponse: + // return c.rxErrorResponse(msg) + // case *pgproto3.NotificationResponse: + // c.rxNotificationResponse(msg) + // case *pgproto3.ReadyForQuery: + // c.rxReadyForQuery(msg) + // case *pgproto3.ParameterStatus: + // c.rxParameterStatus(msg) + // } + + return nil +} diff --git a/batch_test.go b/batch_test.go new file mode 100644 index 00000000..aeef52f4 --- /dev/null +++ b/batch_test.go @@ -0,0 +1,150 @@ +package pgx_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" +) + +func TestConnBeginBatch(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + sql := `create temporary table ledger( + id serial primary key, + description varchar not null, + amount int not null +);` + mustExec(t, conn, sql) + + batch := conn.BeginBatch() + batch.Queue("insert into ledger(description, amount) values($1, $2)", + []interface{}{"q1", 1}, + []pgtype.Oid{pgtype.VarcharOid, pgtype.Int4Oid}, + nil, + ) + batch.Queue("insert into ledger(description, amount) values($1, $2)", + []interface{}{"q2", 2}, + []pgtype.Oid{pgtype.VarcharOid, pgtype.Int4Oid}, + nil, + ) + batch.Queue("insert into ledger(description, amount) values($1, $2)", + []interface{}{"q3", 3}, + []pgtype.Oid{pgtype.VarcharOid, pgtype.Int4Oid}, + nil, + ) + batch.Queue("select id, description, amount from ledger order by id", + nil, + nil, + []int16{pgx.BinaryFormatCode, pgx.TextFormatCode, pgx.BinaryFormatCode}, + ) + batch.Queue("select sum(amount) from ledger", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + ct, err := batch.ExecResults() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) + } + + ct, err = batch.ExecResults() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) + } + + rows, err := batch.QueryResults() + if err != nil { + t.Error(err) + } + + var id int32 + var description string + var amount int32 + if !rows.Next() { + t.Fatal("expected a row to be available") + } + if err := rows.Scan(&id, &description, &amount); err != nil { + t.Fatal(err) + } + if id != 1 { + t.Errorf("id => %v, want %v", id, 1) + } + if description != "q1" { + t.Errorf("description => %v, want %v", description, "q1") + } + if amount != 1 { + t.Errorf("amount => %v, want %v", amount, 1) + } + + if !rows.Next() { + t.Fatal("expected a row to be available") + } + if err := rows.Scan(&id, &description, &amount); err != nil { + t.Fatal(err) + } + if id != 2 { + t.Errorf("id => %v, want %v", id, 2) + } + if description != "q2" { + t.Errorf("description => %v, want %v", description, "q2") + } + if amount != 2 { + t.Errorf("amount => %v, want %v", amount, 2) + } + + if !rows.Next() { + t.Fatal("expected a row to be available") + } + if err := rows.Scan(&id, &description, &amount); err != nil { + t.Fatal(err) + } + if id != 3 { + t.Errorf("id => %v, want %v", id, 3) + } + if description != "q3" { + t.Errorf("description => %v, want %v", description, "q3") + } + if amount != 3 { + t.Errorf("amount => %v, want %v", amount, 3) + } + + if rows.Next() { + t.Fatal("did not expect a row to be available") + } + + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + + err = batch.QueryRowResults().Scan(&amount) + if err != nil { + t.Error(err) + } + if amount != 6 { + t.Errorf("amount => %v, want %v", amount, 6) + } + + err = batch.Finish() + if err != nil { + t.Fatal(err) + } + + ensureConnValid(t, conn) +} diff --git a/conn.go b/conn.go index 68312222..491f2a9e 100644 --- a/conn.go +++ b/conn.go @@ -107,9 +107,9 @@ type Conn struct { status byte // One of connStatus* constants causeOfDeath error - readyForQuery bool // connection has received ReadyForQuery message since last query was sent - cancelQueryInProgress int32 - cancelQueryCompleted chan struct{} + pendingReadyForQueryCount int // numer of ReadyForQuery messages expected + cancelQueryInProgress int32 + cancelQueryCompleted chan struct{} // context support ctxInProgress bool @@ -329,6 +329,8 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl return err } + c.pendingReadyForQueryCount = 1 + for { msg, err := c.rxMsg() if err != nil { @@ -782,7 +784,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared } return nil, err } - c.readyForQuery = false + c.pendingReadyForQueryCount++ ps = &PreparedStatement{Name: name, SQL: sql} @@ -1004,7 +1006,7 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { c.die(err) return err } - c.readyForQuery = false + c.pendingReadyForQueryCount++ return nil } @@ -1045,7 +1047,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} } return err } - c.readyForQuery = false + c.pendingReadyForQueryCount++ return nil } @@ -1167,7 +1169,7 @@ func (c *Conn) rxBackendKeyData(msg *pgproto3.BackendKeyData) { } func (c *Conn) rxReadyForQuery(msg *pgproto3.ReadyForQuery) { - c.readyForQuery = true + c.pendingReadyForQueryCount-- c.txStatus = msg.TxStatus } @@ -1429,7 +1431,7 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, c.die(err) return "", err } - c.readyForQuery = false + c.pendingReadyForQueryCount++ } else { if len(arguments) > 0 { ps, ok := c.preparedStatements[sql] @@ -1563,7 +1565,7 @@ func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error { } func (c *Conn) ensureConnectionReadyForQuery() error { - for !c.readyForQuery { + for c.pendingReadyForQueryCount > 0 { msg, err := c.rxMsg() if err != nil { return err diff --git a/helper_test.go b/helper_test.go index 21f86de5..78063107 100644 --- a/helper_test.go +++ b/helper_test.go @@ -1,8 +1,9 @@ package pgx_test import ( - "github.com/jackc/pgx" "testing" + + "github.com/jackc/pgx" ) func mustConnect(t testing.TB, config pgx.ConnConfig) *pgx.Conn { diff --git a/query.go b/query.go index 447a55ac..a3903a22 100644 --- a/query.go +++ b/query.go @@ -409,7 +409,7 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, c.die(err) return nil, err } - c.readyForQuery = false + c.pendingReadyForQueryCount++ fieldDescriptions, err := c.readUntilRowDescription() if err != nil {