diff --git a/stdlib/sql.go b/stdlib/sql.go index 439a5262..7e635324 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -170,16 +170,34 @@ func (c *Conn) Close() error { } func (c *Conn) Begin() (driver.Tx, error) { + return c.BeginTx(context.Background(), driver.TxOptions{}) +} + +func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { if !c.conn.IsAlive() { return nil, driver.ErrBadConn } - _, err := c.conn.Exec("begin") - if err != nil { - return nil, err + var pgxOpts pgx.TxOptions + switch sql.IsolationLevel(opts.Isolation) { + case sql.LevelDefault: + case sql.LevelReadUncommitted: + pgxOpts.IsoLevel = pgx.ReadUncommitted + case sql.LevelReadCommitted: + pgxOpts.IsoLevel = pgx.ReadCommitted + case sql.LevelSnapshot: + pgxOpts.IsoLevel = pgx.RepeatableRead + case sql.LevelSerializable: + pgxOpts.IsoLevel = pgx.Serializable + default: + return nil, fmt.Errorf("unsupported isolation: %v", opts.Isolation) } - return &Tx{conn: c.conn}, nil + if opts.ReadOnly { + pgxOpts.AccessMode = pgx.ReadOnly + } + + return c.conn.BeginEx(&pgxOpts) } func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) { @@ -389,17 +407,3 @@ func namedValueToInterface(argsV []driver.NamedValue) []interface{} { } return args } - -type Tx struct { - conn *pgx.Conn -} - -func (t *Tx) Commit() error { - _, err := t.conn.Exec("commit") - return err -} - -func (t *Tx) Rollback() error { - _, err := t.conn.Exec("rollback") - return err -} diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index bdafdd48..fdc93c0a 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -2,6 +2,7 @@ package stdlib_test import ( "bytes" + "context" "database/sql" "testing" @@ -603,3 +604,84 @@ func TestTransactionLifeCycle(t *testing.T) { ensureConnValid(t, db) } + +func TestConnBeginTxIsolation(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + var defaultIsoLevel string + err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel) + if err != nil { + t.Fatalf("QueryRow failed: %v", err) + } + + supportedTests := []struct { + sqlIso sql.IsolationLevel + pgIso string + }{ + {sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel}, + {sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"}, + {sqlIso: sql.LevelReadCommitted, pgIso: "read committed"}, + {sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"}, + {sqlIso: sql.LevelSerializable, pgIso: "serializable"}, + } + for i, tt := range supportedTests { + func() { + tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso}) + if err != nil { + t.Errorf("%d. BeginTx failed: %v", i, err) + return + } + defer tx.Rollback() + + var pgIso string + err = tx.QueryRow("show transaction_isolation").Scan(&pgIso) + if err != nil { + t.Errorf("%d. QueryRow failed: %v", i, err) + } + + if pgIso != tt.pgIso { + t.Errorf("%d. pgIso => %s, want %s", i, pgIso, tt.pgIso) + } + }() + } + + unsupportedTests := []struct { + sqlIso sql.IsolationLevel + }{ + {sqlIso: sql.LevelWriteCommitted}, + {sqlIso: sql.LevelLinearizable}, + } + for i, tt := range unsupportedTests { + tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso}) + if err == nil { + t.Errorf("%d. BeginTx should have failed", i) + tx.Rollback() + } + } + + ensureConnValid(t, db) +} + +func TestConnBeginTxReadOnly(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) + if err != nil { + t.Fatalf("BeginTx failed: %v", err) + } + defer tx.Rollback() + + var pgReadOnly string + err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly) + if err != nil { + t.Errorf("%d. QueryRow failed: %v", err) + } + + if pgReadOnly != "on" { + t.Errorf("pgReadOnly => %s, want %s", pgReadOnly, "on") + } + + ensureConnValid(t, db) +}