Weaken num_new_engines into has_new_engines#938
Weaken num_new_engines into has_new_engines#938fzyzcjy wants to merge 4 commits intorollout_ft/20from
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the rollout engine tracking logic by replacing the integer-based num_new_engines with a boolean has_new_engines across the FSDP, Megatron, and Ray rollout modules. The start_engines method was also modified to return a tuple containing initialization handles and the count of new engines. Review feedback identifies several instances where docstrings, comments, and assertion messages were not updated to reflect these variable name and return type changes.
| logger.info(f"Recovered {self.num_new_engines} dead rollout engines (worker_type={self.worker_type})") | ||
| assert self.num_new_engines == len(dead_indices), "num_new_engines does not match dead_indices length" | ||
| logger.info(f"Recovered {curr_num_new_engines} dead rollout engines (worker_type={self.worker_type})") | ||
| assert curr_num_new_engines == len(dead_indices), "curr_num_new_engines does not match dead_indices length" |
There was a problem hiding this comment.
The assertion message still refers to num_new_engines. It should be updated to match the new variable name curr_num_new_engines for consistency.
| assert curr_num_new_engines == len(dead_indices), "curr_num_new_engines does not match dead_indices length" | |
| assert curr_num_new_engines == len(dead_indices), "curr_num_new_engines does not match dead_indices length" |
| def clear_updatable_has_new_engines(self): | ||
| # when fault tolerance is not enabled, we need to manually clear has_new_engines after update_weights |
There was a problem hiding this comment.
The comment on line 199 still refers to num_new_engines. It should be updated to has_new_engines to reflect the recent changes.
| def clear_updatable_has_new_engines(self): | |
| # when fault tolerance is not enabled, we need to manually clear has_new_engines after update_weights | |
| def clear_updatable_has_new_engines(self): | |
| # when fault tolerance is not enabled, we need to manually clear has_new_engines after update_weights |
| def start_engines(self, port_cursors: PortCursors) -> tuple[list, int]: | ||
| """Create Ray actors, allocate ports, and fire ``engine.init()`` without waiting. | ||
|
|
||
| Returns ``(init_handles, port_cursors)`` where *init_handles* is a list | ||
| Returns ``(init_handles, curr_num_new_engines)`` where *init_handles* is a list | ||
| of Ray ObjectRefs and *port_cursors* maps node index -> next free port. | ||
| """ |
There was a problem hiding this comment.
The docstring for start_engines is outdated. It incorrectly states that the method returns (init_handles, port_cursors), whereas it now returns (init_handles, curr_num_new_engines). Additionally, port_cursors is modified in-place rather than returned.
| def start_engines(self, port_cursors: PortCursors) -> tuple[list, int]: | |
| """Create Ray actors, allocate ports, and fire ``engine.init()`` without waiting. | |
| Returns ``(init_handles, port_cursors)`` where *init_handles* is a list | |
| Returns ``(init_handles, curr_num_new_engines)`` where *init_handles* is a list | |
| of Ray ObjectRefs and *port_cursors* maps node index -> next free port. | |
| """ | |
| def start_engines(self, port_cursors: PortCursors) -> tuple[list, int]: | |
| """Create Ray actors, allocate ports, and fire engine.init() without waiting. | |
| Returns (init_handles, curr_num_new_engines) where init_handles is a list | |
| of Ray ObjectRefs and curr_num_new_engines is the number of newly started engines. | |
| """ |
to prepare for multi start() in one step