Extract writeAll

pull/483/head
Jack Christensen 2019-01-26 12:33:51 -06:00
parent 5b09fe1e0b
commit 5b9108a20c
1 changed files with 21 additions and 66 deletions

View File

@ -398,6 +398,15 @@ func (pgConn *PgConn) hardClose() error {
return pgConn.conn.Close() return pgConn.conn.Close()
} }
// writeAll writes the entire buffer successfully or it hard closes the connection.
func (pgConn *PgConn) writeAll(buf []byte) error {
n, err := pgConn.conn.Write(buf)
if err != nil && n > 0 {
pgConn.hardClose()
}
return err
}
// ParameterStatus returns the value of a parameter reported by the server (e.g. // ParameterStatus returns the value of a parameter reported by the server (e.g.
// server_version). Returns an empty string for unknown parameters. // server_version). Returns an empty string for unknown parameters.
func (pgConn *PgConn) ParameterStatus(key string) string { func (pgConn *PgConn) ParameterStatus(key string) string {
@ -482,15 +491,8 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf)
buf = (&pgproto3.Sync{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf)
n, err := pgConn.conn.Write(buf) err := pgConn.writeAll(buf)
if err != nil { if err != nil {
// Partially sent messages are a fatal error for the connection.
if n > 0 {
// Close connection because cannot recover from partially sent message.
pgConn.conn.Close()
pgConn.closed = true
}
return nil, preferContextOverNetTimeoutError(ctx, err) return nil, preferContextOverNetTimeoutError(ctx, err)
} }
@ -654,15 +656,8 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
var buf []byte var buf []byte
buf = (&pgproto3.Query{String: sql}).Encode(buf) buf = (&pgproto3.Query{String: sql}).Encode(buf)
n, err := pgConn.conn.Write(buf) err := pgConn.writeAll(buf)
if err != nil { if err != nil {
// Partially sent messages are a fatal error for the connection.
if n > 0 {
// Close connection because cannot recover from partially sent message.
pgConn.conn.Close()
pgConn.closed = true
}
multiResult.cleanupContextDeadline() multiResult.cleanupContextDeadline()
multiResult.closed = true multiResult.closed = true
multiResult.err = preferContextOverNetTimeoutError(ctx, err) multiResult.err = preferContextOverNetTimeoutError(ctx, err)
@ -718,15 +713,8 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Execute{}).Encode(buf)
buf = (&pgproto3.Sync{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf)
n, err := pgConn.conn.Write(buf) err := pgConn.writeAll(buf)
if err != nil { if err != nil {
// Partially sent messages are a fatal error for the connection.
if n > 0 {
// Close connection because cannot recover from partially sent message.
pgConn.conn.Close()
pgConn.closed = true
}
result.concludeCommand("", err) result.concludeCommand("", err)
result.cleanupContextDeadline() result.cleanupContextDeadline()
result.closed = true result.closed = true
@ -770,15 +758,8 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Execute{}).Encode(buf)
buf = (&pgproto3.Sync{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf)
n, err := pgConn.conn.Write(buf) err := pgConn.writeAll(buf)
if err != nil { if err != nil {
// Partially sent messages are a fatal error for the connection.
if n > 0 {
// Close connection because cannot recover from partially sent message.
pgConn.conn.Close()
pgConn.closed = true
}
result.concludeCommand("", err) result.concludeCommand("", err)
result.cleanupContextDeadline() result.cleanupContextDeadline()
result.closed = true result.closed = true
@ -801,15 +782,8 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
var buf []byte var buf []byte
buf = (&pgproto3.Query{String: sql}).Encode(buf) buf = (&pgproto3.Query{String: sql}).Encode(buf)
n, err := pgConn.conn.Write(buf) err := pgConn.writeAll(buf)
if err != nil { if err != nil {
// Partially sent messages are a fatal error for the connection.
if n > 0 {
// Close connection because cannot recover from partially sent message.
pgConn.conn.Close()
pgConn.closed = true
}
cleanupContextDeadline() cleanupContextDeadline()
<-pgConn.controller <-pgConn.controller
@ -869,15 +843,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
var buf []byte var buf []byte
buf = (&pgproto3.Query{String: sql}).Encode(buf) buf = (&pgproto3.Query{String: sql}).Encode(buf)
n, err := pgConn.conn.Write(buf) err := pgConn.writeAll(buf)
if err != nil { if err != nil {
// Partially sent messages are a fatal error for the connection.
if n > 0 {
// Close connection because cannot recover from partially sent message.
pgConn.conn.Close()
pgConn.closed = true
}
cleanupContextDeadline() cleanupContextDeadline()
<-pgConn.controller <-pgConn.controller
@ -913,25 +880,21 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
} }
// Send copy data // Send copy data
buf = make([]byte, 0, 65536) buf = make([]byte, 0, 20000)
// buf = make([]byte, 0, 65536)
buf = append(buf, 'd') buf = append(buf, 'd')
sp := len(buf) sp := len(buf)
var readErr error var readErr error
signalMessageChan := pgConn.signalMessage() signalMessageChan := pgConn.signalMessage()
for readErr == nil && pgErr == nil { for readErr == nil && pgErr == nil {
var n int
n, readErr = r.Read(buf[5:cap(buf)]) n, readErr = r.Read(buf[5:cap(buf)])
if n > 0 { if n > 0 {
buf = buf[0 : n+5] buf = buf[0 : n+5]
pgio.SetInt32(buf[sp:], int32(n+4)) pgio.SetInt32(buf[sp:], int32(n+4))
n, err = pgConn.conn.Write(buf) err = pgConn.writeAll(buf)
if err != nil { if err != nil {
// Partially sent messages are a fatal error for the connection.
if n > 0 {
// Close connection because cannot recover from partially sent message.
pgConn.conn.Close()
pgConn.closed = true
}
cleanupContextDeadline() cleanupContextDeadline()
if err, ok := err.(net.Error); ok && err.Timeout() { if err, ok := err.(net.Error); ok && err.Timeout() {
go pgConn.recoverFromTimeoutDuringCopyFrom() go pgConn.recoverFromTimeoutDuringCopyFrom()
@ -975,8 +938,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
} }
_, err = pgConn.conn.Write(buf) _, err = pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.conn.Close() pgConn.hardClose()
pgConn.closed = true
cleanupContextDeadline() cleanupContextDeadline()
<-pgConn.controller <-pgConn.controller
@ -1414,15 +1376,8 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn)
batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)
n, err := pgConn.conn.Write(batch.buf) err := pgConn.writeAll(batch.buf)
if err != nil { if err != nil {
// Partially sent messages are a fatal error for the connection.
if n > 0 {
// Close connection because cannot recover from partially sent message.
pgConn.conn.Close()
pgConn.closed = true
}
multiResult.cleanupContextDeadline() multiResult.cleanupContextDeadline()
multiResult.closed = true multiResult.closed = true
multiResult.err = preferContextOverNetTimeoutError(ctx, err) multiResult.err = preferContextOverNetTimeoutError(ctx, err)