mirror of https://github.com/jackc/pgx.git
Add SetLogger to *Conn
Allow replacing logger after connection is established. Also refactor internals of logging such that there is a log method that adds the pid to all log calls instead of making a new logger object. The reason for this is so pid will be logged regardless of whether loggers are replaced and restored.pull/120/head
parent
beed0c0e5f
commit
cffae7ff5d
57
conn.go
57
conn.go
|
@ -131,8 +131,8 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
|||
if c.logger == nil {
|
||||
c.logLevel = LogLevelNone
|
||||
}
|
||||
c.mr.logger = c.logger
|
||||
c.mr.logLevel = c.logLevel
|
||||
c.mr.log = c.log
|
||||
c.mr.logLevel = &c.logLevel
|
||||
|
||||
if c.config.User == "" {
|
||||
user, err := user.Current()
|
||||
|
@ -141,14 +141,14 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
|||
}
|
||||
c.config.User = user.Username
|
||||
if c.logLevel >= LogLevelDebug {
|
||||
c.logger.Debug("Using default connection config", "User", c.config.User)
|
||||
c.log(LogLevelDebug, "Using default connection config", "User", c.config.User)
|
||||
}
|
||||
}
|
||||
|
||||
if c.config.Port == 0 {
|
||||
c.config.Port = 5432
|
||||
if c.logLevel >= LogLevelDebug {
|
||||
c.logger.Debug("Using default connection config", "Port", c.config.Port)
|
||||
c.log(LogLevelDebug, "Using default connection config", "Port", c.config.Port)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -181,12 +181,12 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
|||
|
||||
func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) {
|
||||
if c.logLevel >= LogLevelInfo {
|
||||
c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address))
|
||||
c.log(LogLevelInfo, fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address))
|
||||
}
|
||||
c.conn, err = c.config.Dial(network, address)
|
||||
if err != nil {
|
||||
if c.logLevel >= LogLevelError {
|
||||
c.logger.Error(fmt.Sprintf("Connection failed: %v", err))
|
||||
c.log(LogLevelError, fmt.Sprintf("Connection failed: %v", err))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -195,7 +195,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||
c.conn.Close()
|
||||
c.alive = false
|
||||
if c.logLevel >= LogLevelError {
|
||||
c.logger.Error(err.Error())
|
||||
c.log(LogLevelError, err.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
@ -208,11 +208,11 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||
|
||||
if tlsConfig != nil {
|
||||
if c.logLevel >= LogLevelDebug {
|
||||
c.logger.Debug("Starting TLS handshake")
|
||||
c.log(LogLevelDebug, "Starting TLS handshake")
|
||||
}
|
||||
if err := c.startTLS(tlsConfig); err != nil {
|
||||
if c.logLevel >= LogLevelError {
|
||||
c.logger.Error(fmt.Sprintf("TLS failed: %v", err))
|
||||
c.log(LogLevelError, fmt.Sprintf("TLS failed: %v", err))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -263,8 +263,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||
case readyForQuery:
|
||||
c.rxReadyForQuery(r)
|
||||
if c.logLevel >= LogLevelInfo {
|
||||
c.logger = &connLogger{logger: c.logger, pid: c.Pid}
|
||||
c.logger.Info("Connection established")
|
||||
c.log(LogLevelInfo, "Connection established")
|
||||
}
|
||||
|
||||
err = c.loadPgTypes()
|
||||
|
@ -340,7 +339,7 @@ func (c *Conn) Close() (err error) {
|
|||
|
||||
c.die(errors.New("Closed"))
|
||||
if c.logLevel >= LogLevelInfo {
|
||||
c.logger.Info("Closed connection")
|
||||
c.log(LogLevelInfo, "Closed connection")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -552,7 +551,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
|
|||
if c.logLevel >= LogLevelError {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
c.logger.Error(fmt.Sprintf("Prepare `%s` as `%s` failed: %v", name, sql, err))
|
||||
c.log(LogLevelError, fmt.Sprintf("Prepare `%s` as `%s` failed: %v", name, sql, err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
@ -978,11 +977,11 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag
|
|||
if err == nil {
|
||||
if c.logLevel >= LogLevelInfo {
|
||||
endTime := time.Now()
|
||||
c.logger.Info("Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag)
|
||||
c.log(LogLevelInfo, "Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag)
|
||||
}
|
||||
} else {
|
||||
if c.logLevel >= LogLevelError {
|
||||
c.logger.Error("Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err)
|
||||
c.log(LogLevelError, "Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1057,7 +1056,7 @@ func (c *Conn) rxMsg() (t byte, r *msgReader, err error) {
|
|||
c.lastActivityTime = time.Now()
|
||||
|
||||
if c.logLevel >= LogLevelTrace {
|
||||
c.logger.Debug("rxMsg", "type", string(t), "msgBytesRemaining", c.mr.msgBytesRemaining)
|
||||
c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBytesRemaining", c.mr.msgBytesRemaining)
|
||||
}
|
||||
|
||||
return t, &c.mr, err
|
||||
|
@ -1252,3 +1251,29 @@ func (c *Conn) unlock() error {
|
|||
c.busy = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) log(lvl int, msg string, ctx ...interface{}) {
|
||||
if c.Pid != 0 {
|
||||
ctx = append(ctx, "pid", c.Pid)
|
||||
}
|
||||
|
||||
switch lvl {
|
||||
case LogLevelTrace:
|
||||
c.logger.Debug(msg, ctx...)
|
||||
case LogLevelDebug:
|
||||
c.logger.Debug(msg, ctx...)
|
||||
case LogLevelInfo:
|
||||
c.logger.Info(msg, ctx...)
|
||||
case LogLevelWarn:
|
||||
c.logger.Warn(msg, ctx...)
|
||||
case LogLevelError:
|
||||
c.logger.Error(msg, ctx...)
|
||||
}
|
||||
}
|
||||
|
||||
// SetLogger replaces the current logger and returns the previous logger.
|
||||
func (c *Conn) SetLogger(logger Logger) Logger {
|
||||
oldLogger := c.logger
|
||||
c.logger = logger
|
||||
return oldLogger
|
||||
}
|
||||
|
|
27
conn_test.go
27
conn_test.go
|
@ -1344,3 +1344,30 @@ func TestCatchSimultaneousConnectionQueryAndExec(t *testing.T) {
|
|||
t.Fatalf("conn.Exec should have failed with pgx.ErrConnBusy, but it was %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
type testLogger struct{}
|
||||
|
||||
func (l *testLogger) Debug(msg string, ctx ...interface{}) {}
|
||||
func (l *testLogger) Info(msg string, ctx ...interface{}) {}
|
||||
func (l *testLogger) Warn(msg string, ctx ...interface{}) {}
|
||||
func (l *testLogger) Error(msg string, ctx ...interface{}) {}
|
||||
|
||||
func TestSetLogger(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
l1 := &testLogger{}
|
||||
oldLogger := conn.SetLogger(l1)
|
||||
if oldLogger != nil {
|
||||
t.Fatalf("Expected conn.SetLogger to return %v, but it was %v", nil, oldLogger)
|
||||
}
|
||||
|
||||
l2 := &testLogger{}
|
||||
oldLogger = conn.SetLogger(l2)
|
||||
if oldLogger != l1 {
|
||||
t.Fatalf("Expected conn.SetLogger to return %v, but it was %v", l1, oldLogger)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
25
logger.go
25
logger.go
|
@ -30,31 +30,6 @@ type Logger interface {
|
|||
Error(msg string, ctx ...interface{})
|
||||
}
|
||||
|
||||
type connLogger struct {
|
||||
logger Logger
|
||||
pid int32
|
||||
}
|
||||
|
||||
func (l *connLogger) Debug(msg string, ctx ...interface{}) {
|
||||
ctx = append(ctx, "pid", l.pid)
|
||||
l.logger.Debug(msg, ctx...)
|
||||
}
|
||||
|
||||
func (l *connLogger) Info(msg string, ctx ...interface{}) {
|
||||
ctx = append(ctx, "pid", l.pid)
|
||||
l.logger.Info(msg, ctx...)
|
||||
}
|
||||
|
||||
func (l *connLogger) Warn(msg string, ctx ...interface{}) {
|
||||
ctx = append(ctx, "pid", l.pid)
|
||||
l.logger.Warn(msg, ctx...)
|
||||
}
|
||||
|
||||
func (l *connLogger) Error(msg string, ctx ...interface{}) {
|
||||
ctx = append(ctx, "pid", l.pid)
|
||||
l.logger.Error(msg, ctx...)
|
||||
}
|
||||
|
||||
// Converts log level string to constant
|
||||
//
|
||||
// Valid levels:
|
||||
|
|
|
@ -14,8 +14,8 @@ type msgReader struct {
|
|||
buf [128]byte
|
||||
msgBytesRemaining int32
|
||||
err error
|
||||
logger Logger
|
||||
logLevel int
|
||||
log func(lvl int, msg string, ctx ...interface{})
|
||||
logLevel *int
|
||||
}
|
||||
|
||||
// Err returns any error that the msgReader has experienced
|
||||
|
@ -25,8 +25,8 @@ func (r *msgReader) Err() error {
|
|||
|
||||
// fatal tells r that a Fatal error has occurred
|
||||
func (r *msgReader) fatal(err error) {
|
||||
if r.logLevel >= LogLevelTrace {
|
||||
r.logger.Debug("msgReader.fatal", "error", err, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
if *r.logLevel >= LogLevelTrace {
|
||||
r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
r.err = err
|
||||
}
|
||||
|
@ -38,8 +38,8 @@ func (r *msgReader) rxMsg() (byte, error) {
|
|||
}
|
||||
|
||||
if r.msgBytesRemaining > 0 {
|
||||
if r.logLevel >= LogLevelTrace {
|
||||
r.logger.Debug("msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining)
|
||||
if *r.logLevel >= LogLevelTrace {
|
||||
r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
io.CopyN(ioutil.Discard, r.reader, int64(r.msgBytesRemaining))
|
||||
|
@ -68,8 +68,8 @@ func (r *msgReader) readByte() byte {
|
|||
return 0
|
||||
}
|
||||
|
||||
if r.logLevel >= LogLevelTrace {
|
||||
r.logger.Debug("msgReader.readByte", "value", b, "byteAsString", string(b), "msgBytesRemaining", r.msgBytesRemaining)
|
||||
if *r.logLevel >= LogLevelTrace {
|
||||
r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
return b
|
||||
|
@ -95,8 +95,8 @@ func (r *msgReader) readInt16() int16 {
|
|||
|
||||
n := int16(binary.BigEndian.Uint16(b))
|
||||
|
||||
if r.logLevel >= LogLevelTrace {
|
||||
r.logger.Debug("msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
if *r.logLevel >= LogLevelTrace {
|
||||
r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
return n
|
||||
|
@ -122,8 +122,8 @@ func (r *msgReader) readInt32() int32 {
|
|||
|
||||
n := int32(binary.BigEndian.Uint32(b))
|
||||
|
||||
if r.logLevel >= LogLevelTrace {
|
||||
r.logger.Debug("msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
if *r.logLevel >= LogLevelTrace {
|
||||
r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
return n
|
||||
|
@ -149,8 +149,8 @@ func (r *msgReader) readInt64() int64 {
|
|||
|
||||
n := int64(binary.BigEndian.Uint64(b))
|
||||
|
||||
if r.logLevel >= LogLevelTrace {
|
||||
r.logger.Debug("msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
if *r.logLevel >= LogLevelTrace {
|
||||
r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
return n
|
||||
|
@ -180,8 +180,8 @@ func (r *msgReader) readCString() string {
|
|||
|
||||
s := string(b[0 : len(b)-1])
|
||||
|
||||
if r.logLevel >= LogLevelTrace {
|
||||
r.logger.Debug("msgReader.readCString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
if *r.logLevel >= LogLevelTrace {
|
||||
r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
return s
|
||||
|
@ -214,8 +214,8 @@ func (r *msgReader) readString(count int32) string {
|
|||
|
||||
s := string(b)
|
||||
|
||||
if r.logLevel >= LogLevelTrace {
|
||||
r.logger.Debug("msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
if *r.logLevel >= LogLevelTrace {
|
||||
r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
return s
|
||||
|
@ -241,8 +241,8 @@ func (r *msgReader) readBytes(count int32) []byte {
|
|||
return nil
|
||||
}
|
||||
|
||||
if r.logLevel >= LogLevelTrace {
|
||||
r.logger.Debug("msgReader.readBytes", "value", b, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
if *r.logLevel >= LogLevelTrace {
|
||||
r.log(LogLevelTrace, "msgReader.readBytes", "value", b, "msgBytesRemaining", r.msgBytesRemaining)
|
||||
}
|
||||
|
||||
return b
|
||||
|
|
14
query.go
14
query.go
|
@ -51,8 +51,8 @@ type Rows struct {
|
|||
startTime time.Time
|
||||
sql string
|
||||
args []interface{}
|
||||
logger Logger
|
||||
logLevel int
|
||||
log func(lvl int, msg string, ctx ...interface{})
|
||||
logLevel *int
|
||||
unlockConn bool
|
||||
}
|
||||
|
||||
|
@ -78,12 +78,12 @@ func (rows *Rows) close() {
|
|||
rows.closed = true
|
||||
|
||||
if rows.err == nil {
|
||||
if rows.logLevel >= LogLevelInfo {
|
||||
if *rows.logLevel >= LogLevelInfo {
|
||||
endTime := time.Now()
|
||||
rows.logger.Info("Query", "sql", rows.sql, "args", logQueryArgs(rows.args), "time", endTime.Sub(rows.startTime), "rowCount", rows.rowCount)
|
||||
rows.log(LogLevelInfo, "Query", "sql", rows.sql, "args", logQueryArgs(rows.args), "time", endTime.Sub(rows.startTime), "rowCount", rows.rowCount)
|
||||
}
|
||||
} else if rows.logLevel >= LogLevelError {
|
||||
rows.logger.Error("Query", "sql", rows.sql, "args", logQueryArgs(rows.args))
|
||||
} else if *rows.logLevel >= LogLevelError {
|
||||
rows.log(LogLevelError, "Query", "sql", rows.sql, "args", logQueryArgs(rows.args))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -474,7 +474,7 @@ func (rows *Rows) Values() ([]interface{}, error) {
|
|||
// from Query and handle it in *Rows.
|
||||
func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) {
|
||||
c.lastActivityTime = time.Now()
|
||||
rows := &Rows{conn: c, startTime: c.lastActivityTime, sql: sql, args: args, logger: c.logger, logLevel: c.logLevel}
|
||||
rows := &Rows{conn: c, startTime: c.lastActivityTime, sql: sql, args: args, log: c.log, logLevel: &c.logLevel}
|
||||
|
||||
if err := c.lock(); err != nil {
|
||||
rows.abort(err)
|
||||
|
|
Loading…
Reference in New Issue