mirror of
https://github.com/jackc/pgx.git
synced 2025-05-29 10:42:31 +00:00
Add basic PGSSLMODE support to ParseEnvLibpq
This commit is contained in:
parent
51d6d1a3a6
commit
07a11abc07
43
conn.go
43
conn.go
@ -342,6 +342,25 @@ func ParseDSN(s string) (ConnConfig, error) {
|
||||
// PGDATABASE
|
||||
// PGUSER
|
||||
// PGPASSWORD
|
||||
// PGSSLMODE
|
||||
//
|
||||
// Important TLS Security Notes:
|
||||
// ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This
|
||||
// includes defaulting to "prefer" behavior if no environment variable is set.
|
||||
//
|
||||
// See http://www.postgresql.org/docs/9.4/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION
|
||||
// for details on what level of security each sslmode provides.
|
||||
//
|
||||
// "require" and "verify-ca" modes currently are treated as "verify-full". e.g.
|
||||
// "They have stronger security guarantees than they would with libpq. Do not
|
||||
// "rely on this behavior as it may be possible to match libpq in the match. If
|
||||
// "you need full security use "verify-full".
|
||||
//
|
||||
// Several of the PGSSLMODE options (including the default behavior of "prefer")
|
||||
// will set UseFallbackTLS to true and FallbackTLSConfig to a disabled or
|
||||
// weakened TLS mode. This means that if ParseEnvLibpq is used, but TLSConfig is
|
||||
// later set from a different source that UseFallbackTLS MUST be set false to
|
||||
// avoid the possibility of falling back to weaker or disabled security.
|
||||
func ParseEnvLibpq() (ConnConfig, error) {
|
||||
var cc ConnConfig
|
||||
|
||||
@ -359,6 +378,30 @@ func ParseEnvLibpq() (ConnConfig, error) {
|
||||
cc.User = os.Getenv("PGUSER")
|
||||
cc.Password = os.Getenv("PGPASSWORD")
|
||||
|
||||
sslmode := os.Getenv("PGSSLMODE")
|
||||
|
||||
// Match libpq default behavior
|
||||
if sslmode == "" {
|
||||
sslmode = "prefer"
|
||||
}
|
||||
|
||||
switch sslmode {
|
||||
case "disable":
|
||||
case "allow":
|
||||
cc.UseFallbackTLS = true
|
||||
cc.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
case "prefer":
|
||||
cc.TLSConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
cc.UseFallbackTLS = true
|
||||
cc.FallbackTLSConfig = nil
|
||||
case "require", "verify-ca", "verify-full":
|
||||
cc.TLSConfig = &tls.Config{
|
||||
ServerName: cc.Host,
|
||||
}
|
||||
default:
|
||||
return cc, errors.New("sslmode is invalid")
|
||||
}
|
||||
|
||||
return cc, nil
|
||||
}
|
||||
|
||||
|
162
conn_test.go
162
conn_test.go
@ -382,14 +382,21 @@ func TestParseEnvLibpq(t *testing.T) {
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
envvars map[string]string
|
||||
config pgx.ConnConfig
|
||||
}{
|
||||
{
|
||||
name: "No environment",
|
||||
envvars: map[string]string{},
|
||||
config: pgx.ConnConfig{},
|
||||
config: pgx.ConnConfig{
|
||||
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
UseFallbackTLS: true,
|
||||
FallbackTLSConfig: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Normal PG vars",
|
||||
envvars: map[string]string{
|
||||
"PGHOST": "123.123.123.123",
|
||||
"PGPORT": "7777",
|
||||
@ -398,38 +405,169 @@ func TestParseEnvLibpq(t *testing.T) {
|
||||
"PGPASSWORD": "baz",
|
||||
},
|
||||
config: pgx.ConnConfig{
|
||||
Host: "123.123.123.123",
|
||||
Port: 7777,
|
||||
Database: "foo",
|
||||
User: "bar",
|
||||
Password: "baz",
|
||||
Host: "123.123.123.123",
|
||||
Port: 7777,
|
||||
Database: "foo",
|
||||
User: "bar",
|
||||
Password: "baz",
|
||||
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
UseFallbackTLS: true,
|
||||
FallbackTLSConfig: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sslmode=disable",
|
||||
envvars: map[string]string{
|
||||
"PGSSLMODE": "disable",
|
||||
},
|
||||
config: pgx.ConnConfig{
|
||||
TLSConfig: nil,
|
||||
UseFallbackTLS: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sslmode=allow",
|
||||
envvars: map[string]string{
|
||||
"PGSSLMODE": "allow",
|
||||
},
|
||||
config: pgx.ConnConfig{
|
||||
TLSConfig: nil,
|
||||
UseFallbackTLS: true,
|
||||
FallbackTLSConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sslmode=prefer",
|
||||
envvars: map[string]string{
|
||||
"PGSSLMODE": "prefer",
|
||||
},
|
||||
config: pgx.ConnConfig{
|
||||
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
UseFallbackTLS: true,
|
||||
FallbackTLSConfig: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sslmode=require",
|
||||
envvars: map[string]string{
|
||||
"PGSSLMODE": "require",
|
||||
},
|
||||
config: pgx.ConnConfig{
|
||||
TLSConfig: &tls.Config{},
|
||||
UseFallbackTLS: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sslmode=verify-ca",
|
||||
envvars: map[string]string{
|
||||
"PGSSLMODE": "verify-ca",
|
||||
},
|
||||
config: pgx.ConnConfig{
|
||||
TLSConfig: &tls.Config{},
|
||||
UseFallbackTLS: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sslmode=verify-full",
|
||||
envvars: map[string]string{
|
||||
"PGSSLMODE": "verify-full",
|
||||
},
|
||||
config: pgx.ConnConfig{
|
||||
TLSConfig: &tls.Config{},
|
||||
UseFallbackTLS: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sslmode=verify-full with host",
|
||||
envvars: map[string]string{
|
||||
"PGHOST": "pgx.example",
|
||||
"PGSSLMODE": "verify-full",
|
||||
},
|
||||
config: pgx.ConnConfig{
|
||||
Host: "pgx.example",
|
||||
TLSConfig: &tls.Config{
|
||||
ServerName: "pgx.example",
|
||||
},
|
||||
UseFallbackTLS: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
for _, tt := range tests {
|
||||
for _, n := range pgEnvvars {
|
||||
err := os.Unsetenv(n)
|
||||
if err != nil {
|
||||
t.Fatalf("%d. Unable to clear environment:", i, err)
|
||||
t.Fatalf("%s: Unable to clear environment:", tt.name, err)
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range tt.envvars {
|
||||
err := os.Setenv(k, v)
|
||||
if err != nil {
|
||||
t.Fatalf("%d. Unable to set environment:", i, err)
|
||||
t.Fatalf("%s: Unable to set environment:", tt.name, err)
|
||||
}
|
||||
}
|
||||
|
||||
config, err := pgx.ParseEnvLibpq()
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected error from pgx.ParseLibpq() => %v", i, err)
|
||||
t.Errorf("%s: Unexpected error from pgx.ParseLibpq() => %v", tt.name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(config, tt.config) {
|
||||
t.Errorf("%d. expected %#v got %#v", i, tt.config, config)
|
||||
if config.Host != tt.config.Host {
|
||||
t.Errorf("%s: expected Host to be %v got %v", tt.name, tt.config.Host, config.Host)
|
||||
}
|
||||
if config.Port != tt.config.Port {
|
||||
t.Errorf("%s: expected Port to be %v got %v", tt.name, tt.config.Port, config.Port)
|
||||
}
|
||||
if config.Port != tt.config.Port {
|
||||
t.Errorf("%s: expected Port to be %v got %v", tt.name, tt.config.Port, config.Port)
|
||||
}
|
||||
if config.User != tt.config.User {
|
||||
t.Errorf("%s: expected User to be %v got %v", tt.name, tt.config.User, config.User)
|
||||
}
|
||||
if config.Password != tt.config.Password {
|
||||
t.Errorf("%s: expected Password to be %v got %v", tt.name, tt.config.Password, config.Password)
|
||||
}
|
||||
|
||||
tlsTests := []struct {
|
||||
name string
|
||||
expected *tls.Config
|
||||
actual *tls.Config
|
||||
}{
|
||||
{
|
||||
name: "TLSConfig",
|
||||
expected: tt.config.TLSConfig,
|
||||
actual: config.TLSConfig,
|
||||
},
|
||||
{
|
||||
name: "FallbackTLSConfig",
|
||||
expected: tt.config.FallbackTLSConfig,
|
||||
actual: config.FallbackTLSConfig,
|
||||
},
|
||||
}
|
||||
for _, tlsTest := range tlsTests {
|
||||
name := tlsTest.name
|
||||
expected := tlsTest.expected
|
||||
actual := tlsTest.actual
|
||||
|
||||
if expected == nil && actual != nil {
|
||||
t.Errorf("%s / %s: expected nil, but it was set", tt.name, name)
|
||||
} else if expected != nil && actual == nil {
|
||||
t.Errorf("%s / %s: expected to be set, but got nil", tt.name, name)
|
||||
} else if expected != nil && actual != nil {
|
||||
if actual.InsecureSkipVerify != expected.InsecureSkipVerify {
|
||||
t.Errorf("%s / %s: expected InsecureSkipVerify to be %v got %v", tt.name, name, expected.InsecureSkipVerify, actual.InsecureSkipVerify)
|
||||
}
|
||||
|
||||
if actual.ServerName != expected.ServerName {
|
||||
t.Errorf("%s / %s: expected ServerName to be %v got %v", tt.name, name, expected.ServerName, actual.ServerName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if config.UseFallbackTLS != tt.config.UseFallbackTLS {
|
||||
t.Errorf("%s: expected UseFallbackTLS to be %v got %v", tt.name, tt.config.UseFallbackTLS, config.UseFallbackTLS)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user