Skip to content

Commit

Permalink
Make LocalCache not use synchronized to detect recursive loads.
Browse files Browse the repository at this point in the history
Fixes #6851
Fixes #6845

RELNOTES=n/a
PiperOrigin-RevId: 586666218
  • Loading branch information
Christian Ortlepp authored and Google Java Core Libraries committed Dec 1, 2023
1 parent f4c1264 commit 4117d35
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 14 deletions.
Expand Up @@ -58,6 +58,7 @@
import com.google.common.testing.NullPointerTester;
import com.google.common.testing.SerializableTester;
import com.google.common.testing.TestLogHandler;
import com.google.common.util.concurrent.UncheckedExecutionException;
import java.io.Serializable;
import java.lang.ref.Reference;
import java.lang.ref.ReferenceQueue;
Expand Down Expand Up @@ -2639,8 +2640,86 @@ public void testSerializationProxyManual() {
assertEquals(localCacheTwo.ticker, localCacheThree.ticker);
}

public void testLoadDifferentKeyInLoader() throws ExecutionException, InterruptedException {
LocalCache<String, String> cache = makeLocalCache(createCacheBuilder());
String key1 = "key1";
String key2 = "key2";

assertEquals(
key2,
cache.get(
key1,
new CacheLoader<String, String>() {
@Override
public String load(String key) throws Exception {
return cache.get(key2, identityLoader()); // loads a different key, should work
}
}));
}

public void testRecursiveLoad() throws InterruptedException {
LocalCache<String, String> cache = makeLocalCache(createCacheBuilder());
String key = "key";
CacheLoader<String, String> loader =
new CacheLoader<String, String>() {
@Override
public String load(String key) throws Exception {
return cache.get(key, identityLoader()); // recursive load, this should fail
}
};
testLoadThrows(key, cache, loader);
}

public void testRecursiveLoadWithProxy() throws InterruptedException {
String key = "key";
String otherKey = "otherKey";
LocalCache<String, String> cache = makeLocalCache(createCacheBuilder());
CacheLoader<String, String> loader =
new CacheLoader<String, String>() {
@Override
public String load(String key) throws Exception {
return cache.get(
key,
identityLoader()); // recursive load (same as the initial one), this should fail
}
};
CacheLoader<String, String> proxyLoader =
new CacheLoader<String, String>() {
@Override
public String load(String key) throws Exception {
return cache.get(otherKey, loader); // loads another key, is ok
}
};
testLoadThrows(key, cache, proxyLoader);
}

// utility methods

private void testLoadThrows(
String key, LocalCache<String, String> cache, CacheLoader<String, String> loader)
throws InterruptedException {
CountDownLatch doneSignal = new CountDownLatch(1);
Thread thread =
new Thread(
() -> {
try {
cache.get(key, loader);
} catch (UncheckedExecutionException | ExecutionException e) {
doneSignal.countDown();
}
});
thread.start();

boolean done = doneSignal.await(1, TimeUnit.SECONDS);
if (!done) {
StringBuilder builder = new StringBuilder();
for (StackTraceElement trace : thread.getStackTrace()) {
builder.append("\tat ").append(trace).append('\n');
}
fail(builder.toString());
}
}

