Skip to content

Commit ca37c59

Browse files
Cacheing doc nits (#1876)
1 parent 73aa126 commit ca37c59

File tree

3 files changed

+29
-21
lines changed

3 files changed

+29
-21
lines changed

torchtune/modules/common_utils.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -199,18 +199,18 @@ def disable_kv_cache(model: nn.Module) -> Generator[None, None, None]:
199199
>>> # now temporarily disable caches
200200
>>> with disable_kv_cache(model):
201201
>>> print(model.caches_are_setup())
202-
>>> True
202+
True
203203
>>> print(model.caches_are_enabled())
204-
>>> False
204+
False
205205
>>> print(model.layers[0].attn.kv_cache)
206-
>>> # KVCache()
206+
KVCache()
207207
>>> # caches are now re-enabled, and their state is untouched
208208
>>> print(model.caches_are_setup())
209209
True
210210
>>> print(model.caches_are_enabled())
211211
True
212212
>>> print(model.layers[0].attn.kv_cache)
213-
>>> KVCache()
213+
KVCache()
214214
215215
Args:
216216
model (nn.Module): model to disable KV-cacheing for.
@@ -219,7 +219,8 @@ def disable_kv_cache(model: nn.Module) -> Generator[None, None, None]:
219219
None: Returns control to the caller with KV-caches disabled on the given model.
220220
221221
Raises:
222-
ValueError: If the model does not have caches setup.
222+
ValueError: If the model does not have caches setup. Use :func:`~torchtune.modules.TransformerDecoder.setup_caches` to
223+
setup caches first.
223224
"""
224225
if not model.caches_are_setup():
225226
raise ValueError(
@@ -306,6 +307,7 @@ def local_kv_cache(
306307
307308
Raises:
308309
ValueError: If the model already has caches setup.
310+
You may use :func:`~torchtune.modules.common_utils.delete_kv_caches` to delete existing caches.
309311
"""
310312
if model.caches_are_setup():
311313
raise ValueError(
@@ -340,29 +342,31 @@ def delete_kv_caches(model: nn.Module):
340342
>>> dtype=torch.float32,
341343
>>> decoder_max_seq_len=1024)
342344
>>> print(model.caches_are_setup())
343-
>>> True
345+
True
344346
>>> print(model.caches_are_enabled())
345-
>>> True
347+
True
346348
>>> print(model.layers[0].attn.kv_cache)
347-
>>> KVCache()
349+
KVCache()
348350
>>> delete_kv_caches(model)
349351
>>> print(model.caches_are_setup())
350-
>>> False
352+
False
351353
>>> print(model.caches_are_enabled())
352-
>>> False
354+
False
353355
>>> print(model.layers[0].attn.kv_cache)
354-
>>> None
356+
None
357+
355358
Args:
356359
model (nn.Module): model to enable KV-cacheing for.
357360
358361
Raises:
359-
ValueError: if ``delete_kv_caches`` is called on a model which does not have
360-
caches setup.
362+
ValueError: if this function is called on a model which does not have
363+
caches setup. Use :func:`~torchtune.modules.TransformerDecoder.setup_caches` to
364+
setup caches first.
361365
"""
362366
if not model.caches_are_setup():
363367
raise ValueError(
364-
"You have tried to delete model caches, but `model.caches_are_setup()` "
365-
"is False!"
368+
"You have tried to delete model caches, but model.caches_are_setup() "
369+
"is False! Please setup caches on the model first."
366370
)
367371
for module in model.modules():
368372
if hasattr(module, "kv_cache") and callable(module.kv_cache):

torchtune/modules/model_fusion/_fusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def caches_are_enabled(self) -> bool:
405405
Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant
406406
attention modules will be "enabled" and all forward passes will update the caches. This behaviour
407407
can be disabled without altering the state of the KV-caches by "disabling" the KV-caches
408-
using ``torchtune.modules.disable_kv_cache``, upon which ``caches_are_enabled`` would return False.
408+
using :func:`~torchtune.modules.common_utils.disable_kv_cache`, upon which ``caches_are_enabled`` would return False.
409409
"""
410410
return self.decoder.caches_are_enabled()
411411

torchtune/modules/transformer.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,9 @@ def setup_caches(
410410
):
411411
"""
412412
Sets up key-value attention caches for inference. For each layer in ``self.layers``:
413-
- :class:`~torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``.
414-
- :class:`~torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``.
415-
- :class:`~torchtune.modules.fusion.FusionLayer` will use both ``decoder_max_seq_len`` and ``encoder_max_seq_len``.
413+
- :class:`~torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``.
414+
- :class:`~torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``.
415+
- :class:`~torchtune.modules.model_fusion.FusionLayer` will use ``decoder_max_seq_len`` and ``encoder_max_seq_len``.
416416
417417
Args:
418418
batch_size (int): batch size for the caches.
@@ -460,18 +460,22 @@ def caches_are_enabled(self) -> bool:
460460
Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant
461461
attention modules will be "enabled" and all forward passes will update the caches. This behaviour
462462
can be disabled without altering the state of the KV-caches by "disabling" the KV-caches
463-
using ``torchtune.modules.disable_kv_cache``, upon which ``caches_are_enabled`` would return False.
463+
using :func:`torchtune.modules.common_utils.disable_kv_cache`, upon which ``caches_are_enabled`` would return False.
464464
"""
465465
return self.layers[0].caches_are_enabled()
466466

467467
def reset_caches(self):
468468
"""
469469
Resets KV-cache buffers on relevant attention modules to zero, and reset cache positions to zero,
470470
without deleting or reallocating cache tensors.
471+
472+
Raises:
473+
RuntimeError: if KV-caches are not setup. Use :func:`~torchtune.modules.TransformerDecoder.setup_caches` to
474+
setup caches first.
471475
"""
472476
if not self.caches_are_enabled():
473477
raise RuntimeError(
474-
"Key value caches are not setup. Call ``setup_caches()`` first."
478+
"Key value caches are not setup. Call model.setup_caches first."
475479
)
476480

477481
for layer in self.layers:

0 commit comments

Comments
 (0)