Skip to content

Commit 9699898

Browse files
committed
Enhance thread-safety in ClientSideCaching key retrieval redis#2402
Added per-key locking mechanism to ensure that only a single thread fetches a value from Redis cache for a specific key. This prevents redundant Redis server calls under high concurrent load.
1 parent 966d8a9 commit 9699898

File tree

2 files changed

+197
-44
lines changed

2 files changed

+197
-44
lines changed

src/main/java/io/lettuce/core/support/caching/ClientSideCaching.java

+56-31
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
package io.lettuce.core.support.caching;
22

3-
import java.util.List;
4-
import java.util.concurrent.Callable;
5-
import java.util.concurrent.CopyOnWriteArrayList;
6-
import java.util.function.Consumer;
7-
83
import io.lettuce.core.StatefulRedisConnectionImpl;
94
import io.lettuce.core.TrackingArgs;
105
import io.lettuce.core.api.StatefulRedisConnection;
116
import io.lettuce.core.codec.RedisCodec;
7+
import java.util.List;
8+
import java.util.concurrent.Callable;
9+
import java.util.concurrent.ConcurrentHashMap;
10+
import java.util.concurrent.CopyOnWriteArrayList;
11+
import java.util.concurrent.locks.ReentrantLock;
12+
import java.util.function.Consumer;
1213