/**
* Returns an iterable containing all combinations of maximumSize, expireAfterAccess/Write,
* weakKeys and weak/softValues.
Expand Down
31 changes: 24 additions & 7 deletions android/guava/src/com/google/common/cache/LocalCache.java
Expand Up @@ -2180,12 +2180,7 @@ V lockedGetOrLoad(K key, int hash, CacheLoader<? super K, V> loader) throws Exec

if (createNewEntry) {
try {
// Synchronizes on the entry to allow failing fast when a recursive load is
// detected. This may be circumvented when an entry is copied, but will fail fast most
// of the time.
synchronized (e) {
return loadSync(key, hash, loadingValueReference, loader);
}
return loadSync(key, hash, loadingValueReference, loader);
} finally {
statsCounter.recordMisses(1);
}
Expand All @@ -2201,7 +2196,22 @@ V waitForLoadingValue(ReferenceEntry<K, V> e, K key, ValueReference<K, V> valueR
throw new AssertionError();
}

checkState(!Thread.holdsLock(e), "Recursive load of: %s", key);
// As of this writing, the only prod ValueReference implementation for which isLoading() is
// true is LoadingValueReference. (Note, however, that not all LoadingValueReference instances
// have isLoading()==true: LoadingValueReference has a subclass, ComputingValueReference, for
// which isLoading() is false!) However, that might change, and we already have a *test*
// implementation for which it doesn't hold. So we check instanceof to be safe.
if (valueReference instanceof LoadingValueReference) {
// We check whether the thread that is loading the entry is our current thread, which would
// mean that we both load and wait for the entry. In this case, we fail fast instead of
// deadlocking.
checkState(
((LoadingValueReference<K, V>) valueReference).getLoadingThread()
!= Thread.currentThread(),
"Recursive load of: %s",
key);
}

// don't consider expiration as we're concurrent with loading
try {
V value = valueReference.waitForValue();
Expand Down Expand Up @@ -3427,6 +3437,8 @@ static class LoadingValueReference<K, V> implements ValueReference<K, V> {
final SettableFuture<V> futureValue = SettableFuture.create();
final Stopwatch stopwatch = Stopwatch.createUnstarted();

final Thread loadingThread;

public LoadingValueReference() {
this(LocalCache.<K, V>unset());
}
Expand All @@ -3438,6 +3450,7 @@ public LoadingValueReference() {
*/
public LoadingValueReference(ValueReference<K, V> oldValue) {
this.oldValue = oldValue;
this.loadingThread = Thread.currentThread();
}

@Override
Expand Down Expand Up @@ -3541,6 +3554,10 @@ public ValueReference<K, V> copyFor(
ReferenceQueue<V> queue, @CheckForNull V value, ReferenceEntry<K, V> entry) {
return this;
}

Thread getLoadingThread() {
return this.loadingThread;
}
}

// Queues
Expand Down
79 changes: 79 additions & 0 deletions guava-tests/test/com/google/common/cache/LocalCacheTest.java
Expand Up @@ -58,6 +58,7 @@
import com.google.common.testing.NullPointerTester;
import com.google.common.testing.SerializableTester;
import com.google.common.testing.TestLogHandler;
import com.google.common.util.concurrent.UncheckedExecutionException;
import java.io.Serializable;
import java.lang.ref.Reference;
import java.lang.ref.ReferenceQueue;
Expand Down Expand Up @@ -2688,8 +2689,86 @@ public void testSerializationProxyManual() {
assertEquals(localCacheTwo.ticker, localCacheThree.ticker);
}

public void testLoadDifferentKeyInLoader() throws ExecutionException, InterruptedException {
LocalCache<String, String> cache = makeLocalCache(createCacheBuilder());
String key1 = "key1";
String key2 = "key2";

assertEquals(
key2,
cache.get(
key1,
new CacheLoader<String, String>() {
@Override
public String load(String key) throws Exception {
return cache.get(key2, identityLoader()); // loads a different key, should work
}
}));
}

public void testRecursiveLoad() throws InterruptedException {
LocalCache<String, String> cache = makeLocalCache(createCacheBuilder());
String key = "key";
CacheLoader<String, String> loader =
new CacheLoader<String, String>() {
@Override
public String load(String key) throws Exception {
return cache.get(key, identityLoader()); // recursive load, this should fail
}
};
testLoadThrows(key, cache, loader);
}

public void testRecursiveLoadWithProxy() throws InterruptedException {
String key = "key";
String otherKey = "otherKey";
LocalCache<String, String> cache = makeLocalCache(createCacheBuilder());
CacheLoader<String, String> loader =
new CacheLoader<String, String>() {
@Override
public String load(String key) throws Exception {
return cache.get(
key,
identityLoader()); // recursive load (same as the initial one), this should fail
}
};
CacheLoader<String, String> proxyLoader =
new CacheLoader<String, String>() {
@Override
public String load(String key) throws Exception {
return cache.get(otherKey, loader); // loads another key, is ok
}
};
testLoadThrows(key, cache, proxyLoader);
}

