mirror of https://github.com/harness/drone.git
feat: [CODE-2865]: ssh support changes (#3052)
parent
8b06e30bb9
commit
64d66772d4
|
@ -22,6 +22,8 @@ import (
|
||||||
|
|
||||||
"github.com/harness/gitness/app/auth"
|
"github.com/harness/gitness/app/auth"
|
||||||
"github.com/harness/gitness/types"
|
"github.com/harness/gitness/types"
|
||||||
|
|
||||||
|
"github.com/gliderlabs/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
type key int
|
type key int
|
||||||
|
@ -119,3 +121,7 @@ func RequestIDFrom(ctx context.Context) (string, bool) {
|
||||||
v, ok := ctx.Value(requestIDKey).(string)
|
v, ok := ctx.Value(requestIDKey).(string)
|
||||||
return v, ok && v != ""
|
return v, ok && v != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithRequestIDSSH(parent ssh.Context, v string) {
|
||||||
|
ssh.Context.SetValue(parent, requestIDKey, v)
|
||||||
|
}
|
||||||
|
|
|
@ -28,7 +28,11 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Service interface {
|
type Service interface {
|
||||||
ValidateKey(ctx context.Context, publicKey ssh.PublicKey, usage enum.PublicKeyUsage) (*types.PrincipalInfo, error)
|
ValidateKey(ctx context.Context,
|
||||||
|
username string,
|
||||||
|
publicKey ssh.PublicKey,
|
||||||
|
usage enum.PublicKeyUsage,
|
||||||
|
) (*types.PrincipalInfo, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(
|
func NewService(
|
||||||
|
@ -50,6 +54,7 @@ type LocalService struct {
|
||||||
// It updates the verified timestamp of the matched key to mark it as used.
|
// It updates the verified timestamp of the matched key to mark it as used.
|
||||||
func (s LocalService) ValidateKey(
|
func (s LocalService) ValidateKey(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
_ string,
|
||||||
publicKey ssh.PublicKey,
|
publicKey ssh.PublicKey,
|
||||||
usage enum.PublicKeyUsage,
|
usage enum.PublicKeyUsage,
|
||||||
) (*types.PrincipalInfo, error) {
|
) (*types.PrincipalInfo, error) {
|
||||||
|
|
|
@ -0,0 +1,33 @@
|
||||||
|
// Copyright 2023 Harness, Inc.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
const loggerKey contextKey = "logger"
|
||||||
|
|
||||||
|
func getRequestID(reqID string) string {
|
||||||
|
if len(reqID) > 20 {
|
||||||
|
reqID = reqID[:20]
|
||||||
|
}
|
||||||
|
return reqID
|
||||||
|
}
|
||||||
|
|
||||||
|
func getLoggerWithRequestID(sessionID string) zerolog.Logger {
|
||||||
|
return log.Logger.With().Str("request_id", getRequestID(sessionID)).Logger()
|
||||||
|
}
|
|
@ -0,0 +1,138 @@
|
||||||
|
// Copyright 2023 Harness, Inc.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"runtime/debug"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/harness/gitness/app/api/request"
|
||||||
|
|
||||||
|
"github.com/gliderlabs/ssh"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Middleware func(ssh.Handler) ssh.Handler
|
||||||
|
|
||||||
|
// ChainMiddleware combines multiple middleware into a single ssh.Handler.
|
||||||
|
func ChainMiddleware(handler ssh.Handler, middlewares ...Middleware) ssh.Handler {
|
||||||
|
for i := len(middlewares) - 1; i >= 0; i-- { // Reverse order to maintain correct chaining
|
||||||
|
handler = middlewares[i](handler)
|
||||||
|
}
|
||||||
|
return handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// PanicRecoverMiddleware wraps the SSH handler to recover from panics and log them.
|
||||||
|
func PanicRecoverMiddleware(next ssh.Handler) ssh.Handler {
|
||||||
|
return func(s ssh.Session) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
// Log the panic and stack trace
|
||||||
|
// Get the context and logger
|
||||||
|
ctx := s.Context()
|
||||||
|
logger := getLogger(ctx)
|
||||||
|
logger.Error().Msgf("encountered panic while processing ssh operation: %v\n%s", r, debug.Stack())
|
||||||
|
_, _ = s.Write([]byte("Internal server error. Please try again later.\n"))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Call the next handler
|
||||||
|
next(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func HLogAccessLogHandler(next ssh.Handler) ssh.Handler {
|
||||||
|
return func(s ssh.Session) {
|
||||||
|
start := time.Now()
|
||||||
|
user := s.User()
|
||||||
|
remoteAddr := s.RemoteAddr()
|
||||||
|
command := s.Command()
|
||||||
|
|
||||||
|
// Get the context and logger
|
||||||
|
ctx := s.Context()
|
||||||
|
logger := getLogger(ctx)
|
||||||
|
// Log session start
|
||||||
|
logger.Info().
|
||||||
|
Str("ssh.user", user).
|
||||||
|
Str("ssh.remote", remoteAddr.String()).
|
||||||
|
Strs("ssh.command", command).
|
||||||
|
Msg("SSH session started")
|
||||||
|
|
||||||
|
// Call the next handler
|
||||||
|
next(s)
|
||||||
|
|
||||||
|
// Log session completion
|
||||||
|
duration := time.Since(start)
|
||||||
|
logger.Info().
|
||||||
|
Dur("ssh.elapsed_ms", duration).
|
||||||
|
Str("ssh.user", user).
|
||||||
|
Msg("SSH session completed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func HLogRequestIDHandler(next ssh.Handler) ssh.Handler {
|
||||||
|
return func(s ssh.Session) {
|
||||||
|
sshCtx := s.Context() // This is ssh.Context
|
||||||
|
reqID := getRequestID(sshCtx.SessionID())
|
||||||
|
request.WithRequestIDSSH(sshCtx, reqID)
|
||||||
|
|
||||||
|
log := getLoggerWithRequestID(reqID)
|
||||||
|
sshCtx.SetValue(loggerKey, log)
|
||||||
|
|
||||||
|
// continue serving request
|
||||||
|
next(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type PublicKeyMiddleware func(next ssh.PublicKeyHandler) ssh.PublicKeyHandler
|
||||||
|
|
||||||
|
func ChainPublicKeyMiddleware(handler ssh.PublicKeyHandler, middlewares ...PublicKeyMiddleware) ssh.PublicKeyHandler {
|
||||||
|
for i := len(middlewares) - 1; i >= 0; i-- { // Reverse order for correct chaining
|
||||||
|
handler = middlewares[i](handler)
|
||||||
|
}
|
||||||
|
return handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func LogPublicKeyMiddleware(next ssh.PublicKeyHandler) ssh.PublicKeyHandler {
|
||||||
|
return func(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||||
|
reqID := getRequestID(ctx.SessionID())
|
||||||
|
request.WithRequestIDSSH(ctx, reqID)
|
||||||
|
log := getLoggerWithRequestID(reqID)
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
log.Info().
|
||||||
|
Str("ssh.user", ctx.User()).
|
||||||
|
Str("ssh.remote", ctx.RemoteAddr().String()).
|
||||||
|
Msg("Public key authentication attempt")
|
||||||
|
|
||||||
|
v := next(ctx, key)
|
||||||
|
// Log session completion
|
||||||
|
duration := time.Since(start)
|
||||||
|
log.Info().
|
||||||
|
Dur("ssh.elapsed_ms", duration).
|
||||||
|
Str("ssh.user", ctx.User()).
|
||||||
|
Msg("Public key authentication attempt completed")
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getLogger(ctx ssh.Context) zerolog.Logger {
|
||||||
|
logger, ok := ctx.Value(loggerKey).(zerolog.Logger)
|
||||||
|
if !ok {
|
||||||
|
logger = log.Logger
|
||||||
|
}
|
||||||
|
return logger
|
||||||
|
}
|
|
@ -31,6 +31,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/harness/gitness/app/api/controller/repo"
|
"github.com/harness/gitness/app/api/controller/repo"
|
||||||
|
"github.com/harness/gitness/app/api/request"
|
||||||
"github.com/harness/gitness/app/auth"
|
"github.com/harness/gitness/app/auth"
|
||||||
"github.com/harness/gitness/app/services/publickey"
|
"github.com/harness/gitness/app/services/publickey"
|
||||||
"github.com/harness/gitness/errors"
|
"github.com/harness/gitness/errors"
|
||||||
|
@ -76,7 +77,6 @@ var (
|
||||||
"hmac-sha2-256",
|
"hmac-sha2-256",
|
||||||
"hmac-sha2-512",
|
"hmac-sha2-512",
|
||||||
}
|
}
|
||||||
defaultServerKeyPath = "ssh/gitness.rsa"
|
|
||||||
KeepAliveMsg = "keepalive@openssh.com"
|
KeepAliveMsg = "keepalive@openssh.com"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -97,6 +97,8 @@ type Server struct {
|
||||||
|
|
||||||
Verifier publickey.Service
|
Verifier publickey.Service
|
||||||
RepoCtrl *repo.Controller
|
RepoCtrl *repo.Controller
|
||||||
|
|
||||||
|
ServerKeyPath string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) sanitize() error {
|
func (s *Server) sanitize() error {
|
||||||
|
@ -133,8 +135,16 @@ func (s *Server) ListenAndServe() error {
|
||||||
}
|
}
|
||||||
s.internal = &ssh.Server{
|
s.internal = &ssh.Server{
|
||||||
Addr: net.JoinHostPort(s.Host, strconv.Itoa(s.Port)),
|
Addr: net.JoinHostPort(s.Host, strconv.Itoa(s.Port)),
|
||||||
Handler: s.sessionHandler,
|
Handler: ChainMiddleware(
|
||||||
PublicKeyHandler: s.publicKeyHandler,
|
s.sessionHandler,
|
||||||
|
PanicRecoverMiddleware,
|
||||||
|
HLogRequestIDHandler,
|
||||||
|
HLogAccessLogHandler,
|
||||||
|
),
|
||||||
|
PublicKeyHandler: ChainPublicKeyMiddleware(
|
||||||
|
s.publicKeyHandler,
|
||||||
|
LogPublicKeyMiddleware,
|
||||||
|
),
|
||||||
PtyCallback: func(ssh.Context, ssh.Pty) bool {
|
PtyCallback: func(ssh.Context, ssh.Pty) bool {
|
||||||
return false
|
return false
|
||||||
},
|
},
|
||||||
|
@ -147,7 +157,6 @@ func (s *Server) ListenAndServe() error {
|
||||||
return config
|
return config
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err = s.setupHostKeys()
|
err = s.setupHostKeys()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to setup host keys: %w", err)
|
return fmt.Errorf("failed to setup host keys: %w", err)
|
||||||
|
@ -173,11 +182,11 @@ func (s *Server) setupHostKeys() error {
|
||||||
|
|
||||||
if len(keys) == 0 {
|
if len(keys) == 0 {
|
||||||
log.Debug().Msg("no host key provided - setup default key if it doesn't exist yet")
|
log.Debug().Msg("no host key provided - setup default key if it doesn't exist yet")
|
||||||
err := createKeyIfNotExists(defaultServerKeyPath)
|
err := createKeyIfNotExists(s.ServerKeyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to setup default key %q: %w", defaultServerKeyPath, err)
|
return fmt.Errorf("failed to setup default key %q: %w", s.ServerKeyPath, err)
|
||||||
}
|
}
|
||||||
keys = append(keys, defaultServerKeyPath)
|
keys = append(keys, s.ServerKeyPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
// set keys to internal ssh server
|
// set keys to internal ssh server
|
||||||
|
@ -247,12 +256,14 @@ func (s *Server) sessionHandler(session ssh.Session) {
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(session.Context())
|
ctx, cancel := context.WithCancel(session.Context())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
log := log.Logger.With().Logger()
|
||||||
|
ctx = request.WithRequestID(ctx, getRequestID(session.Context().SessionID()))
|
||||||
|
ctx = log.WithContext(ctx)
|
||||||
|
|
||||||
// set keep alive connection
|
// set keep alive connection
|
||||||
if s.KeepAliveInterval > 0 {
|
if s.KeepAliveInterval > 0 {
|
||||||
go sendKeepAliveMsg(ctx, session, s.KeepAliveInterval)
|
go sendKeepAliveMsg(ctx, session, s.KeepAliveInterval)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = s.RepoCtrl.GitServicePack(
|
err = s.RepoCtrl.GitServicePack(
|
||||||
ctx,
|
ctx,
|
||||||
&auth.Session{
|
&auth.Session{
|
||||||
|
@ -273,6 +284,7 @@ func (s *Server) sessionHandler(session ssh.Session) {
|
||||||
Stdin: session,
|
Stdin: session,
|
||||||
Stderr: session.Stderr(),
|
Stderr: session.Stderr(),
|
||||||
Protocol: gitProtocol,
|
Protocol: gitProtocol,
|
||||||
|
StatelessRPC: false,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -304,6 +316,9 @@ func sendKeepAliveMsg(ctx context.Context, session ssh.Session, interval time.Du
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
func (s *Server) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||||
|
log := getLoggerWithRequestID(ctx.SessionID())
|
||||||
|
request.WithRequestIDSSH(ctx, getRequestID(ctx.SessionID()))
|
||||||
|
|
||||||
if slices.Contains(publickey.DisallowedTypes, key.Type()) {
|
if slices.Contains(publickey.DisallowedTypes, key.Type()) {
|
||||||
log.Warn().Msgf("public key type not supported: %s", key.Type())
|
log.Warn().Msgf("public key type not supported: %s", key.Type())
|
||||||
return false
|
return false
|
||||||
|
@ -316,7 +331,7 @@ func (s *Server) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
principal, err := s.Verifier.ValidateKey(ctx, key, enum.PublicKeyUsageAuth)
|
principal, err := s.Verifier.ValidateKey(ctx, ctx.User(), key, enum.PublicKeyUsageAuth)
|
||||||
if errors.IsNotFound(err) {
|
if errors.IsNotFound(err) {
|
||||||
log.Debug().Err(err).Msg("public key is unknown")
|
log.Debug().Err(err).Msg("public key is unknown")
|
||||||
return false
|
return false
|
||||||
|
@ -325,6 +340,7 @@ func (s *Server) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||||
log.Warn().Err(err).Msg("failed to validate public key")
|
log.Warn().Err(err).Msg("failed to validate public key")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
log.Debug().Msg("public key verified")
|
||||||
|
|
||||||
// check if we have a certificate
|
// check if we have a certificate
|
||||||
if cert, ok := key.(*gossh.Certificate); ok {
|
if cert, ok := key.(*gossh.Certificate); ok {
|
||||||
|
|
|
@ -28,7 +28,7 @@ var WireSet = wire.NewSet(
|
||||||
|
|
||||||
func ProvideServer(
|
func ProvideServer(
|
||||||
config *types.Config,
|
config *types.Config,
|
||||||
vierifier publickey.Service,
|
verifier publickey.Service,
|
||||||
repoctrl *repo.Controller,
|
repoctrl *repo.Controller,
|
||||||
) *Server {
|
) *Server {
|
||||||
return &Server{
|
return &Server{
|
||||||
|
@ -42,7 +42,8 @@ func ProvideServer(
|
||||||
TrustedUserCAKeys: config.SSH.TrustedUserCAKeys,
|
TrustedUserCAKeys: config.SSH.TrustedUserCAKeys,
|
||||||
TrustedUserCAKeysParsed: config.SSH.TrustedUserCAKeysParsed,
|
TrustedUserCAKeysParsed: config.SSH.TrustedUserCAKeysParsed,
|
||||||
KeepAliveInterval: config.SSH.KeepAliveInterval,
|
KeepAliveInterval: config.SSH.KeepAliveInterval,
|
||||||
Verifier: vierifier,
|
Verifier: verifier,
|
||||||
RepoCtrl: repoctrl,
|
RepoCtrl: repoctrl,
|
||||||
|
ServerKeyPath: config.SSH.ServerKeyPath,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -155,6 +155,7 @@ type Config struct {
|
||||||
TrustedUserCAKeysFile string `envconfig:"GITNESS_SSH_TRUSTED_USER_CA_KEYS_FILENAME"`
|
TrustedUserCAKeysFile string `envconfig:"GITNESS_SSH_TRUSTED_USER_CA_KEYS_FILENAME"`
|
||||||
TrustedUserCAKeysParsed []gossh.PublicKey
|
TrustedUserCAKeysParsed []gossh.PublicKey
|
||||||
KeepAliveInterval time.Duration `envconfig:"GITNESS_SSH_KEEP_ALIVE_INTERVAL" default:"5s"`
|
KeepAliveInterval time.Duration `envconfig:"GITNESS_SSH_KEEP_ALIVE_INTERVAL" default:"5s"`
|
||||||
|
ServerKeyPath string `envconfig:"GITNESS_SSH_SERVER_KEY_PATH" default:"ssh/gitness.rsa"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// CI defines configuration related to build executions.
|
// CI defines configuration related to build executions.
|
||||||
|
|
Loading…
Reference in New Issue