Skip to content

Commit cb98dee

Browse files
authored
Fix flex checksum validation cfg (#2981)
* update checksum validation setter * add new line * add new line * add changelog * update changelog content * update changelog * add integ test case for validation skip and crc64 checksum case
1 parent 9c76401 commit cb98dee

File tree

7 files changed

+181
-64
lines changed

7 files changed

+181
-64
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"id": "df93fff8-f662-441f-a12d-ac87614a5064",
3+
"type": "bugfix",
4+
"description": "Enable request checksum validation mode by default",
5+
"modules": [
6+
"service/internal/checksum",
7+
"service/s3"
8+
]
9+
}

codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsHttpChecksumGenerator.java

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import software.amazon.smithy.model.Model;
2727
import software.amazon.smithy.model.knowledge.TopDownIndex;
2828
import software.amazon.smithy.model.shapes.MemberShape;
29+
import software.amazon.smithy.model.shapes.ShapeType;
2930
import software.amazon.smithy.model.shapes.OperationShape;
3031
import software.amazon.smithy.model.shapes.ServiceShape;
3132
import software.amazon.smithy.aws.traits.HttpChecksumTrait;
@@ -54,6 +55,10 @@ private static String getRequestValidationModeAccessorFuncName(String operationN
5455
return String.format("get%s%s", operationName, "RequestValidationModeMember");
5556
}
5657

