diff --git a/conn.go b/conn.go index 6d12fc17..a9129ade 100644 --- a/conn.go +++ b/conn.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net" + "strconv" ) type conn struct { @@ -122,6 +123,90 @@ func (c *conn) Query(sql string) (rows []map[string]string, err error) { panic("Unreachable") } +func (c *conn) selectOne(sql string) (s string, err error) { + if err = c.sendSimpleQuery(sql); err != nil { + return + } + + for { + var t byte + var r *messageReader + if t, r, err = c.rxMsg(); err == nil { + switch t { + case readyForQuery: + return + case rowDescription: + case dataRow: + s = c.rxDataRowFirstValue(r) + case commandComplete: + default: + if err = c.processContextFreeMsg(t, r); err != nil { + return + } + } + } else { + return + } + } + + panic("Unreachable") +} + +func (c *conn) SelectString(sql string) (s string, err error) { + return c.selectOne(sql) +} + +func (c *conn) selectInt(sql string, size int) (i int64, err error) { + var s string + s, err = c.selectOne(sql) + if err != nil { + return + } + + i, err = strconv.ParseInt(s, 10, size) + return +} + +func (c *conn) SelectInt64(sql string) (i int64, err error) { + return c.selectInt(sql, 64) +} + +func (c *conn) SelectInt32(sql string) (i int32, err error) { + var i64 int64 + i64, err = c.selectInt(sql, 32) + i = int32(i64) + return +} + +func (c *conn) SelectInt16(sql string) (i int16, err error) { + var i64 int64 + i64, err = c.selectInt(sql, 16) + i = int16(i64) + return +} + +func (c *conn) selectFloat(sql string, size int) (f float64, err error) { + var s string + s, err = c.selectOne(sql) + if err != nil { + return + } + + f, err = strconv.ParseFloat(s, size) + return +} + +func (c *conn) SelectFloat64(sql string) (f float64, err error) { + return c.selectFloat(sql, 64) +} + +func (c *conn) SelectFloat32(sql string) (f float32, err error) { + var f64 float64 + f64, err = c.selectFloat(sql, 32) + f = float32(f64) + return +} + func (c *conn) sendSimpleQuery(sql string) (err error) { bufSize := 5 + len(sql) + 1 // message identifier (1), message size (4), null string terminator (1) buf := c.getBuf(bufSize) @@ -270,6 +355,15 @@ func (c *conn) rxDataRow(r *messageReader, fields []fieldDescription) (row map[s return } +func (c *conn) rxDataRowFirstValue(r *messageReader) (s string) { + r.readInt16() // ignore field count + + // TODO - handle nulls + size := r.readInt32() + s = r.readByteString(size) + return s +} + func (c *conn) rxCommandComplete(r *messageReader) string { return r.readString() } diff --git a/conn_test.go b/conn_test.go index f20dc145..884c6990 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,6 +1,7 @@ package pgx import ( + "strings" "testing" ) @@ -96,3 +97,135 @@ func TestQuery(t *testing.T) { t.Fatal("Unable to close connection") } } + +func TestSelectString(t *testing.T) { + conn, err := Connect(map[string]string{"socket": "/private/tmp/.s.PGSQL.5432", "user": "pgx_none", "database": "pgx_test"}) + if err != nil { + t.Fatal("Unable to establish connection") + } + + var s string + s, err = conn.SelectString("select 'foo'") + if err != nil { + t.Fatal("Unable to select string: " + err.Error()) + } + + if s != "foo" { + t.Error("Received incorrect string") + } +} + +func TestSelectInt64(t *testing.T) { + conn, err := Connect(map[string]string{"socket": "/private/tmp/.s.PGSQL.5432", "user": "pgx_none", "database": "pgx_test"}) + if err != nil { + t.Fatal("Unable to establish connection") + } + + var i int64 + i, err = conn.SelectInt64("select 1") + if err != nil { + t.Fatal("Unable to select int64: " + err.Error()) + } + + if i != 1 { + t.Error("Received incorrect int64") + } + + i, err = conn.SelectInt64("select power(2,65)::numeric") + if err == nil || !strings.Contains(err.Error(), "value out of range") { + t.Error("Expected value out of range error when selecting number greater than max int64") + } + + i, err = conn.SelectInt64("select -power(2,65)::numeric") + if err == nil || !strings.Contains(err.Error(), "value out of range") { + t.Error("Expected value out of range error when selecting number less than min int64") + } +} + +func TestSelectInt32(t *testing.T) { + conn, err := Connect(map[string]string{"socket": "/private/tmp/.s.PGSQL.5432", "user": "pgx_none", "database": "pgx_test"}) + if err != nil { + t.Fatal("Unable to establish connection") + } + + var i int32 + i, err = conn.SelectInt32("select 1") + if err != nil { + t.Fatal("Unable to select int32: " + err.Error()) + } + + if i != 1 { + t.Error("Received incorrect int32") + } + + i, err = conn.SelectInt32("select power(2,33)::numeric") + if err == nil || !strings.Contains(err.Error(), "value out of range") { + t.Error("Expected value out of range error when selecting number greater than max int32") + } + + i, err = conn.SelectInt32("select -power(2,33)::numeric") + if err == nil || !strings.Contains(err.Error(), "value out of range") { + t.Error("Expected value out of range error when selecting number less than min int32") + } +} + +func TestSelectInt16(t *testing.T) { + conn, err := Connect(map[string]string{"socket": "/private/tmp/.s.PGSQL.5432", "user": "pgx_none", "database": "pgx_test"}) + if err != nil { + t.Fatal("Unable to establish connection") + } + + var i int16 + i, err = conn.SelectInt16("select 1") + if err != nil { + t.Fatal("Unable to select int16: " + err.Error()) + } + + if i != 1 { + t.Error("Received incorrect int16") + } + + i, err = conn.SelectInt16("select power(2,17)::numeric") + if err == nil || !strings.Contains(err.Error(), "value out of range") { + t.Error("Expected value out of range error when selecting number greater than max int16") + } + + i, err = conn.SelectInt16("select -power(2,17)::numeric") + if err == nil || !strings.Contains(err.Error(), "value out of range") { + t.Error("Expected value out of range error when selecting number less than min int16") + } +} + +func TestSelectFloat64(t *testing.T) { + conn, err := Connect(map[string]string{"socket": "/private/tmp/.s.PGSQL.5432", "user": "pgx_none", "database": "pgx_test"}) + if err != nil { + t.Fatal("Unable to establish connection") + } + + var f float64 + f, err = conn.SelectFloat64("select 1.23") + if err != nil { + t.Fatal("Unable to select float64: " + err.Error()) + } + + if f != 1.23 { + t.Error("Received incorrect float64") + } +} + +func TestSelectFloat32(t *testing.T) { + conn, err := Connect(map[string]string{"socket": "/private/tmp/.s.PGSQL.5432", "user": "pgx_none", "database": "pgx_test"}) + if err != nil { + t.Fatal("Unable to establish connection") + } + + var f float32 + f, err = conn.SelectFloat32("select 1.23") + if err != nil { + t.Fatal("Unable to select float32: " + err.Error()) + } + + if f != 1.23 { + t.Error("Received incorrect float32") + } +}