diff --git a/config.go b/config.go index 5cee9297..6e6930ee 100644 --- a/config.go +++ b/config.go @@ -257,6 +257,8 @@ func ParseConfig(connString string) (*Config, error) { "sslkey": {}, "sslcert": {}, "sslrootcert": {}, + "krbspn": {}, + "krbsrvname": {}, "target_session_attrs": {}, "min_read_buffer_size": {}, "service": {}, diff --git a/go.mod b/go.mod index fb3ed181..2a2d6810 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.1.1 + github.com/jackc/pgproto3/v2 v2.2.1-0.20220412121321-175856ffd3c8 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/stretchr/testify v1.7.0 golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 diff --git a/go.sum b/go.sum index bdb5ee8c..c558564b 100644 --- a/go.sum +++ b/go.sum @@ -31,8 +31,9 @@ github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.1.1 h1:7PQ/4gLoqnl87ZxL7xjO0DR5gYuviDCZxQJsUlFW1eI= github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.2.1-0.20220412121321-175856ffd3c8 h1:KxsCQec+1iwJXtxnbbS/dY0EJ6rJEUlFsrJUnL5A2XI= +github.com/jackc/pgproto3/v2 v2.2.1-0.20220412121321-175856ffd3c8/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= @@ -112,7 +113,6 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= diff --git a/krb5.go b/krb5.go new file mode 100644 index 00000000..1f9ce97c --- /dev/null +++ b/krb5.go @@ -0,0 +1,94 @@ +package pgconn + +import ( + "errors" + "github.com/jackc/pgproto3/v2" +) + +// NewGSSFunc creates a GSS authentication provider, for use with +// RegisterGSSProvider. +type NewGSSFunc func() (GSS, error) + +var newGSS NewGSSFunc + +// RegisterGSSProvider registers a GSS authentication provider. For example, if +// you need to use Kerberos to authenticate with your server, add this to your +// main package: +// +// import "github.com/otan/gopgkrb5" +// +// func init() { +// pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() }) +// } +func RegisterGSSProvider(newGSSArg NewGSSFunc) { + newGSS = newGSSArg +} + +// GSS provides GSSAPI authentication (e.g., Kerberos). +type GSS interface { + GetInitToken(host string, service string) ([]byte, error) + GetInitTokenFromSPN(spn string) ([]byte, error) + Continue(inToken []byte) (done bool, outToken []byte, err error) +} + +func (c *PgConn) gssAuth() error { + if newGSS == nil { + return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5") + } + cli, err := newGSS() + if err != nil { + return err + } + + var nextData []byte + if spn, ok := c.config.RuntimeParams["krbspn"]; ok { + // Use the supplied SPN if provided. + nextData, err = cli.GetInitTokenFromSPN(spn) + } else { + // Allow the kerberos service name to be overridden + service := "postgres" + if val, ok := c.config.RuntimeParams["krbsrvname"]; ok { + service = val + } + nextData, err = cli.GetInitToken(c.config.Host, service) + } + if err != nil { + return err + } + + for { + gssResponse := &pgproto3.GSSResponse{ + Data: nextData, + } + _, err = c.conn.Write(gssResponse.Encode(nil)) + if err != nil { + return err + } + resp, err := c.rxGSSContinue() + if err != nil { + return err + } + var done bool + done, nextData, err = cli.Continue(resp.Data) + if err != nil { + return err + } + if done { + break + } + } + return nil +} + +func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) { + msg, err := c.receiveMessage() + if err != nil { + return nil, err + } + gssContinue, ok := msg.(*pgproto3.AuthenticationGSSContinue) + if ok { + return gssContinue, nil + } + + return nil, errors.New("expected AuthenticationGSSContinue message but received unexpected message") +} diff --git a/pgconn.go b/pgconn.go index 9a496ed0..0d07ac57 100644 --- a/pgconn.go +++ b/pgconn.go @@ -320,7 +320,12 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.conn.Close() return nil, &connectError{config: config, msg: "failed SASL auth", err: err} } - + case *pgproto3.AuthenticationGSS: + err = pgConn.gssAuth() + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed GSS auth", err: err} + } case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle if config.ValidateConnect != nil {