Skip to content

Commit 6b26033

Browse files
committed
feat: Validate the Universe Domain (#2330)
* feat: Validate the universe domain * chore: Merge in from origin/main * chore: Add comments for ApiCallContext * chore: Add comments * chore: Address PR comments * chore: Merge endpoint context in both transports * chore: Use @throws for the exceptions * chore: Provide a default EndpointContext * chore: Address PR comments * chore: Update error message * chore: Address PR comments * chore: Address PR comments * chore: Address PR comments
1 parent c4adaed commit 6b26033

23 files changed

+822
-120
lines changed

gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java

+137-15
Large diffs are not rendered by default.

gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcClientCalls.java

+4
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ public static <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
9595
channel = ClientInterceptors.intercept(channel, interceptor);
9696
}
9797

98+
// Validate the Universe Domain prior to the call. Only allow the call to go through
99+
// if the Universe Domain is valid.
100+
grpcContext.validateUniverseDomain();
101+
98102
try (Scope ignored = grpcContext.getTracer().inScope()) {
99103
return channel.newCall(descriptor, callOptions);
100104
}

gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java

+19-2
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,15 @@
3737
import com.google.api.gax.grpc.testing.FakeChannelFactory;
3838
import com.google.api.gax.grpc.testing.FakeMethodDescriptor;
3939
import com.google.api.gax.rpc.ClientContext;
40+
import com.google.api.gax.rpc.EndpointContext;
4041
import com.google.api.gax.rpc.ResponseObserver;
4142
import com.google.api.gax.rpc.ServerStreamingCallSettings;
4243
import com.google.api.gax.rpc.ServerStreamingCallable;
4344
import com.google.api.gax.rpc.StreamController;
4445
import com.google.api.gax.rpc.UnaryCallSettings;
4546
import com.google.api.gax.rpc.UnaryCallable;
4647
import com.google.api.gax.util.FakeLogHandler;
48+
import com.google.auth.Credentials;
4749
import com.google.common.base.Preconditions;
4850
import com.google.common.collect.ImmutableList;
4951
import com.google.common.collect.Lists;
@@ -628,10 +630,17 @@ public void testReleasingClientCallCancelEarly() throws IOException {
628630
ChannelPoolSettings channelPoolSettings = ChannelPoolSettings.staticallySized(1);
629631
ChannelFactory factory = new FakeChannelFactory(ImmutableList.of(fakeChannel));
630632
pool = ChannelPool.create(channelPoolSettings, factory);
633+
634+
EndpointContext endpointContext = Mockito.mock(EndpointContext.class);
635+
Mockito.doNothing()
636+
.when(endpointContext)
637+
.validateUniverseDomain(Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class));
638+
631639
ClientContext context =
632640
ClientContext.newBuilder()
633641
.setTransportChannel(GrpcTransportChannel.create(pool))
634-
.setDefaultCallContext(GrpcCallContext.of(pool, CallOptions.DEFAULT))
642+
.setDefaultCallContext(
643+
GrpcCallContext.of(pool, CallOptions.DEFAULT).withEndpointContext(endpointContext))
635644
.build();
636645
ServerStreamingCallSettings settings =
637646
ServerStreamingCallSettings.<Color, Money>newBuilder().build();
@@ -680,11 +689,19 @@ public void testDoubleRelease() throws Exception {
680689

681690
pool = ChannelPool.create(channelPoolSettings, factory);
682691

692+
EndpointContext endpointContext = Mockito.mock(EndpointContext.class);
693+
Mockito.doNothing()
694+
.when(endpointContext)
695+
.validateUniverseDomain(
696+
Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class));
697+
683698
// Construct a fake callable to use the channel pool
684699
ClientContext context =
685700
ClientContext.newBuilder()
686701
.setTransportChannel(GrpcTransportChannel.create(pool))
687-
.setDefaultCallContext(GrpcCallContext.of(pool, CallOptions.DEFAULT))
702+
.setDefaultCallContext(
703+
GrpcCallContext.of(pool, CallOptions.DEFAULT)
704+
.withEndpointContext(endpointContext))
688705
.build();
689706

690707
UnaryCallSettings<Color, Money> settings =

gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcCallContextTest.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import io.grpc.CallOptions;
4747
import io.grpc.ManagedChannel;
4848
import io.grpc.Metadata.Key;
49+
import java.io.IOException;
4950
import java.util.ArrayList;
5051
import java.util.Collections;
5152
import java.util.HashMap;
@@ -373,7 +374,7 @@ public void testWithOptions() {
373374
}
374375

