Skip to content

Commit 1a384e6

Browse files
authored
chore: add protocol priority tests (#932)
1 parent 36be46c commit 1a384e6

File tree

1 file changed

+228
-0
lines changed
  • smithy-swift-codegen/src/test/kotlin/software/amazon/smithy/swift/codegen/codegencomponents

1 file changed

+228
-0
lines changed

smithy-swift-codegen/src/test/kotlin/software/amazon/smithy/swift/codegen/codegencomponents/SwiftSettingsTest.kt

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,22 @@ package software.amazon.smithy.swift.codegen.codegencomponents
77

88
import org.junit.jupiter.api.Assertions.assertEquals
99
import org.junit.jupiter.api.Test
10+
import org.junit.jupiter.api.assertThrows
11+
import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
12+
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
13+
import software.amazon.smithy.aws.traits.protocols.AwsQueryTrait
14+
import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait
15+
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
16+
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
17+
import software.amazon.smithy.model.Model
18+
import software.amazon.smithy.model.knowledge.ServiceIndex
19+
import software.amazon.smithy.model.shapes.ServiceShape
1020
import software.amazon.smithy.model.shapes.ShapeId
21+
import software.amazon.smithy.model.shapes.StringShape
22+
import software.amazon.smithy.model.traits.ProtocolDefinitionTrait
23+
import software.amazon.smithy.protocol.traits.Rpcv2CborTrait
24+
import software.amazon.smithy.swift.codegen.SwiftSettings
25+
import software.amazon.smithy.swift.codegen.UnresolvableProtocolException
1126
import software.amazon.smithy.swift.codegen.asSmithy
1227
import software.amazon.smithy.swift.codegen.defaultSettings
1328

@@ -25,4 +40,217 @@ class SwiftSettingsTest {
2540
assertEquals("https://github.com/aws-amplify/amplify-codegen.git", settings.gitRepo)
2641
assertEquals(false, settings.mergeModels)
2742
}
43+
44+
// Smithy Protocol Selection Tests
45+
46+
// Row 1: SDK supports all protocols
47+
private val allProtocolsSupported =
48+
setOf(
49+
Rpcv2CborTrait.ID,
50+
AwsJson1_0Trait.ID,
51+
AwsJson1_1Trait.ID,
52+
RestJson1Trait.ID,
53+
RestXmlTrait.ID,
54+
AwsQueryTrait.ID,
55+
Ec2QueryTrait.ID,
56+
)
57+
58+
@Test
59+
fun `when SDK supports all protocols and service has rpcv2Cbor and awsJson1_0 then resolves rpcv2Cbor`() {
60+
val settings = createTestSettings()
61+
val service = createServiceWithProtocols(setOf(Rpcv2CborTrait.ID, AwsJson1_0Trait.ID))
62+
val serviceIndex = createServiceIndex(service)
63+
64+
val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, allProtocolsSupported)
65+
66+
assertEquals(Rpcv2CborTrait.ID, resolvedProtocol)
67+
}
68+
69+
@Test
70+
fun `when SDK supports all protocols and service has only rpcv2Cbor then resolves rpcv2Cbor`() {
71+
val settings = createTestSettings()
72+
val service = createServiceWithProtocols(setOf(Rpcv2CborTrait.ID))
73+
val serviceIndex = createServiceIndex(service)
74+
75+
val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, allProtocolsSupported)
76+
77+
assertEquals(Rpcv2CborTrait.ID, resolvedProtocol)
78+
}
79+
80+
@Test
81+
fun `when SDK supports all protocols and service has rpcv2Cbor awsJson1_0 and awsQuery then resolves rpcv2Cbor`() {
82+
val settings = createTestSettings()
83+
val service = createServiceWithProtocols(setOf(Rpcv2CborTrait.ID, AwsJson1_0Trait.ID, AwsQueryTrait.ID))
84+
val serviceIndex = createServiceIndex(service)
85+
86+
val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, allProtocolsSupported)
87+
88+
assertEquals(Rpcv2CborTrait.ID, resolvedProtocol)
89+
}
90+
91+
@Test
92+
fun `when SDK supports all protocols and service has awsJson1_0 and awsQuery then resolves awsJson1_0`() {
93+
val settings = createTestSettings()
94+
val service = createServiceWithProtocols(setOf(AwsJson1_0Trait.ID, AwsQueryTrait.ID))
95+
val serviceIndex = createServiceIndex(service)
96+
97+
val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, allProtocolsSupported)
98+
99+
assertEquals(AwsJson1_0Trait.ID, resolvedProtocol)
100+
}
101+
102+
@Test
103+
fun `when SDK supports all protocols and service has only awsQuery then resolves awsQuery`() {
104+
val settings = createTestSettings()
105+
val service = createServiceWithProtocols(setOf(AwsQueryTrait.ID))
106+
val serviceIndex = createServiceIndex(service)
107+
108+
val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, allProtocolsSupported)
109+
110+
assertEquals(AwsQueryTrait.ID, resolvedProtocol)
111+
}
112+
113+
// Row 2: SDK does not support rpcv2Cbor
114+
private val withoutRpcv2CborSupport =
115+
setOf(
116+
AwsJson1_0Trait.ID,
117+
AwsJson1_1Trait.ID,
118+
RestJson1Trait.ID,
119+
RestXmlTrait.ID,
120+
AwsQueryTrait.ID,
121+
Ec2QueryTrait.ID,
122+
)
123+
124+
@Test
125+
fun `when SDK does not support rpcv2Cbor and service has rpcv2Cbor and awsJson1_0 then resolves awsJson1_0`() {
126+
val settings = createTestSettings()
127+
val service = createServiceWithProtocols(setOf(Rpcv2CborTrait.ID, AwsJson1_0Trait.ID))
128+
val serviceIndex = createServiceIndex(service)
129+
130+
val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, withoutRpcv2CborSupport)
131+
132+
assertEquals(AwsJson1_0Trait.ID, resolvedProtocol)
133+
}
134+
135+
@Test
136+
fun `when SDK does not support rpcv2Cbor and service has only rpcv2Cbor then throws exception`() {
137+
val settings = createTestSettings()
138+
val service = createServiceWithProtocols(setOf(Rpcv2CborTrait.ID))
139+
val serviceIndex = createServiceIndex(service)
140+
141+
assertThrows<UnresolvableProtocolException> {
142+
settings.resolveServiceProtocol(serviceIndex, service, withoutRpcv2CborSupport)
143+
}
144+
}
145+
146+
@Test
147+
fun `when SDK does not support rpcv2Cbor and service has rpcv2Cbor awsJson1_0 and awsQuery then resolves awsJson1_0`() {
148+
val settings = createTestSettings()
149+
val service = createServiceWithProtocols(setOf(Rpcv2CborTrait.ID, AwsJson1_0Trait.ID, AwsQueryTrait.ID))
150+
val serviceIndex = createServiceIndex(service)
151+
152+
val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, withoutRpcv2CborSupport)
153+
154+
assertEquals(AwsJson1_0Trait.ID, resolvedProtocol)
155+
}
156+
157+
@Test
158+
fun `when SDK does not support rpcv2Cbor and service has awsJson1_0 and awsQuery then resolves awsJson1_0`() {
159+
val settings = createTestSettings()
160+
val service = createServiceWithProtocols(setOf(AwsJson1_0Trait.ID, AwsQueryTrait.ID))
161+
val serviceIndex = createServiceIndex(service)
162+
163+
val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, withoutRpcv2CborSupport)
164+
165+
assertEquals(AwsJson1_0Trait.ID, resolvedProtocol)
166+
}
167+
168+
@Test
169+
fun `when SDK does not support rpcv2Cbor and service has only awsQuery then resolves awsQuery`() {
170+
val settings = createTestSettings()
171+
val service = createServiceWithProtocols(setOf(AwsQueryTrait.ID))
172+
val serviceIndex = createServiceIndex(service)
173+
174+
val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, withoutRpcv2CborSupport)
175+
176+
assertEquals(AwsQueryTrait.ID, resolvedProtocol)
177+
}
178+
179+
// Helper functions
180+
181+
private fun createTestSettings(): SwiftSettings =
182+
SwiftSettings(
183+
service = ShapeId.from("test#TestService"),
184+
moduleName = "TestModule",
185+
moduleVersion = "1.0.0",
186+
moduleDescription = "Test module",
187+
author = "Test Author",
188+
homepage = "https://test.com",
189+
sdkId = "Test",
190+
gitRepo = "https://github.com/test/test.git",
191+
swiftVersion = "5.7",
192+
mergeModels = false,
193+
copyrightNotice = "// Test copyright",
194+
)
195+
196+
private fun createServiceWithProtocols(protocols: Set<ShapeId>): ServiceShape {
197+
var builder =
198+
ServiceShape
199+
.builder()
200+
.id("test#TestService")
201+
.version("1.0")
202+
203+
// Apply the actual protocol traits to the service
204+
for (protocolId in protocols) {
205+
when (protocolId) {
206+
Rpcv2CborTrait.ID -> builder = builder.addTrait(Rpcv2CborTrait.builder().build())
207+
AwsJson1_0Trait.ID -> builder = builder.addTrait(AwsJson1_0Trait.builder().build())
208+
AwsJson1_1Trait.ID -> builder = builder.addTrait(AwsJson1_1Trait.builder().build())
209+
RestJson1Trait.ID -> builder = builder.addTrait(RestJson1Trait.builder().build())
210+
RestXmlTrait.ID -> builder = builder.addTrait(RestXmlTrait.builder().build())
211+
AwsQueryTrait.ID -> builder = builder.addTrait(AwsQueryTrait())
212+
Ec2QueryTrait.ID -> builder = builder.addTrait(Ec2QueryTrait())
213+
}
214+
}
215+
216+
return builder.build()
217+
}
218+
219+
private fun createServiceIndex(service: ServiceShape): ServiceIndex {
220+
val modelBuilder = Model.builder()
221+
222+
// Add the service shape
223+
modelBuilder.addShape(service)
224+
225+
// Add protocol definition shapes to the model
226+
// These are needed for ServiceIndex to recognize the protocols
227+
val protocolShapes =
228+
listOf(
229+
Rpcv2CborTrait.ID,
230+
AwsJson1_0Trait.ID,
231+
AwsJson1_1Trait.ID,
232+
RestJson1Trait.ID,
233+
RestXmlTrait.ID,
234+
AwsQueryTrait.ID,
235+
Ec2QueryTrait.ID,
236+
)
237+
238+
for (protocolId in protocolShapes) {
239+
// Create a shape that represents the protocol definition
240+
// and add the ProtocolDefinitionTrait to it
241+
val protocolShape =
242+
StringShape
243+
.builder()
244+
.id(protocolId)
245+
.addTrait(
246+
ProtocolDefinitionTrait
247+
.builder()
248+
.build(),
249+
).build()
250+
modelBuilder.addShape(protocolShape)
251+
}
252+
253+
val model = modelBuilder.build()
254+
return ServiceIndex.of(model)
255+
}
28256
}

0 commit comments

Comments
 (0)