diff --git a/app.go b/app.go index 284fea2d..a79e2841 100644 --- a/app.go +++ b/app.go @@ -115,7 +115,7 @@ type App struct { latestRoute *Route latestGroup *Group // TLS handler - tlsHandler *tlsHandler + tlsHandler *TLSHandler } // Config is a struct holding the server settings. @@ -570,6 +570,14 @@ func (app *App) handleTrustedProxy(ipAddress string) { } } +// You can use SetTLSHandler to use ClientHelloInfo when using TLS with Listener. +func (app *App) SetTLSHandler(tlsHandler *TLSHandler) { + // Attach the tlsHandler to the config + app.mutex.Lock() + app.tlsHandler = tlsHandler + app.mutex.Unlock() +} + // Mount attaches another app instance as a sub-router along a routing path. // It's very useful to split up a large API as many independent routers and // compose them as a single service using Mount. The fiber's error handler and diff --git a/app_test.go b/app_test.go index 716970fb..b20e0c45 100644 --- a/app_test.go +++ b/app_test.go @@ -6,6 +6,7 @@ package fiber import ( "bytes" + "crypto/tls" "errors" "fmt" "io" @@ -1560,3 +1561,17 @@ func Test_App_Test_no_timeout_infinitely(t *testing.T) { t.FailNow() } } + +func Test_App_SetTLSHandler(t *testing.T) { + tlsHandler := &TLSHandler{clientHelloInfo: &tls.ClientHelloInfo{ + ServerName: "example.golang", + }} + + app := New() + app.SetTLSHandler(tlsHandler) + + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + + utils.AssertEqual(t, "example.golang", c.ClientHelloInfo().ServerName) +} diff --git a/ctx.go b/ctx.go index f956c3b6..ed80c48c 100644 --- a/ctx.go +++ b/ctx.go @@ -68,13 +68,13 @@ type Ctx struct { viewBindMap *dictpool.Dict // Default view map to bind template engine } -// tlsHandle object -type tlsHandler struct { +// TLSHandler object +type TLSHandler struct { clientHelloInfo *tls.ClientHelloInfo } // GetClientInfo Callback function to set CHI -func (t *tlsHandler) GetClientInfo(info *tls.ClientHelloInfo) (*tls.Certificate, error) { +func (t *TLSHandler) GetClientInfo(info *tls.ClientHelloInfo) (*tls.Certificate, error) { t.clientHelloInfo = info return nil, nil } diff --git a/ctx_test.go b/ctx_test.go index 273dd14e..31f47e7a 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -1462,7 +1462,7 @@ func Test_Ctx_ClientHelloInfo(t *testing.T) { PSSWithSHA256 = 0x0804 VersionTLS13 = 0x0304 ) - app.tlsHandler = &tlsHandler{clientHelloInfo: &tls.ClientHelloInfo{ + app.tlsHandler = &TLSHandler{clientHelloInfo: &tls.ClientHelloInfo{ ServerName: "example.golang", SignatureSchemes: []tls.SignatureScheme{PSSWithSHA256}, SupportedVersions: []uint16{VersionTLS13}, diff --git a/listen.go b/listen.go index 4d8d25b7..ca55bb06 100644 --- a/listen.go +++ b/listen.go @@ -31,16 +31,20 @@ func (app *App) Listener(ln net.Listener) error { addr, tlsConfig := lnMetadata(app.config.Network, ln) return app.prefork(app.config.Network, addr, tlsConfig) } + // prepare the server for the start app.startupProcess() + // Print startup message if !app.config.DisableStartupMessage { app.startupMessage(ln.Addr().String(), getTlsConfig(ln) != nil, "") } + // Print routes if app.config.EnablePrintRoutes { app.printRoutesMessage() } + // Start listening return app.server.Serve(ln) } @@ -54,21 +58,26 @@ func (app *App) Listen(addr string) error { if app.config.Prefork { return app.prefork(app.config.Network, addr, nil) } + // Setup listener ln, err := net.Listen(app.config.Network, addr) if err != nil { return err } + // prepare the server for the start app.startupProcess() + // Print startup message if !app.config.DisableStartupMessage { app.startupMessage(ln.Addr().String(), false, "") } + // Print routes if app.config.EnablePrintRoutes { app.printRoutesMessage() } + // Start listening return app.server.Serve(ln) } @@ -82,12 +91,14 @@ func (app *App) ListenTLS(addr, certFile, keyFile string) error { if len(certFile) == 0 || len(keyFile) == 0 { return errors.New("tls: provide a valid cert or key path") } + // Set TLS config with handler cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %s", certFile, keyFile, err) } - tlsHandler := &tlsHandler{} + + tlsHandler := &TLSHandler{} config := &tls.Config{ MinVersion: tls.VersionTLS12, Certificates: []tls.Certificate{ @@ -95,6 +106,7 @@ func (app *App) ListenTLS(addr, certFile, keyFile string) error { }, GetCertificate: tlsHandler.GetClientInfo, } + // Prefork is supported if app.config.Prefork { return app.prefork(app.config.Network, addr, config) @@ -103,23 +115,25 @@ func (app *App) ListenTLS(addr, certFile, keyFile string) error { // Setup listener ln, err := net.Listen(app.config.Network, addr) ln = tls.NewListener(ln, config) - if err != nil { return err } + // prepare the server for the start app.startupProcess() + // Print startup message if !app.config.DisableStartupMessage { app.startupMessage(ln.Addr().String(), true, "") } + // Print routes if app.config.EnablePrintRoutes { app.printRoutesMessage() } // Attach the tlsHandler to the config - app.tlsHandler = tlsHandler + app.SetTLSHandler(tlsHandler) // Start listening return app.server.Serve(ln) @@ -147,7 +161,7 @@ func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string) clientCertPool := x509.NewCertPool() clientCertPool.AppendCertsFromPEM(clientCACert) - tlsHandler := &tlsHandler{} + tlsHandler := &TLSHandler{} config := &tls.Config{ MinVersion: tls.VersionTLS12, ClientAuth: tls.RequireAndVerifyClientCert, @@ -183,7 +197,7 @@ func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string) } // Attach the tlsHandler to the config - app.tlsHandler = tlsHandler + app.SetTLSHandler(tlsHandler) // Start listening return app.server.Serve(ln)