mirror of https://github.com/jackc/pgx.git
Rename pgx.conn to pgx.Connection
parent
d306d42afb
commit
fa4c70907c
76
conn.go
76
conn.go
|
@ -11,7 +11,7 @@ import (
|
|||
"strconv"
|
||||
)
|
||||
|
||||
type conn struct {
|
||||
type Connection struct {
|
||||
conn net.Conn // the underlying TCP or unix domain socket connection
|
||||
buf []byte // work buffer to avoid constant alloc and dealloc
|
||||
pid int32 // backend pid
|
||||
|
@ -24,8 +24,8 @@ type conn struct {
|
|||
// options:
|
||||
// socket: path to unix domain socket
|
||||
// database: name of database
|
||||
func Connect(options map[string]string) (c *conn, err error) {
|
||||
c = new(conn)
|
||||
func Connect(options map[string]string) (c *Connection, err error) {
|
||||
c = new(Connection)
|
||||
|
||||
c.options = make(map[string]string)
|
||||
for k, v := range options {
|
||||
|
@ -82,7 +82,7 @@ func Connect(options map[string]string) (c *conn, err error) {
|
|||
panic("Unreachable")
|
||||
}
|
||||
|
||||
func (c *conn) Close() (err error) {
|
||||
func (c *Connection) Close() (err error) {
|
||||
buf := c.getBuf(5)
|
||||
buf[0] = 'X'
|
||||
binary.BigEndian.PutUint32(buf[1:], 4)
|
||||
|
@ -90,7 +90,7 @@ func (c *conn) Close() (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *conn) query(sql string, onDataRow func(*messageReader, []fieldDescription) error) (err error) {
|
||||
func (c *Connection) query(sql string, onDataRow func(*messageReader, []fieldDescription) error) (err error) {
|
||||
if err = c.sendSimpleQuery(sql); err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -128,7 +128,7 @@ func (c *conn) query(sql string, onDataRow func(*messageReader, []fieldDescripti
|
|||
panic("Unreachable")
|
||||
}
|
||||
|
||||
func (c *conn) Query(sql string) (rows []map[string]string, err error) {
|
||||
func (c *Connection) Query(sql string) (rows []map[string]string, err error) {
|
||||
rows = make([]map[string]string, 0, 8)
|
||||
onDataRow := func(r *messageReader, fields []fieldDescription) error {
|
||||
rows = append(rows, c.rxDataRow(r, fields))
|
||||
|
@ -138,7 +138,7 @@ func (c *conn) Query(sql string) (rows []map[string]string, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *conn) SelectString(sql string) (s string, err error) {
|
||||
func (c *Connection) SelectString(sql string) (s string, err error) {
|
||||
onDataRow := func(r *messageReader, _ []fieldDescription) error {
|
||||
s = c.rxDataRowFirstValue(r)
|
||||
return nil
|
||||
|
@ -147,7 +147,7 @@ func (c *conn) SelectString(sql string) (s string, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *conn) selectInt(sql string, size int) (i int64, err error) {
|
||||
func (c *Connection) selectInt(sql string, size int) (i int64, err error) {
|
||||
var s string
|
||||
s, err = c.SelectString(sql)
|
||||
if err != nil {
|
||||
|
@ -158,25 +158,25 @@ func (c *conn) selectInt(sql string, size int) (i int64, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *conn) SelectInt64(sql string) (i int64, err error) {
|
||||
func (c *Connection) SelectInt64(sql string) (i int64, err error) {
|
||||
return c.selectInt(sql, 64)
|
||||
}
|
||||
|
||||
func (c *conn) SelectInt32(sql string) (i int32, err error) {
|
||||
func (c *Connection) SelectInt32(sql string) (i int32, err error) {
|
||||
var i64 int64
|
||||
i64, err = c.selectInt(sql, 32)
|
||||
i = int32(i64)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) SelectInt16(sql string) (i int16, err error) {
|
||||
func (c *Connection) SelectInt16(sql string) (i int16, err error) {
|
||||
var i64 int64
|
||||
i64, err = c.selectInt(sql, 16)
|
||||
i = int16(i64)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) selectFloat(sql string, size int) (f float64, err error) {
|
||||
func (c *Connection) selectFloat(sql string, size int) (f float64, err error) {
|
||||
var s string
|
||||
s, err = c.SelectString(sql)
|
||||
if err != nil {
|
||||
|
@ -187,18 +187,18 @@ func (c *conn) selectFloat(sql string, size int) (f float64, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *conn) SelectFloat64(sql string) (f float64, err error) {
|
||||
func (c *Connection) SelectFloat64(sql string) (f float64, err error) {
|
||||
return c.selectFloat(sql, 64)
|
||||
}
|
||||
|
||||
func (c *conn) SelectFloat32(sql string) (f float32, err error) {
|
||||
func (c *Connection) SelectFloat32(sql string) (f float32, err error) {
|
||||
var f64 float64
|
||||
f64, err = c.selectFloat(sql, 32)
|
||||
f = float32(f64)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) SelectAllString(sql string) (strings []string, err error) {
|
||||
func (c *Connection) SelectAllString(sql string) (strings []string, err error) {
|
||||
strings = make([]string, 0, 8)
|
||||
onDataRow := func(r *messageReader, _ []fieldDescription) error {
|
||||
strings = append(strings, c.rxDataRowFirstValue(r))
|
||||
|
@ -208,7 +208,7 @@ func (c *conn) SelectAllString(sql string) (strings []string, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *conn) SelectAllInt64(sql string) (ints []int64, err error) {
|
||||
func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) {
|
||||
ints = make([]int64, 0, 8)
|
||||
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
||||
var i int64
|
||||
|
@ -220,7 +220,7 @@ func (c *conn) SelectAllInt64(sql string) (ints []int64, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *conn) SelectAllInt32(sql string) (ints []int32, err error) {
|
||||
func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) {
|
||||
ints = make([]int32, 0, 8)
|
||||
onDataRow := func(r *messageReader, fields []fieldDescription) (parseError error) {
|
||||
var i int64
|
||||
|
@ -232,7 +232,7 @@ func (c *conn) SelectAllInt32(sql string) (ints []int32, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *conn) SelectAllInt16(sql string) (ints []int16, err error) {
|
||||
func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) {
|
||||
ints = make([]int16, 0, 8)
|
||||
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
||||
var i int64
|
||||
|
@ -244,7 +244,7 @@ func (c *conn) SelectAllInt16(sql string) (ints []int16, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *conn) SelectAllFloat64(sql string) (floats []float64, err error) {
|
||||
func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error) {
|
||||
floats = make([]float64, 0, 8)
|
||||
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
||||
var f float64
|
||||
|
@ -256,7 +256,7 @@ func (c *conn) SelectAllFloat64(sql string) (floats []float64, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *conn) SelectAllFloat32(sql string) (floats []float32, err error) {
|
||||
func (c *Connection) SelectAllFloat32(sql string) (floats []float32, err error) {
|
||||
floats = make([]float32, 0, 8)
|
||||
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
||||
var f float64
|
||||
|
@ -268,7 +268,7 @@ func (c *conn) SelectAllFloat32(sql string) (floats []float32, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *conn) sendSimpleQuery(sql string) (err error) {
|
||||
func (c *Connection) sendSimpleQuery(sql string) (err error) {
|
||||
bufSize := 5 + len(sql) + 1 // message identifier (1), message size (4), null string terminator (1)
|
||||
buf := c.getBuf(bufSize)
|
||||
buf[0] = 'Q'
|
||||
|
@ -280,7 +280,7 @@ func (c *conn) sendSimpleQuery(sql string) (err error) {
|
|||
return err
|
||||
}
|
||||
|
||||
func (c *conn) Execute(sql string) (commandTag string, err error) {
|
||||
func (c *Connection) Execute(sql string) (commandTag string, err error) {
|
||||
if err = c.sendSimpleQuery(sql); err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -312,7 +312,7 @@ func (c *conn) Execute(sql string) (commandTag string, err error) {
|
|||
// Processes messages that are not exclusive to one context such as
|
||||
// authentication or query response. The response to these messages
|
||||
// is the same regardless of when they occur.
|
||||
func (c *conn) processContextFreeMsg(t byte, r *messageReader) (err error) {
|
||||
func (c *Connection) processContextFreeMsg(t byte, r *messageReader) (err error) {
|
||||
switch t {
|
||||
case 'S':
|
||||
c.rxParameterStatus(r)
|
||||
|
@ -329,7 +329,7 @@ func (c *conn) processContextFreeMsg(t byte, r *messageReader) (err error) {
|
|||
|
||||
}
|
||||
|
||||
func (c *conn) rxMsg() (t byte, r *messageReader, err error) {
|
||||
func (c *Connection) rxMsg() (t byte, r *messageReader, err error) {
|
||||
var bodySize int32
|
||||
t, bodySize, err = c.rxMsgHeader()
|
||||
if err != nil {
|
||||
|
@ -345,7 +345,7 @@ func (c *conn) rxMsg() (t byte, r *messageReader, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *conn) rxMsgHeader() (t byte, bodySize int32, err error) {
|
||||
func (c *Connection) rxMsgHeader() (t byte, bodySize int32, err error) {
|
||||
buf := c.buf[:5]
|
||||
if _, err = io.ReadFull(c.conn, buf); err != nil {
|
||||
return 0, 0, err
|
||||
|
@ -356,13 +356,13 @@ func (c *conn) rxMsgHeader() (t byte, bodySize int32, err error) {
|
|||
return t, bodySize, nil
|
||||
}
|
||||
|
||||
func (c *conn) rxMsgBody(bodySize int32) (buf []byte, err error) {
|
||||
func (c *Connection) rxMsgBody(bodySize int32) (buf []byte, err error) {
|
||||
buf = c.getBuf(int(bodySize))
|
||||
_, err = io.ReadFull(c.conn, buf)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) rxAuthenticationX(r *messageReader) (err error) {
|
||||
func (c *Connection) rxAuthenticationX(r *messageReader) (err error) {
|
||||
code := r.readInt32()
|
||||
switch code {
|
||||
case 0: // AuthenticationOk
|
||||
|
@ -385,13 +385,13 @@ func hexMD5(s string) string {
|
|||
return hex.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
|
||||
func (c *conn) rxParameterStatus(r *messageReader) {
|
||||
func (c *Connection) rxParameterStatus(r *messageReader) {
|
||||
key := r.readString()
|
||||
value := r.readString()
|
||||
c.runtimeParams[key] = value
|
||||
}
|
||||
|
||||
func (c *conn) rxErrorResponse(r *messageReader) (err PgError) {
|
||||
func (c *Connection) rxErrorResponse(r *messageReader) (err PgError) {
|
||||
for {
|
||||
switch r.readByte() {
|
||||
case 'S':
|
||||
|
@ -410,16 +410,16 @@ func (c *conn) rxErrorResponse(r *messageReader) (err PgError) {
|
|||
panic("Unreachable")
|
||||
}
|
||||
|
||||
func (c *conn) rxBackendKeyData(r *messageReader) {
|
||||
func (c *Connection) rxBackendKeyData(r *messageReader) {
|
||||
c.pid = r.readInt32()
|
||||
c.secretKey = r.readInt32()
|
||||
}
|
||||
|
||||
func (c *conn) rxReadyForQuery(r *messageReader) {
|
||||
func (c *Connection) rxReadyForQuery(r *messageReader) {
|
||||
c.txStatus = r.readByte()
|
||||
}
|
||||
|
||||
func (c *conn) rxRowDescription(r *messageReader) (fields []fieldDescription) {
|
||||
func (c *Connection) rxRowDescription(r *messageReader) (fields []fieldDescription) {
|
||||
fieldCount := r.readInt16()
|
||||
fields = make([]fieldDescription, fieldCount)
|
||||
for i := int16(0); i < fieldCount; i++ {
|
||||
|
@ -435,7 +435,7 @@ func (c *conn) rxRowDescription(r *messageReader) (fields []fieldDescription) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *conn) rxDataRow(r *messageReader, fields []fieldDescription) (row map[string]string) {
|
||||
func (c *Connection) rxDataRow(r *messageReader, fields []fieldDescription) (row map[string]string) {
|
||||
fieldCount := r.readInt16()
|
||||
|
||||
row = make(map[string]string, fieldCount)
|
||||
|
@ -447,7 +447,7 @@ func (c *conn) rxDataRow(r *messageReader, fields []fieldDescription) (row map[s
|
|||
return
|
||||
}
|
||||
|
||||
func (c *conn) rxDataRowFirstValue(r *messageReader) (s string) {
|
||||
func (c *Connection) rxDataRowFirstValue(r *messageReader) (s string) {
|
||||
r.readInt16() // ignore field count
|
||||
|
||||
// TODO - handle nulls
|
||||
|
@ -456,16 +456,16 @@ func (c *conn) rxDataRowFirstValue(r *messageReader) (s string) {
|
|||
return s
|
||||
}
|
||||
|
||||
func (c *conn) rxCommandComplete(r *messageReader) string {
|
||||
func (c *Connection) rxCommandComplete(r *messageReader) string {
|
||||
return r.readString()
|
||||
}
|
||||
|
||||
func (c *conn) txStartupMessage(msg *startupMessage) (err error) {
|
||||
func (c *Connection) txStartupMessage(msg *startupMessage) (err error) {
|
||||
_, err = c.conn.Write(msg.Bytes())
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) txPasswordMessage(password string) (err error) {
|
||||
func (c *Connection) txPasswordMessage(password string) (err error) {
|
||||
bufSize := 5 + len(password) + 1 // message identifier (1), message size (4), password, null string terminator (1)
|
||||
buf := c.getBuf(bufSize)
|
||||
buf[0] = 'p'
|
||||
|
@ -479,7 +479,7 @@ func (c *conn) txPasswordMessage(password string) (err error) {
|
|||
|
||||
// Gets a []byte of n length. If possible it will reuse the connection buffer
|
||||
// otherwise it will allocate a new buffer
|
||||
func (c *conn) getBuf(n int) (buf []byte) {
|
||||
func (c *Connection) getBuf(n int) (buf []byte) {
|
||||
if n <= cap(c.buf) {
|
||||
buf = c.buf[:n]
|
||||
} else {
|
||||
|
|
38
conn_test.go
38
conn_test.go
|
@ -5,18 +5,18 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
var sharedConn *conn
|
||||
var SharedConnection *Connection
|
||||
|
||||
func getSharedConn() (c *conn) {
|
||||
if sharedConn == nil {
|
||||
func getSharedConnection() (c *Connection) {
|
||||
if SharedConnection == nil {
|
||||
var err error
|
||||
sharedConn, err = Connect(map[string]string{"socket": "/private/tmp/.s.PGSQL.5432", "user": "pgx_none", "database": "pgx_test"})
|
||||
SharedConnection, err = Connect(map[string]string{"socket": "/private/tmp/.s.PGSQL.5432", "user": "pgx_none", "database": "pgx_test"})
|
||||
if err != nil {
|
||||
panic("Unable to establish connection")
|
||||
}
|
||||
|
||||
}
|
||||
return sharedConn
|
||||
return SharedConnection
|
||||
}
|
||||
|
||||
func TestConnect(t *testing.T) {
|
||||
|
@ -87,7 +87,7 @@ func TestConnectWithMD5Password(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestExecute(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
conn := getSharedConnection()
|
||||
|
||||
results, err := conn.Execute("create temporary table foo(id serial primary key);")
|
||||
if err != nil {
|
||||
|
@ -116,7 +116,7 @@ func TestExecute(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestQuery(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
conn := getSharedConnection()
|
||||
|
||||
rows, err := conn.Query("select 'Jack' as name")
|
||||
if err != nil {
|
||||
|
@ -133,7 +133,7 @@ func TestQuery(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSelectString(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
conn := getSharedConnection()
|
||||
|
||||
s, err := conn.SelectString("select 'foo'")
|
||||
if err != nil {
|
||||
|
@ -146,7 +146,7 @@ func TestSelectString(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSelectInt64(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
conn := getSharedConnection()
|
||||
|
||||
i, err := conn.SelectInt64("select 1")
|
||||
if err != nil {
|
||||
|
@ -169,7 +169,7 @@ func TestSelectInt64(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSelectInt32(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
conn := getSharedConnection()
|
||||
|
||||
i, err := conn.SelectInt32("select 1")
|
||||
if err != nil {
|
||||
|
@ -192,7 +192,7 @@ func TestSelectInt32(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSelectInt16(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
conn := getSharedConnection()
|
||||
|
||||
i, err := conn.SelectInt16("select 1")
|
||||
if err != nil {
|
||||
|
@ -215,7 +215,7 @@ func TestSelectInt16(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSelectFloat64(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
conn := getSharedConnection()
|
||||
|
||||
f, err := conn.SelectFloat64("select 1.23")
|
||||
if err != nil {
|
||||
|
@ -228,7 +228,7 @@ func TestSelectFloat64(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSelectFloat32(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
conn := getSharedConnection()
|
||||
|
||||
f, err := conn.SelectFloat32("select 1.23")
|
||||
if err != nil {
|
||||
|
@ -241,7 +241,7 @@ func TestSelectFloat32(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSelectAllString(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
conn := getSharedConnection()
|
||||
|
||||
s, err := conn.SelectAllString("select * from (values ('Matthew'), ('Mark'), ('Luke'), ('John')) t")
|
||||
if err != nil {
|
||||
|
@ -254,7 +254,7 @@ func TestSelectAllString(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSelectAllInt64(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
conn := getSharedConnection()
|
||||
|
||||
i, err := conn.SelectAllInt64("select * from (values (1), (2)) t")
|
||||
if err != nil {
|
||||
|
@ -277,7 +277,7 @@ func TestSelectAllInt64(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSelectAllInt32(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
conn := getSharedConnection()
|
||||
|
||||
i, err := conn.SelectAllInt32("select * from (values (1), (2)) t")
|
||||
if err != nil {
|
||||
|
@ -300,7 +300,7 @@ func TestSelectAllInt32(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSelectAllInt16(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
conn := getSharedConnection()
|
||||
|
||||
i, err := conn.SelectAllInt16("select * from (values (1), (2)) t")
|
||||
if err != nil {
|
||||
|
@ -323,7 +323,7 @@ func TestSelectAllInt16(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSelectAllFloat64(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
conn := getSharedConnection()
|
||||
|
||||
f, err := conn.SelectAllFloat64("select * from (values (1.23), (4.56)) t")
|
||||
if err != nil {
|
||||
|
@ -336,7 +336,7 @@ func TestSelectAllFloat64(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSelectAllFloat32(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
conn := getSharedConnection()
|
||||
|
||||
f, err := conn.SelectAllFloat32("select * from (values (1.23), (4.56)) t")
|
||||
if err != nil {
|
||||
|
|
|
@ -9,12 +9,12 @@ import (
|
|||
|
||||
var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`)
|
||||
|
||||
func (c *conn) QuoteString(input string) (output string) {
|
||||
func (c *Connection) QuoteString(input string) (output string) {
|
||||
output = "'" + strings.Replace(input, "'", "''", -1) + "'"
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) SanitizeSql(sql string, args ...interface{}) (output string) {
|
||||
func (c *Connection) SanitizeSql(sql string, args ...interface{}) (output string) {
|
||||
replacer := func(match string) (replacement string) {
|
||||
n, _ := strconv.ParseInt(match[1:], 10, 0)
|
||||
switch arg := args[n-1].(type) {
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
)
|
||||
|
||||
func TestQuoteString(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
conn := getSharedConnection()
|
||||
|
||||
if conn.QuoteString("test") != "'test'" {
|
||||
t.Error("Failed to quote string")
|
||||
|
@ -17,7 +17,7 @@ func TestQuoteString(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSanitizeSql(t *testing.T) {
|
||||
conn := getSharedConn()
|
||||
conn := getSharedConnection()
|
||||
|
||||
if conn.SanitizeSql("select $1, $2, $3", "Jack's", 42, 1.23) != "select 'Jack''s', 42, 1.23" {
|
||||
t.Error("Failed to sanitize sql")
|
||||
|
|
Loading…
Reference in New Issue