Skip to content

Commit 6431000

Browse files
dcamporaCopilot
authored andcommitted
[TRTLLM-5974][feat] Support disaggregated serving in TRTLLM Sampler (NVIDIA#5328)
Signed-off-by: Daniel Campora <[email protected]> Signed-off-by: Daniel Cámpora <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 206d60d commit 6431000

File tree

4 files changed

+152
-5
lines changed

4 files changed

+152
-5
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from tensorrt_llm.bindings.executor import FinishReason
2+
3+
4+
class FinishedState:
5+
# State flags
6+
FINISHED_EOS = 1 << 0
7+
FINISHED_STOP_WORDS = 1 << 1
8+
FINISHED_MAX_LENGTH = 1 << 2
9+
FINISHED = FINISHED_EOS | FINISHED_STOP_WORDS | FINISHED_MAX_LENGTH
10+
SKIP_DECODING = 1 << 3
11+
12+
def __init__(self, state=0):
13+
self._state = state
14+
15+
@classmethod
16+
def empty(cls):
17+
return cls(0)
18+
19+
@classmethod
20+
def finished(cls):
21+
return cls(cls.FINISHED)
22+
23+
@classmethod
24+
def skip_decoding(cls):
25+
return cls(cls.SKIP_DECODING)
26+
27+
@classmethod
28+
def finished_eos(cls):
29+
return cls(cls.FINISHED_EOS)
30+
31+
@classmethod
32+
def finished_max_length(cls):
33+
return cls(cls.FINISHED_MAX_LENGTH)
34+
35+
@classmethod
36+
def finished_stop_words(cls):
37+
return cls(cls.FINISHED_STOP_WORDS)
38+
39+
def set_finished_eos(self):
40+
self._state |= self.FINISHED_EOS
41+
42+
@property
43+
def is_finished_eos(self):
44+
return self._any_bit_set(self.FINISHED_EOS)
45+
46+
def set_finished_stop_words(self):
47+
self._state |= self.FINISHED_STOP_WORDS
48+
49+
@property
50+
def is_finished_stop_words(self):
51+
return self._any_bit_set(self.FINISHED_STOP_WORDS)
52+
53+
def set_finished_max_length(self):
54+
self._state |= self.FINISHED_MAX_LENGTH
55+
56+
@property
57+
def is_finished_max_length(self):
58+
return self._any_bit_set(self.FINISHED_MAX_LENGTH)
59+
60+
def set_finished(self):
61+
self._state |= self.FINISHED
62+
63+
@property
64+
def is_finished(self):
65+
return self._any_bit_set(self.FINISHED)
66+
67+
def set_skip_decoding(self):
68+
self._state |= self.SKIP_DECODING
69+
70+
@property
71+
def is_skip_decoding(self):
72+
return self._any_bit_set(self.SKIP_DECODING)
73+
74+
def to_finish_reason(self):
75+
if self.is_finished_eos:
76+
return FinishReason.END_ID
77+
if self.is_finished_stop_words:
78+
return FinishReason.STOP_WORDS
79+
if self.is_finished_max_length:
80+
return FinishReason.LENGTH
81+
return FinishReason.NOT_FINISHED
82+
83+
def to_underlying(self):
84+
return self._state
85+
86+
def _any_bit_set(self, bits):
87+
return (self._state & bits) != 0

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from tensorrt_llm.executor.result import Logprob
2424
from tensorrt_llm.mapping import Mapping
2525

26+
from .finish_reason import FinishedState
2627
from .llm_request import LlmRequest, LlmRequestState
2728
from .scheduler import ScheduledRequests
2829

@@ -683,6 +684,7 @@ def update_requests(self, state: SampleStateTRTLLM):
683684
for beam in range(beam_width):
684685
seq_len = sequence_lengths_host_data[seq_slot * beam_width +
685686
beam].item()
687+
seq_len = seq_len + 1 if self.is_trt_overlap else seq_len
686688
num_new_tokens[beam] = min(
687689
num_generated_tokens,
688690
seq_len - request.get_num_tokens(beam))
@@ -713,9 +715,10 @@ def update_requests(self, state: SampleStateTRTLLM):
713715
state.host.cum_log_probs[seq_slot * beam_width +
714716
beam].item())
715717

