Skip to content

Commit dbeace0

Browse files
committed
tc
1 parent af35d79 commit dbeace0

File tree

1 file changed

+88
-25
lines changed

1 file changed

+88
-25
lines changed

tools/torchci/td/get_reverts_caused_by_td.py

Lines changed: 88 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99

1010
import argparse
1111
import re
12+
from collections import defaultdict
1213
from concurrent.futures import ThreadPoolExecutor
13-
from dataclasses import dataclass
14+
from dataclasses import dataclass, field
1415
from functools import lru_cache
15-
from typing import Optional
16+
from typing import Any, Optional
1617

1718
import requests
1819
from torchci.clickhouse import query_clickhouse
@@ -23,6 +24,7 @@
2324
class JobFailure:
2425
torchci_classification_line: str
2526
job_name: str
27+
run_id: int
2628
failed_test: Optional[str] = None
2729

2830

@@ -37,7 +39,7 @@ class CommitInfo:
3739
timestamp_of_merge: int = 0
3840
pr_num: int = 0
3941
last_pr_sha: Optional[str] = None
40-
run_id: Optional[int] = None
42+
run_ids: list[int] = field(default_factory=list)
4143

4244

4345
class IndentPrinter:
@@ -89,14 +91,15 @@ def __exit__(self, exc_type, exc_val, exc_tb):
8991
TORCHCI_CLASSIFICATION_QUERY = """
9092
select
9193
name as job_name,
94+
run_id as run_id,
9295
torchci_classification.line as line,
9396
head_sha
9497
from
9598
default.workflow_job
9699
where
97100
head_sha in {shas: Array(String)}
98101
and conclusion = 'failure'
99-
and workflow_name = 'pull'
102+
and workflow_name in ('pull', 'trunk', 'periodic', 'slow')
100103
"""
101104

102105
WORKFLOW_ID_QUERY = """
@@ -108,7 +111,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
108111
default .workflow_run
109112
where
110113
head_sha in {shas: Array(String) }
111-
and name = 'pull'
114+
and name in ('pull', 'trunk', 'periodic', 'slow')
112115
"""
113116

114117

@@ -164,12 +167,29 @@ def get_full_commit_message(sha: str) -> str:
164167

165168

166169
@lru_cache
167-
def get_td_exclusions(run_id: int) -> dict:
168-
"""Fetches the TD exclusions for a given run_id"""
170+
def get_td_exclusions(run_ids: tuple[int]) -> dict:
171+
"""Fetches the TD exclusions for some run ids."""
172+
exclusions = defaultdict(lambda: defaultdict(list))
173+
for run_id in run_ids:
174+
for i in range(3):
175+
response = requests.get(
176+
f"https://ossci-raw-job-status.s3.amazonaws.com/additional_info/td_exclusions/{run_id}/{i + 1}"
177+
)
178+
if response.status_code == 200:
179+
for build_env, test_configs in response.json().items():
180+
for test_config, tests in test_configs.items():
181+
exclusions[build_env][test_config].extend(tests)
182+
return dict(exclusions)
183+
184+
185+
@lru_cache
186+
def get_failures_additional_test_info(
187+
run_id: int,
188+
) -> dict[str, Any]:
189+
"""Fetches additional test info for failures in the given run_id."""
169190
for i in range(3):
170-
response = requests.get(
171-
f"https://ossci-raw-job-status.s3.amazonaws.com/additional_info/td_exclusions/{run_id}/{i + 1}"
172-
)
191+
url = f"https://ossci-raw-job-status.s3.amazonaws.com/additional_info/reruns/{run_id}/{i + 1}"
192+
response = requests.get(url)
173193
if response.status_code == 200:
174194
return response.json()
175195
return {}
@@ -272,7 +292,11 @@ def process_sha(i: int) -> Optional[CommitInfo]:
272292
alt_last_pr_sha = (row["head_sha"], timestamp)
273293
if alt_last_pr_sha[0] != commit.last_pr_sha and commit.last_pr_sha is not None:
274294
p.print(
275-
f"for commit {commit.id} with pr {commit.pr_num}, found last pr sha != alt, {commit.last_pr_sha} != {alt_last_pr_sha[0]}"
295+
f"commit={commit.id} "
296+
f"pr={commit.pr_num} "
297+
f"merge={commit.merge_commit_sha} "
298+
f"timestamp_of_merge={commit.timestamp_of_merge} "
299+
f"found last pr sha != alt, {commit.last_pr_sha} != {alt_last_pr_sha[0]}"
276300
)
277301
bad += 1
278302
if commit.last_pr_sha is None:
@@ -299,14 +323,13 @@ def process_sha(i: int) -> Optional[CommitInfo]:
299323
commit.last_pr_sha == head_sha
300324
and created_at < commit.timestamp_of_merge
301325
):
302-
commit.run_id = int(run_id)
303-
326+
commit.run_ids.append(int(run_id))
304327
return commits_reverted
305328

306329

