diff --git a/registry/app/remote/adapter/awsecr/adapter.go b/registry/app/remote/adapter/awsecr/adapter.go index cc699144e..0c87f0237 100644 --- a/registry/app/remote/adapter/awsecr/adapter.go +++ b/registry/app/remote/adapter/awsecr/adapter.go @@ -51,15 +51,19 @@ func init() { func newAdapter( ctx context.Context, spacePathStore store2.SpacePathStore, service secret.Service, registry types.UpstreamProxy, ) (adp.Adapter, error) { - accessKey, secretKey, err := getCreds(ctx, spacePathStore, service, registry) + accessKey, secretKey, isPublic, err := getCreds(ctx, spacePathStore, service, registry) if err != nil { return nil, err } - svc, err := getAwsSvc(accessKey, secretKey, registry) - if err != nil { - return nil, err + var svc *awsecrapi.ECR + if !isPublic { + svc, err = getAwsSvc(accessKey, secretKey, registry) + if err != nil { + return nil, err + } } - authorizer := NewAuth(accessKey, svc) + + authorizer := NewAuth(accessKey, svc, isPublic) return &adapter{ cacheSvc: svc, diff --git a/registry/app/remote/adapter/awsecr/auth.go b/registry/app/remote/adapter/awsecr/auth.go index 500612650..a287bda28 100644 --- a/registry/app/remote/adapter/awsecr/auth.go +++ b/registry/app/remote/adapter/awsecr/auth.go @@ -19,8 +19,10 @@ package awsecr import ( "context" "encoding/base64" + "encoding/json" "errors" "fmt" + "io" "net/http" "net/url" "strings" @@ -51,6 +53,7 @@ type awsAuthCredential struct { cacheToken *cacheToken cacheExpired *time.Time + isPublic bool } type cacheToken struct { @@ -69,7 +72,7 @@ func (a *awsAuthCredential) Modify(req *http.Request) error { return nil } if !a.isTokenValid() { - endpoint, user, pass, expiresAt, err := a.getAuthorization(req.URL.String()) + endpoint, user, pass, expiresAt, err := a.getAuthorization(req.URL.String(), req.URL.Host) if err != nil { return err @@ -84,7 +87,7 @@ func (a *awsAuthCredential) Modify(req *http.Request) error { a.cacheToken.password = pass a.cacheToken.endpoint = endpoint t := time.Now().Add(DefaultCacheExpiredTime) - if t.Before(*expiresAt) { + if expiresAt == nil || t.Before(*expiresAt) { a.cacheExpired = &t } else { a.cacheExpired = expiresAt @@ -92,6 +95,10 @@ func (a *awsAuthCredential) Modify(req *http.Request) error { } req.Host = a.cacheToken.host req.URL.Host = a.cacheToken.host + if a.isPublic { + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", a.cacheToken.password)) + return nil + } req.SetBasicAuth(a.cacheToken.user, a.cacheToken.password) return nil } @@ -136,25 +143,28 @@ func parseAccountRegion(url string) (string, string, error) { func getCreds( ctx context.Context, spacePathStore store.SpacePathStore, secretService secret.Service, reg types.UpstreamProxy, -) (string, string, error) { +) (string, string, bool, error) { + if api.AuthType(reg.RepoAuthType) == api.AuthTypeAnonymous { + return "", "", true, nil + } if api.AuthType(reg.RepoAuthType) != api.AuthTypeAccessKeySecretKey { log.Debug().Msgf("invalid auth type: %s", reg.RepoAuthType) - return "", "", nil + return "", "", false, nil } secretKey, err := getSecretValue(ctx, spacePathStore, secretService, reg.SecretSpaceID, reg.SecretIdentifier) if err != nil { - return "", "", err + return "", "", false, err } if reg.UserName != "" { - return reg.UserName, secretKey, nil + return reg.UserName, secretKey, false, nil } accessKey, err := getSecretValue(ctx, spacePathStore, secretService, reg.UserNameSecretSpaceID, reg.UserNameSecretIdentifier) if err != nil { - return "", "", err + return "", "", false, err } - return accessKey, secretKey, nil + return accessKey, secretKey, false, nil } func getSecretValue(ctx context.Context, spacePathStore store.SpacePathStore, secretService secret.Service, @@ -172,7 +182,14 @@ func getSecretValue(ctx context.Context, spacePathStore store.SpacePathStore, se return decryptSecret, nil } -func (a *awsAuthCredential) getAuthorization(url string) (string, string, string, *time.Time, error) { +func (a *awsAuthCredential) getAuthorization(url, host string) (string, string, string, *time.Time, error) { + if a.isPublic { + token, err := a.getPublicECRToken(host) + if err != nil { + return "", "", "", nil, err + } + return url, "", token, nil, nil + } id, _, err := parseAccountRegion(url) if err != nil { return "", "", "", nil, err @@ -225,9 +242,51 @@ func (a *awsAuthCredential) isTokenValid() bool { } // NewAuth new aws auth. -func NewAuth(accessKey string, awssvc *awsecrapi.ECR) Credential { +func NewAuth(accessKey string, awssvc *awsecrapi.ECR, isPublic bool) Credential { return &awsAuthCredential{ accessKey: accessKey, awssvc: awssvc, + isPublic: isPublic, } } + +func (a *awsAuthCredential) getPublicECRToken(host string) (string, error) { + c := &http.Client{ + Transport: commonhttp.GetHTTPTransport(commonhttp.WithInsecure(true)), + } + req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, buildTokenURL(host, host), nil) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + resp, err := c.Do(req) + if err != nil { + return "", fmt.Errorf("failed to fetch token: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("non-200 response: %d %s", resp.StatusCode, http.StatusText(resp.StatusCode)) + } + + // Parse the response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response body: %w", err) + } + + // Unmarshal JSON + var tokenResponse TokenResponse + if err := json.Unmarshal(body, &tokenResponse); err != nil { + return "", fmt.Errorf("failed to parse token response: %w", err) + } + + return tokenResponse.Token, nil +} + +type TokenResponse struct { + Token string `json:"token"` +} + +func buildTokenURL(host, service string) string { + return fmt.Sprintf("https://%s/token?service=%s", host, service) +}