diff --git a/tests/test_search.py b/tests/test_search.py index 5b45cfc0a3..c4598f3773 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -2856,6 +2856,64 @@ def test_vector_search_with_default_dialect(client): assert res["total_results"] == 2 +@pytest.mark.redismod +@skip_if_server_version_lt("7.9.0") +def test_vector_search_with_int8_type(client): + client.ft().create_index( + (VectorField("v", "FLAT", {"TYPE": "INT8", "DIM": 2, "DISTANCE_METRIC": "L2"}),) + ) + + a = [1.5, 10] + b = [123, 100] + c = [1, 1] + + client.hset("a", "v", np.array(a, dtype=np.int8).tobytes()) + client.hset("b", "v", np.array(b, dtype=np.int8).tobytes()) + client.hset("c", "v", np.array(c, dtype=np.int8).tobytes()) + + query = Query("*=>[KNN 2 @v $vec as score]") + query_params = {"vec": np.array(a, dtype=np.int8).tobytes()} + + assert 2 in query.get_args() + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 2 + else: + assert res["total_results"] == 2 + + +@pytest.mark.redismod +@skip_if_server_version_lt("7.9.0") +def test_vector_search_with_uint8_type(client): + client.ft().create_index( + ( + VectorField( + "v", "FLAT", {"TYPE": "UINT8", "DIM": 2, "DISTANCE_METRIC": "L2"} + ), + ) + ) + + a = [1.5, 10] + b = [123, 100] + c = [1, 1] + + client.hset("a", "v", np.array(a, dtype=np.uint8).tobytes()) + client.hset("b", "v", np.array(b, dtype=np.uint8).tobytes()) + client.hset("c", "v", np.array(c, dtype=np.uint8).tobytes()) + + query = Query("*=>[KNN 2 @v $vec as score]") + query_params = {"vec": np.array(a, dtype=np.uint8).tobytes()} + + assert 2 in query.get_args() + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 2 + else: + assert res["total_results"] == 2 + + @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") def test_search_query_with_different_dialects(client):