Skip to content

Commit 5ec96c5

Browse files
authored
Fix socket interceptor and add unit test (#17784)
The optimization of using a Stream instead of a Collection caused problems with class not found and/or illegal access errors when using the lambda function in the `Stream::forEach` call in the intercept method. Signed-off-by: Andrew Ross <[email protected]>
1 parent eb90570 commit 5ec96c5

File tree

5 files changed

+65
-22
lines changed

5 files changed

+65
-22
lines changed

libs/agent-sm/agent/src/main/java/org/opensearch/javaagent/SocketChannelInterceptor.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import java.net.UnixDomainSocketAddress;
1919
import java.security.Policy;
2020
import java.security.ProtectionDomain;
21-
import java.util.stream.Stream;
21+
import java.util.Collection;
2222

2323
import net.bytebuddy.asm.Advice;
2424
import net.bytebuddy.asm.Advice.Origin;
@@ -47,26 +47,26 @@ public static void intercept(@Advice.AllArguments Object[] args, @Origin Method
4747
}
4848

4949
final StackWalker walker = StackWalker.getInstance(Option.RETAIN_CLASS_REFERENCE);
50-
final Stream<ProtectionDomain> callers = walker.walk(StackCallerProtectionDomainChainExtractor.INSTANCE);
50+
final Collection<ProtectionDomain> callers = walker.walk(StackCallerProtectionDomainChainExtractor.INSTANCE);
5151

5252
if (args[0] instanceof InetSocketAddress address) {
5353
if (!AgentPolicy.isTrustedHost(address.getHostString())) {
5454
final String host = address.getHostString() + ":" + address.getPort();
5555

5656
final SocketPermission permission = new SocketPermission(host, "connect,resolve");
57-
callers.forEach(domain -> {
57+
for (ProtectionDomain domain : callers) {
5858
if (!policy.implies(domain, permission)) {
5959
throw new SecurityException("Denied access to: " + host + ", domain " + domain);
6060
}
61-
});
61+
}
6262
}
6363
} else if (args[0] instanceof UnixDomainSocketAddress address) {
6464
final NetPermission permission = new NetPermission("accessUnixDomainSocket");
65-
callers.forEach(domain -> {
65+
for (ProtectionDomain domain : callers) {
6666
if (!policy.implies(domain, permission)) {
6767
throw new SecurityException("Denied access to: " + address + ", domain " + domain);
6868
}
69-
});
69+
}
7070
}
7171
}
7272
}

libs/agent-sm/agent/src/main/java/org/opensearch/javaagent/StackCallerProtectionDomainChainExtractor.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010

1111
import java.lang.StackWalker.StackFrame;
1212
import java.security.ProtectionDomain;
13+
import java.util.Collection;
1314
import java.util.function.Function;
15+
import java.util.stream.Collectors;
1416
import java.util.stream.Stream;
1517

1618
/**
1719
* Stack Caller Chain Extractor
1820
*/
19-
public final class StackCallerProtectionDomainChainExtractor implements Function<Stream<StackFrame>, Stream<ProtectionDomain>> {
21+
public final class StackCallerProtectionDomainChainExtractor implements Function<Stream<StackFrame>, Collection<ProtectionDomain>> {
2022
/**
2123
* Single instance of stateless class.
2224
*/
@@ -32,10 +34,10 @@ private StackCallerProtectionDomainChainExtractor() {}
3234
* @param frames stack frames
3335
*/
3436
@Override
35-
public Stream<ProtectionDomain> apply(Stream<StackFrame> frames) {
37+
public Collection<ProtectionDomain> apply(Stream<StackFrame> frames) {
3638
return frames.map(StackFrame::getDeclaringClass)
3739
.map(Class::getProtectionDomain)
3840
.filter(pd -> pd.getCodeSource() != null) /* JDK */
39-
.distinct();
41+
.collect(Collectors.toSet());
4042
}
4143
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*/
8+
9+
package org.opensearch.javaagent;
10+
11+
import org.opensearch.javaagent.bootstrap.AgentPolicy;
12+
import org.junit.BeforeClass;
13+
14+
import java.security.Policy;
15+
import java.util.Set;
16+
17+
public abstract class AgentTestCase {
18+
@SuppressWarnings("removal")
19+
@BeforeClass
20+
public static void setUp() {
21+
AgentPolicy.setPolicy(new Policy() {
22+
}, Set.of(), (caller, chain) -> caller.getName().equalsIgnoreCase("worker.org.gradle.process.internal.worker.GradleWorkerMain"));
23+
}
24+
}

libs/agent-sm/agent/src/test/java/org/opensearch/javaagent/AgentTests.java

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,9 @@
88

99
package org.opensearch.javaagent;
1010

11-
import org.opensearch.javaagent.bootstrap.AgentPolicy;
12-
import org.junit.BeforeClass;
1311
import org.junit.Test;
1412

15-
import java.security.Policy;
16-
import java.util.Set;
17-
18-
public class AgentTests {
19-
@SuppressWarnings("removal")
20-
@BeforeClass
21-
public static void setUp() {
22-
AgentPolicy.setPolicy(new Policy() {
23-
}, Set.of(), (caller, chain) -> caller.getName().equalsIgnoreCase("worker.org.gradle.process.internal.worker.GradleWorkerMain"));
24-
}
25-
13+
public class AgentTests extends AgentTestCase {
2614
@Test(expected = SecurityException.class)
2715
public void testSystemExitIsForbidden() {
2816
System.exit(0);
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*/
8+
9+
package org.opensearch.javaagent;
10+
11+
import org.junit.Test;
12+
13+
import java.io.IOException;
14+
import java.net.InetSocketAddress;
15+
import java.net.UnixDomainSocketAddress;
16+
import java.nio.channels.SocketChannel;
17+
18+
import static org.junit.Assert.assertThrows;
19+
20+
public class SocketChannelInterceptorTests extends AgentTestCase {
21+
@Test
22+
public void test() throws IOException {
23+
try (SocketChannel channel = SocketChannel.open()) {
24+
assertThrows(SecurityException.class, () -> channel.connect(new InetSocketAddress("localhost", 9200)));
25+
26+
assertThrows(SecurityException.class, () -> channel.connect(UnixDomainSocketAddress.of("fake-path")));
27+
}
28+
}
29+
}

0 commit comments

Comments
 (0)