Skip to content

Commit 3c28d1f

Browse files
committed
Address comments
1 parent f993d8a commit 3c28d1f

File tree

5 files changed

+164
-42
lines changed

5 files changed

+164
-42
lines changed

Diff for: msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenByManagedIdentitySupplier.java

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ AuthenticationResult execute() throws Exception {
5252
SilentParameters parameters = SilentParameters
5353
.builder(scopes)
5454
.tenant(managedIdentityParameters.tenant())
55+
.claims(managedIdentityParameters.claims())
5556
.build();
5657

5758
RequestContext context = new RequestContext(

Diff for: msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityApplication.java

+7-6
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,16 @@ public Builder resource(String resource) {
103103
}
104104

105105
/**
106-
* Sets client capabilities for the application.
106+
* Informs the token issuer that the application is able to perform complex authentication actions.
107+
* For example, "cp1" means that the application is able to perform conditional access evaluation,
108+
* because the application has been setup to parse WWW-Authenticate headers associated with a 401 response from the protected APIs,
109+
* and to retry the request with claims API.
107110
*
108-
* @param clientCapabilities List of client capabilities to be requested
109-
* @return instance of Builder of ManagedIdentityApplication
111+
* @param clientCapabilities a list of capabilities (e.g., ["cp1"]) recognized by the token service.
112+
* @return instance of Builder of ManagedIdentityApplication.
110113
*/
111114
public Builder clientCapabilities(List<String> clientCapabilities) {
112-
if (clientCapabilities != null) {
113-
this.clientCapabilities = clientCapabilities;
114-
}
115+
this.clientCapabilities = clientCapabilities;
115116
return self();
116117
}
117118

Diff for: msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityParameters.java

+22-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,18 @@ public Set<String> scopes() {
2929

3030
@Override
3131
public ClaimsRequest claims() {
32-
return (claims != null) ? ClaimsRequest.formatAsClaimsRequest(claims) : null;
32+
if (claims == null || claims.isEmpty()) {
33+
throw new MsalClientException("Claims cannot be null or empty",
34+
AuthenticationErrorCode.INVALID_JSON);
35+
}
36+
37+
try {
38+
return ClaimsRequest.formatAsClaimsRequest(claims);
39+
} catch (Exception ex) {
40+
// Log the exception if the claims JSON is invalid
41+
throw new MsalClientException("Failed to parse claims JSON: " + ex.getMessage(),
42+
AuthenticationErrorCode.INVALID_JSON);
43+
}
3344
}
3445

3546
@Override
@@ -86,6 +97,16 @@ public ManagedIdentityParametersBuilder forceRefresh(boolean forceRefresh) {
8697
return this;
8798
}
8899

100+
/**
101+
* Instructs the SDK to bypass any token caches and to request new tokens with an additional claims challenge.
102+
* The claims challenge string is opaque to applications and should not be parsed.
103+
* The claims challenge string is issued either by the STS as part of an error response or by the resource,
104+
* as part of an HTTP 401 response, in the WWW-Authenticate header.
105+
* For more details see https://learn.microsoft.com/entra/identity-platform/app-resilience-continuous-access-evaluation?tabs=dotnet
106+
*
107+
* @param claims a valid JSON string representing additional claims
108+
* @return this builder instance
109+
*/
89110
public ManagedIdentityParametersBuilder claims(String claims) {
90111
this.claims = claims;
91112
return this;

Diff for: msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestDataProvider.java

+8
Original file line numberDiff line numberDiff line change
@@ -114,4 +114,12 @@ public static Stream<Arguments> createDataGetSource() {
114114
Arguments.of(ManagedIdentitySourceType.IMDS, "", ManagedIdentitySourceType.DEFAULT_TO_IMDS),
115115
Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint, ManagedIdentitySourceType.SERVICE_FABRIC));
116116
}
117+
118+
public static Stream<Arguments> createInvalidClaimsData() {
119+
return Stream.of(
120+
Arguments.of("invalid json format"),
121+
Arguments.of("{\"access_token\": }"),
122+
Arguments.of("")
123+
);
124+
}
117125
}

Diff for: msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java

+126-35
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,21 @@
44
package com.microsoft.aad.msal4j;
55

66
import com.nimbusds.oauth2.sdk.util.URLUtils;
7+
import labapi.App;
78
import org.junit.jupiter.api.Nested;
89
import org.junit.jupiter.api.Test;
910
import org.junit.jupiter.api.TestInstance;
1011
import org.junit.jupiter.api.extension.ExtendWith;
1112
import org.junit.jupiter.params.ParameterizedTest;
1213
import org.junit.jupiter.params.provider.MethodSource;
1314
import org.junit.jupiter.params.provider.ValueSource;
15+
import org.mockito.ArgumentCaptor;
1416
import org.mockito.junit.jupiter.MockitoExtension;
1517

1618
import java.net.SocketException;
1719
import java.nio.file.Path;
1820
import java.nio.file.Paths;
21+
import java.util.Collections;
1922
import java.util.HashMap;
2023
import java.util.List;
2124
import java.util.Map;
@@ -78,56 +81,51 @@ private HttpRequest expectedRequest(ManagedIdentitySourceType source, String res
7881
Map<String, List<String>> queryParameters = new HashMap<>();
7982

8083
switch (source) {
81-
case APP_SERVICE: {
84+
case APP_SERVICE:
8285
endpoint = appServiceEndpoint;
83-
84-
queryParameters.put("api-version", singletonList("2019-08-01"));
85-
queryParameters.put("resource", singletonList(resource));
86-
86+
queryParameters.put("api-version", Collections.singletonList("2019-08-01"));
87+
queryParameters.put("resource", Collections.singletonList(resource));
8788
headers.put("X-IDENTITY-HEADER", "secret");
8889
break;
89-
}
90-
case CLOUD_SHELL: {
90+
case CLOUD_SHELL:
9191
endpoint = cloudShellEndpoint;
92-
9392
headers.put("ContentType", "application/x-www-form-urlencoded");
9493
headers.put("Metadata", "true");
95-
96-
queryParameters.put("resource", singletonList(resource));
94+
queryParameters.put("resource", Collections.singletonList(resource));
9795
break;
98-
}
99-
case IMDS: {
96+
case IMDS:
10097
endpoint = IMDS_ENDPOINT;
101-
queryParameters.put("api-version", singletonList("2018-02-01"));
102-
queryParameters.put("resource", singletonList(resource));
98+
queryParameters.put("api-version", Collections.singletonList("2018-02-01"));
99+
queryParameters.put("resource", Collections.singletonList(resource));
103100
headers.put("Metadata", "true");
104101
break;
105-
}
106-
case AZURE_ARC: {
102+
case AZURE_ARC:
107103
endpoint = azureArcEndpoint;
108-
109-
queryParameters.put("api-version", singletonList("2019-11-01"));
110-
queryParameters.put("resource", singletonList(resource));
111-
104+
queryParameters.put("api-version", Collections.singletonList("2019-11-01"));
105+
queryParameters.put("resource", Collections.singletonList(resource));
112106
headers.put("Metadata", "true");
113107
break;
114-
}
115-
case SERVICE_FABRIC: {
108+
case SERVICE_FABRIC:
116109
endpoint = serviceFabricEndpoint;
117-
queryParameters.put("api-version", singletonList("2019-07-01-preview"));
118-
queryParameters.put("resource", singletonList(resource));
119-
110+
queryParameters.put("api-version", Collections.singletonList("2019-07-01-preview"));
111+
queryParameters.put("resource", Collections.singletonList(resource));
120112
headers.put("secret", "secret");
121113
break;
122-
}
114+
case NONE:
115+
case DEFAULT_TO_IMDS:
116+
endpoint = IMDS_ENDPOINT;
117+
queryParameters.put("api-version", Collections.singletonList("2018-02-01"));
118+
queryParameters.put("resource", Collections.singletonList(resource));
119+
headers.put("Metadata", "true");
120+
break;
123121
}
124122

125123
switch (id.getIdType()) {
126124
case CLIENT_ID:
127-
queryParameters.put("client_id", singletonList(id.getUserAssignedId()));
125+
queryParameters.put("client_id", Collections.singletonList(id.getUserAssignedId()));
128126
break;
129127
case RESOURCE_ID:
130-
queryParameters.put("mi_res_id", singletonList(id.getUserAssignedId()));
128+
queryParameters.put("mi_res_id", Collections.singletonList(id.getUserAssignedId()));
131129
break;
132130
case OBJECT_ID:
133131
queryParameters.put("object_id", singletonList(id.getUserAssignedId()));
@@ -314,9 +312,10 @@ void managedIdentityTest_UserAssigned_NotSupported(ManagedIdentitySourceType sou
314312
miApp.tokenCache().accessTokens.clear();
315313

316314
try {
317-
IAuthenticationResult result = miApp.acquireTokenForManagedIdentity(
315+
miApp.acquireTokenForManagedIdentity(
318316
ManagedIdentityParameters.builder(resource)
319317
.build()).get();
318+
fail("MsalServiceException is expected but not thrown.");
320319
} catch (Exception e) {
321320
assertNotNull(e);
322321
assertNotNull(e.getCause());
@@ -325,10 +324,7 @@ void managedIdentityTest_UserAssigned_NotSupported(ManagedIdentitySourceType sou
325324
MsalServiceException msalMsiException = (MsalServiceException) e.getCause();
326325
assertEquals(source.name(), msalMsiException.managedIdentitySource());
327326
assertEquals(MsalError.USER_ASSIGNED_MANAGED_IDENTITY_NOT_SUPPORTED, msalMsiException.errorCode());
328-
return;
329327
}
330-
331-
fail("MsalServiceException is expected but not thrown.");
332328
}
333329

334330
@ParameterizedTest
@@ -637,21 +633,116 @@ void managedIdentityTest_WithClaims(ManagedIdentitySourceType source, String end
637633

638634
miApp = ManagedIdentityApplication
639635
.builder(ManagedIdentityId.systemAssigned())
640-
.clientCapabilities(singletonList("cp1"))
641636
.httpClient(httpClientMock)
642637
.build();
643638

644639
// Clear caching to avoid cross test pollution.
645640
miApp.tokenCache().accessTokens.clear();
646641

647642
String claimsJson = "{\"default\":\"claim\"}";
643+
644+
// First call, get the token from the identity provider.
648645
IAuthenticationResult result = miApp.acquireTokenForManagedIdentity(
649646
ManagedIdentityParameters.builder(resource)
650-
.claims(claimsJson)
651647
.build()).get();
652648

653649
assertNotNull(result.accessToken());
654-
verify(httpClientMock, times(1)).send(any());
650+
assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource());
651+
652+
// Second call, get the token from the cache without passing the claims.
653+
result = miApp.acquireTokenForManagedIdentity(
654+
ManagedIdentityParameters.builder(resource)
655+
.build()).get();
656+
657+
assertNotNull(result.accessToken());
658+
assertEquals(TokenSource.CACHE, result.metadata().tokenSource());
659+
660+
// Third call, when claims are passed bypass the cache.
661+
result = miApp.acquireTokenForManagedIdentity(
662+
ManagedIdentityParameters.builder(resource)
663+
.claims(claimsJson)
664+
.build()).get();
665+
666+
assertNotNull(result.accessToken());
667+
assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource());
668+
669+
verify(httpClientMock, times(2)).send(any());
670+
}
671+
672+
@ParameterizedTest
673+
@MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError")
674+
void managedIdentity_ClaimsAndCapabilities(ManagedIdentitySourceType source, String endpoint) throws Exception {
675+
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
676+
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
677+
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
678+
if (source == SERVICE_FABRIC) {
679+
ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock);
680+
}
681+
682+
when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource)));
683+
684+
miApp = ManagedIdentityApplication
685+
.builder(ManagedIdentityId.systemAssigned())
686+
.clientCapabilities(singletonList("cp1"))
687+
.httpClient(httpClientMock)
688+
.build();
689+
690+
// Clear caching to avoid cross test pollution.
691+
miApp.tokenCache().accessTokens.clear();
692+
693+
String claimsJson = "{\"default\":\"claim\"}";
694+
// First call, get the token from the identity provider.
695+
IAuthenticationResult result = miApp.acquireTokenForManagedIdentity(
696+
ManagedIdentityParameters.builder(resource)
697+
.build()).get();
698+
699+
assertNotNull(result.accessToken());
700+
assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource());
701+
702+
// Second call, get the token from the cache without passing the claims.
703+
result = miApp.acquireTokenForManagedIdentity(
704+
ManagedIdentityParameters.builder(resource)
705+
.build()).get();
706+
707+
assertNotNull(result.accessToken());
708+
assertEquals(TokenSource.CACHE, result.metadata().tokenSource());
709+
710+
// Third call, when claims are passed bypass the cache.
711+
result = miApp.acquireTokenForManagedIdentity(
712+
ManagedIdentityParameters.builder(resource)
713+
.claims(claimsJson)
714+
.build()).get();
715+
716+
assertNotNull(result.accessToken());
717+
assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource());
718+
719+
verify(httpClientMock, times(2)).send(any());
720+
}
721+
722+
@ParameterizedTest
723+
@MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createInvalidClaimsData")
724+
void managedIdentity_InvalidClaims(String claimsJson) throws Exception {
725+
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(APP_SERVICE, appServiceEndpoint);
726+
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
727+
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
728+
729+
miApp = ManagedIdentityApplication
730+
.builder(ManagedIdentityId.systemAssigned())
731+
.httpClient(httpClientMock)
732+
.build();
733+
734+
CompletableFuture<IAuthenticationResult> future = miApp.acquireTokenForManagedIdentity(
735+
ManagedIdentityParameters.builder(resource)
736+
.claims(claimsJson)
737+
.build());
738+
739+
ExecutionException ex = assertThrows(ExecutionException.class, future::get);
740+
assertInstanceOf(MsalClientException.class, ex.getCause());
741+
MsalClientException msalException = (MsalClientException) ex.getCause();
742+
assertEquals(AuthenticationErrorCode.INVALID_JSON, msalException.errorCode());
743+
744+
// Verify no HTTP requests were made for invalid claims
745+
verify(httpClientMock, never()).send(any());
655746
}
656747

657748
@Nested

0 commit comments

Comments
 (0)