diff --git a/batch.go b/batch.go index 67dafc29..02f93daa 100644 --- a/batch.go +++ b/batch.go @@ -17,13 +17,14 @@ type batchItem struct { // Batch queries are a way of bundling multiple queries together to avoid // unnecessary network round trips. type Batch struct { - conn *Conn - connPool *ConnPool - items []*batchItem - resultsRead int - ctx context.Context - err error - inTx bool + conn *Conn + connPool *ConnPool + items []*batchItem + resultsRead int + pendingCommandComplete bool + ctx context.Context + err error + inTx bool } // BeginBatch returns a *Batch query for c. @@ -153,8 +154,15 @@ func (b *Batch) ExecResults() (CommandTag, error) { default: } + if err := b.ensureCommandComplete(); err != nil { + b.die(err) + return "", err + } + b.resultsRead++ + b.pendingCommandComplete = true + for { msg, err := b.conn.rxMsg() if err != nil { @@ -163,6 +171,7 @@ func (b *Batch) ExecResults() (CommandTag, error) { switch msg := msg.(type) { case *pgproto3.CommandComplete: + b.pendingCommandComplete = false return CommandTag(msg.CommandTag), nil default: if err := b.conn.processContextFreeMsg(msg); err != nil { @@ -190,8 +199,16 @@ func (b *Batch) QueryResults() (*Rows, error) { default: } + if err := b.ensureCommandComplete(); err != nil { + b.die(err) + rows.fatal(err) + return rows, err + } + b.resultsRead++ + b.pendingCommandComplete = true + fieldDescriptions, err := b.conn.readUntilRowDescription() if err != nil { b.die(err) @@ -252,3 +269,25 @@ func (b *Batch) die(err error) { b.connPool.Release(b.conn) } } + +func (b *Batch) ensureCommandComplete() error { + for b.pendingCommandComplete { + msg, err := b.conn.rxMsg() + if err != nil { + return err + } + + switch msg := msg.(type) { + case *pgproto3.CommandComplete: + b.pendingCommandComplete = false + return nil + default: + err = b.conn.processContextFreeMsg(msg) + if err != nil { + return err + } + } + } + + return nil +} diff --git a/batch_test.go b/batch_test.go index 54785f79..3b51971c 100644 --- a/batch_test.go +++ b/batch_test.go @@ -700,4 +700,4 @@ func TestTxBeginBatchRollback(t *testing.T) { } ensureConnValid(t, conn) -} +} \ No newline at end of file diff --git a/query.go b/query.go index e37e6120..407a792c 100644 --- a/query.go +++ b/query.go @@ -149,6 +149,9 @@ func (rows *Rows) Next() bool { rows.values = msg.Values return true case *pgproto3.CommandComplete: + if rows.batch != nil { + rows.batch.pendingCommandComplete = false + } rows.Close() return false diff --git a/replication.go b/replication.go index bfa81e54..7dd5efe4 100644 --- a/replication.go +++ b/replication.go @@ -440,6 +440,18 @@ func (rc *ReplicationConn) CreateReplicationSlot(slotName, outputPlugin string) return } +// Create the replication slot, using the given name and output plugin, and return the consistent_point and snapshot_name values. +func (rc *ReplicationConn) CreateReplicationSlotEx(slotName, outputPlugin string) (consistentPoint string, snapshotName string, err error) { + var dummy string + var rows *Rows + rows, err = rc.sendReplicationModeQuery(fmt.Sprintf("CREATE_REPLICATION_SLOT %s LOGICAL %s", slotName, outputPlugin)) + defer rows.Close() + for rows.Next() { + rows.Scan(&dummy, &consistentPoint, &snapshotName, &dummy) + } + return +} + // Drop the replication slot for the given name func (rc *ReplicationConn) DropReplicationSlot(slotName string) (err error) { _, err = rc.c.Exec(fmt.Sprintf("DROP_REPLICATION_SLOT %s", slotName)) diff --git a/replication_test.go b/replication_test.go index d75233c1..d06d73cd 100644 --- a/replication_test.go +++ b/replication_test.go @@ -56,10 +56,18 @@ func TestSimpleReplicationConnection(t *testing.T) { replicationConn := mustReplicationConnect(t, *replicationConnConfig) defer closeReplicationConn(t, replicationConn) - err = replicationConn.CreateReplicationSlot("pgx_test", "test_decoding") + var cp string + var snapshot_name string + cp, snapshot_name, err = replicationConn.CreateReplicationSlotEx("pgx_test", "test_decoding") if err != nil { t.Fatalf("replication slot create failed: %v", err) } + if cp == "" { + t.Logf("consistent_point is empty") + } + if snapshot_name == "" { + t.Logf("snapshot_name is empty") + } // Do a simple change so we can get some wal data _, err = conn.Exec("create table if not exists replication_test (a integer)") @@ -178,20 +186,35 @@ func TestReplicationConn_DropReplicationSlot(t *testing.T) { replicationConn := mustReplicationConnect(t, *replicationConnConfig) defer closeReplicationConn(t, replicationConn) - err := replicationConn.CreateReplicationSlot("pgx_slot_test", "test_decoding") + var cp string + var snapshot_name string + cp, snapshot_name, err := replicationConn.CreateReplicationSlotEx("pgx_slot_test", "test_decoding") if err != nil { t.Logf("replication slot create failed: %v", err) } + if cp == "" { + t.Logf("consistent_point is empty") + } + if snapshot_name == "" { + t.Logf("snapshot_name is empty") + } + err = replicationConn.DropReplicationSlot("pgx_slot_test") if err != nil { t.Fatalf("Failed to drop replication slot: %v", err) } // We re-create to ensure the drop worked. - err = replicationConn.CreateReplicationSlot("pgx_slot_test", "test_decoding") + cp, snapshot_name, err = replicationConn.CreateReplicationSlotEx("pgx_slot_test", "test_decoding") if err != nil { t.Logf("replication slot create failed: %v", err) } + if cp == "" { + t.Logf("consistent_point is empty") + } + if snapshot_name == "" { + t.Logf("snapshot_name is empty") + } // And finally we drop to ensure we don't leave dirty state err = replicationConn.DropReplicationSlot("pgx_slot_test")