|
1 | 1 | import random |
2 | 2 | from contextlib import contextmanager, nullcontext |
| 3 | +from typing import Optional |
3 | 4 |
|
4 | 5 | import pytest |
5 | 6 |
|
|
18 | 19 | from .test_llm import (_test_llm_capture_request_error, get_model_path, |
19 | 20 | global_kvcache_config, llama_model_path, |
20 | 21 | llm_get_stats_async_test_harness, |
21 | | - llm_get_stats_test_harness, llm_test_harness, prompts, |
22 | | - run_llm_abort_request, |
| 22 | + llm_get_stats_test_harness, |
| 23 | + llm_return_logprobs_test_harness, llm_test_harness, |
| 24 | + prompts, run_llm_abort_request, |
23 | 25 | run_llm_with_postprocess_parallel_and_result_handler, |
24 | 26 | tinyllama_logits_processor_test_harness) |
25 | 27 | from utils.util import (force_ampere, similar, skip_gpu_memory_less_than_40gb, |
@@ -836,3 +838,45 @@ def test_max_num_token_check(self): |
836 | 838 | match="should not exceed max_num_tokens"): |
837 | 839 | ids = [random.randint(10, 100) for _ in range(101)] |
838 | 840 | llm.generate([ids]) |
| 841 | + |
| 842 | + |
| 843 | +@pytest.mark.parametrize( |
| 844 | + "prompt_logprobs, logprobs, return_context_logits, return_generation_logits, backend", |
| 845 | + [ |
| 846 | + (2, None, True, False, |
| 847 | + "pytorch"), # prompt_logprobs with context_logits |
| 848 | + (None, 1, False, False, |
| 849 | + "pytorch"), # generation logprobs only (top-1, PyTorch limit) |
| 850 | + (2, None, False, False, |
| 851 | + "pytorch"), # prompt_logprobs without context_logits |
| 852 | + (None, None, False, False, "pytorch"), # no logprobs at all |
| 853 | + ]) |
| 854 | +def test_llm_return_logprobs(prompt_logprobs: Optional[int], |
| 855 | + logprobs: Optional[int], |
| 856 | + return_context_logits: bool, |
| 857 | + return_generation_logits: bool, backend: str): |
| 858 | + llm_return_logprobs_test_harness(prompt_logprobs, |
| 859 | + logprobs, |
| 860 | + return_context_logits, |
| 861 | + return_generation_logits, |
| 862 | + backend=backend) |
| 863 | + |
| 864 | + |
| 865 | +@pytest.mark.parametrize( |
| 866 | + "prompt_logprobs, logprobs, return_context_logits, return_generation_logits", |
| 867 | + [ |
| 868 | + (None, 1, False, |
| 869 | + False), # generation logprobs only (top-1, PyTorch limit) |
| 870 | + (2, None, True, False), # prompt_logprobs with context_logits |
| 871 | + (2, None, False, False), # prompt_logprobs only |
| 872 | + (2, 1, False, False), # both prompt and generation logprobs |
| 873 | + ]) |
| 874 | +def test_llm_return_logprobs_streaming(prompt_logprobs, logprobs, |
| 875 | + return_context_logits, |
| 876 | + return_generation_logits): |
| 877 | + llm_return_logprobs_test_harness(prompt_logprobs, |
| 878 | + logprobs, |
| 879 | + return_context_logits, |
| 880 | + return_generation_logits, |
| 881 | + streaming=True, |
| 882 | + backend="pytorch") |
0 commit comments