Prepare returns description

pull/483/head
Jack Christensen 2019-01-01 18:03:20 -06:00
parent 48f563a5f7
commit 413ef99979
3 changed files with 52 additions and 14 deletions

View File

@ -54,7 +54,7 @@ func BenchmarkExecPrepared(b *testing.B) {
require.Nil(b, err)
defer closeConn(b, conn)
err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
_, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
b.ResetTimer()
@ -69,7 +69,7 @@ func BenchmarkSendExecPrepared(b *testing.B) {
require.Nil(b, err)
defer closeConn(b, conn)
err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
_, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
b.ResetTimer()

View File

@ -757,13 +757,42 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
return pgConn.bufferLastResult(ctx)
}
type FieldDescription struct {
Name string
TableOID uint32
TableAttributeNumber uint16
DataTypeOID uint32
DataTypeSize int16
TypeModifier int32
FormatCode int16
}
// pgproto3FieldDescriptionToPgconnFieldDescription copies and converts the data from a pgproto3.FieldDescription to a
// FieldDescription.
func pgproto3FieldDescriptionToPgconnFieldDescription(src *pgproto3.FieldDescription, dst *FieldDescription) {
dst.Name = string(src.Name)
dst.TableOID = src.TableOID
dst.TableAttributeNumber = src.TableAttributeNumber
dst.DataTypeOID = src.DataTypeOID
dst.DataTypeSize = src.DataTypeSize
dst.TypeModifier = src.TypeModifier
dst.FormatCode = src.Format
}
type PreparedStatementDescription struct {
Name string
SQL string
ParamOIDs []uint32
Fields []FieldDescription
}
// Prepare creates a prepared statement.
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) error {
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) {
if pgConn.batchCount != 0 {
return errors.New("unflushed previous sends")
return nil, errors.New("unflushed previous sends")
}
if pgConn.pendingReadyForQueryCount != 0 {
return errors.New("unread previous results")
return nil, errors.New("unread previous results")
}
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn)
@ -775,26 +804,32 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
pgConn.batchCount += 1
err := pgConn.Flush(context.Background())
if err != nil {
return preferContextOverNetTimeoutError(ctx, err)
return nil, preferContextOverNetTimeoutError(ctx, err)
}
psd := &PreparedStatementDescription{Name: name, SQL: sql}
for pgConn.pendingReadyForQueryCount > 0 {
msg, err := pgConn.ReceiveMessage()
if err != nil {
return preferContextOverNetTimeoutError(ctx, err)
return nil, preferContextOverNetTimeoutError(ctx, err)
}
switch msg := msg.(type) {
case *pgproto3.ParameterDescription:
// TODO
psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs))
copy(psd.ParamOIDs, msg.ParameterOIDs)
case *pgproto3.RowDescription:
// TODO
psd.Fields = make([]FieldDescription, len(msg.Fields))
for i := range msg.Fields {
pgproto3FieldDescriptionToPgconnFieldDescription(&msg.Fields[i], &psd.Fields[i])
}
case *pgproto3.ErrorResponse:
return errorResponseToPgError(msg)
return nil, errorResponseToPgError(msg)
}
}
return nil
return psd, nil
}
func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {

View File

@ -327,8 +327,11 @@ func TestConnExecPrepared(t *testing.T) {
require.Nil(t, err)
defer closeConn(t, pgConn)
err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
require.Nil(t, err)
require.NotNil(t, psd)
assert.Len(t, psd.ParamOIDs, 1)
assert.Len(t, psd.Fields, 1)
result, err := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil)
require.Nil(t, err)
@ -343,7 +346,7 @@ func TestConnExecPreparedCanceled(t *testing.T) {
require.Nil(t, err)
defer closeConn(t, pgConn)
err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil)
_, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil)
require.Nil(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
@ -362,7 +365,7 @@ func TestConnBatchedQueries(t *testing.T) {
require.Nil(t, err)
defer closeConn(t, pgConn)
err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
_, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
require.Nil(t, err)
pgConn.SendExec("select 'SendExec 1'")