From c03ac1519e710a7c8b88d2f4f74e3883d2116b54 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 8 May 2020 00:06:42 -0500 Subject: [PATCH] Improve stdlib performance with large result sets MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In a hot path type assertions are expensive. Cache the already correctly typed interfaces. ~20% improvement with 1000 rows. Before: jack@glados ~/dev/pgx/stdlib ±master » PGX_BENCH_SELECT_ROWS_COUNTS='1 10 1000' got ./... -bench=SelectRows -benchmem goos: darwin goarch: amd64 pkg: github.com/jackc/pgx/v4/stdlib BenchmarkSelectRowsScanSimple/1_rows-16 21465 55060 ns/op 1679 B/op 40 allocs/op BenchmarkSelectRowsScanSimple/10_rows-16 16692 71176 ns/op 3827 B/op 148 allocs/op BenchmarkSelectRowsScanSimple/1000_rows-16 800 1369547 ns/op 248855 B/op 12938 allocs/op BenchmarkSelectRowsScanNull/1_rows-16 20306 57883 ns/op 1940 B/op 54 allocs/op BenchmarkSelectRowsScanNull/10_rows-16 15942 74729 ns/op 4294 B/op 171 allocs/op BenchmarkSelectRowsScanNull/1000_rows-16 829 1326788 ns/op 261291 B/op 13051 allocs/op PASS ok github.com/jackc/pgx/v4/stdlib 10.429s After: jack@glados ~/dev/pgx/stdlib ±master⚡ » PGX_BENCH_SELECT_ROWS_COUNTS='1 10 1000' got ./... -bench=SelectRows -benchmem goos: darwin goarch: amd64 pkg: github.com/jackc/pgx/v4/stdlib BenchmarkSelectRowsScanSimple/1_rows-16 21327 55097 ns/op 2127 B/op 43 allocs/op BenchmarkSelectRowsScanSimple/10_rows-16 16724 69496 ns/op 4276 B/op 151 allocs/op BenchmarkSelectRowsScanSimple/1000_rows-16 1009 1124573 ns/op 250037 B/op 12941 allocs/op BenchmarkSelectRowsScanNull/1_rows-16 20577 58117 ns/op 2396 B/op 57 allocs/op BenchmarkSelectRowsScanNull/10_rows-16 16402 72533 ns/op 4750 B/op 174 allocs/op BenchmarkSelectRowsScanNull/1000_rows-16 1010 1161437 ns/op 261735 B/op 13054 allocs/op PASS ok github.com/jackc/pgx/v4/stdlib 10.363s --- stdlib/sql.go | 142 +++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 112 insertions(+), 30 deletions(-) diff --git a/stdlib/sql.go b/stdlib/sql.go index 8ff3fd49..db8eefc0 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -349,11 +349,14 @@ func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (dri } type Rows struct { - conn *Conn - rows pgx.Rows - values []interface{} - skipNext bool - skipNextMore bool + conn *Conn + rows pgx.Rows + values []interface{} + driverValuers []driver.Valuer + textDecoders []pgtype.TextDecoder + binaryDecoders []pgtype.BinaryDecoder + skipNext bool + skipNextMore bool } func (r *Rows) Columns() []string { @@ -444,42 +447,112 @@ func (r *Rows) Close() error { } func (r *Rows) Next(dest []driver.Value) error { + ci := r.conn.conn.ConnInfo() + fieldDescriptions := r.rows.FieldDescriptions() + if r.values == nil { - r.values = make([]interface{}, len(r.rows.FieldDescriptions())) - for i, fd := range r.rows.FieldDescriptions() { + r.values = make([]interface{}, len(fieldDescriptions)) + r.driverValuers = make([]driver.Valuer, len(fieldDescriptions)) + r.textDecoders = make([]pgtype.TextDecoder, len(fieldDescriptions)) + r.binaryDecoders = make([]pgtype.BinaryDecoder, len(fieldDescriptions)) + + for i, fd := range fieldDescriptions { switch fd.DataTypeOID { case pgtype.BoolOID: - r.values[i] = &pgtype.Bool{} + v := &pgtype.Bool{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.ByteaOID: - r.values[i] = &pgtype.Bytea{} + v := &pgtype.Bytea{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.CIDOID: - r.values[i] = &pgtype.CID{} + v := &pgtype.CID{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.DateOID: - r.values[i] = &pgtype.Date{} + v := &pgtype.Date{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.Float4OID: - r.values[i] = &pgtype.Float4{} + v := &pgtype.Float4{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.Float8OID: - r.values[i] = &pgtype.Float8{} + v := &pgtype.Float8{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.Int2OID: - r.values[i] = &pgtype.Int2{} + v := &pgtype.Int2{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.Int4OID: - r.values[i] = &pgtype.Int4{} + v := &pgtype.Int4{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.Int8OID: - r.values[i] = &pgtype.Int8{} + v := &pgtype.Int8{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.JSONOID: - r.values[i] = &pgtype.JSON{} + v := &pgtype.JSON{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.JSONBOID: - r.values[i] = &pgtype.JSONB{} + v := &pgtype.JSONB{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.OIDOID: - r.values[i] = &pgtype.OIDValue{} + v := &pgtype.OIDValue{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.TimestampOID: - r.values[i] = &pgtype.Timestamp{} + v := &pgtype.Timestamp{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.TimestamptzOID: - r.values[i] = &pgtype.Timestamptz{} + v := &pgtype.Timestamptz{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.XIDOID: - r.values[i] = &pgtype.XID{} + v := &pgtype.XID{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v default: - r.values[i] = &pgtype.GenericText{} + v := &pgtype.GenericText{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v } } } @@ -500,15 +573,24 @@ func (r *Rows) Next(dest []driver.Value) error { } } - err := r.rows.Scan(r.values...) - if err != nil { - return err - } + for i, rv := range r.rows.RawValues() { + fd := fieldDescriptions[i] + if fd.Format == pgx.BinaryFormatCode { + err := r.binaryDecoders[i].DecodeBinary(ci, rv) + if err != nil { + return fmt.Errorf("scan field %d failed: %v", i, err) + } + } else { + err := r.textDecoders[i].DecodeText(ci, rv) + if err != nil { + return fmt.Errorf("scan field %d failed: %v", i, err) + } + } - for i, v := range r.values { - dest[i], err = v.(driver.Valuer).Value() + var err error + dest[i], err = r.driverValuers[i].Value() if err != nil { - return err + return fmt.Errorf("convert field %d failed: %v", i, err) } }