Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 145 additions & 95 deletions tests/integration/defs/accuracy/test_disaggregated_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,100 +115,139 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
ctx_total_gpus = ctx_tp * ctx_pp
gen_total_gpus = gen_tp * gen_pp

env_ctx = os.environ.copy()
env_ctx["TRTLLM_USE_UCX_KVCACHE"] = "1"
env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(ctx_total_gpus)))

env_gen = os.environ.copy()
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
env_gen["CUDA_VISIBLE_DEVICES"] = ",".join(
map(str, range(ctx_total_gpus, ctx_total_gpus + gen_total_gpus)))
ctx_server_args = common_args + [
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path,
f"--tp_size={ctx_tp}", f"--pp_size={ctx_pp}"
]
gen_server_args = common_args + [
"--port", "8002", "--extra_llm_api_options", gen_server_config_path,
f"--tp_size={gen_tp}", f"--pp_size={gen_pp}"
]
if "max_num_tokens" in ctx_server_config:
ctx_server_args.append(
f"--max_num_tokens={ctx_server_config['max_num_tokens']}")
if "max_num_tokens" in gen_server_config:
gen_server_args.append(
f"--max_num_tokens={gen_server_config['max_num_tokens']}")

with (MyThreadPoolExecutor(max_workers=16) as
thread_pool, temp_dir, popen(ctx_server_args, env=env_ctx) as
ctx_server, popen(gen_server_args, env=env_gen) as gen_server,
popen([
trtllm_serve_path, "disaggregated", "-c",
disaggregated_serving_config_path, "--server_start_timeout",
"3600"
]) as disaggregated_server):
while True:
time.sleep(1)
try:
print("Checking health endpoint")
response = requests.get("http://localhost:8000/health")
if response.status_code == 200:
break
except requests.exceptions.ConnectionError:
continue

client = openai.OpenAI(api_key="1234567890",
base_url=f"http://localhost:8000/v1")

def send_request(prompt: str, sampling_params: SamplingParams,
streaming: bool):
response = client.completions.create(
model=model_name,
prompt=prompt,
stream=streaming,
**({
"max_tokens": sampling_params.max_tokens,
"temperature": sampling_params.temperature,
"top_p": sampling_params.top_p,
"stop": sampling_params.stop,
"seed": sampling_params.seed
} if sampling_params else {}))
result = Result(id=0,
sampling_params=sampling_params,
outputs=[
CompletionOutput(text=response.choices[0].text,
index=0)
])
requested_output = RequestOutput._from_generation_result(
result, prompt=prompt)
setattr(requested_output, "result", result.result)
return requested_output

def generate_async(prompt: str,
sampling_params: Optional[SamplingParams] = None,
streaming: bool = False):
future = thread_pool.submit(send_request, prompt, sampling_params,
streaming)
thread_pool.futures.append(future)
return future

ctx_urls = disaggregated_server_config["context_servers"]["urls"]
gen_urls = disaggregated_server_config["generation_servers"]["urls"]

ctx_ports = [int(url.split(":")[1]) for url in ctx_urls]
gen_ports = [int(url.split(":")[1]) for url in gen_urls]

ctx_servers = []
current_gpu_offset = 0

for i, port in enumerate(ctx_ports):
env_ctx = os.environ.copy()
env_ctx["TRTLLM_USE_UCX_KVCACHE"] = "1"
gpu_range = range(current_gpu_offset,
current_gpu_offset + ctx_total_gpus)
env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_range))
current_gpu_offset += ctx_total_gpus

ctx_server_args = common_args + [
"--port",
str(port), "--extra_llm_api_options", ctx_server_config_path,
f"--tp_size={ctx_tp}", f"--pp_size={ctx_pp}"
]
if "max_num_tokens" in ctx_server_config:
ctx_server_args.append(
f"--max_num_tokens={ctx_server_config['max_num_tokens']}")

ctx_servers.append((env_ctx, ctx_server_args))

gen_servers = []

