diff --git a/copy_from.go b/copy_from.go index afa80a1d..c4540f3e 100644 --- a/copy_from.go +++ b/copy_from.go @@ -35,6 +35,31 @@ func (ctr *copyFromRows) Err() error { return nil } +// CopyFromSlice returns a CopyFromSource interface over a dynamic func +// making it usable by *Conn.CopyFrom. +func CopyFromSlice(length int, next func(int) ([]interface{}, error)) CopyFromSource { + return ©FromSlice{next: next, idx: -1, len: length} +} + +type copyFromSlice struct { + next func(int) ([]interface{}, error) + idx int + len int +} + +func (cts *copyFromSlice) Next() bool { + cts.idx++ + return cts.idx < cts.len +} + +func (cts *copyFromSlice) Values() ([]interface{}, error) { + return cts.next(cts.idx) +} + +func (cts *copyFromSlice) Err() error { + return nil +} + // CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data. type CopyFromSource interface { // Next returns true if there is another row and makes the next row data diff --git a/copy_from_test.go b/copy_from_test.go index 9eaca011..35328205 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -69,6 +69,65 @@ func TestConnCopyFromSmall(t *testing.T) { ensureConnValid(t, conn) } +func TestConnCopyFromSliceSmall(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g timestamptz + )`) + + tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) + + inputRows := [][]interface{}{ + {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, + {nil, nil, nil, nil, nil, nil, nil}, + } + + copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, + pgx.CopyFromSlice(len(inputRows), func(i int) ([]interface{}, error) { + return inputRows[i], nil + })) + if err != nil { + t.Errorf("Unexpected error for CopyFrom: %v", err) + } + if int(copyCount) != len(inputRows) { + t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query(context.Background(), "select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) + } + + ensureConnValid(t, conn) +} + func TestConnCopyFromLarge(t *testing.T) { t.Parallel() diff --git a/doc.go b/doc.go index 2a4ede48..2c8f3b0d 100644 --- a/doc.go +++ b/doc.go @@ -260,6 +260,22 @@ interface. Or implement CopyFromSource to avoid buffering the entire data set in pgx.CopyFromRows(rows), ) +When you already have a typed array using CopyFromSlice can be more convenient. + + rows := []User{ + {"John", "Smith", 36}, + {"Jane", "Doe", 29}, + } + + copyCount, err := conn.CopyFrom( + context.Background(), + pgx.Identifier{"people"}, + []string{"first_name", "last_name", "age"}, + pgx.CopyFromSlice(len(rows), func(i int) ([]interface{}, error) { + return []interface{user.FirstName, user.LastName, user.Age}, nil + }), + ) + CopyFrom can be faster than an insert with as few as 5 rows. Listen and Notify