mirror of https://github.com/jackc/pgx.git
436 lines
10 KiB
Go
436 lines
10 KiB
Go
package pgx_test
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx"
|
|
"github.com/jackc/pgx/pgconn"
|
|
"github.com/jackc/pgx/pgmock"
|
|
"github.com/jackc/pgx/pgproto3"
|
|
)
|
|
|
|
func TestTransactionSuccessfulCommit(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
|
defer closeConn(t, conn)
|
|
|
|
createSql := `
|
|
create temporary table foo(
|
|
id integer,
|
|
unique (id) initially deferred
|
|
);
|
|
`
|
|
|
|
if _, err := conn.Exec(createSql); err != nil {
|
|
t.Fatalf("Failed to create table: %v", err)
|
|
}
|
|
|
|
tx, err := conn.Begin()
|
|
if err != nil {
|
|
t.Fatalf("conn.Begin failed: %v", err)
|
|
}
|
|
|
|
_, err = tx.Exec("insert into foo(id) values (1)")
|
|
if err != nil {
|
|
t.Fatalf("tx.Exec failed: %v", err)
|
|
}
|
|
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
t.Fatalf("tx.Commit failed: %v", err)
|
|
}
|
|
|
|
var n int64
|
|
err = conn.QueryRow("select count(*) from foo").Scan(&n)
|
|
if err != nil {
|
|
t.Fatalf("QueryRow Scan failed: %v", err)
|
|
}
|
|
if n != 1 {
|
|
t.Fatalf("Did not receive correct number of rows: %v", n)
|
|
}
|
|
}
|
|
|
|
func TestTxCommitWhenTxBroken(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
|
defer closeConn(t, conn)
|
|
|
|
createSql := `
|
|
create temporary table foo(
|
|
id integer,
|
|
unique (id) initially deferred
|
|
);
|
|
`
|
|
|
|
if _, err := conn.Exec(createSql); err != nil {
|
|
t.Fatalf("Failed to create table: %v", err)
|
|
}
|
|
|
|
tx, err := conn.Begin()
|
|
if err != nil {
|
|
t.Fatalf("conn.Begin failed: %v", err)
|
|
}
|
|
|
|
if _, err := tx.Exec("insert into foo(id) values (1)"); err != nil {
|
|
t.Fatalf("tx.Exec failed: %v", err)
|
|
}
|
|
|
|
// Purposely break transaction
|
|
if _, err := tx.Exec("syntax error"); err == nil {
|
|
t.Fatal("Unexpected success")
|
|
}
|
|
|
|
err = tx.Commit()
|
|
if err != pgx.ErrTxCommitRollback {
|
|
t.Fatalf("Expected error %v, got %v", pgx.ErrTxCommitRollback, err)
|
|
}
|
|
|
|
var n int64
|
|
err = conn.QueryRow("select count(*) from foo").Scan(&n)
|
|
if err != nil {
|
|
t.Fatalf("QueryRow Scan failed: %v", err)
|
|
}
|
|
if n != 0 {
|
|
t.Fatalf("Did not receive correct number of rows: %v", n)
|
|
}
|
|
}
|
|
|
|
func TestTxCommitSerializationFailure(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
pool := createConnPool(t, 5)
|
|
defer pool.Close()
|
|
|
|
pool.Exec(`drop table if exists tx_serializable_sums`)
|
|
_, err := pool.Exec(`create table tx_serializable_sums(num integer);`)
|
|
if err != nil {
|
|
t.Fatalf("Unable to create temporary table: %v", err)
|
|
}
|
|
defer pool.Exec(`drop table tx_serializable_sums`)
|
|
|
|
tx1, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable})
|
|
if err != nil {
|
|
t.Fatalf("BeginEx failed: %v", err)
|
|
}
|
|
defer tx1.Rollback()
|
|
|
|
tx2, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable})
|
|
if err != nil {
|
|
t.Fatalf("BeginEx failed: %v", err)
|
|
}
|
|
defer tx2.Rollback()
|
|
|
|
_, err = tx1.Exec(`insert into tx_serializable_sums(num) select sum(num) from tx_serializable_sums`)
|
|
if err != nil {
|
|
t.Fatalf("Exec failed: %v", err)
|
|
}
|
|
|
|
_, err = tx2.Exec(`insert into tx_serializable_sums(num) select sum(num) from tx_serializable_sums`)
|
|
if err != nil {
|
|
t.Fatalf("Exec failed: %v", err)
|
|
}
|
|
|
|
err = tx1.Commit()
|
|
if err != nil {
|
|
t.Fatalf("Commit failed: %v", err)
|
|
}
|
|
|
|
err = tx2.Commit()
|
|
if pgErr, ok := err.(pgx.PgError); !ok || pgErr.Code != "40001" {
|
|
t.Fatalf("Expected serialization error 40001, got %#v", err)
|
|
}
|
|
}
|
|
|
|
func TestTransactionSuccessfulRollback(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
|
defer closeConn(t, conn)
|
|
|
|
createSql := `
|
|
create temporary table foo(
|
|
id integer,
|
|
unique (id) initially deferred
|
|
);
|
|
`
|
|
|
|
if _, err := conn.Exec(createSql); err != nil {
|
|
t.Fatalf("Failed to create table: %v", err)
|
|
}
|
|
|
|
tx, err := conn.Begin()
|
|
if err != nil {
|
|
t.Fatalf("conn.Begin failed: %v", err)
|
|
}
|
|
|
|
_, err = tx.Exec("insert into foo(id) values (1)")
|
|
if err != nil {
|
|
t.Fatalf("tx.Exec failed: %v", err)
|
|
}
|
|
|
|
err = tx.Rollback()
|
|
if err != nil {
|
|
t.Fatalf("tx.Rollback failed: %v", err)
|
|
}
|
|
|
|
var n int64
|
|
err = conn.QueryRow("select count(*) from foo").Scan(&n)
|
|
if err != nil {
|
|
t.Fatalf("QueryRow Scan failed: %v", err)
|
|
}
|
|
if n != 0 {
|
|
t.Fatalf("Did not receive correct number of rows: %v", n)
|
|
}
|
|
}
|
|
|
|
func TestBeginExIsoLevels(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
|
defer closeConn(t, conn)
|
|
|
|
isoLevels := []pgx.TxIsoLevel{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted}
|
|
for _, iso := range isoLevels {
|
|
tx, err := conn.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: iso})
|
|
if err != nil {
|
|
t.Fatalf("conn.BeginEx failed: %v", err)
|
|
}
|
|
|
|
var level pgx.TxIsoLevel
|
|
conn.QueryRow("select current_setting('transaction_isolation')").Scan(&level)
|
|
if level != iso {
|
|
t.Errorf("Expected to be in isolation level %v but was %v", iso, level)
|
|
}
|
|
|
|
err = tx.Rollback()
|
|
if err != nil {
|
|
t.Fatalf("tx.Rollback failed: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestBeginExReadOnly(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
|
defer closeConn(t, conn)
|
|
|
|
tx, err := conn.BeginEx(context.Background(), &pgx.TxOptions{AccessMode: pgx.ReadOnly})
|
|
if err != nil {
|
|
t.Fatalf("conn.BeginEx failed: %v", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
_, err = conn.Exec("create table foo(id serial primary key)")
|
|
if pgErr, ok := err.(pgx.PgError); !ok || pgErr.Code != "25006" {
|
|
t.Errorf("Expected error SQLSTATE 25006, but got %#v", err)
|
|
}
|
|
}
|
|
|
|
func TestConnBeginExContextCancel(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
script := &pgmock.Script{
|
|
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
|
|
}
|
|
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
|
|
script.Steps = append(script.Steps,
|
|
pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}),
|
|
pgmock.WaitForClose(),
|
|
)
|
|
|
|
server, err := pgmock.NewServer(script)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer server.Close()
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
errChan <- server.ServeOne()
|
|
}()
|
|
|
|
pc, err := pgconn.ParseConfig(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
conn := mustConnect(t, pgx.ConnConfig{Config: *pc})
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
|
defer cancel()
|
|
|
|
_, err = conn.BeginEx(ctx, nil)
|
|
if err != context.DeadlineExceeded {
|
|
t.Errorf("err => %v, want %v", err, context.DeadlineExceeded)
|
|
}
|
|
|
|
if conn.IsAlive() {
|
|
t.Error("expected conn to be dead after BeginEx failure")
|
|
}
|
|
|
|
if err := <-errChan; err != nil {
|
|
t.Errorf("mock server err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestTxCommitExCancel(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
script := &pgmock.Script{
|
|
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
|
|
}
|
|
script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
|
|
script.Steps = append(script.Steps,
|
|
pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}),
|
|
pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: "BEGIN"}),
|
|
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'T'}),
|
|
pgmock.WaitForClose(),
|
|
)
|
|
|
|
server, err := pgmock.NewServer(script)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer server.Close()
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
errChan <- server.ServeOne()
|
|
}()
|
|
|
|
pc, err := pgconn.ParseConfig(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
conn := mustConnect(t, pgx.ConnConfig{Config: *pc})
|
|
defer conn.Close()
|
|
|
|
tx, err := conn.Begin()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
|
defer cancel()
|
|
err = tx.CommitEx(ctx)
|
|
if err != context.DeadlineExceeded {
|
|
t.Errorf("err => %v, want %v", err, context.DeadlineExceeded)
|
|
}
|
|
|
|
if conn.IsAlive() {
|
|
t.Error("expected conn to be dead after CommitEx failure")
|
|
}
|
|
|
|
if err := <-errChan; err != nil {
|
|
t.Errorf("mock server err: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestTxStatus(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
|
defer closeConn(t, conn)
|
|
|
|
tx, err := conn.Begin()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if status := tx.Status(); status != pgx.TxStatusInProgress {
|
|
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInProgress, status)
|
|
}
|
|
|
|
if err := tx.Rollback(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if status := tx.Status(); status != pgx.TxStatusRollbackSuccess {
|
|
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusRollbackSuccess, status)
|
|
}
|
|
}
|
|
|
|
func TestTxStatusErrorInTransactions(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
|
defer closeConn(t, conn)
|
|
|
|
tx, err := conn.Begin()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if status := tx.Status(); status != pgx.TxStatusInProgress {
|
|
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInProgress, status)
|
|
}
|
|
|
|
_, err = tx.Exec("savepoint s")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, err = tx.Exec("syntax error")
|
|
if err == nil {
|
|
t.Fatal("expected an error but did not get one")
|
|
}
|
|
|
|
if status := tx.Status(); status != pgx.TxStatusInFailure {
|
|
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInFailure, status)
|
|
}
|
|
|
|
_, err = tx.Exec("rollback to s")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if status := tx.Status(); status != pgx.TxStatusInProgress {
|
|
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInProgress, status)
|
|
}
|
|
|
|
if err := tx.Rollback(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if status := tx.Status(); status != pgx.TxStatusRollbackSuccess {
|
|
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusRollbackSuccess, status)
|
|
}
|
|
}
|
|
|
|
func TestTxErr(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
|
defer closeConn(t, conn)
|
|
|
|
tx, err := conn.Begin()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Purposely break transaction
|
|
if _, err := tx.Exec("syntax error"); err == nil {
|
|
t.Fatal("Unexpected success")
|
|
}
|
|
|
|
if err := tx.Commit(); err != pgx.ErrTxCommitRollback {
|
|
t.Fatalf("Expected error %v, got %v", pgx.ErrTxCommitRollback, err)
|
|
}
|
|
|
|
if status := tx.Status(); status != pgx.TxStatusCommitFailure {
|
|
t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusRollbackSuccess, status)
|
|
}
|
|
|
|
if err := tx.Err(); err != pgx.ErrTxCommitRollback {
|
|
t.Fatalf("Expected error %v, got %v", pgx.ErrTxCommitRollback, err)
|
|
}
|
|
}
|