-
Notifications
You must be signed in to change notification settings - Fork 308
Open
Labels
good first issueGood for newcomersGood for newcomerstopic: new featureUse this tag if this PR adds a new featureUse this tag if this PR adds a new feature
Description
Summary
This pattern is very common and can be implemented generically.
The only times this will change is when we need to spoof our actual size, which is uncommon NJT is the only one I can think of
def _apply_fn_to_data(self, fn: Callable):
"""Applies a fn to all tensor components stored on this class"""
tensor_names, ctx = self.__tensor_flatten__()
# Apply the function to each tensor component
new_tensors = {}
for name in tensor_names:
new_tensors[name] = fn(getattr(self, name))
return self.__class__.__tensor_unflatten__(
new_tensors,
ctx,
None, # outer_size parameter
None, # outer_stride parameter
)
Metadata
Metadata
Assignees
Labels
good first issueGood for newcomersGood for newcomerstopic: new featureUse this tag if this PR adds a new featureUse this tag if this PR adds a new feature