Skip to content

Commit

Permalink
Fix multithreading copies in lib vec (#108802)
Browse files Browse the repository at this point in the history
This commit fixes a potential multithreading issue with the lib vec
vector scorer. 

Since the implementation falls back to a lucene scorer which needs to
read from the index input, then we need to make a copy of the index
input. Otherwise, there is a potential for the stateful index input to
be accessed across threads - which would be bad.

The fallback is only used when one or other vector cross a segment
boundary, which is 16G by default. So the likelihood of this occurring
in practice is small, but the affect is bad. 

The fix is deliberately small and targeted, so that it can be
backported. After this change, I'm going to drop the custom VectorScorer
and adapter type, in favour of using the Lucene type directly. This
custom types were initially used when the code lived inside the native
module, where we didn't want to add a dependency on Lucene directly.
  • Loading branch information
ChrisHegarty committed May 19, 2024
1 parent c59322e commit a7e4423
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 1 deletion.
5 changes: 5 additions & 0 deletions docs/changelog/108802.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 108802
summary: Fix multithreading copies in lib vec
area: Vector Search
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@ public interface VectorScorer {
/** The maximum ordinal of vector this scorer can score. */
int maxOrd();

VectorScorer copy();

}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ public int maxOrd() {

@Override
public RandomVectorScorerSupplier copy() throws IOException {
return this; // no need to copy, thread-safe
return new VectorScorerSupplierAdapter(scorer.copy());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,9 @@ public float score(int firstOrd, int secondOrd) throws IOException {
return Math.max(fallbackScore(firstByteOffset, secondByteOffset), 0f);
}
}

@Override
public Int7DotProduct copy() {
return new Int7DotProduct(dims, maxOrd, scoreCorrectionConstant, input.clone());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,9 @@ public float score(int firstOrd, int secondOrd) throws IOException {
return fallbackScore(firstByteOffset, secondByteOffset);
}
}

@Override
public Int7Euclidean copy() {
return new Int7Euclidean(dims, maxOrd, scoreCorrectionConstant, input.clone());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,9 @@ static float scaleMaxInnerProductScore(float rawSimilarity) {
}
return rawSimilarity + 1;
}

@Override
public Int7MaximumInnerProduct copy() {
return new Int7MaximumInnerProduct(dims, maxOrd, scoreCorrectionConstant, input.clone());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,23 @@
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.hnsw.RandomVectorScorer;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.IntStream;

import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
import static org.elasticsearch.vec.VectorSimilarityType.COSINE;
import static org.elasticsearch.vec.VectorSimilarityType.DOT_PRODUCT;
import static org.elasticsearch.vec.VectorSimilarityType.EUCLIDEAN;
Expand Down Expand Up @@ -327,6 +336,78 @@ public void testLarge() throws IOException {
}
}

public void testRace() throws Exception {
testRaceImpl(COSINE);
testRaceImpl(DOT_PRODUCT);
testRaceImpl(EUCLIDEAN);
testRaceImpl(MAXIMUM_INNER_PRODUCT);
}

// Tests that copies in threads do not interfere with each other
void testRaceImpl(VectorSimilarityType sim) throws Exception {
assumeTrue(notSupportedMsg(), supported());
var factory = AbstractVectorTestCase.factory.get();

final long maxChunkSize = 32;
final int dims = 34; // dimensions that are larger than the chunk size, to force fallback
byte[] vec1 = new byte[dims];
byte[] vec2 = new byte[dims];
IntStream.range(0, dims).forEach(i -> vec1[i] = 1);
IntStream.range(0, dims).forEach(i -> vec2[i] = 2);
try (Directory dir = new MMapDirectory(createTempDir("testRace"), maxChunkSize)) {
String fileName = getTestName() + "-" + dims;
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
var one = floatToByteArray(1f);
byte[] bytes = concat(vec1, one, vec1, one, vec2, one, vec2, one);
out.writeBytes(bytes, 0, bytes.length);
}
var expectedScore1 = luceneScore(sim, vec1, vec1, 1, 1, 1);
var expectedScore2 = luceneScore(sim, vec2, vec2, 1, 1, 1);

try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
var scoreSupplier = factory.getInt7ScalarQuantizedVectorScorer(dims, 4, 1, sim, in).get();
var scorer = new VectorScorerSupplierAdapter(scoreSupplier);
var tasks = List.<Callable<Optional<Throwable>>>of(
new ScoreCallable(scorer.copy().scorer(0), 1, expectedScore1),
new ScoreCallable(scorer.copy().scorer(2), 3, expectedScore2)
);
var executor = Executors.newFixedThreadPool(2);
var results = executor.invokeAll(tasks);
executor.shutdown();
assertTrue(executor.awaitTermination(60, TimeUnit.SECONDS));
assertThat(results.stream().filter(Predicate.not(Future::isDone)).count(), equalTo(0L));
for (var res : results) {
assertThat("Unexpected exception" + res.get(), res.get(), isEmpty());
}
}
}
}

static class ScoreCallable implements Callable<Optional<Throwable>> {

final RandomVectorScorer scorer;
final int ord;
final float expectedScore;

ScoreCallable(RandomVectorScorer scorer, int ord, float expectedScore) {
this.scorer = scorer;
this.ord = ord;
this.expectedScore = expectedScore;
}

@Override
public Optional<Throwable> call() throws Exception {
try {
for (int i = 0; i < 100; i++) {
assertThat(scorer.score(ord), equalTo(expectedScore));
}
} catch (Throwable t) {
return Optional.of(t);
}
return Optional.empty();
}
}

// creates the vector based on the given ordinal, which is reproducible given the ord and dims
static byte[] vector(int ord, int dims) {
var random = new Random(Objects.hash(ord, dims));
Expand Down

0 comments on commit a7e4423

Please sign in to comment.