From 351eb8ba679c66de3a67db7da9e0cd06f6fecda8 Mon Sep 17 00:00:00 2001
From: Jack Christensen <jack@jackchristensen.com>
Date: Mon, 6 Feb 2017 19:39:34 -0600
Subject: [PATCH] Initial proof-of-concept database/sql context support

---
 conn.go       | 52 ++++++++++++++++++++++++++++++++++++++++-----------
 stdlib/sql.go | 46 +++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 87 insertions(+), 11 deletions(-)

diff --git a/conn.go b/conn.go
index f7c06014..b8131716 100644
--- a/conn.go
+++ b/conn.go
@@ -619,6 +619,41 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
 // name and sql arguments. This allows a code path to PrepareEx and Query/Exec without
 // concern for if the statement has already been prepared.
 func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
+	return c.PrepareExContext(context.Background(), name, sql, opts)
+
+}
+
+func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
+	select {
+	case <-ctx.Done():
+		return nil, ctx.Err()
+	default:
+	}
+
+	doneChan := make(chan struct{})
+	closedChan := make(chan struct{})
+
+	go func() {
+		select {
+		case <-ctx.Done():
+			c.cancelQuery()
+			c.Close()
+			closedChan <- struct{}{}
+		case <-doneChan:
+		}
+	}()
+
+	ps, err = c.prepareEx(name, sql, opts)
+
+	select {
+	case <-closedChan:
+		return nil, ctx.Err()
+	case doneChan <- struct{}{}:
+		return ps, err
+	}
+}
+
+func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
 	if name != "" {
 		if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql {
 			return ps, nil
@@ -1349,29 +1384,24 @@ func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interfa
 	}
 
 	doneChan := make(chan struct{})
-	closedChan := make(chan bool)
+	closedChan := make(chan struct{})
 
 	go func() {
 		select {
 		case <-ctx.Done():
 			c.cancelQuery()
 			c.Close()
-			<-doneChan
-			closedChan <- true
+			closedChan <- struct{}{}
 		case <-doneChan:
-			closedChan <- false
 		}
 	}()
 
 	commandTag, err = c.Exec(sql, arguments...)
 
-	// Signal cancelation goroutine that operation is done
-	doneChan <- struct{}{}
-
-	// If c was closed due to context cancelation then return context err
-	if <-closedChan {
+	select {
+	case <-closedChan:
 		return "", ctx.Err()
+	case doneChan <- struct{}{}:
+		return commandTag, err
 	}
-
-	return commandTag, err
 }
diff --git a/stdlib/sql.go b/stdlib/sql.go
index 610aefd4..74218a7b 100644
--- a/stdlib/sql.go
+++ b/stdlib/sql.go
@@ -44,6 +44,7 @@
 package stdlib
 
 import (
+	"context"
 	"database/sql"
 	"database/sql/driver"
 	"errors"
@@ -211,6 +212,21 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
 	return c.queryPrepared("", argsV)
 }
 
+func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) {
+	if !c.conn.IsAlive() {
+		return nil, driver.ErrBadConn
+	}
+
+	ps, err := c.conn.PrepareExContext(ctx, "", query, nil)
+	if err != nil {
+		return nil, err
+	}
+
+	restrictBinaryToDatabaseSqlTypes(ps)
+
+	return c.queryPreparedContext(ctx, "", argsV)
+}
+
 func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) {
 	if !c.conn.IsAlive() {
 		return nil, driver.ErrBadConn
@@ -226,6 +242,24 @@ func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, er
 	return &Rows{rows: rows}, nil
 }
 
+func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []driver.NamedValue) (driver.Rows, error) {
+	if !c.conn.IsAlive() {
+		return nil, driver.ErrBadConn
+	}
+
+	args := namedValueToInterface(argsV)
+
+	rows, err := c.conn.QueryContext(ctx, name, args...)
+	if err != nil {
+		fmt.Println(err)
+		return nil, err
+	}
+
+	fmt.Println("ere")
+
+	return &Rows{rows: rows}, nil
+}
+
 // Anything that isn't a database/sql compatible type needs to be forced to
 // text format so that pgx.Rows.Values doesn't decode it into a native type
 // (e.g. []int32)
@@ -318,6 +352,18 @@ func valueToInterface(argsV []driver.Value) []interface{} {
 	return args
 }
 
+func namedValueToInterface(argsV []driver.NamedValue) []interface{} {
+	args := make([]interface{}, 0, len(argsV))
+	for _, v := range argsV {
+		if v.Value != nil {
+			args = append(args, v.Value.(interface{}))
+		} else {
+			args = append(args, nil)
+		}
+	}
+	return args
+}
+
 type Tx struct {
 	conn *pgx.Conn
 }