from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED import random import time import os import yaml import ydb import numpy as np # === Configuration === ENDPOINT = "grpc://hostname:2135" DATABASE = "/Root/vector-db1" MODE = "list" # MODE = "string" DURATION = 10 # seconds MAX_WORKERS = 8 PRINT_PERIOD = 500 NUM_EMBEDDINGS = 1000 EMBEDDING_DIM = 1024 embeddings = [] embeddings_binary = [] QUERY_LIST = f""" DECLARE $EmbeddingList as List; $EmbeddingString = UnTag(Knn::ToBinaryStringFloat($EmbeddingList),'FloatVector'); SELECT Len($EmbeddingString); """ QUERY_STRING = f""" DECLARE $EmbeddingString AS String; SELECT Len($EmbeddingString); """ def float_embedding_to_binary(arr): # arr is a 1D numpy float32 array # Serialize all floats, append a single byte 1 at the end return arr.tobytes() + b'\x01' def pre_generate_embeddings(): for _ in range(NUM_EMBEDDINGS): arr = np.random.randn(EMBEDDING_DIM).astype(np.float32) embeddings.append(arr) embeddings_binary.append(float_embedding_to_binary(arr)) def run_query(pool): def callee(session: ydb.QuerySession): if (MODE == "list"): query = QUERY_LIST parameters = {"$EmbeddingList": ydb.TypedValue(random.choice(embeddings), ydb.ListType(ydb.PrimitiveType.Float))} else: query = QUERY_STRING parameters = {"$EmbeddingString": ydb.TypedValue(random.choice(embeddings_binary), ydb.PrimitiveType.String)} return list(session.transaction().execute( query, commit_tx=True, parameters=parameters, )) result = pool.retry_operation_sync(callee) return result[0].rows def pretty_print(rows): if not rows: print("No results.") return for row in rows: print(row) def run_benchmark(pool): times = [] query_counter = 0 start_time = time.time() def submit_query(executor, futures_times): f = executor.submit(run_query, pool) futures_times[f] = time.time() return f with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: futures_times = {} in_flight = set() for _ in range(MAX_WORKERS): f = submit_query(executor, futures_times) in_flight.add(f) while in_flight: done, _ = wait(in_flight, return_when=FIRST_COMPLETED) for future in done: submit_time = futures_times.pop(future) finish_time = time.time() query_time = finish_time - submit_time rows = future.result() query_counter += 1 if query_counter % PRINT_PERIOD == 0: print(f"\n=== Query result #{query_counter} ({query_time:.3f}s) ===") pretty_print(rows) times.append(query_time) in_flight.remove(future) if finish_time - start_time < DURATION: f_new = submit_query(executor, futures_times) in_flight.add(f_new) total_duration = time.time() - start_time return query_counter, total_duration, times def main(): pre_generate_embeddings() driver = ydb.Driver(endpoint=ENDPOINT, database=DATABASE) driver.wait(fail_fast=True, timeout=5) pool = ydb.QuerySessionPool(driver) try: query_counter, total_duration, times = run_benchmark(pool) print(f"\nCompleted {query_counter} queries in {total_duration:.3f} seconds") if times: print(f"Queries per second: {query_counter / total_duration:.2f}") print(f"Avg query time: {sum(times)/len(times):.3f} seconds") finally: driver.stop() if __name__ == "__main__": main()