Skip to content

Commit d81660d

Browse files
committed
ES|QL random sampling (elastic#125570)
1 parent fd2bac4 commit d81660d

File tree

59 files changed

+4591
-2525
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+4591
-2525
lines changed

docs/changelog/125570.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 125570
2+
summary: ES|QL random sampling
3+
area: Machine Learning
4+
type: feature
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ static TransportVersion def(int id) {
242242
public static final TransportVersion NONE_CHUNKING_STRATEGY_8_19 = def(8_841_0_49);
243243
public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST_8_19 = def(8_841_0_50);
244244
public static final TransportVersion SETTINGS_IN_DATA_STREAMS_8_19 = def(8_841_0_51);
245+
public static final TransportVersion RANDOM_SAMPLER_QUERY_BUILDER_8_19 = def(8_841_0_52);
245246

246247
/*
247248
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/search/SearchModule.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@
137137
import org.elasticsearch.search.aggregations.bucket.sampler.UnmappedSampler;
138138
import org.elasticsearch.search.aggregations.bucket.sampler.random.InternalRandomSampler;
139139
import org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplerAggregationBuilder;
140+
import org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplingQueryBuilder;
140141
import org.elasticsearch.search.aggregations.bucket.terms.DoubleTerms;
141142
import org.elasticsearch.search.aggregations.bucket.terms.LongRareTerms;
142143
import org.elasticsearch.search.aggregations.bucket.terms.LongTerms;
@@ -1209,6 +1210,9 @@ private void registerQueryParsers(List<SearchPlugin> plugins) {
12091210
registerQuery(new QuerySpec<>(ExactKnnQueryBuilder.NAME, ExactKnnQueryBuilder::new, parser -> {
12101211
throw new IllegalArgumentException("[exact_knn] queries cannot be provided directly");
12111212
}));
1213+
registerQuery(
1214+
new QuerySpec<>(RandomSamplingQueryBuilder.NAME, RandomSamplingQueryBuilder::new, RandomSamplingQueryBuilder::fromXContent)
1215+
);
12121216

12131217
registerFromPlugin(plugins, SearchPlugin::getQueries, this::registerQuery);
12141218

server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplingQuery.java

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,34 @@ public final class RandomSamplingQuery extends Query {
4343
* can be generated
4444
*/
4545
public RandomSamplingQuery(double p, int seed, int hash) {
46-
if (p <= 0.0 || p >= 1.0) {
47-
throw new IllegalArgumentException("RandomSampling probability must be between 0.0 and 1.0, was [" + p + "]");
48-
}
46+
checkProbabilityRange(p);
4947
this.p = p;
5048
this.seed = seed;
5149
this.hash = hash;
5250
}
5351

52+
/**
53+
* Verifies that the probability is within the (0.0, 1.0) range.
54+
* @throws IllegalArgumentException in case of an invalid probability.
55+
*/
56+
public static void checkProbabilityRange(double p) throws IllegalArgumentException {
57+
if (p <= 0.0 || p >= 1.0) {
58+
throw new IllegalArgumentException("RandomSampling probability must be strictly between 0.0 and 1.0, was [" + p + "]");
59+
}
60+
}
61+
62+
public double probability() {
63+
return p;
64+
}
65+
66+
public int seed() {
67+
return seed;
68+
}
69+
70+
public int hash() {
71+
return hash;
72+
}
73+
5474
@Override
5575
public String toString(String field) {
5676
return "RandomSamplingQuery{" + "p=" + p + ", seed=" + seed + ", hash=" + hash + '}';
@@ -97,13 +117,13 @@ public void visit(QueryVisitor visitor) {
97117
/**
98118
* A DocIDSetIter that skips a geometrically random number of documents
99119
*/
100-
static class RandomSamplingIterator extends DocIdSetIterator {
120+
public static class RandomSamplingIterator extends DocIdSetIterator {
101121
private final int maxDoc;
102122
private final double p;
103123
private final FastGeometric distribution;
104124
private int doc = -1;
105125

106-
RandomSamplingIterator(int maxDoc, double p, IntSupplier rng) {
126+
public RandomSamplingIterator(int maxDoc, double p, IntSupplier rng) {
107127
this.maxDoc = maxDoc;
108128
this.p = p;
109129
this.distribution = new FastGeometric(rng, p);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.aggregations.bucket.sampler.random;
11+
12+
import org.apache.lucene.search.Query;
13+
import org.elasticsearch.TransportVersion;
14+
import org.elasticsearch.TransportVersions;
15+
import org.elasticsearch.common.Randomness;
16+
import org.elasticsearch.common.io.stream.StreamInput;
17+
import org.elasticsearch.common.io.stream.StreamOutput;
18+
import org.elasticsearch.index.query.AbstractQueryBuilder;
19+
import org.elasticsearch.index.query.SearchExecutionContext;
20+
import org.elasticsearch.xcontent.ConstructingObjectParser;
21+
import org.elasticsearch.xcontent.ParseField;
22+
import org.elasticsearch.xcontent.XContentBuilder;
23+
import org.elasticsearch.xcontent.XContentParser;
24+
25+
import java.io.IOException;
26+
import java.util.Objects;
27+
28+
import static org.elasticsearch.search.aggregations.bucket.sampler.random.RandomSamplingQuery.checkProbabilityRange;
29+
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
30+
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
31+
32+
public class RandomSamplingQueryBuilder extends AbstractQueryBuilder<RandomSamplingQueryBuilder> {
33+
34+
public static final String NAME = "random_sampling";
35+
static final ParseField PROBABILITY = new ParseField("query");
36+
static final ParseField SEED = new ParseField("seed");
37+
static final ParseField HASH = new ParseField("hash");
38+
39+
private final double probability;
40+
private int seed = Randomness.get().nextInt();
41+
private int hash = 0;
42+
43+
public RandomSamplingQueryBuilder(double probability) {
44+
checkProbabilityRange(probability);
45+
this.probability = probability;
46+
}
47+
48+
public RandomSamplingQueryBuilder seed(int seed) {
49+
checkProbabilityRange(probability);
50+
this.seed = seed;
51+
return this;
52+
}
53+
54+
public RandomSamplingQueryBuilder(StreamInput in) throws IOException {
55+
super(in);
56+
this.probability = in.readDouble();
57+
this.seed = in.readInt();
58+
this.hash = in.readInt();
59+
}
60+
61+
public RandomSamplingQueryBuilder hash(Integer hash) {
62+
this.hash = hash;
63+
return this;
64+
}
65+
66+
public double probability() {
67+
return probability;
68+
}
69+
70+
public int seed() {
71+
return seed;
72+
}
73+
74+
public int hash() {
75+
return hash;
76+
}
77+
78+
@Override
79+
protected void doWriteTo(StreamOutput out) throws IOException {
80+
out.writeDouble(probability);
81+
out.writeInt(seed);
82+
out.writeInt(hash);
83+
}
84+
85+
@Override
86+
protected void doXContent(XContentBuilder builder, Params params) throws IOException {
87+
builder.startObject(NAME);
88+
builder.field(PROBABILITY.getPreferredName(), probability);
89+
builder.field(SEED.getPreferredName(), seed);
90+
builder.field(HASH.getPreferredName(), hash);
91+
builder.endObject();
92+
}
93+
94+
private static final ConstructingObjectParser<RandomSamplingQueryBuilder, Void> PARSER = new ConstructingObjectParser<>(
95+
NAME,
96+
false,
97+
args -> {
98+
var randomSamplingQueryBuilder = new RandomSamplingQueryBuilder((double) args[0]);
99+
if (args[1] != null) {
100+
randomSamplingQueryBuilder.seed((int) args[1]);
101+
}
102+
if (args[2] != null) {
103+
randomSamplingQueryBuilder.hash((int) args[2]);
104+
}
105+
return randomSamplingQueryBuilder;
106+
}
107+
);
108+
109+
static {
110+
PARSER.declareDouble(constructorArg(), PROBABILITY);
111+
PARSER.declareInt(optionalConstructorArg(), SEED);
112+
PARSER.declareInt(optionalConstructorArg(), HASH);
113+
}
114+
115+
public static RandomSamplingQueryBuilder fromXContent(XContentParser parser) throws IOException {
116+
return PARSER.apply(parser, null);
117+
}
118+
119+
@Override
120+
protected Query doToQuery(SearchExecutionContext context) throws IOException {
121+
return new RandomSamplingQuery(probability, seed, hash);
122+
}
123+
124+
@Override
125+
protected boolean doEquals(RandomSamplingQueryBuilder other) {
126+
return probability == other.probability && seed == other.seed && hash == other.hash;
127+
}
128+
129+
@Override
130+
protected int doHashCode() {
131+
return Objects.hash(probability, seed, hash);
132+
}
133+
134+
/**
135+
* Returns the name of the writeable object
136+
*/
137+
@Override
138+
public String getWriteableName() {
139+
return NAME;
140+
}
141+
142+
/**
143+
* The minimal version of the recipient this object can be sent to
144+
*/
145+
@Override
146+
public TransportVersion getMinimalSupportedVersion() {
147+
return TransportVersions.RANDOM_SAMPLER_QUERY_BUILDER_8_19;
148+
}
149+
}

server/src/test/java/org/elasticsearch/search/SearchModuleTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,7 @@ public CheckedBiConsumer<ShardSearchRequest, StreamOutput, IOException> getReque
449449
"range",
450450
"regexp",
451451
"knn_score_doc",
452+
"random_sampling",
452453
"script",
453454
"script_score",
454455
"simple_query_string",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.aggregations.bucket.sampler.random;
11+
12+
import org.apache.lucene.search.Query;
13+
import org.elasticsearch.index.query.SearchExecutionContext;
14+
import org.elasticsearch.test.AbstractQueryTestCase;
15+
import org.elasticsearch.xcontent.XContentParseException;
16+
17+
import java.io.IOException;
18+
19+
import static org.hamcrest.Matchers.equalTo;
20+
21+
public class RandomSamplingQueryBuilderTests extends AbstractQueryTestCase<RandomSamplingQueryBuilder> {
22+
23+
@Override
24+
protected RandomSamplingQueryBuilder doCreateTestQueryBuilder() {
25+
double probability = randomDoubleBetween(0.0, 1.0, false);
26+
var builder = new RandomSamplingQueryBuilder(probability);
27+
if (randomBoolean()) {
28+
builder.seed(randomInt());
29+
}
30+
if (randomBoolean()) {
31+
builder.hash(randomInt());
32+
}
33+
return builder;
34+
}
35+
36+
@Override
37+
protected void doAssertLuceneQuery(RandomSamplingQueryBuilder queryBuilder, Query query, SearchExecutionContext context)
38+
throws IOException {
39+
var rsQuery = asInstanceOf(RandomSamplingQuery.class, query);
40+
assertThat(rsQuery.probability(), equalTo(queryBuilder.probability()));
41+
assertThat(rsQuery.seed(), equalTo(queryBuilder.seed()));
42+
assertThat(rsQuery.hash(), equalTo(queryBuilder.hash()));
43+
}
44+
45+
@Override
46+
protected boolean supportsBoost() {
47+
return false;
48+
}
49+
50+
@Override
51+
protected boolean supportsQueryName() {
52+
return false;
53+
}
54+
55+
@Override
56+
public void testUnknownField() {
57+
var json = "{ \""
58+
+ RandomSamplingQueryBuilder.NAME
59+
+ "\" : {\"bogusField\" : \"someValue\", \""
60+
+ RandomSamplingQueryBuilder.PROBABILITY.getPreferredName()
61+
+ "\" : \""
62+
+ randomBoolean()
63+
+ "\", \""
64+
+ RandomSamplingQueryBuilder.SEED.getPreferredName()
65+
+ "\" : \""
66+
+ randomInt()
67+
+ "\", \""
68+
+ RandomSamplingQueryBuilder.HASH.getPreferredName()
69+
+ "\" : \""
70+
+ randomInt()
71+
+ "\" } }";
72+
var e = expectThrows(XContentParseException.class, () -> parseQuery(json));
73+
assertTrue(e.getMessage().contains("bogusField"));
74+
}
75+
}

x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MetadataAttribute.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@ public static boolean isSupported(String name) {
172172
return ATTRIBUTES_MAP.containsKey(name);
173173
}
174174

175+
public static boolean isScoreAttribute(Expression a) {
176+
return a instanceof MetadataAttribute ma && ma.name().equals(SCORE);
177+
}
178+
175179
@Override
176180
@SuppressWarnings("checkstyle:EqualsHashCode")// equals is implemented in parent. See innerEquals instead
177181
public int hashCode() {

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/Page.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,4 +294,21 @@ public Page projectBlocks(int[] blockMapping) {
294294
}
295295
}
296296
}
297+
298+
public Page filter(int... positions) {
299+
Block[] filteredBlocks = new Block[blocks.length];
300+
boolean success = false;
301+
try {
302+
for (int i = 0; i < blocks.length; i++) {
303+
filteredBlocks[i] = getBlock(i).filter(positions);
304+
}
305+
success = true;
306+
} finally {
307+
releaseBlocks();
308+
if (success == false) {
309+
Releasables.closeExpectNoException(filteredBlocks);
310+
}
311+
}
312+
return new Page(filteredBlocks);
313+
}
297314
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ChangePointOperator.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
import org.elasticsearch.xpack.ml.aggs.changepoint.ChangePointDetector;
2020
import org.elasticsearch.xpack.ml.aggs.changepoint.ChangeType;
2121

22+
import java.util.ArrayDeque;
2223
import java.util.ArrayList;
2324
import java.util.Deque;
24-
import java.util.LinkedList;
2525
import java.util.List;
2626

2727
/**
@@ -68,8 +68,8 @@ public ChangePointOperator(DriverContext driverContext, int channel, String sour
6868
this.sourceColumn = sourceColumn;
6969

7070
finished = false;
71-
inputPages = new LinkedList<>();
72-
outputPages = new LinkedList<>();
71+
inputPages = new ArrayDeque<>();
72+
outputPages = new ArrayDeque<>();
7373
warnings = null;
7474
}
7575

0 commit comments

Comments
 (0)