From 22c850e45de8dd7db8606774b37459fde861ec9d Mon Sep 17 00:00:00 2001 From: j7b Date: Fri, 17 Mar 2017 16:59:10 +0000 Subject: [PATCH] Support pgpass --- conn.go | 12 +++++-- pgpass.go | 85 ++++++++++++++++++++++++++++++++++++++++++++++++++ pgpass_test.go | 57 +++++++++++++++++++++++++++++++++ 3 files changed, 151 insertions(+), 3 deletions(-) create mode 100644 pgpass.go create mode 100644 pgpass_test.go diff --git a/conn.go b/conn.go index 602ecbff..a8f1e386 100644 --- a/conn.go +++ b/conn.go @@ -445,7 +445,9 @@ func ParseURI(uri string) (ConnConfig, error) { cp.RuntimeParams[k] = v[0] } - + if cp.Password == "" { + pgpass(&cp) + } return cp, nil } @@ -498,7 +500,9 @@ func ParseDSN(s string) (ConnConfig, error) { if err != nil { return cp, err } - + if cp.Password == "" { + pgpass(&cp) + } return cp, nil } @@ -561,7 +565,9 @@ func ParseEnvLibpq() (ConnConfig, error) { if appname := os.Getenv("PGAPPNAME"); appname != "" { cc.RuntimeParams["application_name"] = appname } - + if cc.Password == "" { + pgpass(&cc) + } return cc, nil } diff --git a/pgpass.go b/pgpass.go new file mode 100644 index 00000000..b6f028d2 --- /dev/null +++ b/pgpass.go @@ -0,0 +1,85 @@ +package pgx + +import ( + "bufio" + "fmt" + "os" + "os/user" + "path/filepath" + "strings" +) + +func parsepgpass(cfg *ConnConfig, line string) *string { + const ( + backslash = "\r" + colon = "\n" + ) + const ( + host int = iota + port + database + username + pw + ) + line = strings.Replace(line, `\:`, colon, -1) + line = strings.Replace(line, `\\`, backslash, -1) + parts := strings.Split(line, `:`) + if len(parts) != 5 { + return nil + } + for i := range parts { + if parts[i] == `*` { + continue + } + parts[i] = strings.Replace(strings.Replace(parts[i], backslash, `\`, -1), colon, `:`, -1) + switch i { + case host: + if parts[i] != cfg.Host { + return nil + } + case port: + portstr := fmt.Sprintf(`%v`, cfg.Port) + if portstr == "0" { + portstr = "5432" + } + if parts[i] != portstr { + return nil + } + case database: + if parts[i] != cfg.Database { + return nil + } + case username: + if parts[i] != cfg.User { + return nil + } + } + } + return &parts[4] +} + +func pgpass(cfg *ConnConfig) (found bool) { + passfile := os.Getenv("PGPASSFILE") + if passfile == "" { + u, err := user.Current() + if err != nil { + return + } + passfile = filepath.Join(u.HomeDir, ".pgpass") + } + f, err := os.Open(passfile) + if err != nil { + return + } + defer f.Close() + scanner := bufio.NewScanner(f) + var pw *string + for scanner.Scan() { + pw = parsepgpass(cfg, scanner.Text()) + if pw != nil { + cfg.Password = *pw + return true + } + } + return false +} diff --git a/pgpass_test.go b/pgpass_test.go new file mode 100644 index 00000000..f6094c82 --- /dev/null +++ b/pgpass_test.go @@ -0,0 +1,57 @@ +package pgx + +import ( + "fmt" + "io/ioutil" + "os" + "strings" + "testing" +) + +func unescape(s string) string { + s = strings.Replace(s, `\:`, `:`, -1) + s = strings.Replace(s, `\\`, `\`, -1) + return s +} + +var passfile = [][]string{ + []string{"test1", "5432", "larrydb", "larry", "whatstheidea"}, + []string{"test1", "5432", "moedb", "moe", "imbecile"}, + []string{"test1", "5432", "curlydb", "curly", "nyuknyuknyuk"}, + []string{"test2", "5432", "*", "shemp", "heymoe"}, + []string{"test2", "5432", "*", "*", `test\\ing\:`}, +} + +func TestPGPass(t *testing.T) { + tf, err := ioutil.TempFile("", "") + if err != nil { + t.Fatal(err) + } + defer tf.Close() + defer os.Remove(tf.Name()) + os.Setenv("PGPASSFILE", tf.Name()) + for _, l := range passfile { + _, err := fmt.Fprintln(tf, strings.Join(l, `:`)) + if err != nil { + t.Fatal(err) + } + } + if err = tf.Close(); err != nil { + t.Fatal(err) + } + for i, l := range passfile { + cfg := ConnConfig{Host: l[0], Database: l[2], User: l[3]} + found := pgpass(&cfg) + if !found { + t.Fatalf("Entry %v not found", i) + } + if cfg.Password != unescape(l[4]) { + t.Fatalf(`Password mismatch entry %v want %s got %s`, i, unescape(l[4]), cfg.Password) + } + } + cfg := ConnConfig{Host: "derp", Database: "herp", User: "joe"} + found := pgpass(&cfg) + if found { + t.Fatal("bad found") + } +}