diff --git a/pgconn/config.go b/pgconn/config.go index 46b39f14..d7a0fefb 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -51,6 +51,8 @@ type Config struct { KerberosSpn string Fallbacks []*FallbackConfig + Sslnegotiation string // sslnegotiation=postgres or sslnegotiation=direct + // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. // It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next // fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs. @@ -318,6 +320,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con "sslkey": {}, "sslcert": {}, "sslrootcert": {}, + "sslnegotiation": {}, "sslpassword": {}, "sslsni": {}, "krbspn": {}, @@ -386,6 +389,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con config.Port = fallbacks[0].Port config.TLSConfig = fallbacks[0].TLSConfig config.Fallbacks = fallbacks[1:] + config.Sslnegotiation = settings["sslnegotiation"] passfile, err := pgpassfile.ReadPassfile(settings["passfile"]) if err == nil { @@ -449,6 +453,7 @@ func parseEnvSettings() map[string]string { "PGSSLSNI": "sslsni", "PGSSLROOTCERT": "sslrootcert", "PGSSLPASSWORD": "sslpassword", + "PGSSLNEGOTIATION": "sslnegotiation", "PGTARGETSESSIONATTRS": "target_session_attrs", "PGSERVICE": "service", "PGSERVICEFILE": "servicefile", @@ -646,6 +651,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P sslkey := settings["sslkey"] sslpassword := settings["sslpassword"] sslsni := settings["sslsni"] + sslnegotiation := settings["sslnegotiation"] // Match libpq default behavior if sslmode == "" { @@ -657,6 +663,11 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P tlsConfig := &tls.Config{} + if sslnegotiation == "direct" { + tlsConfig.NextProtos = []string{"postgresql"} + sslmode = "require" + } + if sslrootcert != "" { var caCertPool *x509.CertPool diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 14966aa4..fb2dd191 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -325,7 +325,15 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo if connectConfig.tlsConfig != nil { pgConn.contextWatcher = ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: pgConn.conn}) pgConn.contextWatcher.Watch(ctx) - tlsConn, err := startTLS(pgConn.conn, connectConfig.tlsConfig) + var ( + tlsConn net.Conn + err error + ) + if config.Sslnegotiation == "direct" { + tlsConn = tls.Client(pgConn.conn, connectConfig.tlsConfig) + } else { + tlsConn, err = startTLS(pgConn.conn, connectConfig.tlsConfig) + } pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. if err != nil { pgConn.conn.Close()