Extract writeAll

pull/483/head
Jack Christensen 2019-01-26 12:33:51 -06:00
parent 5b09fe1e0b
commit 5b9108a20c
1 changed files with 21 additions and 66 deletions

View File

@ -398,6 +398,15 @@ func (pgConn *PgConn) hardClose() error {
return pgConn.conn.Close()
}
// writeAll writes the entire buffer successfully or it hard closes the connection.
func (pgConn *PgConn) writeAll(buf []byte) error {
n, err := pgConn.conn.Write(buf)
if err != nil && n > 0 {
pgConn.hardClose()
}
return err
}
// ParameterStatus returns the value of a parameter reported by the server (e.g.
// server_version). Returns an empty string for unknown parameters.
func (pgConn *PgConn) ParameterStatus(key string) string {
@ -482,15 +491,8 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf)
buf = (&pgproto3.Sync{}).Encode(buf)
n, err := pgConn.conn.Write(buf)
err := pgConn.writeAll(buf)
if err != nil {
// Partially sent messages are a fatal error for the connection.
if n > 0 {
// Close connection because cannot recover from partially sent message.
pgConn.conn.Close()
pgConn.closed = true
}
return nil, preferContextOverNetTimeoutError(ctx, err)
}
@ -654,15 +656,8 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
var buf []byte
buf = (&pgproto3.Query{String: sql}).Encode(buf)
n, err := pgConn.conn.Write(buf)
err := pgConn.writeAll(buf)
if err != nil {
// Partially sent messages are a fatal error for the connection.
if n > 0 {
// Close connection because cannot recover from partially sent message.
pgConn.conn.Close()
pgConn.closed = true
}
multiResult.cleanupContextDeadline()
multiResult.closed = true
multiResult.err = preferContextOverNetTimeoutError(ctx, err)
@ -718,15 +713,8 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
buf = (&pgproto3.Execute{}).Encode(buf)
buf = (&pgproto3.Sync{}).Encode(buf)
n, err := pgConn.conn.Write(buf)
err := pgConn.writeAll(buf)
if err != nil {
// Partially sent messages are a fatal error for the connection.
if n > 0 {
// Close connection because cannot recover from partially sent message.
pgConn.conn.Close()
pgConn.closed = true
}
result.concludeCommand("", err)
result.cleanupContextDeadline()
result.closed = true
@ -770,15 +758,8 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
buf = (&pgproto3.Execute{}).Encode(buf)
buf = (&pgproto3.Sync{}).Encode(buf)
n, err := pgConn.conn.Write(buf)
err := pgConn.writeAll(buf)
if err != nil {
// Partially sent messages are a fatal error for the connection.
if n > 0 {
// Close connection because cannot recover from partially sent message.
pgConn.conn.Close()
pgConn.closed = true
}
result.concludeCommand("", err)
result.cleanupContextDeadline()
result.closed = true
@ -801,15 +782,8 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
var buf []byte
buf = (&pgproto3.Query{String: sql}).Encode(buf)
n, err := pgConn.conn.Write(buf)
err := pgConn.writeAll(buf)
if err != nil {
// Partially sent messages are a fatal error for the connection.
if n > 0 {
// Close connection because cannot recover from partially sent message.
pgConn.conn.Close()
pgConn.closed = true
}
cleanupContextDeadline()
<-pgConn.controller
@ -869,15 +843,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
var buf []byte
buf = (&pgproto3.Query{String: sql}).Encode(buf)
n, err := pgConn.conn.Write(buf)
err := pgConn.writeAll(buf)
if err != nil {
// Partially sent messages are a fatal error for the connection.
if n > 0 {
// Close connection because cannot recover from partially sent message.
pgConn.conn.Close()
pgConn.closed = true
}
cleanupContextDeadline()
<-pgConn.controller
@ -913,25 +880,21 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
}
// Send copy data
buf = make([]byte, 0, 65536)
buf = make([]byte, 0, 20000)
// buf = make([]byte, 0, 65536)
buf = append(buf, 'd')
sp := len(buf)
var readErr error
signalMessageChan := pgConn.signalMessage()
for readErr == nil && pgErr == nil {
var n int
n, readErr = r.Read(buf[5:cap(buf)])
if n > 0 {
buf = buf[0 : n+5]
pgio.SetInt32(buf[sp:], int32(n+4))
n, err = pgConn.conn.Write(buf)
err = pgConn.writeAll(buf)
if err != nil {
// Partially sent messages are a fatal error for the connection.
if n > 0 {
// Close connection because cannot recover from partially sent message.
pgConn.conn.Close()
pgConn.closed = true
}
cleanupContextDeadline()
if err, ok := err.(net.Error); ok && err.Timeout() {
go pgConn.recoverFromTimeoutDuringCopyFrom()
@ -975,8 +938,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
}
_, err = pgConn.conn.Write(buf)
if err != nil {
pgConn.conn.Close()
pgConn.closed = true
pgConn.hardClose()
cleanupContextDeadline()
<-pgConn.controller
@ -1414,15 +1376,8 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn)
batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)
n, err := pgConn.conn.Write(batch.buf)
err := pgConn.writeAll(batch.buf)
if err != nil {
// Partially sent messages are a fatal error for the connection.
if n > 0 {
// Close connection because cannot recover from partially sent message.
pgConn.conn.Close()
pgConn.closed = true
}
multiResult.cleanupContextDeadline()
multiResult.closed = true
multiResult.err = preferContextOverNetTimeoutError(ctx, err)