Skip to content

Commit dc1e9a9

Browse files
authored
Merge pull request #3301 from locustio/simplify-custom-messages-match-case-statements-in-runners.py
refactor case statements and update to use 3.10 set syntax
2 parents dbdf85e + d552bf4 commit dc1e9a9

File tree

5 files changed

+76
-75
lines changed

5 files changed

+76
-75
lines changed

locust/opentelemetry.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,8 @@ def setup_opentelemetry() -> bool:
1515
logger.error("OpenTelemetry SDK is not installed, opentelemetry not enabled. Run 'pip install locust[otel]'")
1616
return False
1717

18-
traces_exporters = set(e.strip().lower() for e in os.getenv("OTEL_TRACES_EXPORTER", "otlp").split(",") if e.strip())
19-
metrics_exporters = set(
20-
e.strip().lower() for e in os.getenv("OTEL_METRICS_EXPORTER", "otlp").split(",") if e.strip()
21-
)
18+
traces_exporters = {e.strip().lower() for e in os.getenv("OTEL_TRACES_EXPORTER", "otlp").split(",") if e.strip()}
19+
metrics_exporters = {e.strip().lower() for e in os.getenv("OTEL_METRICS_EXPORTER", "otlp").split(",") if e.strip()}
2220

2321
if traces_exporters == {"none"} and metrics_exporters == {"none"}:
2422
logger.info("No OpenTelemetry exporters configured, opentelemetry not enabled")

locust/runners.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,23 +1149,22 @@ def handle_message(self, client_id: str, msg: Message) -> None:
11491149
self.quit()
11501150
case "exception":
11511151
self.log_exception(msg.node_id, msg.data["msg"], msg.data["traceback"])
1152+
case _ if lc := self.custom_messages.get(msg.type):
1153+
listener, concurrent = lc
1154+
logger.debug(
1155+
f"Received {msg.type} message from worker {msg.node_id} (index {self.get_worker_index(msg.node_id)})"
1156+
)
1157+
try:
1158+
if not concurrent:
1159+
listener(environment=self.environment, msg=msg)
1160+
else:
1161+
gevent.spawn(listener, environment=self.environment, msg=msg)
1162+
except Exception:
1163+
logging.error(f"Uncaught exception in handler for {msg.type}\n{traceback.format_exc()}")
11521164
case _:
1153-
if lc := self.custom_messages.get(msg.type):
1154-
listener, concurrent = lc
1155-
logger.debug(
1156-
f"Received {msg.type} message from worker {msg.node_id} (index {self.get_worker_index(msg.node_id)})"
1157-
)
1158-
try:
1159-
if not concurrent:
1160-
listener(environment=self.environment, msg=msg)
1161-
else:
1162-
gevent.spawn(listener, environment=self.environment, msg=msg)
1163-
except Exception:
1164-
logging.error(f"Uncaught exception in handler for {msg.type}\n{traceback.format_exc()}")
1165-
else:
1166-
logger.warning(
1167-
f"Unknown message type received from worker {msg.node_id} (index {self.get_worker_index(msg.node_id)}): {msg.type}"
1168-
)
1165+
logger.warning(
1166+
f"Unknown message type received from worker {msg.node_id} (index {self.get_worker_index(msg.node_id)}): {msg.type}"
1167+
)
11691168

11701169
self.check_stopped()
11711170

@@ -1423,16 +1422,15 @@ def handle_message(self, msg: Message) -> None:
14231422
case "spawning_complete":
14241423
# master says we have finished spawning (happens only once during a normal rampup)
14251424
self.environment.events.spawning_complete.fire(user_count=msg.data["user_count"])
1426-
case _:
1427-
if lc := self.custom_messages.get(msg.type):
1428-
listener, concurrent = lc
1429-
logger.debug(f"Received {msg.type} message from master")
1430-
if not concurrent:
1431-
listener(environment=self.environment, msg=msg)
1432-
else:
1433-
gevent.spawn(listener, self.environment, msg)
1425+
case _ if lc := self.custom_messages.get(msg.type):
1426+
listener, concurrent = lc
1427+
logger.debug(f"Received {msg.type} message from master")
1428+
if not concurrent:
1429+
listener(environment=self.environment, msg=msg)
14341430
else:
1435-
logger.warning(f"Unknown message type received: {msg.type}")
1431+
gevent.spawn(listener, self.environment, msg)
1432+
case _:
1433+
logger.warning(f"Unknown message type received: {msg.type}")
14361434

14371435
def stats_reporter(self) -> NoReturn:
14381436
while True:

locust/test/test_runners.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3366,7 +3366,11 @@ def my_task(self):
33663366
sleep(0.1)
33673367
server.mocked_send(Message("stats", BAD_MESSAGE, "zeh_fake_client1"))
33683368
messages = server.get_messages()
3369-
self.assertEqual(5, len(messages))
3369+
self.assertEqual(messages[0].type, "ack")
3370+
self.assertEqual(messages[1].type, "spawn")
3371+
self.assertEqual(messages[2].type, "spawning_complete")
3372+
self.assertEqual(messages[3].type, "reconnect")
3373+
self.assertEqual(messages[4].type, "ack")
33703374

