Skip to content

Commit 437e924

Browse files
authored
add subparsers (#63)
Co-authored-by: John Calderon <[email protected]>
1 parent c583ef1 commit 437e924

File tree

7 files changed

+105
-78
lines changed

7 files changed

+105
-78
lines changed

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ dependencies = [
1414
"transformers",
1515
"sentencepiece",
1616
"aiohttp",
17-
"pydantic"
17+
"pydantic",
18+
"matplotlib"
1819
]
1920

2021
classifiers = [
@@ -28,4 +29,4 @@ classifiers = [
2829
package-dir = {"" = "src"}
2930

3031
[project.scripts]
31-
inference-benchmark = "flexible_inference_benchmark.main:main"
32+
fib = "flexible_inference_benchmark.main:main"

scripts/lint/format.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,4 @@ python -m black \
1313
--exclude=".*pb2.*" \
1414
--line-length 120 \
1515
$additional_opts \
16-
../../src/flexible_inference_benchmark \
17-
../../data_postprocessors
16+
../../src/flexible_inference_benchmark

scripts/lint/mypy.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ strict = True
44
follow_imports = silent
55
no_warn_unused_ignores = True
66
allow_redefinition = True
7-
7+
exclude = data_postprocessors

data_postprocessors/itl.py renamed to src/flexible_inference_benchmark/data_postprocessors/itl.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@
77
import matplotlib.pyplot as plt
88

99

10-
def parse_args():
11-
parser = argparse.ArgumentParser()
12-
parser.add_argument("--datapath", type=str, required=True, help="Path to the data file")
13-
parser.add_argument("--output", type=str, required=False, help="Path to save the plot")
14-
parser.add_argument('--request-num', type=int, default=0, help='Request number to plot')
15-
return parser.parse_args()
10+
def add_itl_parser(subparsers: argparse._SubParsersAction):
11+
itl_parser = subparsers.add_parser("generate-itl-plot")
12+
itl_parser.add_argument("--datapath", type=str, required=True, help="Path to the data file")
13+
itl_parser.add_argument("--output", type=str, required=False, help="Path to save the plot")
14+
itl_parser.add_argument('--request-num', type=int, default=0, help='Request number to plot')
1615

1716

1817
def plot_itl(data, idx, output):
@@ -26,12 +25,7 @@ def plot_itl(data, idx, output):
2625
plt.show()
2726

2827

29-
def main():
30-
args = parse_args()
28+
def run(args: argparse.Namespace):
3129
with open(args.datapath, 'r') as f:
3230
data = json.load(f)
3331
plot_itl(data, args.request_num, args.output)
34-
35-
36-
if __name__ == "__main__":
37-
main()

data_postprocessors/performance.py renamed to src/flexible_inference_benchmark/data_postprocessors/performance.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,13 @@
44

55
import json
66
import argparse
7-
87
import numpy as np
98
from transformers import AutoTokenizer
109

1110

12-
def parse_args():
13-
parser = argparse.ArgumentParser()
14-
parser.add_argument("--datapath", type=str, required=True, help="Path to the data file")
15-
return parser.parse_args()
11+
def add_performance_parser(subparsers: argparse._SubParsersAction) -> None:
12+
performance_parser = subparsers.add_parser('analyse')
13+
performance_parser.add_argument("--datapath", type=str, required=True, help='Path to the json file')
1614

1715

1816
def calculate_metrics(input_requests, outputs, benchmark_duration, tokenizer, stream):
@@ -74,13 +72,8 @@ def calculate_metrics(input_requests, outputs, benchmark_duration, tokenizer, st
7472
print("=" * 50)
7573

7674

77-
def main():
78-
args = parse_args()
75+
def run(args: argparse.Namespace):
7976
with open(args.datapath, 'r') as f:
8077
data = json.load(f)
8178
tokenizer = AutoTokenizer.from_pretrained(data["tokenizer"])
8279
calculate_metrics(data["inputs"], data["outputs"], data["time"], tokenizer, data["stream"])
83-
84-
85-
if __name__ == "__main__":
86-
main()

data_postprocessors/ttft.py renamed to src/flexible_inference_benchmark/data_postprocessors/ttft.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
Simple example of a data postprocessor script with minimal error checking and typing that shows a plot of TTFT.
3+
"""
4+
15
import argparse
26
import json
37
import matplotlib.pyplot as plt
@@ -17,14 +21,13 @@ def color_scheme_generator(num_colors):
1721

1822

1923
def generate_plot(name, data, color, axis):
20-
axis.set_ylabel('time (sec)')
24+
axis.ecdf(data, orientation="horizontal", color=color)
2125
axis.set_xlabel('CDF')
22-
axis.hist(data, orientation="horizontal", bins=len(data) // 2, fill=False, edgecolor=color, label=name)
23-
# axis.legend()
24-
axis.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), fancybox=True, shadow=True, ncol=5)
26+
axis.set_ylabel('time (sec)')
2527

2628
ax2 = axis.twiny()
27-
ax2.ecdf(data, orientation="horizontal", color=color)
29+
ax2.hist(data, orientation="horizontal", bins=len(data) // 2, fill=False, edgecolor=color, label=name)
30+
ax2.set_xticks([])
2831

2932

3033
def plot_ttft(files, color_scheme):
@@ -36,14 +39,16 @@ def plot_ttft(files, color_scheme):
3639
generate_plot(data["backend"], ttft_arr, color_scheme[i], ax1)
3740

3841
fig.tight_layout()
39-
plt.title('TTFS')
42+
plt.title('TTFT')
4043
plt.tight_layout()
4144
plt.savefig("ttft.pdf")
4245

4346

44-
if __name__ == '__main__':
45-
parser = argparse.ArgumentParser()
46-
parser.add_argument("--files", nargs="+", help="list of json files")
47-
args = parser.parse_args()
47+
def add_ttft_parser(subparsers: argparse._SubParsersAction):
48+
ttft_parser = subparsers.add_parser("generate-ttft-plot")
49+
ttft_parser.add_argument("--files", nargs="+", help="list of json files")
50+
51+
52+
def run(args: argparse.Namespace):
4853
color_scheme = color_scheme_generator(len(args.files))
4954
plot_ttft(args.files, color_scheme)

src/flexible_inference_benchmark/main.py

Lines changed: 75 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
from flexible_inference_benchmark.engine.client import Client
1515
from flexible_inference_benchmark.engine.backend_functions import ASYNC_REQUEST_FUNCS
1616
from flexible_inference_benchmark.engine.workloads import WORKLOADS_TYPES
17+
from flexible_inference_benchmark.data_postprocessors.performance import add_performance_parser
18+
from flexible_inference_benchmark.data_postprocessors.ttft import add_ttft_parser
19+
from flexible_inference_benchmark.data_postprocessors.itl import add_itl_parser
1720

1821
logger = logging.getLogger(__name__)
1922

@@ -99,134 +102,148 @@ def send_requests(
99102
return asyncio.run(client.benchmark(requests_prompts, requests_times))
100103

101104

102-
def parse_args() -> argparse.Namespace:
105+
def add_benchmark_subparser(subparsers: argparse._SubParsersAction) -> None: # type: ignore [type-arg]
103106

104-
parser = argparse.ArgumentParser(description="CentML Inference Benchmark")
107+
benchmark_parser = subparsers.add_parser('benchmark')
105108

106-
parser.add_argument("--seed", type=int, default=None, help="seed for reproducibility")
109+
benchmark_parser.add_argument("--seed", type=int, default=None, help="seed for reproducibility")
107110

108-
parser.add_argument(
111+
benchmark_parser.add_argument(
109112
"--backend",
110113
type=str,
111114
default='cserve',
112115
choices=list(ASYNC_REQUEST_FUNCS.keys()),
113116
help="Backend inference engine.",
114117
)
115118

116-
parser.add_argument(
119+
benchmark_parser.add_argument(
117120
"--workload-type",
118121
type=str,
119122
default=None,
120123
choices=list(WORKLOADS_TYPES.keys()),
121124
help="choose a workload type, this will overwrite some arguments",
122125
)
123126

124-
url_group = parser.add_mutually_exclusive_group()
127+
url_group = benchmark_parser.add_mutually_exclusive_group()
125128

126129
url_group.add_argument(
127130
"--base-url", type=str, default=None, help="Server or API base url if not using http host and port."
128131
)
129132

130-
parser.add_argument(
133+
benchmark_parser.add_argument(
131134
"--https-ssl", default=True, help="whether to check for ssl certificate for https endpoints, default is True"
132135
)
133136

134-
parser.add_argument("--endpoint", type=str, default="/v1/completions", help="API endpoint.")
137+
benchmark_parser.add_argument("--endpoint", type=str, default="/v1/completions", help="API endpoint.")
135138

136-
req_group = parser.add_mutually_exclusive_group()
139+
req_group = benchmark_parser.add_mutually_exclusive_group()
137140

138141
req_group.add_argument("--num-of-req", type=int, default=None, help="Total number of request.")
139142

140143
req_group.add_argument("--max-time-for-reqs", type=int, default=None, help="Max time for requests in seconds.")
141144

142-
parser.add_argument(
145+
benchmark_parser.add_argument(
143146
"--request-distribution",
144147
nargs="*",
145148
default=["exponential", 1],
146149
help="Request distribution [Distribution_type (inputs to distribution)]",
147150
)
148151

149-
parser.add_argument(
152+
benchmark_parser.add_argument(
150153
"--input-token-distribution",
151154
nargs="*",
152155
default=["uniform", 0, 255],
153156
help="Request distribution [Distribution_type (inputs to distribution)]",
154157
)
155158

156-
parser.add_argument(
159+
benchmark_parser.add_argument(
157160
"--output-token-distribution",
158161
nargs="*",
159162
default=["uniform", 0, 255],
160163
help="Request distribution [Distribution_type (inputs to distribution)]",
161164
)
162165

163-
prefix_group = parser.add_mutually_exclusive_group()
166+
prefix_group = benchmark_parser.add_mutually_exclusive_group()
164167

165168
prefix_group.add_argument("--prefix-text", type=str, default=None, help="Text to use as prefix for all requests.")
166169

167170
prefix_group.add_argument("--prefix-len", type=int, default=None, help="Length of prefix to use for all requests.")
168171

169172
prefix_group.add_argument('--no-prefix', action='store_true', help='No prefix for requests.')
170173

171-
parser.add_argument("--disable-ignore-eos", action="store_true", help="Disables ignoring the eos token")
174+
benchmark_parser.add_argument("--disable-ignore-eos", action="store_true", help="Disables ignoring the eos token")
172175

173-
parser.add_argument("--disable-stream", action="store_true", help="Disable stream response from API")
176+
benchmark_parser.add_argument("--disable-stream", action="store_true", help="Disable stream response from API")
174177

175-
parser.add_argument("--cookies", default={}, help="Insert cookies in the request")
178+
benchmark_parser.add_argument("--cookies", default={}, help="Insert cookies in the request")
176179

177-
parser.add_argument(
180+
benchmark_parser.add_argument(
178181
"--dataset-name",
179182
type=str,
180183
default="random",
181184
choices=["sharegpt", "other", "random"],
182185
help="Name of the dataset to benchmark on.",
183186
)
184187

185-
parser.add_argument("--dataset-path", type=str, default=None, help="Path to the dataset.")
188+
benchmark_parser.add_argument("--dataset-path", type=str, default=None, help="Path to the dataset.")
186189

187-
parser.add_argument("--model", type=str, help="Name of the model.")
190+
benchmark_parser.add_argument("--model", type=str, help="Name of the model.")
188191

189-
parser.add_argument(
192+
benchmark_parser.add_argument(
190193
"--tokenizer", type=str, default=None, help="Name or path of the tokenizer, if not using the default tokenizer."
191194
)
192195

193-
parser.add_argument("--disable-tqdm", action="store_true", help="Specify to disable tqdm progress bar.")
196+
benchmark_parser.add_argument("--disable-tqdm", action="store_true", help="Specify to disable tqdm progress bar.")
194197

195-
parser.add_argument("--best-of", type=int, default=1, help="Number of best completions to return.")
198+
benchmark_parser.add_argument("--best-of", type=int, default=1, help="Number of best completions to return.")
196199

197-
parser.add_argument("--use-beam-search", action="store_true", help="Use beam search for completions.")
200+
benchmark_parser.add_argument("--use-beam-search", action="store_true", help="Use beam search for completions.")
198201

199-
parser.add_argument(
202+
benchmark_parser.add_argument(
200203
"--output-file",
201204
type=str,
202205
default='output-file.json',
203206
required=False,
204207
help="Output json file to save the results.",
205208
)
206209

207-
parser.add_argument("--debug", action="store_true", help="Log debug messages")
210+
benchmark_parser.add_argument("--debug", action="store_true", help="Log debug messages")
208211

209-
parser.add_argument("--verbose", action="store_true", help="Print short description of each request")
212+
benchmark_parser.add_argument("--verbose", action="store_true", help="Print short description of each request")
210213

211-
parser.add_argument("--config-file", default=None, help="configuration file")
214+
benchmark_parser.add_argument("--config-file", default=None, help="configuration file")
215+
216+
217+
def parse_args() -> argparse.Namespace:
218+
219+
parser = argparse.ArgumentParser(description="CentML Inference Benchmark")
220+
221+
subparsers = parser.add_subparsers(title='Subcommands', dest='subcommand')
222+
223+
add_performance_parser(subparsers)
224+
add_benchmark_subparser(subparsers)
225+
add_ttft_parser(subparsers)
226+
add_itl_parser(subparsers)
212227

213228
args = parser.parse_args()
214-
if args.config_file:
215-
with open(args.config_file, 'r') as f:
216-
parser.set_defaults(**json.load(f))
217-
# Reload arguments to override config file values with command line values
218-
args = parser.parse_args()
219-
if not (args.prefix_text or args.prefix_len or args.no_prefix):
220-
parser.error("Please provide either prefix text or prefix length or specify no prefix.")
221-
if not (args.num_of_req or args.max_time_for_reqs):
222-
parser.error("Please provide either number of requests or max time for requests.")
223-
if not args.model:
224-
parser.error("Please provide the model name.")
229+
if args.subcommand == 'benchmark':
230+
if args.config_file:
231+
with open(args.config_file, 'r') as f:
232+
file_data = json.load(f)
233+
for k, v in file_data.items():
234+
# Reload arguments to override config file values with command line values
235+
setattr(args, k, v)
236+
if not (args.prefix_text or args.prefix_len or args.no_prefix):
237+
parser.error("Please provide either prefix text or prefix length or specify no prefix.")
238+
if not (args.num_of_req or args.max_time_for_reqs):
239+
parser.error("Please provide either number of requests or max time for requests.")
240+
if not args.model:
241+
parser.error("Please provide the model name.")
242+
225243
return args
226244

227245

228-
def main() -> None:
229-
args = parse_args()
246+
def run_main(args: argparse.Namespace) -> None:
230247
configure_logging(args)
231248
if args.workload_type:
232249
workload_type = WORKLOADS_TYPES[args.workload_type]()
@@ -285,5 +302,23 @@ def main() -> None:
285302
logger.debug(f"{output_list}")
286303

287304

305+
def main() -> None:
306+
args = parse_args()
307+
if args.subcommand == "analyse":
308+
from flexible_inference_benchmark.data_postprocessors.performance import run
309+
310+
run(args)
311+
elif args.subcommand == "generate-ttft-plot":
312+
from flexible_inference_benchmark.data_postprocessors.ttft import run
313+
314+
run(args)
315+
elif args.subcommand == "generate-itl-plot":
316+
from flexible_inference_benchmark.data_postprocessors.itl import run
317+
318+
run(args)
319+
else:
320+
run_main(args)
321+
322+
288323
if __name__ == '__main__':
289324
main()

0 commit comments

Comments
 (0)