716-
finish_reason = finish_reasons_host[seq_slot * beam_width +
717-
beam].item()
718-
request.set_finished_reason(FinishReason(finish_reason), beam)
718+
finish_reason = FinishedState(
719+
finish_reasons_host[seq_slot * beam_width +
720+
beam].item()).to_finish_reason()
721+
request.set_finished_reason(finish_reason, beam)
719722

720723
if request.py_return_log_probs:
721724
request.py_result.append_log_probs([log_probs], cum_log_probs)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
2+
hostname: localhost
3+
port: 8000
4+
backend: "pytorch"
5+
free_gpu_memory_fraction: 0.2
6+
context_servers:
7+
num_instances: 1
8+
max_batch_size: 1
9+
max_num_tokens: 3000
10+
max_seq_len: 4096
11+
tensor_parallel_size: 1
12+
pipeline_parallel_size: 1
13+
enable_trtllm_sampler: True
14+
kv_cache_config:
15+
free_gpu_memory_fraction: 0.2
16+
enable_partial_reuse: False
17+
use_cuda_graph: False
18+
disable_overlap_scheduler: True
19+
urls:
20+
- "localhost:8001"
21+
generation_servers:
22+
num_instances: 1
23+
tensor_parallel_size: 1
24+
pipeline_parallel_size: 1
25+
max_batch_size: 256
26+
max_num_tokens: 4096
27+
max_seq_len: 4096
28+
enable_trtllm_sampler: True
29+
kv_cache_config:
30+
free_gpu_memory_fraction: 0.2
31+
enable_partial_reuse: False
32+
use_cuda_graph: False
33+
disable_overlap_scheduler: False
34+
urls:
35+
- "localhost:8002"

tests/integration/defs/disaggregated/test_disaggregated.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def get_test_config(test_desc, example_dir, test_root):
5050
(2, f"{test_configs_root}/disagg_config_cuda_graph_padding.yaml"),
5151
"mixed": (2, f"{test_configs_root}/disagg_config_mixed.yaml"),
5252
"overlap": (2, f"{test_configs_root}/disagg_config_overlap.yaml"),
53+
"trtllm_sampler":
54+
(2, f"{test_configs_root}/disagg_config_trtllm_sampler.yaml"),
5355
"load_balance":
5456
(4, f"{test_configs_root}/disagg_config_load_balance.yaml"),
5557
"cache_aware_balance":
@@ -179,7 +181,7 @@ def run_disaggregated_test(example_dir,
179181
poll_procs=[workers_proc, server_proc])
180182

181183
# Run the chat completion endpoint test only for TinyLlama
182-
if test_desc == "overlap":
184+
if test_desc == "overlap" or test_desc == "trtllm_sampler":
183185
chat_client_cmd = client_cmd + [
184186
'-e', 'chat', '-o', 'output_chat.json'
185187
]
@@ -198,7 +200,7 @@ def run_disaggregated_test(example_dir,
198200
not_expected_strings = ["Berlin Berlin"]
199201

200202
output_files = ['output.json', 'output_streaming.json']
201-
if test_desc == "overlap":
203+
if test_desc == "overlap" or test_desc == "trtllm_sampler":
202204
# Disable streaming chat completion for overlap test
203205
# due to bug
204206
output_files.extend(['output_chat.json'])
@@ -420,6 +422,26 @@ def test_disaggregated_overlap(disaggregated_test_root, llm_venv,
420422
cwd=llm_venv.get_working_directory())
421423

422424

425+
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
426+
indirect=True)
427+
def test_disaggregated_trtllm_sampler(disaggregated_test_root, llm_venv,
428+
disaggregated_example_root,
429+
llama_model_root):
430+
src_dst_dict = {
431+
llama_model_root:
432+
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
433+
}
434+
for src, dst in src_dst_dict.items():
435+
if not os.path.islink(dst):
436+
os.makedirs(os.path.dirname(dst), exist_ok=True)
437+
os.symlink(src, dst, target_is_directory=True)
438+
439+
run_disaggregated_test(disaggregated_example_root,
440+
"trtllm_sampler",
441+
env=llm_venv._new_env,
442+
cwd=llm_venv.get_working_directory())
443+
444+
423445
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
424446
indirect=True)
425447
def test_disaggregated_load_balance(disaggregated_test_root, llm_venv,

0 commit comments

Comments
 (0)