Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/regtest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ jobs:
run: pip install .

- name: Run tests via test.py
run: ./pyro/test.py
run: ./pyro/test.py --nproc 0

105 changes: 81 additions & 24 deletions pyro/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@


import argparse
import contextlib
import datetime
import io
import os
import sys
from multiprocessing import Pool
from pathlib import Path

import pyro.pyro_sim as pyro
from pyro.multigrid.examples import (mg_test_general_inhomogeneous,
Expand All @@ -23,9 +27,64 @@ def __str__(self):
return f"{self.solver}-{self.problem}"


@contextlib.contextmanager
def avoid_interleaved_output(nproc):
"""Collect all the printed output and print it all at once to avoid interleaving."""
if nproc == 1:
# not running in parallel, so we don't have to worry about interleaving
yield
else:
output_buffer = io.StringIO()
try:
with contextlib.redirect_stdout(output_buffer), \
contextlib.redirect_stderr(output_buffer):
yield
finally:
# a single print call probably won't get interleaved
print(output_buffer.getvalue(), end="", flush=True)


def run_test(t, reset_fails, store_all_benchmarks, rtol, nproc):
orig_cwd = Path.cwd()
# run each test in its own directory, since some of the output file names
# overlap between tests, and h5py needs exclusive access when writing
test_dir = orig_cwd / f"test_outputs/{t}"
test_dir.mkdir(parents=True, exist_ok=True)
try:
os.chdir(test_dir)
with avoid_interleaved_output(nproc):
p = pyro.PyroBenchmark(t.solver, comp_bench=True,
reset_bench_on_fail=reset_fails,
make_bench=store_all_benchmarks)
p.initialize_problem(t.problem, t.inputs, t.options)
start_n = p.sim.n
err = p.run_sim(rtol)
finally:
os.chdir(orig_cwd)
if err == 0:
# the test passed; clean up the output files for developer use
basename = p.rp.get_param("io.basename")
(test_dir / f"{basename}{start_n:04d}.h5").unlink()
(test_dir / f"{basename}{p.sim.n:04d}.h5").unlink()
(test_dir / "inputs.auto").unlink()
test_dir.rmdir()
# try removing the top-level output directory
try:
test_dir.parent.rmdir()
except OSError:
pass

return str(t), err


def run_test_star(args):
"""multiprocessing doesn't like lambdas, so this needs to be a full function"""
return run_test(*args)


def do_tests(out_file,
reset_fails=False, store_all_benchmarks=False,
single=None, solver=None, rtol=1e-12):
single=None, solver=None, rtol=1e-12, nproc=1):

opts = {"driver.verbose": 0, "vis.dovis": 0, "io.do_io": 0}

Expand Down Expand Up @@ -59,13 +118,16 @@ def do_tests(out_file,
else:
tests_to_run = tests

for t in tests_to_run:
p = pyro.PyroBenchmark(t.solver, comp_bench=True,
reset_bench_on_fail=reset_fails, make_bench=store_all_benchmarks)
p.initialize_problem(t.problem, t.inputs, t.options)
err = p.run_sim(rtol)

results[str(t)] = err
if nproc == 0:
nproc = os.cpu_count()
# don't create more processes than needed
nproc = min(nproc, len(tests_to_run))
with Pool(processes=nproc) as pool:
tasks = ((t, reset_fails, store_all_benchmarks, rtol, nproc) for t in tests_to_run)
imap_it = pool.imap_unordered(run_test_star, tasks)
# collect run results
for name, err in imap_it:
results[name] = err

# standalone tests
if single is None:
Expand Down Expand Up @@ -120,9 +182,9 @@ def do_tests(out_file,

p = argparse.ArgumentParser()

p.add_argument("-o",
help="name of file to output the report to (otherwise output to the screen",
type=str, nargs=1)
p.add_argument("--outfile", "-o",
help="name of file to output the report to (in addition to the screen)",
type=str, default=None)

p.add_argument("--single",
help="name of a single test (solver-problem) to run",
Expand All @@ -142,23 +204,18 @@ def do_tests(out_file,

p.add_argument("--rtol",
help="relative tolerance to use when comparing data to benchmarks",
type=float, nargs=1)
type=float, default=1.e-12)

args = p.parse_args()

try:
outfile = args.o[0]
except TypeError:
outfile = None
p.add_argument("--nproc", "-n",
help="maximum number of parallel processes to run, or 0 to use all cores",
type=int, default=1)

try:
rtol = args.rtol[0]
except TypeError:
rtol = 1.e-12
args = p.parse_args()

failed = do_tests(outfile,
failed = do_tests(args.outfile,
reset_fails=args.reset_failures,
store_all_benchmarks=args.store_all_benchmarks,
single=args.single, solver=args.solver, rtol=rtol)
single=args.single, solver=args.solver, rtol=args.rtol,
nproc=args.nproc)

sys.exit(failed)