mirror of https://github.com/jackc/pgx.git
Add TxOptions support to stdlib
parent
ffae1b1345
commit
4cbefbb27e
|
@ -170,16 +170,34 @@ func (c *Conn) Close() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Begin() (driver.Tx, error) {
|
func (c *Conn) Begin() (driver.Tx, error) {
|
||||||
|
return c.BeginTx(context.Background(), driver.TxOptions{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||||
if !c.conn.IsAlive() {
|
if !c.conn.IsAlive() {
|
||||||
return nil, driver.ErrBadConn
|
return nil, driver.ErrBadConn
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := c.conn.Exec("begin")
|
var pgxOpts pgx.TxOptions
|
||||||
if err != nil {
|
switch sql.IsolationLevel(opts.Isolation) {
|
||||||
return nil, err
|
case sql.LevelDefault:
|
||||||
|
case sql.LevelReadUncommitted:
|
||||||
|
pgxOpts.IsoLevel = pgx.ReadUncommitted
|
||||||
|
case sql.LevelReadCommitted:
|
||||||
|
pgxOpts.IsoLevel = pgx.ReadCommitted
|
||||||
|
case sql.LevelSnapshot:
|
||||||
|
pgxOpts.IsoLevel = pgx.RepeatableRead
|
||||||
|
case sql.LevelSerializable:
|
||||||
|
pgxOpts.IsoLevel = pgx.Serializable
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported isolation: %v", opts.Isolation)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Tx{conn: c.conn}, nil
|
if opts.ReadOnly {
|
||||||
|
pgxOpts.AccessMode = pgx.ReadOnly
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.conn.BeginEx(&pgxOpts)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) {
|
func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) {
|
||||||
|
@ -389,17 +407,3 @@ func namedValueToInterface(argsV []driver.NamedValue) []interface{} {
|
||||||
}
|
}
|
||||||
return args
|
return args
|
||||||
}
|
}
|
||||||
|
|
||||||
type Tx struct {
|
|
||||||
conn *pgx.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) Commit() error {
|
|
||||||
_, err := t.conn.Exec("commit")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) Rollback() error {
|
|
||||||
_, err := t.conn.Exec("rollback")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ package stdlib_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -603,3 +604,84 @@ func TestTransactionLifeCycle(t *testing.T) {
|
||||||
|
|
||||||
ensureConnValid(t, db)
|
ensureConnValid(t, db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnBeginTxIsolation(t *testing.T) {
|
||||||
|
db := openDB(t)
|
||||||
|
defer closeDB(t, db)
|
||||||
|
|
||||||
|
var defaultIsoLevel string
|
||||||
|
err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("QueryRow failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
supportedTests := []struct {
|
||||||
|
sqlIso sql.IsolationLevel
|
||||||
|
pgIso string
|
||||||
|
}{
|
||||||
|
{sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel},
|
||||||
|
{sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"},
|
||||||
|
{sqlIso: sql.LevelReadCommitted, pgIso: "read committed"},
|
||||||
|
{sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"},
|
||||||
|
{sqlIso: sql.LevelSerializable, pgIso: "serializable"},
|
||||||
|
}
|
||||||
|
for i, tt := range supportedTests {
|
||||||
|
func() {
|
||||||
|
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%d. BeginTx failed: %v", i, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
var pgIso string
|
||||||
|
err = tx.QueryRow("show transaction_isolation").Scan(&pgIso)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%d. QueryRow failed: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pgIso != tt.pgIso {
|
||||||
|
t.Errorf("%d. pgIso => %s, want %s", i, pgIso, tt.pgIso)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
unsupportedTests := []struct {
|
||||||
|
sqlIso sql.IsolationLevel
|
||||||
|
}{
|
||||||
|
{sqlIso: sql.LevelWriteCommitted},
|
||||||
|
{sqlIso: sql.LevelLinearizable},
|
||||||
|
}
|
||||||
|
for i, tt := range unsupportedTests {
|
||||||
|
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("%d. BeginTx should have failed", i)
|
||||||
|
tx.Rollback()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnBeginTxReadOnly(t *testing.T) {
|
||||||
|
db := openDB(t)
|
||||||
|
defer closeDB(t, db)
|
||||||
|
|
||||||
|
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("BeginTx failed: %v", err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
var pgReadOnly string
|
||||||
|
err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%d. QueryRow failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pgReadOnly != "on" {
|
||||||
|
t.Errorf("pgReadOnly => %s, want %s", pgReadOnly, "on")
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, db)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue