package client import ( "context" "errors" "net" "testing" "time" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp/fasthttputil" ) func Test_AddMissing_Port(t *testing.T) { t.Parallel() type args struct { addr string isTLS bool } tests := []struct { name string want string args args }{ { name: "do anything", args: args{ addr: "example.com:1234", }, want: "example.com:1234", }, { name: "add 80 port", args: args{ addr: "example.com", }, want: "example.com:80", }, { name: "add 443 port", args: args{ addr: "example.com", isTLS: true, }, want: "example.com:443", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() require.Equal(t, tt.want, addMissingPort(tt.args.addr, tt.args.isTLS)) }) } } func Test_Exec_Func(t *testing.T) { t.Parallel() ln := fasthttputil.NewInmemoryListener() app := fiber.New() app.Get("/normal", func(c fiber.Ctx) error { return c.SendString(c.Hostname()) }) app.Get("/return-error", func(_ fiber.Ctx) error { return errors.New("the request is error") }) app.Get("/hang-up", func(c fiber.Ctx) error { time.Sleep(time.Second) return c.SendString(c.Hostname() + " hang up") }) go func() { assert.NoError(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) }() time.Sleep(300 * time.Millisecond) t.Run("normal request", func(t *testing.T) { t.Parallel() core, client, req := newCore(), New(), AcquireRequest() core.ctx = context.Background() core.client = client core.req = req client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req.RawRequest.SetRequestURI("http://example.com/normal") resp, err := core.execFunc() require.NoError(t, err) require.Equal(t, 200, resp.RawResponse.StatusCode()) require.Equal(t, "example.com", string(resp.RawResponse.Body())) }) t.Run("the request return an error", func(t *testing.T) { t.Parallel() core, client, req := newCore(), New(), AcquireRequest() core.ctx = context.Background() core.client = client core.req = req client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req.RawRequest.SetRequestURI("http://example.com/return-error") resp, err := core.execFunc() require.NoError(t, err) require.Equal(t, 500, resp.RawResponse.StatusCode()) require.Equal(t, "the request is error", string(resp.RawResponse.Body())) }) t.Run("the request timeout", func(t *testing.T) { t.Parallel() core, client, req := newCore(), New(), AcquireRequest() ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() core.ctx = ctx core.client = client core.req = req client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req.RawRequest.SetRequestURI("http://example.com/hang-up") _, err := core.execFunc() require.Equal(t, ErrTimeoutOrCancel, err) }) } func Test_Execute(t *testing.T) { t.Parallel() ln := fasthttputil.NewInmemoryListener() app := fiber.New() app.Get("/normal", func(c fiber.Ctx) error { return c.SendString(c.Hostname()) }) app.Get("/return-error", func(_ fiber.Ctx) error { return errors.New("the request is error") }) app.Get("/hang-up", func(c fiber.Ctx) error { time.Sleep(time.Second) return c.SendString(c.Hostname() + " hang up") }) go func() { assert.NoError(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) }() t.Run("add user request hooks", func(t *testing.T) { t.Parallel() core, client, req := newCore(), New(), AcquireRequest() client.AddRequestHook(func(_ *Client, _ *Request) error { require.Equal(t, "http://example.com", req.URL()) return nil }) client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req.SetURL("http://example.com") resp, err := core.execute(context.Background(), client, req) require.NoError(t, err) require.Equal(t, "Cannot GET /", string(resp.RawResponse.Body())) }) t.Run("add user response hooks", func(t *testing.T) { t.Parallel() core, client, req := newCore(), New(), AcquireRequest() client.AddResponseHook(func(_ *Client, _ *Response, req *Request) error { require.Equal(t, "http://example.com", req.URL()) return nil }) client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req.SetURL("http://example.com") resp, err := core.execute(context.Background(), client, req) require.NoError(t, err) require.Equal(t, "Cannot GET /", string(resp.RawResponse.Body())) }) t.Run("no timeout", func(t *testing.T) { t.Parallel() core, client, req := newCore(), New(), AcquireRequest() client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req.SetURL("http://example.com/hang-up") resp, err := core.execute(context.Background(), client, req) require.NoError(t, err) require.Equal(t, "example.com hang up", string(resp.RawResponse.Body())) }) t.Run("client timeout", func(t *testing.T) { t.Parallel() core, client, req := newCore(), New(), AcquireRequest() client.SetTimeout(500 * time.Millisecond) client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req.SetURL("http://example.com/hang-up") _, err := core.execute(context.Background(), client, req) require.Equal(t, ErrTimeoutOrCancel, err) }) t.Run("request timeout", func(t *testing.T) { t.Parallel() core, client, req := newCore(), New(), AcquireRequest() client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req.SetURL("http://example.com/hang-up"). SetTimeout(300 * time.Millisecond) _, err := core.execute(context.Background(), client, req) require.Equal(t, ErrTimeoutOrCancel, err) }) t.Run("request timeout has higher level", func(t *testing.T) { t.Parallel() core, client, req := newCore(), New(), AcquireRequest() client.SetTimeout(30 * time.Millisecond) client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req.SetURL("http://example.com/hang-up"). SetTimeout(3000 * time.Millisecond) resp, err := core.execute(context.Background(), client, req) require.NoError(t, err) require.Equal(t, "example.com hang up", string(resp.RawResponse.Body())) }) }