Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions miles/backends/fsdp_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,10 +553,10 @@ def update_weights(self) -> None: # type: ignore[override]
if self.args.debug_train_only or self.args.debug_rollout_only:
return

rollout_engines, rollout_engine_lock, num_new_engines, engine_gpu_counts, engine_gpu_offsets = ray.get(
rollout_engines, rollout_engine_lock, has_new_engines, engine_gpu_counts, engine_gpu_offsets = ray.get(
self.rollout_manager.get_updatable_engines_and_lock.remote()
)
if num_new_engines > 0:
if has_new_engines:
self.weight_updater.connect_rollout_engines(
rollout_engines,
rollout_engine_lock,
Expand All @@ -565,7 +565,7 @@ def update_weights(self) -> None: # type: ignore[override]
)
dist.barrier(group=get_gloo_group())
if dist.get_rank() == 0:
ray.get(self.rollout_manager.clear_updatable_num_new_engines.remote())
ray.get(self.rollout_manager.clear_updatable_has_new_engines.remote())

self.weight_updater.update_weights()

Expand Down
6 changes: 3 additions & 3 deletions miles/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,14 +494,14 @@ def update_weights(self) -> None:
ray.get(self.rollout_manager.recover_updatable_engines.remote())
dist.barrier(group=get_gloo_group())

rollout_engines, rollout_engine_lock, num_new_engines, engine_gpu_counts, engine_gpu_offsets = ray.get(
rollout_engines, rollout_engine_lock, has_new_engines, engine_gpu_counts, engine_gpu_offsets = ray.get(
self.rollout_manager.get_updatable_engines_and_lock.remote()
)

if self.args.offload_train:
reload_process_groups()

if num_new_engines > 0:
if has_new_engines:
self.weight_updater.connect_rollout_engines(
rollout_engines,
rollout_engine_lock,
Expand All @@ -510,7 +510,7 @@ def update_weights(self) -> None:
)
dist.barrier(group=get_gloo_group())
if dist.get_rank() == 0:
ray.get(self.rollout_manager.clear_updatable_num_new_engines.remote())
ray.get(self.rollout_manager.clear_updatable_has_new_engines.remote())

if self.args.offload_train and is_lora_enabled(self.args):
# For LoRA, we must resume() to restore GPU memory backing for adapter
Expand Down
12 changes: 6 additions & 6 deletions miles/ray/rollout/rollout_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,17 +192,17 @@ def get_updatable_engines_and_lock(self):
engines = srv.engines if srv else []
gpu_counts = srv.engine_gpu_counts if srv else []
gpu_offsets = srv.engine_gpu_offsets if srv else []
num_new = srv.num_new_engines if srv else 0
return engines, self.rollout_engine_lock, num_new, gpu_counts, gpu_offsets
has_new = srv.has_new_engines if srv else False
return engines, self.rollout_engine_lock, has_new, gpu_counts, gpu_offsets

def clear_updatable_num_new_engines(self):
# when fault tolerance is not enabled, we need to manually clear num_new_engines after update_weights
def clear_updatable_has_new_engines(self):
# when fault tolerance is not enabled, we need to manually clear has_new_engines after update_weights
Comment on lines +198 to +199
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment on line 199 still refers to num_new_engines. It should be updated to has_new_engines to reflect the recent changes.

Suggested change
def clear_updatable_has_new_engines(self):
# when fault tolerance is not enabled, we need to manually clear has_new_engines after update_weights
def clear_updatable_has_new_engines(self):
# when fault tolerance is not enabled, we need to manually clear has_new_engines after update_weights

srv = self._get_updatable_server()
if srv:
srv.clear_num_new_engines()
srv.clear_has_new_engines()

async def recover_updatable_engines(self) -> None:
"""Restart any dead rollout engines and update num_new_engines for update_weights detection.
"""Restart any dead rollout engines and update has_new_engines for update_weights detection.

Recovers the updatable model (the one that receives weight
updates from training).
Expand Down
12 changes: 6 additions & 6 deletions miles/ray/rollout/rollout_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def start_rollout_servers(args, pg) -> dict[str, "RolloutServer"]:
pg=pg,
all_engines=[None] * num_engines if group_cfg.worker_type != "placeholder" else [],
num_gpus_per_engine=gpus_per_engine,
num_new_engines=0,
has_new_engines=False,
worker_type=group_cfg.worker_type,
rank_offset=engine_offset,
gpu_offset=gpu_offset,
Expand All @@ -71,7 +71,7 @@ def start_rollout_servers(args, pg) -> dict[str, "RolloutServer"]:
router_port=router_port,
update_weights=model_cfg.update_weights,
)
handles = group.start_engines(port_cursors)
handles, _ = group.start_engines(port_cursors)
all_init_handles.extend(handles)
server_groups.append(group)

Expand Down Expand Up @@ -163,12 +163,12 @@ def all_engines(self):
return [e for g in self.server_groups for e in g.all_engines]

@property
def num_new_engines(self):
return sum(g.num_new_engines for g in self.server_groups)
def has_new_engines(self) -> bool:
return any(g.has_new_engines for g in self.server_groups)

def clear_num_new_engines(self):
def clear_has_new_engines(self):
for g in self.server_groups:
g.num_new_engines = 0
g.has_new_engines = False

@property
def engine_gpu_counts(self) -> list[int]:
Expand Down
27 changes: 15 additions & 12 deletions miles/ray/rollout/server_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class ServerGroup:
pg: Any # (placement_group, reordered_bundle_indices, reordered_gpu_ids)
all_engines: list
num_gpus_per_engine: int
num_new_engines: int
# NOTE: this may have risk when recovering engines parallelly; may use source of truth (all_engines) later
has_new_engines: bool
worker_type: str = "regular" # "regular", "prefill", or "decode"
rank_offset: int = 0
gpu_offset: int = 0
Expand All @@ -53,15 +54,15 @@ def engines(self):
"""Node-0 engines only (for multi-node serving)."""
return self.all_engines[:: self.nodes_per_engine]

def start_engines(self, port_cursors: PortCursors) -> list:
def start_engines(self, port_cursors: PortCursors) -> tuple[list, int]:
"""Create Ray actors, allocate ports, and fire ``engine.init()`` without waiting.

Returns ``(init_handles, port_cursors)`` where *init_handles* is a list
Returns ``(init_handles, curr_num_new_engines)`` where *init_handles* is a list
of Ray ObjectRefs and *port_cursors* maps node index -> next free port.
"""
Comment on lines +57 to 62
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for start_engines is outdated. It incorrectly states that the method returns (init_handles, port_cursors), whereas it now returns (init_handles, curr_num_new_engines). Additionally, port_cursors is modified in-place rather than returned.

Suggested change
def start_engines(self, port_cursors: PortCursors) -> tuple[list, int]:
"""Create Ray actors, allocate ports, and fire ``engine.init()`` without waiting.
Returns ``(init_handles, port_cursors)`` where *init_handles* is a list
Returns ``(init_handles, curr_num_new_engines)`` where *init_handles* is a list
of Ray ObjectRefs and *port_cursors* maps node index -> next free port.
"""
def start_engines(self, port_cursors: PortCursors) -> tuple[list, int]:
"""Create Ray actors, allocate ports, and fire engine.init() without waiting.
Returns (init_handles, curr_num_new_engines) where init_handles is a list
of Ray ObjectRefs and curr_num_new_engines is the number of newly started engines.
"""

if self.args.debug_train_only or self.worker_type == "placeholder":
self.num_new_engines = 0
return []
self.has_new_engines = False
return [], 0

num_gpu_per_engine = min(self.num_gpus_per_engine, self.args.num_gpus_per_node)

Expand Down Expand Up @@ -120,10 +121,11 @@ def start_engines(self, port_cursors: PortCursors) -> list:
new_engines.append((global_rank, rollout_engine))
self.all_engines[i] = rollout_engine

self.num_new_engines = len(new_engines)
curr_num_new_engines = len(new_engines)
self.has_new_engines |= curr_num_new_engines > 0

if self.num_new_engines == 0:
return []
if curr_num_new_engines == 0:
return [], 0

if self.args.rollout_external:
addr_and_ports = allocate_rollout_engine_addr_and_ports_external(
Expand All @@ -149,7 +151,7 @@ def start_engines(self, port_cursors: PortCursors) -> list:
)
for index, engine in new_engines
]
return init_handles
return init_handles, curr_num_new_engines

