diff --git a/replication.go b/replication.go index 6d78a7a1..8923d533 100644 --- a/replication.go +++ b/replication.go @@ -8,10 +8,10 @@ import ( ) const ( - copyBothResponse = 'W' - walData = 'w' - senderKeepalive = 'k' - standbyStatusUpdate = 'r' + copyBothResponse = 'W' + walData = 'w' + senderKeepalive = 'k' + standbyStatusUpdate = 'r' initialReplicationResponseTimeout = 5 * time.Second ) @@ -157,7 +157,7 @@ func ReplicationConnect(config ConnConfig) (r *ReplicationConn, err error) { } config.RuntimeParams["replication"] = "database" - c,err := Connect(config) + c, err := Connect(config) if err != nil { return } @@ -208,6 +208,13 @@ func (rc *ReplicationConn) Close() error { return rc.c.Close() } +func (rc *ReplicationConn) IsAlive() bool { + return rc.c.IsAlive() +} + +func (rc *ReplicationConn) CauseOfDeath() error { + return rc.c.CauseOfDeath() +} func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err error) { var t byte @@ -257,12 +264,12 @@ func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err return &ReplicationMessage{ServerHeartbeat: h}, nil default: if rc.c.shouldLog(LogLevelError) { - rc.c.log(LogLevelError,"Unexpected data playload message type %v", t) + rc.c.log(LogLevelError, "Unexpected data playload message type %v", t) } } default: if rc.c.shouldLog(LogLevelError) { - rc.c.log(LogLevelError,"Unexpected replication message type %v", t) + rc.c.log(LogLevelError, "Unexpected replication message type %v", t) } } return diff --git a/replication_test.go b/replication_test.go index 20572edd..ee187ec2 100644 --- a/replication_test.go +++ b/replication_test.go @@ -1,13 +1,13 @@ package pgx_test import ( + "fmt" "github.com/jackc/pgx" + "reflect" "strconv" "strings" "testing" "time" - "reflect" - "fmt" ) // This function uses a postgresql 9.6 specific column @@ -51,7 +51,7 @@ func TestSimpleReplicationConnection(t *testing.T) { replicationConn := mustReplicationConnect(t, *replicationConnConfig) defer closeReplicationConn(t, replicationConn) - err = replicationConn.CreateReplicationSlot("pgx_test","test_decoding") + err = replicationConn.CreateReplicationSlot("pgx_test", "test_decoding") if err != nil { t.Logf("replication slot create failed: %v", err) } @@ -152,14 +152,23 @@ func TestSimpleReplicationConnection(t *testing.T) { replicationConn.SendStandbyStatus(status) replicationConn.StopReplication() + if replicationConn.IsAlive() == false { + t.Errorf("Connection died: %v", replicationConn.CauseOfDeath()) + } + + err = replicationConn.Close() + if err != nil { + t.Fatalf("Replication connection close failed: %v", err) + } + // Let's push the boundary conditions of the standby status and ensure it errors correctly - status, err = pgx.NewStandbyStatus(0,1,2,3,4) + status, err = pgx.NewStandbyStatus(0, 1, 2, 3, 4) if err == nil { - t.Errorf("Expected error from new standby status, got %v",status) + t.Errorf("Expected error from new standby status, got %v", status) } // And if you provide 3 args, ensure the right fields are set - status, err = pgx.NewStandbyStatus(1,2,3) + status, err = pgx.NewStandbyStatus(1, 2, 3) if err != nil { t.Errorf("Failed to create test status: %v", err) } @@ -173,20 +182,23 @@ func TestSimpleReplicationConnection(t *testing.T) { t.Errorf("Unexpected write position %d", status.WalWritePosition) } - err = replicationConn.Close() - if err != nil { - t.Fatalf("Replication connection close failed: %v", err) - } - restartLsn := getConfirmedFlushLsnFor(t, conn, "pgx_test") integerRestartLsn, _ := pgx.ParseLSN(restartLsn) if integerRestartLsn != maxWal { t.Fatalf("Wal offset update failed, expected %s found %s", pgx.FormatLSN(maxWal), restartLsn) } - _, err = conn.Exec("select pg_drop_replication_slot($1)", "pgx_test") + replicationConn2 := mustReplicationConnect(t, *replicationConnConfig) + defer closeReplicationConn(t, replicationConn2) + + err = replicationConn2.DropReplicationSlot("pgx_test") if err != nil { t.Fatalf("Failed to drop replication slot: %v", err) } + droppedLsn := getConfirmedFlushLsnFor(t, conn, "pgx_test") + if droppedLsn != "" { + t.Errorf("Got odd flush lsn %s for supposedly dropped slot", droppedLsn) + } + }