Skip to content

Add FieldName signature to VectorField #104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/NRedisStack/Search/Query.cs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ public HighlightTags(string open, string close)
/// <summary>
/// Set the query parameter to sort by ASC by default
/// </summary>
public bool SortAscending { get; set; } = true;
public bool? SortAscending { get; set; } = null;

// highlight and summarize
internal bool _wantsHighlight = false, _wantsSummarize = false;
Expand Down Expand Up @@ -260,7 +260,8 @@ internal void SerializeRedisArgs(List<object> args)
{
args.Add("SORTBY");
args.Add(SortBy);
args.Add((SortAscending ? "ASC" : "DESC"));
if (SortAscending != null)
args.Add(((bool)SortAscending ? "ASC" : "DESC"));
}
if (Payload != null)
{
Expand Down Expand Up @@ -605,7 +606,7 @@ public Query SummarizeFields(int contextLen, int fragmentCount, string separator
/// <param name="field">the sorting field's name</param>
/// <param name="ascending">if set to true, the sorting order is ascending, else descending</param>
/// <returns>the query object itself</returns>
public Query SetSortBy(string field, bool ascending = true)
public Query SetSortBy(string field, bool? ascending = null)
{
SortBy = field;
SortAscending = ascending;
Expand Down
18 changes: 17 additions & 1 deletion src/NRedisStack/Search/Schema.cs
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,16 @@ public enum VectorAlgo

public VectorAlgo Algorithm { get; }
public Dictionary<string, object>? Attributes { get; }
public VectorField(string name, VectorAlgo algorithm, Dictionary<string, object>? attributes = null)
public VectorField(FieldName name, VectorAlgo algorithm, Dictionary<string, object>? attributes = null)
: base(name, FieldType.Vector)
{
Algorithm = algorithm;
Attributes = attributes;
}

public VectorField(string name, VectorAlgo algorithm, Dictionary<string, object>? attributes = null)
: this(FieldName.Of(name), algorithm, attributes) { }

internal override void AddFieldTypeArgs(List<object> args)
{
args.Add(Algorithm.ToString());
Expand Down Expand Up @@ -376,6 +379,19 @@ public Schema AddTagField(string name, bool sortable = false, bool unf = false,
return this;
}

/// <summary>
/// Add a Vector field to the schema.
/// </summary>
/// <param name="name">The field's name.</param>
/// <param name="algorithm">The vector similarity algorithm to use.</param>
/// <param name="attribute">The algorithm attributes for the creation of the vector index.</param>
/// <returns>The <see cref="Schema"/> object.</returns>
public Schema AddVectorField(FieldName name, VectorAlgo algorithm, Dictionary<string, object>? attributes = null)
{
Fields.Add(new VectorField(name, algorithm, attributes));
return this;
}

/// <summary>
/// Add a Vector field to the schema.
/// </summary>
Expand Down
4 changes: 2 additions & 2 deletions src/NRedisStack/Search/SearchResult.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ public class SearchResult
/// <summary>
/// Converts the documents to a list of json strings. only works on a json documents index.
/// </summary>
public IEnumerable<string>? ToJson() => Documents.Select(x => x["json"].ToString())
.Where(x => !string.IsNullOrEmpty(x));
public List<string>? ToJson() => Documents.Select(x => x["json"].ToString())
.Where(x => !string.IsNullOrEmpty(x)).ToList();

internal SearchResult(RedisResult[] resp, bool hasContent, bool hasScores, bool hasPayloads/*, bool shouldExplainScore*/)
{
Expand Down
93 changes: 93 additions & 0 deletions tests/NRedisStack.Tests/Search/SearchTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using static NRedisStack.Search.Schema;
using NRedisStack.Search.Aggregation;
using NRedisStack.Search.Literals.Enums;
using System.Runtime.InteropServices;

namespace NRedisStack.Tests.Search;

Expand Down Expand Up @@ -1912,6 +1913,98 @@ public async Task TestVectorCount_Issue70()
Assert.Equal(expected.Count(), actual.Args.Length);
}

[Fact]
public void VectorSimilaritySearch()
{
IDatabase db = redisFixture.Redis.GetDatabase();
db.Execute("FLUSHALL");
var ft = db.FT();
var json = db.JSON();

json.Set("vec:1", "$", "{\"vector\":[1,1,1,1]}");
json.Set("vec:2", "$", "{\"vector\":[2,2,2,2]}");
json.Set("vec:3", "$", "{\"vector\":[3,3,3,3]}");
json.Set("vec:4", "$", "{\"vector\":[4,4,4,4]}");

var schema = new Schema().AddVectorField(FieldName.Of("$.vector").As("vector"), Schema.VectorField.VectorAlgo.FLAT, new Dictionary<string, object>()
{
["TYPE"] = "FLOAT32",
["DIM"] = "4",
["DISTANCE_METRIC"] = "L2",
});

var idxDef = new FTCreateParams().On(IndexDataType.JSON).Prefix("vec:");
Assert.True(ft.Create("vss_idx", idxDef, schema));

float[] vec = new float[] { 2, 2, 2, 2 };
byte[] queryVec = MemoryMarshal.Cast<float, byte>(vec).ToArray();


var query = new Query("*=>[KNN 3 @vector $query_vec]")
.AddParam("query_vec", queryVec)
.SetSortBy("__vector_score")
.Dialect(2);
var res = ft.Search("vss_idx", query);

Assert.Equal(3, res.TotalResults);

Assert.Equal("vec:2", res.Documents[0].Id.ToString());

Assert.Equal(0, res.Documents[0]["__vector_score"]);

var jsonRes = res.ToJson();
Assert.Equal("{\"vector\":[2,2,2,2]}", jsonRes![0]);
}

[Fact]
public void QueryingVectorFields()
{
IDatabase db = redisFixture.Redis.GetDatabase();
db.Execute("FLUSHALL");
var ft = db.FT();
var json = db.JSON();

var schema = new Schema().AddVectorField("v", Schema.VectorField.VectorAlgo.HNSW, new Dictionary<string, object>()
{
["TYPE"] = "FLOAT32",
["DIM"] = "2",
["DISTANCE_METRIC"] = "L2",
});

ft.Create("idx", new FTCreateParams(), schema);

db.HashSet("a", "v", "aaaaaaaa");
db.HashSet("b", "v", "aaaabaaa");
db.HashSet("c", "v", "aaaaabaa");

var q = new Query("*=>[KNN 2 @v $vec]").ReturnFields("__v_score").Dialect(2);
var res = ft.Search("idx", q.AddParam("vec", "aaaaaaaa"));
Assert.Equal(2, res.TotalResults);
}

[Fact]
public async Task TestVectorFieldJson_Issue102Async()
{
IDatabase db = redisFixture.Redis.GetDatabase();
db.Execute("FLUSHALL");
var ft = db.FT();
var json = db.JSON();

// JSON.SET 1 $ '{"vec":[1,2,3,4]}'
await json.SetAsync("1", "$", "{\"vec\":[1,2,3,4]}");

// FT.CREATE my_index ON JSON SCHEMA $.vec as vector VECTOR FLAT 6 TYPE FLOAT32 DIM 4 DISTANCE_METRIC L2
var schema = new Schema().AddVectorField(FieldName.Of("$.vec").As("vector"), Schema.VectorField.VectorAlgo.FLAT, new Dictionary<string, object>()
{
["TYPE"] = "FLOAT32",
["DIM"] = "4",
["DISTANCE_METRIC"] = "L2",
});

Assert.True(await ft.CreateAsync("my_index", new FTCreateParams().On(IndexDataType.JSON), schema));

}

[Fact]
public void TestModulePrefixs1()
{
Expand Down