Skip to content

Commit 0d658ad

Browse files
committed
Implement Policy Fetcher in framework
1 parent d3c3c8d commit 0d658ad

File tree

4 files changed

+1020
-0
lines changed

4 files changed

+1020
-0
lines changed

internal/framework/fetch/fetch.go

Lines changed: 388 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,388 @@
1+
package fetch
2+
3+
import (
4+
"context"
5+
"crypto/sha256"
6+
"encoding/hex"
7+
"errors"
8+
"fmt"
9+
"io"
10+
"net/http"
11+
"regexp"
12+
"strings"
13+
"time"
14+
15+
"k8s.io/apimachinery/pkg/util/wait"
16+
)
17+
18+
//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate
19+
20+
// ChecksumMismatchError represents an error when the calculated checksum doesn't match the expected checksum.
21+
// This type of error should not trigger retries as it indicates data corruption or tampering.
22+
type ChecksumMismatchError struct {
23+
Expected string
24+
Actual string
25+
}
26+
27+
func (e *ChecksumMismatchError) Error() string {
28+
return fmt.Sprintf("checksum mismatch: expected %s, got %s", e.Expected, e.Actual)
29+
}
30+
31+
// ChecksumFetchError represents an error when fetching the checksum file fails.
32+
// This type of error should trigger retries as it may be a temporary network issue.
33+
type ChecksumFetchError struct {
34+
Err error
35+
URL string
36+
}
37+
38+
func (e *ChecksumFetchError) Error() string {
39+
return fmt.Sprintf("failed to fetch checksum from %s: %v", e.URL, e.Err)
40+
}
41+
42+
func (e *ChecksumFetchError) Unwrap() error {
43+
return e.Err
44+
}
45+
46+
// options contains the internal configuration for fetching remote files.
47+
type options struct {
48+
checksumLocation string
49+
retryBackoff RetryBackoffType
50+
validationMethods []string
51+
timeout time.Duration
52+
retryMaxDelay time.Duration
53+
retryAttempts int32
54+
}
55+
56+
// defaults returns options with sensible default values.
57+
func defaults() options {
58+
return options{
59+
timeout: 30 * time.Second,
60+
retryAttempts: 3,
61+
retryMaxDelay: 5 * time.Minute,
62+
retryBackoff: RetryBackoffExponential,
63+
}
64+
}
65+
66+
// Option defines a function that modifies fetch options.
67+
type Option func(*options)
68+
69+
// WithTimeout sets the HTTP request timeout.
70+
func WithTimeout(timeout time.Duration) Option {
71+
return func(o *options) {
72+
o.timeout = timeout
73+
}
74+
}
75+
76+
// WithRetryAttempts sets the number of retry attempts.
77+
func WithRetryAttempts(attempts int32) Option {
78+
return func(o *options) {
79+
o.retryAttempts = attempts
80+
}
81+
}
82+
83+
// WithRetryBackoff sets the retry backoff strategy.
84+
func WithRetryBackoff(backoff RetryBackoffType) Option {
85+
return func(o *options) {
86+
o.retryBackoff = backoff
87+
}
88+
}
89+
90+
// WithMaxRetryDelay sets the maximum delay between retries.
91+
func WithMaxRetryDelay(delay time.Duration) Option {
92+
return func(o *options) {
93+
o.retryMaxDelay = delay
94+
}
95+
}
96+
97+
// WithChecksum enables checksum validation with an optional custom checksum location.
98+
// If no location is provided, defaults to <fileURL>.sha256.
99+
func WithChecksum(checksumLocation ...string) Option {
100+
return func(o *options) {
101+
o.validationMethods = append(o.validationMethods, "checksum")
102+
if len(checksumLocation) > 0 {
103+
o.checksumLocation = checksumLocation[0]
104+
}
105+
}
106+
}
107+
108+
// Fetcher defines the interface for fetching remote files.
109+
//
110+
//counterfeiter:generate . Fetcher
111+
type Fetcher interface {
112+
GetRemoteFile(url string, opts ...Option) ([]byte, error)
113+
}
114+
115+
// DefaultFetcher is the default implementation of Fetcher.
116+
type DefaultFetcher struct{}
117+
118+
// RetryBackoffType defines supported backoff strategies.
119+
type RetryBackoffType string
120+
121+
const (
122+
RetryBackoffExponential RetryBackoffType = "exponential"
123+
RetryBackoffLinear RetryBackoffType = "linear"
124+
)
125+
126+
// GetRemoteFile fetches a remote file with retry logic and validation.
127+
func (f *DefaultFetcher) GetRemoteFile(url string, opts ...Option) ([]byte, error) {
128+
ctx := context.Background()
129+
130+
// Apply options to defaults
131+
options := defaults()
132+
for _, opt := range opts {
133+
opt(&options)
134+
}
135+
136+
fetchURL, err := f.convertS3URLToHTTPS(url)
137+
if err != nil {
138+
return nil, fmt.Errorf("failed to convert S3 URL: %w", err)
139+
}
140+
141+
backoff := f.createBackoffConfig(options.retryBackoff, options.retryAttempts, options.retryMaxDelay)
142+
143+
var lastErr error
144+
var result []byte
145+
146+
err = wait.ExponentialBackoffWithContext(ctx, backoff, func(ctx context.Context) (bool, error) {
147+
client := f.createHTTPClientWithTimeout(options.timeout)
148+
data, err := f.fetchFileContent(ctx, client, fetchURL)
149+
if err != nil {
150+
lastErr = fmt.Errorf("failed to fetch file from %s: %w", url, err)
151+
return false, nil
152+
}
153+
154+
if len(options.validationMethods) > 0 {
155+
if err := f.validateFileContent(ctx, data, url, options); err != nil {
156+
lastErr = err
157+
// Don't retry on checksum mismatches as they indicate data corruption
158+
var checksumMismatchErr *ChecksumMismatchError
159+
if errors.As(err, &checksumMismatchErr) {
160+
return false, err
161+
}
162+
return false, nil
163+
}
164+
}
165+
166+
result = data
167+
return true, nil
168+
})
169+
170+
// Return the most meaningful error
171+
if result != nil {
172+
return result, nil
173+
}
174+
175+
if lastErr != nil {
176+
return nil, lastErr
177+
}
178+
179+
if err != nil {
180+
return nil, fmt.Errorf("retry operation failed: %w", err)
181+
}
182+
183+
return nil, fmt.Errorf("failed to fetch file from %s: unknown error", url)
184+
}
185+
186+
func (f *DefaultFetcher) createBackoffConfig(
187+
backoffType RetryBackoffType,
188+
attempts int32,
189+
maxDelay time.Duration,
190+
) wait.Backoff {
191+
switch backoffType {
192+
case RetryBackoffLinear:
193+
return wait.Backoff{
194+
Duration: 200 * time.Millisecond,
195+
Factor: 1.0,
196+
Jitter: 0.1,
197+
Steps: int(attempts + 1),
198+
Cap: maxDelay,
199+
}
200+
case RetryBackoffExponential:
201+
fallthrough
202+
default:
203+
return wait.Backoff{
204+
Duration: 200 * time.Millisecond,
205+
Factor: 2.0,
206+
Jitter: 0.1,
207+
Steps: int(attempts + 1),
208+
Cap: maxDelay,
209+
}
210+
}
211+
}
212+
213+
// validateFileContent validates the fetched file content using the specified methods.
214+
func (f *DefaultFetcher) validateFileContent(ctx context.Context, data []byte, url string, options options) error {
215+
for _, method := range options.validationMethods {
216+
switch method {
217+
case "checksum":
218+
if err := f.validateChecksum(ctx, data, url, options.checksumLocation); err != nil {
219+
return fmt.Errorf("checksum validation failed: %w", err)
220+
}
221+
default:
222+
return fmt.Errorf("unsupported validation method: %s", method)
223+
}
224+
}
225+
return nil
226+
}
227+
228+
// validateChecksum validates the file content against a SHA256 checksum.
229+
func (f *DefaultFetcher) validateChecksum(ctx context.Context, data []byte, url, checksumLocation string) error {
230+
// If no checksum location is provided, default to <url>.sha256
231+
checksumURL := checksumLocation
232+
if checksumURL == "" {
233+
checksumURL = url + ".sha256"
234+
}
235+
236+
fetchChecksumURL, err := f.convertS3URLToHTTPS(checksumURL)
237+
if err != nil {
238+
return &ChecksumFetchError{URL: checksumURL, Err: fmt.Errorf("failed to convert S3 checksum URL: %w", err)}
239+
}
240+
241+
client := f.createHTTPClientWithTimeout(30 * time.Second)
242+
checksumData, err := f.fetchFileContent(ctx, client, fetchChecksumURL)
243+
if err != nil {
244+
return &ChecksumFetchError{URL: checksumURL, Err: err}
245+
}
246+
247+
// Parse the checksum (assume it's in the format "hash filename" or just "hash")
248+
checksumStr := strings.TrimSpace(string(checksumData))
249+
expectedChecksum := strings.Fields(checksumStr)[0] // Take the first field (the hash)
250+
251+
// Calculate the actual checksum
252+
hasher := sha256.New()
253+
hasher.Write(data)
254+
actualChecksum := hex.EncodeToString(hasher.Sum(nil))
255+
256+
if actualChecksum != expectedChecksum {
257+
return &ChecksumMismatchError{Expected: expectedChecksum, Actual: actualChecksum}
258+
}
259+
260+
return nil
261+
}
262+
263+
// createHTTPClientWithTimeout creates an HTTP client with the specified timeout duration.
264+
func (f *DefaultFetcher) createHTTPClientWithTimeout(timeout time.Duration) *http.Client {
265+
return &http.Client{
266+
Timeout: timeout,
267+
}
268+
}
269+
270+
// fetchFileContent performs the actual HTTP GET request and reads the response body.
271+
func (f *DefaultFetcher) fetchFileContent(ctx context.Context, client *http.Client, url string) ([]byte, error) {
272+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
273+
if err != nil {
274+
return nil, fmt.Errorf("failed to create request: %w", err)
275+
}
276+
277+
// Set a reasonable User-Agent header
278+
req.Header.Set("User-Agent", "nginx-gateway-fabric/1.0")
279+
280+
resp, err := client.Do(req)
281+
if err != nil {
282+
return nil, fmt.Errorf("failed to fetch file from %s: %w", url, err)
283+
}
284+
defer resp.Body.Close()
285+
286+
if resp.StatusCode != http.StatusOK {
287+
return nil, fmt.Errorf("HTTP request failed with status %d: %s", resp.StatusCode, resp.Status)
288+
}
289+
290+
content, err := io.ReadAll(resp.Body)
291+
if err != nil {
292+
return nil, fmt.Errorf("failed to read response body: %w", err)
293+
}
294+
295+
return content, nil
296+
}
297+
298+
// convertS3URLToHTTPS converts S3 URLs to HTTPS URLs for fetching.
299+
// Supports both standard S3 URLs (s3://bucket/key) and regional URLs (s3://bucket.region/key).
300+
func (f *DefaultFetcher) convertS3URLToHTTPS(url string) (string, error) {
301+
if !strings.HasPrefix(url, "s3://") {
302+
return url, nil
303+
}
304+
305+
s3Path := strings.TrimPrefix(url, "s3://")
306+
307+
// Split into bucket and object key
308+
parts := strings.SplitN(s3Path, "/", 2)
309+
if len(parts) < 1 {
310+
return "", fmt.Errorf("invalid S3 URL format: %s", url)
311+
}
312+
313+
bucketInfo := parts[0]
314+
var objectKey string
315+
if len(parts) > 1 {
316+
objectKey = parts[1]
317+
}
318+
319+
if bucketInfo == "" {
320+
return "", fmt.Errorf("S3 bucket name cannot be empty")
321+
}
322+
323+
bucket, region := f.parseBucketAndRegion(bucketInfo)
324+
325+
if bucket == "" {
326+
return "", fmt.Errorf("S3 bucket name cannot be empty after parsing")
327+
}
328+
329+
var httpsURL string
330+
if region != "" {
331+
httpsURL = fmt.Sprintf("https://s3.%s.amazonaws.com/%s", region, bucket)
332+
} else {
333+
httpsURL = fmt.Sprintf("https://s3.amazonaws.com/%s", bucket)
334+
}
335+
336+
if objectKey != "" {
337+
httpsURL = fmt.Sprintf("%s/%s", httpsURL, objectKey)
338+
}
339+
340+
return httpsURL, nil
341+
}
342+
343+
// parseBucketAndRegion extracts bucket name and region from the bucket info part of an S3 URL.
344+
// Handles various formats:
345+
// - "my-bucket" -> ("my-bucket", "")
346+
// - "my-bucket.us-west-2" -> ("my-bucket", "us-west-2")
347+
// - "my-bucket.s3.us-west-2.amazonaws.com" -> ("my-bucket", "us-west-2").
348+
func (f *DefaultFetcher) parseBucketAndRegion(bucketInfo string) (bucket, region string) {
349+
// Handle legacy S3 website/FQDN format: bucket.s3.region.amazonaws.com
350+
if strings.Contains(bucketInfo, ".s3.") && strings.HasSuffix(bucketInfo, ".amazonaws.com") {
351+
parts := strings.Split(bucketInfo, ".")
352+
if len(parts) >= 4 && parts[1] == "s3" && parts[len(parts)-1] == "com" && parts[len(parts)-2] == "amazonaws" {
353+
bucket = parts[0]
354+
// Extract region (everything between s3 and amazonaws)
355+
regionParts := parts[2 : len(parts)-2]
356+
region = strings.Join(regionParts, ".")
357+
return bucket, region
358+
}
359+
}
360+
361+
if strings.Contains(bucketInfo, ".") {
362+
parts := strings.SplitN(bucketInfo, ".", 2)
363+
bucket = parts[0]
364+
potentialRegion := parts[1]
365+
366+
if f.isValidAWSRegion(potentialRegion) {
367+
region = potentialRegion
368+
} else {
369+
bucket = bucketInfo
370+
region = ""
371+
}
372+
return bucket, region
373+
}
374+
375+
// Simple bucket name with no region
376+
return bucketInfo, ""
377+
}
378+
379+
// isValidAWSRegion performs basic validation to check if a string looks like an AWS region.
380+
func (f *DefaultFetcher) isValidAWSRegion(region string) bool {
381+
if region == "" {
382+
return false
383+
}
384+
385+
regionPattern := `^[a-z]{2,}-[a-z]+-[0-9]+$|^[a-z]{2,}-[a-z]+-[a-z]+-[0-9]+$`
386+
matched, _ := regexp.MatchString(regionPattern, region)
387+
return matched
388+
}

0 commit comments

Comments
 (0)