Skip to content

Commit 12cbc4b

Browse files
felipemello1Felipe Mello
andauthored
[doc][modules] Update to modules documentation (#1079)
Co-authored-by: Felipe Mello <[email protected]>
1 parent 2fe9a70 commit 12cbc4b

File tree

9 files changed

+59
-40
lines changed

9 files changed

+59
-40
lines changed

docs/source/api_ref_modules.rst

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Tokenizers
2929

3030
tokenizers.SentencePieceTokenizer
3131
tokenizers.TikTokenTokenizer
32+
tokenizers.Tokenizer
3233

3334
PEFT Components
3435
---------------
@@ -41,7 +42,9 @@ PEFT Components
4142
peft.AdapterModule
4243
peft.get_adapter_params
4344
peft.set_trainable_params
44-
45+
peft.validate_missing_and_unexpected_for_lora
46+
peft.validate_state_dict_for_lora
47+
peft.disable_adapter
4548

4649
Module Utilities
4750
------------------
@@ -52,3 +55,12 @@ These are utilities that are common to and can be used by all modules.
5255
:nosignatures:
5356

5457
common_utils.reparametrize_as_dtype_state_dict_post_hook
58+
59+
Loss
60+
------------------
61+
62+
.. autosummary::
63+
:toctree: generated/
64+
:nosignatures:
65+
66+
loss.DPOLoss

torchtune/modules/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,5 @@
2222
"RMSNorm",
2323
"TransformerDecoder",
2424
"TransformerDecoderLayer",
25-
"TransformerClassifier",
2625
"reparametrize_as_dtype_state_dict_post_hook",
2726
]

torchtune/modules/kv_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def reset(self) -> None:
5050
def update(
5151
self, input_pos: Tensor, k_val: Tensor, v_val: Tensor
5252
) -> Tuple[Tensor, Tensor]:
53-
"""Update KV cache and return the updated cache.
53+
"""Update KV cache with the new k_val, v_val and return the updated cache.
5454
5555
Args:
5656
input_pos (Tensor): Current position tensor with shape [S]

