Add TxOptions support to stdlib

batch-wip
Jack Christensen 2017-05-06 16:29:37 -05:00
parent ffae1b1345
commit 4cbefbb27e
2 changed files with 104 additions and 18 deletions

View File

@ -170,16 +170,34 @@ func (c *Conn) Close() error {
}
func (c *Conn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
_, err := c.conn.Exec("begin")
if err != nil {
return nil, err
var pgxOpts pgx.TxOptions
switch sql.IsolationLevel(opts.Isolation) {
case sql.LevelDefault:
case sql.LevelReadUncommitted:
pgxOpts.IsoLevel = pgx.ReadUncommitted
case sql.LevelReadCommitted:
pgxOpts.IsoLevel = pgx.ReadCommitted
case sql.LevelSnapshot:
pgxOpts.IsoLevel = pgx.RepeatableRead
case sql.LevelSerializable:
pgxOpts.IsoLevel = pgx.Serializable
default:
return nil, fmt.Errorf("unsupported isolation: %v", opts.Isolation)
}
return &Tx{conn: c.conn}, nil
if opts.ReadOnly {
pgxOpts.AccessMode = pgx.ReadOnly
}
return c.conn.BeginEx(&pgxOpts)
}
func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) {
@ -389,17 +407,3 @@ func namedValueToInterface(argsV []driver.NamedValue) []interface{} {
}
return args
}
type Tx struct {
conn *pgx.Conn
}
func (t *Tx) Commit() error {
_, err := t.conn.Exec("commit")
return err
}
func (t *Tx) Rollback() error {
_, err := t.conn.Exec("rollback")
return err
}

View File

@ -2,6 +2,7 @@ package stdlib_test
import (
"bytes"
"context"
"database/sql"
"testing"
@ -603,3 +604,84 @@ func TestTransactionLifeCycle(t *testing.T) {
ensureConnValid(t, db)
}
func TestConnBeginTxIsolation(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
var defaultIsoLevel string
err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel)
if err != nil {
t.Fatalf("QueryRow failed: %v", err)
}
supportedTests := []struct {
sqlIso sql.IsolationLevel
pgIso string
}{
{sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel},
{sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"},
{sqlIso: sql.LevelReadCommitted, pgIso: "read committed"},
{sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"},
{sqlIso: sql.LevelSerializable, pgIso: "serializable"},
}
for i, tt := range supportedTests {
func() {
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
if err != nil {
t.Errorf("%d. BeginTx failed: %v", i, err)
return
}
defer tx.Rollback()
var pgIso string
err = tx.QueryRow("show transaction_isolation").Scan(&pgIso)
if err != nil {
t.Errorf("%d. QueryRow failed: %v", i, err)
}
if pgIso != tt.pgIso {
t.Errorf("%d. pgIso => %s, want %s", i, pgIso, tt.pgIso)
}
}()
}
unsupportedTests := []struct {
sqlIso sql.IsolationLevel
}{
{sqlIso: sql.LevelWriteCommitted},
{sqlIso: sql.LevelLinearizable},
}
for i, tt := range unsupportedTests {
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
if err == nil {
t.Errorf("%d. BeginTx should have failed", i)
tx.Rollback()
}
}
ensureConnValid(t, db)
}
func TestConnBeginTxReadOnly(t *testing.T) {
db := openDB(t)
defer closeDB(t, db)
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true})
if err != nil {
t.Fatalf("BeginTx failed: %v", err)
}
defer tx.Rollback()
var pgReadOnly string
err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly)
if err != nil {
t.Errorf("%d. QueryRow failed: %v", err)
}
if pgReadOnly != "on" {
t.Errorf("pgReadOnly => %s, want %s", pgReadOnly, "on")
}
ensureConnValid(t, db)
}