diff --git a/pgconn/benchmark_test.go b/pgconn/benchmark_test.go index bdc550cb..269ac59b 100644 --- a/pgconn/benchmark_test.go +++ b/pgconn/benchmark_test.go @@ -54,7 +54,7 @@ func BenchmarkExecPrepared(b *testing.B) { require.Nil(b, err) defer closeConn(b, conn) - err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + _, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) b.ResetTimer() @@ -69,7 +69,7 @@ func BenchmarkSendExecPrepared(b *testing.B) { require.Nil(b, err) defer closeConn(b, conn) - err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + _, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) b.ResetTimer() diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 1e70a82b..de7020b2 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -757,13 +757,42 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return pgConn.bufferLastResult(ctx) } +type FieldDescription struct { + Name string + TableOID uint32 + TableAttributeNumber uint16 + DataTypeOID uint32 + DataTypeSize int16 + TypeModifier int32 + FormatCode int16 +} + +// pgproto3FieldDescriptionToPgconnFieldDescription copies and converts the data from a pgproto3.FieldDescription to a +// FieldDescription. +func pgproto3FieldDescriptionToPgconnFieldDescription(src *pgproto3.FieldDescription, dst *FieldDescription) { + dst.Name = string(src.Name) + dst.TableOID = src.TableOID + dst.TableAttributeNumber = src.TableAttributeNumber + dst.DataTypeOID = src.DataTypeOID + dst.DataTypeSize = src.DataTypeSize + dst.TypeModifier = src.TypeModifier + dst.FormatCode = src.Format +} + +type PreparedStatementDescription struct { + Name string + SQL string + ParamOIDs []uint32 + Fields []FieldDescription +} + // Prepare creates a prepared statement. -func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) error { +func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { if pgConn.batchCount != 0 { - return errors.New("unflushed previous sends") + return nil, errors.New("unflushed previous sends") } if pgConn.pendingReadyForQueryCount != 0 { - return errors.New("unread previous results") + return nil, errors.New("unread previous results") } cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) @@ -775,26 +804,32 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ pgConn.batchCount += 1 err := pgConn.Flush(context.Background()) if err != nil { - return preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } + psd := &PreparedStatementDescription{Name: name, SQL: sql} + for pgConn.pendingReadyForQueryCount > 0 { msg, err := pgConn.ReceiveMessage() if err != nil { - return preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { case *pgproto3.ParameterDescription: - // TODO + psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) + copy(psd.ParamOIDs, msg.ParameterOIDs) case *pgproto3.RowDescription: - // TODO + psd.Fields = make([]FieldDescription, len(msg.Fields)) + for i := range msg.Fields { + pgproto3FieldDescriptionToPgconnFieldDescription(&msg.Fields[i], &psd.Fields[i]) + } case *pgproto3.ErrorResponse: - return errorResponseToPgError(msg) + return nil, errorResponseToPgError(msg) } } - return nil + return psd, nil } func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 8f976d87..ee573d42 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -327,8 +327,11 @@ func TestConnExecPrepared(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) require.Nil(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, 1) + assert.Len(t, psd.Fields, 1) result, err := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) require.Nil(t, err) @@ -343,7 +346,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil) + _, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil) require.Nil(t, err) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -362,7 +365,7 @@ func TestConnBatchedQueries(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) require.Nil(t, err) pgConn.SendExec("select 'SendExec 1'")