diff --git a/android/guava-tests/test/com/google/common/cache/LocalCacheTest.java b/android/guava-tests/test/com/google/common/cache/LocalCacheTest.java index 431c962f4792..cd1d340d5ce1 100644 --- a/android/guava-tests/test/com/google/common/cache/LocalCacheTest.java +++ b/android/guava-tests/test/com/google/common/cache/LocalCacheTest.java @@ -58,6 +58,7 @@ import com.google.common.testing.NullPointerTester; import com.google.common.testing.SerializableTester; import com.google.common.testing.TestLogHandler; +import com.google.common.util.concurrent.UncheckedExecutionException; import java.io.Serializable; import java.lang.ref.Reference; import java.lang.ref.ReferenceQueue; @@ -2639,8 +2640,86 @@ public void testSerializationProxyManual() { assertEquals(localCacheTwo.ticker, localCacheThree.ticker); } + public void testLoadDifferentKeyInLoader() throws ExecutionException, InterruptedException { + LocalCache cache = makeLocalCache(createCacheBuilder()); + String key1 = "key1"; + String key2 = "key2"; + + assertEquals( + key2, + cache.get( + key1, + new CacheLoader() { + @Override + public String load(String key) throws Exception { + return cache.get(key2, identityLoader()); // loads a different key, should work + } + })); + } + + public void testRecursiveLoad() throws InterruptedException { + LocalCache cache = makeLocalCache(createCacheBuilder()); + String key = "key"; + CacheLoader loader = + new CacheLoader() { + @Override + public String load(String key) throws Exception { + return cache.get(key, identityLoader()); // recursive load, this should fail + } + }; + testLoadThrows(key, cache, loader); + } + + public void testRecursiveLoadWithProxy() throws InterruptedException { + String key = "key"; + String otherKey = "otherKey"; + LocalCache cache = makeLocalCache(createCacheBuilder()); + CacheLoader loader = + new CacheLoader() { + @Override + public String load(String key) throws Exception { + return cache.get( + key, + identityLoader()); // recursive load (same as the initial one), this should fail + } + }; + CacheLoader proxyLoader = + new CacheLoader() { + @Override + public String load(String key) throws Exception { + return cache.get(otherKey, loader); // loads another key, is ok + } + }; + testLoadThrows(key, cache, proxyLoader); + } + // utility methods + private void testLoadThrows( + String key, LocalCache cache, CacheLoader loader) + throws InterruptedException { + CountDownLatch doneSignal = new CountDownLatch(1); + Thread thread = + new Thread( + () -> { + try { + cache.get(key, loader); + } catch (UncheckedExecutionException | ExecutionException e) { + doneSignal.countDown(); + } + }); + thread.start(); + + boolean done = doneSignal.await(1, TimeUnit.SECONDS); + if (!done) { + StringBuilder builder = new StringBuilder(); + for (StackTraceElement trace : thread.getStackTrace()) { + builder.append("\tat ").append(trace).append('\n'); + } + fail(builder.toString()); + } + } + /** * Returns an iterable containing all combinations of maximumSize, expireAfterAccess/Write, * weakKeys and weak/softValues. diff --git a/android/guava/src/com/google/common/cache/LocalCache.java b/android/guava/src/com/google/common/cache/LocalCache.java index f0a7699e0700..f4a8d850759b 100644 --- a/android/guava/src/com/google/common/cache/LocalCache.java +++ b/android/guava/src/com/google/common/cache/LocalCache.java @@ -2180,12 +2180,7 @@ V lockedGetOrLoad(K key, int hash, CacheLoader loader) throws Exec if (createNewEntry) { try { - // Synchronizes on the entry to allow failing fast when a recursive load is - // detected. This may be circumvented when an entry is copied, but will fail fast most - // of the time. - synchronized (e) { - return loadSync(key, hash, loadingValueReference, loader); - } + return loadSync(key, hash, loadingValueReference, loader); } finally { statsCounter.recordMisses(1); } @@ -2201,7 +2196,20 @@ V waitForLoadingValue(ReferenceEntry e, K key, ValueReference valueR throw new AssertionError(); } - checkState(!Thread.holdsLock(e), "Recursive load of: %s", key); + // As of this writing, the only prod ValueReference implementation for which isLoading() is + // true is LoadingValueReference. However, that might change, and we already have a *test* + // implementation for which it doesn't hold. So we check instanceof to be safe. + if (valueReference instanceof LoadingValueReference) { + // We check whether the thread that is loading the entry is our current thread, which would + // mean that we both load and wait for the entry. In this case, we fail fast instead of + // deadlocking. + checkState( + ((LoadingValueReference) valueReference).getLoadingThread() + != Thread.currentThread(), + "Recursive load of: %s", + key); + } + // don't consider expiration as we're concurrent with loading try { V value = valueReference.waitForValue(); @@ -3427,6 +3435,8 @@ static class LoadingValueReference implements ValueReference { final SettableFuture futureValue = SettableFuture.create(); final Stopwatch stopwatch = Stopwatch.createUnstarted(); + final Thread loadingThread; + public LoadingValueReference() { this(LocalCache.unset()); } @@ -3438,6 +3448,7 @@ public LoadingValueReference() { */ public LoadingValueReference(ValueReference oldValue) { this.oldValue = oldValue; + this.loadingThread = Thread.currentThread(); } @Override @@ -3541,6 +3552,10 @@ public ValueReference copyFor( ReferenceQueue queue, @CheckForNull V value, ReferenceEntry entry) { return this; } + + Thread getLoadingThread() { + return this.loadingThread; + } } // Queues diff --git a/guava-tests/test/com/google/common/cache/LocalCacheTest.java b/guava-tests/test/com/google/common/cache/LocalCacheTest.java index 7cc67e840dc2..a16c0ab37039 100644 --- a/guava-tests/test/com/google/common/cache/LocalCacheTest.java +++ b/guava-tests/test/com/google/common/cache/LocalCacheTest.java @@ -58,6 +58,7 @@ import com.google.common.testing.NullPointerTester; import com.google.common.testing.SerializableTester; import com.google.common.testing.TestLogHandler; +import com.google.common.util.concurrent.UncheckedExecutionException; import java.io.Serializable; import java.lang.ref.Reference; import java.lang.ref.ReferenceQueue; @@ -2688,8 +2689,86 @@ public void testSerializationProxyManual() { assertEquals(localCacheTwo.ticker, localCacheThree.ticker); } + public void testLoadDifferentKeyInLoader() throws ExecutionException, InterruptedException { + LocalCache cache = makeLocalCache(createCacheBuilder()); + String key1 = "key1"; + String key2 = "key2"; + + assertEquals( + key2, + cache.get( + key1, + new CacheLoader() { + @Override + public String load(String key) throws Exception { + return cache.get(key2, identityLoader()); // loads a different key, should work + } + })); + } + + public void testRecursiveLoad() throws InterruptedException { + LocalCache cache = makeLocalCache(createCacheBuilder()); + String key = "key"; + CacheLoader loader = + new CacheLoader() { + @Override + public String load(String key) throws Exception { + return cache.get(key, identityLoader()); // recursive load, this should fail + } + }; + testLoadThrows(key, cache, loader); + } + + public void testRecursiveLoadWithProxy() throws InterruptedException { + String key = "key"; + String otherKey = "otherKey"; + LocalCache cache = makeLocalCache(createCacheBuilder()); + CacheLoader loader = + new CacheLoader() { + @Override + public String load(String key) throws Exception { + return cache.get( + key, + identityLoader()); // recursive load (same as the initial one), this should fail + } + }; + CacheLoader proxyLoader = + new CacheLoader() { + @Override + public String load(String key) throws Exception { + return cache.get(otherKey, loader); // loads another key, is ok + } + }; + testLoadThrows(key, cache, proxyLoader); + } + // utility methods + private void testLoadThrows( + String key, LocalCache cache, CacheLoader loader) + throws InterruptedException { + CountDownLatch doneSignal = new CountDownLatch(1); + Thread thread = + new Thread( + () -> { + try { + cache.get(key, loader); + } catch (UncheckedExecutionException | ExecutionException e) { + doneSignal.countDown(); + } + }); + thread.start(); + + boolean done = doneSignal.await(1, TimeUnit.SECONDS); + if (!done) { + StringBuilder builder = new StringBuilder(); + for (StackTraceElement trace : thread.getStackTrace()) { + builder.append("\tat ").append(trace).append('\n'); + } + fail(builder.toString()); + } + } + /** * Returns an iterable containing all combinations of maximumSize, expireAfterAccess/Write, * weakKeys and weak/softValues. diff --git a/guava/src/com/google/common/cache/LocalCache.java b/guava/src/com/google/common/cache/LocalCache.java index a38543dbb7b6..f6ac7b40b241 100644 --- a/guava/src/com/google/common/cache/LocalCache.java +++ b/guava/src/com/google/common/cache/LocalCache.java @@ -2184,12 +2184,7 @@ V lockedGetOrLoad(K key, int hash, CacheLoader loader) throws Exec if (createNewEntry) { try { - // Synchronizes on the entry to allow failing fast when a recursive load is - // detected. This may be circumvented when an entry is copied, but will fail fast most - // of the time. - synchronized (e) { - return loadSync(key, hash, loadingValueReference, loader); - } + return loadSync(key, hash, loadingValueReference, loader); } finally { statsCounter.recordMisses(1); } @@ -2205,7 +2200,20 @@ V waitForLoadingValue(ReferenceEntry e, K key, ValueReference valueR throw new AssertionError(); } - checkState(!Thread.holdsLock(e), "Recursive load of: %s", key); + // As of this writing, the only prod ValueReference implementation for which isLoading() is + // true is LoadingValueReference. However, that might change, and we already have a *test* + // implementation for which it doesn't hold. So we check instanceof to be safe. + if (valueReference instanceof LoadingValueReference) { + // We check whether the thread that is loading the entry is our current thread, which would + // mean that we both load and wait for the entry. In this case, we fail fast instead of + // deadlocking. + checkState( + ((LoadingValueReference) valueReference).getLoadingThread() + != Thread.currentThread(), + "Recursive load of: %s", + key); + } + // don't consider expiration as we're concurrent with loading try { V value = valueReference.waitForValue(); @@ -3517,12 +3525,15 @@ static class LoadingValueReference implements ValueReference { final SettableFuture futureValue = SettableFuture.create(); final Stopwatch stopwatch = Stopwatch.createUnstarted(); + final Thread loadingThread; + public LoadingValueReference() { this(null); } public LoadingValueReference(@CheckForNull ValueReference oldValue) { this.oldValue = (oldValue == null) ? LocalCache.unset() : oldValue; + this.loadingThread = Thread.currentThread(); } @Override @@ -3647,6 +3658,10 @@ public ValueReference copyFor( ReferenceQueue queue, @CheckForNull V value, ReferenceEntry entry) { return this; } + + Thread getLoadingThread() { + return this.loadingThread; + } } static class ComputingValueReference extends LoadingValueReference {