Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion TraceLens/EventReplay/batched_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import argparse
import warnings
import torch
from utils import TensorCfg, build_tensor, benchmark_func, summarize_tensor, dict_profile2torchdtype
from utils import TensorCfg, build_tensor, benchmark_func

def _get_args_kwargs_from_ir(event_replay_IR: dict[str, any], device: str = 'cuda') -> tuple[list[any], dict[str, any]]:
# (Copy the implementation of _get_args_kwargs_from_ir from Step 1 here)
Expand Down
21 changes: 10 additions & 11 deletions TraceLens/EventReplay/event_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,10 @@
import re
import warnings
import time
try:
import torch
except ImportError:
raise ImportError("PyTorch is required for EventReplayer")

from .utils import (TensorCfg, build_tensor, benchmark_func,
list_profile_tensor_types, dict_profile2torchdtype)
from .utils import (_get_torch_or_raise, TensorCfg, build_tensor,
list_profile_tensor_types)

class EventReplayer:
def __init__(self, event: Dict[str, Any], device: str = 'cuda', lazy: bool = False, verbose: bool = False):
"""
Expand Down Expand Up @@ -65,7 +62,8 @@ def _setup(self):
def replay(self):
"""
Replay the event using the matched schema and event replay IR.
"""
"""
torch = _get_torch_or_raise()
# Get the function from the schema
func, _ = torch._C._jit_get_operation(self.event['name'])

Expand All @@ -80,7 +78,8 @@ def replay(self):


@staticmethod
def _search_schema(event: Dict[str, Any], verbose: bool = False) -> Optional[torch._C.FunctionSchema]:
def _search_schema(event: Dict[str, Any], verbose: bool = False) -> Optional['torch._C.FunctionSchema']:
torch = _get_torch_or_raise()
all_schemas = torch._C._jit_get_all_schemas()
op_schemas = [s for s in all_schemas if s.name == event['name']]
# print each schema in separate line
Expand All @@ -105,7 +104,7 @@ def _search_schema(event: Dict[str, Any], verbose: bool = False) -> Optional[tor
raise ValueError(f"Cannot find matching schema for {event['name']}. Please check the event data and schema.")

@staticmethod
def _is_schema_match(event: Dict[str, Any], schema: torch._C.FunctionSchema, verbose: bool = False) -> bool:
def _is_schema_match(event: Dict[str, Any], schema: 'torch._C.FunctionSchema', verbose: bool = False) -> bool:
"""
Check if the event matches the schema.

Expand Down Expand Up @@ -196,7 +195,7 @@ def _is_schema_match(event: Dict[str, Any], schema: torch._C.FunctionSchema, ver


@staticmethod
def _get_event_replay_IR(event: Dict[str, Any], schema: torch._C.FunctionSchema, verbose: bool = False) -> Dict[str, Any]:
def _get_event_replay_IR(event: Dict[str, Any], schema: 'torch._C.FunctionSchema', verbose: bool = False) -> Dict[str, Any]:
"""
Get the event replay IR from the event and schema.

Expand Down Expand Up @@ -276,7 +275,7 @@ def _get_event_replay_IR(event: Dict[str, Any], schema: torch._C.FunctionSchema,


@staticmethod
def _get_args_kwargs(event_replay_IR: Dict[str, Any], device: str = 'cuda') -> tuple[List[torch.Tensor], Dict[str, Any]]:
def _get_args_kwargs(event_replay_IR: Dict[str, Any], device: str = 'cuda') -> tuple[List['torch.Tensor'], Dict[str, Any]]:
"""
Get the arguments and keyword arguments from the event replay IR.

Expand Down
35 changes: 26 additions & 9 deletions TraceLens/EventReplay/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
from typing import List, Dict, Tuple
from typing import List, Dict, Tuple, Any
import time
import torch

_torch_module = None
def _get_torch_or_raise() -> Any: # Changed return type to Any for flexibility
"""Lazily imports and returns the torch module."""
global _torch_module
if _torch_module is None:
try:
import torch
_torch_module = torch
except ImportError:
raise ImportError(
"PyTorch is required for EventReplayer functionality that is being used. "
"Please install PyTorch."
)
return _torch_module

list_profile_tensor_types = ['float', 'c10::Half', 'c10::BFloat16']
dict_profile2torchdtype = {
'float': torch.float32,
'c10::Half': torch.float16,
'c10::BFloat16': torch.bfloat16,
}

from dataclasses import dataclass
@dataclass
Expand All @@ -20,7 +28,15 @@ class TensorCfg:
dtype: str
strides: List[int]

def build_tensor(cfg: TensorCfg, device: str='cuda') -> torch.Tensor:

def build_tensor(cfg: TensorCfg, device: str='cuda') -> 'torch.Tensor':

torch = _get_torch_or_raise()
dict_profile2torchdtype = {
'float': torch.float32,
'c10::Half': torch.float16,
'c10::BFloat16': torch.bfloat16,
}
dtype = dict_profile2torchdtype[cfg.dtype]
size = cfg.shape
stride = cfg.strides
Expand All @@ -29,7 +45,7 @@ def build_tensor(cfg: TensorCfg, device: str='cuda') -> torch.Tensor:
t.normal_() # or whatever init you like
return t

def summarize_tensor(tensor: torch.Tensor) -> str:
def summarize_tensor(tensor: 'torch.Tensor') -> str:
"""
Summarize the tensor information.

Expand All @@ -52,6 +68,7 @@ def benchmark_func(func, device, warmup=50, avg_steps=100):
Returns:
float: Average time taken per iteration in microseconds.
"""
torch = _get_torch_or_raise()
# Warmup phase
for _ in range(warmup):
func()
Expand Down