diff --git a/copy_from.go b/copy_from.go index d26a69dd..abcd2239 100644 --- a/copy_from.go +++ b/copy_from.go @@ -64,10 +64,10 @@ func (cts *copyFromSlice) Err() error { return cts.err } -// CopyFromCh returns a CopyFromSource interface over the provided channel. -// FieldNames is an ordered list of field names to copy from the struct, which -// order must match the order of the columns. -func CopyFromFunc(nxtf func() ([]any, error)) CopyFromSource { +// CopyFromFunc returns a CopyFromSource interface that relies on nxtf for values. +// nxtf returns rows until it either signals an 'end of data' by returning row=nil and err=nil, +// or it returns an error. If nxtf returns an error, the copy is aborted. +func CopyFromFunc(nxtf func() (row []any, err error)) CopyFromSource { return ©FromFunc{next: nxtf} } @@ -79,11 +79,12 @@ type copyFromFunc struct { func (g *copyFromFunc) Next() bool { g.valueRow, g.err = g.next() - return g.err == nil + // only return true if valueRow exists and no error + return g.valueRow != nil && g.err == nil } func (g *copyFromFunc) Values() ([]any, error) { - return g.valueRow, nil + return g.valueRow, g.err } func (g *copyFromFunc) Err() error { diff --git a/copy_from_test.go b/copy_from_test.go index faed1d46..9da23c04 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -2,7 +2,6 @@ package pgx_test import ( "context" - "errors" "fmt" "os" "reflect" @@ -815,7 +814,6 @@ func TestCopyFromFunc(t *testing.T) { )`) dataCh := make(chan int, 1) - closeChanErr := errors.New("closed channel") const channelItems = 10 go func() { @@ -829,14 +827,12 @@ func TestCopyFromFunc(t *testing.T) { pgx.CopyFromFunc(func() ([]any, error) { v, ok := <-dataCh if !ok { - return nil, closeChanErr + return nil, nil } return []any{v}, nil })) - fmt.Print(copyCount, err, "\n") - - require.ErrorIs(t, err, closeChanErr) + require.ErrorIs(t, err, nil) require.EqualValues(t, channelItems, copyCount) rows, err := conn.Query(context.Background(), "select * from foo order by a") @@ -845,5 +841,20 @@ func TestCopyFromFunc(t *testing.T) { require.NoError(t, err) require.Equal(t, []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, nums) + // simulate a failure + copyCount, err = conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, + pgx.CopyFromFunc(func() func() ([]any, error) { + x := 9 + return func() ([]any, error) { + x++ + if x > 100 { + return nil, fmt.Errorf("simulated error") + } + return []any{x}, nil + } + }())) + require.NotErrorIs(t, err, nil) + require.EqualValues(t, 0, copyCount) // no change, due to error + ensureConnValid(t, conn) }