for i, port in enumerate(gen_ports):
env_gen = os.environ.copy()
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
gpu_range = range(current_gpu_offset,
current_gpu_offset + gen_total_gpus)
env_gen["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_range))
current_gpu_offset += gen_total_gpus

gen_server_args = common_args + [
"--port",
str(port), "--extra_llm_api_options", gen_server_config_path,
f"--tp_size={gen_tp}", f"--pp_size={gen_pp}"
]
if "max_num_tokens" in gen_server_config:
gen_server_args.append(
f"--max_num_tokens={gen_server_config['max_num_tokens']}")

gen_servers.append((env_gen, gen_server_args))

@contextlib.contextmanager
def multi_popen(server_configs):
processes = []
try:
yield DuckLLM(args, generate_async)
finally:
ctx_server.terminate()
gen_server.terminate()
disaggregated_server.terminate()

ctx_server.wait()
gen_server.wait()
disaggregated_server.wait()
for env, args in server_configs:
proc = popen(args, env=env)
processes.append(proc)

with contextlib.ExitStack() as stack:
opened_processes = [
stack.enter_context(proc) for proc in processes
]
yield opened_processes
except Exception as e:
print(
f"Failed to start disaggregated server processes in multi_popen: {e}"
)
raise

with (MyThreadPoolExecutor(max_workers=16) as thread_pool, temp_dir):
with multi_popen(ctx_servers + gen_servers):
with popen([
trtllm_serve_path, "disaggregated", "-c",
disaggregated_serving_config_path, "--server_start_timeout",
"3600"
]):
while True:
time.sleep(1)
try:
print("Checking health endpoint")
response = requests.get("http://localhost:8000/health")
if response.status_code == 200:
break
except requests.exceptions.ConnectionError:
continue

client = openai.OpenAI(api_key="1234567890",
base_url=f"http://localhost:8000/v1")

def send_request(prompt: str, sampling_params: SamplingParams,
streaming: bool):
response = client.completions.create(
model=model_name,
prompt=prompt,
stream=streaming,
**({
"max_tokens": sampling_params.max_tokens,
"temperature": sampling_params.temperature,
"top_p": sampling_params.top_p,
"stop": sampling_params.stop,
"seed": sampling_params.seed
} if sampling_params else {}))
result = Result(id=0,
sampling_params=sampling_params,
outputs=[
CompletionOutput(
text=response.choices[0].text,
index=0)
])
requested_output = RequestOutput._from_generation_result(
result, prompt=prompt)
setattr(requested_output, "result", result.result)
return requested_output

def generate_async(
prompt: str,
sampling_params: Optional[SamplingParams] = None,
streaming: bool = False):
future = thread_pool.submit(send_request, prompt,
sampling_params, streaming)
thread_pool.futures.append(future)
return future

yield DuckLLM(args, generate_async)


def run_parallel_test(model_name: str, model_path: str, ctx_pp: int,
ctx_tp: int, gen_pp: int, gen_tp: int,
test_set: LlmapiAccuracyTestHarness):
if ctx_tp * ctx_pp + gen_tp * gen_pp > get_device_count():
ctx_tp: int, gen_pp: int, gen_tp: int, ctx_instances: int,
gen_instances: int, test_set: LlmapiAccuracyTestHarness):
total_ctx_gpus = ctx_tp * ctx_pp * ctx_instances
total_gen_gpus = gen_tp * gen_pp * gen_instances
if total_ctx_gpus + total_gen_gpus > get_device_count():
pytest.fail(
f"Not enough devices for ctx_pp={ctx_pp}+ctx_tp={ctx_tp} and gen_pp={gen_pp}+gen_tp={gen_tp} test"
f"Not enough devices for {ctx_instances} ctx instances (ctx_pp={ctx_pp}*ctx_tp={ctx_tp}) + {gen_instances} gen instances (gen_pp={gen_pp}*gen_tp={gen_tp}), total: {total_ctx_gpus + total_gen_gpus}"
)

kv_cache_config = {
Expand All @@ -233,17 +272,21 @@ def run_parallel_test(model_name: str, model_path: str, ctx_pp: int,
"backend": "default"
}
}

ctx_urls = [f"localhost:{8001 + i * 2}" for i in range(ctx_instances)]
gen_urls = [f"localhost:{8002 + i * 2}" for i in range(gen_instances)]

disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
"num_instances": ctx_instances,
"urls": ctx_urls
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
"num_instances": gen_instances,
"urls": gen_urls
}
}
with launch_disaggregated_llm(disaggregated_server_config,
Expand Down Expand Up @@ -399,14 +442,21 @@ def test_eagle3(self, overlap_scheduler, eagle3_one_model):
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
def test_tp_pp_symmetric(self, tp, pp, testset):
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, pp, tp, pp,
tp, get_accuracy_task(testset))
tp, 1, 1, get_accuracy_task(testset))

@pytest.mark.skip_less_device(4)
@parametrize_with_ids("ctx_pp", [2, 4])
@parametrize_with_ids("gen_tp", [1, 2])
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
def test_ctx_pp_gen_tp_asymmetric(self, ctx_pp, gen_tp, testset):
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, ctx_pp, 1, 1,
gen_tp, get_accuracy_task(testset))
gen_tp, 1, 1, get_accuracy_task(testset))

@pytest.mark.skip_less_device(4)
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
def test_multi_instance(self, testset):
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, 1, 1, 1, 1,
2, 2, get_accuracy_task(testset))


@pytest.mark.skip_less_device_memory(140000)
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/test_lists/qa/llm_function_full.txt
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,8 @@ accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=4]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=2]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[GSM8K]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[MMLU]
accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[False]
accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[True]
accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[False]
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/test_lists/test-db/l0_dgx_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ l0_dgx_h100:
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=2]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=2]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=2]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[GSM8K]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[MMLU]
- accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend
- test_e2e.py::test_ptp_quickstart_advanced_bs1
Expand Down