torchtune/modules/lr_schedulers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,16 @@ def get_cosine_schedule_with_warmup(
3838
torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.
3939
"""
4040

41-
def lr_lambda(current_step):
41+
def lr_lambda(current_step: int) -> float:
42+
# linear warmup phase
4243
if current_step < num_warmup_steps:
4344
return current_step / max(1, num_warmup_steps)
45+
46+
# cosine
4447
progress = (current_step - num_warmup_steps) / max(
4548
1, num_training_steps - num_warmup_steps
4649
)
50+
4751
cosine_lr_multiple = 0.5 * (
4852
1.0 + math.cos(math.pi * num_cycles * 2.0 * progress)
4953
)

torchtune/modules/peft/peft_utils.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -241,23 +241,43 @@ def get_merged_lora_ckpt(
241241

242242
@contextlib.contextmanager
243243
def disable_adapter(model: nn.Module) -> Generator[None, None, None]:
244-
for _, v in model.named_modules():
244+
"""
245+
Temporarily disable the adapters in a neural network model. This can be used,
246+
for example, in DPO for treating the lora adapters as the policy model
247+
and disabling it to treat the base model as the reference model.
248+
249+
This context manager goes through all modules in the provided neural network model,
250+
and if a module has an 'adapter_params' attribute that is callable and a 'disabled' attribute,
251+
it sets 'disabled' to True. Then, the control is given back to caller. Once that finalizes,
252+
it sets 'disabled' back to False for all modules that were temporarily disabled.
253+
254+
Args:
255+
model (nn.Module): The neural network model whose adapters are to be temporarily disabled.
256+
Yields:
257+
None: This function yields control back to the caller, with the adapters disabled.
258+
Example:
259+
>>> with disable_adapter(model):
260+
... # Perform operations with adapters disabled
261+
... pass
262+
263+
"""
264+
for _, module in model.named_modules():
245265
if (
246-
hasattr(v, "adapter_params")
247-
and callable(v.adapter_params)
248-
and hasattr(v, "disabled")
266+
hasattr(module, "adapter_params")
267+
and callable(module.adapter_params)
268+
and hasattr(module, "disabled")
249269
):
250-
v.disabled = True
270+
module.disabled = True
251271
try:
252272
yield
253273
finally:
254-
for _, v in model.named_modules():
274+
for _, module in model.named_modules():
255275
if (
256-
hasattr(v, "adapter_params")
257-
and callable(v.adapter_params)
258-
and hasattr(v, "disabled")
276+
hasattr(module, "adapter_params")
277+
and callable(module.adapter_params)
278+
and hasattr(module, "disabled")
259279
):
260-
v.disabled = False
280+
module.disabled = False
261281

262282

263283
def validate_missing_and_unexpected_for_lora(
@@ -272,7 +292,7 @@ def validate_missing_and_unexpected_for_lora(
272292
"""
273293
A more memory-efficient way to validate that LoRA state dict loading was done properly.
274294
275-
Similar to validate_state_dict_for_lora, this function uses a model's LoRA config to
295+
Similar to :func:`validate_state_dict_for_lora`, this function uses a model's LoRA config to
276296
check that LoRA and/or base model weights are loaded into the full model correctly.
277297
Unlike that function, this method relies only on the values of missing and unexpected
278298
as returned by the load_state_dict API with strict=False. This allows us to do the

torchtune/modules/tokenizers/_sentencepiece.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def encode(
6060
6161
Args:
6262
text (str): The input text to be encoded, unbatched.
63-
add_bos (bool): Whether to prepend BOS to the input, defaults to True.
64-
add_eos (bool): Whether to append EOS to the input, defaults to True.
63+
add_bos (bool): Whether to prepend BOS special token (Beginning of Sentence) to the input, defaults to True.
64+
add_eos (bool): Whether to append EOS special token (End of Sentence) to the input, defaults to True.
6565
trim_leading_whitespace (bool): Whether to trim leading whitespace from
6666
underlying sentencepiece tokenization. Sentencepiece normally prepends
6767
whitespace to any tokenized text, which can cause differences where

torchtune/modules/tokenizers/_tiktoken.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,11 +290,11 @@ def decode(
290290
"""
291291
if truncate_at_eos:
292292
try:
293-
k = token_ids.index(self.eos_id)
293+
idx_eos = token_ids.index(self.eos_id)
294294
except ValueError:
295-
k = None
296-
if k:
297-
token_ids = token_ids[:k]
295+
idx_eos = None
296+
if idx_eos:
297+
token_ids = token_ids[:idx_eos]
298298
token_ids = [token_id for token_id in token_ids if token_id != self.bos_id]
299299
return self.tt_model.decode(token_ids)
300300

torchtune/modules/tokenizers/_utils.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Iterator, List, Protocol, Set, Union
7+
from typing import Iterator, List, Protocol, Set
88

99
from torchtune.data._types import Message
1010

@@ -37,17 +37,6 @@ def tokenize_messages(self, token_ids: List[Message], **kwargs):
3737
pass
3838

3939

40-
def truncate(
41-
tokens: List[int],
42-
max_seq_len: int,
43-
eos_id: Union[int, bool],
44-
):
45-
tokens_truncated = tokens[:max_seq_len]
46-
if tokens_truncated[-1] != eos_id:
47-
tokens_truncated[-1] = eos_id
48-
return tokens_truncated
49-
50-
5140
def _split_long_repetitions(s: str, max_consecutive_slice_len: int) -> Iterator[str]:
5241
"""
5342
Split the string `s` so that each substring contains no more than `max_consecutive_slice_len`

torchtune/modules/transformer.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,6 @@ def forward(
6262
Tensor: output tensor with same shape as input
6363
[batch_size x seq_length x embed_dim]
6464
65-
Notation used for tensor shapes:
66-
- b: batch size
67-
- s: sequence length
68-
- d: embed dim
69-
7065
TODO:
7166
- Make position of norm configurable
7267
"""
@@ -75,13 +70,13 @@ def forward(
7570
# Norm applied before self-attention
7671
attn_out = self.attn(self.sa_norm(x), mask=mask, input_pos=input_pos)
7772

78-
# Residual connection; shape: [b, s, d]
73+
# Residual connection; shape: [batch_size, seq_length, embed_dim]
7974
h = attn_out + x
8075

8176
# Norm applied before the feedforward layer
8277
mlp_out = self.mlp(self.mlp_norm(h))
8378

84-
# Residual connection; shape: [b, s, d]
79+
# Residual connection; shape: [batch_size, seq_length, embed_dim]
8580
out = h + mlp_out
8681
return out
8782

0 commit comments

Comments
 (0)