Add *Conn.SetLogLevel

Allow changing log level after connection is established. Because
log level and loggers can be set independently, it is now possible
to have a log level above none when there is a nil logger. This
means all log statements need to check for nil logger and an
appropriate log level. This check has been factored out into
*Conn.shouldLog.
pull/120/head
Jack Christensen 2016-02-13 10:13:10 -06:00
parent cffae7ff5d
commit 0f7bf19387
4 changed files with 120 additions and 36 deletions

49
conn.go
View File

@ -106,6 +106,7 @@ var ErrNotificationTimeout = errors.New("notification timeout")
var ErrDeadConn = errors.New("conn is dead")
var ErrTLSRefused = errors.New("server refused TLS connection")
var ErrConnBusy = errors.New("conn is busy")
var ErrInvalidLogLevel = errors.New("invalid log level")
type ProtocolError string
@ -128,11 +129,8 @@ func Connect(config ConnConfig) (c *Conn, err error) {
c.logLevel = LogLevelDebug
}
c.logger = c.config.Logger
if c.logger == nil {
c.logLevel = LogLevelNone
}
c.mr.log = c.log
c.mr.logLevel = &c.logLevel
c.mr.shouldLog = c.shouldLog
if c.config.User == "" {
user, err := user.Current()
@ -140,14 +138,14 @@ func Connect(config ConnConfig) (c *Conn, err error) {
return nil, err
}
c.config.User = user.Username
if c.logLevel >= LogLevelDebug {
if c.shouldLog(LogLevelDebug) {
c.log(LogLevelDebug, "Using default connection config", "User", c.config.User)
}
}
if c.config.Port == 0 {
c.config.Port = 5432
if c.logLevel >= LogLevelDebug {
if c.shouldLog(LogLevelDebug) {
c.log(LogLevelDebug, "Using default connection config", "Port", c.config.Port)
}
}
@ -180,12 +178,12 @@ func Connect(config ConnConfig) (c *Conn, err error) {
}
func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) {
if c.logLevel >= LogLevelInfo {
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address))
}
c.conn, err = c.config.Dial(network, address)
if err != nil {
if c.logLevel >= LogLevelError {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, fmt.Sprintf("Connection failed: %v", err))
}
return err
@ -194,7 +192,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
if c != nil && err != nil {
c.conn.Close()
c.alive = false
if c.logLevel >= LogLevelError {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, err.Error())
}
}
@ -207,11 +205,11 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
c.lastActivityTime = time.Now()
if tlsConfig != nil {
if c.logLevel >= LogLevelDebug {
if c.shouldLog(LogLevelDebug) {
c.log(LogLevelDebug, "Starting TLS handshake")
}
if err := c.startTLS(tlsConfig); err != nil {
if c.logLevel >= LogLevelError {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, fmt.Sprintf("TLS failed: %v", err))
}
return err
@ -262,7 +260,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
}
case readyForQuery:
c.rxReadyForQuery(r)
if c.logLevel >= LogLevelInfo {
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "Connection established")
}
@ -338,7 +336,7 @@ func (c *Conn) Close() (err error) {
_, err = c.conn.Write(wbuf.buf)
c.die(errors.New("Closed"))
if c.logLevel >= LogLevelInfo {
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "Closed connection")
}
return err
@ -548,7 +546,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
}
}
if c.logLevel >= LogLevelError {
if c.shouldLog(LogLevelError) {
defer func() {
if err != nil {
c.log(LogLevelError, fmt.Sprintf("Prepare `%s` as `%s` failed: %v", name, sql, err))
@ -975,12 +973,12 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag
defer func() {
if err == nil {
if c.logLevel >= LogLevelInfo {
if c.shouldLog(LogLevelInfo) {
endTime := time.Now()
c.log(LogLevelInfo, "Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag)
}
} else {
if c.logLevel >= LogLevelError {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, "Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err)
}
}
@ -1055,7 +1053,7 @@ func (c *Conn) rxMsg() (t byte, r *msgReader, err error) {
c.lastActivityTime = time.Now()
if c.logLevel >= LogLevelTrace {
if c.shouldLog(LogLevelTrace) {
c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBytesRemaining", c.mr.msgBytesRemaining)
}
@ -1252,6 +1250,10 @@ func (c *Conn) unlock() error {
return nil
}
func (c *Conn) shouldLog(lvl int) bool {
return c.logger != nil && c.logLevel >= lvl
}
func (c *Conn) log(lvl int, msg string, ctx ...interface{}) {
if c.Pid != 0 {
ctx = append(ctx, "pid", c.Pid)
@ -1277,3 +1279,16 @@ func (c *Conn) SetLogger(logger Logger) Logger {
c.logger = logger
return oldLogger
}
// SetLogLevel replaces the current log level and returns the previous log
// level.
func (c *Conn) SetLogLevel(lvl int) (int, error) {
oldLvl := c.logLevel
if lvl < LogLevelNone || lvl > LogLevelTrace {
return oldLvl, ErrInvalidLogLevel
}
c.logLevel = lvl
return lvl, nil
}

View File

@ -1345,12 +1345,28 @@ func TestCatchSimultaneousConnectionQueryAndExec(t *testing.T) {
}
}
type testLogger struct{}
type testLog struct {
lvl int
msg string
ctx []interface{}
}
func (l *testLogger) Debug(msg string, ctx ...interface{}) {}
func (l *testLogger) Info(msg string, ctx ...interface{}) {}
func (l *testLogger) Warn(msg string, ctx ...interface{}) {}
func (l *testLogger) Error(msg string, ctx ...interface{}) {}
type testLogger struct {
logs []testLog
}
func (l *testLogger) Debug(msg string, ctx ...interface{}) {
l.logs = append(l.logs, testLog{lvl: pgx.LogLevelDebug, msg: msg, ctx: ctx})
}
func (l *testLogger) Info(msg string, ctx ...interface{}) {
l.logs = append(l.logs, testLog{lvl: pgx.LogLevelInfo, msg: msg, ctx: ctx})
}
func (l *testLogger) Warn(msg string, ctx ...interface{}) {
l.logs = append(l.logs, testLog{lvl: pgx.LogLevelWarn, msg: msg, ctx: ctx})
}
func (l *testLogger) Error(msg string, ctx ...interface{}) {
l.logs = append(l.logs, testLog{lvl: pgx.LogLevelError, msg: msg, ctx: ctx})
}
func TestSetLogger(t *testing.T) {
t.Parallel()
@ -1364,10 +1380,63 @@ func TestSetLogger(t *testing.T) {
t.Fatalf("Expected conn.SetLogger to return %v, but it was %v", nil, oldLogger)
}
if err := conn.Listen("foo"); err != nil {
t.Fatal(err)
}
if len(l1.logs) == 0 {
t.Fatal("Expected new logger l1 to be called, but it wasn't")
}
l2 := &testLogger{}
oldLogger = conn.SetLogger(l2)
if oldLogger != l1 {
t.Fatalf("Expected conn.SetLogger to return %v, but it was %v", l1, oldLogger)
}
if err := conn.Listen("bar"); err != nil {
t.Fatal(err)
}
if len(l2.logs) == 0 {
t.Fatal("Expected new logger l2 to be called, but it wasn't")
}
}
func TestSetLogLevel(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
logger := &testLogger{}
conn.SetLogger(logger)
if _, err := conn.SetLogLevel(0); err != pgx.ErrInvalidLogLevel {
t.Fatal("SetLogLevel with invalid level did not return error")
}
if _, err := conn.SetLogLevel(pgx.LogLevelNone); err != nil {
t.Fatal(err)
}
if err := conn.Listen("foo"); err != nil {
t.Fatal(err)
}
if len(logger.logs) != 0 {
t.Fatalf("Expected logger not to be called, but it was: %v", logger.logs)
}
if _, err := conn.SetLogLevel(pgx.LogLevelTrace); err != nil {
t.Fatal(err)
}
if err := conn.Listen("bar"); err != nil {
t.Fatal(err)
}
if len(logger.logs) == 0 {
t.Fatal("Expected logger to be called, but it wasn't")
}
}

View File

@ -15,7 +15,7 @@ type msgReader struct {
msgBytesRemaining int32
err error
log func(lvl int, msg string, ctx ...interface{})
logLevel *int
shouldLog func(lvl int) bool
}
// Err returns any error that the msgReader has experienced
@ -25,7 +25,7 @@ func (r *msgReader) Err() error {
// fatal tells r that a Fatal error has occurred
func (r *msgReader) fatal(err error) {
if *r.logLevel >= LogLevelTrace {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgBytesRemaining", r.msgBytesRemaining)
}
r.err = err
@ -38,7 +38,7 @@ func (r *msgReader) rxMsg() (byte, error) {
}
if r.msgBytesRemaining > 0 {
if *r.logLevel >= LogLevelTrace {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining)
}
@ -68,7 +68,7 @@ func (r *msgReader) readByte() byte {
return 0
}
if *r.logLevel >= LogLevelTrace {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgBytesRemaining", r.msgBytesRemaining)
}
@ -95,7 +95,7 @@ func (r *msgReader) readInt16() int16 {
n := int16(binary.BigEndian.Uint16(b))
if *r.logLevel >= LogLevelTrace {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
}
@ -122,7 +122,7 @@ func (r *msgReader) readInt32() int32 {
n := int32(binary.BigEndian.Uint32(b))
if *r.logLevel >= LogLevelTrace {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
}
@ -149,7 +149,7 @@ func (r *msgReader) readInt64() int64 {
n := int64(binary.BigEndian.Uint64(b))
if *r.logLevel >= LogLevelTrace {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
}
@ -180,7 +180,7 @@ func (r *msgReader) readCString() string {
s := string(b[0 : len(b)-1])
if *r.logLevel >= LogLevelTrace {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
}
@ -214,7 +214,7 @@ func (r *msgReader) readString(count int32) string {
s := string(b)
if *r.logLevel >= LogLevelTrace {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
}
@ -241,7 +241,7 @@ func (r *msgReader) readBytes(count int32) []byte {
return nil
}
if *r.logLevel >= LogLevelTrace {
if r.shouldLog(LogLevelTrace) {
r.log(LogLevelTrace, "msgReader.readBytes", "value", b, "msgBytesRemaining", r.msgBytesRemaining)
}

View File

@ -52,7 +52,7 @@ type Rows struct {
sql string
args []interface{}
log func(lvl int, msg string, ctx ...interface{})
logLevel *int
shouldLog func(lvl int) bool
unlockConn bool
}
@ -78,11 +78,11 @@ func (rows *Rows) close() {
rows.closed = true
if rows.err == nil {
if *rows.logLevel >= LogLevelInfo {
if rows.shouldLog(LogLevelInfo) {
endTime := time.Now()
rows.log(LogLevelInfo, "Query", "sql", rows.sql, "args", logQueryArgs(rows.args), "time", endTime.Sub(rows.startTime), "rowCount", rows.rowCount)
}
} else if *rows.logLevel >= LogLevelError {
} else if rows.shouldLog(LogLevelError) {
rows.log(LogLevelError, "Query", "sql", rows.sql, "args", logQueryArgs(rows.args))
}
}
@ -474,7 +474,7 @@ func (rows *Rows) Values() ([]interface{}, error) {
// from Query and handle it in *Rows.
func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) {
c.lastActivityTime = time.Now()
rows := &Rows{conn: c, startTime: c.lastActivityTime, sql: sql, args: args, log: c.log, logLevel: &c.logLevel}
rows := &Rows{conn: c, startTime: c.lastActivityTime, sql: sql, args: args, log: c.log, shouldLog: c.shouldLog}
if err := c.lock(); err != nil {
rows.abort(err)