Detect too many parameters on Prepare

refs #65
This commit is contained in:
Jack Christensen 2015-03-14 18:58:09 -05:00
parent cf674c6958
commit 28ef19702f
2 changed files with 90 additions and 4 deletions

18
conn.go
View File

@ -327,6 +327,9 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
case parseComplete:
case parameterDescription:
ps.ParameterOids = c.rxParameterDescription(r)
if len(ps.ParameterOids) > 65535 && softErr == nil {
softErr = fmt.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOids))
}
case rowDescription:
ps.FieldDescriptions = c.rxRowDescription(r)
for i := range ps.FieldDescriptions {
@ -337,7 +340,11 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
case noData:
case readyForQuery:
c.rxReadyForQuery(r)
if softErr == nil {
c.preparedStatements[name] = ps
}
return ps, softErr
default:
if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil {
@ -820,10 +827,17 @@ func (c *Conn) rxRowDescription(r *msgReader) (fields []FieldDescription) {
}
func (c *Conn) rxParameterDescription(r *msgReader) (parameters []Oid) {
parameterCount := r.readInt16()
// Internally, PostgreSQL supports greater than 64k parameters to a prepared
// statement. But the parameter description uses a 16-bit integer for the
// count of parameters. If there are more than 64K parameters, this count is
// wrong. So read the count, ignore it, and compute the proper value from
// the size of the message.
r.readInt16()
parameterCount := r.msgBytesRemaining / 4
parameters = make([]Oid, 0, parameterCount)
for i := int16(0); i < parameterCount; i++ {
for i := int32(0); i < parameterCount; i++ {
parameters = append(parameters, r.readOid())
}
return

View File

@ -1,7 +1,9 @@
package pgx_test
import (
"fmt"
"github.com/jackc/pgx"
"strconv"
"strings"
"sync"
"testing"
@ -357,7 +359,7 @@ func TestPrepare(t *testing.T) {
}
}
func TestPrepareFailure(t *testing.T) {
func TestPrepareBadSQLFailure(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
@ -370,6 +372,76 @@ func TestPrepareFailure(t *testing.T) {
ensureConnValid(t, conn)
}
func TestPrepareQueryManyParameters(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
tests := []struct {
count int
succeed bool
}{
{
count: 65534,
succeed: true,
},
{
count: 65535,
succeed: true,
},
{
count: 65536,
succeed: false,
},
{
count: 65537,
succeed: false,
},
}
for i, tt := range tests {
params := make([]string, 0, tt.count)
args := make([]interface{}, 0, tt.count)
for j := 0; j < tt.count; j++ {
params = append(params, fmt.Sprintf("($%d::text)", j+1))
args = append(args, strconv.FormatInt(int64(j), 10))
}
sql := "values" + strings.Join(params, ", ")
psName := fmt.Sprintf("manyParams%d", i)
_, err := conn.Prepare(psName, sql)
if err != nil {
if tt.succeed {
t.Errorf("%d. %v", i, err)
}
continue
}
if !tt.succeed {
t.Errorf("%d. Expected error but succeeded", i)
continue
}
rows, err := conn.Query(psName, args...)
if err != nil {
t.Errorf("conn.Query failed: %v", err)
continue
}
for rows.Next() {
var s string
rows.Scan(&s)
}
if rows.Err() != nil {
t.Errorf("Reading query result failed: %v", err)
}
}
ensureConnValid(t, conn)
}
func TestListenNotify(t *testing.T) {
t.Parallel()