Skip to content

Commit d8972a7

Browse files
authored
[Inference API] Use dimensions field in JinaAI text_embedding requests (#139549)
The dimensions field was parsed from the service settings but never used. This commit includes the dimensions field in text embedding requests sent to JinaAI when set by the user. * Add dimensions_set_by_user to JinaAIEmbeddingsServiceSettings * Move similarity and max_input_tokens to exposed fields (cherry picked from commit b2ecf87)
1 parent 630b65b commit d8972a7

File tree

15 files changed

+465
-127
lines changed

15 files changed

+465
-127
lines changed

docs/changelog/139413.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 139413
2+
summary: "[Inference API] Use dimensions field in JinaAI `text_embedding` requests"
3+
area: Inference
4+
type: bug
5+
issues: []
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
9244000,9185014,9112017,8841078
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
downsample_add_multi_field_sources,8841077
1+
jina_ai_embedding_dimensions_support_added,8841078
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
downsample_add_multi_field_sources,9112016
1+
jina_ai_embedding_dimensions_support_added,9112017
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
downsample_add_multi_field_sources,9185013
1+
jina_ai_embedding_dimensions_support_added,9185014
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
downsample_add_multi_field_sources,9232000
1+
jina_ai_embedding_dimensions_support_added,9244000

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceFields.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ public final class ServiceFields {
1414

1515
public static final String SIMILARITY = "similarity";
1616
public static final String DIMENSIONS = "dimensions";
17+
public static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user";
1718
// Typically we use this to define the maximum tokens for the input text (text being sent to an integration)
1819
public static final String MAX_INPUT_TOKENS = "max_input_tokens";
1920
public static final String URL = "url";

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,8 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
304304
similarityToUse,
305305
embeddingSize,
306306
maxInputTokens,
307-
serviceSettings.getEmbeddingType()
307+
serviceSettings.getEmbeddingType(),
308+
serviceSettings.dimensionsSetByUser()
308309
);
309310

310311
return new JinaAIEmbeddingsModel(embeddingsModel, updatedServiceSettings);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettings.java

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@
2727
import java.util.Objects;
2828

2929
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
30+
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS_SET_BY_USER;
3031
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
3132
import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
3233
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
34+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
3335
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
3436
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType;
3537

@@ -42,16 +44,33 @@ public static JinaAIEmbeddingsServiceSettings fromMap(Map<String, Object> map, C
4244
ValidationException validationException = new ValidationException();
4345
var commonServiceSettings = JinaAIServiceSettings.fromMap(map, context);
4446
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
45-
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
47+
Integer dimensions = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException);
4648
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
4749

4850
JinaAIEmbeddingType embeddingTypes = parseEmbeddingType(map, validationException);
4951

52+
Boolean dimensionsSetByUser;
53+
if (context == ConfigurationParseContext.PERSISTENT) {
54+
dimensionsSetByUser = removeAsType(map, DIMENSIONS_SET_BY_USER, Boolean.class);
55+
if (dimensionsSetByUser == null) {
56+
dimensionsSetByUser = Boolean.FALSE;
57+
}
58+
} else {
59+
dimensionsSetByUser = dimensions != null;
60+
}
61+
5062
if (validationException.validationErrors().isEmpty() == false) {
5163
throw validationException;
5264
}
5365

54-
return new JinaAIEmbeddingsServiceSettings(commonServiceSettings, similarity, dims, maxInputTokens, embeddingTypes);
66+
return new JinaAIEmbeddingsServiceSettings(
67+
commonServiceSettings,
68+
similarity,
69+
dimensions,
70+
maxInputTokens,
71+
embeddingTypes,
72+
dimensionsSetByUser
73+
);
5574
}
5675

