diff --git a/conn.go b/conn.go index a9129ade..e9641508 100644 --- a/conn.go +++ b/conn.go @@ -207,6 +207,111 @@ func (c *conn) SelectFloat32(sql string) (f float32, err error) { return } +func (c *conn) SelectAllString(sql string) (strings []string, err error) { + if err = c.sendSimpleQuery(sql); err != nil { + return + } + + strings = make([]string, 0) + + for { + var t byte + var r *messageReader + if t, r, err = c.rxMsg(); err == nil { + switch t { + case readyForQuery: + return + case rowDescription: + case dataRow: + strings = append(strings, c.rxDataRowFirstValue(r)) + case commandComplete: + default: + if err = c.processContextFreeMsg(t, r); err != nil { + return + } + } + } else { + return + } + } + + panic("Unreachable") +} + +func (c *conn) selectAllInt(sql string, size int) (ints []int64, err error) { + var strings []string + strings, err = c.SelectAllString(sql) + if err != nil { + return + } + + ints = make([]int64, len(strings)) + for i, s := range strings { + ints[i], err = strconv.ParseInt(s, 10, size) + if err != nil { + return + } + } + + return +} + +func (c *conn) SelectAllInt64(sql string) (ints []int64, err error) { + return c.selectAllInt(sql, 64) +} + +func (c *conn) SelectAllInt32(sql string) (ints []int32, err error) { + var int64s []int64 + int64s, err = c.selectAllInt(sql, 32) + ints = make([]int32, len(int64s)) + for i := 0; i < len(int64s); i++ { + ints[i] = int32(int64s[i]) + } + return +} + +func (c *conn) SelectAllInt16(sql string) (ints []int16, err error) { + var int64s []int64 + int64s, err = c.selectAllInt(sql, 16) + ints = make([]int16, len(int64s)) + for i := 0; i < len(int64s); i++ { + ints[i] = int16(int64s[i]) + } + return +} + +func (c *conn) selectAllFloat(sql string, size int) (floats []float64, err error) { + var strings []string + strings, err = c.SelectAllString(sql) + if err != nil { + return + } + + floats = make([]float64, len(strings)) + for i, s := range strings { + floats[i], err = strconv.ParseFloat(s, size) + if err != nil { + return + } + } + + return +} + +func (c *conn) SelectAllFloat64(sql string) (floats []float64, err error) { + return c.selectAllFloat(sql, 64) +} + +func (c *conn) SelectAllFloat32(sql string) (floats []float32, err error) { + var float64s []float64 + float64s, err = c.selectAllFloat(sql, 32) + floats = make([]float32, len(float64s)) + for i := 0; i < len(float64s); i++ { + floats[i] = float32(float64s[i]) + } + return +} + func (c *conn) sendSimpleQuery(sql string) (err error) { bufSize := 5 + len(sql) + 1 // message identifier (1), message size (4), null string terminator (1) buf := c.getBuf(bufSize) diff --git a/conn_test.go b/conn_test.go index 5dd4b4c1..19838d46 100644 --- a/conn_test.go +++ b/conn_test.go @@ -210,3 +210,112 @@ func TestSelectFloat32(t *testing.T) { t.Error("Received incorrect float32") } } + + +func TestSelectAllString(t *testing.T) { + conn := getSharedConn() + + s, err := conn.SelectAllString("select * from (values ('Matthew'), ('Mark'), ('Luke'), ('John')) t") + if err != nil { + t.Fatal("Unable to select all strings: " + err.Error()) + } + + if s[0] != "Matthew" || s[1] != "Mark" || s[2] != "Luke" || s[3] != "John" { + t.Error("Received incorrect strings") + } +} + +func TestSelectAllInt64(t *testing.T) { + conn := getSharedConn() + + i, err := conn.SelectAllInt64("select * from (values (1), (2)) t") + if err != nil { + t.Fatal("Unable to select all int64: " + err.Error()) + } + + if i[0] != 1 || i[1] != 2 { + t.Error("Received incorrect int64s") + } + + i, err = conn.SelectAllInt64("select power(2,65)::numeric") + if err == nil || !strings.Contains(err.Error(), "value out of range") { + t.Error("Expected value out of range error when selecting number greater than max int64") + } + + i, err = conn.SelectAllInt64("select -power(2,65)::numeric") + if err == nil || !strings.Contains(err.Error(), "value out of range") { + t.Error("Expected value out of range error when selecting number less than min int64") + } +} + +func TestSelectAllInt32(t *testing.T) { + conn := getSharedConn() + + i, err := conn.SelectAllInt32("select * from (values (1), (2)) t") + if err != nil { + t.Fatal("Unable to select all int32: " + err.Error()) + } + + if i[0] != 1 || i[1] != 2 { + t.Error("Received incorrect int32") + } + + i, err = conn.SelectAllInt32("select power(2,33)::numeric") + if err == nil || !strings.Contains(err.Error(), "value out of range") { + t.Error("Expected value out of range error when selecting number greater than max int32") + } + + i, err = conn.SelectAllInt32("select -power(2,33)::numeric") + if err == nil || !strings.Contains(err.Error(), "value out of range") { + t.Error("Expected value out of range error when selecting number less than min int32") + } +} + +func TestSelectAllInt16(t *testing.T) { + conn := getSharedConn() + + i, err := conn.SelectAllInt16("select * from (values (1), (2)) t") + if err != nil { + t.Fatal("Unable to select all int16: " + err.Error()) + } + + if i[0] != 1 || i[1] != 2 { + t.Error("Received incorrect int16") + } + + i, err = conn.SelectAllInt16("select power(2,17)::numeric") + if err == nil || !strings.Contains(err.Error(), "value out of range") { + t.Error("Expected value out of range error when selecting number greater than max int16") + } + + i, err = conn.SelectAllInt16("select -power(2,17)::numeric") + if err == nil || !strings.Contains(err.Error(), "value out of range") { + t.Error("Expected value out of range error when selecting number less than min int16") + } +} + +func TestSelectAllFloat64(t *testing.T) { + conn := getSharedConn() + + f, err := conn.SelectAllFloat64("select * from (values (1.23), (4.56)) t") + if err != nil { + t.Fatal("Unable to select all float64: " + err.Error()) + } + + if f[0] != 1.23 || f[1] != 4.56 { + t.Error("Received incorrect float64") + } +} + +func TestSelectAllFloat32(t *testing.T) { + conn := getSharedConn() + + f, err := conn.SelectAllFloat32("select * from (values (1.23), (4.56)) t") + if err != nil { + t.Fatal("Unable to select all float32: " + err.Error()) + } + + if f[0] != 1.23 || f[1] != 4.56 { + t.Error("Received incorrect float32") + } +} \ No newline at end of file