@@ -199,18 +199,18 @@ def disable_kv_cache(model: nn.Module) -> Generator[None, None, None]:
199199 >>> # now temporarily disable caches
200200 >>> with disable_kv_cache(model):
201201 >>> print(model.caches_are_setup())
202- >>> True
202+ True
203203 >>> print(model.caches_are_enabled())
204- >>> False
204+ False
205205 >>> print(model.layers[0].attn.kv_cache)
206- >>> # KVCache()
206+ KVCache()
207207 >>> # caches are now re-enabled, and their state is untouched
208208 >>> print(model.caches_are_setup())
209209 True
210210 >>> print(model.caches_are_enabled())
211211 True
212212 >>> print(model.layers[0].attn.kv_cache)
213- >>> KVCache()
213+ KVCache()
214214
215215 Args:
216216 model (nn.Module): model to disable KV-cacheing for.
@@ -219,7 +219,8 @@ def disable_kv_cache(model: nn.Module) -> Generator[None, None, None]:
219219 None: Returns control to the caller with KV-caches disabled on the given model.
220220
221221 Raises:
222- ValueError: If the model does not have caches setup.
222+ ValueError: If the model does not have caches setup. Use :func:`~torchtune.modules.TransformerDecoder.setup_caches` to
223+ setup caches first.
223224 """
224225 if not model .caches_are_setup ():
225226 raise ValueError (
@@ -306,6 +307,7 @@ def local_kv_cache(
306307
307308 Raises:
308309 ValueError: If the model already has caches setup.
310+ You may use :func:`~torchtune.modules.common_utils.delete_kv_caches` to delete existing caches.
309311 """
310312 if model .caches_are_setup ():
311313 raise ValueError (
@@ -340,29 +342,31 @@ def delete_kv_caches(model: nn.Module):
340342 >>> dtype=torch.float32,
341343 >>> decoder_max_seq_len=1024)
342344 >>> print(model.caches_are_setup())
343- >>> True
345+ True
344346 >>> print(model.caches_are_enabled())
345- >>> True
347+ True
346348 >>> print(model.layers[0].attn.kv_cache)
347- >>> KVCache()
349+ KVCache()
348350 >>> delete_kv_caches(model)
349351 >>> print(model.caches_are_setup())
350- >>> False
352+ False
351353 >>> print(model.caches_are_enabled())
352- >>> False
354+ False
353355 >>> print(model.layers[0].attn.kv_cache)
354- >>> None
356+ None
357+
355358 Args:
356359 model (nn.Module): model to enable KV-cacheing for.
357360
358361 Raises:
359- ValueError: if ``delete_kv_caches`` is called on a model which does not have
360- caches setup.
362+ ValueError: if this function is called on a model which does not have
363+ caches setup. Use :func:`~torchtune.modules.TransformerDecoder.setup_caches` to
364+ setup caches first.
361365 """
362366 if not model .caches_are_setup ():
363367 raise ValueError (
364- "You have tried to delete model caches, but ` model.caches_are_setup()` "
365- "is False!"
368+ "You have tried to delete model caches, but model.caches_are_setup() "
369+ "is False! Please setup caches on the model first. "
366370 )
367371 for module in model .modules ():
368372 if hasattr (module , "kv_cache" ) and callable (module .kv_cache ):
0 commit comments