-
Notifications
You must be signed in to change notification settings - Fork 127
Add OneDiffCheckpointLoader #457
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0055a42
35f57ae
05db45a
8a811ad
9e396e2
22ef3e3
0f491f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,10 +56,12 @@ def _align_tensor(torch_module, oneflow_module): | |
| [x for x, _ in oneflow_module.named_parameters()] | ||
| + [x for x, _ in oneflow_module.named_buffers()] | ||
| ) | ||
| for name, tensor in chain.from_iterable([ | ||
| torch_module.named_parameters(), | ||
| torch_module.named_buffers(), | ||
| ]): | ||
| for name, tensor in chain.from_iterable( | ||
| [ | ||
| torch_module.named_parameters(), | ||
| torch_module.named_buffers(), | ||
| ] | ||
| ): | ||
| if name not in oneflow_tensor_list: | ||
| tensor.data = tensor.to(*args, **kwargs) | ||
| else: | ||
|
|
@@ -76,7 +78,6 @@ def _align_tensor(torch_module, oneflow_module): | |
| else: | ||
| _align_tensor(module, self._oneflow_module.get_submodule(name)) | ||
|
|
||
|
|
||
| def __getattr__(self, name): | ||
| if name == "_torch_module": | ||
| return self._modules[name] | ||
|
|
@@ -124,7 +125,11 @@ def __init__(self, torch_modules, oneflow_modules): | |
| for torch_module, oneflow_module in zip( | ||
| self._torch_modules, self._oneflow_modules | ||
| ): | ||
| dual_modules.append(get_mixed_dual_module(torch_module.__class__)(torch_module, oneflow_module)) | ||
| dual_modules.append( | ||
| get_mixed_dual_module(torch_module.__class__)( | ||
| torch_module, oneflow_module | ||
| ) | ||
| ) | ||
| # clear self._modules since `self._torch_modules = torch_modules` will append a module to self._modules | ||
| self._modules.clear() | ||
| self += dual_modules | ||
|
|
@@ -147,6 +152,7 @@ def __setattr__(self, key, value): | |
| setattr(self._oneflow_modules, key, value) | ||
| return object.__setattr__(self, key, value) | ||
|
|
||
|
|
||
| def get_mixed_dual_module(module_cls): | ||
| class MixedDualModule(DualModule, module_cls): | ||
| def __init__(self, torch_module, oneflow_module): | ||
|
|
@@ -155,32 +161,81 @@ def __init__(self, torch_module, oneflow_module): | |
| return MixedDualModule | ||
|
|
||
|
|
||
| def load_graph_from_config(self: "DeployableModule"): | ||
| try: | ||
| if self._graph_config is not None: | ||
| graph_path = self._graph_config[0] | ||
| if not os.path.exists(graph_path): | ||
| logger.warning( | ||
| f"Graph file {graph_path} not exists, will generate graph." | ||
| ) | ||
| return | ||
| graph_device = torch2oflow(self._graph_config[1]) | ||
| self.load_graph(graph_path, graph_device) | ||
| self._graph_config = None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. _graph_config 这个命名改准确一点吧,太宽泛了 |
||
| except Exception as e: | ||
| logger.error(f"Exception in load_graph_from_config: {e=}") | ||
|
|
||
|
|
||
| def save_graph_to_config(self: "DeployableModule"): | ||
| try: | ||
| if self._graph_config is not None: | ||
| graph_file = self._graph_config[0] | ||
| os.makedirs(os.path.dirname(graph_file), exist_ok=True) | ||
| self.save_graph(self._graph_config[0]) | ||
| logger.info(f"Save graph to {self._graph_config[0]} done!") | ||
| except Exception as e: | ||
| logger.error(f"Exception in save_graph_to_config: {e=}") | ||
| finally: | ||
| self._graph_config = None | ||
|
|
||
|
|
||
| def handle_deployable_exception(func): | ||
| @wraps(func) | ||
| def wrapper(self, *args, **kwargs): | ||
| def _run_func(): | ||
| load_graph_from_config(self) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个功能单独写个 decorator ? 放到 handle_deployable_exception 感觉不合适
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改: #460 |
||
| result = func(self, *args, **kwargs) | ||
| save_graph_to_config(self) | ||
| return result | ||
|
|
||
| if transform_mgr.debug_mode: | ||
| return func(self, *args, **kwargs) | ||
| return _run_func() | ||
| else: | ||
| try: | ||
| return func(self, *args, **kwargs) | ||
| return _run_func() | ||
| except Exception as e: | ||
| logger.error(f"Exception in {func.__name__}: {e=}") | ||
| logger.warning("Recompile oneflow module ...") | ||
| del self._deployable_module_model.oneflow_module | ||
| self._deployable_module_dpl_graph = None | ||
| return func(self, *args, **kwargs) | ||
| return _run_func() | ||
|
|
||
| return wrapper | ||
|
|
||
|
|
||
| class DeployableModule(torch.nn.Module): | ||
| def __init__(self, torch_module, oneflow_module, use_graph=True, options={}): | ||
| def __init__( | ||
| self, | ||
| torch_module, | ||
| oneflow_module, | ||
| use_graph=True, | ||
| options={}, | ||
| graph_path=None, | ||
| graph_device=None, | ||
| ): | ||
| torch.nn.Module.__init__(self) | ||
| self._deployable_module_model = get_mixed_dual_module(torch_module.__class__)(torch_module, oneflow_module) | ||
| self._deployable_module_model = get_mixed_dual_module(torch_module.__class__)( | ||
| torch_module, oneflow_module | ||
| ) | ||
| self._deployable_module_use_graph = use_graph | ||
| self._deployable_module_options = options | ||
| self._deployable_module_dpl_graph = None | ||
| self._is_raw_deployable_module = True | ||
| if graph_path is not None: | ||
ccssu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self._graph_config = (graph_path, graph_device) | ||
| else: | ||
| self._graph_config = None | ||
|
|
||
| @classmethod | ||
| def from_existing(cls, existing_module, use_graph=None, options=None): | ||
|
|
@@ -190,6 +245,7 @@ def from_existing(cls, existing_module, use_graph=None, options=None): | |
| instance._deployable_module_dpl_graph = ( | ||
| existing_module._deployable_module_dpl_graph if use_graph else None | ||
| ) | ||
| instance._graph_config = existing_module._graph_config | ||
| return instance | ||
|
|
||
| def get_graph(self): | ||
|
|
@@ -245,7 +301,10 @@ def to(self, *args, **kwargs): | |
|
|
||
| # assert the target device is same as graph device | ||
| target_device = parse_device(args, kwargs) | ||
| if target_device is not None and len(self._deployable_module_dpl_graph._blocks) > 0: | ||
| if ( | ||
| target_device is not None | ||
| and len(self._deployable_module_dpl_graph._blocks) > 0 | ||
| ): | ||
| current_device = next(self._deployable_module_dpl_graph._state()).device | ||
| if not check_device(current_device, target_device): | ||
| raise RuntimeError( | ||
|
|
@@ -357,11 +416,25 @@ def state_dict_hook(module, state_dict, prefix, local_metadata): | |
|
|
||
|
|
||
| # Return a DeployableModule that using module_cls as it's parent class. | ||
| def get_mixed_deployable_module(module_cls): | ||
| def get_mixed_deployable_module(module_cls, graph_path=None, graph_device=None): | ||
| class MixedDeployableModule(DeployableModule, module_cls): | ||
| def __init__(self, torch_module, oneflow_module, use_graph=True, options={}): | ||
| def __init__( | ||
| self, | ||
| torch_module, | ||
| oneflow_module, | ||
| use_graph=True, | ||
| options={}, | ||
| graph_path=None, | ||
| graph_device=None, | ||
| ): | ||
| DeployableModule.__init__( | ||
| self, torch_module, oneflow_module, use_graph, options | ||
| self, | ||
| torch_module, | ||
| oneflow_module, | ||
| use_graph, | ||
| options, | ||
| graph_path, | ||
| graph_device, | ||
| ) | ||
| self._is_raw_deployable_module = False | ||
|
|
||
|
|
@@ -378,7 +451,14 @@ def from_existing(cls, existing_module, use_graph=None, options=None): | |
| return MixedDeployableModule | ||
|
|
||
|
|
||
| def oneflow_compile(torch_module: torch.nn.Module, *, use_graph=True, options={}): | ||
| def oneflow_compile( | ||
| torch_module: torch.nn.Module, | ||
| *, | ||
| use_graph=True, | ||
| options={}, | ||
| graph_path=None, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这些放到 options 里面吧,oneflow_compile 要谨慎扩展参数
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的就是传入 options = { "graph_config": (graph_path, graph_device)} 这种吗 还是 options = { "graph_path":graph_path, "graph_device":graph_device}
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. graph_file, graph_file_device |
||
| graph_device=None, | ||
| ): | ||
| set_default_registry() | ||
|
|
||
| def wrap_module(module): | ||
|
|
@@ -387,7 +467,7 @@ def wrap_module(module): | |
| return module.__class__.from_existing(module, use_graph, options) | ||
| else: | ||
| return get_mixed_deployable_module(module.__class__)( | ||
| module, None, use_graph, options | ||
| module, None, use_graph, options, graph_path, graph_device | ||
| ) | ||
|
|
||
| model = wrap_module(torch_module) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.