From cffae7ff5d4d98d571adc450df0fb20d548c45eb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 12 Feb 2016 17:49:04 -0600 Subject: [PATCH] Add SetLogger to *Conn Allow replacing logger after connection is established. Also refactor internals of logging such that there is a log method that adds the pid to all log calls instead of making a new logger object. The reason for this is so pid will be logged regardless of whether loggers are replaced and restored. --- conn.go | 57 ++++++++++++++++++++++++++++++++++++--------------- conn_test.go | 27 ++++++++++++++++++++++++ logger.go | 25 ---------------------- msg_reader.go | 40 ++++++++++++++++++------------------ query.go | 14 ++++++------- 5 files changed, 95 insertions(+), 68 deletions(-) diff --git a/conn.go b/conn.go index 1c10d449..6b0c65ef 100644 --- a/conn.go +++ b/conn.go @@ -131,8 +131,8 @@ func Connect(config ConnConfig) (c *Conn, err error) { if c.logger == nil { c.logLevel = LogLevelNone } - c.mr.logger = c.logger - c.mr.logLevel = c.logLevel + c.mr.log = c.log + c.mr.logLevel = &c.logLevel if c.config.User == "" { user, err := user.Current() @@ -141,14 +141,14 @@ func Connect(config ConnConfig) (c *Conn, err error) { } c.config.User = user.Username if c.logLevel >= LogLevelDebug { - c.logger.Debug("Using default connection config", "User", c.config.User) + c.log(LogLevelDebug, "Using default connection config", "User", c.config.User) } } if c.config.Port == 0 { c.config.Port = 5432 if c.logLevel >= LogLevelDebug { - c.logger.Debug("Using default connection config", "Port", c.config.Port) + c.log(LogLevelDebug, "Using default connection config", "Port", c.config.Port) } } @@ -181,12 +181,12 @@ func Connect(config ConnConfig) (c *Conn, err error) { func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) { if c.logLevel >= LogLevelInfo { - c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address)) + c.log(LogLevelInfo, fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address)) } c.conn, err = c.config.Dial(network, address) if err != nil { if c.logLevel >= LogLevelError { - c.logger.Error(fmt.Sprintf("Connection failed: %v", err)) + c.log(LogLevelError, fmt.Sprintf("Connection failed: %v", err)) } return err } @@ -195,7 +195,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.conn.Close() c.alive = false if c.logLevel >= LogLevelError { - c.logger.Error(err.Error()) + c.log(LogLevelError, err.Error()) } } }() @@ -208,11 +208,11 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl if tlsConfig != nil { if c.logLevel >= LogLevelDebug { - c.logger.Debug("Starting TLS handshake") + c.log(LogLevelDebug, "Starting TLS handshake") } if err := c.startTLS(tlsConfig); err != nil { if c.logLevel >= LogLevelError { - c.logger.Error(fmt.Sprintf("TLS failed: %v", err)) + c.log(LogLevelError, fmt.Sprintf("TLS failed: %v", err)) } return err } @@ -263,8 +263,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl case readyForQuery: c.rxReadyForQuery(r) if c.logLevel >= LogLevelInfo { - c.logger = &connLogger{logger: c.logger, pid: c.Pid} - c.logger.Info("Connection established") + c.log(LogLevelInfo, "Connection established") } err = c.loadPgTypes() @@ -340,7 +339,7 @@ func (c *Conn) Close() (err error) { c.die(errors.New("Closed")) if c.logLevel >= LogLevelInfo { - c.logger.Info("Closed connection") + c.log(LogLevelInfo, "Closed connection") } return err } @@ -552,7 +551,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { if c.logLevel >= LogLevelError { defer func() { if err != nil { - c.logger.Error(fmt.Sprintf("Prepare `%s` as `%s` failed: %v", name, sql, err)) + c.log(LogLevelError, fmt.Sprintf("Prepare `%s` as `%s` failed: %v", name, sql, err)) } }() } @@ -978,11 +977,11 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag if err == nil { if c.logLevel >= LogLevelInfo { endTime := time.Now() - c.logger.Info("Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag) + c.log(LogLevelInfo, "Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag) } } else { if c.logLevel >= LogLevelError { - c.logger.Error("Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err) + c.log(LogLevelError, "Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err) } } @@ -1057,7 +1056,7 @@ func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { c.lastActivityTime = time.Now() if c.logLevel >= LogLevelTrace { - c.logger.Debug("rxMsg", "type", string(t), "msgBytesRemaining", c.mr.msgBytesRemaining) + c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBytesRemaining", c.mr.msgBytesRemaining) } return t, &c.mr, err @@ -1252,3 +1251,29 @@ func (c *Conn) unlock() error { c.busy = false return nil } + +func (c *Conn) log(lvl int, msg string, ctx ...interface{}) { + if c.Pid != 0 { + ctx = append(ctx, "pid", c.Pid) + } + + switch lvl { + case LogLevelTrace: + c.logger.Debug(msg, ctx...) + case LogLevelDebug: + c.logger.Debug(msg, ctx...) + case LogLevelInfo: + c.logger.Info(msg, ctx...) + case LogLevelWarn: + c.logger.Warn(msg, ctx...) + case LogLevelError: + c.logger.Error(msg, ctx...) + } +} + +// SetLogger replaces the current logger and returns the previous logger. +func (c *Conn) SetLogger(logger Logger) Logger { + oldLogger := c.logger + c.logger = logger + return oldLogger +} diff --git a/conn_test.go b/conn_test.go index bfdd35ea..b011ef10 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1344,3 +1344,30 @@ func TestCatchSimultaneousConnectionQueryAndExec(t *testing.T) { t.Fatalf("conn.Exec should have failed with pgx.ErrConnBusy, but it was %v", err) } } + +type testLogger struct{} + +func (l *testLogger) Debug(msg string, ctx ...interface{}) {} +func (l *testLogger) Info(msg string, ctx ...interface{}) {} +func (l *testLogger) Warn(msg string, ctx ...interface{}) {} +func (l *testLogger) Error(msg string, ctx ...interface{}) {} + +func TestSetLogger(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + l1 := &testLogger{} + oldLogger := conn.SetLogger(l1) + if oldLogger != nil { + t.Fatalf("Expected conn.SetLogger to return %v, but it was %v", nil, oldLogger) + } + + l2 := &testLogger{} + oldLogger = conn.SetLogger(l2) + if oldLogger != l1 { + t.Fatalf("Expected conn.SetLogger to return %v, but it was %v", l1, oldLogger) + } + +} diff --git a/logger.go b/logger.go index ea1a1818..536a2cfa 100644 --- a/logger.go +++ b/logger.go @@ -30,31 +30,6 @@ type Logger interface { Error(msg string, ctx ...interface{}) } -type connLogger struct { - logger Logger - pid int32 -} - -func (l *connLogger) Debug(msg string, ctx ...interface{}) { - ctx = append(ctx, "pid", l.pid) - l.logger.Debug(msg, ctx...) -} - -func (l *connLogger) Info(msg string, ctx ...interface{}) { - ctx = append(ctx, "pid", l.pid) - l.logger.Info(msg, ctx...) -} - -func (l *connLogger) Warn(msg string, ctx ...interface{}) { - ctx = append(ctx, "pid", l.pid) - l.logger.Warn(msg, ctx...) -} - -func (l *connLogger) Error(msg string, ctx ...interface{}) { - ctx = append(ctx, "pid", l.pid) - l.logger.Error(msg, ctx...) -} - // Converts log level string to constant // // Valid levels: diff --git a/msg_reader.go b/msg_reader.go index 4c2d9805..c9519e14 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -14,8 +14,8 @@ type msgReader struct { buf [128]byte msgBytesRemaining int32 err error - logger Logger - logLevel int + log func(lvl int, msg string, ctx ...interface{}) + logLevel *int } // Err returns any error that the msgReader has experienced @@ -25,8 +25,8 @@ func (r *msgReader) Err() error { // fatal tells r that a Fatal error has occurred func (r *msgReader) fatal(err error) { - if r.logLevel >= LogLevelTrace { - r.logger.Debug("msgReader.fatal", "error", err, "msgBytesRemaining", r.msgBytesRemaining) + if *r.logLevel >= LogLevelTrace { + r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgBytesRemaining", r.msgBytesRemaining) } r.err = err } @@ -38,8 +38,8 @@ func (r *msgReader) rxMsg() (byte, error) { } if r.msgBytesRemaining > 0 { - if r.logLevel >= LogLevelTrace { - r.logger.Debug("msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining) + if *r.logLevel >= LogLevelTrace { + r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining) } io.CopyN(ioutil.Discard, r.reader, int64(r.msgBytesRemaining)) @@ -68,8 +68,8 @@ func (r *msgReader) readByte() byte { return 0 } - if r.logLevel >= LogLevelTrace { - r.logger.Debug("msgReader.readByte", "value", b, "byteAsString", string(b), "msgBytesRemaining", r.msgBytesRemaining) + if *r.logLevel >= LogLevelTrace { + r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgBytesRemaining", r.msgBytesRemaining) } return b @@ -95,8 +95,8 @@ func (r *msgReader) readInt16() int16 { n := int16(binary.BigEndian.Uint16(b)) - if r.logLevel >= LogLevelTrace { - r.logger.Debug("msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + if *r.logLevel >= LogLevelTrace { + r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining) } return n @@ -122,8 +122,8 @@ func (r *msgReader) readInt32() int32 { n := int32(binary.BigEndian.Uint32(b)) - if r.logLevel >= LogLevelTrace { - r.logger.Debug("msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + if *r.logLevel >= LogLevelTrace { + r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining) } return n @@ -149,8 +149,8 @@ func (r *msgReader) readInt64() int64 { n := int64(binary.BigEndian.Uint64(b)) - if r.logLevel >= LogLevelTrace { - r.logger.Debug("msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + if *r.logLevel >= LogLevelTrace { + r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining) } return n @@ -180,8 +180,8 @@ func (r *msgReader) readCString() string { s := string(b[0 : len(b)-1]) - if r.logLevel >= LogLevelTrace { - r.logger.Debug("msgReader.readCString", "value", s, "msgBytesRemaining", r.msgBytesRemaining) + if *r.logLevel >= LogLevelTrace { + r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgBytesRemaining", r.msgBytesRemaining) } return s @@ -214,8 +214,8 @@ func (r *msgReader) readString(count int32) string { s := string(b) - if r.logLevel >= LogLevelTrace { - r.logger.Debug("msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining) + if *r.logLevel >= LogLevelTrace { + r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining) } return s @@ -241,8 +241,8 @@ func (r *msgReader) readBytes(count int32) []byte { return nil } - if r.logLevel >= LogLevelTrace { - r.logger.Debug("msgReader.readBytes", "value", b, "msgBytesRemaining", r.msgBytesRemaining) + if *r.logLevel >= LogLevelTrace { + r.log(LogLevelTrace, "msgReader.readBytes", "value", b, "msgBytesRemaining", r.msgBytesRemaining) } return b diff --git a/query.go b/query.go index 8398562b..264028a5 100644 --- a/query.go +++ b/query.go @@ -51,8 +51,8 @@ type Rows struct { startTime time.Time sql string args []interface{} - logger Logger - logLevel int + log func(lvl int, msg string, ctx ...interface{}) + logLevel *int unlockConn bool } @@ -78,12 +78,12 @@ func (rows *Rows) close() { rows.closed = true if rows.err == nil { - if rows.logLevel >= LogLevelInfo { + if *rows.logLevel >= LogLevelInfo { endTime := time.Now() - rows.logger.Info("Query", "sql", rows.sql, "args", logQueryArgs(rows.args), "time", endTime.Sub(rows.startTime), "rowCount", rows.rowCount) + rows.log(LogLevelInfo, "Query", "sql", rows.sql, "args", logQueryArgs(rows.args), "time", endTime.Sub(rows.startTime), "rowCount", rows.rowCount) } - } else if rows.logLevel >= LogLevelError { - rows.logger.Error("Query", "sql", rows.sql, "args", logQueryArgs(rows.args)) + } else if *rows.logLevel >= LogLevelError { + rows.log(LogLevelError, "Query", "sql", rows.sql, "args", logQueryArgs(rows.args)) } } @@ -474,7 +474,7 @@ func (rows *Rows) Values() ([]interface{}, error) { // from Query and handle it in *Rows. func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) { c.lastActivityTime = time.Now() - rows := &Rows{conn: c, startTime: c.lastActivityTime, sql: sql, args: args, logger: c.logger, logLevel: c.logLevel} + rows := &Rows{conn: c, startTime: c.lastActivityTime, sql: sql, args: args, log: c.log, logLevel: &c.logLevel} if err := c.lock(); err != nil { rows.abort(err)