diff --git a/conn.go b/conn.go index 5b5bda5e..2af381b7 100644 --- a/conn.go +++ b/conn.go @@ -342,6 +342,25 @@ func ParseDSN(s string) (ConnConfig, error) { // PGDATABASE // PGUSER // PGPASSWORD +// PGSSLMODE +// +// Important TLS Security Notes: +// ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This +// includes defaulting to "prefer" behavior if no environment variable is set. +// +// See http://www.postgresql.org/docs/9.4/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION +// for details on what level of security each sslmode provides. +// +// "require" and "verify-ca" modes currently are treated as "verify-full". e.g. +// "They have stronger security guarantees than they would with libpq. Do not +// "rely on this behavior as it may be possible to match libpq in the match. If +// "you need full security use "verify-full". +// +// Several of the PGSSLMODE options (including the default behavior of "prefer") +// will set UseFallbackTLS to true and FallbackTLSConfig to a disabled or +// weakened TLS mode. This means that if ParseEnvLibpq is used, but TLSConfig is +// later set from a different source that UseFallbackTLS MUST be set false to +// avoid the possibility of falling back to weaker or disabled security. func ParseEnvLibpq() (ConnConfig, error) { var cc ConnConfig @@ -359,6 +378,30 @@ func ParseEnvLibpq() (ConnConfig, error) { cc.User = os.Getenv("PGUSER") cc.Password = os.Getenv("PGPASSWORD") + sslmode := os.Getenv("PGSSLMODE") + + // Match libpq default behavior + if sslmode == "" { + sslmode = "prefer" + } + + switch sslmode { + case "disable": + case "allow": + cc.UseFallbackTLS = true + cc.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true} + case "prefer": + cc.TLSConfig = &tls.Config{InsecureSkipVerify: true} + cc.UseFallbackTLS = true + cc.FallbackTLSConfig = nil + case "require", "verify-ca", "verify-full": + cc.TLSConfig = &tls.Config{ + ServerName: cc.Host, + } + default: + return cc, errors.New("sslmode is invalid") + } + return cc, nil } diff --git a/conn_test.go b/conn_test.go index 005347d3..32f58e03 100644 --- a/conn_test.go +++ b/conn_test.go @@ -382,14 +382,21 @@ func TestParseEnvLibpq(t *testing.T) { }() tests := []struct { + name string envvars map[string]string config pgx.ConnConfig }{ { + name: "No environment", envvars: map[string]string{}, - config: pgx.ConnConfig{}, + config: pgx.ConnConfig{ + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + }, }, { + name: "Normal PG vars", envvars: map[string]string{ "PGHOST": "123.123.123.123", "PGPORT": "7777", @@ -398,38 +405,169 @@ func TestParseEnvLibpq(t *testing.T) { "PGPASSWORD": "baz", }, config: pgx.ConnConfig{ - Host: "123.123.123.123", - Port: 7777, - Database: "foo", - User: "bar", - Password: "baz", + Host: "123.123.123.123", + Port: 7777, + Database: "foo", + User: "bar", + Password: "baz", + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + }, + }, + { + name: "sslmode=disable", + envvars: map[string]string{ + "PGSSLMODE": "disable", + }, + config: pgx.ConnConfig{ + TLSConfig: nil, + UseFallbackTLS: false, + }, + }, + { + name: "sslmode=allow", + envvars: map[string]string{ + "PGSSLMODE": "allow", + }, + config: pgx.ConnConfig{ + TLSConfig: nil, + UseFallbackTLS: true, + FallbackTLSConfig: &tls.Config{InsecureSkipVerify: true}, + }, + }, + { + name: "sslmode=prefer", + envvars: map[string]string{ + "PGSSLMODE": "prefer", + }, + config: pgx.ConnConfig{ + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + }, + }, + { + name: "sslmode=require", + envvars: map[string]string{ + "PGSSLMODE": "require", + }, + config: pgx.ConnConfig{ + TLSConfig: &tls.Config{}, + UseFallbackTLS: false, + }, + }, + { + name: "sslmode=verify-ca", + envvars: map[string]string{ + "PGSSLMODE": "verify-ca", + }, + config: pgx.ConnConfig{ + TLSConfig: &tls.Config{}, + UseFallbackTLS: false, + }, + }, + { + name: "sslmode=verify-full", + envvars: map[string]string{ + "PGSSLMODE": "verify-full", + }, + config: pgx.ConnConfig{ + TLSConfig: &tls.Config{}, + UseFallbackTLS: false, + }, + }, + { + name: "sslmode=verify-full with host", + envvars: map[string]string{ + "PGHOST": "pgx.example", + "PGSSLMODE": "verify-full", + }, + config: pgx.ConnConfig{ + Host: "pgx.example", + TLSConfig: &tls.Config{ + ServerName: "pgx.example", + }, + UseFallbackTLS: false, }, }, } - for i, tt := range tests { + for _, tt := range tests { for _, n := range pgEnvvars { err := os.Unsetenv(n) if err != nil { - t.Fatalf("%d. Unable to clear environment:", i, err) + t.Fatalf("%s: Unable to clear environment:", tt.name, err) } } for k, v := range tt.envvars { err := os.Setenv(k, v) if err != nil { - t.Fatalf("%d. Unable to set environment:", i, err) + t.Fatalf("%s: Unable to set environment:", tt.name, err) } } config, err := pgx.ParseEnvLibpq() if err != nil { - t.Errorf("%d. Unexpected error from pgx.ParseLibpq() => %v", i, err) + t.Errorf("%s: Unexpected error from pgx.ParseLibpq() => %v", tt.name, err) continue } - if !reflect.DeepEqual(config, tt.config) { - t.Errorf("%d. expected %#v got %#v", i, tt.config, config) + if config.Host != tt.config.Host { + t.Errorf("%s: expected Host to be %v got %v", tt.name, tt.config.Host, config.Host) + } + if config.Port != tt.config.Port { + t.Errorf("%s: expected Port to be %v got %v", tt.name, tt.config.Port, config.Port) + } + if config.Port != tt.config.Port { + t.Errorf("%s: expected Port to be %v got %v", tt.name, tt.config.Port, config.Port) + } + if config.User != tt.config.User { + t.Errorf("%s: expected User to be %v got %v", tt.name, tt.config.User, config.User) + } + if config.Password != tt.config.Password { + t.Errorf("%s: expected Password to be %v got %v", tt.name, tt.config.Password, config.Password) + } + + tlsTests := []struct { + name string + expected *tls.Config + actual *tls.Config + }{ + { + name: "TLSConfig", + expected: tt.config.TLSConfig, + actual: config.TLSConfig, + }, + { + name: "FallbackTLSConfig", + expected: tt.config.FallbackTLSConfig, + actual: config.FallbackTLSConfig, + }, + } + for _, tlsTest := range tlsTests { + name := tlsTest.name + expected := tlsTest.expected + actual := tlsTest.actual + + if expected == nil && actual != nil { + t.Errorf("%s / %s: expected nil, but it was set", tt.name, name) + } else if expected != nil && actual == nil { + t.Errorf("%s / %s: expected to be set, but got nil", tt.name, name) + } else if expected != nil && actual != nil { + if actual.InsecureSkipVerify != expected.InsecureSkipVerify { + t.Errorf("%s / %s: expected InsecureSkipVerify to be %v got %v", tt.name, name, expected.InsecureSkipVerify, actual.InsecureSkipVerify) + } + + if actual.ServerName != expected.ServerName { + t.Errorf("%s / %s: expected ServerName to be %v got %v", tt.name, name, expected.ServerName, actual.ServerName) + } + } + } + + if config.UseFallbackTLS != tt.config.UseFallbackTLS { + t.Errorf("%s: expected UseFallbackTLS to be %v got %v", tt.name, tt.config.UseFallbackTLS, config.UseFallbackTLS) } } }