make properties of QueuedQuery and Batch public, closes #1878

pull/1887/head
Pavlo Golub 2024-01-25 11:39:17 +01:00 committed by Jack Christensen
parent a57bb8caea
commit c90f82a4e3
2 changed files with 49 additions and 49 deletions

View File

@ -10,8 +10,8 @@ import (
// QueuedQuery is a query that has been queued for execution via a Batch. // QueuedQuery is a query that has been queued for execution via a Batch.
type QueuedQuery struct { type QueuedQuery struct {
query string SQL string
arguments []any Arguments []any
fn batchItemFunc fn batchItemFunc
sd *pgconn.StatementDescription sd *pgconn.StatementDescription
} }
@ -57,7 +57,7 @@ func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
// Batch queries are a way of bundling multiple queries together to avoid // Batch queries are a way of bundling multiple queries together to avoid
// unnecessary network round trips. A Batch must only be sent once. // unnecessary network round trips. A Batch must only be sent once.
type Batch struct { type Batch struct {
queuedQueries []*QueuedQuery QueuedQueries []*QueuedQuery
} }
// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. // Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement.
@ -65,16 +65,16 @@ type Batch struct {
// connection's DefaultQueryExecMode. // connection's DefaultQueryExecMode.
func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery { func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery {
qq := &QueuedQuery{ qq := &QueuedQuery{
query: query, SQL: query,
arguments: arguments, Arguments: arguments,
} }
b.queuedQueries = append(b.queuedQueries, qq) b.QueuedQueries = append(b.QueuedQueries, qq)
return qq return qq
} }
// Len returns number of queries that have been queued so far. // Len returns number of queries that have been queued so far.
func (b *Batch) Len() int { func (b *Batch) Len() int {
return len(b.queuedQueries) return len(b.QueuedQueries)
} }
type BatchResults interface { type BatchResults interface {
@ -227,9 +227,9 @@ func (br *batchResults) Close() error {
} }
// Read and run fn for all remaining items // Read and run fn for all remaining items
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) { for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
if br.b.queuedQueries[br.qqIdx].fn != nil { if br.b.QueuedQueries[br.qqIdx].fn != nil {
err := br.b.queuedQueries[br.qqIdx].fn(br) err := br.b.QueuedQueries[br.qqIdx].fn(br)
if err != nil { if err != nil {
br.err = err br.err = err
} }
@ -253,10 +253,10 @@ func (br *batchResults) earlyError() error {
} }
func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) { func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) { if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
bi := br.b.queuedQueries[br.qqIdx] bi := br.b.QueuedQueries[br.qqIdx]
query = bi.query query = bi.SQL
args = bi.arguments args = bi.Arguments
ok = true ok = true
br.qqIdx++ br.qqIdx++
} }
@ -396,9 +396,9 @@ func (br *pipelineBatchResults) Close() error {
} }
// Read and run fn for all remaining items // Read and run fn for all remaining items
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) { for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
if br.b.queuedQueries[br.qqIdx].fn != nil { if br.b.QueuedQueries[br.qqIdx].fn != nil {
err := br.b.queuedQueries[br.qqIdx].fn(br) err := br.b.QueuedQueries[br.qqIdx].fn(br)
if err != nil { if err != nil {
br.err = err br.err = err
} }
@ -422,10 +422,10 @@ func (br *pipelineBatchResults) earlyError() error {
} }
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) { func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) { if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
bi := br.b.queuedQueries[br.qqIdx] bi := br.b.QueuedQueries[br.qqIdx]
query = bi.query query = bi.SQL
args = bi.arguments args = bi.Arguments
ok = true ok = true
br.qqIdx++ br.qqIdx++
} }

56
conn.go
View File

