diff --git a/pgconn.go b/pgconn.go index 6e1fb7e3..3157f17e 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1151,7 +1151,10 @@ func (rr *ResultReader) Close() (CommandTag, error) { return nil, rr.err } - switch msg.(type) { + switch msg := msg.(type) { + // Detect a deferred constraint violation where the ErrorResponse is sent after CommandComplete. + case *pgproto3.ErrorResponse: + rr.err = errorResponseToPgError(msg) case *pgproto3.ReadyForQuery: rr.pgConn.contextWatcher.Unwatch() rr.pgConn.unlock() diff --git a/pgconn_test.go b/pgconn_test.go index feb78641..1b90b9d2 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -381,6 +381,34 @@ func TestConnExecMultipleQueriesError(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecDeferredError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) + + _, err = pgConn.Exec(context.Background(), `update t set n=n+1 where id='b' returning *`).ReadAll() + require.NotNil(t, err) + + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + func TestConnExecContextCanceled(t *testing.T) { t.Parallel() @@ -437,6 +465,33 @@ func TestConnExecParams(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecParamsDeferredError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) + + result := pgConn.ExecParams(context.Background(), `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() + require.NotNil(t, result.Err) + var pgErr *pgconn.PgError + require.True(t, errors.As(result.Err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + func TestConnExecParamsMaxNumberOfParams(t *testing.T) { t.Parallel() @@ -683,6 +738,36 @@ func TestConnExecBatch(t *testing.T) { assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) } +func TestConnExecBatchDeferredError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) + + batch := &pgconn.Batch{} + + batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) + _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NotNil(t, err) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + func TestConnExecBatchPrecanceled(t *testing.T) { t.Parallel()