Skip to content

ThreadLocalAccessor API capable of working in scoped scenarios #103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
*
* @author Rossen Stoyanchev
* @author Brian Clozel
* @author Dariusz Jędrzejczyk
* @since 1.0.0
*/
final class DefaultContextSnapshot extends HashMap<Object, Object> implements ContextSnapshot {
Expand Down Expand Up @@ -206,7 +207,7 @@ private <V> void resetThreadLocalValue(ThreadLocalAccessor<?> accessor, @Nullabl
((ThreadLocalAccessor<V>) accessor).restore(previousValue);
}
else {
accessor.reset();
accessor.restore();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
*
* @author Rossen Stoyanchev
* @author Marcin Grzejszczak
* @author Dariusz Jędrzejczyk
* @since 1.0.0
* @see ContextRegistry#registerThreadLocalAccessor(ThreadLocalAccessor)
* @see ContextRegistry#registerThreadLocalAccessor(String, Supplier, Consumer, Runnable)
Expand All @@ -38,7 +39,18 @@ public interface ThreadLocalAccessor<V> {
Object key();

/**
* Return the current {@link ThreadLocal} value.
* Return the current {@link ThreadLocal} value, or {@code null} if not set. This
* method is called in two scenarios:
* <ul>
* <li>When capturing a {@link ContextSnapshot}. A {@code null} value would not end up
* in the snapshot and would mean the snapshot is missing a mapping for this
* accessor's {@link #key()}.</li>
* <li>When setting {@link ThreadLocal} values from a {@link ContextSnapshot} or a
* Context object (operated upon by {@link ContextAccessor}) to check for existing
* values: {@code null} means the {@link ThreadLocal} is not set and upon closing a
* {@link io.micrometer.context.ContextSnapshot.Scope}, the {@link #restore()} variant
* with no argument would be called.</li>
* </ul>
*/
@Nullable
V getValue();
Expand All @@ -53,9 +65,29 @@ public interface ThreadLocalAccessor<V> {
void setValue(V value);

/**
* Remove the {@link ThreadLocal} value.
* Called instead of {@link #setValue(Object)} in order to remove the current
* {@link ThreadLocal} value at the start of a
* {@link io.micrometer.context.ContextSnapshot.Scope}.
*
* @since 1.0.3
*/
void reset();
default void setValue() {
reset();
}

/**
* Remove the {@link ThreadLocal} value when setting {@link ThreadLocal} values in
* case of missing mapping for a {@link #key()} from a {@link ContextSnapshot}, or a
* Context object (operated upon by {@link ContextAccessor}).
* @deprecated To be replaced by calls to {@link #setValue()} (and/or
* {@link #restore()}), which needs to be implemented when this implementation is
* removed.
*/
@Deprecated
default void reset() {
throw new IllegalStateException(this.getClass().getName() + "#reset() should "
+ "not be called. Please implement #setValue() method when removing the " + "#reset() implementation.");
}

/**
* Remove the current {@link ThreadLocal} value and set the previously stored one.
Expand All @@ -69,4 +101,15 @@ default void restore(V previousValue) {
setValue(previousValue);
}

/**
* Remove the current {@link ThreadLocal} value when restoring values after a
* {@link io.micrometer.context.ContextSnapshot.Scope} closes but there was no
* {@link ThreadLocal} value present prior to the closed scope.
* @see #getValue()
* @since 1.0.3
*/
default void restore() {
setValue();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@
class ContextWrappingTests {

private final ContextRegistry registry = new ContextRegistry()
.registerThreadLocalAccessor(new ObservationThreadLocalAccessor());
.registerThreadLocalAccessor(new StringThreadLocalAccessor());

@AfterEach
void clear() {
ObservationThreadLocalHolder.reset();
StringThreadLocalHolder.reset();
}

@Test
void should_instrument_runnable() throws InterruptedException {
ObservationThreadLocalHolder.setValue("hello");
StringThreadLocalHolder.setValue("hello");
AtomicReference<String> valueInNewThread = new AtomicReference<>();
Runnable runnable = runnable(valueInNewThread);
runInNewThread(runnable);
Expand All @@ -61,10 +61,10 @@ void should_instrument_runnable() throws InterruptedException {

@Test
void should_instrument_callable() throws ExecutionException, InterruptedException, TimeoutException {
ObservationThreadLocalHolder.setValue("hello");
StringThreadLocalHolder.setValue("hello");
AtomicReference<String> valueInNewThread = new AtomicReference<>();
Callable<String> callable = () -> {
valueInNewThread.set(ObservationThreadLocalHolder.getValue());
valueInNewThread.set(StringThreadLocalHolder.getValue());
return "foo";
};
runInNewThread(callable);
Expand All @@ -78,7 +78,7 @@ void should_instrument_callable() throws ExecutionException, InterruptedExceptio

@Test
void should_instrument_executor() throws InterruptedException {
ObservationThreadLocalHolder.setValue("hello");
StringThreadLocalHolder.setValue("hello");
AtomicReference<String> valueInNewThread = new AtomicReference<>();
Executor executor = command -> new Thread(command).start();
runInNewThread(executor, valueInNewThread);
Expand All @@ -95,7 +95,7 @@ void should_instrument_executor() throws InterruptedException {
void should_instrument_executor_service() throws InterruptedException, ExecutionException, TimeoutException {
ExecutorService executorService = Executors.newSingleThreadExecutor();
try {
ObservationThreadLocalHolder.setValue("hello");
StringThreadLocalHolder.setValue("hello");
AtomicReference<String> valueInNewThread = new AtomicReference<>();
runInNewThread(executorService, valueInNewThread,
atomic -> then(atomic.get()).as("By default thread local information should not be propagated")
Expand All @@ -119,13 +119,13 @@ void should_instrument_scheduled_executor_service()
throws InterruptedException, ExecutionException, TimeoutException {
ScheduledExecutorService executorService = Executors.newSingleThreadScheduledExecutor();
try {
ObservationThreadLocalHolder.setValue("hello at time of creation of the executor");
StringThreadLocalHolder.setValue("hello at time of creation of the executor");
AtomicReference<String> valueInNewThread = new AtomicReference<>();
runInNewThread(executorService, valueInNewThread,
atomic -> then(atomic.get()).as("By default thread local information should not be propagated")
.isNull());

ObservationThreadLocalHolder.setValue("hello at time of creation of the executor");
StringThreadLocalHolder.setValue("hello at time of creation of the executor");
runInNewThread(
ContextExecutorService
.wrap(executorService, () -> ContextSnapshot.captureAllUsing(key -> true, this.registry)),
Expand Down Expand Up @@ -166,9 +166,9 @@ private void runInNewThread(ExecutorService executor, AtomicReference<String> va
Consumer<AtomicReference<String>> assertion)
throws InterruptedException, ExecutionException, TimeoutException {

ObservationThreadLocalHolder.setValue("hello"); // IMPORTANT: We are setting the
// thread local value as late as
// possible
StringThreadLocalHolder.setValue("hello"); // IMPORTANT: We are setting the
// thread local value as late as
// possible
executor.execute(runnable(valueInNewThread));
Thread.sleep(5);
assertion.accept(valueInNewThread);
Expand Down Expand Up @@ -215,12 +215,12 @@ private void runInNewThread(ScheduledExecutorService executor, AtomicReference<S
}

private Runnable runnable(AtomicReference<String> valueInNewThread) {
return () -> valueInNewThread.set(ObservationThreadLocalHolder.getValue());
return () -> valueInNewThread.set(StringThreadLocalHolder.getValue());
}

private Callable<Object> callable(AtomicReference<String> valueInNewThread) {
return () -> {
valueInNewThread.set(ObservationThreadLocalHolder.getValue());
valueInNewThread.set(StringThreadLocalHolder.getValue());
return "foo";
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import java.util.Map;

import io.micrometer.context.ContextSnapshot.Scope;
import io.micrometer.context.observation.Observation;
import io.micrometer.context.observation.ObservationThreadLocalAccessor;
import io.micrometer.context.observation.ObservationScopeThreadLocalHolder;
import org.assertj.core.api.BDDAssertions;
import org.junit.jupiter.api.Test;

Expand All @@ -38,86 +41,86 @@ public class DefaultContextSnapshotTests {

@Test
void should_propagate_thread_local() {
this.registry.registerThreadLocalAccessor(new ObservationThreadLocalAccessor());
this.registry.registerThreadLocalAccessor(new StringThreadLocalAccessor());

ObservationThreadLocalHolder.setValue("hello");
StringThreadLocalHolder.setValue("hello");
ContextSnapshot snapshot = ContextSnapshot.captureAllUsing(key -> true, this.registry);

ObservationThreadLocalHolder.setValue("hola");
StringThreadLocalHolder.setValue("hola");
try {
try (Scope scope = snapshot.setThreadLocals()) {
then(ObservationThreadLocalHolder.getValue()).isEqualTo("hello");
then(StringThreadLocalHolder.getValue()).isEqualTo("hello");
}
then(ObservationThreadLocalHolder.getValue()).isEqualTo("hola");
then(StringThreadLocalHolder.getValue()).isEqualTo("hola");
}
finally {
ObservationThreadLocalHolder.reset();
StringThreadLocalHolder.reset();
}
}

@Test
void should_propagate_single_thread_local_value() {
this.registry.registerContextAccessor(new TestContextAccessor());
this.registry.registerThreadLocalAccessor(new ObservationThreadLocalAccessor());
this.registry.registerThreadLocalAccessor(new StringThreadLocalAccessor());

String key = ObservationThreadLocalAccessor.KEY;
String key = StringThreadLocalAccessor.KEY;
Map<String, String> sourceContext = Collections.singletonMap(key, "hello");

ObservationThreadLocalHolder.setValue("hola");
StringThreadLocalHolder.setValue("hola");
try {
try (Scope scope = ContextSnapshot.setThreadLocalsFrom(sourceContext, this.registry, key)) {
then(ObservationThreadLocalHolder.getValue()).isEqualTo("hello");
then(StringThreadLocalHolder.getValue()).isEqualTo("hello");
}
then(ObservationThreadLocalHolder.getValue()).isEqualTo("hola");
then(StringThreadLocalHolder.getValue()).isEqualTo("hola");
}
finally {
ObservationThreadLocalHolder.reset();
StringThreadLocalHolder.reset();
}
}

@Test
void should_propagate_all_single_thread_local_value() {
this.registry.registerContextAccessor(new TestContextAccessor());
this.registry.registerThreadLocalAccessor(new ObservationThreadLocalAccessor());
this.registry.registerThreadLocalAccessor(new StringThreadLocalAccessor());

String key = ObservationThreadLocalAccessor.KEY;
String key = StringThreadLocalAccessor.KEY;
Map<String, String> sourceContext = Collections.singletonMap(key, "hello");

ObservationThreadLocalHolder.setValue("hola");
StringThreadLocalHolder.setValue("hola");
try {
try (Scope scope = ContextSnapshot.setAllThreadLocalsFrom(sourceContext, this.registry)) {
then(ObservationThreadLocalHolder.getValue()).isEqualTo("hello");
then(StringThreadLocalHolder.getValue()).isEqualTo("hello");
}
then(ObservationThreadLocalHolder.getValue()).isEqualTo("hola");
then(StringThreadLocalHolder.getValue()).isEqualTo("hola");
}
finally {
ObservationThreadLocalHolder.reset();
StringThreadLocalHolder.reset();
}
}

@Test
void should_override_context_values_when_many_contexts() {
this.registry.registerContextAccessor(new TestContextAccessor());

String key = ObservationThreadLocalAccessor.KEY;
String key = StringThreadLocalAccessor.KEY;
Map<String, String> firstContext = Collections.singletonMap(key, "hello");
Map<String, String> secondContext = Collections.singletonMap(key, "override");
try {
ContextSnapshot contextSnapshot = ContextSnapshot.captureFromContext(this.registry, firstContext,
secondContext);
contextSnapshot.wrap(() -> {
then(ObservationThreadLocalHolder.getValue()).isEqualTo("override");
then(StringThreadLocalHolder.getValue()).isEqualTo("override");
});
}
finally {
ObservationThreadLocalHolder.reset();
StringThreadLocalHolder.reset();
}
}

@Test
void should_throw_an_exception_when_no_keys_are_passed() {
this.registry.registerContextAccessor(new TestContextAccessor());
this.registry.registerThreadLocalAccessor(new ObservationThreadLocalAccessor());
this.registry.registerThreadLocalAccessor(new StringThreadLocalAccessor());

Map<String, String> sourceContext = Collections.singletonMap("foo", "hello");

Expand All @@ -129,7 +132,7 @@ void should_throw_an_exception_when_no_keys_are_passed() {
@Test
void should_throw_an_exception_when_no_keys_are_passed_for_version_with_no_registry() {
this.registry.registerContextAccessor(new TestContextAccessor());
this.registry.registerThreadLocalAccessor(new ObservationThreadLocalAccessor());
this.registry.registerThreadLocalAccessor(new StringThreadLocalAccessor());

Map<String, String> sourceContext = Collections.singletonMap("foo", "hello");

Expand Down Expand Up @@ -195,20 +198,20 @@ void should_filter_thread_locals_on_restore() {

@Test
void should_not_fail_on_empty_thread_local() {
this.registry.registerThreadLocalAccessor(new ObservationThreadLocalAccessor());
this.registry.registerThreadLocalAccessor(new StringThreadLocalAccessor());

then(ObservationThreadLocalHolder.getValue()).isNull();
then(StringThreadLocalHolder.getValue()).isNull();

ContextSnapshot snapshot = ContextSnapshot.captureAll(this.registry);

ObservationThreadLocalHolder.reset();
then(ObservationThreadLocalHolder.getValue()).isNull();
StringThreadLocalHolder.reset();
then(StringThreadLocalHolder.getValue()).isNull();

try (Scope scope = snapshot.setThreadLocals()) {
then(ObservationThreadLocalHolder.getValue()).isNull();
then(StringThreadLocalHolder.getValue()).isNull();
}

then(ObservationThreadLocalHolder.getValue()).isNull();
then(StringThreadLocalHolder.getValue()).isNull();
}

@Test
Expand Down Expand Up @@ -237,7 +240,6 @@ void should_ignore_null_value_in_source_context() {
}

@Test
@SuppressWarnings("unchecked")
void should_ignore_null_mapping_in_source_context_when_skipping_intermediate_snapshot() {
String key = "foo";
ThreadLocal<String> fooThreadLocal = new ThreadLocal<>();
Expand Down Expand Up @@ -301,4 +303,31 @@ void toString_should_include_values() {
barThreadLocal.remove();
}

@Test
void should_work_with_scope_based_thread_local_accessor() {
this.registry.registerContextAccessor(new TestContextAccessor());
this.registry.registerThreadLocalAccessor(new ObservationThreadLocalAccessor());

String key = ObservationThreadLocalAccessor.KEY;
Observation observation = new Observation();
Map<String, Observation> sourceContext = Collections.singletonMap(key, observation);

then(ObservationScopeThreadLocalHolder.getCurrentObservation()).isNull();
try (Scope scope1 = ContextSnapshot.setAllThreadLocalsFrom(sourceContext, this.registry)) {
then(ObservationScopeThreadLocalHolder.getCurrentObservation()).isSameAs(observation);
try (Scope scope2 = ContextSnapshot.setAllThreadLocalsFrom(Collections.emptyMap(), this.registry)) {
then(ObservationScopeThreadLocalHolder.getCurrentObservation()).isSameAs(observation);
// TODO: This should work like this in the future
// then(ObservationScopeThreadLocalHolder.getCurrentObservation()).as("We're
// resetting the observation").isNull();
// then(ObservationScopeThreadLocalHolder.getValue()).as("This is the
// 'null' scope").isNotNull();
}
then(ObservationScopeThreadLocalHolder.getCurrentObservation()).as("We're back to previous observation")
.isSameAs(observation);
}
then(ObservationScopeThreadLocalHolder.getCurrentObservation()).as("There was no observation at the beginning")
.isNull();
}

}
Loading