diff --git a/conn.go b/conn.go index 6b0c65ef..ef22daba 100644 --- a/conn.go +++ b/conn.go @@ -106,6 +106,7 @@ var ErrNotificationTimeout = errors.New("notification timeout") var ErrDeadConn = errors.New("conn is dead") var ErrTLSRefused = errors.New("server refused TLS connection") var ErrConnBusy = errors.New("conn is busy") +var ErrInvalidLogLevel = errors.New("invalid log level") type ProtocolError string @@ -128,11 +129,8 @@ func Connect(config ConnConfig) (c *Conn, err error) { c.logLevel = LogLevelDebug } c.logger = c.config.Logger - if c.logger == nil { - c.logLevel = LogLevelNone - } c.mr.log = c.log - c.mr.logLevel = &c.logLevel + c.mr.shouldLog = c.shouldLog if c.config.User == "" { user, err := user.Current() @@ -140,14 +138,14 @@ func Connect(config ConnConfig) (c *Conn, err error) { return nil, err } c.config.User = user.Username - if c.logLevel >= LogLevelDebug { + if c.shouldLog(LogLevelDebug) { c.log(LogLevelDebug, "Using default connection config", "User", c.config.User) } } if c.config.Port == 0 { c.config.Port = 5432 - if c.logLevel >= LogLevelDebug { + if c.shouldLog(LogLevelDebug) { c.log(LogLevelDebug, "Using default connection config", "Port", c.config.Port) } } @@ -180,12 +178,12 @@ func Connect(config ConnConfig) (c *Conn, err error) { } func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) { - if c.logLevel >= LogLevelInfo { + if c.shouldLog(LogLevelInfo) { c.log(LogLevelInfo, fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address)) } c.conn, err = c.config.Dial(network, address) if err != nil { - if c.logLevel >= LogLevelError { + if c.shouldLog(LogLevelError) { c.log(LogLevelError, fmt.Sprintf("Connection failed: %v", err)) } return err @@ -194,7 +192,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl if c != nil && err != nil { c.conn.Close() c.alive = false - if c.logLevel >= LogLevelError { + if c.shouldLog(LogLevelError) { c.log(LogLevelError, err.Error()) } } @@ -207,11 +205,11 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.lastActivityTime = time.Now() if tlsConfig != nil { - if c.logLevel >= LogLevelDebug { + if c.shouldLog(LogLevelDebug) { c.log(LogLevelDebug, "Starting TLS handshake") } if err := c.startTLS(tlsConfig); err != nil { - if c.logLevel >= LogLevelError { + if c.shouldLog(LogLevelError) { c.log(LogLevelError, fmt.Sprintf("TLS failed: %v", err)) } return err @@ -262,7 +260,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } case readyForQuery: c.rxReadyForQuery(r) - if c.logLevel >= LogLevelInfo { + if c.shouldLog(LogLevelInfo) { c.log(LogLevelInfo, "Connection established") } @@ -338,7 +336,7 @@ func (c *Conn) Close() (err error) { _, err = c.conn.Write(wbuf.buf) c.die(errors.New("Closed")) - if c.logLevel >= LogLevelInfo { + if c.shouldLog(LogLevelInfo) { c.log(LogLevelInfo, "Closed connection") } return err @@ -548,7 +546,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { } } - if c.logLevel >= LogLevelError { + if c.shouldLog(LogLevelError) { defer func() { if err != nil { c.log(LogLevelError, fmt.Sprintf("Prepare `%s` as `%s` failed: %v", name, sql, err)) @@ -975,12 +973,12 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag defer func() { if err == nil { - if c.logLevel >= LogLevelInfo { + if c.shouldLog(LogLevelInfo) { endTime := time.Now() c.log(LogLevelInfo, "Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag) } } else { - if c.logLevel >= LogLevelError { + if c.shouldLog(LogLevelError) { c.log(LogLevelError, "Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err) } } @@ -1055,7 +1053,7 @@ func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { c.lastActivityTime = time.Now() - if c.logLevel >= LogLevelTrace { + if c.shouldLog(LogLevelTrace) { c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBytesRemaining", c.mr.msgBytesRemaining) } @@ -1252,6 +1250,10 @@ func (c *Conn) unlock() error { return nil } +func (c *Conn) shouldLog(lvl int) bool { + return c.logger != nil && c.logLevel >= lvl +} + func (c *Conn) log(lvl int, msg string, ctx ...interface{}) { if c.Pid != 0 { ctx = append(ctx, "pid", c.Pid) @@ -1277,3 +1279,16 @@ func (c *Conn) SetLogger(logger Logger) Logger { c.logger = logger return oldLogger } + +// SetLogLevel replaces the current log level and returns the previous log +// level. +func (c *Conn) SetLogLevel(lvl int) (int, error) { + oldLvl := c.logLevel + + if lvl < LogLevelNone || lvl > LogLevelTrace { + return oldLvl, ErrInvalidLogLevel + } + + c.logLevel = lvl + return lvl, nil +} diff --git a/conn_test.go b/conn_test.go index b011ef10..cd61fc95 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1345,12 +1345,28 @@ func TestCatchSimultaneousConnectionQueryAndExec(t *testing.T) { } } -type testLogger struct{} +type testLog struct { + lvl int + msg string + ctx []interface{} +} -func (l *testLogger) Debug(msg string, ctx ...interface{}) {} -func (l *testLogger) Info(msg string, ctx ...interface{}) {} -func (l *testLogger) Warn(msg string, ctx ...interface{}) {} -func (l *testLogger) Error(msg string, ctx ...interface{}) {} +type testLogger struct { + logs []testLog +} + +func (l *testLogger) Debug(msg string, ctx ...interface{}) { + l.logs = append(l.logs, testLog{lvl: pgx.LogLevelDebug, msg: msg, ctx: ctx}) +} +func (l *testLogger) Info(msg string, ctx ...interface{}) { + l.logs = append(l.logs, testLog{lvl: pgx.LogLevelInfo, msg: msg, ctx: ctx}) +} +func (l *testLogger) Warn(msg string, ctx ...interface{}) { + l.logs = append(l.logs, testLog{lvl: pgx.LogLevelWarn, msg: msg, ctx: ctx}) +} +func (l *testLogger) Error(msg string, ctx ...interface{}) { + l.logs = append(l.logs, testLog{lvl: pgx.LogLevelError, msg: msg, ctx: ctx}) +} func TestSetLogger(t *testing.T) { t.Parallel() @@ -1364,10 +1380,63 @@ func TestSetLogger(t *testing.T) { t.Fatalf("Expected conn.SetLogger to return %v, but it was %v", nil, oldLogger) } + if err := conn.Listen("foo"); err != nil { + t.Fatal(err) + } + + if len(l1.logs) == 0 { + t.Fatal("Expected new logger l1 to be called, but it wasn't") + } + l2 := &testLogger{} oldLogger = conn.SetLogger(l2) if oldLogger != l1 { t.Fatalf("Expected conn.SetLogger to return %v, but it was %v", l1, oldLogger) } + if err := conn.Listen("bar"); err != nil { + t.Fatal(err) + } + + if len(l2.logs) == 0 { + t.Fatal("Expected new logger l2 to be called, but it wasn't") + } +} + +func TestSetLogLevel(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + logger := &testLogger{} + conn.SetLogger(logger) + + if _, err := conn.SetLogLevel(0); err != pgx.ErrInvalidLogLevel { + t.Fatal("SetLogLevel with invalid level did not return error") + } + + if _, err := conn.SetLogLevel(pgx.LogLevelNone); err != nil { + t.Fatal(err) + } + + if err := conn.Listen("foo"); err != nil { + t.Fatal(err) + } + + if len(logger.logs) != 0 { + t.Fatalf("Expected logger not to be called, but it was: %v", logger.logs) + } + + if _, err := conn.SetLogLevel(pgx.LogLevelTrace); err != nil { + t.Fatal(err) + } + + if err := conn.Listen("bar"); err != nil { + t.Fatal(err) + } + + if len(logger.logs) == 0 { + t.Fatal("Expected logger to be called, but it wasn't") + } } diff --git a/msg_reader.go b/msg_reader.go index c9519e14..a5de6463 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -15,7 +15,7 @@ type msgReader struct { msgBytesRemaining int32 err error log func(lvl int, msg string, ctx ...interface{}) - logLevel *int + shouldLog func(lvl int) bool } // Err returns any error that the msgReader has experienced @@ -25,7 +25,7 @@ func (r *msgReader) Err() error { // fatal tells r that a Fatal error has occurred func (r *msgReader) fatal(err error) { - if *r.logLevel >= LogLevelTrace { + if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgBytesRemaining", r.msgBytesRemaining) } r.err = err @@ -38,7 +38,7 @@ func (r *msgReader) rxMsg() (byte, error) { } if r.msgBytesRemaining > 0 { - if *r.logLevel >= LogLevelTrace { + if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining) } @@ -68,7 +68,7 @@ func (r *msgReader) readByte() byte { return 0 } - if *r.logLevel >= LogLevelTrace { + if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgBytesRemaining", r.msgBytesRemaining) } @@ -95,7 +95,7 @@ func (r *msgReader) readInt16() int16 { n := int16(binary.BigEndian.Uint16(b)) - if *r.logLevel >= LogLevelTrace { + if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining) } @@ -122,7 +122,7 @@ func (r *msgReader) readInt32() int32 { n := int32(binary.BigEndian.Uint32(b)) - if *r.logLevel >= LogLevelTrace { + if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining) } @@ -149,7 +149,7 @@ func (r *msgReader) readInt64() int64 { n := int64(binary.BigEndian.Uint64(b)) - if *r.logLevel >= LogLevelTrace { + if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining) } @@ -180,7 +180,7 @@ func (r *msgReader) readCString() string { s := string(b[0 : len(b)-1]) - if *r.logLevel >= LogLevelTrace { + if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgBytesRemaining", r.msgBytesRemaining) } @@ -214,7 +214,7 @@ func (r *msgReader) readString(count int32) string { s := string(b) - if *r.logLevel >= LogLevelTrace { + if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining) } @@ -241,7 +241,7 @@ func (r *msgReader) readBytes(count int32) []byte { return nil } - if *r.logLevel >= LogLevelTrace { + if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readBytes", "value", b, "msgBytesRemaining", r.msgBytesRemaining) } diff --git a/query.go b/query.go index 264028a5..9c9aa4a8 100644 --- a/query.go +++ b/query.go @@ -52,7 +52,7 @@ type Rows struct { sql string args []interface{} log func(lvl int, msg string, ctx ...interface{}) - logLevel *int + shouldLog func(lvl int) bool unlockConn bool } @@ -78,11 +78,11 @@ func (rows *Rows) close() { rows.closed = true if rows.err == nil { - if *rows.logLevel >= LogLevelInfo { + if rows.shouldLog(LogLevelInfo) { endTime := time.Now() rows.log(LogLevelInfo, "Query", "sql", rows.sql, "args", logQueryArgs(rows.args), "time", endTime.Sub(rows.startTime), "rowCount", rows.rowCount) } - } else if *rows.logLevel >= LogLevelError { + } else if rows.shouldLog(LogLevelError) { rows.log(LogLevelError, "Query", "sql", rows.sql, "args", logQueryArgs(rows.args)) } } @@ -474,7 +474,7 @@ func (rows *Rows) Values() ([]interface{}, error) { // from Query and handle it in *Rows. func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) { c.lastActivityTime = time.Now() - rows := &Rows{conn: c, startTime: c.lastActivityTime, sql: sql, args: args, log: c.log, logLevel: &c.logLevel} + rows := &Rows{conn: c, startTime: c.lastActivityTime, sql: sql, args: args, log: c.log, shouldLog: c.shouldLog} if err := c.lock(); err != nil { rows.abort(err)