|
14 | 14 | from flexible_inference_benchmark.engine.client import Client
|
15 | 15 | from flexible_inference_benchmark.engine.backend_functions import ASYNC_REQUEST_FUNCS
|
16 | 16 | from flexible_inference_benchmark.engine.workloads import WORKLOADS_TYPES
|
| 17 | +from flexible_inference_benchmark.data_postprocessors.performance import add_performance_parser |
| 18 | +from flexible_inference_benchmark.data_postprocessors.ttft import add_ttft_parser |
| 19 | +from flexible_inference_benchmark.data_postprocessors.itl import add_itl_parser |
17 | 20 |
|
18 | 21 | logger = logging.getLogger(__name__)
|
19 | 22 |
|
@@ -99,134 +102,148 @@ def send_requests(
|
99 | 102 | return asyncio.run(client.benchmark(requests_prompts, requests_times))
|
100 | 103 |
|
101 | 104 |
|
102 |
| -def parse_args() -> argparse.Namespace: |
| 105 | +def add_benchmark_subparser(subparsers: argparse._SubParsersAction) -> None: # type: ignore [type-arg] |
103 | 106 |
|
104 |
| - parser = argparse.ArgumentParser(description="CentML Inference Benchmark") |
| 107 | + benchmark_parser = subparsers.add_parser('benchmark') |
105 | 108 |
|
106 |
| - parser.add_argument("--seed", type=int, default=None, help="seed for reproducibility") |
| 109 | + benchmark_parser.add_argument("--seed", type=int, default=None, help="seed for reproducibility") |
107 | 110 |
|
108 |
| - parser.add_argument( |
| 111 | + benchmark_parser.add_argument( |
109 | 112 | "--backend",
|
110 | 113 | type=str,
|
111 | 114 | default='cserve',
|
112 | 115 | choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
113 | 116 | help="Backend inference engine.",
|
114 | 117 | )
|
115 | 118 |
|
116 |
| - parser.add_argument( |
| 119 | + benchmark_parser.add_argument( |
117 | 120 | "--workload-type",
|
118 | 121 | type=str,
|
119 | 122 | default=None,
|
120 | 123 | choices=list(WORKLOADS_TYPES.keys()),
|
121 | 124 | help="choose a workload type, this will overwrite some arguments",
|
122 | 125 | )
|
123 | 126 |
|
124 |
| - url_group = parser.add_mutually_exclusive_group() |
| 127 | + url_group = benchmark_parser.add_mutually_exclusive_group() |
125 | 128 |
|
126 | 129 | url_group.add_argument(
|
127 | 130 | "--base-url", type=str, default=None, help="Server or API base url if not using http host and port."
|
128 | 131 | )
|
129 | 132 |
|
130 |
| - parser.add_argument( |
| 133 | + benchmark_parser.add_argument( |
131 | 134 | "--https-ssl", default=True, help="whether to check for ssl certificate for https endpoints, default is True"
|
132 | 135 | )
|
133 | 136 |
|
134 |
| - parser.add_argument("--endpoint", type=str, default="/v1/completions", help="API endpoint.") |
| 137 | + benchmark_parser.add_argument("--endpoint", type=str, default="/v1/completions", help="API endpoint.") |
135 | 138 |
|
136 |
| - req_group = parser.add_mutually_exclusive_group() |
| 139 | + req_group = benchmark_parser.add_mutually_exclusive_group() |
137 | 140 |
|
138 | 141 | req_group.add_argument("--num-of-req", type=int, default=None, help="Total number of request.")
|
139 | 142 |
|
140 | 143 | req_group.add_argument("--max-time-for-reqs", type=int, default=None, help="Max time for requests in seconds.")
|
141 | 144 |
|
142 |
| - parser.add_argument( |
| 145 | + benchmark_parser.add_argument( |
143 | 146 | "--request-distribution",
|
144 | 147 | nargs="*",
|
145 | 148 | default=["exponential", 1],
|
146 | 149 | help="Request distribution [Distribution_type (inputs to distribution)]",
|
147 | 150 | )
|
148 | 151 |
|
149 |
| - parser.add_argument( |
| 152 | + benchmark_parser.add_argument( |
150 | 153 | "--input-token-distribution",
|
151 | 154 | nargs="*",
|
152 | 155 | default=["uniform", 0, 255],
|
153 | 156 | help="Request distribution [Distribution_type (inputs to distribution)]",
|
154 | 157 | )
|
155 | 158 |
|
156 |
| - parser.add_argument( |
| 159 | + benchmark_parser.add_argument( |
157 | 160 | "--output-token-distribution",
|
158 | 161 | nargs="*",
|
159 | 162 | default=["uniform", 0, 255],
|
160 | 163 | help="Request distribution [Distribution_type (inputs to distribution)]",
|
161 | 164 | )
|
162 | 165 |
|
163 |
| - prefix_group = parser.add_mutually_exclusive_group() |
| 166 | + prefix_group = benchmark_parser.add_mutually_exclusive_group() |
164 | 167 |
|
165 | 168 | prefix_group.add_argument("--prefix-text", type=str, default=None, help="Text to use as prefix for all requests.")
|
166 | 169 |
|
167 | 170 | prefix_group.add_argument("--prefix-len", type=int, default=None, help="Length of prefix to use for all requests.")
|
168 | 171 |
|
169 | 172 | prefix_group.add_argument('--no-prefix', action='store_true', help='No prefix for requests.')
|
170 | 173 |
|
171 |
| - parser.add_argument("--disable-ignore-eos", action="store_true", help="Disables ignoring the eos token") |
| 174 | + benchmark_parser.add_argument("--disable-ignore-eos", action="store_true", help="Disables ignoring the eos token") |
172 | 175 |
|
173 |
| - parser.add_argument("--disable-stream", action="store_true", help="Disable stream response from API") |
| 176 | + benchmark_parser.add_argument("--disable-stream", action="store_true", help="Disable stream response from API") |
174 | 177 |
|
175 |
| - parser.add_argument("--cookies", default={}, help="Insert cookies in the request") |
| 178 | + benchmark_parser.add_argument("--cookies", default={}, help="Insert cookies in the request") |
176 | 179 |
|
177 |
| - parser.add_argument( |
| 180 | + benchmark_parser.add_argument( |
178 | 181 | "--dataset-name",
|
179 | 182 | type=str,
|
180 | 183 | default="random",
|
181 | 184 | choices=["sharegpt", "other", "random"],
|
182 | 185 | help="Name of the dataset to benchmark on.",
|
183 | 186 | )
|
184 | 187 |
|
185 |
| - parser.add_argument("--dataset-path", type=str, default=None, help="Path to the dataset.") |
| 188 | + benchmark_parser.add_argument("--dataset-path", type=str, default=None, help="Path to the dataset.") |
186 | 189 |
|
187 |
| - parser.add_argument("--model", type=str, help="Name of the model.") |
| 190 | + benchmark_parser.add_argument("--model", type=str, help="Name of the model.") |
188 | 191 |
|
189 |
| - parser.add_argument( |
| 192 | + benchmark_parser.add_argument( |
190 | 193 | "--tokenizer", type=str, default=None, help="Name or path of the tokenizer, if not using the default tokenizer."
|
191 | 194 | )
|
192 | 195 |
|
193 |
| - parser.add_argument("--disable-tqdm", action="store_true", help="Specify to disable tqdm progress bar.") |
| 196 | + benchmark_parser.add_argument("--disable-tqdm", action="store_true", help="Specify to disable tqdm progress bar.") |
194 | 197 |
|
195 |
| - parser.add_argument("--best-of", type=int, default=1, help="Number of best completions to return.") |
| 198 | + benchmark_parser.add_argument("--best-of", type=int, default=1, help="Number of best completions to return.") |
196 | 199 |
|
197 |
| - parser.add_argument("--use-beam-search", action="store_true", help="Use beam search for completions.") |
| 200 | + benchmark_parser.add_argument("--use-beam-search", action="store_true", help="Use beam search for completions.") |
198 | 201 |
|
199 |
| - parser.add_argument( |
| 202 | + benchmark_parser.add_argument( |
200 | 203 | "--output-file",
|
201 | 204 | type=str,
|
202 | 205 | default='output-file.json',
|
203 | 206 | required=False,
|
204 | 207 | help="Output json file to save the results.",
|
205 | 208 | )
|
206 | 209 |
|
207 |
| - parser.add_argument("--debug", action="store_true", help="Log debug messages") |
| 210 | + benchmark_parser.add_argument("--debug", action="store_true", help="Log debug messages") |
208 | 211 |
|
209 |
| - parser.add_argument("--verbose", action="store_true", help="Print short description of each request") |
| 212 | + benchmark_parser.add_argument("--verbose", action="store_true", help="Print short description of each request") |
210 | 213 |
|
211 |
| - parser.add_argument("--config-file", default=None, help="configuration file") |
| 214 | + benchmark_parser.add_argument("--config-file", default=None, help="configuration file") |
| 215 | + |
| 216 | + |
| 217 | +def parse_args() -> argparse.Namespace: |
| 218 | + |
| 219 | + parser = argparse.ArgumentParser(description="CentML Inference Benchmark") |
| 220 | + |
| 221 | + subparsers = parser.add_subparsers(title='Subcommands', dest='subcommand') |
| 222 | + |
| 223 | + add_performance_parser(subparsers) |
| 224 | + add_benchmark_subparser(subparsers) |
| 225 | + add_ttft_parser(subparsers) |
| 226 | + add_itl_parser(subparsers) |
212 | 227 |
|
213 | 228 | args = parser.parse_args()
|
214 |
| - if args.config_file: |
215 |
| - with open(args.config_file, 'r') as f: |
216 |
| - parser.set_defaults(**json.load(f)) |
217 |
| - # Reload arguments to override config file values with command line values |
218 |
| - args = parser.parse_args() |
219 |
| - if not (args.prefix_text or args.prefix_len or args.no_prefix): |
220 |
| - parser.error("Please provide either prefix text or prefix length or specify no prefix.") |
221 |
| - if not (args.num_of_req or args.max_time_for_reqs): |
222 |
| - parser.error("Please provide either number of requests or max time for requests.") |
223 |
| - if not args.model: |
224 |
| - parser.error("Please provide the model name.") |
| 229 | + if args.subcommand == 'benchmark': |
| 230 | + if args.config_file: |
| 231 | + with open(args.config_file, 'r') as f: |
| 232 | + file_data = json.load(f) |
| 233 | + for k, v in file_data.items(): |
| 234 | + # Reload arguments to override config file values with command line values |
| 235 | + setattr(args, k, v) |
| 236 | + if not (args.prefix_text or args.prefix_len or args.no_prefix): |
| 237 | + parser.error("Please provide either prefix text or prefix length or specify no prefix.") |
| 238 | + if not (args.num_of_req or args.max_time_for_reqs): |
| 239 | + parser.error("Please provide either number of requests or max time for requests.") |
| 240 | + if not args.model: |
| 241 | + parser.error("Please provide the model name.") |
| 242 | + |
225 | 243 | return args
|
226 | 244 |
|
227 | 245 |
|
228 |
| -def main() -> None: |
229 |
| - args = parse_args() |
| 246 | +def run_main(args: argparse.Namespace) -> None: |
230 | 247 | configure_logging(args)
|
231 | 248 | if args.workload_type:
|
232 | 249 | workload_type = WORKLOADS_TYPES[args.workload_type]()
|
@@ -285,5 +302,23 @@ def main() -> None:
|
285 | 302 | logger.debug(f"{output_list}")
|
286 | 303 |
|
287 | 304 |
|
| 305 | +def main() -> None: |
| 306 | + args = parse_args() |
| 307 | + if args.subcommand == "analyse": |
| 308 | + from flexible_inference_benchmark.data_postprocessors.performance import run |
| 309 | + |
| 310 | + run(args) |
| 311 | + elif args.subcommand == "generate-ttft-plot": |
| 312 | + from flexible_inference_benchmark.data_postprocessors.ttft import run |
| 313 | + |
| 314 | + run(args) |
| 315 | + elif args.subcommand == "generate-itl-plot": |
| 316 | + from flexible_inference_benchmark.data_postprocessors.itl import run |
| 317 | + |
| 318 | + run(args) |
| 319 | + else: |
| 320 | + run_main(args) |
| 321 | + |
| 322 | + |
288 | 323 | if __name__ == '__main__':
|
289 | 324 | main()
|
0 commit comments