From c6cec81e2cf5966c1c6acdeba1d59acea9b71225 Mon Sep 17 00:00:00 2001 From: Fredrik Petrini Date: Mon, 8 Oct 2018 11:39:18 +0200 Subject: [PATCH] Fix: Handle (n > 0 and err == io.EOF) in CopyFromReader as per io.Reader documentation --- copy_from.go | 2 +- copy_from_test.go | 93 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/copy_from.go b/copy_from.go index 314d441f..27e2fc9a 100644 --- a/copy_from.go +++ b/copy_from.go @@ -298,7 +298,7 @@ func (c *Conn) CopyFromReader(r io.Reader, sql string) error { sp := len(buf) for { n, err := r.Read(buf[5:cap(buf)]) - if err == io.EOF { + if err == io.EOF && n == 0 { break } buf = buf[0 : n+5] diff --git a/copy_from_test.go b/copy_from_test.go index 0ed88b72..4c239b05 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -1,7 +1,12 @@ package pgx_test import ( + "compress/gzip" + "fmt" + "io/ioutil" + "os" "reflect" + "strconv" "strings" "testing" "time" @@ -639,3 +644,91 @@ func TestConnCopyFromReaderNoTableError(t *testing.T) { ensureConnValid(t, conn) } + +func TestConnCopyFromGzipReader(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int4, + b varchar + )`) + + f, err := ioutil.TempFile("", "*") + if err != nil { + t.Fatalf("Unexpected error for ioutil.TempFile: %v", err) + } + + gw := gzip.NewWriter(f) + + inputRows := [][]interface{}{} + for i := 0; i < 1000; i++ { + val := strconv.Itoa(i * i) + inputRows = append(inputRows, []interface{}{int32(i), val}) + _, err = gw.Write([]byte(fmt.Sprintf("%d,\"%s\"\n", i, val))) + if err != nil { + t.Errorf("Unexpected error for gw.Write: %v", err) + } + } + + err = gw.Close() + if err != nil { + t.Fatalf("Unexpected error for gw.Close: %v", err) + } + + _, err = f.Seek(0, 0) + if err != nil { + t.Fatalf("Unexpected error for f.Seek: %v", err) + } + + gr, err := gzip.NewReader(f) + if err != nil { + t.Fatalf("Unexpected error for gzip.NewReader: %v", err) + } + + err = conn.CopyFromReader(gr, "COPY foo FROM STDIN WITH (FORMAT csv)") + if err != nil { + t.Errorf("Unexpected error for CopyFromReader: %v", err) + } + + err = gr.Close() + if err != nil { + t.Errorf("Unexpected error for gr.Close: %v", err) + } + + err = f.Close() + if err != nil { + t.Errorf("Unexpected error for f.Close: %v", err) + } + + err = os.Remove(f.Name()) + if err != nil { + t.Errorf("Unexpected error for os.Remove: %v", err) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal") + } + + ensureConnValid(t, conn) +}