Skip to content

Commit

Permalink
[8.14] Fix multithreading copies in lib vec (#108802) (#108810)
Browse files Browse the repository at this point in the history
Backport of:  * #108802
  • Loading branch information
ChrisHegarty committed May 19, 2024
1 parent 9f0690d commit f83f0bc
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 1 deletion.
5 changes: 5 additions & 0 deletions docs/changelog/108802.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 108802
summary: Fix multithreading copies in lib vec
area: Vector Search
type: bug
issues: []
5 changes: 5 additions & 0 deletions docs/changelog/108810.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 108810
summary: "[8.14] Fix multithreading copies in lib vec"
area: Vector Search
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@ public interface VectorScorer {
/** The maximum ordinal of vector this scorer can score. */
int maxOrd();

VectorScorer copy();

}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ public int maxOrd() {

@Override
public RandomVectorScorerSupplier copy() throws IOException {
return this; // no need to copy, thread-safe
return new VectorScorerSupplierAdapter(scorer.copy());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,9 @@ public float score(int firstOrd, int secondOrd) throws IOException {
return fallbackScore(firstByteOffset, secondByteOffset);
}
}

@Override
public DotProduct copy() {
return new DotProduct(dims, maxOrd, scoreCorrectionConstant, input.clone());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,9 @@ public float score(int firstOrd, int secondOrd) throws IOException {
return fallbackScore(firstByteOffset, secondByteOffset);
}
}

@Override
public Euclidean copy() {
return new Euclidean(dims, maxOrd, scoreCorrectionConstant, input.clone());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,9 @@ static float scaleMaxInnerProductScore(float rawSimilarity) {
}
return rawSimilarity + 1;
}

@Override
public MaximumInnerProduct copy() {
return new MaximumInnerProduct(dims, maxOrd, scoreCorrectionConstant, input.clone());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,24 @@
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.elasticsearch.test.ESTestCase;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.IntStream;

import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
import static org.elasticsearch.vec.VectorSimilarityType.COSINE;
import static org.elasticsearch.vec.VectorSimilarityType.DOT_PRODUCT;
import static org.elasticsearch.vec.VectorSimilarityType.EUCLIDEAN;
Expand Down Expand Up @@ -278,6 +287,78 @@ public void testLarge() throws IOException {
}
}

public void testRace() throws Exception {
testRaceImpl(COSINE);
testRaceImpl(DOT_PRODUCT);
testRaceImpl(EUCLIDEAN);
testRaceImpl(MAXIMUM_INNER_PRODUCT);
}

// Tests that copies in threads do not interfere with each other
void testRaceImpl(VectorSimilarityType sim) throws Exception {
assumeTrue(notSupportedMsg(), supported());
var factory = AbstractVectorTestCase.factory.get();

final long maxChunkSize = 32;
final int dims = 34; // dimensions that are larger than the chunk size, to force fallback
byte[] vec1 = new byte[dims];
byte[] vec2 = new byte[dims];
IntStream.range(0, dims).forEach(i -> vec1[i] = 1);
IntStream.range(0, dims).forEach(i -> vec2[i] = 2);
try (Directory dir = new MMapDirectory(createTempDir("testRace"), maxChunkSize)) {
String fileName = getTestName() + "-" + dims;
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
var one = floatToByteArray(1f);
byte[] bytes = concat(vec1, one, vec1, one, vec2, one, vec2, one);
out.writeBytes(bytes, 0, bytes.length);
}
var expectedScore1 = luceneScore(sim, vec1, vec1, 1, 1, 1);
var expectedScore2 = luceneScore(sim, vec2, vec2, 1, 1, 1);

try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
var scoreSupplier = factory.getScalarQuantizedVectorScorer(dims, 4, 1, sim, in).get();
var scorer = new VectorScorerSupplierAdapter(scoreSupplier);
var tasks = List.<Callable<Optional<Throwable>>>of(
new ScoreCallable(scorer.copy().scorer(0), 1, expectedScore1),
new ScoreCallable(scorer.copy().scorer(2), 3, expectedScore2)
);
var executor = Executors.newFixedThreadPool(2);
var results = executor.invokeAll(tasks);
executor.shutdown();
assertTrue(executor.awaitTermination(60, TimeUnit.SECONDS));
assertThat(results.stream().filter(Predicate.not(Future::isDone)).count(), equalTo(0L));
for (var res : results) {
assertThat("Unexpected exception" + res.get(), res.get(), isEmpty());
}
}
}
}

static class ScoreCallable implements Callable<Optional<Throwable>> {

final RandomVectorScorer scorer;
final int ord;
final float expectedScore;

ScoreCallable(RandomVectorScorer scorer, int ord, float expectedScore) {
this.scorer = scorer;
this.ord = ord;
this.expectedScore = expectedScore;
}

@Override
public Optional<Throwable> call() throws Exception {
try {
for (int i = 0; i < 100; i++) {
assertThat(scorer.score(ord), equalTo(expectedScore));
}
} catch (Throwable t) {
return Optional.of(t);
}
return Optional.empty();
}
}

// creates the vector based on the given ordinal, which is reproducible given the ord and dims
static byte[] vector(int ord, int dims) {
var random = new Random(Objects.hash(ord, dims));
Expand Down

0 comments on commit f83f0bc

Please sign in to comment.