9
9
10
10
import argparse
11
11
import re
12
+ from collections import defaultdict
12
13
from concurrent .futures import ThreadPoolExecutor
13
- from dataclasses import dataclass
14
+ from dataclasses import dataclass , field
14
15
from functools import lru_cache
15
- from typing import Optional
16
+ from typing import Any , Optional
16
17
17
18
import requests
18
19
from torchci .clickhouse import query_clickhouse
23
24
class JobFailure :
24
25
torchci_classification_line : str
25
26
job_name : str
27
+ run_id : int
26
28
failed_test : Optional [str ] = None
27
29
28
30
@@ -37,7 +39,7 @@ class CommitInfo:
37
39
timestamp_of_merge : int = 0
38
40
pr_num : int = 0
39
41
last_pr_sha : Optional [str ] = None
40
- run_id : Optional [int ] = None
42
+ run_ids : list [int ] = field ( default_factory = list )
41
43
42
44
43
45
class IndentPrinter :
@@ -89,14 +91,15 @@ def __exit__(self, exc_type, exc_val, exc_tb):
89
91
TORCHCI_CLASSIFICATION_QUERY = """
90
92
select
91
93
name as job_name,
94
+ run_id as run_id,
92
95
torchci_classification.line as line,
93
96
head_sha
94
97
from
95
98
default.workflow_job
96
99
where
97
100
head_sha in {shas: Array(String)}
98
101
and conclusion = 'failure'
99
- and workflow_name = 'pull'
102
+ and workflow_name in ( 'pull', 'trunk', 'periodic', 'slow')
100
103
"""
101
104
102
105
WORKFLOW_ID_QUERY = """
@@ -108,7 +111,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
108
111
default .workflow_run
109
112
where
110
113
head_sha in {shas: Array(String) }
111
- and name = 'pull'
114
+ and name in ( 'pull', 'trunk', 'periodic', 'slow')
112
115
"""
113
116
114
117
@@ -164,12 +167,29 @@ def get_full_commit_message(sha: str) -> str:
164
167
165
168
166
169
@lru_cache
167
- def get_td_exclusions (run_id : int ) -> dict :
168
- """Fetches the TD exclusions for a given run_id"""
170
+ def get_td_exclusions (run_ids : tuple [int ]) -> dict :
171
+ """Fetches the TD exclusions for some run ids."""
172
+ exclusions = defaultdict (lambda : defaultdict (list ))
173
+ for run_id in run_ids :
174
+ for i in range (3 ):
175
+ response = requests .get (
176
+ f"https://ossci-raw-job-status.s3.amazonaws.com/additional_info/td_exclusions/{ run_id } /{ i + 1 } "
177
+ )
178
+ if response .status_code == 200 :
179
+ for build_env , test_configs in response .json ().items ():
180
+ for test_config , tests in test_configs .items ():
181
+ exclusions [build_env ][test_config ].extend (tests )
182
+ return dict (exclusions )
183
+
184
+
185
+ @lru_cache
186
+ def get_failures_additional_test_info (
187
+ run_id : int ,
188
+ ) -> dict [str , Any ]:
189
+ """Fetches additional test info for failures in the given run_id."""
169
190
for i in range (3 ):
170
- response = requests .get (
171
- f"https://ossci-raw-job-status.s3.amazonaws.com/additional_info/td_exclusions/{ run_id } /{ i + 1 } "
172
- )
191
+ url = f"https://ossci-raw-job-status.s3.amazonaws.com/additional_info/reruns/{ run_id } /{ i + 1 } "
192
+ response = requests .get (url )
173
193
if response .status_code == 200 :
174
194
return response .json ()
175
195
return {}
@@ -272,7 +292,11 @@ def process_sha(i: int) -> Optional[CommitInfo]:
272
292
alt_last_pr_sha = (row ["head_sha" ], timestamp )
273
293
if alt_last_pr_sha [0 ] != commit .last_pr_sha and commit .last_pr_sha is not None :
274
294
p .print (
275
- f"for commit { commit .id } with pr { commit .pr_num } , found last pr sha != alt, { commit .last_pr_sha } != { alt_last_pr_sha [0 ]} "
295
+ f"commit={ commit .id } "
296
+ f"pr={ commit .pr_num } "
297
+ f"merge={ commit .merge_commit_sha } "
298
+ f"timestamp_of_merge={ commit .timestamp_of_merge } "
299
+ f"found last pr sha != alt, { commit .last_pr_sha } != { alt_last_pr_sha [0 ]} "
276
300
)
277
301
bad += 1
278
302
if commit .last_pr_sha is None :
@@ -299,14 +323,13 @@ def process_sha(i: int) -> Optional[CommitInfo]:
299
323
commit .last_pr_sha == head_sha
300
324
and created_at < commit .timestamp_of_merge
301
325
):
302
- commit .run_id = int (run_id )
303
-
326
+ commit .run_ids .append (int (run_id ))
304
327
return commits_reverted
305
328
306
329
307
330
def get_job_failures (shas : list [str ]) -> dict [str , list [JobFailure ]]:
308
331
"""Fetches job failures for the given SHAs."""
309
- # Need to batch b/c too many shas
332
+ # Need to batch in case too many SHAs
310
333
batch_size = 500
311
334
failures_dict : dict [str , list [JobFailure ]] = {}
312
335
with ThreadPoolExecutor (max_workers = 8 ) as executor :
@@ -325,20 +348,60 @@ def get_job_failures(shas: list[str]) -> dict[str, list[JobFailure]]:
325
348
for row in job_failures :
326
349
head_sha = row ["head_sha" ]
327
350
job_name = row ["job_name" ]
351
+ run_id = row ["run_id" ]
328
352
line = row ["line" ]
329
353
if head_sha not in failures_dict :
330
354
failures_dict [head_sha ] = []
331
355
failures_dict [head_sha ].append (
332
356
JobFailure (
333
357
torchci_classification_line = line ,
334
358
job_name = job_name ,
359
+ run_id = int (run_id ),
335
360
failed_test = get_test_file (line ),
336
361
)
337
362
)
363
+ del futures
364
+
365
+ futures2 = []
366
+ with ThreadPoolExecutor (max_workers = 8 ) as executor :
367
+ for sha , failures in failures_dict .items ():
368
+ run_ids = set (f .run_id for f in failures if f .run_id is not None )
369
+ for run_id in run_ids :
370
+ futures2 .append ((sha , executor .submit (get_failures_for_run_id , run_id )))
371
+ for sha , future in futures2 :
372
+ additional_failures = future .result ()
373
+ failures_dict [sha ].extend (additional_failures )
338
374
return failures_dict
339
375
340
376
341
- def check_failure_in_td_exclusion (f : JobFailure , run_id : int ) -> bool :
377
+ @lru_cache
378
+ def get_failures_for_run_id (run_id : int ) -> list [JobFailure ]:
379
+ """Fetches the failures for the given run_id."""
380
+ failures = get_failures_additional_test_info (run_id )
381
+ job_failures = []
382
+ for build , d in failures .items ():
383
+ for test_config , dd in d .items ():
384
+ for test_file , ddd in dd .items ():
385
+ for test_class , dddd in ddd .items ():
386
+ for test_name , info in dddd .items ():
387
+ failed = True
388
+ for i in info :
389
+ if "failure" not in i :
390
+ failed = False
391
+ if failed :
392
+ job_failures .append (
393
+ JobFailure (
394
+ torchci_classification_line = f"{ test_file } ::{ test_class } ::{ test_name } " ,
395
+ job_name = f"{ build } / test ({ test_config } , 1, 1, runner)" ,
396
+ run_id = run_id ,
397
+ failed_test = f"{ test_file } " ,
398
+ )
399
+ )
400
+
401
+ return job_failures
402
+
403
+
404
+ def check_failure_in_td_exclusion (f : JobFailure , run_ids : list [int ]) -> bool :
342
405
"""True if the commit is bad (excluded in TD)"""
343
406
x = re .search (JOB_NAME_REGEX , f .job_name )
344
407
if x is None :
@@ -347,26 +410,26 @@ def check_failure_in_td_exclusion(f: JobFailure, run_id: int) -> bool:
347
410
)
348
411
return False
349
412
350
- td_exclusions = get_td_exclusions (run_id )
413
+ td_exclusions = get_td_exclusions (tuple ( run_ids ) )
351
414
build_env = x .group (1 )
352
415
test_config = x .group (2 )
353
416
p .print (
354
417
f"Build environment: { build_env } , Test config: { test_config } , len(td_exclusions): { len (td_exclusions )} "
355
418
)
356
419
if len (td_exclusions ) == 0 :
357
- p .print (f"No TD exclusions found for run { run_id } " )
420
+ p .print (f"No TD exclusions found for run { run_ids } " )
358
421
return False
359
422
if build_env not in td_exclusions :
360
423
p .print (
361
- f"Build environment { build_env } not found in TD exclusions for run { run_id } "
424
+ f"Build environment { build_env } not found in TD exclusions for run { run_ids } "
362
425
)
363
426
elif test_config not in td_exclusions [build_env ]:
364
- p .print (f"Test { test_config } not found in TD exclusions for run { run_id } " )
427
+ p .print (f"Test { test_config } not found in TD exclusions for run { run_ids } " )
365
428
elif f .failed_test in td_exclusions [build_env ][test_config ]:
366
- p .print (f"Test { f .failed_test } is excluded in TD for run { run_id } " )
429
+ p .print (f"Test { f .failed_test } is excluded in TD for run { run_ids } " )
367
430
return True
368
431
else :
369
- p .print (f"Test { f .failed_test } is not excluded in TD for run { run_id } " )
432
+ p .print (f"Test { f .failed_test } is not excluded in TD for run { run_ids } " )
370
433
return False
371
434
372
435
@@ -410,8 +473,8 @@ def main() -> None:
410
473
p .print (f"Merge commit: { s .merge_commit_sha } " )
411
474
p .print (f"Merge commit prev: { s .merge_commit_sha_prev } " )
412
475
p .print (f"Last PR sha: { s .last_pr_sha } " )
413
- p .print (f"Run ID: { s .run_id } " )
414
- if s . run_id is None :
476
+ p .print (f"Run ID: { s .run_ids } " )
477
+ if len ( s . run_ids ) == 0 :
415
478
p .print (f"Run ID is None for commit { s .last_pr_sha } , skipping" )
416
479
unable_to_check += 1
417
480
continue
@@ -443,11 +506,11 @@ def main() -> None:
443
506
)
444
507
continue
445
508
446
- any_bad |= check_failure_in_td_exclusion (f , s .run_id )
509
+ any_bad |= check_failure_in_td_exclusion (f , s .run_ids )
447
510
if any_bad :
448
511
caused_by_bad_td .append (s )
449
512
p .print (
450
- f"Commit { s .last_pr_sha } with run_id { s .run_id } is caused by bad TD"
513
+ f"Commit { s .last_pr_sha } with run_id { s .run_ids } is caused by bad TD"
451
514
)
452
515
p .print (
453
516
f"CAUSED BY BAD TD: { len (caused_by_bad_td )} / { i + 1 } = { len (caused_by_bad_td ) / (i + 1 ):.2%} "
0 commit comments