Skip to content

Commit 3d79e52

Browse files
benwtrentvaleriy42
authored andcommitted
Fix ivf nodestats impl for getOffHeapByteSize (elastic#129259)
This fixes a silly bug where we didn't override `OffHeapStats` for IVF.
1 parent 6504b27 commit 3d79e52

File tree

2 files changed

+40
-53
lines changed

2 files changed

+40
-53
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Lines changed: 6 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
1313
import org.apache.lucene.index.FieldInfo;
14-
import org.apache.lucene.index.FloatVectorValues;
1514
import org.apache.lucene.index.SegmentReadState;
1615
import org.apache.lucene.index.VectorSimilarityFunction;
1716
import org.apache.lucene.search.KnnCollector;
@@ -20,10 +19,12 @@
2019
import org.apache.lucene.util.VectorUtil;
2120
import org.apache.lucene.util.hnsw.NeighborQueue;
2221
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
22+
import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats;
2323
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
2424
import org.elasticsearch.simdvec.ESVectorUtil;
2525

2626
import java.io.IOException;
27+
import java.util.Map;
2728
import java.util.function.IntPredicate;
2829

2930
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS;
@@ -38,7 +39,7 @@
3839
* Default implementation of {@link IVFVectorsReader}. It scores the posting lists centroids using
3940
* brute force and then scores the top ones using the posting list.
4041
*/
41-
public class DefaultIVFVectorsReader extends IVFVectorsReader {
42+
public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeapStats {
4243
private static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1);
4344

4445
public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException {
@@ -163,57 +164,9 @@ static float int4QuantizedScore(
163164
}
164165
}
165166

166-
static class OffHeapCentroidFloatVectorValues extends FloatVectorValues {
167-
private final int numCentroids;
168-
private final IndexInput input;
169-
private final int dimension;
170-
private final float[] centroid;
171-
private final long centroidByteSize;
172-
private int ord = -1;
173-
174-
OffHeapCentroidFloatVectorValues(int numCentroids, IndexInput input, int dimension) {
175-
this.numCentroids = numCentroids;
176-
this.input = input;
177-
this.dimension = dimension;
178-
this.centroid = new float[dimension];
179-
this.centroidByteSize = dimension + 3 * Float.BYTES + Short.BYTES;
180-
}
181-
182-
@Override
183-
public float[] vectorValue(int ord) throws IOException {
184-
if (ord < 0 || ord >= numCentroids) {
185-
throw new IllegalArgumentException("ord must be in [0, " + numCentroids + "]");
186-
}
187-
if (ord == this.ord) {
188-
return centroid;
189-
}
190-
readQuantizedCentroid(ord);
191-
return centroid;
192-
}
193-
194-
private void readQuantizedCentroid(int centroidOrdinal) throws IOException {
195-
if (centroidOrdinal == ord) {
196-
return;
197-
}
198-
input.seek(numCentroids * centroidByteSize + (long) Float.BYTES * dimension * centroidOrdinal);
199-
input.readFloats(centroid, 0, centroid.length);
200-
ord = centroidOrdinal;
201-
}
202-
203-
@Override
204-
public int dimension() {
205-
return dimension;
206-
}
207-
208-
@Override
209-
public int size() {
210-
return numCentroids;
211-
}
212-
213-
@Override
214-
public FloatVectorValues copy() throws IOException {
215-
return new OffHeapCentroidFloatVectorValues(numCentroids, input.clone(), dimension);
216-
}
167+
@Override
168+
public Map<String, Long> getOffHeapByteSize(FieldInfo fieldInfo) {
169+
return Map.of();
217170
}
218171

219172
private static class MemorySegmentPostingsVisitor implements PostingVisitor {

server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,25 @@
1313
import org.apache.lucene.codecs.Codec;
1414
import org.apache.lucene.codecs.FilterCodec;
1515
import org.apache.lucene.codecs.KnnVectorsFormat;
16+
import org.apache.lucene.codecs.KnnVectorsReader;
17+
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
18+
import org.apache.lucene.document.Document;
19+
import org.apache.lucene.document.KnnFloatVectorField;
20+
import org.apache.lucene.index.CodecReader;
21+
import org.apache.lucene.index.DirectoryReader;
22+
import org.apache.lucene.index.IndexReader;
23+
import org.apache.lucene.index.IndexWriter;
24+
import org.apache.lucene.index.LeafReader;
1625
import org.apache.lucene.index.VectorEncoding;
1726
import org.apache.lucene.index.VectorSimilarityFunction;
27+
import org.apache.lucene.store.Directory;
1828
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
1929
import org.apache.lucene.tests.util.TestUtil;
2030
import org.elasticsearch.common.logging.LogConfigurator;
31+
import org.elasticsearch.index.codec.vectors.reflect.OffHeapByteSizeUtils;
2132
import org.junit.Before;
2233

34+
import java.io.IOException;
2335
import java.util.List;
2436
import java.util.Locale;
2537

@@ -94,4 +106,26 @@ public void testLimits() {
94106
expectThrows(IllegalArgumentException.class, () -> new IVFVectorsFormat(MIN_VECTORS_PER_CLUSTER - 1));
95107
expectThrows(IllegalArgumentException.class, () -> new IVFVectorsFormat(MAX_VECTORS_PER_CLUSTER + 1));
96108
}
109+
110+
public void testSimpleOffHeapSize() throws IOException {
111+
float[] vector = randomVector(random().nextInt(12, 500));
112+
try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
113+
Document doc = new Document();
114+
doc.add(new KnnFloatVectorField("f", vector, VectorSimilarityFunction.EUCLIDEAN));
115+
w.addDocument(doc);
116+
w.commit();
117+
try (IndexReader reader = DirectoryReader.open(w)) {
118+
LeafReader r = getOnlyLeafReader(reader);
119+
if (r instanceof CodecReader codecReader) {
120+
KnnVectorsReader knnVectorsReader = codecReader.getVectorReader();
121+
if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) {
122+
knnVectorsReader = fieldsReader.getFieldReader("f");
123+
}
124+
var fieldInfo = r.getFieldInfos().fieldInfo("f");
125+
var offHeap = OffHeapByteSizeUtils.getOffHeapByteSize(knnVectorsReader, fieldInfo);
126+
assertEquals(0, offHeap.size());
127+
}
128+
}
129+
}
130+
}
97131
}

0 commit comments

Comments
 (0)