diff --git a/conn.go b/conn.go index 24aecc0b..7e826878 100644 --- a/conn.go +++ b/conn.go @@ -62,6 +62,7 @@ type Conn struct { fp *fastpath pgsql_af_inet byte pgsql_af_inet6 byte + busy bool } type PreparedStatement struct { @@ -99,6 +100,7 @@ var ErrNoRows = errors.New("no rows in result set") var ErrNotificationTimeout = errors.New("notification timeout") var ErrDeadConn = errors.New("conn is dead") var ErrTLSRefused = errors.New("server refused TLS connection") +var ErrConnBusy = errors.New("conn is busy") type ProtocolError string @@ -878,19 +880,29 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} // Exec executes sql. sql can be either a prepared statement name or an SQL string. // arguments should be referenced positionally from the sql string as $1, $2, etc. func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { + if err = c.lock(); err != nil { + return commandTag, err + } + startTime := time.Now() c.lastActivityTime = startTime - if c.logLevel >= LogLevelError { - defer func() { - if err == nil { + defer func() { + if err == nil { + if c.logLevel >= LogLevelInfo { endTime := time.Now() c.logger.Info("Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag) - } else { + } + } else { + if c.logLevel >= LogLevelError { c.logger.Error("Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err) } - }() - } + } + + if unlockErr := c.unlock(); unlockErr != nil && err == nil { + err = unlockErr + } + }() if err = c.sendQuery(sql, arguments...); err != nil { return @@ -1137,3 +1149,19 @@ func (c *Conn) die(err error) { c.causeOfDeath = err c.conn.Close() } + +func (c *Conn) lock() error { + if c.busy { + return ErrConnBusy + } + c.busy = true + return nil +} + +func (c *Conn) unlock() error { + if !c.busy { + return errors.New("unlock conn that is not busy") + } + c.busy = false + return nil +} diff --git a/conn_pool_test.go b/conn_pool_test.go index 73519855..93fea4d4 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -487,6 +487,11 @@ func TestConnPoolQueryConcurrentLoad(t *testing.T) { if rowCount != 1000 { t.Error("Select called onDataRow wrong number of times") } + + _, err = pool.Exec("--;") + if err != nil { + t.Fatalf("pool.Exec failed: %v", err) + } }() } diff --git a/conn_test.go b/conn_test.go index 2e3f612e..147a925c 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1030,3 +1030,39 @@ func TestInsertTimestampArray(t *testing.T) { t.Errorf("Unexpected results from Exec: %v", results) } } + +func TestCatchSimultaneousConnectionQueries(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + rows1, err := conn.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + defer rows1.Close() + + _, err = conn.Query("select generate_series(1,$1)", 10) + if err != pgx.ErrConnBusy { + t.Fatalf("conn.Query should have failed with pgx.ErrConnBusy, but it was %v", err) + } +} + +func TestCatchSimultaneousConnectionQueryAndExec(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + rows, err := conn.Query("select generate_series(1,$1)", 10) + if err != nil { + t.Fatalf("conn.Query failed: ", err) + } + defer rows.Close() + + _, err = conn.Exec("create temporary table foo(spice timestamp[])") + if err != pgx.ErrConnBusy { + t.Fatalf("conn.Exec should have failed with pgx.ErrConnBusy, but it was %v", err) + } +} diff --git a/query.go b/query.go index 1810f305..1b29c425 100644 --- a/query.go +++ b/query.go @@ -39,20 +39,21 @@ func (r *Row) Scan(dest ...interface{}) (err error) { // the *Conn can be used again. Rows are closed by explicitly calling Close(), // calling Next() until it returns false, or when a fatal error occurs. type Rows struct { - pool *ConnPool - conn *Conn - mr *msgReader - fields []FieldDescription - vr ValueReader - rowCount int - columnIdx int - err error - closed bool - startTime time.Time - sql string - args []interface{} - logger Logger - logLevel int + pool *ConnPool + conn *Conn + mr *msgReader + fields []FieldDescription + vr ValueReader + rowCount int + columnIdx int + err error + closed bool + startTime time.Time + sql string + args []interface{} + logger Logger + logLevel int + unlockConn bool } func (rows *Rows) FieldDescriptions() []FieldDescription { @@ -64,6 +65,11 @@ func (rows *Rows) close() { return } + if rows.unlockConn { + rows.conn.unlock() + rows.unlockConn = false + } + if rows.pool != nil { rows.pool.Release(rows.conn) rows.pool = nil @@ -421,6 +427,12 @@ func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) { c.lastActivityTime = time.Now() rows := &Rows{conn: c, startTime: c.lastActivityTime, sql: sql, args: args, logger: c.logger, logLevel: c.logLevel} + if err := c.lock(); err != nil { + rows.abort(err) + return rows, err + } + rows.unlockConn = true + ps, ok := c.preparedStatements[sql] if !ok { var err error