5776
static JinaAIEmbeddingType parseEmbeddingType(Map<String, Object> map, ValidationException validationException) {
@@ -72,24 +91,31 @@ static JinaAIEmbeddingType parseEmbeddingType(Map<String, Object> map, Validatio
7291
"jina_ai_embedding_type_support_added"
7392
);
7493

94+
static final TransportVersion JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED = TransportVersion.fromName(
95+
"jina_ai_embedding_dimensions_support_added"
96+
);
97+
7598
private final JinaAIServiceSettings commonSettings;
7699
private final SimilarityMeasure similarity;
77100
private final Integer dimensions;
78101
private final Integer maxInputTokens;
79102
private final JinaAIEmbeddingType embeddingType;
103+
private final Boolean dimensionsSetByUser;
80104

81105
public JinaAIEmbeddingsServiceSettings(
82106
JinaAIServiceSettings commonSettings,
83107
@Nullable SimilarityMeasure similarity,
84108
@Nullable Integer dimensions,
85109
@Nullable Integer maxInputTokens,
86-
@Nullable JinaAIEmbeddingType embeddingType
110+
@Nullable JinaAIEmbeddingType embeddingType,
111+
Boolean dimensionsSetByUser
87112
) {
88113
this.commonSettings = commonSettings;
89114
this.similarity = similarity;
90115
this.dimensions = dimensions;
91116
this.maxInputTokens = maxInputTokens;
92117
this.embeddingType = embeddingType != null ? embeddingType : JinaAIEmbeddingType.FLOAT;
118+
this.dimensionsSetByUser = Objects.requireNonNull(dimensionsSetByUser);
93119
}
94120

95121
public JinaAIEmbeddingsServiceSettings(StreamInput in) throws IOException {
@@ -101,6 +127,12 @@ public JinaAIEmbeddingsServiceSettings(StreamInput in) throws IOException {
101127
this.embeddingType = (in.getTransportVersion().supports(JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED))
102128
? Objects.requireNonNullElse(in.readOptionalEnum(JinaAIEmbeddingType.class), JinaAIEmbeddingType.FLOAT)
103129
: JinaAIEmbeddingType.FLOAT;
130+
131+
if (in.getTransportVersion().supports(JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED)) {
132+
this.dimensionsSetByUser = in.readBoolean();
133+
} else {
134+
this.dimensionsSetByUser = false;
135+
}
104136
}
105137

106138
public JinaAIServiceSettings getCommonSettings() {
@@ -117,6 +149,11 @@ public Integer dimensions() {
117149
return dimensions;
118150
}
119151

152+
@Override
153+
public Boolean dimensionsSetByUser() {
154+
return dimensionsSetByUser;
155+
}
156+
120157
public Integer maxInputTokens() {
121158
return maxInputTokens;
122159
}
@@ -144,18 +181,10 @@ public String getWriteableName() {
144181
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
145182
builder.startObject();
146183

147-
builder = commonSettings.toXContentFragment(builder, params);
148-
if (similarity != null) {
149-
builder.field(SIMILARITY, similarity);
150-
}
151-
if (dimensions != null) {
152-
builder.field(DIMENSIONS, dimensions);
153-
}
154-
if (maxInputTokens != null) {
155-
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
156-
}
157-
if (embeddingType != null) {
158-
builder.field(EMBEDDING_TYPE, embeddingType);
184+
toXContentFragmentOfExposedFields(builder, params);
185+
186+
if (dimensionsSetByUser != null) {
187+
builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser);
159188
}
160189

161190
builder.endObject();
@@ -165,9 +194,18 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
165194
@Override
166195
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
167196
commonSettings.toXContentFragmentOfExposedFields(builder, params);
197+
if (dimensions != null) {
198+
builder.field(DIMENSIONS, dimensions);
199+
}
168200
if (embeddingType != null) {
169201
builder.field(EMBEDDING_TYPE, embeddingType);
170202
}
203+
if (maxInputTokens != null) {
204+
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
205+
}
206+
if (similarity != null) {
207+
builder.field(SIMILARITY, similarity);
208+
}
171209
return builder;
172210
}
173211

@@ -186,6 +224,10 @@ public void writeTo(StreamOutput out) throws IOException {
186224
if (out.getTransportVersion().supports(JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED)) {
187225
out.writeOptionalEnum(JinaAIEmbeddingType.translateToVersion(embeddingType, out.getTransportVersion()));
188226
}
227+
228+
if (out.getTransportVersion().supports(JINA_AI_EMBEDDING_DIMENSIONS_SUPPORT_ADDED)) {
229+
out.writeBoolean(dimensionsSetByUser);
230+
}
189231
}
190232

191233
@Override
@@ -197,11 +239,12 @@ public boolean equals(Object o) {
197239
&& Objects.equals(similarity, that.similarity)
198240
&& Objects.equals(dimensions, that.dimensions)
199241
&& Objects.equals(maxInputTokens, that.maxInputTokens)
200-
&& Objects.equals(embeddingType, that.embeddingType);
242+
&& Objects.equals(embeddingType, that.embeddingType)
243+
&& Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser);
201244
}
202245

203246
@Override
204247
public int hashCode() {
205-
return Objects.hash(commonSettings, similarity, dimensions, maxInputTokens, embeddingType);
248+
return Objects.hash(commonSettings, similarity, dimensions, maxInputTokens, embeddingType, dimensionsSetByUser);
206249
}
207250
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequest.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ public class JinaAIEmbeddingsRequest extends JinaAIRequest {
3434
private final String model;
3535
private final String inferenceEntityId;
3636
private final JinaAIEmbeddingType embeddingType;
37+
private final Integer dimensions;
38+
private final boolean dimensionsSetByUser;
3739

3840
public JinaAIEmbeddingsRequest(List<String> input, InputType inputType, JinaAIEmbeddingsModel embeddingsModel) {
3941
Objects.requireNonNull(embeddingsModel);
@@ -45,15 +47,18 @@ public JinaAIEmbeddingsRequest(List<String> input, InputType inputType, JinaAIEm
4547
model = embeddingsModel.getServiceSettings().getCommonSettings().modelId();
4648
embeddingType = embeddingsModel.getServiceSettings().getEmbeddingType();
4749
inferenceEntityId = embeddingsModel.getInferenceEntityId();
50+
dimensions = embeddingsModel.getServiceSettings().dimensions();
51+
dimensionsSetByUser = embeddingsModel.getServiceSettings().dimensionsSetByUser();
4852
}
4953

5054
@Override
5155
public HttpRequest createHttpRequest() {
5256
HttpPost httpPost = new HttpPost(account.uri());
5357

5458
ByteArrayEntity byteEntity = new ByteArrayEntity(
55-
Strings.toString(new JinaAIEmbeddingsRequestEntity(input, inputType, taskSettings, model, embeddingType))
56-
.getBytes(StandardCharsets.UTF_8)
59+
Strings.toString(
60+
new JinaAIEmbeddingsRequestEntity(input, inputType, taskSettings, model, embeddingType, dimensions, dimensionsSetByUser)
61+
).getBytes(StandardCharsets.UTF_8)
5762
);
5863
httpPost.setEntity(byteEntity);
5964

0 commit comments

Comments
 (0)