diff --git a/onnxscript/_framework_apis/torch_2_5.py b/onnxscript/_framework_apis/torch_2_5.py index 162faf4b75..5bbb64af88 100644 --- a/onnxscript/_framework_apis/torch_2_5.py +++ b/onnxscript/_framework_apis/torch_2_5.py @@ -13,6 +13,7 @@ ] import dataclasses +import importlib.util import os import pathlib from typing import Callable @@ -63,7 +64,9 @@ def check_model(model: ir.Model) -> None: del model # Unused yet -def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike) -> None: +def save_model_with_external_data( + model: ir.Model, model_path: str | os.PathLike, verbose: bool = False +) -> None: """Save the model with external data. The model is unchanged after saving.""" # TODO(#1835): Decide if we want to externalize large attributes as well @@ -78,7 +81,31 @@ def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike destination_path = pathlib.Path(model_path) data_path = f"{destination_path.name}.data" - ir.save(model, model_path, external_data=data_path) + # Show a progress bar if verbose is True and tqdm is installed + use_tqdm = verbose and importlib.util.find_spec("tqdm") is not None + + if use_tqdm: + import tqdm # pylint: disable=import-outside-toplevel + + with tqdm.tqdm() as pbar: + total_set = False + + def callback( + tensor: ir.TensorProtocol, metadata: ir.external_data.CallbackInfo + ) -> None: + nonlocal total_set + if not total_set: + pbar.total = metadata.total + total_set = True + + pbar.update() + pbar.set_description( + f"Saving {tensor.name} ({tensor.dtype.short_name()}, {tensor.shape}) at offset {metadata.offset}" + ) + + ir.save(model, model_path, external_data=data_path, callback=callback) + else: + ir.save(model, model_path, external_data=data_path) def get_torchlib_ops() -> list[_OnnxFunctionMeta]: