4
4
package com .microsoft .aad .msal4j ;
5
5
6
6
import com .nimbusds .oauth2 .sdk .util .URLUtils ;
7
+ import labapi .App ;
7
8
import org .junit .jupiter .api .Nested ;
8
9
import org .junit .jupiter .api .Test ;
9
10
import org .junit .jupiter .api .TestInstance ;
10
11
import org .junit .jupiter .api .extension .ExtendWith ;
11
12
import org .junit .jupiter .params .ParameterizedTest ;
12
13
import org .junit .jupiter .params .provider .MethodSource ;
13
14
import org .junit .jupiter .params .provider .ValueSource ;
15
+ import org .mockito .ArgumentCaptor ;
14
16
import org .mockito .junit .jupiter .MockitoExtension ;
15
17
16
18
import java .net .SocketException ;
17
19
import java .nio .file .Path ;
18
20
import java .nio .file .Paths ;
21
+ import java .util .Collections ;
19
22
import java .util .HashMap ;
20
23
import java .util .List ;
21
24
import java .util .Map ;
@@ -78,56 +81,51 @@ private HttpRequest expectedRequest(ManagedIdentitySourceType source, String res
78
81
Map <String , List <String >> queryParameters = new HashMap <>();
79
82
80
83
switch (source ) {
81
- case APP_SERVICE : {
84
+ case APP_SERVICE :
82
85
endpoint = appServiceEndpoint ;
83
-
84
- queryParameters .put ("api-version" , singletonList ("2019-08-01" ));
85
- queryParameters .put ("resource" , singletonList (resource ));
86
-
86
+ queryParameters .put ("api-version" , Collections .singletonList ("2019-08-01" ));
87
+ queryParameters .put ("resource" , Collections .singletonList (resource ));
87
88
headers .put ("X-IDENTITY-HEADER" , "secret" );
88
89
break ;
89
- }
90
- case CLOUD_SHELL : {
90
+ case CLOUD_SHELL :
91
91
endpoint = cloudShellEndpoint ;
92
-
93
92
headers .put ("ContentType" , "application/x-www-form-urlencoded" );
94
93
headers .put ("Metadata" , "true" );
95
-
96
- queryParameters .put ("resource" , singletonList (resource ));
94
+ queryParameters .put ("resource" , Collections .singletonList (resource ));
97
95
break ;
98
- }
99
- case IMDS : {
96
+ case IMDS :
100
97
endpoint = IMDS_ENDPOINT ;
101
- queryParameters .put ("api-version" , singletonList ("2018-02-01" ));
102
- queryParameters .put ("resource" , singletonList (resource ));
98
+ queryParameters .put ("api-version" , Collections . singletonList ("2018-02-01" ));
99
+ queryParameters .put ("resource" , Collections . singletonList (resource ));
103
100
headers .put ("Metadata" , "true" );
104
101
break ;
105
- }
106
- case AZURE_ARC : {
102
+ case AZURE_ARC :
107
103
endpoint = azureArcEndpoint ;
108
-
109
- queryParameters .put ("api-version" , singletonList ("2019-11-01" ));
110
- queryParameters .put ("resource" , singletonList (resource ));
111
-
104
+ queryParameters .put ("api-version" , Collections .singletonList ("2019-11-01" ));
105
+ queryParameters .put ("resource" , Collections .singletonList (resource ));
112
106
headers .put ("Metadata" , "true" );
113
107
break ;
114
- }
115
- case SERVICE_FABRIC : {
108
+ case SERVICE_FABRIC :
116
109
endpoint = serviceFabricEndpoint ;
117
- queryParameters .put ("api-version" , singletonList ("2019-07-01-preview" ));
118
- queryParameters .put ("resource" , singletonList (resource ));
119
-
110
+ queryParameters .put ("api-version" , Collections .singletonList ("2019-07-01-preview" ));
111
+ queryParameters .put ("resource" , Collections .singletonList (resource ));
120
112
headers .put ("secret" , "secret" );
121
113
break ;
122
- }
114
+ case NONE :
115
+ case DEFAULT_TO_IMDS :
116
+ endpoint = IMDS_ENDPOINT ;
117
+ queryParameters .put ("api-version" , Collections .singletonList ("2018-02-01" ));
118
+ queryParameters .put ("resource" , Collections .singletonList (resource ));
119
+ headers .put ("Metadata" , "true" );
120
+ break ;
123
121
}
124
122
125
123
switch (id .getIdType ()) {
126
124
case CLIENT_ID :
127
- queryParameters .put ("client_id" , singletonList (id .getUserAssignedId ()));
125
+ queryParameters .put ("client_id" , Collections . singletonList (id .getUserAssignedId ()));
128
126
break ;
129
127
case RESOURCE_ID :
130
- queryParameters .put ("mi_res_id" , singletonList (id .getUserAssignedId ()));
128
+ queryParameters .put ("mi_res_id" , Collections . singletonList (id .getUserAssignedId ()));
131
129
break ;
132
130
case OBJECT_ID :
133
131
queryParameters .put ("object_id" , singletonList (id .getUserAssignedId ()));
@@ -314,9 +312,10 @@ void managedIdentityTest_UserAssigned_NotSupported(ManagedIdentitySourceType sou
314
312
miApp .tokenCache ().accessTokens .clear ();
315
313
316
314
try {
317
- IAuthenticationResult result = miApp .acquireTokenForManagedIdentity (
315
+ miApp .acquireTokenForManagedIdentity (
318
316
ManagedIdentityParameters .builder (resource )
319
317
.build ()).get ();
318
+ fail ("MsalServiceException is expected but not thrown." );
320
319
} catch (Exception e ) {
321
320
assertNotNull (e );
322
321
assertNotNull (e .getCause ());
@@ -325,10 +324,7 @@ void managedIdentityTest_UserAssigned_NotSupported(ManagedIdentitySourceType sou
325
324
MsalServiceException msalMsiException = (MsalServiceException ) e .getCause ();
326
325
assertEquals (source .name (), msalMsiException .managedIdentitySource ());
327
326
assertEquals (MsalError .USER_ASSIGNED_MANAGED_IDENTITY_NOT_SUPPORTED , msalMsiException .errorCode ());
328
- return ;
329
327
}
330
-
331
- fail ("MsalServiceException is expected but not thrown." );
332
328
}
333
329
334
330
@ ParameterizedTest
@@ -637,21 +633,116 @@ void managedIdentityTest_WithClaims(ManagedIdentitySourceType source, String end
637
633
638
634
miApp = ManagedIdentityApplication
639
635
.builder (ManagedIdentityId .systemAssigned ())
640
- .clientCapabilities (singletonList ("cp1" ))
641
636
.httpClient (httpClientMock )
642
637
.build ();
643
638
644
639
// Clear caching to avoid cross test pollution.
645
640
miApp .tokenCache ().accessTokens .clear ();
646
641
647
642
String claimsJson = "{\" default\" :\" claim\" }" ;
643
+
644
+ // First call, get the token from the identity provider.
648
645
IAuthenticationResult result = miApp .acquireTokenForManagedIdentity (
649
646
ManagedIdentityParameters .builder (resource )
650
- .claims (claimsJson )
651
647
.build ()).get ();
652
648
653
649
assertNotNull (result .accessToken ());
654
- verify (httpClientMock , times (1 )).send (any ());
650
+ assertEquals (TokenSource .IDENTITY_PROVIDER , result .metadata ().tokenSource ());
651
+
652
+ // Second call, get the token from the cache without passing the claims.
653
+ result = miApp .acquireTokenForManagedIdentity (
654
+ ManagedIdentityParameters .builder (resource )
655
+ .build ()).get ();
656
+
657
+ assertNotNull (result .accessToken ());
658
+ assertEquals (TokenSource .CACHE , result .metadata ().tokenSource ());
659
+
660
+ // Third call, when claims are passed bypass the cache.
661
+ result = miApp .acquireTokenForManagedIdentity (
662
+ ManagedIdentityParameters .builder (resource )
663
+ .claims (claimsJson )
664
+ .build ()).get ();
665
+
666
+ assertNotNull (result .accessToken ());
667
+ assertEquals (TokenSource .IDENTITY_PROVIDER , result .metadata ().tokenSource ());
668
+
669
+ verify (httpClientMock , times (2 )).send (any ());
670
+ }
671
+
672
+ @ ParameterizedTest
673
+ @ MethodSource ("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError" )
674
+ void managedIdentity_ClaimsAndCapabilities (ManagedIdentitySourceType source , String endpoint ) throws Exception {
675
+ IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper (source , endpoint );
676
+ ManagedIdentityApplication .setEnvironmentVariables (environmentVariables );
677
+ DefaultHttpClient httpClientMock = mock (DefaultHttpClient .class );
678
+ if (source == SERVICE_FABRIC ) {
679
+ ServiceFabricManagedIdentitySource .setHttpClient (httpClientMock );
680
+ }
681
+
682
+ when (httpClientMock .send (expectedRequest (source , resource ))).thenReturn (expectedResponse (200 , getSuccessfulResponse (resource )));
683
+
684
+ miApp = ManagedIdentityApplication
685
+ .builder (ManagedIdentityId .systemAssigned ())
686
+ .clientCapabilities (singletonList ("cp1" ))
687
+ .httpClient (httpClientMock )
688
+ .build ();
689
+
690
+ // Clear caching to avoid cross test pollution.
691
+ miApp .tokenCache ().accessTokens .clear ();
692
+
693
+ String claimsJson = "{\" default\" :\" claim\" }" ;
694
+ // First call, get the token from the identity provider.
695
+ IAuthenticationResult result = miApp .acquireTokenForManagedIdentity (
696
+ ManagedIdentityParameters .builder (resource )
697
+ .build ()).get ();
698
+
699
+ assertNotNull (result .accessToken ());
700
+ assertEquals (TokenSource .IDENTITY_PROVIDER , result .metadata ().tokenSource ());
701
+
702
+ // Second call, get the token from the cache without passing the claims.
703
+ result = miApp .acquireTokenForManagedIdentity (
704
+ ManagedIdentityParameters .builder (resource )
705
+ .build ()).get ();
706
+
707
+ assertNotNull (result .accessToken ());
708
+ assertEquals (TokenSource .CACHE , result .metadata ().tokenSource ());
709
+
710
+ // Third call, when claims are passed bypass the cache.
711
+ result = miApp .acquireTokenForManagedIdentity (
712
+ ManagedIdentityParameters .builder (resource )
713
+ .claims (claimsJson )
714
+ .build ()).get ();
715
+
716
+ assertNotNull (result .accessToken ());
717
+ assertEquals (TokenSource .IDENTITY_PROVIDER , result .metadata ().tokenSource ());
718
+
719
+ verify (httpClientMock , times (2 )).send (any ());
720
+ }
721
+
722
+ @ ParameterizedTest
723
+ @ MethodSource ("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createInvalidClaimsData" )
724
+ void managedIdentity_InvalidClaims (String claimsJson ) throws Exception {
725
+ IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper (APP_SERVICE , appServiceEndpoint );
726
+ ManagedIdentityApplication .setEnvironmentVariables (environmentVariables );
727
+ DefaultHttpClient httpClientMock = mock (DefaultHttpClient .class );
728
+
729
+ miApp = ManagedIdentityApplication
730
+ .builder (ManagedIdentityId .systemAssigned ())
731
+ .httpClient (httpClientMock )
732
+ .build ();
733
+
734
+ CompletableFuture <IAuthenticationResult > future = miApp .acquireTokenForManagedIdentity (
735
+ ManagedIdentityParameters .builder (resource )
736
+ .claims (claimsJson )
737
+ .build ());
738
+
739
+ ExecutionException ex = assertThrows (ExecutionException .class , future ::get );
740
+ assertInstanceOf (MsalClientException .class , ex .getCause ());
741
+ MsalClientException msalException = (MsalClientException ) ex .getCause ();
742
+ assertEquals (AuthenticationErrorCode .INVALID_JSON , msalException .errorCode ());
743
+
744
+ // Verify no HTTP requests were made for invalid claims
745
+ verify (httpClientMock , never ()).send (any ());
655
746
}
656
747
657
748
@ Nested
0 commit comments