diff --git a/connection.go b/connection.go index 56e82f5d..96963dfb 100644 --- a/connection.go +++ b/connection.go @@ -142,7 +142,9 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error case commandComplete: case bindComplete: default: - err = c.processContextFreeMsg(t, r) + if e := c.processContextFreeMsg(t, r); e != nil && err == nil { + err = e + } } } else { return rxErr @@ -406,9 +408,7 @@ func (c *Connection) Execute(sql string, arguments ...interface{}) (commandTag s } for { - var t byte - var r *MessageReader - if t, r, err = c.rxMsg(); err == nil { + if t, r, rxErr := c.rxMsg(); rxErr == nil { switch t { case readyForQuery: return @@ -418,12 +418,12 @@ func (c *Connection) Execute(sql string, arguments ...interface{}) (commandTag s case commandComplete: commandTag = r.ReadString() default: - if err = c.processContextFreeMsg(t, r); err != nil { - return + if e := c.processContextFreeMsg(t, r); e != nil && err == nil { + err = e } } } else { - return + return "", rxErr } } } diff --git a/connection_test.go b/connection_test.go index 7206c2b3..3881aafd 100644 --- a/connection_test.go +++ b/connection_test.go @@ -179,7 +179,22 @@ func TestExecute(t *testing.T) { if results != "SELECT 1" { t.Errorf("Unexpected results from Execute: %v", results) } +} +func TestExecuteFailure(t *testing.T) { + conn, err := Connect(*defaultConnectionParameters) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + defer conn.Close() + + if _, err := conn.Execute("select;"); err == nil { + t.Fatal("Expected SQL syntax error") + } + + if _, err := conn.SelectValue("select 1"); err != nil { + t.Fatalf("Execute failure appears to have broken connection: %v", err) + } } func TestSelectFunc(t *testing.T) { @@ -204,6 +219,23 @@ func TestSelectFunc(t *testing.T) { } } +func TestSelectFuncFailure(t *testing.T) { + conn, err := Connect(*defaultConnectionParameters) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + defer conn.Close() + + // using SelectValue as it delegates to SelectFunc and is easier to work with + if _, err := conn.SelectValue("select;"); err == nil { + t.Fatal("Expected SQL syntax error") + } + + if _, err := conn.SelectValue("select 1"); err != nil { + t.Fatalf("SelectFunc failure appears to have broken connection: %v", err) + } +} + func TestSelectRows(t *testing.T) { conn := getSharedConnection()