diff --git a/copy_from.go b/copy_from.go index a2c227fd..b15a0ae1 100644 --- a/copy_from.go +++ b/copy_from.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "io" + "reflect" "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgconn" @@ -64,6 +65,51 @@ func (cts *copyFromSlice) Err() error { return cts.err } +// CopyFromCh returns a CopyFromSource interface over the provided channel. +// FieldNames is an ordered list of field names to copy from the struct, which +// order must match the order of the columns. +func CopyFromCh[T any](ch chan T, fieldNames []string) CopyFromSource { + return ©FromCh[T]{c: ch, fieldNames: fieldNames} +} + +type copyFromCh[T any] struct { + c chan T + fieldNames []string + valueRow []interface{} + err error +} + +func (g *copyFromCh[T]) Next() bool { + g.valueRow = g.valueRow[:0] // Clear buffer + val, ok := <-g.c + if !ok { + return false + } + // Handle both pointer to struct and struct + s := reflect.ValueOf(val) + if s.Kind() == reflect.Ptr { + s = s.Elem() + } + + for i := 0; i < len(g.fieldNames); i++ { + f := s.FieldByName(g.fieldNames[i]) + if !f.IsValid() { + g.err = fmt.Errorf("'%v' field not found in %#v", g.fieldNames[i], s.Interface()) + return false + } + g.valueRow = append(g.valueRow, f.Interface()) + } + return true +} + +func (g *copyFromCh[T]) Values() ([]interface{}, error) { + return g.valueRow, nil +} + +func (g *copyFromCh[T]) Err() error { + return g.err +} + // CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data. type CopyFromSource interface { // Next returns true if there is another row and makes the next row data