@@ -272,6 +272,12 @@ class GenState(eqx.Module):
272272 cache : PageCache
273273 decode_state : DecodeState
274274
275+ def reset (self ):
276+ return GenState (
277+ cache = self .cache .reset (),
278+ decode_state = self .decode_state .reset (),
279+ )
280+
275281 def clone_sequence (
276282 self , parent_local_id : int , child_local_id : int | None = None , seq_params : SeqDecodingParams | None = None
277283 ) -> tuple ["GenState" , int ]:
@@ -797,22 +803,6 @@ def __init__(
797803 # Results by request id -> choice -> DecodeResult
798804 self .results : dict [int , dict [int , DecodeResult ]] = {}
799805
800- def _verify_free_slot_view (self , * , context : str ) -> None :
801- """Ensure host free-list matches the device page-table used mask."""
802-
803- used_mask = np .asarray (jax .device_get (self .gen_state .decode_state .sequences .used_mask .array )).astype (bool )
804- free_set = set (self .free_slots )
805-
806- for slot_id , is_used in enumerate (used_mask ):
807- if is_used and slot_id in free_set :
808- raise RuntimeError (
809- f"[free slot invariant] slot { slot_id } marked used but present in free list during { context } "
810- )
811- if not is_used and slot_id not in free_set :
812- raise RuntimeError (
813- f"[free slot invariant] slot { slot_id } free in page table but missing from free list during { context } "
814- )
815-
816806 @classmethod
817807 def from_model_with_config (
818808 cls ,
@@ -853,26 +843,12 @@ def reset(self) -> None:
853843
854844 Keeps the KV cache memory allocated. Reuses current `PageTable` object with pages freed.
855845 """
856- decode_state = self .gen_state .decode_state
857- page_table = decode_state .page_table
858- sequences = decode_state .sequences
859- for slot_id in range (page_table .max_seqs ):
860- sequences , page_table = sequences .free_pages (page_table , slot_id )
861-
862- new_decode_state = DecodeState .init (
863- page_table ,
864- max_stop_seqs = self .config .max_stop_seqs ,
865- max_stop_tokens = self .config .max_stop_tokens ,
866- max_queued_tokens = self .config .max_queued_tokens ,
867- )
868- self .gen_state = dataclasses .replace (self .gen_state , decode_state = new_decode_state )
869- self .free_slots = list (range (int (page_table .max_seqs )))
846+ self .gen_state = self .gen_state .reset ()
847+ self .free_slots = list (range (int (self .gen_state .decode_state .max_seqs )))
870848 self .local_map .clear ()
871849 self .sequences .clear ()
872850 self .results = {}
873851
874- self ._verify_free_slot_view (context = "reset" )
875-
876852 def _prefill_batch (self , batch : Sequence [Request ]) -> _DecodeOutputs | None :
877853 """Admit a batch from the head of the queue that fits in free slots/pages.
878854
0 commit comments