From 94052ea9405d1b6b29ea3d32e4c4261f59ee757a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 24 Mar 2016 14:22:16 -0500 Subject: [PATCH] Rows.Scan can ignore column with nil fixes #130 --- query.go | 8 ++++++-- query_test.go | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/query.go b/query.go index a40d17b2..9ccf3e94 100644 --- a/query.go +++ b/query.go @@ -233,8 +233,8 @@ func (e scanArgError) Error() string { // Scan reads the values from the current row into dest values positionally. // dest can include pointers to core types, values implementing the Scanner -// interface, and []byte. []byte will skip the decoding process and directly -// copy the raw bytes received from PostgreSQL. +// interface, []byte, and nil. []byte will skip the decoding process and directly +// copy the raw bytes received from PostgreSQL. nil will skip the value entirely. func (rows *Rows) Scan(dest ...interface{}) (err error) { if len(rows.fields) != len(dest) { err = fmt.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields)) @@ -245,6 +245,10 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { for i, d := range dest { vr, _ := rows.nextColumn() + if d == nil { + continue + } + // Check for []byte first as we allow sidestepping the decoding process and retrieving the raw bytes if b, ok := d.(*[]byte); ok { // If it actually is a bytea then pass it through decodeBytea (so it can be decoded if it is in text format) diff --git a/query_test.go b/query_test.go index 30ed341a..664d0bb6 100644 --- a/query_test.go +++ b/query_test.go @@ -218,6 +218,40 @@ func TestConnQueryReadTooManyValues(t *testing.T) { ensureConnValid(t, conn) } +func TestConnQueryScanIgnoreColumn(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + rows, err := conn.Query("select 1::int8, 2::int8, 3::int8") + if err != nil { + t.Fatalf("conn.Query failed: %v", err) + } + + ok := rows.Next() + if !ok { + t.Fatal("rows.Next terminated early") + } + + var n, m int64 + err = rows.Scan(&n, nil, &m) + if err != nil { + t.Fatalf("rows.Scan failed: %v", err) + } + rows.Close() + + if n != 1 { + t.Errorf("Expected n to equal 1, but it was %d", n) + } + + if m != 3 { + t.Errorf("Expected n to equal 3, but it was %d", m) + } + + ensureConnValid(t, conn) +} + func TestConnQueryScanner(t *testing.T) { t.Parallel()