diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 512c9a88..c785f367 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -404,19 +404,6 @@ func (pgConn *PgConn) IsAlive() bool { return !pgConn.closed } -// writeAll writes the entire buffer. The connection is hard closed on a partial write or a non-temporary error. -func (pgConn *PgConn) writeAll(buf []byte) error { - n, err := pgConn.conn.Write(buf) - if err != nil { - if n > 0 { - pgConn.hardClose() - } else if ne, ok := err.(net.Error); ok && !ne.Temporary() { - pgConn.hardClose() - } - } - return err -} - // ParameterStatus returns the value of a parameter reported by the server (e.g. // server_version). Returns an empty string for unknown parameters. func (pgConn *PgConn) ParameterStatus(key string) string { @@ -501,8 +488,9 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - err := pgConn.writeAll(buf) + _, err := pgConn.conn.Write(buf) if err != nil { + pgConn.hardClose() return nil, preferContextOverNetTimeoutError(ctx, err) } @@ -666,8 +654,9 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { var buf []byte buf = (&pgproto3.Query{String: sql}).Encode(buf) - err := pgConn.writeAll(buf) + _, err := pgConn.conn.Write(buf) if err != nil { + pgConn.hardClose() multiResult.cleanupContextDeadline() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err) @@ -723,8 +712,9 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - err := pgConn.writeAll(buf) + _, err := pgConn.conn.Write(buf) if err != nil { + pgConn.hardClose() result.concludeCommand("", err) result.cleanupContextDeadline() result.closed = true @@ -768,8 +758,9 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - err := pgConn.writeAll(buf) + _, err := pgConn.conn.Write(buf) if err != nil { + pgConn.hardClose() result.concludeCommand("", err) result.cleanupContextDeadline() result.closed = true @@ -792,8 +783,9 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm var buf []byte buf = (&pgproto3.Query{String: sql}).Encode(buf) - err := pgConn.writeAll(buf) + _, err := pgConn.conn.Write(buf) if err != nil { + pgConn.hardClose() cleanupContextDeadline() <-pgConn.controller @@ -853,8 +845,9 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co var buf []byte buf = (&pgproto3.Query{String: sql}).Encode(buf) - err := pgConn.writeAll(buf) + _, err := pgConn.conn.Write(buf) if err != nil { + pgConn.hardClose() cleanupContextDeadline() <-pgConn.controller @@ -903,14 +896,11 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co buf = buf[0 : n+5] pgio.SetInt32(buf[sp:], int32(n+4)) - err = pgConn.writeAll(buf) + _, err = pgConn.conn.Write(buf) if err != nil { + pgConn.hardClose() cleanupContextDeadline() - if err, ok := err.(net.Error); ok && err.Timeout() { - go pgConn.recoverFromTimeoutDuringCopyFrom() - } else { - <-pgConn.controller - } + <-pgConn.controller return "", preferContextOverNetTimeoutError(ctx, err) } @@ -1386,8 +1376,9 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) - err := pgConn.writeAll(batch.buf) + _, err := pgConn.conn.Write(batch.buf) if err != nil { + pgConn.hardClose() multiResult.cleanupContextDeadline() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err) diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 7fb01e2c..dbf9b840 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -863,7 +863,7 @@ func TestConnCopyFromCanceled(t *testing.T) { assert.Equal(t, int64(0), ct.RowsAffected()) require.Equal(t, context.DeadlineExceeded, err) - ensureConnValid(t, pgConn) + assert.False(t, pgConn.IsAlive()) } func TestConnCopyFromGzipReader(t *testing.T) {