diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 08fce16e..2a3c5936 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -438,6 +438,8 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ psd := &PreparedStatementDescription{Name: name, SQL: sql} + var parseErr error + readloop: for { msg, err := pgConn.ReceiveMessage() @@ -454,14 +456,17 @@ readloop: psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) copy(psd.Fields, msg.Fields) case *pgproto3.ErrorResponse: - go pgConn.recoverFromTimeout() - return nil, errorResponseToPgError(msg) + parseErr = errorResponseToPgError(msg) case *pgproto3.ReadyForQuery: break readloop } } <-pgConn.controller + + if parseErr != nil { + return nil, parseErr + } return psd, nil } diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 9452ffc0..90f99325 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -235,6 +235,20 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { } } +func TestConnPrepareFailure(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + psd, err := pgConn.Prepare(context.Background(), "ps1", "SYNTAX ERROR", nil) + require.Nil(t, psd) + require.NotNil(t, err) + + ensureConnValid(t, pgConn) +} + func TestConnExec(t *testing.T) { t.Parallel()