mirror of https://github.com/jackc/pgx.git
Prepare returns description
parent
48f563a5f7
commit
413ef99979
|
@ -54,7 +54,7 @@ func BenchmarkExecPrepared(b *testing.B) {
|
|||
require.Nil(b, err)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
|
||||
_, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
|
@ -69,7 +69,7 @@ func BenchmarkSendExecPrepared(b *testing.B) {
|
|||
require.Nil(b, err)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
|
||||
_, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
|
|
|
@ -757,13 +757,42 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
|
|||
return pgConn.bufferLastResult(ctx)
|
||||
}
|
||||
|
||||
type FieldDescription struct {
|
||||
Name string
|
||||
TableOID uint32
|
||||
TableAttributeNumber uint16
|
||||
DataTypeOID uint32
|
||||
DataTypeSize int16
|
||||
TypeModifier int32
|
||||
FormatCode int16
|
||||
}
|
||||
|
||||
// pgproto3FieldDescriptionToPgconnFieldDescription copies and converts the data from a pgproto3.FieldDescription to a
|
||||
// FieldDescription.
|
||||
func pgproto3FieldDescriptionToPgconnFieldDescription(src *pgproto3.FieldDescription, dst *FieldDescription) {
|
||||
dst.Name = string(src.Name)
|
||||
dst.TableOID = src.TableOID
|
||||
dst.TableAttributeNumber = src.TableAttributeNumber
|
||||
dst.DataTypeOID = src.DataTypeOID
|
||||
dst.DataTypeSize = src.DataTypeSize
|
||||
dst.TypeModifier = src.TypeModifier
|
||||
dst.FormatCode = src.Format
|
||||
}
|
||||
|
||||
type PreparedStatementDescription struct {
|
||||
Name string
|
||||
SQL string
|
||||
ParamOIDs []uint32
|
||||
Fields []FieldDescription
|
||||
}
|
||||
|
||||
// Prepare creates a prepared statement.
|
||||
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) error {
|
||||
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) {
|
||||
if pgConn.batchCount != 0 {
|
||||
return errors.New("unflushed previous sends")
|
||||
return nil, errors.New("unflushed previous sends")
|
||||
}
|
||||
if pgConn.pendingReadyForQueryCount != 0 {
|
||||
return errors.New("unread previous results")
|
||||
return nil, errors.New("unread previous results")
|
||||
}
|
||||
|
||||
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||
|
@ -775,26 +804,32 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
|
|||
pgConn.batchCount += 1
|
||||
err := pgConn.Flush(context.Background())
|
||||
if err != nil {
|
||||
return preferContextOverNetTimeoutError(ctx, err)
|
||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
||||
}
|
||||
|
||||
psd := &PreparedStatementDescription{Name: name, SQL: sql}
|
||||
|
||||
for pgConn.pendingReadyForQueryCount > 0 {
|
||||
msg, err := pgConn.ReceiveMessage()
|
||||
if err != nil {
|
||||
return preferContextOverNetTimeoutError(ctx, err)
|
||||
return nil, preferContextOverNetTimeoutError(ctx, err)
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ParameterDescription:
|
||||
// TODO
|
||||
psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs))
|
||||
copy(psd.ParamOIDs, msg.ParameterOIDs)
|
||||
case *pgproto3.RowDescription:
|
||||
// TODO
|
||||
psd.Fields = make([]FieldDescription, len(msg.Fields))
|
||||
for i := range msg.Fields {
|
||||
pgproto3FieldDescriptionToPgconnFieldDescription(&msg.Fields[i], &psd.Fields[i])
|
||||
}
|
||||
case *pgproto3.ErrorResponse:
|
||||
return errorResponseToPgError(msg)
|
||||
return nil, errorResponseToPgError(msg)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return psd, nil
|
||||
}
|
||||
|
||||
func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
|
||||
|
|
|
@ -327,8 +327,11 @@ func TestConnExecPrepared(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
|
||||
psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, psd)
|
||||
assert.Len(t, psd.ParamOIDs, 1)
|
||||
assert.Len(t, psd.Fields, 1)
|
||||
|
||||
result, err := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil)
|
||||
require.Nil(t, err)
|
||||
|
@ -343,7 +346,7 @@ func TestConnExecPreparedCanceled(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil)
|
||||
_, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil)
|
||||
require.Nil(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
|
@ -362,7 +365,7 @@ func TestConnBatchedQueries(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
|
||||
_, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
|
||||
require.Nil(t, err)
|
||||
|
||||
pgConn.SendExec("select 'SendExec 1'")
|
||||
|
|
Loading…
Reference in New Issue