diff --git a/conn_pool.go b/conn_pool.go index b97ccb28..27ca3531 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -116,10 +116,14 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { } // A connection is available - if len(p.availableConnections) > 0 { - c := p.availableConnections[len(p.availableConnections)-1] + // The pool works like a queue. Available connection will be returned + // from the head. A new connection will be added to the tail. + numAvailable := len(p.availableConnections) + if numAvailable > 0 { + c := p.availableConnections[0] c.poolResetCount = p.resetCount - p.availableConnections = p.availableConnections[:len(p.availableConnections)-1] + copy(p.availableConnections, p.availableConnections[1:]) + p.availableConnections = p.availableConnections[:numAvailable-1] return c, nil } diff --git a/conn_test.go b/conn_test.go index f208be8e..763efd11 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1553,7 +1553,8 @@ func TestListenNotify(t *testing.T) { } // when timeout occurs - ctx, _ = context.WithTimeout(context.Background(), time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() notification, err = listener.WaitForNotification(ctx) if err != context.DeadlineExceeded { t.Errorf("WaitForNotification returned the wrong kind of error: %v", err) @@ -1610,7 +1611,8 @@ func TestUnlistenSpecificChannel(t *testing.T) { t.Fatalf("Unexpected error on Query: %v", rows.Err()) } - ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() notification, err = listener.WaitForNotification(ctx) if err != context.DeadlineExceeded { t.Errorf("WaitForNotification returned the wrong kind of error: %v", err) @@ -1690,7 +1692,8 @@ func TestListenNotifySelfNotification(t *testing.T) { // Notify self and WaitForNotification immediately mustExec(t, conn, "notify self") - ctx, _ := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() notification, err := conn.WaitForNotification(ctx) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) @@ -1708,7 +1711,8 @@ func TestListenNotifySelfNotification(t *testing.T) { t.Fatalf("Unexpected error on Query: %v", rows.Err()) } - ctx, _ = context.WithTimeout(context.Background(), time.Second) + ctx, cncl := context.WithTimeout(context.Background(), time.Second) + defer cncl() notification, err = conn.WaitForNotification(ctx) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) diff --git a/copy_from.go b/copy_from.go index 314d441f..27e2fc9a 100644 --- a/copy_from.go +++ b/copy_from.go @@ -298,7 +298,7 @@ func (c *Conn) CopyFromReader(r io.Reader, sql string) error { sp := len(buf) for { n, err := r.Read(buf[5:cap(buf)]) - if err == io.EOF { + if err == io.EOF && n == 0 { break } buf = buf[0 : n+5] diff --git a/copy_from_test.go b/copy_from_test.go index 0ed88b72..4c239b05 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -1,7 +1,12 @@ package pgx_test import ( + "compress/gzip" + "fmt" + "io/ioutil" + "os" "reflect" + "strconv" "strings" "testing" "time" @@ -639,3 +644,91 @@ func TestConnCopyFromReaderNoTableError(t *testing.T) { ensureConnValid(t, conn) } + +func TestConnCopyFromGzipReader(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int4, + b varchar + )`) + + f, err := ioutil.TempFile("", "*") + if err != nil { + t.Fatalf("Unexpected error for ioutil.TempFile: %v", err) + } + + gw := gzip.NewWriter(f) + + inputRows := [][]interface{}{} + for i := 0; i < 1000; i++ { + val := strconv.Itoa(i * i) + inputRows = append(inputRows, []interface{}{int32(i), val}) + _, err = gw.Write([]byte(fmt.Sprintf("%d,\"%s\"\n", i, val))) + if err != nil { + t.Errorf("Unexpected error for gw.Write: %v", err) + } + } + + err = gw.Close() + if err != nil { + t.Fatalf("Unexpected error for gw.Close: %v", err) + } + + _, err = f.Seek(0, 0) + if err != nil { + t.Fatalf("Unexpected error for f.Seek: %v", err) + } + + gr, err := gzip.NewReader(f) + if err != nil { + t.Fatalf("Unexpected error for gzip.NewReader: %v", err) + } + + err = conn.CopyFromReader(gr, "COPY foo FROM STDIN WITH (FORMAT csv)") + if err != nil { + t.Errorf("Unexpected error for CopyFromReader: %v", err) + } + + err = gr.Close() + if err != nil { + t.Errorf("Unexpected error for gr.Close: %v", err) + } + + err = f.Close() + if err != nil { + t.Errorf("Unexpected error for f.Close: %v", err) + } + + err = os.Remove(f.Name()) + if err != nil { + t.Errorf("Unexpected error for os.Remove: %v", err) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal") + } + + ensureConnValid(t, conn) +} diff --git a/doc.go b/doc.go index a1ddced0..a4ff00e2 100644 --- a/doc.go +++ b/doc.go @@ -121,7 +121,7 @@ database/sql. The second is to use a pointer to a pointer. var foo pgtype.Varchar var bar *string - err := conn.QueryRow("select foo, bar from widgets where id=$1", 42).Scan(&a, &b) + err := conn.QueryRow("select foo, bar from widgets where id=$1", 42).Scan(&foo, &bar) if err != nil { return err } diff --git a/messages.go b/messages.go index 01b799b2..aca6ae2e 100644 --- a/messages.go +++ b/messages.go @@ -24,7 +24,7 @@ type FieldDescription struct { DataType pgtype.OID DataTypeSize int16 DataTypeName string - Modifier uint32 + Modifier int32 FormatCode int16 } @@ -53,6 +53,10 @@ func (fd FieldDescription) PrecisionScale() (precision, scale int64, ok bool) { func (fd FieldDescription) Type() reflect.Type { switch fd.DataType { + case pgtype.Float8OID: + return reflect.TypeOf(float64(0)) + case pgtype.Float4OID: + return reflect.TypeOf(float32(0)) case pgtype.Int8OID: return reflect.TypeOf(int64(0)) case pgtype.Int4OID: diff --git a/pgmock/pgmock.go b/pgmock/pgmock.go index fec0d3f6..4d15f7b8 100644 --- a/pgmock/pgmock.go +++ b/pgmock/pgmock.go @@ -229,7 +229,7 @@ where ( TableAttributeNumber: 65534, DataTypeOID: 26, DataTypeSize: 4, - TypeModifier: 4294967295, + TypeModifier: -1, Format: 0, }, {Name: "typname", @@ -237,7 +237,7 @@ where ( TableAttributeNumber: 1, DataTypeOID: 19, DataTypeSize: 64, - TypeModifier: 4294967295, + TypeModifier: -1, Format: 0, }, }, @@ -455,7 +455,7 @@ where ( TableAttributeNumber: 65534, DataTypeOID: 26, DataTypeSize: 4, - TypeModifier: 4294967295, + TypeModifier: -1, Format: 0, }, {Name: "typname", @@ -463,7 +463,7 @@ where ( TableAttributeNumber: 1, DataTypeOID: 19, DataTypeSize: 64, - TypeModifier: 4294967295, + TypeModifier: -1, Format: 0, }, }, @@ -496,7 +496,7 @@ where ( TableAttributeNumber: 65534, DataTypeOID: 26, DataTypeSize: 4, - TypeModifier: 4294967295, + TypeModifier: -1, Format: 0, }, {Name: "typname", @@ -504,7 +504,7 @@ where ( TableAttributeNumber: 1, DataTypeOID: 19, DataTypeSize: 64, - TypeModifier: 4294967295, + TypeModifier: -1, Format: 0, }, {Name: "typbasetype", @@ -512,7 +512,7 @@ where ( TableAttributeNumber: 65534, DataTypeOID: 26, DataTypeSize: 4, - TypeModifier: 4294967295, + TypeModifier: -1, Format: 0, }, }, diff --git a/pgproto3/row_description.go b/pgproto3/row_description.go index d0df11b0..3c5a6faa 100644 --- a/pgproto3/row_description.go +++ b/pgproto3/row_description.go @@ -19,7 +19,7 @@ type FieldDescription struct { TableAttributeNumber uint16 DataTypeOID uint32 DataTypeSize int16 - TypeModifier uint32 + TypeModifier int32 Format int16 } @@ -57,7 +57,7 @@ func (dst *RowDescription) Decode(src []byte) error { fd.TableAttributeNumber = binary.BigEndian.Uint16(buf.Next(2)) fd.DataTypeOID = binary.BigEndian.Uint32(buf.Next(4)) fd.DataTypeSize = int16(binary.BigEndian.Uint16(buf.Next(2))) - fd.TypeModifier = binary.BigEndian.Uint32(buf.Next(4)) + fd.TypeModifier = int32(binary.BigEndian.Uint32(buf.Next(4))) fd.Format = int16(binary.BigEndian.Uint16(buf.Next(2))) dst.Fields[i] = fd @@ -80,7 +80,7 @@ func (src *RowDescription) Encode(dst []byte) []byte { dst = pgio.AppendUint16(dst, fd.TableAttributeNumber) dst = pgio.AppendUint32(dst, fd.DataTypeOID) dst = pgio.AppendInt16(dst, fd.DataTypeSize) - dst = pgio.AppendUint32(dst, fd.TypeModifier) + dst = pgio.AppendInt32(dst, fd.TypeModifier) dst = pgio.AppendInt16(dst, fd.Format) } diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index a3a96ffd..78f3e6d4 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -869,7 +869,8 @@ func TestConnPingContextCancel(t *testing.T) { } defer closeDB(t, db) - ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() err = db.PingContext(ctx) if err != context.DeadlineExceeded { @@ -923,7 +924,8 @@ func TestConnPrepareContextCancel(t *testing.T) { } defer closeDB(t, db) - ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() _, err = db.PrepareContext(ctx, "select now()") if err != context.DeadlineExceeded { @@ -974,7 +976,8 @@ func TestConnExecContextCancel(t *testing.T) { } defer closeDB(t, db) - ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() _, err = db.ExecContext(ctx, "create temporary table exec_context_test(id serial primary key)") if err != context.DeadlineExceeded { @@ -1027,7 +1030,7 @@ func TestConnQueryContextCancel(t *testing.T) { Name: "n", DataTypeOID: 23, DataTypeSize: 4, - TypeModifier: 4294967295, + TypeModifier: -1, }, }, }), @@ -1145,7 +1148,8 @@ func TestStmtExecContextCancel(t *testing.T) { } defer stmt.Close() - ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() _, err = stmt.ExecContext(ctx, 42) if err != context.DeadlineExceeded { @@ -1202,7 +1206,7 @@ func TestStmtQueryContextCancel(t *testing.T) { Name: "n", DataTypeOID: 23, DataTypeSize: 4, - TypeModifier: 4294967295, + TypeModifier: -1, }, }, }), diff --git a/stress_test.go b/stress_test.go index 114bec81..d6b89c51 100644 --- a/stress_test.go +++ b/stress_test.go @@ -213,7 +213,8 @@ func listenAndPoolUnlistens(pool *pgx.ConnPool, actionNum int) error { return err } - ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() _, err = conn.WaitForNotification(ctx) if err == context.DeadlineExceeded { return nil diff --git a/tx.go b/tx.go index eb6b6805..0fb428fb 100644 --- a/tx.go +++ b/tx.go @@ -147,7 +147,8 @@ func (tx *Tx) CommitEx(ctx context.Context) error { // defer tx.Rollback() is safe even if tx.Commit() will be called first in a // non-error condition. func (tx *Tx) Rollback() error { - ctx, _ := context.WithTimeout(context.Background(), 15*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() return tx.RollbackEx(ctx) } diff --git a/tx_test.go b/tx_test.go index f9a9d5c7..eff5604e 100644 --- a/tx_test.go +++ b/tx_test.go @@ -261,7 +261,8 @@ func TestConnBeginExContextCancel(t *testing.T) { conn := mustConnect(t, mockConfig) - ctx, _ := context.WithTimeout(context.Background(), 50*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() _, err = conn.BeginEx(ctx, nil) if err != context.DeadlineExceeded { @@ -315,7 +316,8 @@ func TestTxCommitExCancel(t *testing.T) { t.Fatal(err) } - ctx, _ := context.WithTimeout(context.Background(), 50*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() err = tx.CommitEx(ctx) if err != context.DeadlineExceeded { t.Errorf("err => %v, want %v", err, context.DeadlineExceeded)