Skip to content

Simplify the SplitState application for optimizers TBE SSD #4492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,16 @@
# pyre-strict

import enum
from typing import Any, Dict # noqa: F401
import itertools
from typing import Any, Dict, List, Tuple # noqa: F401

import torch

from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
EmbeddingLocation,
SplitState,
)


@enum.unique
class EmbOptimType(enum.Enum):
Expand Down Expand Up @@ -68,6 +74,49 @@ def state_size_nbytes(
else:
return 0

def ssd_state_splits(
self,
embedding_specs: List[Tuple[int, int]], # Tuple of (rows, dims)
optimizer_state_dtypes: Dict[str, "SparseType"] = {}, # noqa: B006
enable_optimizer_offloading: bool = False,
) -> List[Tuple[SplitState, str, torch.dtype]]:
"""
Returns the split planning for the optimizer states
"""
(rows, _) = zip(*embedding_specs)
T_ = len(embedding_specs)

# This is the cumulative row counts for rowwise states
row_count_cumsum: List[int] = [0] + list(itertools.accumulate(rows))
# This is the cumulative element counts for elementwise states
table_size_cumsum: List[int] = [0] + list(
itertools.accumulate([r * d for r, d in embedding_specs])
)

if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
params = {"momentum1": row_count_cumsum}
elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
params = {"momentum1": table_size_cumsum, "momentum2": row_count_cumsum}
else:
params = {}

return [
(
SplitState(
dev_size=(
cumsum_table[-1] if not enable_optimizer_offloading else 0
),
host_size=0,
uvm_size=0,
placements=[EmbeddingLocation.DEVICE for _ in range(T_)],
offsets=cumsum_table[:-1],
),
name,
self._extract_dtype(optimizer_state_dtypes, name),
)
for (name, cumsum_table) in params.items()
]

def dtype(self) -> torch.dtype:
"""
Returns the dtype of the optimizer state
Expand Down
21 changes: 7 additions & 14 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,21 +852,14 @@ def __init__(
dtype=table_embedding_dtype,
)

momentum1_offsets = [0] + list(itertools.accumulate(rows))
self._apply_split(
SplitState(
dev_size=(
self.total_hash_size if not self.enable_optimizer_offloading else 0
),
host_size=0,
uvm_size=0,
placements=[EmbeddingLocation.DEVICE for _ in range(T_)],
offsets=momentum1_offsets[:-1],
),
"momentum1",
# Create the optimizer state tensors
for template in self.optimizer.ssd_state_splits(
self.embedding_specs,
self.optimizer_state_dtypes,
self.enable_optimizer_offloading,
):
# pyre-fixme[6]: For 3rd argument expected `Type[dtype]` but got `dtype`.
dtype=torch.float32,
)
self._apply_split(*template)

# For storing current iteration data
self.current_iter_data: Optional[IterData] = None
Expand Down
Loading