mirror of
https://github.com/jackc/pgx.git
synced 2025-04-27 13:14:32 +00:00
In case of an error it was possible for the goroutine that builds the copy stream to still be running after CopyFrom returned. Since that goroutine uses the connections ConnInfo data types to encode the copy data it was possible for those types to be concurrently used in an unsafe fashion. CopyFrom will no longer return until that goroutine has completed.
170 lines
4.0 KiB
Go
170 lines
4.0 KiB
Go
package pgx
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
|
|
"github.com/jackc/pgconn"
|
|
"github.com/jackc/pgio"
|
|
errors "golang.org/x/xerrors"
|
|
)
|
|
|
|
// CopyFromRows returns a CopyFromSource interface over the provided rows slice
|
|
// making it usable by *Conn.CopyFrom.
|
|
func CopyFromRows(rows [][]interface{}) CopyFromSource {
|
|
return ©FromRows{rows: rows, idx: -1}
|
|
}
|
|
|
|
type copyFromRows struct {
|
|
rows [][]interface{}
|
|
idx int
|
|
}
|
|
|
|
func (ctr *copyFromRows) Next() bool {
|
|
ctr.idx++
|
|
return ctr.idx < len(ctr.rows)
|
|
}
|
|
|
|
func (ctr *copyFromRows) Values() ([]interface{}, error) {
|
|
return ctr.rows[ctr.idx], nil
|
|
}
|
|
|
|
func (ctr *copyFromRows) Err() error {
|
|
return nil
|
|
}
|
|
|
|
// CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data.
|
|
type CopyFromSource interface {
|
|
// Next returns true if there is another row and makes the next row data
|
|
// available to Values(). When there are no more rows available or an error
|
|
// has occurred it returns false.
|
|
Next() bool
|
|
|
|
// Values returns the values for the current row.
|
|
Values() ([]interface{}, error)
|
|
|
|
// Err returns any error that has been encountered by the CopyFromSource. If
|
|
// this is not nil *Conn.CopyFrom will abort the copy.
|
|
Err() error
|
|
}
|
|
|
|
type copyFrom struct {
|
|
conn *Conn
|
|
tableName Identifier
|
|
columnNames []string
|
|
rowSrc CopyFromSource
|
|
readerErrChan chan error
|
|
}
|
|
|
|
func (ct *copyFrom) run(ctx context.Context) (int64, error) {
|
|
quotedTableName := ct.tableName.Sanitize()
|
|
cbuf := &bytes.Buffer{}
|
|
for i, cn := range ct.columnNames {
|
|
if i != 0 {
|
|
cbuf.WriteString(", ")
|
|
}
|
|
cbuf.WriteString(quoteIdentifier(cn))
|
|
}
|
|
quotedColumnNames := cbuf.String()
|
|
|
|
sd, err := ct.conn.Prepare(ctx, "", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
r, w := io.Pipe()
|
|
doneChan := make(chan struct{})
|
|
|
|
go func() {
|
|
defer close(doneChan)
|
|
|
|
// Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283.
|
|
buf := ct.conn.wbuf
|
|
|
|
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
|
|
buf = pgio.AppendInt32(buf, 0)
|
|
buf = pgio.AppendInt32(buf, 0)
|
|
|
|
moreRows := true
|
|
for moreRows {
|
|
var err error
|
|
moreRows, buf, err = ct.buildCopyBuf(buf, sd)
|
|
if err != nil {
|
|
w.CloseWithError(err)
|
|
return
|
|
}
|
|
|
|
if ct.rowSrc.Err() != nil {
|
|
w.CloseWithError(ct.rowSrc.Err())
|
|
return
|
|
}
|
|
|
|
if len(buf) > 0 {
|
|
_, err = w.Write(buf)
|
|
if err != nil {
|
|
w.Close()
|
|
return
|
|
}
|
|
}
|
|
|
|
buf = buf[:0]
|
|
}
|
|
|
|
w.Close()
|
|
}()
|
|
|
|
commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
|
|
|
|
r.Close()
|
|
<-doneChan
|
|
|
|
return commandTag.RowsAffected(), err
|
|
}
|
|
|
|
func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) {
|
|
|
|
for ct.rowSrc.Next() {
|
|
values, err := ct.rowSrc.Values()
|
|
if err != nil {
|
|
return false, nil, err
|
|
}
|
|
if len(values) != len(ct.columnNames) {
|
|
return false, nil, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
|
|
}
|
|
|
|
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
|
|
for i, val := range values {
|
|
buf, err = encodePreparedStatementArgument(ct.conn.connInfo, buf, sd.Fields[i].DataTypeOID, val)
|
|
if err != nil {
|
|
return false, nil, err
|
|
}
|
|
}
|
|
|
|
if len(buf) > 65536 {
|
|
return true, buf, nil
|
|
}
|
|
}
|
|
|
|
return false, buf, nil
|
|
}
|
|
|
|
// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion.
|
|
// It returns the number of rows copied and an error.
|
|
//
|
|
// CopyFrom requires all values use the binary format. Almost all types
|
|
// implemented by pgx use the binary format by default. Types implementing
|
|
// Encoder can only be used if they encode to the binary format.
|
|
func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
|
|
ct := ©From{
|
|
conn: c,
|
|
tableName: tableName,
|
|
columnNames: columnNames,
|
|
rowSrc: rowSrc,
|
|
readerErrChan: make(chan error),
|
|
}
|
|
|
|
return ct.run(ctx)
|
|
}
|