307330
def get_job_failures(shas: list[str]) -> dict[str, list[JobFailure]]:
308331
"""Fetches job failures for the given SHAs."""
309-
# Need to batch b/c too many shas
332+
# Need to batch in case too many SHAs
310333
batch_size = 500
311334
failures_dict: dict[str, list[JobFailure]] = {}
312335
with ThreadPoolExecutor(max_workers=8) as executor:
@@ -325,20 +348,60 @@ def get_job_failures(shas: list[str]) -> dict[str, list[JobFailure]]:
325348
for row in job_failures:
326349
head_sha = row["head_sha"]
327350
job_name = row["job_name"]
351+
run_id = row["run_id"]
328352
line = row["line"]
329353
if head_sha not in failures_dict:
330354
failures_dict[head_sha] = []
331355
failures_dict[head_sha].append(
332356
JobFailure(
333357
torchci_classification_line=line,
334358
job_name=job_name,
359+
run_id=int(run_id),
335360
failed_test=get_test_file(line),
336361
)
337362
)
363+
del futures
364+
365+
futures2 = []
366+
with ThreadPoolExecutor(max_workers=8) as executor:
367+
for sha, failures in failures_dict.items():
368+
run_ids = set(f.run_id for f in failures if f.run_id is not None)
369+
for run_id in run_ids:
370+
futures2.append((sha, executor.submit(get_failures_for_run_id, run_id)))
371+
for sha, future in futures2:
372+
additional_failures = future.result()
373+
failures_dict[sha].extend(additional_failures)
338374
return failures_dict
339375

340376

341-
def check_failure_in_td_exclusion(f: JobFailure, run_id: int) -> bool:
377+
@lru_cache
378+
def get_failures_for_run_id(run_id: int) -> list[JobFailure]:
379+
"""Fetches the failures for the given run_id."""
380+
failures = get_failures_additional_test_info(run_id)
381+
job_failures = []
382+
for build, d in failures.items():
383+
for test_config, dd in d.items():
384+
for test_file, ddd in dd.items():
385+
for test_class, dddd in ddd.items():
386+
for test_name, info in dddd.items():
387+
failed = True
388+
for i in info:
389+
if "failure" not in i:
390+
failed = False
391+
if failed:
392+
job_failures.append(
393+
JobFailure(
394+
torchci_classification_line=f"{test_file}::{test_class}::{test_name}",
395+
job_name=f"{build} / test ({test_config}, 1, 1, runner)",
396+
run_id=run_id,
397+
failed_test=f"{test_file}",
398+
)
399+
)
400+
401+
return job_failures
402+
403+
404+
def check_failure_in_td_exclusion(f: JobFailure, run_ids: list[int]) -> bool:
342405
"""True if the commit is bad (excluded in TD)"""
343406
x = re.search(JOB_NAME_REGEX, f.job_name)
344407
if x is None:
@@ -347,26 +410,26 @@ def check_failure_in_td_exclusion(f: JobFailure, run_id: int) -> bool:
347410
)
348411
return False
349412

350-
td_exclusions = get_td_exclusions(run_id)
413+
td_exclusions = get_td_exclusions(tuple(run_ids))
351414
build_env = x.group(1)
352415
test_config = x.group(2)
353416
p.print(
354417
f"Build environment: {build_env}, Test config: {test_config}, len(td_exclusions): {len(td_exclusions)}"
355418
)
356419
if len(td_exclusions) == 0:
357-
p.print(f"No TD exclusions found for run {run_id}")
420+
p.print(f"No TD exclusions found for run {run_ids}")
358421
return False
359422
if build_env not in td_exclusions:
360423
p.print(
361-
f"Build environment {build_env} not found in TD exclusions for run {run_id}"
424+
f"Build environment {build_env} not found in TD exclusions for run {run_ids}"
362425
)
363426
elif test_config not in td_exclusions[build_env]:
364-
p.print(f"Test {test_config} not found in TD exclusions for run {run_id}")
427+
p.print(f"Test {test_config} not found in TD exclusions for run {run_ids}")
365428
elif f.failed_test in td_exclusions[build_env][test_config]:
366-
p.print(f"Test {f.failed_test} is excluded in TD for run {run_id}")
429+
p.print(f"Test {f.failed_test} is excluded in TD for run {run_ids}")
367430
return True
368431
else:
369-
p.print(f"Test {f.failed_test} is not excluded in TD for run {run_id}")
432+
p.print(f"Test {f.failed_test} is not excluded in TD for run {run_ids}")
370433
return False
371434

372435

@@ -410,8 +473,8 @@ def main() -> None:
410473
p.print(f"Merge commit: {s.merge_commit_sha}")
411474
p.print(f"Merge commit prev: {s.merge_commit_sha_prev}")
412475
p.print(f"Last PR sha: {s.last_pr_sha}")
413-
p.print(f"Run ID: {s.run_id}")
414-
if s.run_id is None:
476+
p.print(f"Run ID: {s.run_ids}")
477+
if len(s.run_ids) == 0:
415478
p.print(f"Run ID is None for commit {s.last_pr_sha}, skipping")
416479
unable_to_check += 1
417480
continue
@@ -443,11 +506,11 @@ def main() -> None:
443506
)
444507
continue
445508

446-
any_bad |= check_failure_in_td_exclusion(f, s.run_id)
509+
any_bad |= check_failure_in_td_exclusion(f, s.run_ids)
447510
if any_bad:
448511
caused_by_bad_td.append(s)
449512
p.print(
450-
f"Commit {s.last_pr_sha} with run_id {s.run_id} is caused by bad TD"
513+
f"Commit {s.last_pr_sha} with run_id {s.run_ids} is caused by bad TD"
451514
)
452515
p.print(
453516
f"CAUSED BY BAD TD: {len(caused_by_bad_td)} / {i + 1} = {len(caused_by_bad_td) / (i + 1):.2%}"

0 commit comments

Comments
 (0)