From 351eb8ba679c66de3a67db7da9e0cd06f6fecda8 Mon Sep 17 00:00:00 2001 From: Jack Christensen <jack@jackchristensen.com> Date: Mon, 6 Feb 2017 19:39:34 -0600 Subject: [PATCH] Initial proof-of-concept database/sql context support --- conn.go | 52 ++++++++++++++++++++++++++++++++++++++++----------- stdlib/sql.go | 46 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 11 deletions(-) diff --git a/conn.go b/conn.go index f7c06014..b8131716 100644 --- a/conn.go +++ b/conn.go @@ -619,6 +619,41 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { // name and sql arguments. This allows a code path to PrepareEx and Query/Exec without // concern for if the statement has already been prepared. func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { + return c.PrepareExContext(context.Background(), name, sql, opts) + +} + +func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + doneChan := make(chan struct{}) + closedChan := make(chan struct{}) + + go func() { + select { + case <-ctx.Done(): + c.cancelQuery() + c.Close() + closedChan <- struct{}{} + case <-doneChan: + } + }() + + ps, err = c.prepareEx(name, sql, opts) + + select { + case <-closedChan: + return nil, ctx.Err() + case doneChan <- struct{}{}: + return ps, err + } +} + +func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { if name != "" { if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql { return ps, nil @@ -1349,29 +1384,24 @@ func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interfa } doneChan := make(chan struct{}) - closedChan := make(chan bool) + closedChan := make(chan struct{}) go func() { select { case <-ctx.Done(): c.cancelQuery() c.Close() - <-doneChan - closedChan <- true + closedChan <- struct{}{} case <-doneChan: - closedChan <- false } }() commandTag, err = c.Exec(sql, arguments...) - // Signal cancelation goroutine that operation is done - doneChan <- struct{}{} - - // If c was closed due to context cancelation then return context err - if <-closedChan { + select { + case <-closedChan: return "", ctx.Err() + case doneChan <- struct{}{}: + return commandTag, err } - - return commandTag, err } diff --git a/stdlib/sql.go b/stdlib/sql.go index 610aefd4..74218a7b 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -44,6 +44,7 @@ package stdlib import ( + "context" "database/sql" "database/sql/driver" "errors" @@ -211,6 +212,21 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) { return c.queryPrepared("", argsV) } +func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) { + if !c.conn.IsAlive() { + return nil, driver.ErrBadConn + } + + ps, err := c.conn.PrepareExContext(ctx, "", query, nil) + if err != nil { + return nil, err + } + + restrictBinaryToDatabaseSqlTypes(ps) + + return c.queryPreparedContext(ctx, "", argsV) +} + func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) { if !c.conn.IsAlive() { return nil, driver.ErrBadConn @@ -226,6 +242,24 @@ func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, er return &Rows{rows: rows}, nil } +func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []driver.NamedValue) (driver.Rows, error) { + if !c.conn.IsAlive() { + return nil, driver.ErrBadConn + } + + args := namedValueToInterface(argsV) + + rows, err := c.conn.QueryContext(ctx, name, args...) + if err != nil { + fmt.Println(err) + return nil, err + } + + fmt.Println("ere") + + return &Rows{rows: rows}, nil +} + // Anything that isn't a database/sql compatible type needs to be forced to // text format so that pgx.Rows.Values doesn't decode it into a native type // (e.g. []int32) @@ -318,6 +352,18 @@ func valueToInterface(argsV []driver.Value) []interface{} { return args } +func namedValueToInterface(argsV []driver.NamedValue) []interface{} { + args := make([]interface{}, 0, len(argsV)) + for _, v := range argsV { + if v.Value != nil { + args = append(args, v.Value.(interface{})) + } else { + args = append(args, nil) + } + } + return args +} + type Tx struct { conn *pgx.Conn }