@ -903,10 +903,10 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
return &batchResults{ctx: ctx, conn: c, err: err} return &batchResults{ctx: ctx, conn: c, err: err}
} }
for _, bi := range b.queuedQueries { for _, bi := range b.QueuedQueries {
var queryRewriter QueryRewriter var queryRewriter QueryRewriter
sql := bi.query sql := bi.SQL
arguments := bi.arguments arguments := bi.Arguments
optionLoop: optionLoop:
for len(arguments) > 0 { for len(arguments) > 0 {
@ -928,8 +928,8 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
} }
} }
bi.query = sql bi.SQL = sql
bi.arguments = arguments bi.Arguments = arguments
} }
// TODO: changing mode per batch? Update Batch.Queue function comment when implemented // TODO: changing mode per batch? Update Batch.Queue function comment when implemented
@ -939,8 +939,8 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
} }
// All other modes use extended protocol and thus can use prepared statements. // All other modes use extended protocol and thus can use prepared statements.
for _, bi := range b.queuedQueries { for _, bi := range b.QueuedQueries {
if sd, ok := c.preparedStatements[bi.query]; ok { if sd, ok := c.preparedStatements[bi.SQL]; ok {
bi.sd = sd bi.sd = sd
} }
} }
@ -961,11 +961,11 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults { func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults {
var sb strings.Builder var sb strings.Builder
for i, bi := range b.queuedQueries { for i, bi := range b.QueuedQueries {
if i > 0 { if i > 0 {
sb.WriteByte(';') sb.WriteByte(';')
} }
sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...) sql, err := c.sanitizeForSimpleQuery(bi.SQL, bi.Arguments...)
if err != nil { if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err} return &batchResults{ctx: ctx, conn: c, err: err}
} }
@ -984,21 +984,21 @@ func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batc
func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults { func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults {
batch := &pgconn.Batch{} batch := &pgconn.Batch{}
for _, bi := range b.queuedQueries { for _, bi := range b.QueuedQueries {
sd := bi.sd sd := bi.sd
if sd != nil { if sd != nil {
err := c.eqb.Build(c.typeMap, sd, bi.arguments) err := c.eqb.Build(c.typeMap, sd, bi.Arguments)
if err != nil { if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err} return &batchResults{ctx: ctx, conn: c, err: err}
} }
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
} else { } else {
err := c.eqb.Build(c.typeMap, nil, bi.arguments) err := c.eqb.Build(c.typeMap, nil, bi.Arguments)
if err != nil { if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err} return &batchResults{ctx: ctx, conn: c, err: err}
} }
batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats) batch.ExecParams(bi.SQL, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
} }
} }
@ -1023,18 +1023,18 @@ func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batc
distinctNewQueries := []*pgconn.StatementDescription{} distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int) distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.queuedQueries { for _, bi := range b.QueuedQueries {
if bi.sd == nil { if bi.sd == nil {
sd := c.statementCache.Get(bi.query) sd := c.statementCache.Get(bi.SQL)
if sd != nil { if sd != nil {
bi.sd = sd bi.sd = sd
} else { } else {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present { if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
bi.sd = distinctNewQueries[idx] bi.sd = distinctNewQueries[idx]
} else { } else {
sd = &pgconn.StatementDescription{ sd = &pgconn.StatementDescription{
Name: stmtcache.StatementName(bi.query), Name: stmtcache.StatementName(bi.SQL),
SQL: bi.query, SQL: bi.SQL,
} }
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd) distinctNewQueries = append(distinctNewQueries, sd)
@ -1055,17 +1055,17 @@ func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch
distinctNewQueries := []*pgconn.StatementDescription{} distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int) distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.queuedQueries { for _, bi := range b.QueuedQueries {
if bi.sd == nil { if bi.sd == nil {
sd := c.descriptionCache.Get(bi.query) sd := c.descriptionCache.Get(bi.SQL)
if sd != nil { if sd != nil {
bi.sd = sd bi.sd = sd
} else { } else {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present { if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
bi.sd = distinctNewQueries[idx] bi.sd = distinctNewQueries[idx]
} else { } else {
sd = &pgconn.StatementDescription{ sd = &pgconn.StatementDescription{
SQL: bi.query, SQL: bi.SQL,
} }
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd) distinctNewQueries = append(distinctNewQueries, sd)
@ -1082,13 +1082,13 @@ func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch)
distinctNewQueries := []*pgconn.StatementDescription{} distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int) distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.queuedQueries { for _, bi := range b.QueuedQueries {
if bi.sd == nil { if bi.sd == nil {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present { if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
bi.sd = distinctNewQueries[idx] bi.sd = distinctNewQueries[idx]
} else { } else {
sd := &pgconn.StatementDescription{ sd := &pgconn.StatementDescription{
SQL: bi.query, SQL: bi.SQL,
} }
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd) distinctNewQueries = append(distinctNewQueries, sd)
@ -1154,11 +1154,11 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
} }
// Queue the queries. // Queue the queries.
for _, bi := range b.queuedQueries { for _, bi := range b.QueuedQueries {
err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments) err := c.eqb.Build(c.typeMap, bi.sd, bi.Arguments)
if err != nil { if err != nil {
// we wrap the error so we the user can understand which query failed inside the batch // we wrap the error so we the user can understand which query failed inside the batch
err = fmt.Errorf("error building query %s: %w", bi.query, err) err = fmt.Errorf("error building query %s: %w", bi.SQL, err)
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
} }