Add guards against usage of busy connection

log-bench
Jack Christensen 2015-09-16 10:22:16 -05:00
parent 51407590eb
commit 86837e5576
4 changed files with 101 additions and 20 deletions

40
conn.go
View File

@ -62,6 +62,7 @@ type Conn struct {
fp *fastpath
pgsql_af_inet byte
pgsql_af_inet6 byte
busy bool
}
type PreparedStatement struct {
@ -99,6 +100,7 @@ var ErrNoRows = errors.New("no rows in result set")
var ErrNotificationTimeout = errors.New("notification timeout")
var ErrDeadConn = errors.New("conn is dead")
var ErrTLSRefused = errors.New("server refused TLS connection")
var ErrConnBusy = errors.New("conn is busy")
type ProtocolError string
@ -878,19 +880,29 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
// Exec executes sql. sql can be either a prepared statement name or an SQL string.
// arguments should be referenced positionally from the sql string as $1, $2, etc.
func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
if err = c.lock(); err != nil {
return commandTag, err
}
startTime := time.Now()
c.lastActivityTime = startTime
if c.logLevel >= LogLevelError {
defer func() {
if err == nil {
defer func() {
if err == nil {
if c.logLevel >= LogLevelInfo {
endTime := time.Now()
c.logger.Info("Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag)
} else {
}
} else {
if c.logLevel >= LogLevelError {
c.logger.Error("Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err)
}
}()
}
}
if unlockErr := c.unlock(); unlockErr != nil && err == nil {
err = unlockErr
}
}()
if err = c.sendQuery(sql, arguments...); err != nil {
return
@ -1137,3 +1149,19 @@ func (c *Conn) die(err error) {
c.causeOfDeath = err
c.conn.Close()
}
func (c *Conn) lock() error {
if c.busy {
return ErrConnBusy
}
c.busy = true
return nil
}
func (c *Conn) unlock() error {
if !c.busy {
return errors.New("unlock conn that is not busy")
}
c.busy = false
return nil
}

View File

@ -487,6 +487,11 @@ func TestConnPoolQueryConcurrentLoad(t *testing.T) {
if rowCount != 1000 {
t.Error("Select called onDataRow wrong number of times")
}
_, err = pool.Exec("--;")
if err != nil {
t.Fatalf("pool.Exec failed: %v", err)
}
}()
}

View File

@ -1030,3 +1030,39 @@ func TestInsertTimestampArray(t *testing.T) {
t.Errorf("Unexpected results from Exec: %v", results)
}
}
func TestCatchSimultaneousConnectionQueries(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
rows1, err := conn.Query("select generate_series(1,$1)", 10)
if err != nil {
t.Fatalf("conn.Query failed: ", err)
}
defer rows1.Close()
_, err = conn.Query("select generate_series(1,$1)", 10)
if err != pgx.ErrConnBusy {
t.Fatalf("conn.Query should have failed with pgx.ErrConnBusy, but it was %v", err)
}
}
func TestCatchSimultaneousConnectionQueryAndExec(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
rows, err := conn.Query("select generate_series(1,$1)", 10)
if err != nil {
t.Fatalf("conn.Query failed: ", err)
}
defer rows.Close()
_, err = conn.Exec("create temporary table foo(spice timestamp[])")
if err != pgx.ErrConnBusy {
t.Fatalf("conn.Exec should have failed with pgx.ErrConnBusy, but it was %v", err)
}
}

View File

@ -39,20 +39,21 @@ func (r *Row) Scan(dest ...interface{}) (err error) {
// the *Conn can be used again. Rows are closed by explicitly calling Close(),
// calling Next() until it returns false, or when a fatal error occurs.
type Rows struct {
pool *ConnPool
conn *Conn
mr *msgReader
fields []FieldDescription
vr ValueReader
rowCount int
columnIdx int
err error
closed bool
startTime time.Time
sql string
args []interface{}
logger Logger
logLevel int
pool *ConnPool
conn *Conn
mr *msgReader
fields []FieldDescription
vr ValueReader
rowCount int
columnIdx int
err error
closed bool
startTime time.Time
sql string
args []interface{}
logger Logger
logLevel int
unlockConn bool
}
func (rows *Rows) FieldDescriptions() []FieldDescription {
@ -64,6 +65,11 @@ func (rows *Rows) close() {
return
}
if rows.unlockConn {
rows.conn.unlock()
rows.unlockConn = false
}
if rows.pool != nil {
rows.pool.Release(rows.conn)
rows.pool = nil
@ -421,6 +427,12 @@ func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) {
c.lastActivityTime = time.Now()
rows := &Rows{conn: c, startTime: c.lastActivityTime, sql: sql, args: args, logger: c.logger, logLevel: c.logLevel}
if err := c.lock(); err != nil {
rows.abort(err)
return rows, err
}
rows.unlockConn = true
ps, ok := c.preparedStatements[sql]
if !ok {
var err error