Allowed nxtf to signal end of data by returning nil,nil

Added some test
Improved documentation
pull/1786/head^2
robford 2023-11-07 09:19:16 +01:00 committed by Jack Christensen
parent 9b6d3809d6
commit d38dd85756
2 changed files with 24 additions and 12 deletions

View File

@ -64,10 +64,10 @@ func (cts *copyFromSlice) Err() error {
return cts.err return cts.err
} }
// CopyFromCh returns a CopyFromSource interface over the provided channel. // CopyFromFunc returns a CopyFromSource interface that relies on nxtf for values.
// FieldNames is an ordered list of field names to copy from the struct, which // nxtf returns rows until it either signals an 'end of data' by returning row=nil and err=nil,
// order must match the order of the columns. // or it returns an error. If nxtf returns an error, the copy is aborted.
func CopyFromFunc(nxtf func() ([]any, error)) CopyFromSource { func CopyFromFunc(nxtf func() (row []any, err error)) CopyFromSource {
return &copyFromFunc{next: nxtf} return &copyFromFunc{next: nxtf}
} }
@ -79,11 +79,12 @@ type copyFromFunc struct {
func (g *copyFromFunc) Next() bool { func (g *copyFromFunc) Next() bool {
g.valueRow, g.err = g.next() g.valueRow, g.err = g.next()
return g.err == nil // only return true if valueRow exists and no error
return g.valueRow != nil && g.err == nil
} }
func (g *copyFromFunc) Values() ([]any, error) { func (g *copyFromFunc) Values() ([]any, error) {
return g.valueRow, nil return g.valueRow, g.err
} }
func (g *copyFromFunc) Err() error { func (g *copyFromFunc) Err() error {

View File

@ -2,7 +2,6 @@ package pgx_test
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"os" "os"
"reflect" "reflect"
@ -815,7 +814,6 @@ func TestCopyFromFunc(t *testing.T) {
)`) )`)
dataCh := make(chan int, 1) dataCh := make(chan int, 1)
closeChanErr := errors.New("closed channel")
const channelItems = 10 const channelItems = 10
go func() { go func() {
@ -829,14 +827,12 @@ func TestCopyFromFunc(t *testing.T) {
pgx.CopyFromFunc(func() ([]any, error) { pgx.CopyFromFunc(func() ([]any, error) {
v, ok := <-dataCh v, ok := <-dataCh
if !ok { if !ok {
return nil, closeChanErr return nil, nil
} }
return []any{v}, nil return []any{v}, nil
})) }))
fmt.Print(copyCount, err, "\n") require.ErrorIs(t, err, nil)
require.ErrorIs(t, err, closeChanErr)
require.EqualValues(t, channelItems, copyCount) require.EqualValues(t, channelItems, copyCount)
rows, err := conn.Query(context.Background(), "select * from foo order by a") rows, err := conn.Query(context.Background(), "select * from foo order by a")
@ -845,5 +841,20 @@ func TestCopyFromFunc(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, nums) require.Equal(t, []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, nums)
// simulate a failure
copyCount, err = conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"},
pgx.CopyFromFunc(func() func() ([]any, error) {
x := 9
return func() ([]any, error) {
x++
if x > 100 {
return nil, fmt.Errorf("simulated error")
}
return []any{x}, nil
}
}()))
require.NotErrorIs(t, err, nil)
require.EqualValues(t, 0, copyCount) // no change, due to error
ensureConnValid(t, conn) ensureConnValid(t, conn)
} }