@@ -7,7 +7,22 @@ package software.amazon.smithy.swift.codegen.codegencomponents
7
7
8
8
import org.junit.jupiter.api.Assertions.assertEquals
9
9
import org.junit.jupiter.api.Test
10
+ import org.junit.jupiter.api.assertThrows
11
+ import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
12
+ import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
13
+ import software.amazon.smithy.aws.traits.protocols.AwsQueryTrait
14
+ import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait
15
+ import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
16
+ import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
17
+ import software.amazon.smithy.model.Model
18
+ import software.amazon.smithy.model.knowledge.ServiceIndex
19
+ import software.amazon.smithy.model.shapes.ServiceShape
10
20
import software.amazon.smithy.model.shapes.ShapeId
21
+ import software.amazon.smithy.model.shapes.StringShape
22
+ import software.amazon.smithy.model.traits.ProtocolDefinitionTrait
23
+ import software.amazon.smithy.protocol.traits.Rpcv2CborTrait
24
+ import software.amazon.smithy.swift.codegen.SwiftSettings
25
+ import software.amazon.smithy.swift.codegen.UnresolvableProtocolException
11
26
import software.amazon.smithy.swift.codegen.asSmithy
12
27
import software.amazon.smithy.swift.codegen.defaultSettings
13
28
@@ -25,4 +40,217 @@ class SwiftSettingsTest {
25
40
assertEquals(" https://github.com/aws-amplify/amplify-codegen.git" , settings.gitRepo)
26
41
assertEquals(false , settings.mergeModels)
27
42
}
43
+
44
+ // Smithy Protocol Selection Tests
45
+
46
+ // Row 1: SDK supports all protocols
47
+ private val allProtocolsSupported =
48
+ setOf (
49
+ Rpcv2CborTrait .ID ,
50
+ AwsJson1_0Trait .ID ,
51
+ AwsJson1_1Trait .ID ,
52
+ RestJson1Trait .ID ,
53
+ RestXmlTrait .ID ,
54
+ AwsQueryTrait .ID ,
55
+ Ec2QueryTrait .ID ,
56
+ )
57
+
58
+ @Test
59
+ fun `when SDK supports all protocols and service has rpcv2Cbor and awsJson1_0 then resolves rpcv2Cbor` () {
60
+ val settings = createTestSettings()
61
+ val service = createServiceWithProtocols(setOf (Rpcv2CborTrait .ID , AwsJson1_0Trait .ID ))
62
+ val serviceIndex = createServiceIndex(service)
63
+
64
+ val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, allProtocolsSupported)
65
+
66
+ assertEquals(Rpcv2CborTrait .ID , resolvedProtocol)
67
+ }
68
+
69
+ @Test
70
+ fun `when SDK supports all protocols and service has only rpcv2Cbor then resolves rpcv2Cbor` () {
71
+ val settings = createTestSettings()
72
+ val service = createServiceWithProtocols(setOf (Rpcv2CborTrait .ID ))
73
+ val serviceIndex = createServiceIndex(service)
74
+
75
+ val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, allProtocolsSupported)
76
+
77
+ assertEquals(Rpcv2CborTrait .ID , resolvedProtocol)
78
+ }
79
+
80
+ @Test
81
+ fun `when SDK supports all protocols and service has rpcv2Cbor awsJson1_0 and awsQuery then resolves rpcv2Cbor` () {
82
+ val settings = createTestSettings()
83
+ val service = createServiceWithProtocols(setOf (Rpcv2CborTrait .ID , AwsJson1_0Trait .ID , AwsQueryTrait .ID ))
84
+ val serviceIndex = createServiceIndex(service)
85
+
86
+ val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, allProtocolsSupported)
87
+
88
+ assertEquals(Rpcv2CborTrait .ID , resolvedProtocol)
89
+ }
90
+
91
+ @Test
92
+ fun `when SDK supports all protocols and service has awsJson1_0 and awsQuery then resolves awsJson1_0` () {
93
+ val settings = createTestSettings()
94
+ val service = createServiceWithProtocols(setOf (AwsJson1_0Trait .ID , AwsQueryTrait .ID ))
95
+ val serviceIndex = createServiceIndex(service)
96
+
97
+ val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, allProtocolsSupported)
98
+
99
+ assertEquals(AwsJson1_0Trait .ID , resolvedProtocol)
100
+ }
101
+
102
+ @Test
103
+ fun `when SDK supports all protocols and service has only awsQuery then resolves awsQuery` () {
104
+ val settings = createTestSettings()
105
+ val service = createServiceWithProtocols(setOf (AwsQueryTrait .ID ))
106
+ val serviceIndex = createServiceIndex(service)
107
+
108
+ val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, allProtocolsSupported)
109
+
110
+ assertEquals(AwsQueryTrait .ID , resolvedProtocol)
111
+ }
112
+
113
+ // Row 2: SDK does not support rpcv2Cbor
114
+ private val withoutRpcv2CborSupport =
115
+ setOf (
116
+ AwsJson1_0Trait .ID ,
117
+ AwsJson1_1Trait .ID ,
118
+ RestJson1Trait .ID ,
119
+ RestXmlTrait .ID ,
120
+ AwsQueryTrait .ID ,
121
+ Ec2QueryTrait .ID ,
122
+ )
123
+
124
+ @Test
125
+ fun `when SDK does not support rpcv2Cbor and service has rpcv2Cbor and awsJson1_0 then resolves awsJson1_0` () {
126
+ val settings = createTestSettings()
127
+ val service = createServiceWithProtocols(setOf (Rpcv2CborTrait .ID , AwsJson1_0Trait .ID ))
128
+ val serviceIndex = createServiceIndex(service)
129
+
130
+ val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, withoutRpcv2CborSupport)
131
+
132
+ assertEquals(AwsJson1_0Trait .ID , resolvedProtocol)
133
+ }
134
+
135
+ @Test
136
+ fun `when SDK does not support rpcv2Cbor and service has only rpcv2Cbor then throws exception` () {
137
+ val settings = createTestSettings()
138
+ val service = createServiceWithProtocols(setOf (Rpcv2CborTrait .ID ))
139
+ val serviceIndex = createServiceIndex(service)
140
+
141
+ assertThrows<UnresolvableProtocolException > {
142
+ settings.resolveServiceProtocol(serviceIndex, service, withoutRpcv2CborSupport)
143
+ }
144
+ }
145
+
146
+ @Test
147
+ fun `when SDK does not support rpcv2Cbor and service has rpcv2Cbor awsJson1_0 and awsQuery then resolves awsJson1_0` () {
148
+ val settings = createTestSettings()
149
+ val service = createServiceWithProtocols(setOf (Rpcv2CborTrait .ID , AwsJson1_0Trait .ID , AwsQueryTrait .ID ))
150
+ val serviceIndex = createServiceIndex(service)
151
+
152
+ val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, withoutRpcv2CborSupport)
153
+
154
+ assertEquals(AwsJson1_0Trait .ID , resolvedProtocol)
155
+ }
156
+
157
+ @Test
158
+ fun `when SDK does not support rpcv2Cbor and service has awsJson1_0 and awsQuery then resolves awsJson1_0` () {
159
+ val settings = createTestSettings()
160
+ val service = createServiceWithProtocols(setOf (AwsJson1_0Trait .ID , AwsQueryTrait .ID ))
161
+ val serviceIndex = createServiceIndex(service)
162
+
163
+ val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, withoutRpcv2CborSupport)
164
+
165
+ assertEquals(AwsJson1_0Trait .ID , resolvedProtocol)
166
+ }
167
+
168
+ @Test
169
+ fun `when SDK does not support rpcv2Cbor and service has only awsQuery then resolves awsQuery` () {
170
+ val settings = createTestSettings()
171
+ val service = createServiceWithProtocols(setOf (AwsQueryTrait .ID ))
172
+ val serviceIndex = createServiceIndex(service)
173
+
174
+ val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, withoutRpcv2CborSupport)
175
+
176
+ assertEquals(AwsQueryTrait .ID , resolvedProtocol)
177
+ }
178
+
179
+ // Helper functions
180
+
181
+ private fun createTestSettings (): SwiftSettings =
182
+ SwiftSettings (
183
+ service = ShapeId .from(" test#TestService" ),
184
+ moduleName = " TestModule" ,
185
+ moduleVersion = " 1.0.0" ,
186
+ moduleDescription = " Test module" ,
187
+ author = " Test Author" ,
188
+ homepage = " https://test.com" ,
189
+ sdkId = " Test" ,
190
+ gitRepo = " https://github.com/test/test.git" ,
191
+ swiftVersion = " 5.7" ,
192
+ mergeModels = false ,
193
+ copyrightNotice = " // Test copyright" ,
194
+ )
195
+
196
+ private fun createServiceWithProtocols (protocols : Set <ShapeId >): ServiceShape {
197
+ var builder =
198
+ ServiceShape
199
+ .builder()
200
+ .id(" test#TestService" )
201
+ .version(" 1.0" )
202
+
203
+ // Apply the actual protocol traits to the service
204
+ for (protocolId in protocols) {
205
+ when (protocolId) {
206
+ Rpcv2CborTrait .ID -> builder = builder.addTrait(Rpcv2CborTrait .builder().build())
207
+ AwsJson1_0Trait .ID -> builder = builder.addTrait(AwsJson1_0Trait .builder().build())
208
+ AwsJson1_1Trait .ID -> builder = builder.addTrait(AwsJson1_1Trait .builder().build())
209
+ RestJson1Trait .ID -> builder = builder.addTrait(RestJson1Trait .builder().build())
210
+ RestXmlTrait .ID -> builder = builder.addTrait(RestXmlTrait .builder().build())
211
+ AwsQueryTrait .ID -> builder = builder.addTrait(AwsQueryTrait ())
212
+ Ec2QueryTrait .ID -> builder = builder.addTrait(Ec2QueryTrait ())
213
+ }
214
+ }
215
+
216
+ return builder.build()
217
+ }
218
+
219
+ private fun createServiceIndex (service : ServiceShape ): ServiceIndex {
220
+ val modelBuilder = Model .builder()
221
+
222
+ // Add the service shape
223
+ modelBuilder.addShape(service)
224
+
225
+ // Add protocol definition shapes to the model
226
+ // These are needed for ServiceIndex to recognize the protocols
227
+ val protocolShapes =
228
+ listOf (
229
+ Rpcv2CborTrait .ID ,
230
+ AwsJson1_0Trait .ID ,
231
+ AwsJson1_1Trait .ID ,
232
+ RestJson1Trait .ID ,
233
+ RestXmlTrait .ID ,
234
+ AwsQueryTrait .ID ,
235
+ Ec2QueryTrait .ID ,
236
+ )
237
+
238
+ for (protocolId in protocolShapes) {
239
+ // Create a shape that represents the protocol definition
240
+ // and add the ProtocolDefinitionTrait to it
241
+ val protocolShape =
242
+ StringShape
243
+ .builder()
244
+ .id(protocolId)
245
+ .addTrait(
246
+ ProtocolDefinitionTrait
247
+ .builder()
248
+ .build(),
249
+ ).build()
250
+ modelBuilder.addShape(protocolShape)
251
+ }
252
+
253
+ val model = modelBuilder.build()
254
+ return ServiceIndex .of(model)
255
+ }
28
256
}
0 commit comments