33713375
# Expected message order in outbox: ack, spawn, reconnect, ack
33723376
self.assertEqual(

locust/user/markov_taskset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def dfs(task_name):
206206
dfs(dest)
207207

208208
dfs(tasks[0].__name__)
209-
unreachable = set([task.__name__ for task in tasks]) - visited
209+
unreachable = {task.__name__ for task in tasks} - visited
210210

211211
if len(unreachable) > 0:
212212
logging.warning(f"The following markov tasks are unreachable in class {classname}: {unreachable}")

locust/web.py

Lines changed: 45 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -257,51 +257,52 @@ def swarm() -> Response:
257257
user_count = None
258258
spawn_rate = None
259259
for key, value in request.form.items():
260-
if key == "user_count": # if we just renamed this field to "users" we wouldn't need this
261-
user_count = int(value)
262-
parsed_options_dict["users"] = user_count
263-
elif key == "spawn_rate":
264-
spawn_rate = float(value)
265-
parsed_options_dict[key] = spawn_rate
266-
elif key == "host":
267-
# Replace < > to guard against XSS
268-
environment.host = str(request.form["host"]).replace("<", "").replace(">", "")
269-
parsed_options_dict[key] = environment.host
270-
elif key == "user_classes":
271-
# Set environment.parsed_options.user_classes to the selected user_classes
272-
parsed_options_dict[key] = request.form.getlist("user_classes")
273-
elif key == "run_time":
274-
if not value:
275-
continue
276-
try:
277-
run_time = parse_timespan(value)
278-
parsed_options_dict[key] = run_time
279-
except ValueError:
280-
err_msg = "Valid run_time formats are : 20, 20s, 3m, 2h, 1h20m, 3h30m10s, etc."
281-
logger.error(err_msg)
282-
return jsonify({"success": False, "message": err_msg, "host": environment.host})
283-
elif key == "profile":
284-
environment.profile = str(request.form["profile"]) or None
285-
parsed_options_dict[key] = environment.profile
286-
elif key in parsed_options_dict:
287-
# update the value in environment.parsed_options, but dont change the type.
288-
parsed_options_value = parsed_options_dict[key]
289-
290-
if isinstance(parsed_options_value, bool):
291-
parsed_options_dict[key] = value == "true"
292-
elif parsed_options_value is None:
293-
parsed_options_dict[key] = value
294-
elif isinstance(parsed_options_value, list):
295-
if "," in value:
296-
value_as_list = value.split(",")
260+
match key:
261+
case "user_count": # if we just renamed this field to "users" we wouldn't need this
262+
user_count = int(value)
263+
parsed_options_dict["users"] = user_count
264+
case "spawn_rate":
265+
spawn_rate = float(value)
266+
parsed_options_dict[key] = spawn_rate
267+
case "host":
268+
# Replace < > to guard against XSS
269+
environment.host = str(request.form["host"]).replace("<", "").replace(">", "")
270+
parsed_options_dict[key] = environment.host
271+
case "user_classes":
272+
# Set environment.parsed_options.user_classes to the selected user_classes
273+
parsed_options_dict[key] = request.form.getlist("user_classes")
274+
case "run_time":
275+
if not value:
276+
continue
277+
try:
278+
run_time = parse_timespan(value)
279+
parsed_options_dict[key] = run_time
280+
except ValueError:
281+
err_msg = "Valid run_time formats are : 20, 20s, 3m, 2h, 1h20m, 3h30m10s, etc."
282+
logger.error(err_msg)
283+
return jsonify({"success": False, "message": err_msg, "host": environment.host})
284+
case "profile":
285+
environment.profile = str(request.form["profile"]) or None
286+
parsed_options_dict[key] = environment.profile
287+
case _ if key in parsed_options_dict:
288+
# update the value in environment.parsed_options, but dont change the type.
289+
parsed_options_value = parsed_options_dict[key]
290+
291+
if isinstance(parsed_options_value, bool):
292+
parsed_options_dict[key] = value == "true"
293+
elif parsed_options_value is None:
294+
parsed_options_dict[key] = value
295+
elif isinstance(parsed_options_value, list):
296+
if "," in value:
297+
value_as_list = value.split(",")
298+
else:
299+
value_as_list = request.form.getlist(key)
300+
if all(isinstance(x, int) for x in parsed_options_value):
301+
parsed_options_dict[key] = list(map(int, value_as_list))
302+
else:
303+
parsed_options_dict[key] = value_as_list
297304
else:
298-
value_as_list = request.form.getlist(key)
299-
if all(isinstance(x, int) for x in parsed_options_value):
300-
parsed_options_dict[key] = list(map(int, value_as_list))
301-
else:
302-
parsed_options_dict[key] = value_as_list
303-
else:
304-
parsed_options_dict[key] = type(parsed_options_value)(value)
305+
parsed_options_dict[key] = type(parsed_options_value)(value)
305306

306307
if environment.shape_class and environment.runner is not None:
307308
environment.runner.start_shape()

0 commit comments

Comments
 (0)