package fiber import ( "bytes" "context" "crypto/tls" "errors" "fmt" "io" "log" //nolint:depguard // TODO: Required to capture output, use internal log package instead "net" "os" "strings" "sync" "testing" "time" "github.com/gofiber/utils/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttputil" ) // go test -run Test_Listen func Test_Listen(t *testing.T) { app := New() require.Error(t, app.Listen(":99999")) go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":4003", ListenConfig{DisableStartupMessage: true})) } // go test -run Test_Listen_Graceful_Shutdown func Test_Listen_Graceful_Shutdown(t *testing.T) { t.Run("Basic Graceful Shutdown", func(t *testing.T) { testGracefulShutdown(t, 0) }) t.Run("Shutdown With Timeout", func(t *testing.T) { testGracefulShutdown(t, 500*time.Millisecond) }) t.Run("Shutdown With Timeout Error", func(t *testing.T) { testGracefulShutdown(t, 1*time.Nanosecond) }) } func testGracefulShutdown(t *testing.T, shutdownTimeout time.Duration) { t.Helper() var mu sync.Mutex var shutdown bool var receivedErr error app := New() app.Get("/", func(c Ctx) error { time.Sleep(10 * time.Millisecond) return c.SendString(c.Hostname()) }) ln := fasthttputil.NewInmemoryListener() errs := make(chan error, 1) app.hooks.OnPostShutdown(func(err error) error { mu.Lock() defer mu.Unlock() shutdown = true receivedErr = err return nil }) go func() { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() errs <- app.Listener(ln, ListenConfig{ DisableStartupMessage: true, GracefulContext: ctx, ShutdownTimeout: shutdownTimeout, }) }() require.Eventually(t, func() bool { conn, err := ln.Dial() if err == nil { if err := conn.Close(); err != nil { t.Logf("error closing connection: %v", err) } return true } return false }, time.Second, 100*time.Millisecond, "Server failed to become ready") client := fasthttp.HostClient{ Dial: func(_ string) (net.Conn, error) { return ln.Dial() }, } type testCase struct { expectedErr error expectedBody string name string waitTime time.Duration expectedStatusCode int closeConnection bool } testCases := []testCase{ { name: "Server running normally", waitTime: 500 * time.Millisecond, expectedBody: "example.com", expectedStatusCode: StatusOK, expectedErr: nil, closeConnection: true, }, { name: "Server shutdown complete", waitTime: 3 * time.Second, expectedBody: "", expectedStatusCode: StatusOK, expectedErr: fasthttputil.ErrInmemoryListenerClosed, closeConnection: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { time.Sleep(tc.waitTime) req := fasthttp.AcquireRequest() defer fasthttp.ReleaseRequest(req) req.SetRequestURI("http://example.com") resp := fasthttp.AcquireResponse() defer fasthttp.ReleaseResponse(resp) err := client.Do(req, resp) if tc.expectedErr == nil { require.NoError(t, err) require.Equal(t, tc.expectedStatusCode, resp.StatusCode()) require.Equal(t, tc.expectedBody, utils.UnsafeString(resp.Body())) } else { require.ErrorIs(t, err, tc.expectedErr) } }) } mu.Lock() require.True(t, shutdown) if shutdownTimeout == 1*time.Nanosecond { require.Error(t, receivedErr) require.ErrorIs(t, receivedErr, context.DeadlineExceeded) } require.NoError(t, <-errs) mu.Unlock() } // go test -run Test_Listen_Prefork func Test_Listen_Prefork(t *testing.T) { testPreforkMaster = true app := New() require.NoError(t, app.Listen(":99999", ListenConfig{DisableStartupMessage: true, EnablePrefork: true})) } // go test -run Test_Listen_TLSMinVersion func Test_Listen_TLSMinVersion(t *testing.T) { testPreforkMaster = true app := New() // Invalid TLSMinVersion require.Panics(t, func() { _ = app.Listen(":443", ListenConfig{TLSMinVersion: tls.VersionTLS10}) //nolint:errcheck // ignore error }) require.Panics(t, func() { _ = app.Listen(":443", ListenConfig{TLSMinVersion: tls.VersionTLS11}) //nolint:errcheck // ignore error }) // Prefork require.Panics(t, func() { _ = app.Listen(":443", ListenConfig{DisableStartupMessage: true, EnablePrefork: true, TLSMinVersion: tls.VersionTLS10}) //nolint:errcheck // ignore error }) require.Panics(t, func() { _ = app.Listen(":443", ListenConfig{DisableStartupMessage: true, EnablePrefork: true, TLSMinVersion: tls.VersionTLS11}) //nolint:errcheck // ignore error }) // Valid TLSMinVersion go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":0", ListenConfig{TLSMinVersion: tls.VersionTLS13})) // Valid TLSMinVersion with Prefork go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":99999", ListenConfig{DisableStartupMessage: true, EnablePrefork: true, TLSMinVersion: tls.VersionTLS13})) } // go test -run Test_Listen_TLS func Test_Listen_TLS(t *testing.T) { app := New() // invalid port require.Error(t, app.Listen(":99999", ListenConfig{ CertFile: "./.github/testdata/ssl.pem", CertKeyFile: "./.github/testdata/ssl.key", })) go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":0", ListenConfig{ CertFile: "./.github/testdata/ssl.pem", CertKeyFile: "./.github/testdata/ssl.key", })) } // go test -run Test_Listen_TLS_Prefork func Test_Listen_TLS_Prefork(t *testing.T) { testPreforkMaster = true app := New() // invalid key file content require.Error(t, app.Listen(":0", ListenConfig{ DisableStartupMessage: true, EnablePrefork: true, CertFile: "./.github/testdata/ssl.pem", CertKeyFile: "./.github/testdata/template.tmpl", })) go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":99999", ListenConfig{ DisableStartupMessage: true, EnablePrefork: true, CertFile: "./.github/testdata/ssl.pem", CertKeyFile: "./.github/testdata/ssl.key", })) } // go test -run Test_Listen_MutualTLS func Test_Listen_MutualTLS(t *testing.T) { app := New() // invalid port require.Error(t, app.Listen(":99999", ListenConfig{ CertFile: "./.github/testdata/ssl.pem", CertKeyFile: "./.github/testdata/ssl.key", CertClientFile: "./.github/testdata/ca-chain.cert.pem", })) go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":0", ListenConfig{ CertFile: "./.github/testdata/ssl.pem", CertKeyFile: "./.github/testdata/ssl.key", CertClientFile: "./.github/testdata/ca-chain.cert.pem", })) } // go test -run Test_Listen_MutualTLS_Prefork func Test_Listen_MutualTLS_Prefork(t *testing.T) { testPreforkMaster = true app := New() // invalid key file content require.Error(t, app.Listen(":0", ListenConfig{ DisableStartupMessage: true, EnablePrefork: true, CertFile: "./.github/testdata/ssl.pem", CertKeyFile: "./.github/testdata/template.html", CertClientFile: "./.github/testdata/ca-chain.cert.pem", })) go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":99999", ListenConfig{ DisableStartupMessage: true, EnablePrefork: true, CertFile: "./.github/testdata/ssl.pem", CertKeyFile: "./.github/testdata/ssl.key", CertClientFile: "./.github/testdata/ca-chain.cert.pem", })) } // go test -run Test_Listener func Test_Listener(t *testing.T) { app := New() go func() { time.Sleep(500 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() ln := fasthttputil.NewInmemoryListener() require.NoError(t, app.Listener(ln)) } func Test_App_Listener_TLS_Listener(t *testing.T) { // Create tls certificate cer, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key") if err != nil { require.NoError(t, err) } //nolint:gosec // We're in a test so using old ciphers is fine config := &tls.Config{Certificates: []tls.Certificate{cer}} //nolint:gosec // We're in a test so listening on all interfaces is fine ln, err := tls.Listen(NetworkTCP4, ":0", config) require.NoError(t, err) app := New() go func() { time.Sleep(time.Millisecond * 500) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listener(ln)) } // go test -run Test_Listen_TLSConfigFunc func Test_Listen_TLSConfigFunc(t *testing.T) { var callTLSConfig bool app := New() go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":0", ListenConfig{ DisableStartupMessage: true, TLSConfigFunc: func(_ *tls.Config) { callTLSConfig = true }, CertFile: "./.github/testdata/ssl.pem", CertKeyFile: "./.github/testdata/ssl.key", })) require.True(t, callTLSConfig) } // go test -run Test_Listen_ListenerAddrFunc func Test_Listen_ListenerAddrFunc(t *testing.T) { var network string app := New() go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":0", ListenConfig{ DisableStartupMessage: true, ListenerAddrFunc: func(addr net.Addr) { network = addr.Network() }, CertFile: "./.github/testdata/ssl.pem", CertKeyFile: "./.github/testdata/ssl.key", })) require.Equal(t, "tcp", network) } // go test -run Test_Listen_BeforeServeFunc func Test_Listen_BeforeServeFunc(t *testing.T) { var handlers uint32 app := New() go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() wantErr := errors.New("test") require.ErrorIs(t, app.Listen(":0", ListenConfig{ DisableStartupMessage: true, BeforeServeFunc: func(fiber *App) error { handlers = fiber.HandlersCount() return wantErr }, }), wantErr) require.Zero(t, handlers) } // go test -run Test_Listen_ListenerNetwork func Test_Listen_ListenerNetwork(t *testing.T) { var network string app := New() go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":0", ListenConfig{ DisableStartupMessage: true, ListenerNetwork: NetworkTCP6, ListenerAddrFunc: func(addr net.Addr) { network = addr.String() }, })) require.Contains(t, network, "[::]:") go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":0", ListenConfig{ DisableStartupMessage: true, ListenerNetwork: NetworkTCP4, ListenerAddrFunc: func(addr net.Addr) { network = addr.String() }, })) require.Contains(t, network, "0.0.0.0:") } // go test -run Test_Listen_Master_Process_Show_Startup_Message func Test_Listen_Master_Process_Show_Startup_Message(t *testing.T) { cfg := ListenConfig{ EnablePrefork: true, } startupMessage := captureOutput(func() { New(). startupMessage(":3000", true, strings.Repeat(",11111,22222,33333,44444,55555,60000", 10), cfg) }) colors := Colors{} require.Contains(t, startupMessage, "https://127.0.0.1:3000") require.Contains(t, startupMessage, "(bound on host 0.0.0.0 and port 3000)") require.Contains(t, startupMessage, "Child PIDs") require.Contains(t, startupMessage, "11111, 22222, 33333, 44444, 55555, 60000") require.Contains(t, startupMessage, fmt.Sprintf("Prefork: \t\t\t%sEnabled%s", colors.Blue, colors.Reset)) } // go test -run Test_Listen_Master_Process_Show_Startup_MessageWithAppName func Test_Listen_Master_Process_Show_Startup_MessageWithAppName(t *testing.T) { cfg := ListenConfig{ EnablePrefork: true, } app := New(Config{AppName: "Test App v3.0.0"}) startupMessage := captureOutput(func() { app.startupMessage(":3000", true, strings.Repeat(",11111,22222,33333,44444,55555,60000", 10), cfg) }) require.Equal(t, "Test App v3.0.0", app.Config().AppName) require.Contains(t, startupMessage, app.Config().AppName) } // go test -run Test_Listen_Master_Process_Show_Startup_MessageWithAppNameNonAscii func Test_Listen_Master_Process_Show_Startup_MessageWithAppNameNonAscii(t *testing.T) { cfg := ListenConfig{ EnablePrefork: true, } appName := "Serveur de vérification des données" app := New(Config{AppName: appName}) startupMessage := captureOutput(func() { app.startupMessage(":3000", false, "", cfg) }) require.Contains(t, startupMessage, "Serveur de vérification des données") } // go test -run Test_Listen_Master_Process_Show_Startup_MessageWithDisabledPreforkAndCustomEndpoint func Test_Listen_Master_Process_Show_Startup_MessageWithDisabledPreforkAndCustomEndpoint(t *testing.T) { cfg := ListenConfig{ EnablePrefork: false, } appName := "Fiber Example Application" app := New(Config{AppName: appName}) startupMessage := captureOutput(func() { app.startupMessage("server.com:8081", true, strings.Repeat(",11111,22222,33333,44444,55555,60000", 5), cfg) }) colors := Colors{} require.Contains(t, startupMessage, fmt.Sprintf("%sINFO%s", colors.Green, colors.Reset)) require.Contains(t, startupMessage, fmt.Sprintf("%s%s%s", colors.Blue, appName, colors.Reset)) require.Contains(t, startupMessage, fmt.Sprintf("%s%s%s", colors.Blue, "https://server.com:8081", colors.Reset)) require.Contains(t, startupMessage, fmt.Sprintf("Prefork: \t\t\t%sDisabled%s", colors.Red, colors.Reset)) } // go test -run Test_Listen_Print_Route func Test_Listen_Print_Route(t *testing.T) { app := New() app.Get("/", emptyHandler).Name("routeName") printRoutesMessage := captureOutput(func() { app.printRoutesMessage() }) require.Contains(t, printRoutesMessage, MethodGet) require.Contains(t, printRoutesMessage, "/") require.Contains(t, printRoutesMessage, "emptyHandler") require.Contains(t, printRoutesMessage, "routeName") } // go test -run Test_Listen_Print_Route_With_Group func Test_Listen_Print_Route_With_Group(t *testing.T) { app := New() app.Get("/", emptyHandler) v1 := app.Group("v1") v1.Get("/test", emptyHandler).Name("v1") v1.Post("/test/fiber", emptyHandler) v1.Put("/test/fiber/*", emptyHandler) printRoutesMessage := captureOutput(func() { app.printRoutesMessage() }) require.Contains(t, printRoutesMessage, MethodGet) require.Contains(t, printRoutesMessage, "/") require.Contains(t, printRoutesMessage, "emptyHandler") require.Contains(t, printRoutesMessage, "/v1/test") require.Contains(t, printRoutesMessage, "POST") require.Contains(t, printRoutesMessage, "/v1/test/fiber") require.Contains(t, printRoutesMessage, "PUT") require.Contains(t, printRoutesMessage, "/v1/test/fiber/*") } func captureOutput(f func()) string { reader, writer, err := os.Pipe() if err != nil { panic(err) } stdout := os.Stdout stderr := os.Stderr defer func() { os.Stdout = stdout os.Stderr = stderr log.SetOutput(os.Stderr) }() os.Stdout = writer os.Stderr = writer log.SetOutput(writer) out := make(chan string) wg := new(sync.WaitGroup) wg.Add(1) go func() { var buf bytes.Buffer wg.Done() _, err := io.Copy(&buf, reader) if err != nil { panic(err) } out <- buf.String() }() wg.Wait() f() err = writer.Close() if err != nil { panic(err) } return <-out } func emptyHandler(_ Ctx) error { return nil }