|
11 | 11 |
|
12 | 12 | import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
|
13 | 13 | import org.apache.lucene.index.FieldInfo;
|
14 |
| -import org.apache.lucene.index.FloatVectorValues; |
15 | 14 | import org.apache.lucene.index.SegmentReadState;
|
16 | 15 | import org.apache.lucene.index.VectorSimilarityFunction;
|
17 | 16 | import org.apache.lucene.search.KnnCollector;
|
|
20 | 19 | import org.apache.lucene.util.VectorUtil;
|
21 | 20 | import org.apache.lucene.util.hnsw.NeighborQueue;
|
22 | 21 | import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
|
| 22 | +import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats; |
23 | 23 | import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
24 | 24 | import org.elasticsearch.simdvec.ESVectorUtil;
|
25 | 25 |
|
26 | 26 | import java.io.IOException;
|
| 27 | +import java.util.Map; |
27 | 28 | import java.util.function.IntPredicate;
|
28 | 29 |
|
29 | 30 | import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS;
|
|
38 | 39 | * Default implementation of {@link IVFVectorsReader}. It scores the posting lists centroids using
|
39 | 40 | * brute force and then scores the top ones using the posting list.
|
40 | 41 | */
|
41 |
| -public class DefaultIVFVectorsReader extends IVFVectorsReader { |
| 42 | +public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeapStats { |
42 | 43 | private static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1);
|
43 | 44 |
|
44 | 45 | public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException {
|
@@ -163,57 +164,9 @@ static float int4QuantizedScore(
|
163 | 164 | }
|
164 | 165 | }
|
165 | 166 |
|
166 |
| - static class OffHeapCentroidFloatVectorValues extends FloatVectorValues { |
167 |
| - private final int numCentroids; |
168 |
| - private final IndexInput input; |
169 |
| - private final int dimension; |
170 |
| - private final float[] centroid; |
171 |
| - private final long centroidByteSize; |
172 |
| - private int ord = -1; |
173 |
| - |
174 |
| - OffHeapCentroidFloatVectorValues(int numCentroids, IndexInput input, int dimension) { |
175 |
| - this.numCentroids = numCentroids; |
176 |
| - this.input = input; |
177 |
| - this.dimension = dimension; |
178 |
| - this.centroid = new float[dimension]; |
179 |
| - this.centroidByteSize = dimension + 3 * Float.BYTES + Short.BYTES; |
180 |
| - } |
181 |
| - |
182 |
| - @Override |
183 |
| - public float[] vectorValue(int ord) throws IOException { |
184 |
| - if (ord < 0 || ord >= numCentroids) { |
185 |
| - throw new IllegalArgumentException("ord must be in [0, " + numCentroids + "]"); |
186 |
| - } |
187 |
| - if (ord == this.ord) { |
188 |
| - return centroid; |
189 |
| - } |
190 |
| - readQuantizedCentroid(ord); |
191 |
| - return centroid; |
192 |
| - } |
193 |
| - |
194 |
| - private void readQuantizedCentroid(int centroidOrdinal) throws IOException { |
195 |
| - if (centroidOrdinal == ord) { |
196 |
| - return; |
197 |
| - } |
198 |
| - input.seek(numCentroids * centroidByteSize + (long) Float.BYTES * dimension * centroidOrdinal); |
199 |
| - input.readFloats(centroid, 0, centroid.length); |
200 |
| - ord = centroidOrdinal; |
201 |
| - } |
202 |
| - |
203 |
| - @Override |
204 |
| - public int dimension() { |
205 |
| - return dimension; |
206 |
| - } |
207 |
| - |
208 |
| - @Override |
209 |
| - public int size() { |
210 |
| - return numCentroids; |
211 |
| - } |
212 |
| - |
213 |
| - @Override |
214 |
| - public FloatVectorValues copy() throws IOException { |
215 |
| - return new OffHeapCentroidFloatVectorValues(numCentroids, input.clone(), dimension); |
216 |
| - } |
| 167 | + @Override |
| 168 | + public Map<String, Long> getOffHeapByteSize(FieldInfo fieldInfo) { |
| 169 | + return Map.of(); |
217 | 170 | }
|
218 | 171 |
|
219 | 172 | private static class MemorySegmentPostingsVisitor implements PostingVisitor {
|
|
0 commit comments