Support pgpass

pull/247/head
j7b 2017-03-17 16:59:10 +00:00
parent ea4b3ffb14
commit 22c850e45d
3 changed files with 151 additions and 3 deletions

12
conn.go
View File

@ -445,7 +445,9 @@ func ParseURI(uri string) (ConnConfig, error) {
cp.RuntimeParams[k] = v[0] cp.RuntimeParams[k] = v[0]
} }
if cp.Password == "" {
pgpass(&cp)
}
return cp, nil return cp, nil
} }
@ -498,7 +500,9 @@ func ParseDSN(s string) (ConnConfig, error) {
if err != nil { if err != nil {
return cp, err return cp, err
} }
if cp.Password == "" {
pgpass(&cp)
}
return cp, nil return cp, nil
} }
@ -561,7 +565,9 @@ func ParseEnvLibpq() (ConnConfig, error) {
if appname := os.Getenv("PGAPPNAME"); appname != "" { if appname := os.Getenv("PGAPPNAME"); appname != "" {
cc.RuntimeParams["application_name"] = appname cc.RuntimeParams["application_name"] = appname
} }
if cc.Password == "" {
pgpass(&cc)
}
return cc, nil return cc, nil
} }

85
pgpass.go Normal file
View File

@ -0,0 +1,85 @@
package pgx
import (
"bufio"
"fmt"
"os"
"os/user"
"path/filepath"
"strings"
)
func parsepgpass(cfg *ConnConfig, line string) *string {
const (
backslash = "\r"
colon = "\n"
)
const (
host int = iota
port
database
username
pw
)
line = strings.Replace(line, `\:`, colon, -1)
line = strings.Replace(line, `\\`, backslash, -1)
parts := strings.Split(line, `:`)
if len(parts) != 5 {
return nil
}
for i := range parts {
if parts[i] == `*` {
continue
}
parts[i] = strings.Replace(strings.Replace(parts[i], backslash, `\`, -1), colon, `:`, -1)
switch i {
case host:
if parts[i] != cfg.Host {
return nil
}
case port:
portstr := fmt.Sprintf(`%v`, cfg.Port)
if portstr == "0" {
portstr = "5432"
}
if parts[i] != portstr {
return nil
}
case database:
if parts[i] != cfg.Database {
return nil
}
case username:
if parts[i] != cfg.User {
return nil
}
}
}
return &parts[4]
}
func pgpass(cfg *ConnConfig) (found bool) {
passfile := os.Getenv("PGPASSFILE")
if passfile == "" {
u, err := user.Current()
if err != nil {
return
}
passfile = filepath.Join(u.HomeDir, ".pgpass")
}
f, err := os.Open(passfile)
if err != nil {
return
}
defer f.Close()
scanner := bufio.NewScanner(f)
var pw *string
for scanner.Scan() {
pw = parsepgpass(cfg, scanner.Text())
if pw != nil {
cfg.Password = *pw
return true
}
}
return false
}

57
pgpass_test.go Normal file
View File

@ -0,0 +1,57 @@
package pgx
import (
"fmt"
"io/ioutil"
"os"
"strings"
"testing"
)
func unescape(s string) string {
s = strings.Replace(s, `\:`, `:`, -1)
s = strings.Replace(s, `\\`, `\`, -1)
return s
}
var passfile = [][]string{
[]string{"test1", "5432", "larrydb", "larry", "whatstheidea"},
[]string{"test1", "5432", "moedb", "moe", "imbecile"},
[]string{"test1", "5432", "curlydb", "curly", "nyuknyuknyuk"},
[]string{"test2", "5432", "*", "shemp", "heymoe"},
[]string{"test2", "5432", "*", "*", `test\\ing\:`},
}
func TestPGPass(t *testing.T) {
tf, err := ioutil.TempFile("", "")
if err != nil {
t.Fatal(err)
}
defer tf.Close()
defer os.Remove(tf.Name())
os.Setenv("PGPASSFILE", tf.Name())
for _, l := range passfile {
_, err := fmt.Fprintln(tf, strings.Join(l, `:`))
if err != nil {
t.Fatal(err)
}
}
if err = tf.Close(); err != nil {
t.Fatal(err)
}
for i, l := range passfile {
cfg := ConnConfig{Host: l[0], Database: l[2], User: l[3]}
found := pgpass(&cfg)
if !found {
t.Fatalf("Entry %v not found", i)
}
if cfg.Password != unescape(l[4]) {
t.Fatalf(`Password mismatch entry %v want %s got %s`, i, unescape(l[4]), cfg.Password)
}
}
cfg := ConnConfig{Host: "derp", Database: "herp", User: "joe"}
found := pgpass(&cfg)
if found {
t.Fatal("bad found")
}
}