mirror of
https://github.com/jackc/pgx.git
synced 2025-04-29 07:21:43 +00:00
Added basic null handling
* Connection.SelectRows leaves null values empty * Select* and SelectAll* now error on null refs #4
This commit is contained in:
parent
ee25d4a03a
commit
0b1ac12c0e
59
conn.go
59
conn.go
@ -128,6 +128,10 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*messageReader, []fie
|
|||||||
panic("Unreachable")
|
panic("Unreachable")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Null values are not included in rows. However, because maps return the 0 value
|
||||||
|
// for missing values this flattens nulls to empty string. If the caller needs to
|
||||||
|
// distinguish between a real empty string and a null it can use the comma ok
|
||||||
|
// pattern when accessing the map
|
||||||
func (c *Connection) SelectRows(sql string) (rows []map[string]string, err error) {
|
func (c *Connection) SelectRows(sql string) (rows []map[string]string, err error) {
|
||||||
rows = make([]map[string]string, 0, 8)
|
rows = make([]map[string]string, 0, 8)
|
||||||
onDataRow := func(r *messageReader, fields []fieldDescription) error {
|
onDataRow := func(r *messageReader, fields []fieldDescription) error {
|
||||||
@ -140,7 +144,11 @@ func (c *Connection) SelectRows(sql string) (rows []map[string]string, err error
|
|||||||
|
|
||||||
func (c *Connection) SelectString(sql string) (s string, err error) {
|
func (c *Connection) SelectString(sql string) (s string, err error) {
|
||||||
onDataRow := func(r *messageReader, _ []fieldDescription) error {
|
onDataRow := func(r *messageReader, _ []fieldDescription) error {
|
||||||
s = c.rxDataRowFirstValue(r)
|
var null bool
|
||||||
|
s, null = c.rxDataRowFirstValue(r)
|
||||||
|
if null {
|
||||||
|
return errors.New("Unexpected NULL")
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
err = c.SelectFunc(sql, onDataRow)
|
err = c.SelectFunc(sql, onDataRow)
|
||||||
@ -201,7 +209,11 @@ func (c *Connection) SelectFloat32(sql string) (f float32, err error) {
|
|||||||
func (c *Connection) SelectAllString(sql string) (strings []string, err error) {
|
func (c *Connection) SelectAllString(sql string) (strings []string, err error) {
|
||||||
strings = make([]string, 0, 8)
|
strings = make([]string, 0, 8)
|
||||||
onDataRow := func(r *messageReader, _ []fieldDescription) error {
|
onDataRow := func(r *messageReader, _ []fieldDescription) error {
|
||||||
strings = append(strings, c.rxDataRowFirstValue(r))
|
s, null := c.rxDataRowFirstValue(r)
|
||||||
|
if null {
|
||||||
|
return errors.New("Unexpected NULL")
|
||||||
|
}
|
||||||
|
strings = append(strings, s)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
err = c.SelectFunc(sql, onDataRow)
|
err = c.SelectFunc(sql, onDataRow)
|
||||||
@ -211,8 +223,12 @@ func (c *Connection) SelectAllString(sql string) (strings []string, err error) {
|
|||||||
func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) {
|
func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) {
|
||||||
ints = make([]int64, 0, 8)
|
ints = make([]int64, 0, 8)
|
||||||
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
||||||
|
s, null := c.rxDataRowFirstValue(r)
|
||||||
|
if null {
|
||||||
|
return errors.New("Unexpected NULL")
|
||||||
|
}
|
||||||
var i int64
|
var i int64
|
||||||
i, parseError = strconv.ParseInt(c.rxDataRowFirstValue(r), 10, 64)
|
i, parseError = strconv.ParseInt(s, 10, 64)
|
||||||
ints = append(ints, i)
|
ints = append(ints, i)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -223,8 +239,12 @@ func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) {
|
|||||||
func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) {
|
func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) {
|
||||||
ints = make([]int32, 0, 8)
|
ints = make([]int32, 0, 8)
|
||||||
onDataRow := func(r *messageReader, fields []fieldDescription) (parseError error) {
|
onDataRow := func(r *messageReader, fields []fieldDescription) (parseError error) {
|
||||||
|
s, null := c.rxDataRowFirstValue(r)
|
||||||
|
if null {
|
||||||
|
return errors.New("Unexpected NULL")
|
||||||
|
}
|
||||||
var i int64
|
var i int64
|
||||||
i, parseError = strconv.ParseInt(c.rxDataRowFirstValue(r), 10, 32)
|
i, parseError = strconv.ParseInt(s, 10, 32)
|
||||||
ints = append(ints, int32(i))
|
ints = append(ints, int32(i))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -235,8 +255,12 @@ func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) {
|
|||||||
func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) {
|
func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) {
|
||||||
ints = make([]int16, 0, 8)
|
ints = make([]int16, 0, 8)
|
||||||
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
||||||
|
s, null := c.rxDataRowFirstValue(r)
|
||||||
|
if null {
|
||||||
|
return errors.New("Unexpected NULL")
|
||||||
|
}
|
||||||
var i int64
|
var i int64
|
||||||
i, parseError = strconv.ParseInt(c.rxDataRowFirstValue(r), 10, 16)
|
i, parseError = strconv.ParseInt(s, 10, 16)
|
||||||
ints = append(ints, int16(i))
|
ints = append(ints, int16(i))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -247,8 +271,12 @@ func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) {
|
|||||||
func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error) {
|
func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error) {
|
||||||
floats = make([]float64, 0, 8)
|
floats = make([]float64, 0, 8)
|
||||||
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
||||||
|
s, null := c.rxDataRowFirstValue(r)
|
||||||
|
if null {
|
||||||
|
return errors.New("Unexpected NULL")
|
||||||
|
}
|
||||||
var f float64
|
var f float64
|
||||||
f, parseError = strconv.ParseFloat(c.rxDataRowFirstValue(r), 64)
|
f, parseError = strconv.ParseFloat(s, 64)
|
||||||
floats = append(floats, f)
|
floats = append(floats, f)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -259,8 +287,12 @@ func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error)
|
|||||||
func (c *Connection) SelectAllFloat32(sql string) (floats []float32, err error) {
|
func (c *Connection) SelectAllFloat32(sql string) (floats []float32, err error) {
|
||||||
floats = make([]float32, 0, 8)
|
floats = make([]float32, 0, 8)
|
||||||
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
||||||
|
s, null := c.rxDataRowFirstValue(r)
|
||||||
|
if null {
|
||||||
|
return errors.New("Unexpected NULL")
|
||||||
|
}
|
||||||
var f float64
|
var f float64
|
||||||
f, parseError = strconv.ParseFloat(c.rxDataRowFirstValue(r), 32)
|
f, parseError = strconv.ParseFloat(s, 32)
|
||||||
floats = append(floats, float32(f))
|
floats = append(floats, float32(f))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -440,20 +472,25 @@ func (c *Connection) rxDataRow(r *messageReader, fields []fieldDescription) (row
|
|||||||
|
|
||||||
row = make(map[string]string, fieldCount)
|
row = make(map[string]string, fieldCount)
|
||||||
for i := int16(0); i < fieldCount; i++ {
|
for i := int16(0); i < fieldCount; i++ {
|
||||||
// TODO - handle nulls
|
|
||||||
size := r.readInt32()
|
size := r.readInt32()
|
||||||
|
if size > -1 {
|
||||||
row[fields[i].name] = r.readByteString(size)
|
row[fields[i].name] = r.readByteString(size)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connection) rxDataRowFirstValue(r *messageReader) (s string) {
|
func (c *Connection) rxDataRowFirstValue(r *messageReader) (s string, null bool) {
|
||||||
r.readInt16() // ignore field count
|
r.readInt16() // ignore field count
|
||||||
|
|
||||||
// TODO - handle nulls
|
|
||||||
size := r.readInt32()
|
size := r.readInt32()
|
||||||
|
if size > -1 {
|
||||||
s = r.readByteString(size)
|
s = r.readByteString(size)
|
||||||
return s
|
} else {
|
||||||
|
null = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connection) rxCommandComplete(r *messageReader) string {
|
func (c *Connection) rxCommandComplete(r *messageReader) string {
|
||||||
|
76
conn_test.go
76
conn_test.go
@ -136,7 +136,7 @@ func TestSelectFunc(t *testing.T) {
|
|||||||
func TestSelectRows(t *testing.T) {
|
func TestSelectRows(t *testing.T) {
|
||||||
conn := getSharedConnection()
|
conn := getSharedConnection()
|
||||||
|
|
||||||
rows, err := conn.SelectRows("select 'Jack' as name")
|
rows, err := conn.SelectRows("select 'Jack' as name, null as position")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Query failed")
|
t.Fatal("Query failed")
|
||||||
}
|
}
|
||||||
@ -146,7 +146,15 @@ func TestSelectRows(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if rows[0]["name"] != "Jack" {
|
if rows[0]["name"] != "Jack" {
|
||||||
t.Fatal("Received incorrect name")
|
t.Error("Received incorrect name")
|
||||||
|
}
|
||||||
|
|
||||||
|
value, presence := rows[0]["position"]
|
||||||
|
if value != "" {
|
||||||
|
t.Error("Should have received empty string for null")
|
||||||
|
}
|
||||||
|
if presence != false {
|
||||||
|
t.Error("Null value shouldn't have been present in map")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -155,11 +163,14 @@ func TestSelectString(t *testing.T) {
|
|||||||
|
|
||||||
s, err := conn.SelectString("select 'foo'")
|
s, err := conn.SelectString("select 'foo'")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Unable to select string: " + err.Error())
|
t.Error("Unable to select string: " + err.Error())
|
||||||
|
} else if s != "foo" {
|
||||||
|
t.Error("Received incorrect string")
|
||||||
}
|
}
|
||||||
|
|
||||||
if s != "foo" {
|
_, err = conn.SelectString("select null")
|
||||||
t.Error("Received incorrect string")
|
if err == nil {
|
||||||
|
t.Error("Should have received error on null")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -184,6 +195,11 @@ func TestSelectInt64(t *testing.T) {
|
|||||||
if err == nil || !strings.Contains(err.Error(), "value out of range") {
|
if err == nil || !strings.Contains(err.Error(), "value out of range") {
|
||||||
t.Error("Expected value out of range error when selecting number less than min int64")
|
t.Error("Expected value out of range error when selecting number less than min int64")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err = conn.SelectInt64("select null")
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "NULL") {
|
||||||
|
t.Error("Should have received error on null")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSelectInt32(t *testing.T) {
|
func TestSelectInt32(t *testing.T) {
|
||||||
@ -207,6 +223,11 @@ func TestSelectInt32(t *testing.T) {
|
|||||||
if err == nil || !strings.Contains(err.Error(), "value out of range") {
|
if err == nil || !strings.Contains(err.Error(), "value out of range") {
|
||||||
t.Error("Expected value out of range error when selecting number less than min int32")
|
t.Error("Expected value out of range error when selecting number less than min int32")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err = conn.SelectInt32("select null")
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "NULL") {
|
||||||
|
t.Error("Should have received error on null")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSelectInt16(t *testing.T) {
|
func TestSelectInt16(t *testing.T) {
|
||||||
@ -230,6 +251,11 @@ func TestSelectInt16(t *testing.T) {
|
|||||||
if err == nil || !strings.Contains(err.Error(), "value out of range") {
|
if err == nil || !strings.Contains(err.Error(), "value out of range") {
|
||||||
t.Error("Expected value out of range error when selecting number less than min int16")
|
t.Error("Expected value out of range error when selecting number less than min int16")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err = conn.SelectInt16("select null")
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "NULL") {
|
||||||
|
t.Error("Should have received error on null")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSelectFloat64(t *testing.T) {
|
func TestSelectFloat64(t *testing.T) {
|
||||||
@ -243,6 +269,11 @@ func TestSelectFloat64(t *testing.T) {
|
|||||||
if f != 1.23 {
|
if f != 1.23 {
|
||||||
t.Error("Received incorrect float64")
|
t.Error("Received incorrect float64")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err = conn.SelectFloat64("select null")
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "NULL") {
|
||||||
|
t.Error("Should have received error on null")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSelectFloat32(t *testing.T) {
|
func TestSelectFloat32(t *testing.T) {
|
||||||
@ -256,6 +287,11 @@ func TestSelectFloat32(t *testing.T) {
|
|||||||
if f != 1.23 {
|
if f != 1.23 {
|
||||||
t.Error("Received incorrect float32")
|
t.Error("Received incorrect float32")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err = conn.SelectFloat32("select null")
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "NULL") {
|
||||||
|
t.Error("Should have received error on null")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSelectAllString(t *testing.T) {
|
func TestSelectAllString(t *testing.T) {
|
||||||
@ -269,6 +305,11 @@ func TestSelectAllString(t *testing.T) {
|
|||||||
if s[0] != "Matthew" || s[1] != "Mark" || s[2] != "Luke" || s[3] != "John" {
|
if s[0] != "Matthew" || s[1] != "Mark" || s[2] != "Luke" || s[3] != "John" {
|
||||||
t.Error("Received incorrect strings")
|
t.Error("Received incorrect strings")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err = conn.SelectAllString("select * from (values ('Matthew'), (null)) t")
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "NULL") {
|
||||||
|
t.Error("Should have received error on null")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSelectAllInt64(t *testing.T) {
|
func TestSelectAllInt64(t *testing.T) {
|
||||||
@ -292,6 +333,11 @@ func TestSelectAllInt64(t *testing.T) {
|
|||||||
if err == nil || !strings.Contains(err.Error(), "value out of range") {
|
if err == nil || !strings.Contains(err.Error(), "value out of range") {
|
||||||
t.Error("Expected value out of range error when selecting number less than min int64")
|
t.Error("Expected value out of range error when selecting number less than min int64")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err = conn.SelectAllInt64("select null")
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "NULL") {
|
||||||
|
t.Error("Should have received error on null")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSelectAllInt32(t *testing.T) {
|
func TestSelectAllInt32(t *testing.T) {
|
||||||
@ -315,6 +361,11 @@ func TestSelectAllInt32(t *testing.T) {
|
|||||||
if err == nil || !strings.Contains(err.Error(), "value out of range") {
|
if err == nil || !strings.Contains(err.Error(), "value out of range") {
|
||||||
t.Error("Expected value out of range error when selecting number less than min int32")
|
t.Error("Expected value out of range error when selecting number less than min int32")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err = conn.SelectAllInt32("select null")
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "NULL") {
|
||||||
|
t.Error("Should have received error on null")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSelectAllInt16(t *testing.T) {
|
func TestSelectAllInt16(t *testing.T) {
|
||||||
@ -338,6 +389,11 @@ func TestSelectAllInt16(t *testing.T) {
|
|||||||
if err == nil || !strings.Contains(err.Error(), "value out of range") {
|
if err == nil || !strings.Contains(err.Error(), "value out of range") {
|
||||||
t.Error("Expected value out of range error when selecting number less than min int16")
|
t.Error("Expected value out of range error when selecting number less than min int16")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err = conn.SelectAllInt16("select null")
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "NULL") {
|
||||||
|
t.Error("Should have received error on null")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSelectAllFloat64(t *testing.T) {
|
func TestSelectAllFloat64(t *testing.T) {
|
||||||
@ -351,6 +407,11 @@ func TestSelectAllFloat64(t *testing.T) {
|
|||||||
if f[0] != 1.23 || f[1] != 4.56 {
|
if f[0] != 1.23 || f[1] != 4.56 {
|
||||||
t.Error("Received incorrect float64")
|
t.Error("Received incorrect float64")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err = conn.SelectAllFloat64("select null")
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "NULL") {
|
||||||
|
t.Error("Should have received error on null")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSelectAllFloat32(t *testing.T) {
|
func TestSelectAllFloat32(t *testing.T) {
|
||||||
@ -364,4 +425,9 @@ func TestSelectAllFloat32(t *testing.T) {
|
|||||||
if f[0] != 1.23 || f[1] != 4.56 {
|
if f[0] != 1.23 || f[1] != 4.56 {
|
||||||
t.Error("Received incorrect float32")
|
t.Error("Received incorrect float32")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err = conn.SelectAllFloat32("select null")
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "NULL") {
|
||||||
|
t.Error("Should have received error on null")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user