drone/ssh/middleware.go

139 lines
3.8 KiB
Go

// Copyright 2023 Harness, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ssh
import (
"runtime/debug"
"time"
"github.com/harness/gitness/app/api/request"
"github.com/gliderlabs/ssh"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
type Middleware func(ssh.Handler) ssh.Handler
// ChainMiddleware combines multiple middleware into a single ssh.Handler.
func ChainMiddleware(handler ssh.Handler, middlewares ...Middleware) ssh.Handler {
for i := len(middlewares) - 1; i >= 0; i-- { // Reverse order to maintain correct chaining
handler = middlewares[i](handler)
}
return handler
}
// PanicRecoverMiddleware wraps the SSH handler to recover from panics and log them.
func PanicRecoverMiddleware(next ssh.Handler) ssh.Handler {
return func(s ssh.Session) {
defer func() {
if r := recover(); r != nil {
// Log the panic and stack trace
// Get the context and logger
ctx := s.Context()
logger := getLogger(ctx)
logger.Error().Msgf("encountered panic while processing ssh operation: %v\n%s", r, debug.Stack())
_, _ = s.Write([]byte("Internal server error. Please try again later.\n"))
}
}()
// Call the next handler
next(s)
}
}
func HLogAccessLogHandler(next ssh.Handler) ssh.Handler {
return func(s ssh.Session) {
start := time.Now()
user := s.User()
remoteAddr := s.RemoteAddr()
command := s.Command()
// Get the context and logger
ctx := s.Context()
logger := getLogger(ctx)
// Log session start
logger.Info().
Str("ssh.user", user).
Str("ssh.remote", remoteAddr.String()).
Strs("ssh.command", command).
Msg("SSH session started")
// Call the next handler
next(s)
// Log session completion
duration := time.Since(start)
logger.Info().
Dur("ssh.elapsed_ms", duration).
Str("ssh.user", user).
Msg("SSH session completed")
}
}
func HLogRequestIDHandler(next ssh.Handler) ssh.Handler {
return func(s ssh.Session) {
sshCtx := s.Context() // This is ssh.Context
reqID := getRequestID(sshCtx.SessionID())
request.WithRequestIDSSH(sshCtx, reqID)
log := getLoggerWithRequestID(reqID)
sshCtx.SetValue(loggerKey, log)
// continue serving request
next(s)
}
}
type PublicKeyMiddleware func(next ssh.PublicKeyHandler) ssh.PublicKeyHandler
func ChainPublicKeyMiddleware(handler ssh.PublicKeyHandler, middlewares ...PublicKeyMiddleware) ssh.PublicKeyHandler {
for i := len(middlewares) - 1; i >= 0; i-- { // Reverse order for correct chaining
handler = middlewares[i](handler)
}
return handler
}
func LogPublicKeyMiddleware(next ssh.PublicKeyHandler) ssh.PublicKeyHandler {
return func(ctx ssh.Context, key ssh.PublicKey) bool {
reqID := getRequestID(ctx.SessionID())
request.WithRequestIDSSH(ctx, reqID)
log := getLoggerWithRequestID(reqID)
start := time.Now()
log.Info().
Str("ssh.user", ctx.User()).
Str("ssh.remote", ctx.RemoteAddr().String()).
Msg("Public key authentication attempt")
v := next(ctx, key)
// Log session completion
duration := time.Since(start)
log.Info().
Dur("ssh.elapsed_ms", duration).
Str("ssh.user", ctx.User()).
Msg("Public key authentication attempt completed")
return v
}
}
func getLogger(ctx ssh.Context) zerolog.Logger {
logger, ok := ctx.Value(loggerKey).(zerolog.Logger)
if !ok {
logger = log.Logger
}
return logger
}