375376
@Test
376-
public void testMergeOptions() {
377+
public void testMergeOptions() throws IOException {
377378
GrpcCallContext emptyCallContext = GrpcCallContext.createDefault();
378379
ApiCallContext.Key<String> contextKey1 = ApiCallContext.Key.create("testKey1");
379380
ApiCallContext.Key<String> contextKey2 = ApiCallContext.Key.create("testKey2");

gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcCallableFactoryTest.java

+13-7
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,15 @@
3535
import com.google.api.gax.grpc.testing.FakeServiceImpl;
3636
import com.google.api.gax.grpc.testing.InProcessServer;
3737
import com.google.api.gax.retrying.RetrySettings;
38+
import com.google.api.gax.rpc.ApiCallContext;
3839
import com.google.api.gax.rpc.ClientContext;
40+
import com.google.api.gax.rpc.EndpointContext;
3941
import com.google.api.gax.rpc.InvalidArgumentException;
4042
import com.google.api.gax.rpc.ServerStreamingCallSettings;
4143
import com.google.api.gax.rpc.ServerStreamingCallable;
4244
import com.google.api.gax.rpc.StatusCode.Code;
4345
import com.google.api.gax.tracing.SpanName;
46+
import com.google.auth.Credentials;
4447
import com.google.common.collect.ImmutableList;
4548
import com.google.common.truth.Truth;
4649
import com.google.type.Color;
@@ -74,10 +77,16 @@ public void setUp() throws Exception {
7477
inprocessServer.start();
7578

7679
channel = InProcessChannelBuilder.forName(serverName).directExecutor().usePlaintext().build();
80+
EndpointContext endpointContext = Mockito.mock(EndpointContext.class);
81+
Mockito.doNothing()
82+
.when(endpointContext)
83+
.validateUniverseDomain(Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class));
7784
clientContext =
7885
ClientContext.newBuilder()
7986
.setTransportChannel(GrpcTransportChannel.create(channel))
80-
.setDefaultCallContext(GrpcCallContext.of(channel, CallOptions.DEFAULT))
87+
.setDefaultCallContext(
88+
GrpcCallContext.of(channel, CallOptions.DEFAULT)
89+
.withEndpointContext(endpointContext))
8190
.build();
8291
}
8392

@@ -106,11 +115,10 @@ public void createServerStreamingCallableRetryableExceptions() {
106115
GrpcCallableFactory.createServerStreamingCallable(
107116
grpcCallSettings, nonRetryableSettings, clientContext);
108117

118+
ApiCallContext defaultCallContext = clientContext.getDefaultCallContext();
109119
Throwable actualError = null;
110120
try {
111-
nonRetryableCallable
112-
.first()
113-
.call(Color.getDefaultInstance(), clientContext.getDefaultCallContext());
121+
nonRetryableCallable.first().call(Color.getDefaultInstance(), defaultCallContext);
114122
} catch (Throwable e) {
115123
actualError = e;
116124
}
@@ -134,9 +142,7 @@ public void createServerStreamingCallableRetryableExceptions() {
134142

135143
Throwable actualError2 = null;
136144
try {
137-
retryableCallable
138-
.first()
139-
.call(Color.getDefaultInstance(), clientContext.getDefaultCallContext());
145+
retryableCallable.first().call(Color.getDefaultInstance(), defaultCallContext);
140146
} catch (Throwable e) {
141147
actualError2 = e;
142148
}

gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcClientCallsTest.java

+117-10
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,16 @@
3030
package com.google.api.gax.grpc;
3131

3232
import static com.google.common.truth.Truth.assertThat;
33+
import static org.junit.Assert.assertThrows;
3334
import static org.mockito.Mockito.verify;
3435

3536
import com.google.api.gax.grpc.testing.FakeChannelFactory;
3637
import com.google.api.gax.grpc.testing.FakeServiceGrpc;
38+
import com.google.api.gax.rpc.EndpointContext;
39+
import com.google.api.gax.rpc.UnauthenticatedException;
40+
import com.google.api.gax.rpc.UnavailableException;
41+
import com.google.auth.Credentials;
42+
import com.google.auth.Retryable;
3743
import com.google.common.collect.ImmutableList;
3844
import com.google.common.truth.Truth;
3945
import com.google.type.Color;
@@ -45,18 +51,58 @@
4551
import io.grpc.ManagedChannel;
4652
import io.grpc.Metadata;
4753
import io.grpc.MethodDescriptor;
54+
import io.grpc.Status;
4855
import java.io.IOException;
4956
import java.util.Arrays;
5057
import java.util.HashMap;
5158
import java.util.List;
5259
import java.util.Map;
5360
import java.util.concurrent.TimeUnit;
61+
import org.junit.Before;
5462
import org.junit.Test;
5563
import org.mockito.ArgumentCaptor;
5664
import org.mockito.Mockito;
5765
import org.threeten.bp.Duration;
5866

5967
public class GrpcClientCallsTest {
68+
69+
// Auth Library's GoogleAuthException is package-private. Copy basic functionality for tests
70+
private static class GoogleAuthException extends IOException implements Retryable {
71+
72+
private final boolean isRetryable;
73+
74+
private GoogleAuthException(boolean isRetryable) {
75+
this.isRetryable = isRetryable;
76+
}
77+
78+
@Override
79+
public boolean isRetryable() {
80+
return isRetryable;
81+
}
82+
83+
@Override
84+
public int getRetryCount() {
85+
return 0;
86+
}
87+
}
88+
89+
private GrpcCallContext defaultCallContext;
90+
private EndpointContext endpointContext;
91+
private Credentials credentials;
92+
private Channel mockChannel;
93+
94+
@Before
95+
public void setUp() throws IOException {
96+
credentials = Mockito.mock(Credentials.class);
97+
endpointContext = Mockito.mock(EndpointContext.class);
98+
mockChannel = Mockito.mock(Channel.class);
99+
100+
defaultCallContext = GrpcCallContext.createDefault().withEndpointContext(endpointContext);
101+
Mockito.doNothing()
102+
.when(endpointContext)
103+
.validateUniverseDomain(Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class));
104+
}
105+
60106
@Test
61107
public void testAffinity() throws IOException {
62108
MethodDescriptor<Color, Money> descriptor = FakeServiceGrpc.METHOD_RECOGNIZE;
@@ -78,7 +124,7 @@ public void testAffinity() throws IOException {
78124
ChannelPool.create(
79125
ChannelPoolSettings.staticallySized(2),
80126
new FakeChannelFactory(Arrays.asList(channel0, channel1)));
81-
GrpcCallContext context = GrpcCallContext.createDefault().withChannel(pool);
127+
GrpcCallContext context = defaultCallContext.withChannel(pool);
82128

83129
ClientCall<Color, Money> gotCallA =
84130
GrpcClientCalls.newCall(descriptor, context.withChannelAffinity(0));
@@ -92,7 +138,7 @@ public void testAffinity() throws IOException {
92138
}
93139

94140
@Test
95-
public void testExtraHeaders() {
141+
public void testExtraHeaders() throws IOException {
96142
Metadata emptyHeaders = new Metadata();
97143
final Map<String, List<String>> extraHeaders = new HashMap<>();
98144
extraHeaders.put(
@@ -128,12 +174,12 @@ public void testExtraHeaders() {
128174
.thenReturn(mockClientCall);
129175

130176
GrpcCallContext context =
131-
GrpcCallContext.createDefault().withChannel(mockChannel).withExtraHeaders(extraHeaders);
177+
defaultCallContext.withChannel(mockChannel).withExtraHeaders(extraHeaders);
132178
GrpcClientCalls.newCall(descriptor, context).start(mockListener, emptyHeaders);
133179
}
134180

135181
@Test
136-
public void testTimeoutToDeadlineConversion() {
182+
public void testTimeoutToDeadlineConversion() throws IOException {
137183
MethodDescriptor<Color, Money> descriptor = FakeServiceGrpc.METHOD_RECOGNIZE;
138184

139185
@SuppressWarnings("unchecked")
@@ -152,8 +198,7 @@ public void testTimeoutToDeadlineConversion() {
152198
Duration timeout = Duration.ofSeconds(10);
153199
Deadline minExpectedDeadline = Deadline.after(timeout.getSeconds(), TimeUnit.SECONDS);
154200

155-
GrpcCallContext context =
156-
GrpcCallContext.createDefault().withChannel(mockChannel).withTimeout(timeout);
201+
GrpcCallContext context = defaultCallContext.withChannel(mockChannel).withTimeout(timeout);
157202

158203
GrpcClientCalls.newCall(descriptor, context).start(mockListener, new Metadata());
159204

@@ -164,7 +209,7 @@ public void testTimeoutToDeadlineConversion() {
164209
}
165210

166211
@Test
167-
public void testTimeoutAfterDeadline() {
212+
public void testTimeoutAfterDeadline() throws IOException {
168213
MethodDescriptor<Color, Money> descriptor = FakeServiceGrpc.METHOD_RECOGNIZE;
169214

170215
@SuppressWarnings("unchecked")
@@ -185,7 +230,7 @@ public void testTimeoutAfterDeadline() {
185230
Duration timeout = Duration.ofSeconds(10);
186231

187232
GrpcCallContext context =
188-
GrpcCallContext.createDefault()
233+
defaultCallContext
189234
.withChannel(mockChannel)
190235
.withCallOptions(CallOptions.DEFAULT.withDeadline(priorDeadline))
191236
.withTimeout(timeout);
@@ -197,7 +242,7 @@ public void testTimeoutAfterDeadline() {
197242
}
198243

199244
@Test
200-
public void testTimeoutBeforeDeadline() {
245+
public void testTimeoutBeforeDeadline() throws IOException {
201246
MethodDescriptor<Color, Money> descriptor = FakeServiceGrpc.METHOD_RECOGNIZE;
202247

203248
@SuppressWarnings("unchecked")
@@ -219,7 +264,7 @@ public void testTimeoutBeforeDeadline() {
219264
Deadline minExpectedDeadline = Deadline.after(timeout.getSeconds(), TimeUnit.SECONDS);
220265

221266
GrpcCallContext context =
222-
GrpcCallContext.createDefault()
267+
defaultCallContext
223268
.withChannel(mockChannel)
224269
.withCallOptions(CallOptions.DEFAULT.withDeadline(subsequentDeadline))
225270
.withTimeout(timeout);
@@ -232,4 +277,66 @@ public void testTimeoutBeforeDeadline() {
232277
Truth.assertThat(capturedCallOptions.getValue().getDeadline()).isAtLeast(minExpectedDeadline);
233278
Truth.assertThat(capturedCallOptions.getValue().getDeadline()).isAtMost(maxExpectedDeadline);
234279
}
280+
281+
@Test
282+
public void testValidUniverseDomain() throws IOException {
283+
GrpcCallContext context =
284+
GrpcCallContext.createDefault()
285+
.withChannel(mockChannel)
286+
.withCredentials(credentials)
287+
.withEndpointContext(endpointContext);
288+
289+
CallOptions callOptions = context.getCallOptions();
290+
291+
MethodDescriptor<Color, Money> descriptor = FakeServiceGrpc.METHOD_RECOGNIZE;
292+
GrpcClientCalls.newCall(descriptor, context);
293+
Mockito.verify(mockChannel, Mockito.times(1)).newCall(descriptor, callOptions);
294+
}
295+
296+
// This test is when the universe domain does not match
297+
@Test
298+
public void testInvalidUniverseDomain() throws IOException {
299+
Mockito.doThrow(
300+
new UnauthenticatedException(
301+
null, GrpcStatusCode.of(Status.Code.UNAUTHENTICATED), false))
302+
.when(endpointContext)
303+
.validateUniverseDomain(Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class));
304+
GrpcCallContext context =
305+
GrpcCallContext.createDefault()
306+
.withChannel(mockChannel)
307+
.withCredentials(credentials)
308+
.withEndpointContext(endpointContext);
309+
310+
CallOptions callOptions = context.getCallOptions();
311+
312+
MethodDescriptor<Color, Money> descriptor = FakeServiceGrpc.METHOD_RECOGNIZE;
313+
UnauthenticatedException exception =
314+
assertThrows(
315+
UnauthenticatedException.class, () -> GrpcClientCalls.newCall(descriptor, context));
316+
assertThat(exception.getStatusCode().getCode()).isEqualTo(GrpcStatusCode.Code.UNAUTHENTICATED);
317+
Mockito.verify(mockChannel, Mockito.never()).newCall(descriptor, callOptions);
318+
}
319+
320+
// This test is when the MDS is unable to return a valid universe domain
321+
@Test
322+
public void testUniverseDomainNotReady_shouldRetry() throws IOException {
323+
Mockito.doThrow(new GoogleAuthException(true))
324+
.when(endpointContext)
325+
.validateUniverseDomain(Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class));
326+
GrpcCallContext context =
327+
GrpcCallContext.createDefault()
328+
.withChannel(mockChannel)
329+
.withCredentials(credentials)
330+
.withEndpointContext(endpointContext);
331+
332+
CallOptions callOptions = context.getCallOptions();
333+
334+
MethodDescriptor<Color, Money> descriptor = FakeServiceGrpc.METHOD_RECOGNIZE;
335+
UnavailableException exception =
336+
assertThrows(
337+
UnavailableException.class, () -> GrpcClientCalls.newCall(descriptor, context));
338+
assertThat(exception.getStatusCode().getCode()).isEqualTo(GrpcStatusCode.Code.UNAVAILABLE);
339+
Truth.assertThat(exception.isRetryable()).isTrue();
340+
Mockito.verify(mockChannel, Mockito.never()).newCall(descriptor, callOptions);
341+
}
235342
}

0 commit comments

Comments
 (0)