mirror of https://github.com/jackc/pgx.git
Prepare returns description
parent
48f563a5f7
commit
413ef99979
|
@ -54,7 +54,7 @@ func BenchmarkExecPrepared(b *testing.B) {
|
||||||
require.Nil(b, err)
|
require.Nil(b, err)
|
||||||
defer closeConn(b, conn)
|
defer closeConn(b, conn)
|
||||||
|
|
||||||
err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
|
_, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
||||||
|
@ -69,7 +69,7 @@ func BenchmarkSendExecPrepared(b *testing.B) {
|
||||||
require.Nil(b, err)
|
require.Nil(b, err)
|
||||||
defer closeConn(b, conn)
|
defer closeConn(b, conn)
|
||||||
|
|
||||||
err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
|
_, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil)
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
||||||
|
|
|
@ -757,13 +757,42 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
|
||||||
return pgConn.bufferLastResult(ctx)
|
return pgConn.bufferLastResult(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type FieldDescription struct {
|
||||||
|
Name string
|
||||||
|
TableOID uint32
|
||||||
|
TableAttributeNumber uint16
|
||||||
|
DataTypeOID uint32
|
||||||
|
DataTypeSize int16
|
||||||
|
TypeModifier int32
|
||||||
|
FormatCode int16
|
||||||
|
}
|
||||||
|
|
||||||
|
// pgproto3FieldDescriptionToPgconnFieldDescription copies and converts the data from a pgproto3.FieldDescription to a
|
||||||
|
// FieldDescription.
|
||||||
|
func pgproto3FieldDescriptionToPgconnFieldDescription(src *pgproto3.FieldDescription, dst *FieldDescription) {
|
||||||
|
dst.Name = string(src.Name)
|
||||||
|
dst.TableOID = src.TableOID
|
||||||
|
dst.TableAttributeNumber = src.TableAttributeNumber
|
||||||
|
dst.DataTypeOID = src.DataTypeOID
|
||||||
|
dst.DataTypeSize = src.DataTypeSize
|
||||||
|
dst.TypeModifier = src.TypeModifier
|
||||||
|
dst.FormatCode = src.Format
|
||||||
|
}
|
||||||
|
|
||||||
|
type PreparedStatementDescription struct {
|
||||||
|
Name string
|
||||||
|
SQL string
|
||||||
|
ParamOIDs []uint32
|
||||||
|
Fields []FieldDescription
|
||||||
|
}
|
||||||
|
|
||||||
// Prepare creates a prepared statement.
|
// Prepare creates a prepared statement.
|
||||||
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) error {
|
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) {
|
||||||
if pgConn.batchCount != 0 {
|
if pgConn.batchCount != 0 {
|
||||||
return errors.New("unflushed previous sends")
|
return nil, errors.New("unflushed previous sends")
|
||||||
}
|
}
|
||||||
if pgConn.pendingReadyForQueryCount != 0 {
|
if pgConn.pendingReadyForQueryCount != 0 {
|
||||||
return errors.New("unread previous results")
|
return nil, errors.New("unread previous results")
|
||||||
}
|
}
|
||||||
|
|
||||||
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn)
|
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||||
|
@ -775,26 +804,32 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
|
||||||
pgConn.batchCount += 1
|
pgConn.batchCount += 1
|
||||||
err := pgConn.Flush(context.Background())
|
err := pgConn.Flush(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return preferContextOverNetTimeoutError(ctx, err)
|
return nil, preferContextOverNetTimeoutError(ctx, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
psd := &PreparedStatementDescription{Name: name, SQL: sql}
|
||||||
|
|
||||||
for pgConn.pendingReadyForQueryCount > 0 {
|
for pgConn.pendingReadyForQueryCount > 0 {
|
||||||
msg, err := pgConn.ReceiveMessage()
|
msg, err := pgConn.ReceiveMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return preferContextOverNetTimeoutError(ctx, err)
|
return nil, preferContextOverNetTimeoutError(ctx, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
case *pgproto3.ParameterDescription:
|
case *pgproto3.ParameterDescription:
|
||||||
// TODO
|
psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs))
|
||||||
|
copy(psd.ParamOIDs, msg.ParameterOIDs)
|
||||||
case *pgproto3.RowDescription:
|
case *pgproto3.RowDescription:
|
||||||
// TODO
|
psd.Fields = make([]FieldDescription, len(msg.Fields))
|
||||||
|
for i := range msg.Fields {
|
||||||
|
pgproto3FieldDescriptionToPgconnFieldDescription(&msg.Fields[i], &psd.Fields[i])
|
||||||
|
}
|
||||||
case *pgproto3.ErrorResponse:
|
case *pgproto3.ErrorResponse:
|
||||||
return errorResponseToPgError(msg)
|
return nil, errorResponseToPgError(msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return psd, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
|
func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
|
||||||
|
|
|
@ -327,8 +327,11 @@ func TestConnExecPrepared(t *testing.T) {
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
defer closeConn(t, pgConn)
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
|
psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
require.NotNil(t, psd)
|
||||||
|
assert.Len(t, psd.ParamOIDs, 1)
|
||||||
|
assert.Len(t, psd.Fields, 1)
|
||||||
|
|
||||||
result, err := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil)
|
result, err := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
@ -343,7 +346,7 @@ func TestConnExecPreparedCanceled(t *testing.T) {
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
defer closeConn(t, pgConn)
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil)
|
_, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
@ -362,7 +365,7 @@ func TestConnBatchedQueries(t *testing.T) {
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
defer closeConn(t, pgConn)
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
|
_, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
pgConn.SendExec("select 'SendExec 1'")
|
pgConn.SendExec("select 'SendExec 1'")
|
||||||
|
|
Loading…
Reference in New Issue