Skip to content

Commit 40e83cb

Browse files
committed
[ES|QL] COMPLETION command analysis. (#126677)
* [ES|QL] COMPLETION command analysis. * Moving prompt type test in postAnalysisVerification * Test lint.
1 parent d5a78ba commit 40e83cb

File tree

4 files changed

+177
-24
lines changed

4 files changed

+177
-24
lines changed

x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@
109109
import java.time.Period;
110110
import java.util.ArrayList;
111111
import java.util.Arrays;
112+
import java.util.Collection;
112113
import java.util.EnumSet;
113114
import java.util.HashMap;
114115
import java.util.HashSet;
@@ -126,6 +127,7 @@
126127
import static java.util.Collections.emptyMap;
127128
import static java.util.Collections.unmodifiableMap;
128129
import static org.elasticsearch.test.ESTestCase.assertEquals;
130+
import static org.elasticsearch.test.ESTestCase.assertThat;
129131
import static org.elasticsearch.test.ESTestCase.between;
130132
import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength;
131133
import static org.elasticsearch.test.ESTestCase.randomArray;
@@ -151,6 +153,7 @@
151153
import static org.elasticsearch.xpack.esql.parser.ParserUtils.ParamClassification.IDENTIFIER;
152154
import static org.elasticsearch.xpack.esql.parser.ParserUtils.ParamClassification.PATTERN;
153155
import static org.elasticsearch.xpack.esql.parser.ParserUtils.ParamClassification.VALUE;
156+
import static org.hamcrest.Matchers.hasSize;
154157
import static org.hamcrest.Matchers.instanceOf;
155158
import static org.junit.Assert.assertNotNull;
156159
import static org.junit.Assert.assertNull;
@@ -884,6 +887,19 @@ public static void assertEsqlFailure(Exception e) {
884887
.ifPresent(transportFailure -> assertNull("remote transport exception must be unwrapped", transportFailure.getCause()));
885888
}
886889

890+
public static <T> T singleValue(Collection<T> collection) {
891+
assertThat(collection, hasSize(1));
892+
return collection.iterator().next();
893+
}
894+
895+
public static Attribute getAttributeByName(Collection<Attribute> attributes, String name) {
896+
return attributes.stream().filter(attr -> attr.name().equals(name)).findAny().orElse(null);
897+
}
898+
899+
public static Map<String, Object> jsonEntityToMap(HttpEntity entity) throws IOException {
900+
return entityToMap(entity, XContentType.JSON);
901+
}
902+
887903
public static Map<String, Object> entityToMap(HttpEntity entity, XContentType expectedContentType) throws IOException {
888904
try (InputStream content = entity.getContent()) {
889905
XContentType xContentType = XContentType.fromMediaType(entity.getContentType().getValue());

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
import org.elasticsearch.xpack.esql.plan.logical.Project;
8484
import org.elasticsearch.xpack.esql.plan.logical.Rename;
8585
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
86+
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
8687
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
8788
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
8889
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
@@ -480,6 +481,10 @@ protected LogicalPlan doRule(LogicalPlan plan) {
480481
return resolveAggregate(aggregate, childrenOutput);
481482
}
482483

484+
if (plan instanceof Completion c) {
485+
return resolveCompletion(c, childrenOutput);
486+
}
487+
483488
if (plan instanceof Drop d) {
484489
return resolveDrop(d, childrenOutput);
485490
}
@@ -586,6 +591,21 @@ private Aggregate resolveAggregate(Aggregate aggregate, List<Attribute> children
586591
return aggregate;
587592
}
588593

594+
private LogicalPlan resolveCompletion(Completion p, List<Attribute> childrenOutput) {
595+
Attribute targetField = p.targetField();
596+
Expression prompt = p.prompt();
597+
598+
if (targetField instanceof UnresolvedAttribute ua) {
599+
targetField = new ReferenceAttribute(ua.source(), ua.name(), TEXT);
600+
}
601+
602+
if (prompt.resolved() == false) {
603+
prompt = prompt.transformUp(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));
604+
}
605+
606+
return new Completion(p.source(), p.child(), p.inferenceId(), prompt, targetField);
607+
}
608+
589609
private LogicalPlan resolveMvExpand(MvExpand p, List<Attribute> childrenOutput) {
590610
if (p.target() instanceof UnresolvedAttribute ua) {
591611
Attribute resolved = maybeResolveAttribute(ua, childrenOutput);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111
import org.elasticsearch.common.io.stream.StreamInput;
1212
import org.elasticsearch.common.io.stream.StreamOutput;
1313
import org.elasticsearch.inference.TaskType;
14+
import org.elasticsearch.xpack.esql.capabilities.PostAnalysisVerificationAware;
15+
import org.elasticsearch.xpack.esql.common.Failures;
1416
import org.elasticsearch.xpack.esql.core.expression.Attribute;
1517
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
1618
import org.elasticsearch.xpack.esql.core.expression.Expression;
1719
import org.elasticsearch.xpack.esql.core.expression.NameId;
1820
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
1921
import org.elasticsearch.xpack.esql.core.tree.Source;
22+
import org.elasticsearch.xpack.esql.core.type.DataType;
2023
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
2124
import org.elasticsearch.xpack.esql.plan.GeneratingPlan;
2225
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
@@ -26,9 +29,15 @@
2629
import java.util.List;
2730
import java.util.Objects;
2831

32+
import static org.elasticsearch.xpack.esql.common.Failure.fail;
33+
import static org.elasticsearch.xpack.esql.core.type.DataType.TEXT;
2934
import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
3035

31-
public class Completion extends InferencePlan<Completion> implements GeneratingPlan<Completion>, SortAgnostic {
36+
public class Completion extends InferencePlan<Completion>
37+
implements
38+
GeneratingPlan<Completion>,
39+
SortAgnostic,
40+
PostAnalysisVerificationAware {
3241

3342
public static final String DEFAULT_OUTPUT_FIELD_NAME = "completion";
3443

@@ -130,6 +139,13 @@ public boolean expressionsResolved() {
130139
return super.expressionsResolved() && prompt.resolved();
131140
}
132141

142+
@Override
143+
public void postAnalysisVerification(Failures failures) {
144+
if (prompt.resolved() && DataType.isString(prompt.dataType()) == false) {
145+
failures.add(fail(prompt, "prompt must be of type [{}] but is [{}]", TEXT.typeName(), prompt.dataType().typeName()));
146+
}
147+
}
148+
133149
@Override
134150
protected NodeInfo<? extends LogicalPlan> info() {
135151
return NodeInfo.create(this, Completion::new, child(), inferenceId(), prompt, targetField);

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java

Lines changed: 124 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString;
4949
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
5050
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger;
51+
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat;
5152
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Substring;
5253
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
5354
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
@@ -67,6 +68,7 @@
6768
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
6869
import org.elasticsearch.xpack.esql.plan.logical.Row;
6970
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
71+
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
7072
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
7173
import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject;
7274
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
@@ -89,9 +91,11 @@
8991
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
9092
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
9193
import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration;
94+
import static org.elasticsearch.xpack.esql.EsqlTestUtils.getAttributeByName;
9295
import static org.elasticsearch.xpack.esql.EsqlTestUtils.paramAsConstant;
9396
import static org.elasticsearch.xpack.esql.EsqlTestUtils.paramAsIdentifier;
9497
import static org.elasticsearch.xpack.esql.EsqlTestUtils.paramAsPattern;
98+
import static org.elasticsearch.xpack.esql.EsqlTestUtils.referenceAttribute;
9599
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
96100
import static org.elasticsearch.xpack.esql.analysis.Analyzer.NO_FIELDS;
97101
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyze;
@@ -3050,7 +3054,7 @@ public void testResolveRerankInferenceId() {
30503054

30513055
{
30523056
LogicalPlan plan = analyze(
3053-
" FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH `reranking-inference-id`",
3057+
"FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH `reranking-inference-id`",
30543058
"mapping-books.json"
30553059
);
30563060
Rerank rerank = as(as(plan, Limit.class).child(), Rerank.class);
@@ -3120,16 +3124,13 @@ public void testResolveRerankFields() {
31203124
Filter filter = as(drop.child(), Filter.class);
31213125
EsRelation relation = as(filter.child(), EsRelation.class);
31223126

3123-
Attribute titleAttribute = relation.output().stream().filter(attribute -> attribute.name().equals("title")).findFirst().get();
3124-
assertThat(titleAttribute, notNullValue());
3127+
Attribute titleAttribute = getAttributeByName(relation.output(), "title");
3128+
assertThat(getAttributeByName(relation.output(), "title"), notNullValue());
31253129

31263130
assertThat(rerank.queryText(), equalTo(string("italian food recipe")));
31273131
assertThat(rerank.inferenceId(), equalTo(string("reranking-inference-id")));
31283132
assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", titleAttribute))));
3129-
assertThat(
3130-
rerank.scoreAttribute(),
3131-
equalTo(relation.output().stream().filter(attr -> attr.name().equals(MetadataAttribute.SCORE)).findFirst().get())
3132-
);
3133+
assertThat(rerank.scoreAttribute(), equalTo(getAttributeByName(relation.output(), MetadataAttribute.SCORE)));
31333134
}
31343135

31353136
{
@@ -3149,15 +3150,11 @@ public void testResolveRerankFields() {
31493150
assertThat(rerank.inferenceId(), equalTo(string("reranking-inference-id")));
31503151

31513152
assertThat(rerank.rerankFields(), hasSize(3));
3152-
Attribute titleAttribute = relation.output().stream().filter(attribute -> attribute.name().equals("title")).findFirst().get();
3153+
Attribute titleAttribute = getAttributeByName(relation.output(), "title");
31533154
assertThat(titleAttribute, notNullValue());
31543155
assertThat(rerank.rerankFields().get(0), equalTo(alias("title", titleAttribute)));
31553156

3156-
Attribute descriptionAttribute = relation.output()
3157-
.stream()
3158-
.filter(attribute -> attribute.name().equals("description"))
3159-
.findFirst()
3160-
.get();
3157+
Attribute descriptionAttribute = getAttributeByName(relation.output(), "description");
31613158
assertThat(descriptionAttribute, notNullValue());
31623159
Alias descriptionAlias = rerank.rerankFields().get(1);
31633160
assertThat(descriptionAlias.name(), equalTo("description"));
@@ -3166,13 +3163,11 @@ public void testResolveRerankFields() {
31663163
equalTo(List.of(descriptionAttribute, literal(0), literal(100)))
31673164
);
31683165

3169-
Attribute yearAttribute = relation.output().stream().filter(attribute -> attribute.name().equals("year")).findFirst().get();
3166+
Attribute yearAttribute = getAttributeByName(relation.output(), "year");
31703167
assertThat(yearAttribute, notNullValue());
31713168
assertThat(rerank.rerankFields().get(2), equalTo(alias("yearRenamed", yearAttribute)));
3172-
assertThat(
3173-
rerank.scoreAttribute(),
3174-
equalTo(relation.output().stream().filter(attr -> attr.name().equals(MetadataAttribute.SCORE)).findFirst().get())
3175-
);
3169+
3170+
assertThat(rerank.scoreAttribute(), equalTo(getAttributeByName(relation.output(), MetadataAttribute.SCORE)));
31763171
}
31773172

31783173
{
@@ -3204,11 +3199,7 @@ public void testResolveRerankScoreField() {
32043199
Filter filter = as(rerank.child(), Filter.class);
32053200
EsRelation relation = as(filter.child(), EsRelation.class);
32063201

3207-
Attribute metadataScoreAttribute = relation.output()
3208-
.stream()
3209-
.filter(attr -> attr.name().equals(MetadataAttribute.SCORE))
3210-
.findFirst()
3211-
.get();
3202+
Attribute metadataScoreAttribute = getAttributeByName(relation.output(), MetadataAttribute.SCORE);
32123203
assertThat(rerank.scoreAttribute(), equalTo(metadataScoreAttribute));
32133204
assertThat(rerank.output(), hasItem(metadataScoreAttribute));
32143205
}
@@ -3232,6 +3223,116 @@ public void testResolveRerankScoreField() {
32323223
}
32333224
}
32343225

3226+
public void testResolveCompletionInferenceId() {
3227+
assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled());
3228+
3229+
LogicalPlan plan = analyze("""
3230+
FROM books METADATA _score
3231+
| COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `completion-inference-id`
3232+
""", "mapping-books.json");
3233+
Completion completion = as(as(plan, Limit.class).child(), Completion.class);
3234+
assertThat(completion.inferenceId(), equalTo(string("completion-inference-id")));
3235+
}
3236+
3237+
public void testResolveCompletionInferenceIdInvalidTaskType() {
3238+
assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled());
3239+
3240+
assertError(
3241+
"""
3242+
FROM books METADATA _score
3243+
| COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `reranking-inference-id`
3244+
""",
3245+
"mapping-books.json",
3246+
new QueryParams(),
3247+
"cannot use inference endpoint [reranking-inference-id] with task type [rerank] within a Completion command."
3248+
+ " Only inference endpoints with the task type [completion] are supported"
3249+
);
3250+
}
3251+
3252+
public void testResolveCompletionInferenceMissingInferenceId() {
3253+
assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled());
3254+
3255+
assertError("""
3256+
FROM books METADATA _score
3257+
| COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `unknown-inference-id`
3258+
""", "mapping-books.json", new QueryParams(), "unresolved inference [unknown-inference-id]");
3259+
}
3260+
3261+
public void testResolveCompletionInferenceIdResolutionError() {
3262+
assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled());
3263+
3264+
assertError("""
3265+
FROM books METADATA _score
3266+
| COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `error-inference-id`
3267+
""", "mapping-books.json", new QueryParams(), "error with inference resolution");
3268+
}
3269+
3270+
public void testResolveCompletionTargetField() {
3271+
assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled());
3272+
3273+
LogicalPlan plan = analyze("""
3274+
FROM books METADATA _score
3275+
| COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `completion-inference-id` AS translation
3276+
""", "mapping-books.json");
3277+
3278+
Completion completion = as(as(plan, Limit.class).child(), Completion.class);
3279+
assertThat(completion.targetField(), equalTo(referenceAttribute("translation", DataType.TEXT)));
3280+
}
3281+
3282+
public void testResolveCompletionDefaultTargetField() {
3283+
assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled());
3284+
3285+
LogicalPlan plan = analyze("""
3286+
FROM books METADATA _score
3287+
| COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `completion-inference-id`
3288+
""", "mapping-books.json");
3289+
3290+
Completion completion = as(as(plan, Limit.class).child(), Completion.class);
3291+
assertThat(completion.targetField(), equalTo(referenceAttribute("completion", DataType.TEXT)));
3292+
}
3293+
3294+
public void testResolveCompletionPrompt() {
3295+
assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled());
3296+
3297+
LogicalPlan plan = analyze("""
3298+
FROM books METADATA _score
3299+
| COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `completion-inference-id`
3300+
""", "mapping-books.json");
3301+
3302+
Completion completion = as(as(plan, Limit.class).child(), Completion.class);
3303+
EsRelation esRelation = as(completion.child(), EsRelation.class);
3304+
3305+
assertThat(
3306+
as(completion.prompt(), Concat.class).children(),
3307+
equalTo(List.of(string("Translate the following text in French\n"), getAttributeByName(esRelation.output(), "description")))
3308+
);
3309+
}
3310+
3311+
public void testResolveCompletionPromptInvalidType() {
3312+
assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled());
3313+
3314+
assertError("""
3315+
FROM books METADATA _score
3316+
| COMPLETION LENGTH(description) WITH `completion-inference-id`
3317+
""", "mapping-books.json", new QueryParams(), "prompt must be of type [text] but is [integer]");
3318+
}
3319+
3320+
public void testResolveCompletionOutputField() {
3321+
assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled());
3322+
3323+
LogicalPlan plan = analyze("""
3324+
FROM books METADATA _score
3325+
| COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `completion-inference-id` AS description
3326+
""", "mapping-books.json");
3327+
3328+
Completion completion = as(as(plan, Limit.class).child(), Completion.class);
3329+
assertThat(completion.targetField(), equalTo(referenceAttribute("description", DataType.TEXT)));
3330+
3331+
EsRelation esRelation = as(completion.child(), EsRelation.class);
3332+
assertThat(getAttributeByName(completion.output(), "description"), equalTo(completion.targetField()));
3333+
assertThat(getAttributeByName(esRelation.output(), "description"), not(equalTo(completion.targetField())));
3334+
}
3335+
32353336
@Override
32363337
protected IndexAnalyzers createDefaultIndexAnalyzers() {
32373338
return super.createDefaultIndexAnalyzers();

0 commit comments

Comments
 (0)