1314
/**
1415
* Utility to provide server-side assistance for client-side caches. This is a {@link CacheFrontend} that represents a two-level
@@ -41,6 +42,8 @@ public class ClientSideCaching<K, V> implements CacheFrontend<K, V> {
4142

4243
private final List<Consumer<K>> invalidationListeners = new CopyOnWriteArrayList<>();
4344

45+
private final ConcurrentHashMap<K, ReentrantLock> keyLocks = new ConcurrentHashMap<>();
46+
4447
private ClientSideCaching(CacheAccessor<K, V> cacheAccessor, RedisCache<K, V> redisCache) {
4548
this.cacheAccessor = cacheAccessor;
4649
this.redisCache = redisCache;
@@ -49,12 +52,12 @@ private ClientSideCaching(CacheAccessor<K, V> cacheAccessor, RedisCache<K, V> re
4952
/**
5053
* Enable server-assisted Client side caching for the given {@link CacheAccessor} and {@link StatefulRedisConnection}.
5154
* <p>
52-
* Note that the {@link CacheFrontend} is associated with a Redis connection. Make sure to {@link CacheFrontend#close()
53-
* close} the frontend object to release the Redis connection after use.
55+
* Note that the {@link CacheFrontend} is associated with a Redis connection. Make sure to
56+
* {@link CacheFrontend#close() close} the frontend object to release the Redis connection after use.
5457
*
5558
* @param cacheAccessor the accessor used to interact with the client-side cache.
5659
* @param connection the Redis connection to use. The connection will be associated with {@link CacheFrontend} and must be
57-
* closed through {@link CacheFrontend#close()}.
60+
* closed through {@link CacheFrontend#close()}.
5861
* @param tracking the tracking parameters.
5962
* @param <K> Key type.
6063
* @param <V> Value type.
@@ -72,12 +75,12 @@ public static <K, V> CacheFrontend<K, V> enable(CacheAccessor<K, V> cacheAccesso
7275
* Create a server-assisted Client side caching for the given {@link CacheAccessor} and {@link StatefulRedisConnection}.
7376
* This method expects that client key tracking is already configured.
7477
* <p>
75-
* Note that the {@link CacheFrontend} is associated with a Redis connection. Make sure to {@link CacheFrontend#close()
76-
* close} the frontend object to release the Redis connection after use.
78+
* Note that the {@link CacheFrontend} is associated with a Redis connection. Make sure to
79+
* {@link CacheFrontend#close() close} the frontend object to release the Redis connection after use.
7780
*
7881
* @param cacheAccessor the accessor used to interact with the client-side cache.
7982
* @param connection the Redis connection to use. The connection will be associated with {@link CacheFrontend} and must be
80-
* closed through {@link CacheFrontend#close()}.
83+
* closed through {@link CacheFrontend#close()}.
8184
* @param <K> Key type.
8285
* @param <V> Value type.
8386
* @return the {@link CacheFrontend} for value retrieval.
@@ -103,6 +106,7 @@ private static <K, V> CacheFrontend<K, V> create(CacheAccessor<K, V> cacheAccess
103106
}
104107

105108
private void notifyInvalidate(K key) {
109+
keyLocks.remove(key);
106110

107111
for (java.util.function.Consumer<K> invalidationListener : invalidationListeners) {
108112
invalidationListener.accept(key);
@@ -111,6 +115,7 @@ private void notifyInvalidate(K key) {
111115

112116
@Override
113117
public void close() {
118+
keyLocks.clear();
114119
redisCache.close();
115120
}
116121

@@ -124,10 +129,20 @@ public V get(K key) {
124129
V value = cacheAccessor.get(key);
125130

126131
if (value == null) {
127-
value = redisCache.get(key);
132+
ReentrantLock keyLock = keyLocks.computeIfAbsent(key, k -> new ReentrantLock());
133+
keyLock.lock();
134+
try {
135+
value = cacheAccessor.get(key);
136+
137+
if (value == null) {
138+
value = redisCache.get(key);
128139

129-
if (value != null) {
130-
cacheAccessor.put(key, value);
140+
if (value != null) {
141+
cacheAccessor.put(key, value);
142+
}
143+
}
144+
} finally {
145+
keyLock.unlock();
131146
}
132147
}
133148

@@ -140,28 +155,38 @@ public V get(K key, Callable<V> valueLoader) {
140155
V value = cacheAccessor.get(key);
141156

142157
if (value == null) {
143-
value = redisCache.get(key);
144-
145-
if (value == null) {
158+
ReentrantLock keyLock = keyLocks.computeIfAbsent(key, k -> new ReentrantLock());
159+
keyLock.lock();
146160

147-
try {
148-
value = valueLoader.call();
149-
} catch (Exception e) {
150-
throw new ValueRetrievalException(
151-
String.format("Value loader %s failed with an exception for key %s", valueLoader, key), e);
152-
}
161+
try {
162+
value = cacheAccessor.get(key);
153163

154164
if (value == null) {
155-
throw new ValueRetrievalException(
156-
String.format("Value loader %s returned a null value for key %s", valueLoader, key));
157-
}
158-
redisCache.put(key, value);
165+
value = redisCache.get(key);
159166

160-
// register interest in key
161-
redisCache.get(key);
162-
}
167+
if (value == null) {
168+
try {
169+
value = valueLoader.call();
170+
} catch (Exception e) {
171+
throw new ValueRetrievalException(
172+
String.format("Value loader %s failed with an exception for key %s", valueLoader, key), e);
173+
}
163174

164-
cacheAccessor.put(key, value);
175+
if (value == null) {
176+
throw new ValueRetrievalException(
177+
String.format("Value loader %s returned a null value for key %s", valueLoader, key));
178+
}
179+
180+
redisCache.put(key, value);
181+
182+
redisCache.get(key);
183+
}
184+
185+
cacheAccessor.put(key, value);
186+
}
187+
} finally {
188+
keyLock.unlock();
189+
}
165190
}
166191

167192
return value;

src/test/java/io/lettuce/core/support/caching/ClientsideCachingIntegrationTests.java

+141-13
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,6 @@
33
import static io.lettuce.TestTags.INTEGRATION_TEST;
44
import static org.assertj.core.api.Assertions.assertThat;
55

6-
import java.util.HashMap;
7-
import java.util.List;
8-
import java.util.Map;
9-
import java.util.concurrent.ConcurrentHashMap;
10-
import java.util.concurrent.CopyOnWriteArrayList;
11-
12-
import javax.inject.Inject;
13-
14-
import org.junit.jupiter.api.BeforeEach;
15-
import org.junit.jupiter.api.Tag;
16-
import org.junit.jupiter.api.Test;
17-
import org.junit.jupiter.api.extension.ExtendWith;
18-
196
import io.lettuce.core.ClientOptions;
207
import io.lettuce.core.RedisClient;
218
import io.lettuce.core.TestSupport;
@@ -29,11 +16,29 @@
2916
import io.lettuce.test.LettuceExtension;
3017
import io.lettuce.test.Wait;
3118
import io.lettuce.test.condition.EnabledOnCommand;
19+
import java.lang.reflect.Field;
20+
import java.util.HashMap;
21+
import java.util.List;
22+
import java.util.Map;
23+
import java.util.concurrent.ConcurrentHashMap;
24+
import java.util.concurrent.CopyOnWriteArrayList;
25+
import java.util.concurrent.CountDownLatch;
26+
import java.util.concurrent.ExecutorService;
27+
import java.util.concurrent.Executors;
28+
import java.util.concurrent.TimeUnit;
29+
import java.util.concurrent.atomic.AtomicInteger;
30+
import java.util.concurrent.locks.ReentrantLock;
31+
import javax.inject.Inject;
32+
import org.junit.jupiter.api.BeforeEach;
33+
import org.junit.jupiter.api.Tag;
34+
import org.junit.jupiter.api.Test;
35+
import org.junit.jupiter.api.extension.ExtendWith;
3236

3337
/**
3438
* Integration tests for server-side assisted cache invalidation.
3539
*
3640
* @author Mark Paluch
41+
* @author Yoobin Yoon
3742
*/
3843
@Tag(INTEGRATION_TEST)
3944
@ExtendWith(LettuceExtension.class)
@@ -227,4 +232,127 @@ void serverAssistedCachingShouldUseValueLoader() throws InterruptedException {
227232
frontend.close();
228233
}
229234

