From 28ef19702f269af3fdaf9367c6ac7a265a859228 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 14 Mar 2015 18:58:09 -0500 Subject: [PATCH] Detect too many parameters on Prepare refs #65 --- conn.go | 20 +++++++++++--- conn_test.go | 74 +++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 90 insertions(+), 4 deletions(-) diff --git a/conn.go b/conn.go index 97500229..898eea94 100644 --- a/conn.go +++ b/conn.go @@ -327,6 +327,9 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { case parseComplete: case parameterDescription: ps.ParameterOids = c.rxParameterDescription(r) + if len(ps.ParameterOids) > 65535 && softErr == nil { + softErr = fmt.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOids)) + } case rowDescription: ps.FieldDescriptions = c.rxRowDescription(r) for i := range ps.FieldDescriptions { @@ -337,7 +340,11 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { case noData: case readyForQuery: c.rxReadyForQuery(r) - c.preparedStatements[name] = ps + + if softErr == nil { + c.preparedStatements[name] = ps + } + return ps, softErr default: if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { @@ -820,10 +827,17 @@ func (c *Conn) rxRowDescription(r *msgReader) (fields []FieldDescription) { } func (c *Conn) rxParameterDescription(r *msgReader) (parameters []Oid) { - parameterCount := r.readInt16() + // Internally, PostgreSQL supports greater than 64k parameters to a prepared + // statement. But the parameter description uses a 16-bit integer for the + // count of parameters. If there are more than 64K parameters, this count is + // wrong. So read the count, ignore it, and compute the proper value from + // the size of the message. + r.readInt16() + parameterCount := r.msgBytesRemaining / 4 + parameters = make([]Oid, 0, parameterCount) - for i := int16(0); i < parameterCount; i++ { + for i := int32(0); i < parameterCount; i++ { parameters = append(parameters, r.readOid()) } return diff --git a/conn_test.go b/conn_test.go index 244b46c6..8c6e24cf 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,7 +1,9 @@ package pgx_test import ( + "fmt" "github.com/jackc/pgx" + "strconv" "strings" "sync" "testing" @@ -357,7 +359,7 @@ func TestPrepare(t *testing.T) { } } -func TestPrepareFailure(t *testing.T) { +func TestPrepareBadSQLFailure(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -370,6 +372,76 @@ func TestPrepareFailure(t *testing.T) { ensureConnValid(t, conn) } +func TestPrepareQueryManyParameters(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + tests := []struct { + count int + succeed bool + }{ + { + count: 65534, + succeed: true, + }, + { + count: 65535, + succeed: true, + }, + { + count: 65536, + succeed: false, + }, + { + count: 65537, + succeed: false, + }, + } + + for i, tt := range tests { + params := make([]string, 0, tt.count) + args := make([]interface{}, 0, tt.count) + for j := 0; j < tt.count; j++ { + params = append(params, fmt.Sprintf("($%d::text)", j+1)) + args = append(args, strconv.FormatInt(int64(j), 10)) + } + + sql := "values" + strings.Join(params, ", ") + + psName := fmt.Sprintf("manyParams%d", i) + _, err := conn.Prepare(psName, sql) + if err != nil { + if tt.succeed { + t.Errorf("%d. %v", i, err) + } + continue + } + if !tt.succeed { + t.Errorf("%d. Expected error but succeeded", i) + continue + } + + rows, err := conn.Query(psName, args...) + if err != nil { + t.Errorf("conn.Query failed: %v", err) + continue + } + + for rows.Next() { + var s string + rows.Scan(&s) + } + + if rows.Err() != nil { + t.Errorf("Reading query result failed: %v", err) + } + } + + ensureConnValid(t, conn) +} + func TestListenNotify(t *testing.T) { t.Parallel()