diff --git a/connection.go b/connection.go index 2ff3cc57..e30ce681 100644 --- a/connection.go +++ b/connection.go @@ -37,6 +37,15 @@ func (e NotSingleRowError) Error() string { return fmt.Sprintf("Expected to find 1 row exactly, instead found %d", e.RowCount) } +type UnexpectedColumnCountError struct { + ExpectedCount int16 + ActualCount int16 +} + +func (e UnexpectedColumnCountError) Error() string { + return fmt.Sprintf("Expected result to have %d column(s), instead it has %d", e.ExpectedCount, e.ActualCount) +} + func Connect(parameters ConnectionParameters) (c *Connection, err error) { c = new(Connection) @@ -165,11 +174,16 @@ func (c *Connection) SelectRow(sql string) (row map[string]interface{}, err erro return } +// Returns a UnexpectedColumnCountError if exactly one column is not found // Returns a NotSingleRowError if exactly one row is not found func (c *Connection) SelectValue(sql string) (v interface{}, err error) { var numRowsFound int64 onDataRow := func(r *DataRowReader) error { + if len(r.fields) != 1 { + return UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: int16(len(r.fields))} + } + numRowsFound++ v = r.ReadValue() return nil @@ -183,9 +197,14 @@ func (c *Connection) SelectValue(sql string) (v interface{}, err error) { return } +// Returns a UnexpectedColumnCountError if exactly one column is not found func (c *Connection) SelectValues(sql string) (values []interface{}, err error) { values = make([]interface{}, 0, 8) onDataRow := func(r *DataRowReader) error { + if len(r.fields) != 1 { + return UnexpectedColumnCountError{ExpectedCount: 1, ActualCount: int16(len(r.fields))} + } + values = append(values, r.ReadValue()) return nil } diff --git a/connection_test.go b/connection_test.go index 70fae107..efb3952f 100644 --- a/connection_test.go +++ b/connection_test.go @@ -236,6 +236,11 @@ func TestConnectionSelectValue(t *testing.T) { if _, ok := err.(NotSingleRowError); !ok { t.Error("Multiple matching rows should have returned NotSingleRowError") } + + _, err = conn.SelectValue("select 'Matthew', 'Mark'") + if _, ok := err.(UnexpectedColumnCountError); !ok { + t.Error("Multiple columns should have returned UnexpectedColumnCountError") + } } func TestSelectValues(t *testing.T) { @@ -262,4 +267,9 @@ func TestSelectValues(t *testing.T) { test("select * from (values ('Matthew'), ('Mark'), ('Luke'), ('John')) t", []interface{}{"Matthew", "Mark", "Luke", "John"}) test("select * from (values ('Matthew'), (null)) t", []interface{}{"Matthew", nil}) test("select * from (values (1::int4), (2::int4), (null), (3::int4)) t", []interface{}{int32(1), int32(2), nil, int32(3)}) + + _, err := conn.SelectValues("select 'Matthew', 'Mark'") + if _, ok := err.(UnexpectedColumnCountError); !ok { + t.Error("Multiple columns should have returned UnexpectedColumnCountError") + } }