235+
@Test
236+
void valueLoaderShouldBeInvokedOnceForConcurrentRequests() throws Exception {
237+
238+
Map<String, String> clientCache = new ConcurrentHashMap<>();
239+
240+
StatefulRedisConnection<String, String> connection = redisClient.connect();
241+
242+
final String testKey = "concurrent-loader-key";
243+
connection.sync().del(testKey);
244+
245+
AtomicInteger loaderCallCount = new AtomicInteger(0);
246+
247+
CacheFrontend<String, String> frontend = ClientSideCaching.enable(CacheAccessor.forMap(clientCache), connection,
248+
TrackingArgs.Builder.enabled());
249+
250+
try {
251+
int threadCount = 10;
252+
CountDownLatch startLatch = new CountDownLatch(1);
253+
CountDownLatch finishLatch = new CountDownLatch(threadCount);
254+
List<String> results = new CopyOnWriteArrayList<>();
255+
256+
ExecutorService executor = Executors.newFixedThreadPool(threadCount);
257+
for (int i = 0; i < threadCount; i++) {
258+
executor.submit(() -> {
259+
try {
260+
startLatch.await();
261+
262+
String result = frontend.get(testKey, () -> {
263+
loaderCallCount.incrementAndGet();
264+
265+
try {
266+
Thread.sleep(100);
267+
} catch (InterruptedException e) {
268+
Thread.currentThread().interrupt();
269+
}
270+
271+
return "loaded-value";
272+
});
273+
274+
results.add(result);
275+
} catch (Exception e) {
276+
e.printStackTrace();
277+
} finally {
278+
finishLatch.countDown();
279+
}
280+
});
281+
}
282+
283+
startLatch.countDown();
284+
285+
finishLatch.await(5, TimeUnit.SECONDS);
286+
executor.shutdown();
287+
288+
assertThat(loaderCallCount.get()).isEqualTo(1);
289+
290+
assertThat(results).hasSize(threadCount);
291+
assertThat(results).containsOnly("loaded-value");
292+
293+
assertThat(connection.sync().get(testKey)).isEqualTo("loaded-value");
294+
295+
assertThat(clientCache).containsEntry(testKey, "loaded-value");
296+
} finally {
297+
frontend.close();
298+
connection.close();
299+
}
300+
}
301+
302+
@Test
303+
void locksShouldBeProperlyCleanedUp() throws Exception {
304+
305+
Map<String, String> clientCache = new ConcurrentHashMap<>();
306+
307+
StatefulRedisConnection<String, String> connection = redisClient.connect();
308+
StatefulRedisConnection<String, String> otherClient = redisClient.connect();
309+
310+
final String testKey1 = "lock-test-key1";
311+
final String testKey2 = "lock-test-key2";
312+
final String initialValue = "initial-value";
313+
final String updatedValue = "updated-value";
314+
315+
connection.sync().del(testKey1, testKey2);
316+
connection.sync().set(testKey1, initialValue);
317+
connection.sync().set(testKey2, initialValue);
318+
319+
ClientSideCaching<String, String> frontend = (ClientSideCaching<String, String>) ClientSideCaching.enable(
320+
CacheAccessor.forMap(clientCache), connection, TrackingArgs.Builder.enabled());
321+
322+
Field keyLocksField = ClientSideCaching.class.getDeclaredField("keyLocks");
323+
keyLocksField.setAccessible(true);
324+
ConcurrentHashMap<String, ReentrantLock> keyLocks = (ConcurrentHashMap<String, ReentrantLock>) keyLocksField.get(
325+
frontend);
326+
327+
try {
328+
frontend.get(testKey1);
329+
frontend.get(testKey2);
330+
331+
assertThat(keyLocks).containsKey(testKey1);
332+
assertThat(keyLocks).containsKey(testKey2);
333+
assertThat(keyLocks).hasSize(2);
334+
335+
otherClient.sync().set(testKey1, updatedValue);
336+
337+
Thread.sleep(200);
338+
339+
assertThat(keyLocks).doesNotContainKey(testKey1);
340+
assertThat(keyLocks).containsKey(testKey2);
341+
assertThat(keyLocks).hasSize(1);
342+
343+
frontend.get(testKey1);
344+
345+
assertThat(keyLocks).containsKey(testKey1);
346+
assertThat(keyLocks).hasSize(2);
347+
348+
frontend.close();
349+
350+
assertThat(keyLocks).isEmpty();
351+
352+
} finally {
353+
connection.close();
354+
otherClient.close();
355+
}
356+
}
357+
230358
}

0 commit comments

Comments
 (0)