diff --git a/connection_pool.go b/connection_pool.go index 1c410796..ee3eb07d 100644 --- a/connection_pool.go +++ b/connection_pool.go @@ -33,6 +33,9 @@ func (p *ConnectionPool) Acquire() (c *Connection) { } func (p *ConnectionPool) Release(c *Connection) { + if c.txStatus != 'I' { + c.Execute("rollback") + } p.connectionChannel <- c } diff --git a/connection_pool_test.go b/connection_pool_test.go index ea978845..d211104b 100644 --- a/connection_pool_test.go +++ b/connection_pool_test.go @@ -108,3 +108,40 @@ func TestPoolAcquireAndReleaseCycle(t *testing.T) { pool.Release(c) } } + +func TestPoolReleaseWithTransactions(t *testing.T) { + pool := createConnectionPool(1) + defer pool.Close() + + var err error + conn := pool.Acquire() + if _, err = conn.Execute("begin"); err != nil { + t.Fatalf("Unexpected error begining transaction: %v", err) + } + if _, err = conn.Execute("select"); err == nil { + t.Fatal("Did not receive expected error") + } + if conn.txStatus != 'E' { + t.Fatalf("Expected txStatus to be 'E', instead it was '%c'", conn.txStatus) + } + + pool.Release(conn) + + if conn.txStatus != 'I' { + t.Fatalf("Expected release to rollback errored transaction, but it did not: '%c'", conn.txStatus) + } + + conn = pool.Acquire() + if _, err = conn.Execute("begin"); err != nil { + t.Fatalf("Unexpected error begining transaction: %v", err) + } + if conn.txStatus != 'T' { + t.Fatalf("Expected txStatus to be 'T', instead it was '%c'", conn.txStatus) + } + + pool.Release(conn) + + if conn.txStatus != 'I' { + t.Fatalf("Expected release to rollback uncommitted transaction, but it did not: '%c'", conn.txStatus) + } +}