Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
public class RedisOMAiProperties {
private boolean enabled = false;
private int embeddingBatchSize = 1000;
private final Djl djl = new Djl();
private final Transformers transformers = new Transformers();
private final OpenAi openAi = new OpenAi();
Expand Down Expand Up @@ -63,7 +64,15 @@ public Ollama getOllama() {
return ollama;
}

// Transformer properties
public int getEmbeddingBatchSize() {
return embeddingBatchSize;
}

public void setEmbeddingBatchSize(int embeddingBatchSize) {
this.embeddingBatchSize = embeddingBatchSize;
}

// Transformer properties
public static class Transformers {
private String tokenizerResource;
private String modelResource;
Expand All @@ -87,7 +96,6 @@ public Map<String, String> getTokenizerOptions() {
}
}

// DJL properties
public static class Djl {
private static final String DEFAULT_ENGINE = "PyTorch";
// image embedding settings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ public <S extends T> List<S> saveAll(Iterable<S> entities) {
List<S> saved = new ArrayList<>();
List<Object> entityIds = new ArrayList<>();

embedder.processEntities(entities);

try (Jedis jedis = modulesOperations.client().getJedis().get()) {
Pipeline pipeline = jedis.pipelined();
Gson gson = gsonBuilder.create();
Expand All @@ -209,7 +211,6 @@ public <S extends T> List<S> saveAll(Iterable<S> entities) {

// process entity pre-save mutation
auditor.processEntity(entity, isNew);
embedder.processEntity(entity);

Optional<Long> maybeTtl = getTTLForEntity(entity);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,8 @@ public <S extends T> List<S> saveAll(Iterable<S> entities) {
Assert.notNull(entities, "The given Iterable of entities must not be null!");
List<S> saved = new ArrayList<>();

embedder.processEntities(entities);

try (Jedis jedis = modulesOperations.client().getJedis().get()) {
Pipeline pipeline = jedis.pipelined();

Expand All @@ -380,7 +382,6 @@ public <S extends T> List<S> saveAll(Iterable<S> entities) {

// process entity pre-save mutation
auditor.processEntity(entity, isNew);
embedder.processEntity(entity);

RedisData rdo = new RedisData();
mappingConverter.write(entity, rdo);
Expand Down
Loading