mirror of
https://github.com/jackc/pgx.git
synced 2025-04-27 21:25:53 +00:00
This commit adds a background reader that can optionally buffer reads. It is used whenever a potentially blocking write is made to the server. The background reader is started on a slight delay so there should be no meaningful performance impact as it doesn't run for quick queries and its overhead is minimal relative to slower queries.
101 lines
2.3 KiB
Go
101 lines
2.3 KiB
Go
package pgconn
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
|
|
"github.com/jackc/pgx/v5/pgproto3"
|
|
)
|
|
|
|
// NewGSSFunc creates a GSS authentication provider, for use with
|
|
// RegisterGSSProvider.
|
|
type NewGSSFunc func() (GSS, error)
|
|
|
|
var newGSS NewGSSFunc
|
|
|
|
// RegisterGSSProvider registers a GSS authentication provider. For example, if
|
|
// you need to use Kerberos to authenticate with your server, add this to your
|
|
// main package:
|
|
//
|
|
// import "github.com/otan/gopgkrb5"
|
|
//
|
|
// func init() {
|
|
// pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() })
|
|
// }
|
|
func RegisterGSSProvider(newGSSArg NewGSSFunc) {
|
|
newGSS = newGSSArg
|
|
}
|
|
|
|
// GSS provides GSSAPI authentication (e.g., Kerberos).
|
|
type GSS interface {
|
|
GetInitToken(host string, service string) ([]byte, error)
|
|
GetInitTokenFromSPN(spn string) ([]byte, error)
|
|
Continue(inToken []byte) (done bool, outToken []byte, err error)
|
|
}
|
|
|
|
func (c *PgConn) gssAuth() error {
|
|
if newGSS == nil {
|
|
return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5")
|
|
}
|
|
cli, err := newGSS()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var nextData []byte
|
|
if c.config.KerberosSpn != "" {
|
|
// Use the supplied SPN if provided.
|
|
nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn)
|
|
} else {
|
|
// Allow the kerberos service name to be overridden
|
|
service := "postgres"
|
|
if c.config.KerberosSrvName != "" {
|
|
service = c.config.KerberosSrvName
|
|
}
|
|
nextData, err = cli.GetInitToken(c.config.Host, service)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for {
|
|
gssResponse := &pgproto3.GSSResponse{
|
|
Data: nextData,
|
|
}
|
|
c.frontend.Send(gssResponse)
|
|
err = c.flushWithPotentialWriteReadDeadlock()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
resp, err := c.rxGSSContinue()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var done bool
|
|
done, nextData, err = cli.Continue(resp.Data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if done {
|
|
break
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) {
|
|
msg, err := c.receiveMessage()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch m := msg.(type) {
|
|
case *pgproto3.AuthenticationGSSContinue:
|
|
return m, nil
|
|
case *pgproto3.ErrorResponse:
|
|
return nil, ErrorResponseToPgError(m)
|
|
}
|
|
|
|
return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg)
|
|
}
|