All writes errors are fatal

pull/483/head
Jack Christensen 2019-01-28 23:13:03 -06:00
parent d3a2c1c107
commit 4eff30fa70
2 changed files with 18 additions and 27 deletions

View File

@ -404,19 +404,6 @@ func (pgConn *PgConn) IsAlive() bool {
return !pgConn.closed
}
// writeAll writes the entire buffer. The connection is hard closed on a partial write or a non-temporary error.
func (pgConn *PgConn) writeAll(buf []byte) error {
n, err := pgConn.conn.Write(buf)
if err != nil {
if n > 0 {
pgConn.hardClose()
} else if ne, ok := err.(net.Error); ok && !ne.Temporary() {
pgConn.hardClose()
}
}
return err
}
// ParameterStatus returns the value of a parameter reported by the server (e.g.
// server_version). Returns an empty string for unknown parameters.
func (pgConn *PgConn) ParameterStatus(key string) string {
@ -501,8 +488,9 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf)
buf = (&pgproto3.Sync{}).Encode(buf)
err := pgConn.writeAll(buf)
_, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
return nil, preferContextOverNetTimeoutError(ctx, err)
}
@ -666,8 +654,9 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
var buf []byte
buf = (&pgproto3.Query{String: sql}).Encode(buf)
err := pgConn.writeAll(buf)
_, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
multiResult.cleanupContextDeadline()
multiResult.closed = true
multiResult.err = preferContextOverNetTimeoutError(ctx, err)
@ -723,8 +712,9 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
buf = (&pgproto3.Execute{}).Encode(buf)
buf = (&pgproto3.Sync{}).Encode(buf)
err := pgConn.writeAll(buf)
_, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
result.concludeCommand("", err)
result.cleanupContextDeadline()
result.closed = true
@ -768,8 +758,9 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
buf = (&pgproto3.Execute{}).Encode(buf)
buf = (&pgproto3.Sync{}).Encode(buf)
err := pgConn.writeAll(buf)
_, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
result.concludeCommand("", err)
result.cleanupContextDeadline()
result.closed = true
@ -792,8 +783,9 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
var buf []byte
buf = (&pgproto3.Query{String: sql}).Encode(buf)
err := pgConn.writeAll(buf)
_, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
cleanupContextDeadline()
<-pgConn.controller
@ -853,8 +845,9 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
var buf []byte
buf = (&pgproto3.Query{String: sql}).Encode(buf)
err := pgConn.writeAll(buf)
_, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
cleanupContextDeadline()
<-pgConn.controller
@ -903,14 +896,11 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
buf = buf[0 : n+5]
pgio.SetInt32(buf[sp:], int32(n+4))
err = pgConn.writeAll(buf)
_, err = pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
cleanupContextDeadline()
if err, ok := err.(net.Error); ok && err.Timeout() {
go pgConn.recoverFromTimeoutDuringCopyFrom()
} else {
<-pgConn.controller
}
<-pgConn.controller
return "", preferContextOverNetTimeoutError(ctx, err)
}
@ -1386,8 +1376,9 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn)
batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)
err := pgConn.writeAll(batch.buf)
_, err := pgConn.conn.Write(batch.buf)
if err != nil {
pgConn.hardClose()
multiResult.cleanupContextDeadline()
multiResult.closed = true
multiResult.err = preferContextOverNetTimeoutError(ctx, err)

View File

@ -863,7 +863,7 @@ func TestConnCopyFromCanceled(t *testing.T) {
assert.Equal(t, int64(0), ct.RowsAffected())
require.Equal(t, context.DeadlineExceeded, err)
ensureConnValid(t, pgConn)
assert.False(t, pgConn.IsAlive())
}
func TestConnCopyFromGzipReader(t *testing.T) {