diff --git a/bench_test.go b/bench_test.go index c5418fbb..d555da45 100644 --- a/bench_test.go +++ b/bench_test.go @@ -73,9 +73,7 @@ func BenchmarkSelectRowSimpleNarrow(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err := conn.SelectRow("select * from narrow where id=$1", ids[i]); err != nil { - b.Fatalf("Failure while benchmarking: %v", err) - } + _ = mustSelectRow(b, conn, "select * from narrow where id=$1", ids[i]) } } @@ -91,9 +89,7 @@ func BenchmarkSelectRowPreparedNarrow(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err := conn.SelectRow("getNarrowById", ids[i]); err != nil { - b.Fatalf("Failure while benchmarking: %v", err) - } + mustSelectRow(b, conn, "getNarrowById", ids[i]) } } @@ -109,9 +105,7 @@ func BenchmarkSelectRowsSimpleNarrow(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err := conn.SelectRows("select * from narrow where id between $1 and $2", ids[i], ids[i]+10); err != nil { - b.Fatalf("Failure while benchmarking: %v", err) - } + mustSelectRows(b, conn, "select * from narrow where id between $1 and $2", ids[i], ids[i]+10) } } @@ -127,9 +121,7 @@ func BenchmarkSelectRowsPreparedNarrow(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err := conn.SelectRows("getMultipleNarrowById", ids[i], ids[i]+10); err != nil { - b.Fatalf("Failure while benchmarking: %v", err) - } + mustSelectRows(b, conn, "getMultipleNarrowById", ids[i], ids[i]+10) } } @@ -138,7 +130,7 @@ func createJoinsTestData(b *testing.B, conn *Connection) { return } - if _, err := conn.Execute(` + mustExecute(b, conn, ` drop table if exists product_component; drop table if exists component; drop table if exists product; @@ -189,9 +181,7 @@ func createJoinsTestData(b *testing.B, conn *Connection) { create index on product_component(component_id); analyze; - `); err != nil { - panic(fmt.Sprintf("Unable to create test data: %v", err)) - } + `) mustPrepare(b, conn, "joinAggregate", ` select product.id, sum(cost*quantity) as total_cost @@ -222,9 +212,7 @@ func BenchmarkSelectRowsSimpleJoins(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err := conn.SelectRows(sql); err != nil { - b.Fatalf("Failure while benchmarking: %v", err) - } + mustSelectRows(b, conn, sql) } } @@ -234,9 +222,7 @@ func BenchmarkSelectRowsPreparedJoins(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err := conn.SelectRows("joinAggregate"); err != nil { - b.Fatalf("Failure while benchmarking: %v", err) - } + mustSelectRows(b, conn, "joinAggregate") } } @@ -245,7 +231,7 @@ func createInt2TextVsBinaryTestData(b *testing.B, conn *Connection) { return } - if _, err := conn.Execute(` + mustExecute(b, conn, ` drop table if exists t; create temporary table t( @@ -260,9 +246,7 @@ func createInt2TextVsBinaryTestData(b *testing.B, conn *Connection) { select (random() * 32000)::int2, (random() * 32000)::int2, (random() * 32000)::int2, (random() * 32000)::int2, (random() * 32000)::int2 from generate_series(1, 10); - `); err != nil { - b.Fatalf("Could not set up test data: %v", err) - } + `) int2TextVsBinaryTestDataLoaded = true } @@ -279,9 +263,7 @@ func BenchmarkInt2Text(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err := conn.SelectRows("selectInt16"); err != nil { - b.Fatalf("Failure while benchmarking: %v", err) - } + mustSelectRows(b, conn, "selectInt16") } } @@ -293,9 +275,7 @@ func BenchmarkInt2Binary(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err := conn.SelectRows("selectInt16"); err != nil { - b.Fatalf("Failure while benchmarking: %v", err) - } + mustSelectRows(b, conn, "selectInt16") } } @@ -304,7 +284,7 @@ func createInt4TextVsBinaryTestData(b *testing.B, conn *Connection) { return } - if _, err := conn.Execute(` + mustExecute(b, conn, ` drop table if exists t; create temporary table t( @@ -319,9 +299,7 @@ func createInt4TextVsBinaryTestData(b *testing.B, conn *Connection) { select (random() * 1000000)::int4, (random() * 1000000)::int4, (random() * 1000000)::int4, (random() * 1000000)::int4, (random() * 1000000)::int4 from generate_series(1, 10); - `); err != nil { - b.Fatalf("Could not set up test data: %v", err) - } + `) int4TextVsBinaryTestDataLoaded = true } @@ -338,9 +316,7 @@ func BenchmarkInt4Text(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err := conn.SelectRows("selectInt32"); err != nil { - b.Fatalf("Failure while benchmarking: %v", err) - } + mustSelectRows(b, conn, "selectInt32") } } @@ -352,9 +328,7 @@ func BenchmarkInt4Binary(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err := conn.SelectRows("selectInt32"); err != nil { - b.Fatalf("Failure while benchmarking: %v", err) - } + mustSelectRows(b, conn, "selectInt32") } } @@ -363,7 +337,7 @@ func createInt8TextVsBinaryTestData(b *testing.B, conn *Connection) { return } - if _, err := conn.Execute(` + mustExecute(b, conn, ` drop table if exists t; create temporary table t( @@ -378,9 +352,7 @@ func createInt8TextVsBinaryTestData(b *testing.B, conn *Connection) { select (random() * 1000000)::int8, (random() * 1000000)::int8, (random() * 1000000)::int8, (random() * 1000000)::int8, (random() * 1000000)::int8 from generate_series(1, 10); - `); err != nil { - b.Fatalf("Could not set up test data: %v", err) - } + `) int8TextVsBinaryTestDataLoaded = true } @@ -397,9 +369,7 @@ func BenchmarkInt8Text(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err := conn.SelectRows("selectInt64"); err != nil { - b.Fatalf("Failure while benchmarking: %v", err) - } + mustSelectRows(b, conn, "selectInt64") } } @@ -411,9 +381,7 @@ func BenchmarkInt8Binary(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err := conn.SelectRows("selectInt64"); err != nil { - b.Fatalf("Failure while benchmarking: %v", err) - } + mustSelectRows(b, conn, "selectInt64") } } @@ -422,7 +390,7 @@ func createFloat4TextVsBinaryTestData(b *testing.B, conn *Connection) { return } - if _, err := conn.Execute(` + mustExecute(b, conn, ` drop table if exists t; create temporary table t( @@ -437,9 +405,7 @@ func createFloat4TextVsBinaryTestData(b *testing.B, conn *Connection) { select (random() * 1000000)::float4, (random() * 1000000)::float4, (random() * 1000000)::float4, (random() * 1000000)::float4, (random() * 1000000)::float4 from generate_series(1, 10); - `); err != nil { - b.Fatalf("Could not set up test data: %v", err) - } + `) float4TextVsBinaryTestDataLoaded = true } @@ -456,9 +422,7 @@ func BenchmarkFloat4Text(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err := conn.SelectRows("selectFloat32"); err != nil { - b.Fatalf("Failure while benchmarking: %v", err) - } + mustSelectRows(b, conn, "selectFloat32") } } @@ -470,9 +434,7 @@ func BenchmarkFloat4Binary(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err := conn.SelectRows("selectFloat32"); err != nil { - b.Fatalf("Failure while benchmarking: %v", err) - } + mustSelectRows(b, conn, "selectFloat32") } } @@ -481,7 +443,7 @@ func createFloat8TextVsBinaryTestData(b *testing.B, conn *Connection) { return } - if _, err := conn.Execute(` + mustExecute(b, conn, ` drop table if exists t; create temporary table t( @@ -496,9 +458,7 @@ func createFloat8TextVsBinaryTestData(b *testing.B, conn *Connection) { select (random() * 1000000)::float8, (random() * 1000000)::float8, (random() * 1000000)::float8, (random() * 1000000)::float8, (random() * 1000000)::float8 from generate_series(1, 10); - `); err != nil { - b.Fatalf("Could not set up test data: %v", err) - } + `) float8TextVsBinaryTestDataLoaded = true } @@ -515,9 +475,7 @@ func BenchmarkFloat8Text(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err := conn.SelectRows("selectFloat32"); err != nil { - b.Fatalf("Failure while benchmarking: %v", err) - } + mustSelectRows(b, conn, "selectFloat32") } } @@ -529,9 +487,7 @@ func BenchmarkFloat8Binary(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err := conn.SelectRows("selectFloat32"); err != nil { - b.Fatalf("Failure while benchmarking: %v", err) - } + mustSelectRows(b, conn, "selectFloat32") } } @@ -540,7 +496,7 @@ func createBoolTextVsBinaryTestData(b *testing.B, conn *Connection) { return } - if _, err := conn.Execute(` + mustExecute(b, conn, ` drop table if exists t; create temporary table t( @@ -555,9 +511,7 @@ func createBoolTextVsBinaryTestData(b *testing.B, conn *Connection) { select random() > 0.5, random() > 0.5, random() > 0.5, random() > 0.5, random() > 0.5 from generate_series(1, 10); - `); err != nil { - b.Fatalf("Could not set up test data: %v", err) - } + `) boolTextVsBinaryTestDataLoaded = true } @@ -574,9 +528,7 @@ func BenchmarkBoolText(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err := conn.SelectRows("selectBool"); err != nil { - b.Fatalf("Failure while benchmarking: %v", err) - } + mustSelectRows(b, conn, "selectBool") } } @@ -588,8 +540,6 @@ func BenchmarkBoolBinary(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err := conn.SelectRows("selectBool"); err != nil { - b.Fatalf("Failure while benchmarking: %v", err) - } + mustSelectRows(b, conn, "selectBool") } } diff --git a/connection_pool_test.go b/connection_pool_test.go index d211104b..bfc76dbf 100644 --- a/connection_pool_test.go +++ b/connection_pool_test.go @@ -43,13 +43,8 @@ func TestPoolAcquireAndReleaseCycle(t *testing.T) { allConnections := acquireAll() for _, c := range allConnections { - var err error - if _, err = c.Execute("create temporary table t(counter integer not null)"); err != nil { - t.Fatal("Unable to create temp table:" + err.Error()) - } - if _, err = c.Execute("insert into t(counter) values(0);"); err != nil { - t.Fatal("Unable to insert initial counter row: " + err.Error()) - } + mustExecute(t, c, "create temporary table t(counter integer not null)") + mustExecute(t, c, "insert into t(counter) values(0);") } for _, c := range allConnections { @@ -65,10 +60,7 @@ func TestPoolAcquireAndReleaseCycle(t *testing.T) { defer pool.Release(conn) // Increment counter... - _, err = conn.Execute("update t set counter = counter + 1") - if err != nil { - t.Fatal("Unable to update counter: " + err.Error()) - } + mustExecute(t, conn, "update t set counter = counter + 1") completeSync <- 0 } @@ -86,11 +78,7 @@ func TestPoolAcquireAndReleaseCycle(t *testing.T) { allConnections = acquireAll() for _, c := range allConnections { - v, err := c.SelectValue("select counter from t") - if err != nil { - t.Fatal("Unable to read back execution counter: " + err.Error()) - } - + v := mustSelectValue(t, c, "select counter from t") n := v.(int32) if n == 0 { t.Error("A connection was never used") @@ -115,9 +103,7 @@ func TestPoolReleaseWithTransactions(t *testing.T) { var err error conn := pool.Acquire() - if _, err = conn.Execute("begin"); err != nil { - t.Fatalf("Unexpected error begining transaction: %v", err) - } + mustExecute(t, conn, "begin") if _, err = conn.Execute("select"); err == nil { t.Fatal("Did not receive expected error") } @@ -132,9 +118,7 @@ func TestPoolReleaseWithTransactions(t *testing.T) { } conn = pool.Acquire() - if _, err = conn.Execute("begin"); err != nil { - t.Fatalf("Unexpected error begining transaction: %v", err) - } + mustExecute(t, conn, "begin") if conn.txStatus != 'T' { t.Fatalf("Expected txStatus to be 'T', instead it was '%c'", conn.txStatus) } diff --git a/connection_test.go b/connection_test.go index a79d3e2c..9e3baf42 100644 --- a/connection_test.go +++ b/connection_test.go @@ -44,8 +44,7 @@ func TestConnect(t *testing.T) { t.Errorf("Did not connect to specified database (%v)", defaultConnectionParameters.Database) } - rows, err = conn.SelectRows("select current_user") - if err != nil || rows[0]["current_user"] != defaultConnectionParameters.User { + if user := mustSelectValue(t, conn, "select current_user"); user != defaultConnectionParameters.User { t.Errorf("Did not connect as specified user (%v)", defaultConnectionParameters.User) } @@ -137,46 +136,26 @@ func TestConnectWithMD5Password(t *testing.T) { func TestExecute(t *testing.T) { conn := getSharedConnection() - results, err := conn.Execute("create temporary table foo(id integer primary key);") - if err != nil { - t.Fatal("Execute failed: " + err.Error()) - } - if results != "CREATE TABLE" { + if results := mustExecute(t, conn, "create temporary table foo(id integer primary key);"); results != "CREATE TABLE" { t.Error("Unexpected results from Execute") } // Accept parameters - results, err = conn.Execute("insert into foo(id) values($1)", 1) - if err != nil { - t.Errorf("Execute failed: %v", err) - } - if results != "INSERT 0 1" { + if results := mustExecute(t, conn, "insert into foo(id) values($1)", 1); results != "INSERT 0 1" { t.Errorf("Unexpected results from Execute: %v", results) } - results, err = conn.Execute("drop table foo;") - if err != nil { - t.Fatal("Execute failed: " + err.Error()) - } - if results != "DROP TABLE" { + if results := mustExecute(t, conn, "drop table foo;"); results != "DROP TABLE" { t.Error("Unexpected results from Execute") } // Multiple statements can be executed -- last command tag is returned - results, err = conn.Execute("create temporary table foo(id serial primary key); drop table foo;") - if err != nil { - t.Fatal("Execute failed: " + err.Error()) - } - if results != "DROP TABLE" { + if results := mustExecute(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); results != "DROP TABLE" { t.Error("Unexpected results from Execute") } // Can execute longer SQL strings than sharedBufferSize - results, err = conn.Execute(strings.Repeat("select 42; ", 1000)) - if err != nil { - t.Fatal("Execute failed: " + err.Error()) - } - if results != "SELECT 1" { + if results := mustExecute(t, conn, strings.Repeat("select 42; ", 1000)); results != "SELECT 1" { t.Errorf("Unexpected results from Execute: %v", results) } } @@ -239,10 +218,7 @@ func TestSelectFuncFailure(t *testing.T) { func TestSelectRows(t *testing.T) { conn := getSharedConnection() - rows, err := conn.SelectRows("select $1 as name, null as position", "Jack") - if err != nil { - t.Fatal("Query failed") - } + rows := mustSelectRows(t, conn, "select $1 as name, null as position", "Jack") if len(rows) != 1 { t.Fatal("Received wrong number of rows") @@ -264,11 +240,7 @@ func TestSelectRows(t *testing.T) { func TestSelectRow(t *testing.T) { conn := getSharedConnection() - row, err := conn.SelectRow("select $1 as name, null as position", "Jack") - if err != nil { - t.Fatal("Query failed") - } - + row := mustSelectRow(t, conn, "select $1 as name, null as position", "Jack") if row["name"] != "Jack" { t.Error("Received incorrect name") } @@ -281,7 +253,7 @@ func TestSelectRow(t *testing.T) { t.Error("Null value should have been present in map as nil") } - _, err = conn.SelectRow("select 'Jack' as name where 1=2") + _, err := conn.SelectRow("select 'Jack' as name where 1=2") if _, ok := err.(NotSingleRowError); !ok { t.Error("No matching row should have returned NotSingleRowError") } @@ -476,9 +448,7 @@ func TestTransaction(t *testing.T) { // Transaction happy path -- it executes function and commits committed, err = conn.Transaction(func() bool { - if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil { - t.Fatalf("Failed to insert into table: %v", err) - } + mustExecute(t, conn, "insert into foo(id) values (1)") return true }) if err != nil { @@ -489,24 +459,16 @@ func TestTransaction(t *testing.T) { } var n interface{} - n, err = conn.SelectValue("select count(*) from foo") - if err != nil { - t.Fatalf("Unexpected error selecting value from foo: %v", err) - } + n = mustSelectValue(t, conn, "select count(*) from foo") if n.(int64) != 1 { t.Fatalf("Did not receive correct number of rows: %v", n) } - _, err = conn.Execute("truncate foo") - if err != nil { - t.Fatalf("Unexpected error truncating foo: %v", err) - } + mustExecute(t, conn, "truncate foo") // It rolls back when passed function returns false committed, err = conn.Transaction(func() bool { - if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil { - t.Fatalf("Failed to insert into table: %v", err) - } + mustExecute(t, conn, "insert into foo(id) values (1)") return false }) if err != nil { @@ -515,19 +477,14 @@ func TestTransaction(t *testing.T) { if committed { t.Fatal("Transaction should not have been committed") } - n, err = conn.SelectValue("select count(*) from foo") - if err != nil { - t.Fatalf("Unexpected error selecting value from foo: %v", err) - } + n = mustSelectValue(t, conn, "select count(*) from foo") if n.(int64) != 0 { t.Fatalf("Did not receive correct number of rows: %v", n) } // it rolls back changes when connection is in error state committed, err = conn.Transaction(func() bool { - if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil { - t.Fatalf("Failed to insert into table: %v", err) - } + mustExecute(t, conn, "insert into foo(id) values (1)") if _, err := conn.Execute("invalid"); err == nil { t.Fatal("Execute was supposed to error but didn't") } @@ -539,22 +496,15 @@ func TestTransaction(t *testing.T) { if committed { t.Fatal("Transaction was committed when it shouldn't have been") } - n, err = conn.SelectValue("select count(*) from foo") - if err != nil { - t.Fatalf("Unexpected error selecting value from foo: %v", err) - } + n = mustSelectValue(t, conn, "select count(*) from foo") if n.(int64) != 0 { t.Fatalf("Did not receive correct number of rows: %v", n) } // when commit fails committed, err = conn.Transaction(func() bool { - if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil { - t.Fatalf("Failed to insert into table: %v", err) - } - if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil { - t.Fatalf("Failed to insert into table: %v", err) - } + mustExecute(t, conn, "insert into foo(id) values (1)") + mustExecute(t, conn, "insert into foo(id) values (1)") return true }) if err == nil { @@ -564,10 +514,7 @@ func TestTransaction(t *testing.T) { t.Fatal("Transaction was committed when it should have failed") } - n, err = conn.SelectValue("select count(*) from foo") - if err != nil { - t.Fatalf("Unexpected error selecting value from foo: %v", err) - } + n = mustSelectValue(t, conn, "select count(*) from foo") if n.(int64) != 0 { t.Fatalf("Did not receive correct number of rows: %v", n) } @@ -579,17 +526,11 @@ func TestTransaction(t *testing.T) { }() committed, err = conn.Transaction(func() bool { - if _, err := conn.Execute("insert into foo(id) values (1)"); err != nil { - t.Fatalf("Failed to insert into table: %v", err) - } + mustExecute(t, conn, "insert into foo(id) values (1)") panic("stop!") - return true }) - n, err = conn.SelectValue("select count(*) from foo") - if err != nil { - t.Fatalf("Unexpected error selecting value from foo: %v", err) - } + n = mustSelectValue(t, conn, "select count(*) from foo") if n.(int64) != 0 { t.Fatalf("Did not receive correct number of rows: %v", n) } diff --git a/helper_test.go b/helper_test.go index bc2eb945..4e429e66 100644 --- a/helper_test.go +++ b/helper_test.go @@ -9,3 +9,35 @@ func mustPrepare(t test, conn *Connection, name, sql string) { t.Fatalf("Could not prepare %v: %v", name, err) } } + +func mustExecute(t test, conn *Connection, sql string, arguments ...interface{}) (commandTag string) { + var err error + if commandTag, err = conn.Execute(sql, arguments...); err != nil { + t.Fatalf("Execute unexpectedly failed with %v: %v", sql, err) + } + return +} + +func mustSelectRow(t test, conn *Connection, sql string, arguments ...interface{}) (row map[string]interface{}) { + var err error + if row, err = conn.SelectRow(sql, arguments...); err != nil { + t.Fatalf("SelectRow unexpectedly failed with %v: %v", sql, err) + } + return +} + +func mustSelectRows(t test, conn *Connection, sql string, arguments ...interface{}) (rows []map[string]interface{}) { + var err error + if rows, err = conn.SelectRows(sql, arguments...); err != nil { + t.Fatalf("SelectRows unexpected failed with %v: %v", sql, err) + } + return +} + +func mustSelectValue(t test, conn *Connection, sql string, arguments ...interface{}) (value interface{}) { + var err error + if value, err = conn.SelectValue(sql, arguments...); err != nil { + t.Fatalf("SelectValue unexpectedly failed with %v: %v", sql, err) + } + return +}