58+
private static String setRequestValidationModeAccessorFuncName(String operationName) {
59+
return String.format("set%s%s", operationName, "RequestValidationModeMember");
60+
}
61+
5762
private static String getAddInputMiddlewareFuncName(String operationName) {
5863
return String.format("add%sInputChecksumMiddlewares", operationName);
5964
}
@@ -158,7 +163,7 @@ public void writeAdditionalFiles(
158163

159164
goDelegator.useShapeWriter(operation, writer -> {
160165
// generate getter helper function to access input member value
161-
writeGetInputMemberAccessorHelper(writer, model, symbolProvider, operation);
166+
writeInputMemberAccessorHelper(writer, model, symbolProvider, operation);
162167

163168
// generate middleware helper function
164169
if (generateComputeInputChecksums) {
@@ -212,7 +217,7 @@ public static boolean hasInputChecksumTrait(Model model, ServiceShape service) {
212217
return false;
213218
}
214219

215-
private static boolean hasOutputChecksumTrait(Model model, ServiceShape service, OperationShape operation) {
220+
private static boolean hasOutputChecksumTrait(Model model, ServiceShape service, OperationShape operation) {
216221
if (!hasChecksumTrait(model, service, operation)) {
217222
return false;
218223
}
@@ -356,6 +361,7 @@ private void writeOutputMiddlewareHelper(
356361
writer.write("""
357362
return $T(stack, $T{
358363
GetValidationMode: $L,
364+
SetValidationMode: $L,
359365
ResponseChecksumValidation: options.ResponseChecksumValidation,
360366
ValidationAlgorithms: $L,
361367
IgnoreMultipartValidation: $L,
@@ -367,6 +373,7 @@ private void writeOutputMiddlewareHelper(
367373
SymbolUtils.createValueSymbolBuilder("OutputMiddlewareOptions",
368374
AwsGoDependency.SERVICE_INTERNAL_CHECKSUM).build(),
369375
getRequestValidationModeAccessorFuncName(operationName),
376+
setRequestValidationModeAccessorFuncName(operationName),
370377
convertToGoStringList(responseAlgorithms),
371378
ignoreMultipartChecksumValidationMap.getOrDefault(
372379
service.toShapeId(), new HashSet<>()).contains(operation.toShapeId())
@@ -389,7 +396,7 @@ private String convertToGoStringList(List<String> list) {
389396
return sb.toString();
390397
}
391398

392-
private void writeGetInputMemberAccessorHelper(
399+
private void writeInputMemberAccessorHelper(
393400
GoWriter writer,
394401
Model model,
395402
SymbolProvider symbolProvider,
@@ -438,6 +445,9 @@ private void writeGetInputMemberAccessorHelper(
438445
String.format("%s gets the request checksum validation mode provided as input.", funcName));
439446
getInputTemplate(writer, symbolProvider, input, funcName, memberName);
440447
writer.insertTrailingNewline();
448+
funcName = setRequestValidationModeAccessorFuncName(operationSymbol.getName());
449+
setInputTemplate(writer, symbolProvider, input, funcName, memberName);
450+
writer.insertTrailingNewline();
441451
}
442452
}
443453

@@ -459,6 +469,26 @@ private void getInputTemplate(
459469
writer.write("");
460470
}
461471

472+
private void setInputTemplate(
473+
GoWriter writer,
474+
SymbolProvider symbolProvider,
475+
StructureShape input,
476+
String funcName,
477+
String memberName
478+
) {
479+
writer.write(GoWriter.goTemplate("""
480+
func $fn:L(input interface{}, mode string) {
481+
in := input.(*$inputType:L)
482+
in.$member:L = types.$member:L(mode)
483+
}""",
484+
Map.of(
485+
"fn", funcName,
486+
"inputType", symbolProvider.toSymbol(input).getName(),
487+
"member", memberName
488+
)));
489+
writer.write("");
490+
}
491+
462492
private void generateInputComputedChecksumMetadataHelpers(
463493
GoWriter writer,
464494
Model model,

service/internal/checksum/middleware_add.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ type OutputMiddlewareOptions struct {
113113
// mode and true, or false if no mode is specified.
114114
GetValidationMode func(interface{}) (string, bool)
115115

116+
// SetValidationMode is a function to set the checksum validation mode of input parameters
117+
SetValidationMode func(interface{}, string)
118+
116119
// ResponseChecksumValidation is the user config to opt-in/out response checksum validation
117120
ResponseChecksumValidation aws.ResponseChecksumValidation
118121

@@ -141,6 +144,7 @@ type OutputMiddlewareOptions struct {
141144
func AddOutputMiddleware(stack *middleware.Stack, options OutputMiddlewareOptions) error {
142145
err := stack.Initialize.Add(&setupOutputContext{
143146
GetValidationMode: options.GetValidationMode,
147+
SetValidationMode: options.SetValidationMode,
144148
ResponseChecksumValidation: options.ResponseChecksumValidation,
145149
}, middleware.Before)
146150
if err != nil {

service/internal/checksum/middleware_setup_context.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ type setupOutputContext struct {
7070
// mode and true, or false if no mode is specified.
7171
GetValidationMode func(interface{}) (string, bool)
7272

73+
// SetValidationMode is a function to set the checksum validation mode of input parameters
74+
SetValidationMode func(interface{}, string)
75+
7376
// ResponseChecksumValidation states user config to opt-in/out checksum validation
7477
ResponseChecksumValidation aws.ResponseChecksumValidation
7578
}
@@ -90,6 +93,7 @@ func (m *setupOutputContext) HandleInitialize(
9093
mode, _ := m.GetValidationMode(in.Parameters)
9194

9295
if m.ResponseChecksumValidation == aws.ResponseChecksumValidationWhenSupported || mode == checksumValidationModeEnabled {
96+
m.SetValidationMode(in.Parameters, checksumValidationModeEnabled)
9397
ctx = setContextOutputValidationMode(ctx, checksumValidationModeEnabled)
9498
}
9599

service/internal/checksum/middleware_setup_context_test.go

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -131,50 +131,74 @@ func TestSetupOutput(t *testing.T) {
131131
inputParams interface{}
132132
ResponseChecksumValidation aws.ResponseChecksumValidation
133133
getValidationMode func(interface{}) (string, bool)
134-
expectValue string
134+
setValidationMode func(interface{}, string)
135+
expectCtxValue string
136+
expectInputValue string
135137
}{
136138
"user config support checksum found empty": {
137139
ResponseChecksumValidation: aws.ResponseChecksumValidationWhenSupported,
138-
inputParams: Params{Value: ""},
140+
inputParams: &Params{Value: ""},
139141
getValidationMode: func(v interface{}) (string, bool) {
140-
vv := v.(Params)
142+
vv := v.(*Params)
141143
return vv.Value, true
142144
},
143-
expectValue: "ENABLED",
145+
setValidationMode: func(v interface{}, m string) {
146+
vv := v.(*Params)
147+
vv.Value = m
148+
},
149+
expectCtxValue: "ENABLED",
150+
expectInputValue: "ENABLED",
144151
},
145152
"user config support checksum found invalid value": {
146153
ResponseChecksumValidation: aws.ResponseChecksumValidationWhenSupported,
147-
inputParams: Params{Value: "abc123"},
154+
inputParams: &Params{Value: "abc123"},
148155
getValidationMode: func(v interface{}) (string, bool) {
149-
vv := v.(Params)
156+
vv := v.(*Params)
150157
return vv.Value, true
158+
151159
},
152-
expectValue: "ENABLED",
160+
setValidationMode: func(v interface{}, m string) {
161+
vv := v.(*Params)
162+
vv.Value = m
163+
},
164+
expectCtxValue: "ENABLED",
165+
expectInputValue: "ENABLED",
153166
},
154167
"user config require checksum found invalid value": {
155168
ResponseChecksumValidation: aws.ResponseChecksumValidationWhenRequired,
156-
inputParams: Params{Value: "abc123"},
169+
inputParams: &Params{Value: "abc123"},
157170
getValidationMode: func(v interface{}) (string, bool) {
158-
vv := v.(Params)
171+
vv := v.(*Params)
159172
return vv.Value, true
160173
},
161-
expectValue: "",
174+
setValidationMode: func(v interface{}, m string) {
175+
vv := v.(*Params)
176+
vv.Value = m
177+
},
178+
expectCtxValue: "",
179+
expectInputValue: "abc123",
162180
},
163181
"user config require checksum found valid value": {
164182
ResponseChecksumValidation: aws.ResponseChecksumValidationWhenRequired,
165-
inputParams: Params{Value: "ENABLED"},
183+
inputParams: &Params{Value: "ENABLED"},
166184
getValidationMode: func(v interface{}) (string, bool) {
167-
vv := v.(Params)
185+
vv := v.(*Params)
168186
return vv.Value, true
169187
},
170-
expectValue: "ENABLED",
188+
setValidationMode: func(v interface{}, m string) {
189+
vv := v.(*Params)
190+
vv.Value = m
191+
},
192+
expectCtxValue: "ENABLED",
193+
expectInputValue: "ENABLED",
171194
},
172195
}
173196

174197
for name, c := range cases {
175198
t.Run(name, func(t *testing.T) {
176199
m := setupOutputContext{
177200
GetValidationMode: c.getValidationMode,
201+
SetValidationMode: c.setValidationMode,
178202
ResponseChecksumValidation: c.ResponseChecksumValidation,
179203
}
180204

@@ -185,10 +209,14 @@ func TestSetupOutput(t *testing.T) {
185209
out middleware.InitializeOutput, metadata middleware.Metadata, err error,
186210
) {
187211
v := getContextOutputValidationMode(ctx)
188-
if e, a := c.expectValue, v; e != a {
189-
t.Errorf("expect value %v, got %v", e, a)
212+
if e, a := c.expectCtxValue, v; e != a {
213+
t.Errorf("expect ctx checksum validation mode to be %v, got %v", e, a)
190214
}
191-
215+
in := input.Parameters.(*Params)
216+
if e, a := c.expectInputValue, in.Value; e != a {
217+
t.Errorf("expect input checksum validation mode to be %v, got %v", e, a)
218+
}
219+
192220
return out, metadata, nil
193221
},
194222
))

0 commit comments

Comments
 (0)