diff --git a/conn.go b/conn.go index afd83414..d429a151 100644 --- a/conn.go +++ b/conn.go @@ -128,6 +128,10 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*messageReader, []fie panic("Unreachable") } +// Null values are not included in rows. However, because maps return the 0 value +// for missing values this flattens nulls to empty string. If the caller needs to +// distinguish between a real empty string and a null it can use the comma ok +// pattern when accessing the map func (c *Connection) SelectRows(sql string) (rows []map[string]string, err error) { rows = make([]map[string]string, 0, 8) onDataRow := func(r *messageReader, fields []fieldDescription) error { @@ -140,7 +144,11 @@ func (c *Connection) SelectRows(sql string) (rows []map[string]string, err error func (c *Connection) SelectString(sql string) (s string, err error) { onDataRow := func(r *messageReader, _ []fieldDescription) error { - s = c.rxDataRowFirstValue(r) + var null bool + s, null = c.rxDataRowFirstValue(r) + if null { + return errors.New("Unexpected NULL") + } return nil } err = c.SelectFunc(sql, onDataRow) @@ -201,7 +209,11 @@ func (c *Connection) SelectFloat32(sql string) (f float32, err error) { func (c *Connection) SelectAllString(sql string) (strings []string, err error) { strings = make([]string, 0, 8) onDataRow := func(r *messageReader, _ []fieldDescription) error { - strings = append(strings, c.rxDataRowFirstValue(r)) + s, null := c.rxDataRowFirstValue(r) + if null { + return errors.New("Unexpected NULL") + } + strings = append(strings, s) return nil } err = c.SelectFunc(sql, onDataRow) @@ -211,8 +223,12 @@ func (c *Connection) SelectAllString(sql string) (strings []string, err error) { func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) { ints = make([]int64, 0, 8) onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { + s, null := c.rxDataRowFirstValue(r) + if null { + return errors.New("Unexpected NULL") + } var i int64 - i, parseError = strconv.ParseInt(c.rxDataRowFirstValue(r), 10, 64) + i, parseError = strconv.ParseInt(s, 10, 64) ints = append(ints, i) return } @@ -223,8 +239,12 @@ func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) { func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) { ints = make([]int32, 0, 8) onDataRow := func(r *messageReader, fields []fieldDescription) (parseError error) { + s, null := c.rxDataRowFirstValue(r) + if null { + return errors.New("Unexpected NULL") + } var i int64 - i, parseError = strconv.ParseInt(c.rxDataRowFirstValue(r), 10, 32) + i, parseError = strconv.ParseInt(s, 10, 32) ints = append(ints, int32(i)) return } @@ -235,8 +255,12 @@ func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) { func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) { ints = make([]int16, 0, 8) onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { + s, null := c.rxDataRowFirstValue(r) + if null { + return errors.New("Unexpected NULL") + } var i int64 - i, parseError = strconv.ParseInt(c.rxDataRowFirstValue(r), 10, 16) + i, parseError = strconv.ParseInt(s, 10, 16) ints = append(ints, int16(i)) return } @@ -247,8 +271,12 @@ func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) { func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error) { floats = make([]float64, 0, 8) onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { + s, null := c.rxDataRowFirstValue(r) + if null { + return errors.New("Unexpected NULL") + } var f float64 - f, parseError = strconv.ParseFloat(c.rxDataRowFirstValue(r), 64) + f, parseError = strconv.ParseFloat(s, 64) floats = append(floats, f) return } @@ -259,8 +287,12 @@ func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error) func (c *Connection) SelectAllFloat32(sql string) (floats []float32, err error) { floats = make([]float32, 0, 8) onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) { + s, null := c.rxDataRowFirstValue(r) + if null { + return errors.New("Unexpected NULL") + } var f float64 - f, parseError = strconv.ParseFloat(c.rxDataRowFirstValue(r), 32) + f, parseError = strconv.ParseFloat(s, 32) floats = append(floats, float32(f)) return } @@ -440,20 +472,25 @@ func (c *Connection) rxDataRow(r *messageReader, fields []fieldDescription) (row row = make(map[string]string, fieldCount) for i := int16(0); i < fieldCount; i++ { - // TODO - handle nulls size := r.readInt32() - row[fields[i].name] = r.readByteString(size) + if size > -1 { + row[fields[i].name] = r.readByteString(size) + } } return } -func (c *Connection) rxDataRowFirstValue(r *messageReader) (s string) { +func (c *Connection) rxDataRowFirstValue(r *messageReader) (s string, null bool) { r.readInt16() // ignore field count - // TODO - handle nulls size := r.readInt32() - s = r.readByteString(size) - return s + if size > -1 { + s = r.readByteString(size) + } else { + null = true + } + + return } func (c *Connection) rxCommandComplete(r *messageReader) string { diff --git a/conn_test.go b/conn_test.go index 60b67663..bf6279f6 100644 --- a/conn_test.go +++ b/conn_test.go @@ -136,7 +136,7 @@ func TestSelectFunc(t *testing.T) { func TestSelectRows(t *testing.T) { conn := getSharedConnection() - rows, err := conn.SelectRows("select 'Jack' as name") + rows, err := conn.SelectRows("select 'Jack' as name, null as position") if err != nil { t.Fatal("Query failed") } @@ -146,7 +146,15 @@ func TestSelectRows(t *testing.T) { } if rows[0]["name"] != "Jack" { - t.Fatal("Received incorrect name") + t.Error("Received incorrect name") + } + + value, presence := rows[0]["position"] + if value != "" { + t.Error("Should have received empty string for null") + } + if presence != false { + t.Error("Null value shouldn't have been present in map") } } @@ -155,11 +163,14 @@ func TestSelectString(t *testing.T) { s, err := conn.SelectString("select 'foo'") if err != nil { - t.Fatal("Unable to select string: " + err.Error()) + t.Error("Unable to select string: " + err.Error()) + } else if s != "foo" { + t.Error("Received incorrect string") } - if s != "foo" { - t.Error("Received incorrect string") + _, err = conn.SelectString("select null") + if err == nil { + t.Error("Should have received error on null") } } @@ -184,6 +195,11 @@ func TestSelectInt64(t *testing.T) { if err == nil || !strings.Contains(err.Error(), "value out of range") { t.Error("Expected value out of range error when selecting number less than min int64") } + + _, err = conn.SelectInt64("select null") + if err == nil || !strings.Contains(err.Error(), "NULL") { + t.Error("Should have received error on null") + } } func TestSelectInt32(t *testing.T) { @@ -207,6 +223,11 @@ func TestSelectInt32(t *testing.T) { if err == nil || !strings.Contains(err.Error(), "value out of range") { t.Error("Expected value out of range error when selecting number less than min int32") } + + _, err = conn.SelectInt32("select null") + if err == nil || !strings.Contains(err.Error(), "NULL") { + t.Error("Should have received error on null") + } } func TestSelectInt16(t *testing.T) { @@ -230,6 +251,11 @@ func TestSelectInt16(t *testing.T) { if err == nil || !strings.Contains(err.Error(), "value out of range") { t.Error("Expected value out of range error when selecting number less than min int16") } + + _, err = conn.SelectInt16("select null") + if err == nil || !strings.Contains(err.Error(), "NULL") { + t.Error("Should have received error on null") + } } func TestSelectFloat64(t *testing.T) { @@ -243,6 +269,11 @@ func TestSelectFloat64(t *testing.T) { if f != 1.23 { t.Error("Received incorrect float64") } + + _, err = conn.SelectFloat64("select null") + if err == nil || !strings.Contains(err.Error(), "NULL") { + t.Error("Should have received error on null") + } } func TestSelectFloat32(t *testing.T) { @@ -256,6 +287,11 @@ func TestSelectFloat32(t *testing.T) { if f != 1.23 { t.Error("Received incorrect float32") } + + _, err = conn.SelectFloat32("select null") + if err == nil || !strings.Contains(err.Error(), "NULL") { + t.Error("Should have received error on null") + } } func TestSelectAllString(t *testing.T) { @@ -269,6 +305,11 @@ func TestSelectAllString(t *testing.T) { if s[0] != "Matthew" || s[1] != "Mark" || s[2] != "Luke" || s[3] != "John" { t.Error("Received incorrect strings") } + + _, err = conn.SelectAllString("select * from (values ('Matthew'), (null)) t") + if err == nil || !strings.Contains(err.Error(), "NULL") { + t.Error("Should have received error on null") + } } func TestSelectAllInt64(t *testing.T) { @@ -292,6 +333,11 @@ func TestSelectAllInt64(t *testing.T) { if err == nil || !strings.Contains(err.Error(), "value out of range") { t.Error("Expected value out of range error when selecting number less than min int64") } + + _, err = conn.SelectAllInt64("select null") + if err == nil || !strings.Contains(err.Error(), "NULL") { + t.Error("Should have received error on null") + } } func TestSelectAllInt32(t *testing.T) { @@ -315,6 +361,11 @@ func TestSelectAllInt32(t *testing.T) { if err == nil || !strings.Contains(err.Error(), "value out of range") { t.Error("Expected value out of range error when selecting number less than min int32") } + + _, err = conn.SelectAllInt32("select null") + if err == nil || !strings.Contains(err.Error(), "NULL") { + t.Error("Should have received error on null") + } } func TestSelectAllInt16(t *testing.T) { @@ -338,6 +389,11 @@ func TestSelectAllInt16(t *testing.T) { if err == nil || !strings.Contains(err.Error(), "value out of range") { t.Error("Expected value out of range error when selecting number less than min int16") } + + _, err = conn.SelectAllInt16("select null") + if err == nil || !strings.Contains(err.Error(), "NULL") { + t.Error("Should have received error on null") + } } func TestSelectAllFloat64(t *testing.T) { @@ -351,6 +407,11 @@ func TestSelectAllFloat64(t *testing.T) { if f[0] != 1.23 || f[1] != 4.56 { t.Error("Received incorrect float64") } + + _, err = conn.SelectAllFloat64("select null") + if err == nil || !strings.Contains(err.Error(), "NULL") { + t.Error("Should have received error on null") + } } func TestSelectAllFloat32(t *testing.T) { @@ -364,4 +425,9 @@ func TestSelectAllFloat32(t *testing.T) { if f[0] != 1.23 || f[1] != 4.56 { t.Error("Received incorrect float32") } + + _, err = conn.SelectAllFloat32("select null") + if err == nil || !strings.Contains(err.Error(), "NULL") { + t.Error("Should have received error on null") + } }