def stop_engines(self, rollout_engine_id: int):
logger.info(f"Killing server group {rollout_engine_id}...")
Expand All @@ -173,12 +175,13 @@ def stop_engines(self, rollout_engine_id: int):
async def recover(self, port_cursors: PortCursors):
dead_indices = [i for i, engine in enumerate(self.all_engines) if engine is None]

await asyncio.gather(*self.start_engines(port_cursors))
handles, curr_num_new_engines = self.start_engines(port_cursors)
await asyncio.gather(*handles)

release_handles = []
all_resume_engines = []
logger.info(f"Recovered {self.num_new_engines} dead rollout engines (worker_type={self.worker_type})")
assert self.num_new_engines == len(dead_indices), "num_new_engines does not match dead_indices length"
logger.info(f"Recovered {curr_num_new_engines} dead rollout engines (worker_type={self.worker_type})")
assert curr_num_new_engines == len(dead_indices), "curr_num_new_engines does not match dead_indices length"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion message still refers to num_new_engines. It should be updated to match the new variable name curr_num_new_engines for consistency.

Suggested change
assert curr_num_new_engines == len(dead_indices), "curr_num_new_engines does not match dead_indices length"
assert curr_num_new_engines == len(dead_indices), "curr_num_new_engines does not match dead_indices length"

if self.needs_offload and dead_indices:
new_engines = [self.all_engines[i] for i in dead_indices]
release_handles.extend(engine.release_memory_occupation.remote() for engine in new_engines)
Expand Down
Loading