mirror of https://github.com/jackc/pgx.git
Add guards against usage of busy connection
parent
51407590eb
commit
86837e5576
40
conn.go
40
conn.go
|
@ -62,6 +62,7 @@ type Conn struct {
|
|||
fp *fastpath
|
||||
pgsql_af_inet byte
|
||||
pgsql_af_inet6 byte
|
||||
busy bool
|
||||
}
|
||||
|
||||
type PreparedStatement struct {
|
||||
|
@ -99,6 +100,7 @@ var ErrNoRows = errors.New("no rows in result set")
|
|||
var ErrNotificationTimeout = errors.New("notification timeout")
|
||||
var ErrDeadConn = errors.New("conn is dead")
|
||||
var ErrTLSRefused = errors.New("server refused TLS connection")
|
||||
var ErrConnBusy = errors.New("conn is busy")
|
||||
|
||||
type ProtocolError string
|
||||
|
||||
|
@ -878,19 +880,29 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||
// Exec executes sql. sql can be either a prepared statement name or an SQL string.
|
||||
// arguments should be referenced positionally from the sql string as $1, $2, etc.
|
||||
func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||
if err = c.lock(); err != nil {
|
||||
return commandTag, err
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
c.lastActivityTime = startTime
|
||||
|
||||
if c.logLevel >= LogLevelError {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
if c.logLevel >= LogLevelInfo {
|
||||
endTime := time.Now()
|
||||
c.logger.Info("Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag)
|
||||
} else {
|
||||
}
|
||||
} else {
|
||||
if c.logLevel >= LogLevelError {
|
||||
c.logger.Error("Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
if unlockErr := c.unlock(); unlockErr != nil && err == nil {
|
||||
err = unlockErr
|
||||
}
|
||||
}()
|
||||
|
||||
if err = c.sendQuery(sql, arguments...); err != nil {
|
||||
return
|
||||
|
@ -1137,3 +1149,19 @@ func (c *Conn) die(err error) {
|
|||
c.causeOfDeath = err
|
||||
c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *Conn) lock() error {
|
||||
if c.busy {
|
||||
return ErrConnBusy
|
||||
}
|
||||
c.busy = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) unlock() error {
|
||||
if !c.busy {
|
||||
return errors.New("unlock conn that is not busy")
|
||||
}
|
||||
c.busy = false
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -487,6 +487,11 @@ func TestConnPoolQueryConcurrentLoad(t *testing.T) {
|
|||
if rowCount != 1000 {
|
||||
t.Error("Select called onDataRow wrong number of times")
|
||||
}
|
||||
|
||||
_, err = pool.Exec("--;")
|
||||
if err != nil {
|
||||
t.Fatalf("pool.Exec failed: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
|
|
36
conn_test.go
36
conn_test.go
|
@ -1030,3 +1030,39 @@ func TestInsertTimestampArray(t *testing.T) {
|
|||
t.Errorf("Unexpected results from Exec: %v", results)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCatchSimultaneousConnectionQueries(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
rows1, err := conn.Query("select generate_series(1,$1)", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("conn.Query failed: ", err)
|
||||
}
|
||||
defer rows1.Close()
|
||||
|
||||
_, err = conn.Query("select generate_series(1,$1)", 10)
|
||||
if err != pgx.ErrConnBusy {
|
||||
t.Fatalf("conn.Query should have failed with pgx.ErrConnBusy, but it was %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCatchSimultaneousConnectionQueryAndExec(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
rows, err := conn.Query("select generate_series(1,$1)", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("conn.Query failed: ", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
_, err = conn.Exec("create temporary table foo(spice timestamp[])")
|
||||
if err != pgx.ErrConnBusy {
|
||||
t.Fatalf("conn.Exec should have failed with pgx.ErrConnBusy, but it was %v", err)
|
||||
}
|
||||
}
|
||||
|
|
40
query.go
40
query.go
|
@ -39,20 +39,21 @@ func (r *Row) Scan(dest ...interface{}) (err error) {
|
|||
// the *Conn can be used again. Rows are closed by explicitly calling Close(),
|
||||
// calling Next() until it returns false, or when a fatal error occurs.
|
||||
type Rows struct {
|
||||
pool *ConnPool
|
||||
conn *Conn
|
||||
mr *msgReader
|
||||
fields []FieldDescription
|
||||
vr ValueReader
|
||||
rowCount int
|
||||
columnIdx int
|
||||
err error
|
||||
closed bool
|
||||
startTime time.Time
|
||||
sql string
|
||||
args []interface{}
|
||||
logger Logger
|
||||
logLevel int
|
||||
pool *ConnPool
|
||||
conn *Conn
|
||||
mr *msgReader
|
||||
fields []FieldDescription
|
||||
vr ValueReader
|
||||
rowCount int
|
||||
columnIdx int
|
||||
err error
|
||||
closed bool
|
||||
startTime time.Time
|
||||
sql string
|
||||
args []interface{}
|
||||
logger Logger
|
||||
logLevel int
|
||||
unlockConn bool
|
||||
}
|
||||
|
||||
func (rows *Rows) FieldDescriptions() []FieldDescription {
|
||||
|
@ -64,6 +65,11 @@ func (rows *Rows) close() {
|
|||
return
|
||||
}
|
||||
|
||||
if rows.unlockConn {
|
||||
rows.conn.unlock()
|
||||
rows.unlockConn = false
|
||||
}
|
||||
|
||||
if rows.pool != nil {
|
||||
rows.pool.Release(rows.conn)
|
||||
rows.pool = nil
|
||||
|
@ -421,6 +427,12 @@ func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) {
|
|||
c.lastActivityTime = time.Now()
|
||||
rows := &Rows{conn: c, startTime: c.lastActivityTime, sql: sql, args: args, logger: c.logger, logLevel: c.logLevel}
|
||||
|
||||
if err := c.lock(); err != nil {
|
||||
rows.abort(err)
|
||||
return rows, err
|
||||
}
|
||||
rows.unlockConn = true
|
||||
|
||||
ps, ok := c.preparedStatements[sql]
|
||||
if !ok {
|
||||
var err error
|
||||
|
|
Loading…
Reference in New Issue