diff --git a/pgpass.go b/pgpass.go index b6f028d2..34b9bdf5 100644 --- a/pgpass.go +++ b/pgpass.go @@ -9,7 +9,7 @@ import ( "strings" ) -func parsepgpass(cfg *ConnConfig, line string) *string { +func parsepgpass(line, cfgHost, cfgPort, cfgDatabase, cfgUsername string) *string { const ( backslash = "\r" colon = "\n" @@ -21,6 +21,9 @@ func parsepgpass(cfg *ConnConfig, line string) *string { username pw ) + if strings.HasPrefix(line, "#") { + return nil + } line = strings.Replace(line, `\:`, colon, -1) line = strings.Replace(line, `\\`, backslash, -1) parts := strings.Split(line, `:`) @@ -34,23 +37,19 @@ func parsepgpass(cfg *ConnConfig, line string) *string { parts[i] = strings.Replace(strings.Replace(parts[i], backslash, `\`, -1), colon, `:`, -1) switch i { case host: - if parts[i] != cfg.Host { + if parts[i] != cfgHost { return nil } case port: - portstr := fmt.Sprintf(`%v`, cfg.Port) - if portstr == "0" { - portstr = "5432" - } - if parts[i] != portstr { + if parts[i] != cfgPort { return nil } case database: - if parts[i] != cfg.Database { + if parts[i] != cfgDatabase { return nil } case username: - if parts[i] != cfg.User { + if parts[i] != cfgUsername { return nil } } @@ -72,10 +71,32 @@ func pgpass(cfg *ConnConfig) (found bool) { return } defer f.Close() + + host := cfg.Host + if _, err := os.Stat(host); err == nil { + host = "localhost" + } + port := fmt.Sprintf(`%v`, cfg.Port) + if port == "0" { + port = "5432" + } + username := cfg.User + if username == "" { + user, err := user.Current() + if err != nil { + return + } + username = user.Username + } + database := cfg.Database + if database == "" { + database = username + } + scanner := bufio.NewScanner(f) var pw *string for scanner.Scan() { - pw = parsepgpass(cfg, scanner.Text()) + pw = parsepgpass(scanner.Text(), host, port, database, username) if pw != nil { cfg.Password = *pw return true diff --git a/pgpass_test.go b/pgpass_test.go index d36e811a..2c63f130 100644 --- a/pgpass_test.go +++ b/pgpass_test.go @@ -4,6 +4,7 @@ import ( "fmt" "io/ioutil" "os" + "os/user" "strings" "testing" ) @@ -20,6 +21,8 @@ var passfile = [][]string{ {"test1", "5432", "curlydb", "curly", "nyuknyuknyuk"}, {"test2", "5432", "*", "shemp", "heymoe"}, {"test2", "5432", "*", "*", `test\\ing\:`}, + {"localhost", "*", "*", "*", "sesam"}, + {"test3", "*", "", "", "swordfish"}, // user will be filled later } func TestPGPass(t *testing.T) { @@ -27,9 +30,20 @@ func TestPGPass(t *testing.T) { if err != nil { t.Fatal(err) } + user, err := user.Current() + if err != nil { + t.Fatal(err) + } + passfile[len(passfile)-1][2] = user.Username + passfile[len(passfile)-1][3] = user.Username + defer tf.Close() defer os.Remove(tf.Name()) os.Setenv("PGPASSFILE", tf.Name()) + _, err = fmt.Fprintln(tf, "#some comment\n\n#more comment") + if err != nil { + t.Fatal(err) + } for _, l := range passfile { _, err := fmt.Fprintln(tf, strings.Join(l, `:`)) if err != nil { @@ -48,9 +62,28 @@ func TestPGPass(t *testing.T) { if cfg.Password != unescape(l[4]) { t.Fatalf(`Password mismatch entry %v want %s got %s`, i, unescape(l[4]), cfg.Password) } + if l[0] == "localhost" { + // using some existing path as socket + cfg := ConnConfig{Host: tf.Name(), Database: l[2], User: l[3]} + found := pgpass(&cfg) + if !found { + t.Fatalf("Entry %v not found", i) + } + if cfg.Password != unescape(l[4]) { + t.Fatalf(`Password mismatch entry %v want %s got %s`, i, unescape(l[4]), cfg.Password) + } + } } - cfg := ConnConfig{Host: "derp", Database: "herp", User: "joe"} + cfg := ConnConfig{Host: "test3"} found := pgpass(&cfg) + if !found { + t.Fatalf("Entry for default user name") + } + if cfg.Password != "swordfish" { + t.Fatalf(`Password mismatch for default user entry, want %s got %s`, "swordfish", cfg.Password) + } + cfg = ConnConfig{Host: "derp", Database: "herp", User: "joe"} + found = pgpass(&cfg) if found { t.Fatal("bad found") }