diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java index b0738a6ea5bfb..5a951774ce22b 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java @@ -276,12 +276,11 @@ TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher) throws IOException TopDocs doVectorQuery(float[] vector, IndexSearcher searcher) throws IOException { Query knnQuery; int topK = this.topK; - int efSearch = this.efSearch; if (overSamplingFactor > 1f) { // oversample the topK results to get more candidates for the final result topK = (int) Math.ceil(topK * overSamplingFactor); - efSearch = Math.max(topK, efSearch); } + int efSearch = Math.max(topK, this.efSearch); if (indexType == KnnIndexTester.IndexType.IVF) { knnQuery = new IVFKnnFloatVectorQuery(VECTOR_FIELD, vector, topK, efSearch, null, nProbe); } else { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java new file mode 100644 index 0000000000000..7c0c79e6ab6ca --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.internal.hppc.IntArrayList; + +final class CentroidAssignments { + + private final int numCentroids; + private final float[][] cachedCentroids; + private final IntArrayList[] assignmentsByCluster; + + private CentroidAssignments(int numCentroids, float[][] cachedCentroids, IntArrayList[] assignmentsByCluster) { + this.numCentroids = numCentroids; + this.cachedCentroids = cachedCentroids; + this.assignmentsByCluster = assignmentsByCluster; + } + + CentroidAssignments(float[][] centroids, IntArrayList[] assignmentsByCluster) { + this(centroids.length, centroids, assignmentsByCluster); + } + + CentroidAssignments(int numCentroids, IntArrayList[] assignmentsByCluster) { + this(numCentroids, null, assignmentsByCluster); + } + + // Getters and setters + public int numCentroids() { + return numCentroids; + } + + public float[][] cachedCentroids() { + return cachedCentroids; + } + + public IntArrayList[] assignmentsByCluster() { + return assignmentsByCluster; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index e09cf474d09ea..36a7c2084a4a5 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -112,15 +112,6 @@ public float score(int centroidOrdinal) throws IOException { }; } - @Override - protected FloatVectorValues getCentroids(IndexInput indexInput, int numCentroids, FieldInfo info) { - FieldEntry entry = fields.get(info.number); - if (entry == null) { - return null; - } - return new OffHeapCentroidFloatVectorValues(numCentroids, indexInput, info.getVectorDimension()); - } - @Override NeighborQueue scorePostingLists(FieldInfo fieldInfo, KnnCollector knnCollector, CentroidQueryScorer centroidQueryScorer, int nProbe) throws IOException { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index 1c431b01e611c..2527a91074e7e 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -14,21 +14,19 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.internal.hppc.IntArrayList; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.InfoStream; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans; +import org.elasticsearch.index.codec.vectors.cluster.KMeansResult; import org.elasticsearch.simdvec.ES91OSQVectorsScorer; -import org.elasticsearch.simdvec.ESVectorUtil; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; -import java.util.ArrayList; -import java.util.List; import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS; import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize; @@ -36,16 +34,12 @@ import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.IVF_VECTOR_COMPONENT; /** - * Default implementation of {@link IVFVectorsWriter}. It uses {@link KMeans} algorithm to - * partition the vector space, and then stores the centroids an posting list in a sequential + * Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to + * partition the vector space, and then stores the centroids and posting list in a sequential * fashion. */ public class DefaultIVFVectorsWriter extends IVFVectorsWriter { - static final float SOAR_LAMBDA = 1.0f; - // What percentage of the centroids do we do a second check on for SOAR assignment - static final float EXT_SOAR_LIMIT_CHECK_RATIO = 0.10f; - private final int vectorPerCluster; public DefaultIVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate, int vectorPerCluster) throws IOException { @@ -53,77 +47,81 @@ public DefaultIVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVec this.vectorPerCluster = vectorPerCluster; } - @Override - CentroidAssignmentScorer calculateAndWriteCentroids( - FieldInfo fieldInfo, - FloatVectorValues floatVectorValues, - IndexOutput centroidOutput, - float[] globalCentroid - ) throws IOException { - if (floatVectorValues.size() == 0) { - return CentroidAssignmentScorer.EMPTY; - } - // calculate the centroids - int maxNumClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1; - int desiredClusters = (int) Math.max(Math.sqrt(floatVectorValues.size()), maxNumClusters); - final KMeans.Results kMeans = KMeans.cluster( - floatVectorValues, - desiredClusters, - false, - 42L, - KMeans.KmeansInitializationMethod.PLUS_PLUS, - null, - fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE, - 1, - 15, - desiredClusters * 256 - ); - float[][] centroids = kMeans.centroids(); - // write them - writeCentroids(centroids, fieldInfo, globalCentroid, centroidOutput); - return new OnHeapCentroidAssignmentScorer(centroids); - } - @Override long[] buildAndWritePostingsLists( FieldInfo fieldInfo, - InfoStream infoStream, - CentroidAssignmentScorer randomCentroidScorer, + CentroidSupplier centroidSupplier, FloatVectorValues floatVectorValues, - IndexOutput postingsOutput + IndexOutput postingsOutput, + InfoStream infoStream, + IntArrayList[] assignmentsByCluster ) throws IOException { - IntArrayList[] clusters = new IntArrayList[randomCentroidScorer.size()]; - for (int i = 0; i < randomCentroidScorer.size(); i++) { - clusters[i] = new IntArrayList(floatVectorValues.size() / randomCentroidScorer.size() / 4); - } - assignCentroids(randomCentroidScorer, floatVectorValues, clusters); - if (infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { - printClusterQualityStatistics(clusters, infoStream); - } // write the posting lists - final long[] offsets = new long[randomCentroidScorer.size()]; + final long[] offsets = new long[centroidSupplier.size()]; OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer); DocIdsWriter docIdsWriter = new DocIdsWriter(); - for (int i = 0; i < randomCentroidScorer.size(); i++) { - float[] centroid = randomCentroidScorer.centroid(i); + + for (int c = 0; c < centroidSupplier.size(); c++) { + float[] centroid = centroidSupplier.centroid(c); binarizedByteVectorValues.centroid = centroid; - // TODO sort by distance to the centroid - IntArrayList cluster = clusters[i]; + // TODO: add back in sorting vectors by distance to centroid + IntArrayList cluster = assignmentsByCluster[c]; // TODO align??? - offsets[i] = postingsOutput.getFilePointer(); + offsets[c] = postingsOutput.getFilePointer(); int size = cluster.size(); postingsOutput.writeVInt(size); postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); // TODO we might want to consider putting the docIds in a separate file // to aid with only having to fetch vectors from slower storage when they are required // keeping them in the same file indicates we pull the entire file into cache - docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), cluster.size(), postingsOutput); + docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), size, postingsOutput); writePostingList(cluster, postingsOutput, binarizedByteVectorValues); } + + if (infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { + printClusterQualityStatistics(assignmentsByCluster, infoStream); + } + return offsets; } + private static void printClusterQualityStatistics(IntArrayList[] clusters, InfoStream infoStream) { + float min = Float.MAX_VALUE; + float max = Float.MIN_VALUE; + float mean = 0; + float m2 = 0; + // iteratively compute the variance & mean + int count = 0; + for (IntArrayList cluster : clusters) { + count += 1; + if (cluster == null) { + continue; + } + float delta = cluster.size() - mean; + mean += delta / count; + m2 += delta * (cluster.size() - mean); + min = Math.min(min, cluster.size()); + max = Math.max(max, cluster.size()); + } + float variance = m2 / (clusters.length - 1); + infoStream.message( + IVF_VECTOR_COMPONENT, + "Centroid count: " + + clusters.length + + " min: " + + min + + " max: " + + max + + " mean: " + + mean + + " stdDev: " + + Math.sqrt(variance) + + " variance: " + + variance + ); + } + private void writePostingList(IntArrayList cluster, IndexOutput postingsOutput, BinarizedFloatVectorValues binarizedByteVectorValues) throws IOException { int limit = cluster.size() - ES91OSQVectorsScorer.BULK_SIZE + 1; @@ -173,13 +171,8 @@ private void writePostingList(IntArrayList cluster, IndexOutput postingsOutput, } @Override - CentroidAssignmentScorer createCentroidScorer( - IndexInput centroidsInput, - int numCentroids, - FieldInfo fieldInfo, - float[] globalCentroid - ) { - return new OffHeapCentroidAssignmentScorer(centroidsInput, numCentroids, fieldInfo); + CentroidSupplier createCentroidSupplier(IndexInput centroidsInput, int numCentroids, FieldInfo fieldInfo, float[] globalCentroid) { + return new OffHeapCentroidSupplier(centroidsInput, numCentroids, fieldInfo); } static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] globalCentroid, IndexOutput centroidOutput) @@ -188,24 +181,8 @@ static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] glo byte[] quantizedScratch = new byte[fieldInfo.getVectorDimension()]; float[] centroidScratch = new float[fieldInfo.getVectorDimension()]; // TODO do we want to store these distances as well for future use? - float[] distances = new float[centroids.length]; - for (int i = 0; i < centroids.length; i++) { - distances[i] = VectorUtil.squareDistance(centroids[i], globalCentroid); - } - // sort the centroids by distance to globalCentroid, nearest (smallest distance), to furthest - // (largest) - for (int i = 0; i < centroids.length; i++) { - for (int j = i + 1; j < centroids.length; j++) { - if (distances[i] > distances[j]) { - float[] tmp = centroids[i]; - centroids[i] = centroids[j]; - centroids[j] = tmp; - float tmpDistance = distances[i]; - distances[i] = distances[j]; - distances[j] = tmpDistance; - } - } - } + // TODO: sort centroids by global centroid (was doing so previously here) + // TODO: sorting tanks recall possibly because centroids ordinals no longer are aligned for (float[] centroid : centroids) { System.arraycopy(centroid, 0, centroidScratch, 0, centroid.length); OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize( @@ -223,190 +200,60 @@ static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] glo } } - static float[][] gatherInitCentroids( - List centroidList, - List segmentCentroids, - int desiredClusters, + CentroidAssignments calculateAndWriteCentroids( FieldInfo fieldInfo, - MergeState mergeState + FloatVectorValues floatVectorValues, + IndexOutput centroidOutput, + MergeState mergeState, + float[] globalCentroid ) throws IOException { - if (centroidList.size() == 0) { - return null; - } - long startTime = System.nanoTime(); - // sort centroid list by floatvector size - FloatVectorValues baseSegment = centroidList.get(0); - for (var l : centroidList) { - if (l.size() > baseSegment.size()) { - baseSegment = l; - } - } - float[] scratch = new float[fieldInfo.getVectorDimension()]; - float minimumDistance = Float.MAX_VALUE; - for (int j = 0; j < baseSegment.size(); j++) { - System.arraycopy(baseSegment.vectorValue(j), 0, scratch, 0, baseSegment.dimension()); - for (int k = j + 1; k < baseSegment.size(); k++) { - float d = VectorUtil.squareDistance(scratch, baseSegment.vectorValue(k)); - if (d < minimumDistance) { - minimumDistance = d; - } - } - } - if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { - mergeState.infoStream.message( - IVF_VECTOR_COMPONENT, - "Agglomerative cluster min distance: " + minimumDistance + " From biggest segment: " + baseSegment.size() - ); - } - int[] labels = new int[segmentCentroids.size()]; - // loop over segments - int clusterIdx = 0; - // keep track of all inter-centroid distances, - // using less than centroid * centroid space (e.g. not keeping track of duplicates) - for (int i = 0; i < segmentCentroids.size(); i++) { - if (labels[i] == 0) { - clusterIdx += 1; - labels[i] = clusterIdx; - } - SegmentCentroid segmentCentroid = segmentCentroids.get(i); - System.arraycopy( - centroidList.get(segmentCentroid.segment()).vectorValue(segmentCentroid.centroid), - 0, - scratch, - 0, - baseSegment.dimension() - ); - for (int j = i + 1; j < segmentCentroids.size(); j++) { - float d = VectorUtil.squareDistance( - scratch, - centroidList.get(segmentCentroids.get(j).segment()).vectorValue(segmentCentroids.get(j).centroid()) - ); - if (d < minimumDistance / 2) { - if (labels[j] == 0) { - labels[j] = labels[i]; - } else { - for (int k = 0; k < labels.length; k++) { - if (labels[k] == labels[j]) { - labels[k] = labels[i]; - } - } - } - } - } - } - float[][] initCentroids = new float[clusterIdx][fieldInfo.getVectorDimension()]; - int[] sum = new int[clusterIdx]; - for (int i = 0; i < segmentCentroids.size(); i++) { - SegmentCentroid segmentCentroid = segmentCentroids.get(i); - int label = labels[i]; - FloatVectorValues segment = centroidList.get(segmentCentroid.segment()); - float[] vector = segment.vectorValue(segmentCentroid.centroid); - for (int j = 0; j < vector.length; j++) { - initCentroids[label - 1][j] += (vector[j] * segmentCentroid.centroidSize); - } - sum[label - 1] += segmentCentroid.centroidSize; - } - for (int i = 0; i < initCentroids.length; i++) { - if (sum[i] == 0 || sum[i] == 1) { - continue; - } - for (int j = 0; j < initCentroids[i].length; j++) { - initCentroids[i][j] /= sum[i]; - } - } - if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { - mergeState.infoStream.message( - IVF_VECTOR_COMPONENT, - "Agglomerative cluster time ms: " + ((System.nanoTime() - startTime) / 1000000.0) - ); - mergeState.infoStream.message( - IVF_VECTOR_COMPONENT, - "Gathered initCentroids:" + initCentroids.length + " for desired: " + desiredClusters - ); - } - return initCentroids; + // TODO: take advantage of prior generated clusters from mergeState in the future + return calculateAndWriteCentroids(fieldInfo, floatVectorValues, centroidOutput, mergeState.infoStream, globalCentroid, false); } - record SegmentCentroid(int segment, int centroid, int centroidSize) {} + CentroidAssignments calculateAndWriteCentroids( + FieldInfo fieldInfo, + FloatVectorValues floatVectorValues, + IndexOutput centroidOutput, + InfoStream infoStream, + float[] globalCentroid + ) throws IOException { + return calculateAndWriteCentroids(fieldInfo, floatVectorValues, centroidOutput, infoStream, globalCentroid, true); + } /** - * Calculate the centroids for the given field and write them to the given - * temporary centroid output. - * When merging, we first bootstrap the KMeans algorithm with the centroids contained in the merging segments. - * To prevent centroids that are too similar from having an outsized impact, all centroids that are closer than - * the largest segments intra-cluster distance are merged into a single centroid. - * The resulting centroids are then used to initialize the KMeans algorithm. + * Calculate the centroids for the given field and write them to the given centroid output. + * We use the {@link HierarchicalKMeans} algorithm to partition the space of all vectors across merging segments * * @param fieldInfo merging field info * @param floatVectorValues the float vector values to merge - * @param temporaryCentroidOutput the temporary centroid output - * @param mergeState the merge state + * @param centroidOutput the centroid output + * @param infoStream the merge state * @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids - * @return the number of centroids written + * @param cacheCentroids whether the centroids are kept or discarded once computed + * @return the vector assignments, soar assignments, and if asked the centroids themselves that were computed * @throws IOException if an I/O error occurs */ - @Override - protected int calculateAndWriteCentroids( + CentroidAssignments calculateAndWriteCentroids( FieldInfo fieldInfo, FloatVectorValues floatVectorValues, - IndexOutput temporaryCentroidOutput, - MergeState mergeState, - float[] globalCentroid + IndexOutput centroidOutput, + InfoStream infoStream, + float[] globalCentroid, + boolean cacheCentroids ) throws IOException { - if (floatVectorValues.size() == 0) { - return 0; - } - int maxNumClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1; - int desiredClusters = (int) Math.max(Math.sqrt(floatVectorValues.size()), maxNumClusters); - // init centroids from merge state - List centroidList = new ArrayList<>(); - List segmentCentroids = new ArrayList<>(desiredClusters); - - int segmentIdx = 0; - for (var reader : mergeState.knnVectorsReaders) { - IVFVectorsReader ivfVectorsReader = IVFVectorsFormat.getIVFReader(reader, fieldInfo.name); - if (ivfVectorsReader == null) { - continue; - } - FloatVectorValues centroid = ivfVectorsReader.getCentroids(fieldInfo); - if (centroid == null) { - continue; - } - centroidList.add(centroid); - for (int i = 0; i < centroid.size(); i++) { - int size = ivfVectorsReader.centroidSize(fieldInfo.name, i); - if (size == 0) { - continue; - } - segmentCentroids.add(new SegmentCentroid(segmentIdx, i, size)); - } - segmentIdx++; - } - - float[][] initCentroids = gatherInitCentroids(centroidList, segmentCentroids, desiredClusters, fieldInfo, mergeState); - - // FIXME: run a custom version of KMeans that is just better... long nanoTime = System.nanoTime(); - final KMeans.Results kMeans = KMeans.cluster( - floatVectorValues, - desiredClusters, - false, - 42L, - KMeans.KmeansInitializationMethod.PLUS_PLUS, - initCentroids, - fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE, - 1, - 5, - desiredClusters * 64 - ); - if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { - mergeState.infoStream.message(IVF_VECTOR_COMPONENT, "KMeans time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0)); - } - float[][] centroids = kMeans.centroids(); - // write them - // calculate the global centroid from all the centroids: + // TODO: consider hinting / bootstrapping hierarchical kmeans with the prior segments centroids + KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster); + float[][] centroids = kMeansResult.centroids(); + int[] assignments = kMeansResult.assignments(); + int[] soarAssignments = kMeansResult.soarAssignments(); + + // TODO: for flush we are doing this over the vectors and here centroids which seems duplicative + // preliminary tests suggest recall is good using only centroids but need to do further evaluation + // TODO: push this logic into vector util? for (float[] centroid : centroids) { for (int j = 0; j < centroid.length; j++) { globalCentroid[j] += centroid[j]; @@ -415,197 +262,41 @@ protected int calculateAndWriteCentroids( for (int j = 0; j < globalCentroid.length; j++) { globalCentroid[j] /= centroids.length; } - writeCentroids(centroids, fieldInfo, globalCentroid, temporaryCentroidOutput); - return centroids.length; - } - @Override - long[] buildAndWritePostingsLists( - FieldInfo fieldInfo, - CentroidAssignmentScorer centroidAssignmentScorer, - FloatVectorValues floatVectorValues, - IndexOutput postingsOutput, - MergeState mergeState - ) throws IOException { - IntArrayList[] clusters = new IntArrayList[centroidAssignmentScorer.size()]; - for (int i = 0; i < centroidAssignmentScorer.size(); i++) { - clusters[i] = new IntArrayList(floatVectorValues.size() / centroidAssignmentScorer.size() / 4); - } - long nanoTime = System.nanoTime(); - // Can we do a pre-filter by finding the nearest centroids to the original vector centroids? - // We need to be careful on vecOrd vs. doc as we need random access to the raw vector for posting list writing - assignCentroids(centroidAssignmentScorer, floatVectorValues, clusters); - if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { - mergeState.infoStream.message(IVF_VECTOR_COMPONENT, "assignCentroids time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0)); - } - - if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { - printClusterQualityStatistics(clusters, mergeState.infoStream); - } - // write the posting lists - final long[] offsets = new long[centroidAssignmentScorer.size()]; - OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); - BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer); - DocIdsWriter docIdsWriter = new DocIdsWriter(); - for (int i = 0; i < centroidAssignmentScorer.size(); i++) { - float[] centroid = centroidAssignmentScorer.centroid(i); - binarizedByteVectorValues.centroid = centroid; - // TODO: sort by distance to the centroid - IntArrayList cluster = clusters[i]; - // TODO align??? - offsets[i] = postingsOutput.getFilePointer(); - int size = cluster.size(); - postingsOutput.writeVInt(size); - postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); - // TODO we might want to consider putting the docIds in a separate file - // to aid with only having to fetch vectors from slower storage when they are required - // keeping them in the same file indicates we pull the entire file into cache - docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), size, postingsOutput); - writePostingList(cluster, postingsOutput, binarizedByteVectorValues); - } - return offsets; - } + // write centroids + writeCentroids(centroids, fieldInfo, globalCentroid, centroidOutput); - private static void printClusterQualityStatistics(IntArrayList[] clusters, InfoStream infoStream) { - float min = Float.MAX_VALUE; - float max = Float.MIN_VALUE; - float mean = 0; - float m2 = 0; - // iteratively compute the variance & mean - int count = 0; - for (IntArrayList cluster : clusters) { - count += 1; - if (cluster == null) { - continue; - } - float delta = cluster.size() - mean; - mean += delta / count; - m2 += delta * (cluster.size() - mean); - min = Math.min(min, cluster.size()); - max = Math.max(max, cluster.size()); + if (infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { + infoStream.message( + IVF_VECTOR_COMPONENT, + "calculate centroids and assign vectors time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0) + ); + infoStream.message(IVF_VECTOR_COMPONENT, "final centroid count: " + centroids.length); } - float variance = m2 / (clusters.length - 1); - infoStream.message( - IVF_VECTOR_COMPONENT, - "Centroid count: " - + clusters.length - + " min: " - + min - + " max: " - + max - + " mean: " - + mean - + " stdDev: " - + Math.sqrt(variance) - + " variance: " - + variance - ); - } - static void assignCentroids(CentroidAssignmentScorer scorer, FloatVectorValues vectors, IntArrayList[] clusters) throws IOException { - int numCentroids = scorer.size(); - // we at most will look at the EXT_SOAR_LIMIT_CHECK_RATIO nearest centroids if possible - int soarToCheck = (int) (numCentroids * EXT_SOAR_LIMIT_CHECK_RATIO); - int soarClusterCheckCount = Math.min(numCentroids - 1, soarToCheck); - NeighborQueue neighborsToCheck = new NeighborQueue(soarClusterCheckCount + 1, true); - OrdScoreIterator ordScoreIterator = new OrdScoreIterator(soarClusterCheckCount + 1); - float[] scratch = new float[vectors.dimension()]; - for (int docID = 0; docID < vectors.size(); docID++) { - float[] vector = vectors.vectorValue(docID); - scorer.setScoringVector(vector); - int bestCentroid = 0; - float bestScore = Float.MAX_VALUE; - if (numCentroids > 1) { - for (short c = 0; c < numCentroids; c++) { - float squareDist = scorer.score(c); - neighborsToCheck.insertWithOverflow(c, squareDist); + IntArrayList[] assignmentsByCluster = new IntArrayList[centroids.length]; + for (int c = 0; c < centroids.length; c++) { + IntArrayList cluster = new IntArrayList(vectorPerCluster); + for (int j = 0; j < assignments.length; j++) { + if (assignments[j] == c) { + cluster.add(j); } - // pop the best - int sz = neighborsToCheck.size(); - int best = neighborsToCheck.consumeNodesAndScoresMin(ordScoreIterator.ords, ordScoreIterator.scores); - // Set the size to the number of neighbors we actually found - ordScoreIterator.setSize(sz); - bestScore = ordScoreIterator.getScore(best); - bestCentroid = ordScoreIterator.getOrd(best); } - clusters[bestCentroid].add(docID); - if (soarClusterCheckCount > 0) { - assignCentroidSOAR( - ordScoreIterator, - docID, - bestCentroid, - scorer.centroid(bestCentroid), - bestScore, - scratch, - scorer, - vector, - clusters - ); - } - neighborsToCheck.clear(); - } - } - - static void assignCentroidSOAR( - OrdScoreIterator centroidsToCheck, - int vecOrd, - int bestCentroidId, - float[] bestCentroid, - float bestScore, - float[] scratch, - CentroidAssignmentScorer scorer, - float[] vector, - IntArrayList[] clusters - ) throws IOException { - ESVectorUtil.subtract(vector, bestCentroid, scratch); - int bestSecondaryCentroid = -1; - float minDist = Float.MAX_VALUE; - for (int i = 0; i < centroidsToCheck.size(); i++) { - float score = centroidsToCheck.getScore(i); - int centroidOrdinal = centroidsToCheck.getOrd(i); - if (centroidOrdinal == bestCentroidId) { - continue; - } - float proj = ESVectorUtil.soarResidual(vector, scorer.centroid(centroidOrdinal), scratch); - score += SOAR_LAMBDA * proj * proj / bestScore; - if (score < minDist) { - bestSecondaryCentroid = centroidOrdinal; - minDist = score; - } - } - if (bestSecondaryCentroid != -1) { - clusters[bestSecondaryCentroid].add(vecOrd); - } - } - - static class OrdScoreIterator { - private final int[] ords; - private final float[] scores; - private int idx = 0; - OrdScoreIterator(int size) { - this.ords = new int[size]; - this.scores = new float[size]; - } - - int setSize(int size) { - if (size > ords.length) { - throw new IllegalArgumentException("size must be <= " + ords.length); + for (int j = 0; j < soarAssignments.length; j++) { + if (soarAssignments[j] == c) { + cluster.add(j); + } } - this.idx = size; - return size; - } - int getOrd(int idx) { - return ords[idx]; + cluster.trimToSize(); + assignmentsByCluster[c] = cluster; } - float getScore(int idx) { - return scores[idx]; - } - - int size() { - return idx; + if (cacheCentroids) { + return new CentroidAssignments(centroids, assignmentsByCluster); + } else { + return new CentroidAssignments(centroids.length, assignmentsByCluster); } } @@ -650,16 +341,25 @@ private void binarize(int ord) throws IOException { } } - static class OffHeapCentroidAssignmentScorer implements CentroidAssignmentScorer { + static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections) + throws IOException { + indexOutput.writeBytes(binaryValue, binaryValue.length); + indexOutput.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + indexOutput.writeInt(Float.floatToIntBits(corrections.upperInterval())); + indexOutput.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff; + indexOutput.writeShort((short) corrections.quantizedComponentSum()); + } + + static class OffHeapCentroidSupplier implements CentroidSupplier { private final IndexInput centroidsInput; private final int numCentroids; private final int dimension; private final float[] scratch; - private float[] q; private final long rawCentroidOffset; private int currOrd = -1; - OffHeapCentroidAssignmentScorer(IndexInput centroidsInput, int numCentroids, FieldInfo info) { + OffHeapCentroidSupplier(IndexInput centroidsInput, int numCentroids, FieldInfo info) { this.centroidsInput = centroidsInput; this.numCentroids = numCentroids; this.dimension = info.getVectorDimension(); @@ -682,55 +382,5 @@ public float[] centroid(int centroidOrdinal) throws IOException { this.currOrd = centroidOrdinal; return scratch; } - - @Override - public void setScoringVector(float[] vector) { - q = vector; - } - - @Override - public float score(int centroidOrdinal) throws IOException { - return VectorUtil.squareDistance(centroid(centroidOrdinal), q); - } - } - - // TODO throw away rawCentroids - static class OnHeapCentroidAssignmentScorer implements CentroidAssignmentScorer { - private final float[][] centroids; - private float[] q; - - OnHeapCentroidAssignmentScorer(float[][] centroids) { - this.centroids = centroids; - } - - @Override - public int size() { - return centroids.length; - } - - @Override - public void setScoringVector(float[] vector) { - q = vector; - } - - @Override - public float[] centroid(int centroidOrdinal) throws IOException { - return centroids[centroidOrdinal]; - } - - @Override - public float score(int centroidOrdinal) throws IOException { - return VectorUtil.squareDistance(centroid(centroidOrdinal), q); - } - } - - static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections) - throws IOException { - indexOutput.writeBytes(binaryValue, binaryValue.length); - indexOutput.writeInt(Float.floatToIntBits(corrections.lowerInterval())); - indexOutput.writeInt(Float.floatToIntBits(corrections.upperInterval())); - indexOutput.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); - assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff; - indexOutput.writeShort((short) corrections.quantizedComponentSum()); } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java index 1a0a5bd94af35..d5086cf2d479e 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java @@ -97,23 +97,6 @@ abstract CentroidQueryScorer getCentroidScorer( IndexInput clusters ) throws IOException; - protected abstract FloatVectorValues getCentroids(IndexInput indexInput, int numCentroids, FieldInfo info) throws IOException; - - public FloatVectorValues getCentroids(FieldInfo fieldInfo) throws IOException { - FieldEntry entry = fields.get(fieldInfo.number); - if (entry == null) { - return null; - } - return getCentroids(entry.centroidSlice(ivfCentroids), entry.postingListOffsets.length, fieldInfo); - } - - int centroidSize(String fieldName, int centroidOrdinal) throws IOException { - FieldInfo fieldInfo = state.fieldInfos.fieldInfo(fieldName); - FieldEntry entry = fields.get(fieldInfo.number); - ivfClusters.seek(entry.postingListOffsets[centroidOrdinal]); - return ivfClusters.readVInt(); - } - private static IndexInput openDataInput( SegmentReadState state, int versionMeta, diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java index d6188703881a4..73a0bdd5efa48 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java @@ -11,11 +11,9 @@ import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.KnnFieldVectorsWriter; -import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; -import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; @@ -25,6 +23,7 @@ import org.apache.lucene.index.Sorter; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.IntArrayList; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; @@ -123,38 +122,32 @@ public final KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOExc return rawVectorDelegate; } - protected abstract int calculateAndWriteCentroids( + abstract CentroidAssignments calculateAndWriteCentroids( FieldInfo fieldInfo, FloatVectorValues floatVectorValues, - IndexOutput temporaryCentroidOutput, + IndexOutput centroidOutput, MergeState mergeState, float[] globalCentroid ) throws IOException; - abstract long[] buildAndWritePostingsLists( - FieldInfo fieldInfo, - CentroidAssignmentScorer scorer, - FloatVectorValues floatVectorValues, - IndexOutput postingsOutput, - MergeState mergeState - ) throws IOException; - - abstract CentroidAssignmentScorer calculateAndWriteCentroids( + abstract CentroidAssignments calculateAndWriteCentroids( FieldInfo fieldInfo, FloatVectorValues floatVectorValues, IndexOutput centroidOutput, + InfoStream infoStream, float[] globalCentroid ) throws IOException; abstract long[] buildAndWritePostingsLists( FieldInfo fieldInfo, - InfoStream infoStream, - CentroidAssignmentScorer scorer, + CentroidSupplier centroidSupplier, FloatVectorValues floatVectorValues, - IndexOutput postingsOutput + IndexOutput postingsOutput, + InfoStream infoStream, + IntArrayList[] assignmentsByCluster ) throws IOException; - abstract CentroidAssignmentScorer createCentroidScorer( + abstract CentroidSupplier createCentroidSupplier( IndexInput centroidsInput, int numCentroids, FieldInfo fieldInfo, @@ -166,33 +159,31 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { rawVectorDelegate.flush(maxDoc, sortMap); for (FieldWriter fieldWriter : fieldWriters) { float[] globalCentroid = new float[fieldWriter.fieldInfo.getVectorDimension()]; - // calculate global centroid - for (var vector : fieldWriter.delegate.getVectors()) { - for (int i = 0; i < globalCentroid.length; i++) { - globalCentroid[i] += vector[i]; - } - } - for (int i = 0; i < globalCentroid.length; i++) { - globalCentroid[i] /= fieldWriter.delegate.getVectors().size(); - } // build a float vector values with random access final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc); // build centroids long centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES); - final CentroidAssignmentScorer centroidAssignmentScorer = calculateAndWriteCentroids( + + final CentroidAssignments centroidAssignments = calculateAndWriteCentroids( fieldWriter.fieldInfo, floatVectorValues, ivfCentroids, + segmentWriteState.infoStream, globalCentroid ); + + CentroidSupplier centroidSupplier = new OnHeapCentroidSupplier(centroidAssignments.cachedCentroids()); + long centroidLength = ivfCentroids.getFilePointer() - centroidOffset; final long[] offsets = buildAndWritePostingsLists( fieldWriter.fieldInfo, - segmentWriteState.infoStream, - centroidAssignmentScorer, + centroidSupplier, floatVectorValues, - ivfClusters + ivfClusters, + segmentWriteState.infoStream, + centroidAssignments.assignmentsByCluster() ); + // write posting lists writeMeta(fieldWriter.fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid); } } @@ -240,16 +231,6 @@ public int ordToDoc(int ord) { }; } - static IVFVectorsReader getIVFReader(KnnVectorsReader vectorsReader, String fieldName) { - if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { - vectorsReader = candidateReader.getFieldReader(fieldName); - } - if (vectorsReader instanceof IVFVectorsReader reader) { - return reader; - } - return null; - } - @Override @SuppressForbidden(reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)") public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { @@ -277,22 +258,25 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()]; final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, in, numVectors); success = false; - CentroidAssignmentScorer centroidAssignmentScorer; long centroidOffset; long centroidLength; String centroidTempName = null; int numCentroids; IndexOutput centroidTemp = null; + CentroidAssignments centroidAssignments; try { centroidTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "civf_", IOContext.DEFAULT); centroidTempName = centroidTemp.getName(); - numCentroids = calculateAndWriteCentroids( + + centroidAssignments = calculateAndWriteCentroids( fieldInfo, floatVectorValues, centroidTemp, mergeState, calculatedGlobalCentroid ); + numCentroids = centroidAssignments.numCentroids(); + success = true; } finally { if (success == false && centroidTempName != null) { @@ -311,21 +295,28 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro CodecUtil.writeFooter(centroidTemp); IOUtils.close(centroidTemp); centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES); - try (IndexInput centroidInput = mergeState.segmentInfo.dir.openInput(centroidTempName, IOContext.DEFAULT)) { - ivfCentroids.copyBytes(centroidInput, centroidInput.length() - CodecUtil.footerLength()); + try (IndexInput centroidsInput = mergeState.segmentInfo.dir.openInput(centroidTempName, IOContext.DEFAULT)) { + ivfCentroids.copyBytes(centroidsInput, centroidsInput.length() - CodecUtil.footerLength()); centroidLength = ivfCentroids.getFilePointer() - centroidOffset; - centroidAssignmentScorer = createCentroidScorer(centroidInput, numCentroids, fieldInfo, calculatedGlobalCentroid); - assert centroidAssignmentScorer.size() == numCentroids; + + CentroidSupplier centroidSupplier = createCentroidSupplier( + centroidsInput, + numCentroids, + fieldInfo, + calculatedGlobalCentroid + ); + // build a float vector values with random access // build centroids final long[] offsets = buildAndWritePostingsLists( fieldInfo, - centroidAssignmentScorer, + centroidSupplier, floatVectorValues, ivfClusters, - mergeState + mergeState.infoStream, + centroidAssignments.assignmentsByCluster() ); - assert offsets.length == centroidAssignmentScorer.size(); + assert offsets.length == centroidSupplier.size(); writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, calculatedGlobalCentroid); } } finally { @@ -453,8 +444,8 @@ public final long ramBytesUsed() { private record FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter delegate) {} - interface CentroidAssignmentScorer { - CentroidAssignmentScorer EMPTY = new CentroidAssignmentScorer() { + interface CentroidSupplier { + CentroidSupplier EMPTY = new CentroidSupplier() { @Override public int size() { return 0; @@ -464,24 +455,29 @@ public int size() { public float[] centroid(int centroidOrdinal) { throw new IllegalStateException("No centroids"); } - - @Override - public float score(int centroidOrdinal) { - throw new IllegalStateException("No centroids"); - } - - @Override - public void setScoringVector(float[] vector) { - throw new IllegalStateException("No centroids"); - } }; int size(); float[] centroid(int centroidOrdinal) throws IOException; + } - void setScoringVector(float[] vector); + // TODO throw away rawCentroids + static class OnHeapCentroidSupplier implements CentroidSupplier { + private final float[][] centroids; - float score(int centroidOrdinal) throws IOException; + OnHeapCentroidSupplier(float[][] centroids) { + this.centroids = centroids; + } + + @Override + public int size() { + return centroids.length; + } + + @Override + public float[] centroid(int centroidOrdinal) throws IOException { + return centroids[centroidOrdinal]; + } } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/KMeans.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/KMeans.java deleted file mode 100644 index 715791c5cbb54..0000000000000 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/KMeans.java +++ /dev/null @@ -1,494 +0,0 @@ -/* - * @notice - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * Modifications copyright (C) 2025 Elasticsearch B.V. - */ -package org.elasticsearch.index.codec.vectors; - -import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.internal.hppc.IntArrayList; -import org.apache.lucene.internal.hppc.IntObjectHashMap; -import org.apache.lucene.util.ArrayUtil; -import org.apache.lucene.util.VectorUtil; - -import java.io.IOException; -import java.util.Arrays; -import java.util.HashSet; -import java.util.Random; -import java.util.Set; - -import static org.elasticsearch.index.codec.vectors.SampleReader.createSampleReader; - -/** KMeans clustering algorithm for vectors */ -class KMeans { - public static final int DEFAULT_RESTARTS = 1; - public static final int DEFAULT_ITRS = 10; - public static final int DEFAULT_SAMPLE_VECTORS_PER_CENTROID = 128; - - private static final float EPS = 1f / 1024f; - private final FloatVectorValues vectors; - private final int numVectors; - private final int numCentroids; - private final Random random; - private final KmeansInitializationMethod initializationMethod; - private final float[][] initCentroids; - private final int restarts; - private final int iters; - - /** - * Cluster vectors into a given number of clusters - * - * @param vectors float vectors - * @param similarityFunction vector similarity function. For COSINE similarity, vectors must be - * normalized. - * @param numClusters number of cluster to cluster vector into - * @return results of clustering: produced centroids and for each vector its centroid - * @throws IOException when if there is an error accessing vectors - */ - static Results cluster(FloatVectorValues vectors, VectorSimilarityFunction similarityFunction, int numClusters) throws IOException { - return cluster( - vectors, - numClusters, - true, - 42L, - KmeansInitializationMethod.PLUS_PLUS, - null, - similarityFunction == VectorSimilarityFunction.COSINE, - DEFAULT_RESTARTS, - DEFAULT_ITRS, - DEFAULT_SAMPLE_VECTORS_PER_CENTROID * numClusters - ); - } - - /** - * Expert: Cluster vectors into a given number of clusters - * - * @param vectors float vectors - * @param numClusters number of cluster to cluster vector into - * @param assignCentroidsToVectors if {@code true} assign centroids for all vectors. Centroids are - * computed on a sample of vectors. If this parameter is {@code true}, in results also return - * for all vectors what centroids they belong to. - * @param seed random seed - * @param initializationMethod Kmeans initialization method - * @param initCentroids initial centroids, if not {@code null} utilize as initial centroids for - * the given initialization method - * @param normalizeCenters for cosine distance, set to true, to use spherical k-means where - * centers are normalized - * @param restarts how many times to run Kmeans algorithm - * @param iters how many iterations to do within a single run - * @param sampleSize sample size to select from all vectors on which to run Kmeans algorithm - * @return results of clustering: produced centroids and if {@code assignCentroidsToVectors == - * true} also for each vector its centroid - * @throws IOException if there is error accessing vectors - */ - static Results cluster( - FloatVectorValues vectors, - int numClusters, - boolean assignCentroidsToVectors, - long seed, - KmeansInitializationMethod initializationMethod, - float[][] initCentroids, - boolean normalizeCenters, - int restarts, - int iters, - int sampleSize - ) throws IOException { - if (vectors.size() == 0) { - return null; - } - // adjust sampleSize and numClusters - sampleSize = Math.max(sampleSize, 100 * numClusters); - if (sampleSize > vectors.size()) { - sampleSize = vectors.size(); - // Decrease the number of clusters if needed - int maxNumClusters = Math.max(1, sampleSize / 100); - numClusters = Math.min(numClusters, maxNumClusters); - } - - Random random = new Random(seed); - float[][] centroids; - if (numClusters == 1) { - centroids = new float[1][vectors.dimension()]; - for (int i = 0; i < vectors.size(); i++) { - float[] vector = vectors.vectorValue(i); - for (int dim = 0; dim < vector.length; dim++) { - centroids[0][dim] += vector[dim]; - } - } - for (int dim = 0; dim < centroids[0].length; dim++) { - centroids[0][dim] /= vectors.size(); - } - } else { - FloatVectorValues sampleVectors = vectors.size() <= sampleSize ? vectors : createSampleReader(vectors, sampleSize, seed); - KMeans kmeans = new KMeans(sampleVectors, numClusters, random, initializationMethod, initCentroids, restarts, iters); - centroids = kmeans.computeCentroids(normalizeCenters); - } - - int[] vectorCentroids = null; - int[] centroidSize = null; - // Assign each vector to the nearest centroid and update the centres - if (assignCentroidsToVectors) { - vectorCentroids = new int[vectors.size()]; - centroidSize = new int[centroids.length]; - assignCentroids(random, vectorCentroids, centroidSize, vectors, centroids); - } - if (normalizeCenters) { - for (float[] centroid : centroids) { - VectorUtil.l2normalize(centroid, false); - } - } - return new Results(centroids, centroidSize, vectorCentroids); - } - - private static void assignCentroids( - Random random, - int[] docCentroids, - int[] centroidSize, - FloatVectorValues vectors, - float[][] centroids - ) throws IOException { - short numCentroids = (short) centroids.length; - assert Arrays.stream(centroidSize).allMatch(size -> size == 0); - for (int docID = 0; docID < vectors.size(); docID++) { - float[] vector = vectors.vectorValue(docID); - short bestCentroid = 0; - if (numCentroids > 1) { - float minSquaredDist = Float.MAX_VALUE; - for (short c = 0; c < numCentroids; c++) { - // TODO: replace with RandomVectorScorer::score possible on quantized vectors - float squareDist = VectorUtil.squareDistance(centroids[c], vector); - if (squareDist < minSquaredDist) { - bestCentroid = c; - minSquaredDist = squareDist; - } - } - } - centroidSize[bestCentroid] += 1; - docCentroids[docID] = bestCentroid; - } - - IntArrayList unassignedCentroids = new IntArrayList(); - for (int c = 0; c < numCentroids; c++) { - if (centroidSize[c] == 0) { - unassignedCentroids.add(c); - } - } - if (unassignedCentroids.size() > 0) { - throwAwayAndSplitCentroids(random, vectors, centroids, docCentroids, centroidSize, unassignedCentroids); - } - assert Arrays.stream(centroidSize).sum() == vectors.size(); - } - - private final float[] kmeansPlusPlusScratch; - - KMeans( - FloatVectorValues vectors, - int numCentroids, - Random random, - KmeansInitializationMethod initializationMethod, - float[][] initCentroids, - int restarts, - int iters - ) { - this.vectors = vectors; - this.numVectors = vectors.size(); - this.numCentroids = numCentroids; - this.random = random; - this.initializationMethod = initializationMethod; - this.restarts = restarts; - this.iters = iters; - this.initCentroids = initCentroids; - this.kmeansPlusPlusScratch = initializationMethod == KmeansInitializationMethod.PLUS_PLUS ? new float[numVectors] : null; - } - - float[][] computeCentroids(boolean normalizeCenters) throws IOException { - // TODO can we make this off-heap, or reusable? This could be a big array - int[] vectorCentroids = new int[numVectors]; - double minSquaredDist = Double.MAX_VALUE; - double squaredDist = 0; - float[][] bestCentroids = null; - float[][] centroids = new float[numCentroids][vectors.dimension()]; - int restarts = this.restarts; - int numInitializedCentroids = 0; - // The user has given us a solid number of centroids to start of with, so skip restarts, fill in - // where we can, and refine - if (initCentroids != null && initCentroids.length > numCentroids / 2) { - int i = 0; - for (; i < Math.min(numCentroids, initCentroids.length); i++) { - System.arraycopy(initCentroids[i], 0, centroids[i], 0, initCentroids[i].length); - } - numInitializedCentroids = i; - restarts = 1; - } - - for (int restart = 0; restart < restarts; restart++) { - switch (initializationMethod) { - case FORGY -> initializeForgy(centroids, numInitializedCentroids); - case RESERVOIR_SAMPLING -> initializeReservoirSampling(centroids, numInitializedCentroids); - case PLUS_PLUS -> initializePlusPlus(centroids, numInitializedCentroids); - } - double prevSquaredDist = Double.MAX_VALUE; - int[] centroidSize = new int[centroids.length]; - for (int iter = 0; iter < iters; iter++) { - squaredDist = runKMeansStep(centroids, centroidSize, vectorCentroids, normalizeCenters); - // Check for convergence - if (prevSquaredDist <= (squaredDist + 1e-6)) { - break; - } - Arrays.fill(centroidSize, 0); - prevSquaredDist = squaredDist; - } - if (squaredDist < minSquaredDist) { - minSquaredDist = squaredDist; - // Copy out the best centroid as it might be overwritten by the next restart - bestCentroids = new float[centroids.length][]; - for (int i = 0; i < centroids.length; i++) { - bestCentroids[i] = ArrayUtil.copyArray(centroids[i]); - } - } - } - return bestCentroids; - } - - /** - * Initialize centroids using Forgy method: randomly select numCentroids vectors for initial - * centroids - */ - private void initializeForgy(float[][] initialCentroids, int fromCentroid) throws IOException { - if (fromCentroid >= numCentroids) { - return; - } - int numCentroids = this.numCentroids - fromCentroid; - Set selection = new HashSet<>(); - while (selection.size() < numCentroids) { - selection.add(random.nextInt(numVectors)); - } - int i = 0; - for (Integer selectedIdx : selection) { - float[] vector = vectors.vectorValue(selectedIdx); - System.arraycopy(vector, 0, initialCentroids[fromCentroid + i++], 0, vector.length); - } - } - - /** Initialize centroids using a reservoir sampling method */ - private void initializeReservoirSampling(float[][] initialCentroids, int fromCentroid) throws IOException { - if (fromCentroid >= numCentroids) { - return; - } - int numCentroids = this.numCentroids - fromCentroid; - for (int index = 0; index < numVectors; index++) { - float[] vector = vectors.vectorValue(index); - if (index < numCentroids) { - System.arraycopy(vector, 0, initialCentroids[index + fromCentroid], 0, vector.length); - } else if (random.nextDouble() < numCentroids * (1.0 / index)) { - int c = random.nextInt(numCentroids); - System.arraycopy(vector, 0, initialCentroids[c + fromCentroid], 0, vector.length); - } - } - } - - /** Initialize centroids using Kmeans++ method */ - private void initializePlusPlus(float[][] initialCentroids, int fromCentroid) throws IOException { - if (fromCentroid >= numCentroids) { - return; - } - // Choose the first centroid uniformly at random - int firstIndex = random.nextInt(numVectors); - float[] value = vectors.vectorValue(firstIndex); - System.arraycopy(value, 0, initialCentroids[fromCentroid], 0, value.length); - - // Store distances of each point to the nearest centroid - Arrays.fill(kmeansPlusPlusScratch, Float.MAX_VALUE); - - // Step 2 and 3: Select remaining centroids - for (int i = fromCentroid + 1; i < numCentroids; i++) { - // Update distances with the new centroid - double totalSum = 0; - for (int j = 0; j < numVectors; j++) { - // TODO: replace with RandomVectorScorer::score possible on quantized vectors - float dist = VectorUtil.squareDistance(vectors.vectorValue(j), initialCentroids[i - 1]); - if (dist < kmeansPlusPlusScratch[j]) { - kmeansPlusPlusScratch[j] = dist; - } - totalSum += kmeansPlusPlusScratch[j]; - } - - // Randomly select next centroid - double r = totalSum * random.nextDouble(); - double cumulativeSum = 0; - int nextCentroidIndex = 0; - for (int j = 0; j < numVectors; j++) { - cumulativeSum += kmeansPlusPlusScratch[j]; - if (cumulativeSum >= r && kmeansPlusPlusScratch[j] > 0) { - nextCentroidIndex = j; - break; - } - } - // Update centroid - value = vectors.vectorValue(nextCentroidIndex); - System.arraycopy(value, 0, initialCentroids[i], 0, value.length); - } - } - - /** - * Run kmeans step - * - * @param centroids centroids, new calculated centroids are written here - * @param docCentroids for each document which centroid it belongs to, results will be written - * here - * @param normalizeCentroids if centroids should be normalized; used for cosine similarity only - * @throws IOException if there is an error accessing vector values - */ - private double runKMeansStep(float[][] centroids, int[] centroidSize, int[] docCentroids, boolean normalizeCentroids) - throws IOException { - short numCentroids = (short) centroids.length; - assert Arrays.stream(centroidSize).allMatch(size -> size == 0); - float[][] newCentroids = new float[numCentroids][centroids[0].length]; - - double sumSquaredDist = 0; - for (int docID = 0; docID < vectors.size(); docID++) { - float[] vector = vectors.vectorValue(docID); - short bestCentroid = 0; - if (numCentroids > 1) { - float minSquaredDist = Float.MAX_VALUE; - for (short c = 0; c < numCentroids; c++) { - // TODO: replace with RandomVectorScorer::score possible on quantized vectors - float squareDist = VectorUtil.squareDistance(centroids[c], vector); - if (squareDist < minSquaredDist) { - bestCentroid = c; - minSquaredDist = squareDist; - } - } - sumSquaredDist += minSquaredDist; - } - - centroidSize[bestCentroid] += 1; - for (int dim = 0; dim < vector.length; dim++) { - newCentroids[bestCentroid][dim] += vector[dim]; - } - docCentroids[docID] = bestCentroid; - } - - IntArrayList unassignedCentroids = new IntArrayList(); - for (int c = 0; c < numCentroids; c++) { - if (centroidSize[c] > 0) { - for (int dim = 0; dim < newCentroids[c].length; dim++) { - centroids[c][dim] = newCentroids[c][dim] / centroidSize[c]; - } - } else { - unassignedCentroids.add(c); - } - } - if (unassignedCentroids.size() > 0) { - throwAwayAndSplitCentroids(random, vectors, centroids, docCentroids, centroidSize, unassignedCentroids); - } - if (normalizeCentroids) { - for (float[] centroid : centroids) { - VectorUtil.l2normalize(centroid, false); - } - } - assert Arrays.stream(centroidSize).sum() == vectors.size(); - return sumSquaredDist; - } - - static void throwAwayAndSplitCentroids( - Random random, - FloatVectorValues vectors, - float[][] centroids, - int[] docCentroids, - int[] centroidSize, - IntArrayList unassignedCentroidsIdxs - ) throws IOException { - IntObjectHashMap splitCentroids = new IntObjectHashMap<>(unassignedCentroidsIdxs.size()); - // used for splitting logic - int[] splitSizes = Arrays.copyOf(centroidSize, centroidSize.length); - // FAISS style algorithm for splitting - for (int i = 0; i < unassignedCentroidsIdxs.size(); i++) { - int toSplit; - for (toSplit = 0; true; toSplit = (toSplit + 1) % centroids.length) { - /* probability to pick this cluster for split */ - double p = (splitSizes[toSplit] - 1.0) / (float) (docCentroids.length - centroids.length); - float r = random.nextFloat(); - if (r < p) { - break; /* found our cluster to be split */ - } - } - int unassignedCentroidIdx = unassignedCentroidsIdxs.get(i); - // keep track of those that are split, this way we reassign docCentroids and fix up true size - // & centroids - splitCentroids.getOrDefault(toSplit, new IntArrayList()).add(unassignedCentroidIdx); - System.arraycopy(centroids[toSplit], 0, centroids[unassignedCentroidIdx], 0, centroids[unassignedCentroidIdx].length); - for (int dim = 0; dim < centroids[unassignedCentroidIdx].length; dim++) { - if (dim % 2 == 0) { - centroids[unassignedCentroidIdx][dim] *= (1 + EPS); - centroids[toSplit][dim] *= (1 - EPS); - } else { - centroids[unassignedCentroidIdx][dim] *= (1 - EPS); - centroids[toSplit][dim] *= (1 + EPS); - } - } - splitSizes[unassignedCentroidIdx] = splitSizes[toSplit] / 2; - splitSizes[toSplit] -= splitSizes[unassignedCentroidIdx]; - } - // now we need to reassign docCentroids and fix up true size & centroids - for (int i = 0; i < docCentroids.length; i++) { - int docCentroid = docCentroids[i]; - IntArrayList split = splitCentroids.get(docCentroid); - if (split != null) { - // we need to reassign this doc - int bestCentroid = docCentroid; - float bestDist = VectorUtil.squareDistance(centroids[docCentroid], vectors.vectorValue(i)); - for (int j = 0; j < split.size(); j++) { - int newCentroid = split.get(j); - float dist = VectorUtil.squareDistance(centroids[newCentroid], vectors.vectorValue(i)); - if (dist < bestDist) { - bestCentroid = newCentroid; - bestDist = dist; - } - } - if (bestCentroid != docCentroid) { - // we need to update the centroid size - centroidSize[docCentroid]--; - centroidSize[bestCentroid]++; - docCentroids[i] = (short) bestCentroid; - // we need to update the old and new centroid accounting for size as well - for (int dim = 0; dim < centroids[docCentroid].length; dim++) { - centroids[docCentroid][dim] -= vectors.vectorValue(i)[dim] / centroidSize[docCentroid]; - centroids[bestCentroid][dim] += vectors.vectorValue(i)[dim] / centroidSize[bestCentroid]; - } - } - } - } - } - - /** Kmeans initialization methods */ - public enum KmeansInitializationMethod { - FORGY, - RESERVOIR_SAMPLING, - PLUS_PLUS - } - - /** - * Results of KMeans clustering - * - * @param centroids the produced centroids - * @param centroidsSize for each centroid how many vectors belong to it - * @param vectorCentroids for each vector which centroid it belongs to - */ - public record Results(float[][] centroids, int[] centroidsSize, int[] vectorCentroids) {} -} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/FloatVectorValuesSlice.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/FloatVectorValuesSlice.java new file mode 100644 index 0000000000000..6da6ff196e93e --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/FloatVectorValuesSlice.java @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.cluster; + +import org.apache.lucene.index.FloatVectorValues; + +import java.io.IOException; + +class FloatVectorValuesSlice extends FloatVectorValues { + + private final FloatVectorValues allValues; + private final int[] slice; + + FloatVectorValuesSlice(FloatVectorValues allValues, int[] slice) { + assert slice != null; + assert slice.length <= allValues.size(); + this.allValues = allValues; + this.slice = slice; + } + + @Override + public float[] vectorValue(int ord) throws IOException { + return this.allValues.vectorValue(this.slice[ord]); + } + + @Override + public int dimension() { + return this.allValues.dimension(); + } + + @Override + public int size() { + return slice.length; + } + + @Override + public int ordToDoc(int ord) { + return this.slice[ord]; + } + + @Override + public FloatVectorValues copy() throws IOException { + return new FloatVectorValuesSlice(this.allValues.copy(), this.slice); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java new file mode 100644 index 0000000000000..fdd02a0cf752a --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java @@ -0,0 +1,197 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.cluster; + +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.util.VectorUtil; + +import java.io.IOException; + +/** + * An implementation of the hierarchical k-means algorithm that better partitions data than naive k-means + */ +public class HierarchicalKMeans { + + static final int MAXK = 128; + static final int MAX_ITERATIONS_DEFAULT = 6; + static final int SAMPLES_PER_CLUSTER_DEFAULT = 256; + static final float DEFAULT_SOAR_LAMBDA = 1.0f; + + final int dimension; + final int maxIterations; + final int samplesPerCluster; + final int clustersPerNeighborhood; + final float soarLambda; + + public HierarchicalKMeans(int dimension) { + this(dimension, MAX_ITERATIONS_DEFAULT, SAMPLES_PER_CLUSTER_DEFAULT, MAXK, DEFAULT_SOAR_LAMBDA); + } + + HierarchicalKMeans(int dimension, int maxIterations, int samplesPerCluster, int clustersPerNeighborhood, float soarLambda) { + this.dimension = dimension; + this.maxIterations = maxIterations; + this.samplesPerCluster = samplesPerCluster; + this.clustersPerNeighborhood = clustersPerNeighborhood; + this.soarLambda = soarLambda; + } + + /** + * clusters or moreso partitions the set of vectors by starting with a rough number of partitions and then recursively refining those + * lastly a pass is made to adjust nearby neighborhoods and add an extra assignment per vector to nearby neighborhoods + * + * @param vectors the vectors to cluster + * @param targetSize the rough number of vectors that should be attached to a cluster + * @return the centroids and the vectors assignments and SOAR (spilled from nearby neighborhoods) assignments + * @throws IOException is thrown if vectors is inaccessible + */ + public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IOException { + + if (vectors.size() == 0) { + return new KMeansIntermediate(); + } + + // if we have a small number of vectors pick one and output that as the centroid + if (vectors.size() <= targetSize) { + float[] centroid = new float[dimension]; + System.arraycopy(vectors.vectorValue(0), 0, centroid, 0, dimension); + return new KMeansIntermediate(new float[][] { centroid }, new int[vectors.size()]); + } + + // partition the space + KMeansIntermediate kMeansIntermediate = clusterAndSplit(vectors, targetSize); + if (kMeansIntermediate.centroids().length > 1 && kMeansIntermediate.centroids().length < vectors.size()) { + float f = Math.min((float) samplesPerCluster / targetSize, 1.0f); + int localSampleSize = (int) (f * vectors.size()); + KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, maxIterations, clustersPerNeighborhood, DEFAULT_SOAR_LAMBDA); + kMeansLocal.cluster(vectors, kMeansIntermediate, true); + } + + return kMeansIntermediate; + } + + KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int targetSize) throws IOException { + if (vectors.size() <= targetSize) { + return new KMeansIntermediate(); + } + + int k = Math.clamp((int) ((vectors.size() + targetSize / 2.0f) / (float) targetSize), 2, MAXK); + int m = Math.min(k * samplesPerCluster, vectors.size()); + + // TODO: instead of creating a sub-cluster assignments reuse the parent array each time + int[] assignments = new int[vectors.size()]; + + KMeansLocal kmeans = new KMeansLocal(m, maxIterations); + float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, k); + KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids); + kmeans.cluster(vectors, kMeansIntermediate); + + // TODO: consider adding cluster size counts to the kmeans algo + // handle assignment here so we can track distance and cluster size + int[] centroidVectorCount = new int[centroids.length]; + float[][] nextCentroids = new float[centroids.length][dimension]; + for (int i = 0; i < vectors.size(); i++) { + float smallest = Float.MAX_VALUE; + int centroidIdx = -1; + float[] vector = vectors.vectorValue(i); + for (int j = 0; j < centroids.length; j++) { + float[] centroid = centroids[j]; + float d = VectorUtil.squareDistance(vector, centroid); + if (d < smallest) { + smallest = d; + centroidIdx = j; + } + } + centroidVectorCount[centroidIdx]++; + for (int j = 0; j < dimension; j++) { + nextCentroids[centroidIdx][j] += vector[j]; + } + assignments[i] = centroidIdx; + } + + // update centroids based on assignments of all vectors + for (int i = 0; i < centroids.length; i++) { + if (centroidVectorCount[i] > 0) { + for (int j = 0; j < dimension; j++) { + centroids[i][j] = nextCentroids[i][j] / centroidVectorCount[i]; + } + } + } + + int effectiveK = 0; + for (int i = 0; i < centroidVectorCount.length; i++) { + if (centroidVectorCount[i] > 0) { + effectiveK++; + } + } + + kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc); + + if (effectiveK == 1) { + return kMeansIntermediate; + } + + for (int c = 0; c < centroidVectorCount.length; c++) { + // Recurse for each cluster which is larger than targetSize + // Give ourselves 30% margin for the target size + if (100 * centroidVectorCount[c] > 134 * targetSize) { + FloatVectorValues sample = createClusterSlice(centroidVectorCount[c], c, vectors, assignments); + + // TODO: consider iterative here instead of recursive + // recursive call to build out the sub partitions around this centroid c + // subsequently reconcile and flatten the space of all centroids and assignments into one structure we can return + updateAssignmentsWithRecursiveSplit(kMeansIntermediate, c, clusterAndSplit(sample, targetSize)); + } + } + + return kMeansIntermediate; + } + + static FloatVectorValues createClusterSlice(int clusterSize, int cluster, FloatVectorValues vectors, int[] assignments) { + int[] slice = new int[clusterSize]; + int idx = 0; + for (int i = 0; i < assignments.length; i++) { + if (assignments[i] == cluster) { + slice[idx] = i; + idx++; + } + } + + return new FloatVectorValuesSlice(vectors, slice); + } + + void updateAssignmentsWithRecursiveSplit(KMeansIntermediate current, int cluster, KMeansIntermediate subPartitions) { + int orgCentroidsSize = current.centroids().length; + int newCentroidsSize = current.centroids().length + subPartitions.centroids().length - 1; + + // update based on the outcomes from the split clusters recursion + if (subPartitions.centroids().length > 1) { + float[][] newCentroids = new float[newCentroidsSize][dimension]; + System.arraycopy(current.centroids(), 0, newCentroids, 0, current.centroids().length); + + // replace the original cluster + int origCentroidOrd = 0; + newCentroids[cluster] = subPartitions.centroids()[0]; + + // append the remainder + System.arraycopy(subPartitions.centroids(), 1, newCentroids, current.centroids().length, subPartitions.centroids().length - 1); + + current.setCentroids(newCentroids); + + for (int i = 0; i < subPartitions.assignments().length; i++) { + // this is a new centroid that was added, and so we'll need to remap it + if (subPartitions.assignments()[i] != origCentroidOrd) { + int parentOrd = subPartitions.ordToDoc(i); + assert current.assignments()[parentOrd] == cluster; + current.assignments()[parentOrd] = subPartitions.assignments()[i] + orgCentroidsSize - 1; + } + } + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansIntermediate.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansIntermediate.java new file mode 100644 index 0000000000000..75caa5c7d3281 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansIntermediate.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.cluster; + +import org.apache.lucene.util.hnsw.IntToIntFunction; + +/** + * Intermediate object for clustering (partitioning) a set of vectors + */ +class KMeansIntermediate extends KMeansResult { + private final IntToIntFunction assignmentOrds; + + private KMeansIntermediate(float[][] centroids, int[] assignments, IntToIntFunction assignmentOrds, int[] soarAssignments) { + super(centroids, assignments, soarAssignments); + assert assignmentOrds != null; + this.assignmentOrds = assignmentOrds; + } + + KMeansIntermediate(float[][] centroids, int[] assignments, IntToIntFunction assignmentOrdinals) { + this(centroids, assignments, assignmentOrdinals, new int[0]); + } + + KMeansIntermediate() { + this(new float[0][0], new int[0], i -> i, new int[0]); + } + + KMeansIntermediate(float[][] centroids) { + this(centroids, new int[0], i -> i, new int[0]); + } + + KMeansIntermediate(float[][] centroids, int[] assignments) { + this(centroids, assignments, i -> i, new int[0]); + } + + public int ordToDoc(int ord) { + return assignmentOrds.apply(ord); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java new file mode 100644 index 0000000000000..415a082c5a2b1 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java @@ -0,0 +1,306 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.cluster; + +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.simdvec.ESVectorUtil; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; + +/** + * k-means implementation specific to the needs of the {@link HierarchicalKMeans} algorithm that deals specifically + * with finalizing nearby pre-established clusters and generate + * SOAR assignments + */ +class KMeansLocal { + + final int sampleSize; + final int maxIterations; + final int clustersPerNeighborhood; + final float soarLambda; + + KMeansLocal(int sampleSize, int maxIterations, int clustersPerNeighborhood, float soarLambda) { + this.sampleSize = sampleSize; + this.maxIterations = maxIterations; + this.clustersPerNeighborhood = clustersPerNeighborhood; + this.soarLambda = soarLambda; + } + + KMeansLocal(int sampleSize, int maxIterations) { + this(sampleSize, maxIterations, -1, -1f); + } + + /** + * uses a Reservoir Sampling approach to picking the initial centroids which are subsequently expected + * to be used by a clustering algorithm + * + * @param vectors used to pick an initial set of random centroids + * @param centroidCount the total number of centroids to pick + * @return randomly selected centroids that are the min of centroidCount and sampleSize + * @throws IOException is thrown if vectors is inaccessible + */ + static float[][] pickInitialCentroids(FloatVectorValues vectors, int centroidCount) throws IOException { + Random random = new Random(42L); + int centroidsSize = Math.min(vectors.size(), centroidCount); + float[][] centroids = new float[centroidsSize][vectors.dimension()]; + for (int i = 0; i < vectors.size(); i++) { + float[] vector; + if (i < centroidCount) { + vector = vectors.vectorValue(i); + System.arraycopy(vector, 0, centroids[i], 0, vector.length); + } else if (random.nextDouble() < centroidCount * (1.0 / i)) { + int c = random.nextInt(centroidCount); + vector = vectors.vectorValue(i); + System.arraycopy(vector, 0, centroids[c], 0, vector.length); + } + } + return centroids; + } + + private boolean stepLloyd( + FloatVectorValues vectors, + float[][] centroids, + float[][] nextCentroids, + int[] assignments, + int sampleSize, + List neighborhoods + ) throws IOException { + boolean changed = false; + int dim = vectors.dimension(); + int[] centroidCounts = new int[centroids.length]; + + for (int i = 0; i < nextCentroids.length; i++) { + Arrays.fill(nextCentroids[i], 0.0f); + } + + for (int i = 0; i < sampleSize; i++) { + float[] vector = vectors.vectorValue(i); + int[] neighborOffsets = null; + int centroidIdx = -1; + if (neighborhoods != null) { + neighborOffsets = neighborhoods.get(assignments[i]); + centroidIdx = assignments[i]; + } + int bestCentroidOffset = getBestCentroidOffset(centroids, vector, centroidIdx, neighborOffsets); + if (assignments[i] != bestCentroidOffset) { + changed = true; + } + assignments[i] = bestCentroidOffset; + centroidCounts[bestCentroidOffset]++; + for (short d = 0; d < dim; d++) { + nextCentroids[bestCentroidOffset][d] += vector[d]; + } + } + + for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) { + if (centroidCounts[clusterIdx] > 0) { + float countF = (float) centroidCounts[clusterIdx]; + for (short d = 0; d < dim; d++) { + centroids[clusterIdx][d] = nextCentroids[clusterIdx][d] / countF; + } + } + } + + return changed; + } + + int getBestCentroidOffset(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) { + int bestCentroidOffset = centroidIdx; + float minDsq; + if (centroidIdx > 0 && centroidIdx < centroids.length) { + minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]); + } else { + minDsq = Float.MAX_VALUE; + } + + int k = 0; + for (int j = 0; j < centroids.length; j++) { + if (centroidOffsets == null || j == centroidOffsets[k]) { + float dsq = VectorUtil.squareDistance(vector, centroids[j]); + if (dsq < minDsq) { + minDsq = dsq; + bestCentroidOffset = j; + } + } + } + return bestCentroidOffset; + } + + private void computeNeighborhoods(float[][] centers, List neighborhoods, int clustersPerNeighborhood) { + int k = neighborhoods.size(); + + if (k == 0 || clustersPerNeighborhood <= 0) { + return; + } + + List neighborQueues = new ArrayList<>(k); + for (int i = 0; i < k; i++) { + neighborQueues.add(new NeighborQueue(clustersPerNeighborhood, true)); + } + for (int i = 0; i < k - 1; i++) { + for (int j = i + 1; j < k; j++) { + float dsq = VectorUtil.squareDistance(centers[i], centers[j]); + neighborQueues.get(j).insertWithOverflow(i, dsq); + neighborQueues.get(i).insertWithOverflow(j, dsq); + } + } + + for (int i = 0; i < k; i++) { + NeighborQueue queue = neighborQueues.get(i); + int neighborCount = queue.size(); + int[] neighbors = new int[neighborCount]; + queue.consumeNodes(neighbors); + neighborhoods.set(i, neighbors); + } + } + + private int[] assignSpilled(FloatVectorValues vectors, List neighborhoods, float[][] centroids, int[] assignments) + throws IOException { + // SOAR uses an adjusted distance for assigning spilled documents which is + // given by: + // + // soar(x, c) = ||x - c||^2 + lambda * ((x - c_1)^t (x - c))^2 / ||x - c_1||^2 + // + // Here, x is the document, c is the nearest centroid, and c_1 is the first + // centroid the document was assigned to. The document is assigned to the + // cluster with the smallest soar(x, c). + + int[] spilledAssignments = new int[assignments.length]; + + float[] diffs = new float[vectors.dimension()]; + for (int i = 0; i < vectors.size(); i++) { + float[] vector = vectors.vectorValue(i); + + int currAssignment = assignments[i]; + float[] currentCentroid = centroids[currAssignment]; + for (short j = 0; j < vectors.dimension(); j++) { + float diff = vector[j] - currentCentroid[j]; + diffs[j] = diff; + } + + // TODO: cache these? + // float vectorCentroidDist = assignmentDistances[i]; + float vectorCentroidDist = VectorUtil.squareDistance(vector, currentCentroid); + + int bestAssignment = -1; + float minSoar = Float.MAX_VALUE; + assert neighborhoods.get(currAssignment) != null; + for (int neighbor : neighborhoods.get(currAssignment)) { + if (neighbor == currAssignment) { + continue; + } + float[] neighborCentroid = centroids[neighbor]; + float soar = distanceSoar(diffs, vector, neighborCentroid, vectorCentroidDist); + if (soar < minSoar) { + bestAssignment = neighbor; + minSoar = soar; + } + } + + spilledAssignments[i] = bestAssignment; + } + + return spilledAssignments; + } + + private float distanceSoar(float[] residual, float[] vector, float[] centroid, float rnorm) { + // TODO: combine these to be more efficient + float dsq = VectorUtil.squareDistance(vector, centroid); + float rproj = ESVectorUtil.soarResidual(vector, centroid, residual); + return dsq + soarLambda * rproj * rproj / rnorm; + } + + /** + * cluster using a lloyd k-means algorithm that is not neighbor aware + * + * @param vectors the vectors to cluster + * @param kMeansIntermediate the output object to populate which minimally includes centroids, + * but may include assignments and soar assignments as well; care should be taken in + * passing in a valid output object with a centroids array that is the size of centroids expected + * @throws IOException is thrown if vectors is inaccessible + */ + void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) throws IOException { + cluster(vectors, kMeansIntermediate, false); + } + + /** + * cluster using a lloyd kmeans algorithm that also considers prior clustered neighborhoods when adjusting centroids + * this also is used to generate the neighborhood aware additional (SOAR) assignments + * + * @param vectors the vectors to cluster + * @param kMeansIntermediate the output object to populate which minimally includes centroids, + * the prior assignments of the given vectors; care should be taken in + * passing in a valid output object with a centroids array that is the size of centroids expected + * and assignments that are the same size as the vectors. The SOAR assignments are overwritten by this operation. + * @param neighborAware whether nearby neighboring centroids and their vectors should be used to update the centroid positions, + * implies SOAR assignments + * @throws IOException is thrown if vectors is inaccessible + */ + void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, boolean neighborAware) throws IOException { + float[][] centroids = kMeansIntermediate.centroids(); + + List neighborhoods = null; + if (neighborAware) { + int k = centroids.length; + neighborhoods = new ArrayList<>(k); + for (int i = 0; i < k; ++i) { + neighborhoods.add(null); + } + computeNeighborhoods(centroids, neighborhoods, clustersPerNeighborhood); + } + cluster(vectors, kMeansIntermediate, neighborhoods); + if (neighborAware && clustersPerNeighborhood > 0) { + int[] assignments = kMeansIntermediate.assignments(); + assert assignments != null; + assert assignments.length == vectors.size(); + kMeansIntermediate.setSoarAssignments(assignSpilled(vectors, neighborhoods, centroids, assignments)); + } + } + + void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, List neighborhoods) throws IOException { + float[][] centroids = kMeansIntermediate.centroids(); + int k = centroids.length; + int n = vectors.size(); + + if (k == 1 || k >= n) { + return; + } + + int[] assignments = new int[n]; + float[][] nextCentroids = new float[centroids.length][vectors.dimension()]; + for (int i = 0; i < maxIterations; i++) { + if (stepLloyd(vectors, centroids, nextCentroids, assignments, sampleSize, neighborhoods) == false) { + break; + } + } + stepLloyd(vectors, centroids, nextCentroids, assignments, vectors.size(), neighborhoods); + } + + /** + * helper that calls {@link KMeansLocal#cluster(FloatVectorValues, KMeansIntermediate)} given a set of initialized centroids, + * this call is not neighbor aware + * + * @param vectors the vectors to cluster + * @param centroids the initialized centroids to be shifted using k-means + * @param sampleSize the subset of vectors to use when shifting centroids + * @param maxIterations the max iterations to shift centroids + */ + public static void cluster(FloatVectorValues vectors, float[][] centroids, int sampleSize, int maxIterations) throws IOException { + KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids); + KMeansLocal kMeans = new KMeansLocal(sampleSize, maxIterations); + kMeans.cluster(vectors, kMeansIntermediate); + } + +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansResult.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansResult.java new file mode 100644 index 0000000000000..5c2f4afb03f1a --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansResult.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.cluster; + +/** + * Output object for clustering (partitioning) a set of vectors + */ +public class KMeansResult { + private float[][] centroids; + private final int[] assignments; + private int[] soarAssignments; + + KMeansResult(float[][] centroids, int[] assignments, int[] soarAssignments) { + assert centroids != null; + assert assignments != null; + assert soarAssignments != null; + this.centroids = centroids; + this.assignments = assignments; + this.soarAssignments = soarAssignments; + } + + public float[][] centroids() { + return centroids; + } + + void setCentroids(float[][] centroids) { + this.centroids = centroids; + } + + public int[] assignments() { + return assignments; + } + + void setSoarAssignments(int[] soarAssignments) { + this.soarAssignments = soarAssignments; + } + + public int[] soarAssignments() { + return soarAssignments; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/NeighborQueue.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/NeighborQueue.java similarity index 98% rename from server/src/main/java/org/elasticsearch/index/codec/vectors/NeighborQueue.java rename to server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/NeighborQueue.java index f27e85d46cddb..48aa3c5004843 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/NeighborQueue.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/NeighborQueue.java @@ -14,10 +14,10 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. + * * Modifications copyright (C) 2025 Elasticsearch B.V. */ - -package org.elasticsearch.index.codec.vectors; +package org.elasticsearch.index.codec.vectors.cluster; import org.apache.lucene.util.LongHeap; import org.apache.lucene.util.NumericUtils; diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/KMeansTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/KMeansTests.java deleted file mode 100644 index 001847a521ba5..0000000000000 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/KMeansTests.java +++ /dev/null @@ -1,154 +0,0 @@ -/* - * @notice - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * Modifications copyright (C) 2025 Elasticsearch B.V. - */ - -package org.elasticsearch.index.codec.vectors; - -import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.index.VectorSimilarityFunction; -import org.elasticsearch.test.ESTestCase; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -public class KMeansTests extends ESTestCase { - - public void testKMeansAPI() throws IOException { - int nClusters = random().nextInt(1, 10); - int nVectors = random().nextInt(nClusters * 100, nClusters * 200); - int dims = random().nextInt(2, 20); - int randIdx = random().nextInt(VectorSimilarityFunction.values().length); - VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.values()[randIdx]; - FloatVectorValues vectors = generateData(nVectors, dims, nClusters); - - // default case - { - KMeans.Results results = KMeans.cluster(vectors, similarityFunction, nClusters); - assertResults(results, nClusters, nVectors, true); - assertEquals(nClusters, results.centroids().length); - assertEquals(nClusters, results.centroidsSize().length); - assertEquals(nVectors, results.vectorCentroids().length); - } - // expert case - { - boolean assignCentroidsToVectors = random().nextBoolean(); - int randIdx2 = random().nextInt(KMeans.KmeansInitializationMethod.values().length); - KMeans.KmeansInitializationMethod initializationMethod = KMeans.KmeansInitializationMethod.values()[randIdx2]; - int restarts = random().nextInt(1, 6); - int iters = random().nextInt(1, 10); - int sampleSize = random().nextInt(10, nVectors * 2); - - KMeans.Results results = KMeans.cluster( - vectors, - nClusters, - assignCentroidsToVectors, - random().nextLong(), - initializationMethod, - null, - similarityFunction == VectorSimilarityFunction.COSINE, - restarts, - iters, - sampleSize - ); - assertResults(results, nClusters, nVectors, assignCentroidsToVectors); - } - } - - private void assertResults(KMeans.Results results, int nClusters, int nVectors, boolean assignCentroidsToVectors) { - assertEquals(nClusters, results.centroids().length); - if (assignCentroidsToVectors) { - assertEquals(nClusters, results.centroidsSize().length); - assertEquals(nVectors, results.vectorCentroids().length); - int[] centroidsSize = new int[nClusters]; - for (int i = 0; i < nVectors; i++) { - centroidsSize[results.vectorCentroids()[i]]++; - } - assertArrayEquals(centroidsSize, results.centroidsSize()); - } else { - assertNull(results.vectorCentroids()); - } - } - - public void testKMeansSpecialCases() throws IOException { - { - // nClusters > nVectors - int nClusters = 20; - int nVectors = 10; - FloatVectorValues vectors = generateData(nVectors, 5, nClusters); - KMeans.Results results = KMeans.cluster(vectors, VectorSimilarityFunction.EUCLIDEAN, nClusters); - // assert that we get 1 centroid, as nClusters will be adjusted - assertEquals(1, results.centroids().length); - assertEquals(nVectors, results.vectorCentroids().length); - } - { - // small sample size - int sampleSize = 2; - int nClusters = 2; - int nVectors = 300; - FloatVectorValues vectors = generateData(nVectors, 5, nClusters); - KMeans.KmeansInitializationMethod initializationMethod = KMeans.KmeansInitializationMethod.PLUS_PLUS; - KMeans.Results results = KMeans.cluster( - vectors, - nClusters, - true, - random().nextLong(), - initializationMethod, - null, - false, - 1, - 2, - sampleSize - ); - assertResults(results, nClusters, nVectors, true); - } - } - - public void testKMeansSAllZero() throws IOException { - int nClusters = 10; - List vectors = new ArrayList<>(); - for (int i = 0; i < 1000; i++) { - float[] vector = new float[5]; - vectors.add(vector); - } - KMeans.Results results = KMeans.cluster(FloatVectorValues.fromFloats(vectors, 5), VectorSimilarityFunction.EUCLIDEAN, nClusters); - assertResults(results, nClusters, 1000, true); - } - - private static FloatVectorValues generateData(int nSamples, int nDims, int nClusters) { - List vectors = new ArrayList<>(nSamples); - float[][] centroids = new float[nClusters][nDims]; - // Generate random centroids - for (int i = 0; i < nClusters; i++) { - for (int j = 0; j < nDims; j++) { - centroids[i][j] = random().nextFloat() * 100; - } - } - // Generate data points around centroids - for (int i = 0; i < nSamples; i++) { - int cluster = random().nextInt(nClusters); - float[] vector = new float[nDims]; - for (int j = 0; j < nDims; j++) { - vector[j] = centroids[cluster][j] + random().nextFloat() * 10 - 5; - } - vectors.add(vector); - } - return FloatVectorValues.fromFloats(vectors, nDims); - } -} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeansTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeansTests.java new file mode 100644 index 0000000000000..4c481ca4a5f36 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeansTests.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.index.codec.vectors.cluster; + +import org.apache.lucene.index.FloatVectorValues; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +public class HierarchicalKMeansTests extends ESTestCase { + + public void testHKmeans() throws IOException { + int nClusters = random().nextInt(1, 10); + int nVectors = random().nextInt(nClusters * 100, nClusters * 200); + int dims = random().nextInt(2, 20); + int sampleSize = random().nextInt(100, nVectors + 1); + int maxIterations = random().nextInt(0, 100); + int clustersPerNeighborhood = random().nextInt(0, 512); + float soarLambda = random().nextFloat(0.5f, 1.5f); + FloatVectorValues vectors = generateData(nVectors, dims, nClusters); + + int targetSize = (int) ((float) nVectors / (float) nClusters); + HierarchicalKMeans hkmeans = new HierarchicalKMeans(dims, maxIterations, sampleSize, clustersPerNeighborhood, soarLambda); + + KMeansResult result = hkmeans.cluster(vectors, targetSize); + + float[][] centroids = result.centroids(); + int[] assignments = result.assignments(); + int[] soarAssignments = result.soarAssignments(); + + assertEquals(nClusters, centroids.length, 6); + assertEquals(nVectors, assignments.length); + if (centroids.length > 1 && clustersPerNeighborhood > 0) { + assertEquals(nVectors, soarAssignments.length); + // verify no duplicates exist + for (int i = 0; i < assignments.length; i++) { + assert assignments[i] != soarAssignments[i]; + } + } + } + + private static FloatVectorValues generateData(int nSamples, int nDims, int nClusters) { + List vectors = new ArrayList<>(nSamples); + float[][] centroids = new float[nClusters][nDims]; + // Generate random centroids + for (int i = 0; i < nClusters; i++) { + for (int j = 0; j < nDims; j++) { + centroids[i][j] = random().nextFloat() * 100; + } + } + // Generate data points around centroids + for (int i = 0; i < nSamples; i++) { + int cluster = random().nextInt(nClusters); + float[] vector = new float[nDims]; + for (int j = 0; j < nDims; j++) { + vector[j] = centroids[cluster][j] + random().nextFloat() * 10 - 5; + } + vectors.add(vector); + } + return FloatVectorValues.fromFloats(vectors, nDims); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java new file mode 100644 index 0000000000000..c0a0ca8341129 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java @@ -0,0 +1,127 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.cluster; + +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +public class KMeansLocalTests extends ESTestCase { + + public void testKMeansNeighbors() throws IOException { + int nClusters = random().nextInt(1, 10); + int nVectors = random().nextInt(nClusters * 100, nClusters * 200); + int dims = random().nextInt(2, 20); + int sampleSize = random().nextInt(100, nVectors + 1); + int maxIterations = random().nextInt(0, 100); + int clustersPerNeighborhood = random().nextInt(0, 512); + float soarLambda = random().nextFloat(0.5f, 1.5f); + FloatVectorValues vectors = generateData(nVectors, dims, nClusters); + + float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, nClusters); + KMeansLocal.cluster(vectors, centroids, sampleSize, maxIterations); + + int[] assignments = new int[vectors.size()]; + int[] assignmentOrdinals = new int[vectors.size()]; + for (int i = 0; i < vectors.size(); i++) { + float minDist = Float.MAX_VALUE; + int ord = -1; + for (int j = 0; j < centroids.length; j++) { + float dist = VectorUtil.squareDistance(vectors.vectorValue(i), centroids[j]); + if (dist < minDist) { + minDist = dist; + ord = j; + } + } + assignments[i] = ord; + assignmentOrdinals[i] = i; + } + + KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, i -> assignmentOrdinals[i]); + KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations, clustersPerNeighborhood, soarLambda); + kMeansLocal.cluster(vectors, kMeansIntermediate, true); + + assertEquals(nClusters, centroids.length); + assertNotNull(kMeansIntermediate.soarAssignments()); + } + + public void testKMeansNeighborsAllZero() throws IOException { + int nClusters = 10; + int maxIterations = 10; + int clustersPerNeighborhood = 128; + float soarLambda = 1.0f; + int nVectors = 1000; + List vectors = new ArrayList<>(); + for (int i = 0; i < nVectors; i++) { + float[] vector = new float[5]; + vectors.add(vector); + } + int sampleSize = vectors.size(); + FloatVectorValues fvv = FloatVectorValues.fromFloats(vectors, 5); + + float[][] centroids = KMeansLocal.pickInitialCentroids(fvv, nClusters); + KMeansLocal.cluster(fvv, centroids, sampleSize, maxIterations); + + int[] assignments = new int[vectors.size()]; + int[] assignmentOrdinals = new int[vectors.size()]; + for (int i = 0; i < vectors.size(); i++) { + float minDist = Float.MAX_VALUE; + int ord = -1; + for (int j = 0; j < centroids.length; j++) { + float dist = VectorUtil.squareDistance(fvv.vectorValue(i), centroids[j]); + if (dist < minDist) { + minDist = dist; + ord = j; + } + } + assignments[i] = ord; + assignmentOrdinals[i] = i; + } + + KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, i -> assignmentOrdinals[i]); + KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations, clustersPerNeighborhood, soarLambda); + kMeansLocal.cluster(fvv, kMeansIntermediate, true); + + assertEquals(nClusters, centroids.length); + assertNotNull(kMeansIntermediate.soarAssignments()); + for (float[] centroid : centroids) { + for (float v : centroid) { + if (v > 0.0000001f) { + assertEquals(0.0f, v, 0.00000001f); + } + } + } + } + + private static FloatVectorValues generateData(int nSamples, int nDims, int nClusters) { + List vectors = new ArrayList<>(nSamples); + float[][] centroids = new float[nClusters][nDims]; + // Generate random centroids + for (int i = 0; i < nClusters; i++) { + for (int j = 0; j < nDims; j++) { + centroids[i][j] = random().nextFloat() * 100; + } + } + // Generate data points around centroids + for (int i = 0; i < nSamples; i++) { + int cluster = random().nextInt(nClusters); + float[] vector = new float[nDims]; + for (int j = 0; j < nDims; j++) { + vector[j] = centroids[cluster][j] + random().nextFloat() * 10 - 5; + } + vectors.add(vector); + } + return FloatVectorValues.fromFloats(vectors, nDims); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/NeighborQueueTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/NeighborQueueTests.java similarity index 97% rename from server/src/test/java/org/elasticsearch/index/codec/vectors/NeighborQueueTests.java rename to server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/NeighborQueueTests.java index 7238f58d746dc..56c86b4ef6bc9 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/NeighborQueueTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/NeighborQueueTests.java @@ -14,10 +14,10 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * Modifications copyright (C) 2025 Elasticsearch B.V. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. */ - -package org.elasticsearch.index.codec.vectors; +package org.elasticsearch.index.codec.vectors.cluster; import org.elasticsearch.test.ESTestCase;