更新 main.go
This commit is contained in:
309
main.go
309
main.go
@@ -17,6 +17,9 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-containerregistry/pkg/authn"
|
||||
"github.com/google/go-containerregistry/pkg/name"
|
||||
"github.com/google/go-containerregistry/pkg/v1/remote"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -32,9 +35,6 @@ type Config struct {
|
||||
var config Config
|
||||
|
||||
var client = &http.Client{
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
DisableKeepAlives: false,
|
||||
@@ -46,6 +46,11 @@ var client = &http.Client{
|
||||
},
|
||||
}
|
||||
|
||||
var remoteOptions = []remote.Option{
|
||||
remote.WithAuth(authn.Anonymous),
|
||||
remote.WithTransport(client.Transport),
|
||||
}
|
||||
|
||||
// Token 缓存
|
||||
type TokenCache struct {
|
||||
mu sync.RWMutex
|
||||
@@ -64,7 +69,6 @@ var tokenCache = &TokenCache{
|
||||
func (tc *TokenCache) Get(key string) (string, bool) {
|
||||
tc.mu.RLock()
|
||||
defer tc.mu.RUnlock()
|
||||
|
||||
if cached, ok := tc.cache[key]; ok {
|
||||
if time.Now().Before(cached.ExpiresAt) {
|
||||
return cached.Token, true
|
||||
@@ -77,7 +81,6 @@ func (tc *TokenCache) Get(key string) (string, bool) {
|
||||
func (tc *TokenCache) Set(key, token string, ttl time.Duration) {
|
||||
tc.mu.Lock()
|
||||
defer tc.mu.Unlock()
|
||||
|
||||
tc.cache[key] = &CachedToken{
|
||||
Token: token,
|
||||
ExpiresAt: time.Now().Add(ttl),
|
||||
@@ -92,7 +95,8 @@ type ManifestCache struct {
|
||||
|
||||
type CachedManifest struct {
|
||||
Data []byte
|
||||
Headers http.Header
|
||||
MediaType string
|
||||
Digest string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
@@ -103,7 +107,6 @@ var manifestCache = &ManifestCache{
|
||||
func (mc *ManifestCache) Get(key string) (*CachedManifest, bool) {
|
||||
mc.mu.RLock()
|
||||
defer mc.mu.RUnlock()
|
||||
|
||||
if cached, ok := mc.cache[key]; ok {
|
||||
if time.Now().Before(cached.ExpiresAt) {
|
||||
return cached, true
|
||||
@@ -113,13 +116,13 @@ func (mc *ManifestCache) Get(key string) (*CachedManifest, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (mc *ManifestCache) Set(key string, data []byte, headers http.Header, ttl time.Duration) {
|
||||
func (mc *ManifestCache) Set(key string, data []byte, mediaType, digest string, ttl time.Duration) {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
|
||||
mc.cache[key] = &CachedManifest{
|
||||
Data: data,
|
||||
Headers: headers,
|
||||
MediaType: mediaType,
|
||||
Digest: digest,
|
||||
ExpiresAt: time.Now().Add(ttl),
|
||||
}
|
||||
}
|
||||
@@ -313,100 +316,170 @@ func getCleanHost(r *http.Request) string {
|
||||
}
|
||||
|
||||
func handleRegistryRequest(w http.ResponseWriter, r *http.Request) {
|
||||
const targetHost = "registry-1.docker.io"
|
||||
path := strings.TrimPrefix(r.URL.Path, "/v2/")
|
||||
|
||||
// Manifest 缓存检查
|
||||
if r.Method == http.MethodGet && strings.Contains(path, "/manifests/") {
|
||||
cacheKey := fmt.Sprintf("manifest:%s", path)
|
||||
// /v2/ 端点
|
||||
if path == "" {
|
||||
w.Header().Set("Docker-Distribution-API-Version", "registry/2.0")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("{}"))
|
||||
return
|
||||
}
|
||||
|
||||
imageName, apiType, reference := parseRegistryPath(path)
|
||||
if imageName == "" || apiType == "" {
|
||||
http.Error(w, "Invalid path format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.Contains(imageName, "/") {
|
||||
imageName = "library/" + imageName
|
||||
}
|
||||
|
||||
imageRef := fmt.Sprintf("registry-1.docker.io/%s", imageName)
|
||||
|
||||
switch apiType {
|
||||
case "manifests":
|
||||
handleManifestRequest(w, r, imageRef, reference)
|
||||
case "blobs":
|
||||
handleBlobRequest(w, r, imageRef, reference)
|
||||
default:
|
||||
http.Error(w, "API endpoint not found", http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func parseRegistryPath(path string) (imageName, apiType, reference string) {
|
||||
if idx := strings.Index(path, "/manifests/"); idx != -1 {
|
||||
imageName = path[:idx]
|
||||
apiType = "manifests"
|
||||
reference = path[idx+len("/manifests/"):]
|
||||
return
|
||||
}
|
||||
if idx := strings.Index(path, "/blobs/"); idx != -1 {
|
||||
imageName = path[:idx]
|
||||
apiType = "blobs"
|
||||
reference = path[idx+len("/blobs/"):]
|
||||
return
|
||||
}
|
||||
return "", "", ""
|
||||
}
|
||||
|
||||
func handleManifestRequest(w http.ResponseWriter, r *http.Request, imageRef, reference string) {
|
||||
cacheKey := fmt.Sprintf("%s@%s", imageRef, reference)
|
||||
|
||||
// 检查缓存
|
||||
if r.Method == http.MethodGet {
|
||||
if cached, ok := manifestCache.Get(cacheKey); ok {
|
||||
logrus.Debugf("Docker镜像: 使用缓存的 manifest")
|
||||
for k, v := range cached.Headers {
|
||||
for _, val := range v {
|
||||
w.Header().Add(k, val)
|
||||
}
|
||||
}
|
||||
w.Header().Set("Content-Type", cached.MediaType)
|
||||
w.Header().Set("Docker-Content-Digest", cached.Digest)
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(cached.Data)))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(cached.Data)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
|
||||
defer cancel()
|
||||
var ref name.Reference
|
||||
var err error
|
||||
|
||||
url := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: targetHost,
|
||||
Path: "/v2/" + path,
|
||||
RawQuery: r.URL.RawQuery,
|
||||
if strings.HasPrefix(reference, "sha256:") {
|
||||
ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference))
|
||||
} else {
|
||||
ref, err = name.NewTag(fmt.Sprintf("%s:%s", imageRef, reference))
|
||||
}
|
||||
|
||||
headers := copyHeaders(r.Header)
|
||||
headers.Set("Host", targetHost)
|
||||
|
||||
logrus.Debugf("Docker镜像: 转发请求至 %s", url.String())
|
||||
|
||||
resp, err := sendRequestWithContext(ctx, r.Method, url.String(), headers, r.Body)
|
||||
if err != nil {
|
||||
logrus.Errorf("Docker镜像: 请求失败 - %v", err)
|
||||
http.Error(w, "服务暂时不可用", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
handleAuthChallenge(w, r, resp)
|
||||
logrus.Errorf("解析镜像引用失败: %v", err)
|
||||
http.Error(w, "Invalid reference", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
respHeaders := copyHeaders(resp.Header)
|
||||
if respHeaders.Get("WWW-Authenticate") != "" {
|
||||
currentDomain := getCleanHost(r)
|
||||
respHeaders.Set("WWW-Authenticate",
|
||||
fmt.Sprintf(`Bearer realm="https://%s/auth/token", service="registry.docker.io"`, currentDomain))
|
||||
}
|
||||
|
||||
for k, v := range respHeaders {
|
||||
for _, val := range v {
|
||||
w.Header().Add(k, val)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
|
||||
// 缓存 manifest
|
||||
if resp.StatusCode == http.StatusOK && r.Method == http.MethodGet && strings.Contains(path, "/manifests/") {
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err == nil {
|
||||
cacheKey := fmt.Sprintf("manifest:%s", path)
|
||||
ttl := 10 * time.Minute
|
||||
if strings.Contains(path, "sha256:") {
|
||||
ttl = 1 * time.Hour
|
||||
}
|
||||
manifestCache.Set(cacheKey, data, respHeaders, ttl)
|
||||
w.Write(data)
|
||||
logrus.Debugf("Docker镜像: manifest 已缓存 [大小: %.2f KB]", float64(len(data))/1024)
|
||||
if r.Method == http.MethodHead {
|
||||
desc, err := remote.Head(ref, remoteOptions...)
|
||||
if err != nil {
|
||||
logrus.Errorf("HEAD请求失败: %v", err)
|
||||
http.Error(w, "Manifest not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
written, err := io.Copy(w, resp.Body)
|
||||
if err != nil {
|
||||
logrus.Errorf("Docker镜像: 传输响应失败 - %v", err)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", string(desc.MediaType))
|
||||
w.Header().Set("Docker-Content-Digest", desc.Digest.String())
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", desc.Size))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
logrus.Debugf("Docker镜像: HEAD 响应完成")
|
||||
} else {
|
||||
desc, err := remote.Get(ref, remoteOptions...)
|
||||
if err != nil {
|
||||
logrus.Errorf("GET请求失败: %v", err)
|
||||
http.Error(w, "Manifest not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
if logrus.IsLevelEnabled(logrus.DebugLevel) {
|
||||
logrus.Debugf("Docker镜像: 响应完成 [状态: %d] [大小: %.2f KB]",
|
||||
resp.StatusCode, float64(written)/1024)
|
||||
w.Header().Set("Content-Type", string(desc.MediaType))
|
||||
w.Header().Set("Docker-Content-Digest", desc.Digest.String())
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(desc.Manifest)))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(desc.Manifest)
|
||||
|
||||
// 缓存 manifest
|
||||
ttl := 10 * time.Minute
|
||||
if strings.HasPrefix(reference, "sha256:") {
|
||||
ttl = 1 * time.Hour
|
||||
}
|
||||
manifestCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), desc.Digest.String(), ttl)
|
||||
|
||||
logrus.Debugf("Docker镜像: manifest 响应完成 [大小: %.2f KB]", float64(len(desc.Manifest))/1024)
|
||||
}
|
||||
}
|
||||
|
||||
func handleAuthRequest(w http.ResponseWriter, r *http.Request) {
|
||||
const targetHost = "auth.docker.io"
|
||||
func handleBlobRequest(w http.ResponseWriter, r *http.Request, imageRef, digest string) {
|
||||
digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest))
|
||||
if err != nil {
|
||||
logrus.Errorf("解析digest引用失败: %v", err)
|
||||
http.Error(w, "Invalid digest reference", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Token 缓存检查
|
||||
layer, err := remote.Layer(digestRef, remoteOptions...)
|
||||
if err != nil {
|
||||
logrus.Errorf("获取layer失败: %v", err)
|
||||
http.Error(w, "Layer not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
size, err := layer.Size()
|
||||
if err != nil {
|
||||
logrus.Errorf("获取layer大小失败: %v", err)
|
||||
http.Error(w, "Failed to get layer size", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
reader, err := layer.Compressed()
|
||||
if err != nil {
|
||||
logrus.Errorf("获取layer内容失败: %v", err)
|
||||
http.Error(w, "Failed to get layer content", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", size))
|
||||
w.Header().Set("Docker-Content-Digest", digest)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
written, err := io.Copy(w, reader)
|
||||
if err != nil {
|
||||
logrus.Errorf("传输layer失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logrus.Debugf("Docker镜像: blob 传输完成 [大小: %.2f MB]", float64(written)/(1024*1024))
|
||||
}
|
||||
|
||||
func handleAuthRequest(w http.ResponseWriter, r *http.Request) {
|
||||
cacheKey := r.URL.RawQuery
|
||||
|
||||
if token, ok := tokenCache.Get(cacheKey); ok {
|
||||
logrus.Debugf("认证服务: 使用缓存的 token")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@@ -415,11 +488,12 @@ func handleAuthRequest(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
const targetHost = "auth.docker.io"
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
path := strings.TrimPrefix(r.URL.Path, "/auth/")
|
||||
url := &url.URL{
|
||||
targetURL := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: targetHost,
|
||||
Path: "/" + path,
|
||||
@@ -429,9 +503,9 @@ func handleAuthRequest(w http.ResponseWriter, r *http.Request) {
|
||||
headers := copyHeaders(r.Header)
|
||||
headers.Set("Host", targetHost)
|
||||
|
||||
logrus.Debugf("认证服务: 转发请求至 %s", url.String())
|
||||
logrus.Debugf("认证服务: 转发请求至 %s", targetURL.String())
|
||||
|
||||
resp, err := sendRequestWithContext(ctx, r.Method, url.String(), headers, r.Body)
|
||||
resp, err := sendRequestWithContext(ctx, r.Method, targetURL.String(), headers, r.Body)
|
||||
if err != nil {
|
||||
logrus.Errorf("认证服务: 请求失败 - %v", err)
|
||||
http.Error(w, "服务暂时不可用", http.StatusBadGateway)
|
||||
@@ -446,7 +520,6 @@ func handleAuthRequest(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
|
||||
// 缓存 token
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err == nil {
|
||||
@@ -464,16 +537,7 @@ func handleAuthRequest(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
written, err := io.Copy(w, resp.Body)
|
||||
if err != nil {
|
||||
logrus.Errorf("认证服务: 传输响应失败 - %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if logrus.IsLevelEnabled(logrus.DebugLevel) {
|
||||
logrus.Debugf("认证服务: 响应完成 [状态: %d] [大小: %.2f KB]",
|
||||
resp.StatusCode, float64(written)/1024)
|
||||
}
|
||||
io.Copy(w, resp.Body)
|
||||
}
|
||||
|
||||
func handleCloudflareRequest(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -482,7 +546,7 @@ func handleCloudflareRequest(w http.ResponseWriter, r *http.Request) {
|
||||
defer cancel()
|
||||
|
||||
path := strings.TrimPrefix(r.URL.Path, "/production-cloudflare/")
|
||||
url := &url.URL{
|
||||
targetURL := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: targetHost,
|
||||
Path: "/" + path,
|
||||
@@ -492,9 +556,9 @@ func handleCloudflareRequest(w http.ResponseWriter, r *http.Request) {
|
||||
headers := copyHeaders(r.Header)
|
||||
headers.Set("Host", targetHost)
|
||||
|
||||
logrus.Debugf("Cloudflare: 转发请求至 %s", url.String())
|
||||
logrus.Debugf("Cloudflare: 转发请求至 %s", targetURL.String())
|
||||
|
||||
resp, err := sendRequestWithContext(ctx, r.Method, url.String(), headers, r.Body)
|
||||
resp, err := sendRequestWithContext(ctx, r.Method, targetURL.String(), headers, r.Body)
|
||||
if err != nil {
|
||||
logrus.Errorf("Cloudflare: 请求失败 - %v", err)
|
||||
http.Error(w, "服务暂时不可用", http.StatusBadGateway)
|
||||
@@ -507,33 +571,6 @@ func handleCloudflareRequest(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Add(k, val)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
|
||||
written, err := io.Copy(w, resp.Body)
|
||||
if err != nil {
|
||||
logrus.Errorf("Cloudflare: 传输响应失败 - %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if logrus.IsLevelEnabled(logrus.DebugLevel) {
|
||||
logrus.Debugf("Cloudflare: 响应完成 [状态: %d] [大小: %.2f KB]",
|
||||
resp.StatusCode, float64(written)/1024)
|
||||
}
|
||||
}
|
||||
|
||||
func handleAuthChallenge(w http.ResponseWriter, r *http.Request, resp *http.Response) {
|
||||
for k, v := range resp.Header {
|
||||
for _, val := range v {
|
||||
w.Header().Add(k, val)
|
||||
}
|
||||
}
|
||||
|
||||
if authHeader := w.Header().Get("WWW-Authenticate"); authHeader != "" {
|
||||
currentDomain := getCleanHost(r)
|
||||
w.Header().Set("WWW-Authenticate",
|
||||
fmt.Sprintf(`Bearer realm="https://%s/auth/token", service="registry.docker.io"`, currentDomain))
|
||||
}
|
||||
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
io.Copy(w, resp.Body)
|
||||
}
|
||||
@@ -546,10 +583,6 @@ func handleDisguise(w http.ResponseWriter, r *http.Request) {
|
||||
RawQuery: r.URL.RawQuery,
|
||||
}
|
||||
|
||||
if logrus.IsLevelEnabled(logrus.DebugLevel) {
|
||||
logrus.Debugf("伪装页面: 转发请求至 %s", targetURL.String())
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -570,17 +603,7 @@ func handleDisguise(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
|
||||
written, err := io.Copy(w, resp.Body)
|
||||
if err != nil {
|
||||
logrus.Errorf("伪装页面: 传输响应失败 - %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if logrus.IsLevelEnabled(logrus.DebugLevel) {
|
||||
logrus.Debugf("伪装页面: 响应完成 [状态: %d] [大小: %.2f KB]",
|
||||
resp.StatusCode, float64(written)/1024)
|
||||
}
|
||||
io.Copy(w, resp.Body)
|
||||
}
|
||||
|
||||
func sendRequestWithContext(ctx context.Context, method, url string, headers http.Header, body io.ReadCloser) (*http.Response, error) {
|
||||
@@ -588,18 +611,8 @@ func sendRequestWithContext(ctx context.Context, method, url string, headers htt
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %v", err)
|
||||
}
|
||||
|
||||
req.Header = headers
|
||||
|
||||
startTime := time.Now()
|
||||
resp, err := client.Do(req)
|
||||
|
||||
if err == nil && logrus.IsLevelEnabled(logrus.DebugLevel) {
|
||||
duration := time.Since(startTime)
|
||||
logrus.Debugf("请求耗时: %.2f 秒 (%s)", duration.Seconds(), url)
|
||||
}
|
||||
|
||||
return resp, err
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func copyHeaders(src http.Header) http.Header {
|
||||
|
||||
Reference in New Issue
Block a user