diff --git a/main.go b/main.go index cbf712a..071b950 100644 --- a/main.go +++ b/main.go @@ -17,6 +17,9 @@ import ( "syscall" "time" + "github.com/google/go-containerregistry/pkg/authn" + "github.com/google/go-containerregistry/pkg/name" + "github.com/google/go-containerregistry/pkg/v1/remote" "github.com/sirupsen/logrus" ) @@ -32,9 +35,6 @@ type Config struct { var config Config var client = &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, Timeout: 30 * time.Second, Transport: &http.Transport{ DisableKeepAlives: false, @@ -46,6 +46,11 @@ var client = &http.Client{ }, } +var remoteOptions = []remote.Option{ + remote.WithAuth(authn.Anonymous), + remote.WithTransport(client.Transport), +} + // Token 缓存 type TokenCache struct { mu sync.RWMutex @@ -64,7 +69,6 @@ var tokenCache = &TokenCache{ func (tc *TokenCache) Get(key string) (string, bool) { tc.mu.RLock() defer tc.mu.RUnlock() - if cached, ok := tc.cache[key]; ok { if time.Now().Before(cached.ExpiresAt) { return cached.Token, true @@ -77,7 +81,6 @@ func (tc *TokenCache) Get(key string) (string, bool) { func (tc *TokenCache) Set(key, token string, ttl time.Duration) { tc.mu.Lock() defer tc.mu.Unlock() - tc.cache[key] = &CachedToken{ Token: token, ExpiresAt: time.Now().Add(ttl), @@ -92,7 +95,8 @@ type ManifestCache struct { type CachedManifest struct { Data []byte - Headers http.Header + MediaType string + Digest string ExpiresAt time.Time } @@ -103,7 +107,6 @@ var manifestCache = &ManifestCache{ func (mc *ManifestCache) Get(key string) (*CachedManifest, bool) { mc.mu.RLock() defer mc.mu.RUnlock() - if cached, ok := mc.cache[key]; ok { if time.Now().Before(cached.ExpiresAt) { return cached, true @@ -113,13 +116,13 @@ func (mc *ManifestCache) Get(key string) (*CachedManifest, bool) { return nil, false } -func (mc *ManifestCache) Set(key string, data []byte, headers http.Header, ttl time.Duration) { +func (mc *ManifestCache) Set(key string, data []byte, mediaType, digest string, ttl time.Duration) { mc.mu.Lock() defer mc.mu.Unlock() - mc.cache[key] = &CachedManifest{ Data: data, - Headers: headers, + MediaType: mediaType, + Digest: digest, ExpiresAt: time.Now().Add(ttl), } } @@ -313,100 +316,170 @@ func getCleanHost(r *http.Request) string { } func handleRegistryRequest(w http.ResponseWriter, r *http.Request) { - const targetHost = "registry-1.docker.io" path := strings.TrimPrefix(r.URL.Path, "/v2/") - // Manifest 缓存检查 - if r.Method == http.MethodGet && strings.Contains(path, "/manifests/") { - cacheKey := fmt.Sprintf("manifest:%s", path) + // /v2/ 端点 + if path == "" { + w.Header().Set("Docker-Distribution-API-Version", "registry/2.0") + w.WriteHeader(http.StatusOK) + w.Write([]byte("{}")) + return + } + + imageName, apiType, reference := parseRegistryPath(path) + if imageName == "" || apiType == "" { + http.Error(w, "Invalid path format", http.StatusBadRequest) + return + } + + if !strings.Contains(imageName, "/") { + imageName = "library/" + imageName + } + + imageRef := fmt.Sprintf("registry-1.docker.io/%s", imageName) + + switch apiType { + case "manifests": + handleManifestRequest(w, r, imageRef, reference) + case "blobs": + handleBlobRequest(w, r, imageRef, reference) + default: + http.Error(w, "API endpoint not found", http.StatusNotFound) + } +} + +func parseRegistryPath(path string) (imageName, apiType, reference string) { + if idx := strings.Index(path, "/manifests/"); idx != -1 { + imageName = path[:idx] + apiType = "manifests" + reference = path[idx+len("/manifests/"):] + return + } + if idx := strings.Index(path, "/blobs/"); idx != -1 { + imageName = path[:idx] + apiType = "blobs" + reference = path[idx+len("/blobs/"):] + return + } + return "", "", "" +} + +func handleManifestRequest(w http.ResponseWriter, r *http.Request, imageRef, reference string) { + cacheKey := fmt.Sprintf("%s@%s", imageRef, reference) + + // 检查缓存 + if r.Method == http.MethodGet { if cached, ok := manifestCache.Get(cacheKey); ok { logrus.Debugf("Docker镜像: 使用缓存的 manifest") - for k, v := range cached.Headers { - for _, val := range v { - w.Header().Add(k, val) - } - } + w.Header().Set("Content-Type", cached.MediaType) + w.Header().Set("Docker-Content-Digest", cached.Digest) + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(cached.Data))) w.WriteHeader(http.StatusOK) w.Write(cached.Data) return } } - ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) - defer cancel() + var ref name.Reference + var err error - url := &url.URL{ - Scheme: "https", - Host: targetHost, - Path: "/v2/" + path, - RawQuery: r.URL.RawQuery, + if strings.HasPrefix(reference, "sha256:") { + ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference)) + } else { + ref, err = name.NewTag(fmt.Sprintf("%s:%s", imageRef, reference)) } - headers := copyHeaders(r.Header) - headers.Set("Host", targetHost) - - logrus.Debugf("Docker镜像: 转发请求至 %s", url.String()) - - resp, err := sendRequestWithContext(ctx, r.Method, url.String(), headers, r.Body) if err != nil { - logrus.Errorf("Docker镜像: 请求失败 - %v", err) - http.Error(w, "服务暂时不可用", http.StatusBadGateway) - return - } - defer resp.Body.Close() - - if resp.StatusCode == http.StatusUnauthorized { - handleAuthChallenge(w, r, resp) + logrus.Errorf("解析镜像引用失败: %v", err) + http.Error(w, "Invalid reference", http.StatusBadRequest) return } - respHeaders := copyHeaders(resp.Header) - if respHeaders.Get("WWW-Authenticate") != "" { - currentDomain := getCleanHost(r) - respHeaders.Set("WWW-Authenticate", - fmt.Sprintf(`Bearer realm="https://%s/auth/token", service="registry.docker.io"`, currentDomain)) - } - - for k, v := range respHeaders { - for _, val := range v { - w.Header().Add(k, val) - } - } - w.WriteHeader(resp.StatusCode) - - // 缓存 manifest - if resp.StatusCode == http.StatusOK && r.Method == http.MethodGet && strings.Contains(path, "/manifests/") { - data, err := io.ReadAll(resp.Body) - if err == nil { - cacheKey := fmt.Sprintf("manifest:%s", path) - ttl := 10 * time.Minute - if strings.Contains(path, "sha256:") { - ttl = 1 * time.Hour - } - manifestCache.Set(cacheKey, data, respHeaders, ttl) - w.Write(data) - logrus.Debugf("Docker镜像: manifest 已缓存 [大小: %.2f KB]", float64(len(data))/1024) + if r.Method == http.MethodHead { + desc, err := remote.Head(ref, remoteOptions...) + if err != nil { + logrus.Errorf("HEAD请求失败: %v", err) + http.Error(w, "Manifest not found", http.StatusNotFound) return } - } - written, err := io.Copy(w, resp.Body) - if err != nil { - logrus.Errorf("Docker镜像: 传输响应失败 - %v", err) - return - } + w.Header().Set("Content-Type", string(desc.MediaType)) + w.Header().Set("Docker-Content-Digest", desc.Digest.String()) + w.Header().Set("Content-Length", fmt.Sprintf("%d", desc.Size)) + w.WriteHeader(http.StatusOK) + logrus.Debugf("Docker镜像: HEAD 响应完成") + } else { + desc, err := remote.Get(ref, remoteOptions...) + if err != nil { + logrus.Errorf("GET请求失败: %v", err) + http.Error(w, "Manifest not found", http.StatusNotFound) + return + } - if logrus.IsLevelEnabled(logrus.DebugLevel) { - logrus.Debugf("Docker镜像: 响应完成 [状态: %d] [大小: %.2f KB]", - resp.StatusCode, float64(written)/1024) + w.Header().Set("Content-Type", string(desc.MediaType)) + w.Header().Set("Docker-Content-Digest", desc.Digest.String()) + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(desc.Manifest))) + w.WriteHeader(http.StatusOK) + w.Write(desc.Manifest) + + // 缓存 manifest + ttl := 10 * time.Minute + if strings.HasPrefix(reference, "sha256:") { + ttl = 1 * time.Hour + } + manifestCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), desc.Digest.String(), ttl) + + logrus.Debugf("Docker镜像: manifest 响应完成 [大小: %.2f KB]", float64(len(desc.Manifest))/1024) } } -func handleAuthRequest(w http.ResponseWriter, r *http.Request) { - const targetHost = "auth.docker.io" +func handleBlobRequest(w http.ResponseWriter, r *http.Request, imageRef, digest string) { + digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest)) + if err != nil { + logrus.Errorf("解析digest引用失败: %v", err) + http.Error(w, "Invalid digest reference", http.StatusBadRequest) + return + } - // Token 缓存检查 + layer, err := remote.Layer(digestRef, remoteOptions...) + if err != nil { + logrus.Errorf("获取layer失败: %v", err) + http.Error(w, "Layer not found", http.StatusNotFound) + return + } + + size, err := layer.Size() + if err != nil { + logrus.Errorf("获取layer大小失败: %v", err) + http.Error(w, "Failed to get layer size", http.StatusInternalServerError) + return + } + + reader, err := layer.Compressed() + if err != nil { + logrus.Errorf("获取layer内容失败: %v", err) + http.Error(w, "Failed to get layer content", http.StatusInternalServerError) + return + } + defer reader.Close() + + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Length", fmt.Sprintf("%d", size)) + w.Header().Set("Docker-Content-Digest", digest) + w.WriteHeader(http.StatusOK) + + written, err := io.Copy(w, reader) + if err != nil { + logrus.Errorf("传输layer失败: %v", err) + return + } + + logrus.Debugf("Docker镜像: blob 传输完成 [大小: %.2f MB]", float64(written)/(1024*1024)) +} + +func handleAuthRequest(w http.ResponseWriter, r *http.Request) { cacheKey := r.URL.RawQuery + if token, ok := tokenCache.Get(cacheKey); ok { logrus.Debugf("认证服务: 使用缓存的 token") w.Header().Set("Content-Type", "application/json") @@ -415,11 +488,12 @@ func handleAuthRequest(w http.ResponseWriter, r *http.Request) { return } + const targetHost = "auth.docker.io" ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) defer cancel() path := strings.TrimPrefix(r.URL.Path, "/auth/") - url := &url.URL{ + targetURL := &url.URL{ Scheme: "https", Host: targetHost, Path: "/" + path, @@ -429,9 +503,9 @@ func handleAuthRequest(w http.ResponseWriter, r *http.Request) { headers := copyHeaders(r.Header) headers.Set("Host", targetHost) - logrus.Debugf("认证服务: 转发请求至 %s", url.String()) + logrus.Debugf("认证服务: 转发请求至 %s", targetURL.String()) - resp, err := sendRequestWithContext(ctx, r.Method, url.String(), headers, r.Body) + resp, err := sendRequestWithContext(ctx, r.Method, targetURL.String(), headers, r.Body) if err != nil { logrus.Errorf("认证服务: 请求失败 - %v", err) http.Error(w, "服务暂时不可用", http.StatusBadGateway) @@ -446,7 +520,6 @@ func handleAuthRequest(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(resp.StatusCode) - // 缓存 token if resp.StatusCode == http.StatusOK { data, err := io.ReadAll(resp.Body) if err == nil { @@ -464,16 +537,7 @@ func handleAuthRequest(w http.ResponseWriter, r *http.Request) { } } - written, err := io.Copy(w, resp.Body) - if err != nil { - logrus.Errorf("认证服务: 传输响应失败 - %v", err) - return - } - - if logrus.IsLevelEnabled(logrus.DebugLevel) { - logrus.Debugf("认证服务: 响应完成 [状态: %d] [大小: %.2f KB]", - resp.StatusCode, float64(written)/1024) - } + io.Copy(w, resp.Body) } func handleCloudflareRequest(w http.ResponseWriter, r *http.Request) { @@ -482,7 +546,7 @@ func handleCloudflareRequest(w http.ResponseWriter, r *http.Request) { defer cancel() path := strings.TrimPrefix(r.URL.Path, "/production-cloudflare/") - url := &url.URL{ + targetURL := &url.URL{ Scheme: "https", Host: targetHost, Path: "/" + path, @@ -492,9 +556,9 @@ func handleCloudflareRequest(w http.ResponseWriter, r *http.Request) { headers := copyHeaders(r.Header) headers.Set("Host", targetHost) - logrus.Debugf("Cloudflare: 转发请求至 %s", url.String()) + logrus.Debugf("Cloudflare: 转发请求至 %s", targetURL.String()) - resp, err := sendRequestWithContext(ctx, r.Method, url.String(), headers, r.Body) + resp, err := sendRequestWithContext(ctx, r.Method, targetURL.String(), headers, r.Body) if err != nil { logrus.Errorf("Cloudflare: 请求失败 - %v", err) http.Error(w, "服务暂时不可用", http.StatusBadGateway) @@ -507,33 +571,6 @@ func handleCloudflareRequest(w http.ResponseWriter, r *http.Request) { w.Header().Add(k, val) } } - w.WriteHeader(resp.StatusCode) - - written, err := io.Copy(w, resp.Body) - if err != nil { - logrus.Errorf("Cloudflare: 传输响应失败 - %v", err) - return - } - - if logrus.IsLevelEnabled(logrus.DebugLevel) { - logrus.Debugf("Cloudflare: 响应完成 [状态: %d] [大小: %.2f KB]", - resp.StatusCode, float64(written)/1024) - } -} - -func handleAuthChallenge(w http.ResponseWriter, r *http.Request, resp *http.Response) { - for k, v := range resp.Header { - for _, val := range v { - w.Header().Add(k, val) - } - } - - if authHeader := w.Header().Get("WWW-Authenticate"); authHeader != "" { - currentDomain := getCleanHost(r) - w.Header().Set("WWW-Authenticate", - fmt.Sprintf(`Bearer realm="https://%s/auth/token", service="registry.docker.io"`, currentDomain)) - } - w.WriteHeader(resp.StatusCode) io.Copy(w, resp.Body) } @@ -546,10 +583,6 @@ func handleDisguise(w http.ResponseWriter, r *http.Request) { RawQuery: r.URL.RawQuery, } - if logrus.IsLevelEnabled(logrus.DebugLevel) { - logrus.Debugf("伪装页面: 转发请求至 %s", targetURL.String()) - } - ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) defer cancel() @@ -570,17 +603,7 @@ func handleDisguise(w http.ResponseWriter, r *http.Request) { } } w.WriteHeader(resp.StatusCode) - - written, err := io.Copy(w, resp.Body) - if err != nil { - logrus.Errorf("伪装页面: 传输响应失败 - %v", err) - return - } - - if logrus.IsLevelEnabled(logrus.DebugLevel) { - logrus.Debugf("伪装页面: 响应完成 [状态: %d] [大小: %.2f KB]", - resp.StatusCode, float64(written)/1024) - } + io.Copy(w, resp.Body) } func sendRequestWithContext(ctx context.Context, method, url string, headers http.Header, body io.ReadCloser) (*http.Response, error) { @@ -588,18 +611,8 @@ func sendRequestWithContext(ctx context.Context, method, url string, headers htt if err != nil { return nil, fmt.Errorf("创建请求失败: %v", err) } - req.Header = headers - - startTime := time.Now() - resp, err := client.Do(req) - - if err == nil && logrus.IsLevelEnabled(logrus.DebugLevel) { - duration := time.Since(startTime) - logrus.Debugf("请求耗时: %.2f 秒 (%s)", duration.Seconds(), url) - } - - return resp, err + return client.Do(req) } func copyHeaders(src http.Header) http.Header {