-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathmain_ppo.py
More file actions
200 lines (165 loc) · 8.28 KB
/
main_ppo.py
File metadata and controls
200 lines (165 loc) · 8.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
import os
import ray
import hydra
def get_custom_reward_fn(config):
import importlib.util, sys
reward_fn_config = config.get("custom_reward_function") or {}
file_path = reward_fn_config.get("path")
if not file_path:
return None
if not os.path.exists(file_path):
raise FileNotFoundError(f"Reward function file '{file_path}' not found.")
spec = importlib.util.spec_from_file_location("custom_module", file_path)
module = importlib.util.module_from_spec(spec)
try:
sys.modules["custom_module"] = module
spec.loader.exec_module(module)
except Exception as e:
raise RuntimeError(f"Error loading module from '{file_path}': {e}")
function_name = reward_fn_config.get("name")
if not hasattr(module, function_name):
raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.")
print(f"using customized reward function '{function_name}' from '{file_path}'")
raw_fn = getattr(module, function_name)
reward_kwargs = dict(reward_fn_config.get("reward_kwargs", {}))
def wrapped_fn(*args, **kwargs):
return raw_fn(*args, **kwargs, **reward_kwargs)
return wrapped_fn
@hydra.main(config_path='./verl/verl/trainer/config', config_name='ppo_trainer', version_base=None)
def main(config):
run_ppo(config)
def run_ppo(config) -> None:
# TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices
# isolation, will solve in the future
os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get('CUDA_VISIBLE_DEVICES', '')
if not ray.is_initialized():
# this is for local ray cluster
ray.init(runtime_env={
'env_vars': {
'TOKENIZERS_PARALLELISM': 'true',
'NCCL_DEBUG': 'WARN',
'VLLM_LOGGING_LEVEL': 'WARN'
}
})
runner = TaskRunner.remote()
ray.get(runner.run.remote(config))
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
class TaskRunner:
def run(self, config):
from verl.utils.fs import copy_to_local
# print initial config
from pprint import pprint
from omegaconf import OmegaConf
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
# download the checkpoint from hdfs
local_path = copy_to_local(config.actor_rollout_ref.model.path)
# instantiate tokenizer
from verl.utils import hf_tokenizer, hf_processor
trust_remote_code = config.data.get('trust_remote_code', False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
# define worker classes
if config.actor_rollout_ref.actor.strategy == 'fsdp':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray import RayWorkerGroup
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == 'megatron':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
ray_worker_group_cls = NVMegatronRayWorkerGroup
else:
raise NotImplementedError
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
role_worker_mapping = {
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
Role.Critic: ray.remote(CriticWorker),
}
global_pool_id = 'global_pool'
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
}
# we should adopt a multi-source reward function here
# - for rule-based rm, we directly call a reward score
# - for model-based rm, we call a model
# - for code related prompt, we send to a sandbox if there are test cases
# - finally, we combine all the rewards together
# - The reward type depends on the tag of the data
if config.reward_model.enable:
if config.reward_model.strategy == 'fsdp':
from verl.workers.fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == 'megatron':
from verl.workers.megatron_workers import RewardModelWorker
elif config.reward_model.strategy == 'verifier':
from verifier import RewardModelWorker
else:
raise NotImplementedError
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
mapping[Role.RewardModel] = global_pool_id
#use reference model
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
mapping[Role.RefPolicy] = global_pool_id
reward_manager_name = config.reward_model.get("reward_manager", "naive")
if reward_manager_name == 'naive':
from verl.workers.reward_manager import NaiveRewardManager
reward_manager_cls = NaiveRewardManager
elif reward_manager_name == 'prime':
from verl.workers.reward_manager import PrimeRewardManager
reward_manager_cls = PrimeRewardManager
elif reward_manager_name == 'batch':
from verl.workers.reward_manager import BatchRewardManager
reward_manager_cls = BatchRewardManager
elif reward_manager_name == 'dapo':
from verl.workers.reward_manager import DAPORewardManager
reward_manager_cls = DAPORewardManager
else:
raise NotImplementedError
compute_score = get_custom_reward_fn(config)
reward_kwargs = dict(config.reward_model.get("reward_kwargs", {}))
reward_fn = reward_manager_cls(tokenizer=tokenizer,
num_examine=0,
compute_score=compute_score,
reward_fn_key=config.data.reward_fn_key,
**reward_kwargs)
# Note that we always use function-based RM for validation
val_reward_fn = reward_manager_cls(tokenizer=tokenizer,
num_examine=1,
compute_score=compute_score,
reward_fn_key=config.data.reward_fn_key)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
trainer = RayPPOTrainer(config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn)
trainer.init_workers()
trainer.fit()
if __name__ == '__main__':
main()