Skip to content

Commit a7c7d0f

Browse files
authored
Using bulk APIs/calls on entities during saveAll (#552)
* Using bulk APIs/calls on entities during saveAll * Cleaning up DefaultEmbedder * Adding batch size prop
1 parent 02f3bb0 commit a7c7d0f

File tree

8 files changed

+377
-157
lines changed

8 files changed

+377
-157
lines changed

redis-om-spring/src/main/java/com/redis/om/spring/RedisOMAiProperties.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
public class RedisOMAiProperties {
1616
private boolean enabled = false;
17+
private int embeddingBatchSize = 1000;
1718
private final Djl djl = new Djl();
1819
private final Transformers transformers = new Transformers();
1920
private final OpenAi openAi = new OpenAi();
@@ -63,7 +64,15 @@ public Ollama getOllama() {
6364
return ollama;
6465
}
6566

66-
// Transformer properties
67+
public int getEmbeddingBatchSize() {
68+
return embeddingBatchSize;
69+
}
70+
71+
public void setEmbeddingBatchSize(int embeddingBatchSize) {
72+
this.embeddingBatchSize = embeddingBatchSize;
73+
}
74+
75+
// Transformer properties
6776
public static class Transformers {
6877
private String tokenizerResource;
6978
private String modelResource;
@@ -87,7 +96,6 @@ public Map<String, String> getTokenizerOptions() {
8796
}
8897
}
8998

90-
// DJL properties
9199
public static class Djl {
92100
private static final String DEFAULT_ENGINE = "PyTorch";
93101
// image embedding settings

redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisDocumentRepository.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ public <S extends T> List<S> saveAll(Iterable<S> entities) {
185185
List<S> saved = new ArrayList<>();
186186
List<Object> entityIds = new ArrayList<>();
187187

188+
embedder.processEntities(entities);
189+
188190
try (Jedis jedis = modulesOperations.client().getJedis().get()) {
189191
Pipeline pipeline = jedis.pipelined();
190192
Gson gson = gsonBuilder.create();
@@ -209,7 +211,6 @@ public <S extends T> List<S> saveAll(Iterable<S> entities) {
209211

210212
// process entity pre-save mutation
211213
auditor.processEntity(entity, isNew);
212-
embedder.processEntity(entity);
213214

214215
Optional<Long> maybeTtl = getTTLForEntity(entity);
215216

redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisEnhancedRepository.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,8 @@ public <S extends T> List<S> saveAll(Iterable<S> entities) {
360360
Assert.notNull(entities, "The given Iterable of entities must not be null!");
361361
List<S> saved = new ArrayList<>();
362362

363+
embedder.processEntities(entities);
364+
363365
try (Jedis jedis = modulesOperations.client().getJedis().get()) {
364366
Pipeline pipeline = jedis.pipelined();
365367

@@ -380,7 +382,6 @@ public <S extends T> List<S> saveAll(Iterable<S> entities) {
380382

381383
// process entity pre-save mutation
382384
auditor.processEntity(entity, isNew);
383-
embedder.processEntity(entity);
384385

385386
RedisData rdo = new RedisData();
386387
mappingConverter.write(entity, rdo);

0 commit comments

Comments
 (0)