48
48
import org .elasticsearch .xpack .esql .expression .function .fulltext .QueryString ;
49
49
import org .elasticsearch .xpack .esql .expression .function .grouping .Bucket ;
50
50
import org .elasticsearch .xpack .esql .expression .function .scalar .convert .ToInteger ;
51
+ import org .elasticsearch .xpack .esql .expression .function .scalar .string .Concat ;
51
52
import org .elasticsearch .xpack .esql .expression .function .scalar .string .Substring ;
52
53
import org .elasticsearch .xpack .esql .expression .predicate .operator .arithmetic .Add ;
53
54
import org .elasticsearch .xpack .esql .expression .predicate .operator .comparison .Equals ;
67
68
import org .elasticsearch .xpack .esql .plan .logical .OrderBy ;
68
69
import org .elasticsearch .xpack .esql .plan .logical .Row ;
69
70
import org .elasticsearch .xpack .esql .plan .logical .UnresolvedRelation ;
71
+ import org .elasticsearch .xpack .esql .plan .logical .inference .Completion ;
70
72
import org .elasticsearch .xpack .esql .plan .logical .inference .Rerank ;
71
73
import org .elasticsearch .xpack .esql .plan .logical .local .EsqlProject ;
72
74
import org .elasticsearch .xpack .esql .plugin .EsqlPlugin ;
89
91
import static org .elasticsearch .xpack .esql .EsqlTestUtils .TEST_VERIFIER ;
90
92
import static org .elasticsearch .xpack .esql .EsqlTestUtils .as ;
91
93
import static org .elasticsearch .xpack .esql .EsqlTestUtils .configuration ;
94
+ import static org .elasticsearch .xpack .esql .EsqlTestUtils .getAttributeByName ;
92
95
import static org .elasticsearch .xpack .esql .EsqlTestUtils .paramAsConstant ;
93
96
import static org .elasticsearch .xpack .esql .EsqlTestUtils .paramAsIdentifier ;
94
97
import static org .elasticsearch .xpack .esql .EsqlTestUtils .paramAsPattern ;
98
+ import static org .elasticsearch .xpack .esql .EsqlTestUtils .referenceAttribute ;
95
99
import static org .elasticsearch .xpack .esql .EsqlTestUtils .withDefaultLimitWarning ;
96
100
import static org .elasticsearch .xpack .esql .analysis .Analyzer .NO_FIELDS ;
97
101
import static org .elasticsearch .xpack .esql .analysis .AnalyzerTestUtils .analyze ;
@@ -3050,7 +3054,7 @@ public void testResolveRerankInferenceId() {
3050
3054
3051
3055
{
3052
3056
LogicalPlan plan = analyze (
3053
- " FROM books METADATA _score | RERANK \" italian food recipe\" ON title WITH `reranking-inference-id`" ,
3057
+ "FROM books METADATA _score | RERANK \" italian food recipe\" ON title WITH `reranking-inference-id`" ,
3054
3058
"mapping-books.json"
3055
3059
);
3056
3060
Rerank rerank = as (as (plan , Limit .class ).child (), Rerank .class );
@@ -3120,16 +3124,13 @@ public void testResolveRerankFields() {
3120
3124
Filter filter = as (drop .child (), Filter .class );
3121
3125
EsRelation relation = as (filter .child (), EsRelation .class );
3122
3126
3123
- Attribute titleAttribute = relation .output (). stream (). filter ( attribute -> attribute . name (). equals ( "title" )). findFirst (). get ( );
3124
- assertThat (titleAttribute , notNullValue ());
3127
+ Attribute titleAttribute = getAttributeByName ( relation .output (), "title" );
3128
+ assertThat (getAttributeByName ( relation . output (), "title" ) , notNullValue ());
3125
3129
3126
3130
assertThat (rerank .queryText (), equalTo (string ("italian food recipe" )));
3127
3131
assertThat (rerank .inferenceId (), equalTo (string ("reranking-inference-id" )));
3128
3132
assertThat (rerank .rerankFields (), equalTo (List .of (alias ("title" , titleAttribute ))));
3129
- assertThat (
3130
- rerank .scoreAttribute (),
3131
- equalTo (relation .output ().stream ().filter (attr -> attr .name ().equals (MetadataAttribute .SCORE )).findFirst ().get ())
3132
- );
3133
+ assertThat (rerank .scoreAttribute (), equalTo (getAttributeByName (relation .output (), MetadataAttribute .SCORE )));
3133
3134
}
3134
3135
3135
3136
{
@@ -3149,15 +3150,11 @@ public void testResolveRerankFields() {
3149
3150
assertThat (rerank .inferenceId (), equalTo (string ("reranking-inference-id" )));
3150
3151
3151
3152
assertThat (rerank .rerankFields (), hasSize (3 ));
3152
- Attribute titleAttribute = relation .output (). stream (). filter ( attribute -> attribute . name (). equals ( "title" )). findFirst (). get ( );
3153
+ Attribute titleAttribute = getAttributeByName ( relation .output (), "title" );
3153
3154
assertThat (titleAttribute , notNullValue ());
3154
3155
assertThat (rerank .rerankFields ().get (0 ), equalTo (alias ("title" , titleAttribute )));
3155
3156
3156
- Attribute descriptionAttribute = relation .output ()
3157
- .stream ()
3158
- .filter (attribute -> attribute .name ().equals ("description" ))
3159
- .findFirst ()
3160
- .get ();
3157
+ Attribute descriptionAttribute = getAttributeByName (relation .output (), "description" );
3161
3158
assertThat (descriptionAttribute , notNullValue ());
3162
3159
Alias descriptionAlias = rerank .rerankFields ().get (1 );
3163
3160
assertThat (descriptionAlias .name (), equalTo ("description" ));
@@ -3166,13 +3163,11 @@ public void testResolveRerankFields() {
3166
3163
equalTo (List .of (descriptionAttribute , literal (0 ), literal (100 )))
3167
3164
);
3168
3165
3169
- Attribute yearAttribute = relation .output (). stream (). filter ( attribute -> attribute . name (). equals ( "year" )). findFirst (). get ( );
3166
+ Attribute yearAttribute = getAttributeByName ( relation .output (), "year" );
3170
3167
assertThat (yearAttribute , notNullValue ());
3171
3168
assertThat (rerank .rerankFields ().get (2 ), equalTo (alias ("yearRenamed" , yearAttribute )));
3172
- assertThat (
3173
- rerank .scoreAttribute (),
3174
- equalTo (relation .output ().stream ().filter (attr -> attr .name ().equals (MetadataAttribute .SCORE )).findFirst ().get ())
3175
- );
3169
+
3170
+ assertThat (rerank .scoreAttribute (), equalTo (getAttributeByName (relation .output (), MetadataAttribute .SCORE )));
3176
3171
}
3177
3172
3178
3173
{
@@ -3204,11 +3199,7 @@ public void testResolveRerankScoreField() {
3204
3199
Filter filter = as (rerank .child (), Filter .class );
3205
3200
EsRelation relation = as (filter .child (), EsRelation .class );
3206
3201
3207
- Attribute metadataScoreAttribute = relation .output ()
3208
- .stream ()
3209
- .filter (attr -> attr .name ().equals (MetadataAttribute .SCORE ))
3210
- .findFirst ()
3211
- .get ();
3202
+ Attribute metadataScoreAttribute = getAttributeByName (relation .output (), MetadataAttribute .SCORE );
3212
3203
assertThat (rerank .scoreAttribute (), equalTo (metadataScoreAttribute ));
3213
3204
assertThat (rerank .output (), hasItem (metadataScoreAttribute ));
3214
3205
}
@@ -3232,6 +3223,116 @@ public void testResolveRerankScoreField() {
3232
3223
}
3233
3224
}
3234
3225
3226
+ public void testResolveCompletionInferenceId () {
3227
+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3228
+
3229
+ LogicalPlan plan = analyze ("""
3230
+ FROM books METADATA _score
3231
+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `completion-inference-id`
3232
+ """ , "mapping-books.json" );
3233
+ Completion completion = as (as (plan , Limit .class ).child (), Completion .class );
3234
+ assertThat (completion .inferenceId (), equalTo (string ("completion-inference-id" )));
3235
+ }
3236
+
3237
+ public void testResolveCompletionInferenceIdInvalidTaskType () {
3238
+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3239
+
3240
+ assertError (
3241
+ """
3242
+ FROM books METADATA _score
3243
+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `reranking-inference-id`
3244
+ """ ,
3245
+ "mapping-books.json" ,
3246
+ new QueryParams (),
3247
+ "cannot use inference endpoint [reranking-inference-id] with task type [rerank] within a Completion command."
3248
+ + " Only inference endpoints with the task type [completion] are supported"
3249
+ );
3250
+ }
3251
+
3252
+ public void testResolveCompletionInferenceMissingInferenceId () {
3253
+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3254
+
3255
+ assertError ("""
3256
+ FROM books METADATA _score
3257
+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `unknown-inference-id`
3258
+ """ , "mapping-books.json" , new QueryParams (), "unresolved inference [unknown-inference-id]" );
3259
+ }
3260
+
3261
+ public void testResolveCompletionInferenceIdResolutionError () {
3262
+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3263
+
3264
+ assertError ("""
3265
+ FROM books METADATA _score
3266
+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `error-inference-id`
3267
+ """ , "mapping-books.json" , new QueryParams (), "error with inference resolution" );
3268
+ }
3269
+
3270
+ public void testResolveCompletionTargetField () {
3271
+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3272
+
3273
+ LogicalPlan plan = analyze ("""
3274
+ FROM books METADATA _score
3275
+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `completion-inference-id` AS translation
3276
+ """ , "mapping-books.json" );
3277
+
3278
+ Completion completion = as (as (plan , Limit .class ).child (), Completion .class );
3279
+ assertThat (completion .targetField (), equalTo (referenceAttribute ("translation" , DataType .TEXT )));
3280
+ }
3281
+
3282
+ public void testResolveCompletionDefaultTargetField () {
3283
+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3284
+
3285
+ LogicalPlan plan = analyze ("""
3286
+ FROM books METADATA _score
3287
+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `completion-inference-id`
3288
+ """ , "mapping-books.json" );
3289
+
3290
+ Completion completion = as (as (plan , Limit .class ).child (), Completion .class );
3291
+ assertThat (completion .targetField (), equalTo (referenceAttribute ("completion" , DataType .TEXT )));
3292
+ }
3293
+
3294
+ public void testResolveCompletionPrompt () {
3295
+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3296
+
3297
+ LogicalPlan plan = analyze ("""
3298
+ FROM books METADATA _score
3299
+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `completion-inference-id`
3300
+ """ , "mapping-books.json" );
3301
+
3302
+ Completion completion = as (as (plan , Limit .class ).child (), Completion .class );
3303
+ EsRelation esRelation = as (completion .child (), EsRelation .class );
3304
+
3305
+ assertThat (
3306
+ as (completion .prompt (), Concat .class ).children (),
3307
+ equalTo (List .of (string ("Translate the following text in French\n " ), getAttributeByName (esRelation .output (), "description" )))
3308
+ );
3309
+ }
3310
+
3311
+ public void testResolveCompletionPromptInvalidType () {
3312
+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3313
+
3314
+ assertError ("""
3315
+ FROM books METADATA _score
3316
+ | COMPLETION LENGTH(description) WITH `completion-inference-id`
3317
+ """ , "mapping-books.json" , new QueryParams (), "prompt must be of type [text] but is [integer]" );
3318
+ }
3319
+
3320
+ public void testResolveCompletionOutputField () {
3321
+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3322
+
3323
+ LogicalPlan plan = analyze ("""
3324
+ FROM books METADATA _score
3325
+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `completion-inference-id` AS description
3326
+ """ , "mapping-books.json" );
3327
+
3328
+ Completion completion = as (as (plan , Limit .class ).child (), Completion .class );
3329
+ assertThat (completion .targetField (), equalTo (referenceAttribute ("description" , DataType .TEXT )));
3330
+
3331
+ EsRelation esRelation = as (completion .child (), EsRelation .class );
3332
+ assertThat (getAttributeByName (completion .output (), "description" ), equalTo (completion .targetField ()));
3333
+ assertThat (getAttributeByName (esRelation .output (), "description" ), not (equalTo (completion .targetField ())));
3334
+ }
3335
+
3235
3336
@ Override
3236
3337
protected IndexAnalyzers createDefaultIndexAnalyzers () {
3237
3338
return super .createDefaultIndexAnalyzers ();
0 commit comments