// utility methods

private void testLoadThrows(
String key, LocalCache<String, String> cache, CacheLoader<String, String> loader)
throws InterruptedException {
CountDownLatch doneSignal = new CountDownLatch(1);
Thread thread =
new Thread(
() -> {
try {
cache.get(key, loader);
} catch (UncheckedExecutionException | ExecutionException e) {
doneSignal.countDown();
}
});
thread.start();

boolean done = doneSignal.await(1, TimeUnit.SECONDS);
if (!done) {
StringBuilder builder = new StringBuilder();
for (StackTraceElement trace : thread.getStackTrace()) {
builder.append("\tat ").append(trace).append('\n');
}
fail(builder.toString());
}
}

/**
* Returns an iterable containing all combinations of maximumSize, expireAfterAccess/Write,
* weakKeys and weak/softValues.
Expand Down
31 changes: 24 additions & 7 deletions guava/src/com/google/common/cache/LocalCache.java
Expand Up @@ -2184,12 +2184,7 @@ V lockedGetOrLoad(K key, int hash, CacheLoader<? super K, V> loader) throws Exec

if (createNewEntry) {
try {
// Synchronizes on the entry to allow failing fast when a recursive load is
// detected. This may be circumvented when an entry is copied, but will fail fast most
// of the time.
synchronized (e) {
return loadSync(key, hash, loadingValueReference, loader);
}
return loadSync(key, hash, loadingValueReference, loader);
} finally {
statsCounter.recordMisses(1);
}
Expand All @@ -2205,7 +2200,22 @@ V waitForLoadingValue(ReferenceEntry<K, V> e, K key, ValueReference<K, V> valueR
throw new AssertionError();
}

checkState(!Thread.holdsLock(e), "Recursive load of: %s", key);
// As of this writing, the only prod ValueReference implementation for which isLoading() is
// true is LoadingValueReference. (Note, however, that not all LoadingValueReference instances
// have isLoading()==true: LoadingValueReference has a subclass, ComputingValueReference, for
// which isLoading() is false!) However, that might change, and we already have a *test*
// implementation for which it doesn't hold. So we check instanceof to be safe.
if (valueReference instanceof LoadingValueReference) {
// We check whether the thread that is loading the entry is our current thread, which would
// mean that we both load and wait for the entry. In this case, we fail fast instead of
// deadlocking.
checkState(
((LoadingValueReference<K, V>) valueReference).getLoadingThread()
!= Thread.currentThread(),
"Recursive load of: %s",
key);
}

// don't consider expiration as we're concurrent with loading
try {
V value = valueReference.waitForValue();
Expand Down Expand Up @@ -3517,12 +3527,15 @@ static class LoadingValueReference<K, V> implements ValueReference<K, V> {
final SettableFuture<V> futureValue = SettableFuture.create();
final Stopwatch stopwatch = Stopwatch.createUnstarted();

final Thread loadingThread;

public LoadingValueReference() {
this(null);
}

public LoadingValueReference(@CheckForNull ValueReference<K, V> oldValue) {
this.oldValue = (oldValue == null) ? LocalCache.unset() : oldValue;
this.loadingThread = Thread.currentThread();
}

@Override
Expand Down Expand Up @@ -3647,6 +3660,10 @@ public ValueReference<K, V> copyFor(
ReferenceQueue<V> queue, @CheckForNull V value, ReferenceEntry<K, V> entry) {
return this;
}

Thread getLoadingThread() {
return this.loadingThread;
}
}

static class ComputingValueReference<K, V> extends LoadingValueReference<K, V> {
Expand Down

0 comments on commit 4117d35

Please sign in to comment.