diff --git a/batch.go b/batch.go index 7d515566..b08b271c 100644 --- a/batch.go +++ b/batch.go @@ -45,9 +45,9 @@ func (b *Batch) Conn() *Conn { return b.conn } -// Queue queues a query to batch b. parameterOIDs are required if there are -// parameters and query is not the name of a prepared statement. -// resultFormatCodes are required if there is a result. +// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. parameterOIDs and +// resultFormatCodes should be nil if query is a prepared statement. Otherwise, parameterOIDs are required if there are +// parameters and resultFormatCodes are required if there is a result. func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgtype.OID, resultFormatCodes []int16) { b.items = append(b.items, &batchItem{ query: query, @@ -95,7 +95,21 @@ func (b *Batch) Send(ctx context.Context) error { } if ps != nil { - batch.ExecPrepared(ps.Name, paramValues, paramFormats, bi.resultFormatCodes) + resultFormats := bi.resultFormatCodes + if resultFormats == nil { + resultFormats = make([]int16, len(ps.FieldDescriptions)) + for i := range resultFormats { + if dt, ok := b.conn.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok { + if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { + resultFormats[i] = BinaryFormatCode + } else { + resultFormats[i] = TextFormatCode + } + } + } + } + + batch.ExecPrepared(ps.Name, paramValues, paramFormats, resultFormats) } else { oids := make([]uint32, len(parameterOIDs)) for i := 0; i < len(parameterOIDs); i++ { diff --git a/batch_test.go b/batch_test.go index 1cc27c4b..3831a678 100644 --- a/batch_test.go +++ b/batch_test.go @@ -177,7 +177,7 @@ func TestConnBeginBatchWithPreparedStatement(t *testing.T) { batch.Queue("ps1", []interface{}{5}, nil, - []int16{pgx.BinaryFormatCode}, + nil, ) }