Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rllm/experimental/fully_async/fully_async_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from omegaconf import OmegaConf
from tqdm import tqdm
from verl import DataProto
from verl.experimental.fully_async_policy.ray_trainer import FullyAsyncRayPPOTrainer
from verl.experimental.separation.ray_trainer import SeparateRayPPOTrainer
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
from verl.trainer.ppo import core_algos
from verl.trainer.ppo.core_algos import agg_loss
Expand All @@ -40,7 +40,7 @@


@ray.remote(num_cpus=10)
class FullyAsyncTrainer(FullyAsyncRayPPOTrainer):
class FullyAsyncTrainer(SeparateRayPPOTrainer):
"""
A fully asynchronous PPO trainer that obtains samples from a MessageQueue for training.
Based on an improved implementation of OneStepOffRayTrainer
Expand Down
10 changes: 5 additions & 5 deletions rllm/experimental/fully_async/inference_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
import subprocess

import ray
from verl.experimental.fully_async_policy.ray_trainer import FullyAsyncRayPPOTrainer
from verl.experimental.separation.ray_trainer import SeparateRayPPOTrainer
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
from verl.trainer.ppo.ray_trainer import ResourcePoolManager
from verl.trainer.ppo.utils import Role, WorkerType
from verl.workers.rollout.utils import get_free_port
from verl.utils.net_utils import get_free_port


@ray.remote(num_cpus=10, max_concurrency=100)
class InferenceManager(FullyAsyncRayPPOTrainer):
class InferenceManager(SeparateRayPPOTrainer):
"""
Manages SGLang inference servers for async training.
Responsible for:
Expand Down Expand Up @@ -120,10 +120,10 @@ def _init_models(self):
async def _init_async_rollout_manager(self):
# create async rollout manager and request scheduler
assert self.config.actor_rollout_ref.rollout.mode == "async"
from verl.experimental.fully_async_policy.agent_loop import FullyAsyncAgentLoopManager
from verl.experimental.agent_loop import AgentLoopManager

self.async_rollout_mode = True
self.async_rollout_manager = await FullyAsyncAgentLoopManager.create(
self.async_rollout_manager = await AgentLoopManager.create(
config=self.config,
worker_group=self.rollout_wg,
)
Expand Down
2 changes: 1 addition & 1 deletion rllm/experimental/fully_async/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import ray
from omegaconf import OmegaConf
from verl.experimental.fully_async_policy.fully_async_main import create_resource_pool_manager, create_role_worker_mapping
from verl.experimental.separation.utils import create_resource_pool_manager, create_role_worker_mapping
from verl.trainer.ppo.utils import Role
from verl.utils.fs import copy_to_local

Expand Down
50 changes: 50 additions & 0 deletions tests/test_verl_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Test that verl import paths are compatible with verl 0.7.1+.

Regression test for https://github.com/rllm-org/rllm/issues/470
After verl 0.7.1 restructured module paths, several imports in the
fully_async module broke with ModuleNotFoundError.
"""

import importlib

import pytest


@pytest.mark.parametrize(
"module_path,names",
[
(
"rllm.experimental.fully_async.runner",
["AsyncAgentTrainer", "FullyAsyncTaskRunner"],
),
(
"rllm.experimental.fully_async.fully_async_trainer",
["FullyAsyncTrainer"],
),
(
"rllm.experimental.fully_async.inference_manager",
["InferenceManager"],
),
],
)
def test_fully_async_imports(module_path: str, names: list[str]) -> None:
"""Verify fully_async modules can be imported without ModuleNotFoundError."""
mod = importlib.import_module(module_path)
for name in names:
assert hasattr(mod, name), f"{module_path} is missing attribute {name}"


@pytest.mark.parametrize(
"module_path,name",
[
("verl.experimental.separation.ray_trainer", "SeparateRayPPOTrainer"),
("verl.experimental.separation.utils", "create_resource_pool_manager"),
("verl.experimental.separation.utils", "create_role_worker_mapping"),
("verl.experimental.agent_loop", "AgentLoopManager"),
("verl.utils.net_utils", "get_free_port"),
],
)
def test_verl_module_paths_exist(module_path: str, name: str) -> None:
"""Verify the verl module paths used by rllm exist in the installed verl package."""
mod = importlib.import_module(module_path)
assert hasattr(mod, name), f"{module_path} is missing attribute {name}"
Loading