Skip to content

Update ppo #10912

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Aug 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions llm/alignment/rl/gsm8k_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the GSM8k dataset to parquet format
"""

import argparse
import os
import re

import datasets


def extract_solution(solution_str):
solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
assert solution is not None
final_solution = solution.group(0)
final_solution = final_solution.split("#### ")[1].replace(",", "")
return final_solution


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", default="./gsm8k")

args = parser.parse_args()

data_source = "openai/gsm8k"

dataset = datasets.load_dataset(data_source, "main")

train_dataset = dataset["train"]
test_dataset = dataset["test"]

instruction_following = 'Let\'s think step by step and output the final answer after "####".'

# add a row to each data item that represents a unique id
def make_map_fn(split):
def process_fn(example, idx):
question_raw = "<|im_start|>user\n" + example.pop("question")

system_raw = (
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
)
question = system_raw + question_raw + " " + instruction_following + "<|im_end|>\n<|im_start|>assistant\n"

answer_raw = example.pop("answer")
solution = extract_solution(answer_raw)
data = {
"src": question,
"tgt": solution,
}
return data

return process_fn

train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)

local_dir = args.local_dir

train_dataset.to_json(os.path.join(local_dir, "train.jsonl"), orient="records", lines=True)
test_dataset.to_json(os.path.join(local_dir, "test.jsonl"), orient="records", lines=True)
88 changes: 47 additions & 41 deletions llm/alignment/rl/run_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from paddlenlp.transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForTokenClassification,
AutoTokenizer,
PretrainedConfig,
)
Expand Down Expand Up @@ -134,7 +135,6 @@ def create_actor_models(
)
if not training_args.autotuner_benchmark:
reference_model.set_state_dict(actor_model.state_dict())

actor_tokenizer = AutoTokenizer.from_pretrained(
model_args.actor_model_name_or_path,
model_max_length=data_args.max_length,
Expand Down Expand Up @@ -210,46 +210,43 @@ def create_critic_models(
data_args: DataArgument,
training_args: TrainingArguments,
common_config: Dict,
reward_model,
):
with timers_scope_runtimer("Critic model loading time"):
reward_model_config = reward_model.config
if model_args.critic_model_name_or_path is None:
model_args.critic_model_name_or_path = model_args.reward_model_name_or_path
critic_model = AutoModelForScore.from_config(
reward_model_config,
dtype=training_args.model_dtype,
score_type="critic",
do_normalize=False,
clip_range_value=training_args.clip_range_value,
**common_config,
critic_model_config = AutoConfig.from_pretrained(
model_args.critic_model_name_or_path,
tensor_parallel_output=training_args.tensor_parallel_output,
tensor_parallel_degree=training_args.tensor_parallel_degree,
tensor_parallel_rank=training_args.tensor_parallel_rank,
dtype=training_args.model_dtype,
recompute=training_args.critic_recompute,
recompute_granularity=model_args.critic_recompute_granularity,
recompute_use_reentrant=training_args.recompute_use_reentrant,
**common_config,
)
LlmMetaConfig.set_llm_config(critic_model_config, training_args)

critic_model_config.max_position_embeddings = data_args.max_length
critic_model_config.use_sparse_head_and_loss_fn = False
critic_model_config.num_labels = 1
critic_model_config.classifier_dropout = 0.0
critic_model_config.hidden_dropout = 0.0
logger.info(f"Loading Critic model with config:\n\t{critic_model_config}\n")

if not training_args.autotuner_benchmark:
critic_model = AutoModelForTokenClassification.from_pretrained(
model_args.critic_model_name_or_path,
config=critic_model_config,
)
if not training_args.autotuner_benchmark:
critic_model.set_state_dict(reward_model.state_dict())
else:
if not training_args.autotuner_benchmark:
critic_model = AutoModelForScore.from_pretrained(
model_args.critic_model_name_or_path,
config=reward_model_config,
score_type="critic",
do_normalize=False,
clip_range_value=training_args.clip_range_value,
**common_config,
)
else:
critic_model = AutoModelForScore.from_config(
reward_model_config,
score_type="critic",
do_normalize=False,
clip_range_value=training_args.clip_range_value,
**common_config,
)
critic_model = AutoModelForTokenClassification.from_config(
critic_model_config,
)

critic_tokenizer = AutoTokenizer.from_pretrained(
model_args.critic_model_name_or_path,
model_max_length=data_args.max_length,
padding_side="left",
tokenizer_alpha=model_args.reward_critic_tokenizer_alpha,
tokenizer_alpha=model_args.critic_tokenizer_alpha,
use_fast=True,
)
if critic_tokenizer.pad_token_id is None:
Expand All @@ -261,16 +258,16 @@ def create_critic_models(
if training_args.eval_mode == "single":
config.tensor_parallel_degree = -1
config.tensor_parallel_rank = 0
with timers_scope_runtimer("Reward critic eval model loading time"):
critic_eval_model = AutoModelForScore.from_config(config)
with timers_scope_runtimer("Critic eval model loading time"):
critic_eval_model = AutoModelForTokenClassification.from_config(config)
else:
critic_eval_model = None

return critic_model, critic_eval_model, critic_tokenizer


def create_rl_dataset(data_args, training_args, tokenizer):
requires_label = True if training_args.use_rm_server else False
requires_label = True if training_args.use_rm_server or training_args.use_rule_reward else False
train_ds = RLHFDataset(
dataset_name_or_path=data_args.train_datasets,
tokenizer=tokenizer,
Expand Down Expand Up @@ -333,15 +330,16 @@ def main():
actor_model, actor_eval_model, reference_model, actor_tokenizer = create_actor_models(
model_args, data_args, training_args, common_config, reshard_controller
)

if not training_args.use_rm_server and model_args.reward_model_name_or_path is not None:
if training_args.use_rule_reward:
reward_model, reward_tokenizer = None, actor_tokenizer
elif not training_args.use_rm_server and model_args.reward_model_name_or_path is not None:
reward_model, reward_tokenizer = create_reward_models(model_args, data_args, training_args, common_config)
else:
reward_model, reward_tokenizer = model_args.reward_server, actor_tokenizer

if training_args.rl_algorithm == "ppo":
critic_model, critic_eval_model, critic_tokenizer = create_critic_models(
model_args, data_args, training_args, common_config, reward_model
model_args, data_args, training_args, common_config
)
else:
critic_model, critic_eval_model, critic_tokenizer = None, None, None
Expand All @@ -355,15 +353,23 @@ def main():
offload_tensor_to_cpu((reference_model, "freeze_model"))

if training_args.rl_algorithm == "ppo":
offload_tensor_to_cpu((reward_model, "freeze_model"))
if not training_args.use_rm_server and not training_args.use_rule_reward:
offload_tensor_to_cpu((reward_model, "freeze_model"))
if critic_eval_model is not None:
offload_tensor_to_cpu((critic_eval_model, "freeze_model"))

# NOTE(gongenlei): release memory_reserved_size to equal to memory_allocated_size
paddle.device.cuda.empty_cache()

def compute_metrics(eval_preds):
accuracy = (eval_preds.predictions == 3).astype("float32").mean().item()
'''
If "use_rm_server" is TRUE, the score ranges from -3 to 3, with 3 being the only correct score (format + result).
If using the "Regularized Matching Function (use_rule_reward=True)" (currently only implemented for the gsm8k dataset), the score ranges from 0 to 1.
'''
if training_args.use_rule_reward:
accuracy = (eval_preds.predictions == 1).astype("float32").mean().item()
else:
accuracy = (eval_preds.predictions == 3).astype("float32").mean().item()
return {"accuracy": accuracy}

try:
Expand All @@ -389,7 +395,7 @@ def compute_metrics(eval_preds):
data_collator=partial(
collate_fn,
pad_token_id=actor_tokenizer.pad_token_id,
requires_label=True if training_args.use_rm_server else False,
requires_label=True if training_args.use_rm_server or training_args.use_rule_reward else False,
max_prompt_len=data_args.max_prompt_len if training_args.balance_batch else None,
), # NOTE: enforce prompt padding to max_prompt_len when using balance_batch
compute_metrics=compute_metrics, # TODO: only used for grpo (kk datasets)
Expand Down
131 changes: 131 additions & 0 deletions llm/config/qwen/ppo_argument.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# RL algorithms
rl_algorithm: "ppo" # The reinforcement learning algorithm used, supported: "ppo", "grpo", "reinforce_plus_plus"

# models
actor_model_name_or_path: "Qwen/Qwen2.5-1.5B-Instruct" # The name or path of the actor model
reward_model_name_or_path: "" # The name or path of the reward model
critic_model_name_or_path: "Qwen/Qwen2.5-1.5B-Instruct" # The name or path of the critic model
use_rm_server: false # Whether to use the reward model server
reward_server: "http://127.0.0.1:8731" # The address of the reward model server
use_rule_reward: True # The reward for gsm8k dataset. If use_rule_reward: use_rm_server = false

# logging
logging_dir: ppo-logs # Directory for logging
logging_steps: 1 # Number of steps between logging
output_dir: "qwen2.5-1.5b-gsm8k-ppo/checkpoints" # Directory for output ckpts
report_to: "wandb" # Supported reporting options: "all", "wandb", "tensorboard", "visualdl"(default), "none"
wandb_http_proxy: "http://agent.baidu.com:8188" # HTTP proxy for wandb
run_name: "qwen2.5-1.5b-gsm8k-ppo" # Name of the run

# data
train_datasets: "gsm8k/train.jsonl" # Path to the training dataset
eval_datasets: "gsm8k/test.jsonl" # Path to the evaluation dataset
prompt_key: "src" # Key for the prompt in the dataset
response_key: "tgt" # Key for the response in the dataset
dataloader_drop_last: true # Whether to drop the last incomplete batch in the DataLoader
balance_batch: true # Whether to balance batch size across dataset_world_size
use_remove_padding: true # Whether to remove padding tokens in the input

# distributed training args
tensor_parallel_degree: 2 # Degree of tensor parallelism
sequence_parallel: true # Whether to enable sequence parallelism
sharding_parallel_degree: -1 # Degree of sharding parallelism
sharding: "stage1" # Sharding strategy, e.g., "stage1" or "stage2"
sharding_parallel_config: "enable_release_grads" # Configuration for sharding parallelism
pipeline_parallel_degree: 1 # Degree of pipeline parallelism
virtual_pp_degree: 1 # Degree of virtual pipeline parallelism

# rollout args
max_prompt_len: 1024 # Maximum length of the prompt, exceeding which will be automatically truncated
max_dec_len: 512 # Maximum length of the response
min_dec_len: 32 # Minimum length of the response
top_p: 1.0 # Top-p sampling parameter
temperature: 1.0 # Temperature parameter for sampling
repetition_penalty: 1.0 # Repetition penalty parameter
rollout_max_num_seqs: 1024 # The maximum number of sequences that can be processed in a single inference
rollout_quant_type: "" # Quantization type, e.g., "weight_only_int8"

# training args
do_train: true # Whether to perform training
seed: 42 # Random seed for reproducibility
global_batch_size: 256 # Global batch size for training (rollouts = rollout_n * global_batch_size)
global_gen_batch_size: -1 # Global generation batch size for dynamic sampling
global_mini_batch_size: 64 # Mini-batch size for training, default = (global_batch_size * rollout_n * update_iters) // dataset_world_size
rollout_n: 1 # Number of rollouts, set rollout_n = 1 for 'ppo'
update_iters: 1 # Number of training iterations for rollout samples
per_device_logprob_batch_size: 4 # Log probability batch size per device
per_device_reward_batch_size: 2 # Reward batch size per device
per_device_value_batch_size: 2 # Value batch size per device
per_device_train_batch_size: 2 # Training micro batch size per device
# gradient_accumulation_steps: 4 # Gradient accumulation steps (auto-calculated): global_bz * rollout_n *
num_train_epochs: 5 # Number of training epochs
max_length: 2048 # Maximum length for training, should be larger than max_prompt_len + max_dec_len
adam_beta1: 0.9 # AdamW optimizer beta1
adam_beta2: 0.999 # AdamW optimizer beta2
adam_epsilon: 1e-8 # AdamW optimizer epsilon
max_grad_norm: 1.0 # Maximum gradient norm for clipping
max_steps: -1 # Maximum number of training steps
save_steps: 300 # Number of steps between model saves
save_strategy: "steps" # Strategy for saving models
ignore_save_lr_and_optim: true # Whether to ignore saving learning rate and optimizer state (leave empty if not specified)
disable_tqdm: true # Whether to disable tqdm progress bar

# actor training args
learning_rate: 1e-6 # Learning rate for training
min_learning_rate: 1e-6 # Minimum learning rate
lr_scheduler_type: "constant" # Learning rate scheduler type
weight_decay: 1e-2 # Weight decay for the AdamW optimizer
warmup_ratio: 0.0 # Number of warmup steps

# critic training args
critic_learning_rate: 1e-5 # Learning rate for critic model
critic_min_learning_rate: 1e-5 # Minimum learning rate for critic model
critic_lr_scheduler_type: "constant" # Learning rate scheduler type for critic model
critic_weight_decay: 1e-2 # Weight decay for the AdamW optimizer of critic model
critic_warmup_ratio: 0.0 # Number of warmup steps for critic model

# RL args
kl_coeff: 0.0 # KL coefficient
kl_loss_coeff: 0.001 # KL loss coefficient
pg_loss_coeff: 1.0 # Policy gradient loss coefficient
entropy_coeff: 0.001 # Entropy coefficient
clip_range_ratio: 0.2 # The clipping range for ratio between the old and new policy. (PPO algorithm)
clip_range_ratio_low: 0.2 # The clipping range for ratio between the old and new policy. (PPO algorithm)
clip_range_ratio_high: 0.2 # The clipping range for ratio between the old and new policy. (PPO algorithm)
clip_range_score: 10.0 # The clipping range for the output of the score model. The reward is clipped into [-clip_range_score, clip_range_score].
enable_overlong_reward_buffer: false # Whether to enable overlong reward buffer
overlong_reward_buffer: 256 # The length of the overlong reward buffer
overlong_penalty_factor: 1.0 # The penalty factor for overlong reward buffer
clip_range_value: 0.5 # The clipping range for the output of the value model. The value is clipped into [-clip_range_value, clip_range_value].
normalize_reward: false # Whether to normalize reward
normalize_advantage: false # Whether to normalize advantage
dynamic_sampling: false # Whether to use dynamic sampling, which is introcuded in DAPO algorithm https://arxiv.org/abs/2503.14476
max_gen_batches: 2 # Maximum number of generation batches for dynamic sampling
use_fp32_compute: true # Whether to use fp32 to compute xx_log_prob,rewards, advantages and loss

# eval args
do_eval: true # Whether to perform evaluation
per_device_eval_batch_size: 1319 # Evaluation batch size per device
evaluation_strategy: "steps" # Evaluation strategy, e.g., "steps"
eval_steps: 10 # Number of steps between evaluations

# device memory optimization args
use_flash_attention: true # Whether to use fused attention operations
use_fused_rms_norm: true # Whether to use fused RMS norm operations, which needs to install fused_ln in slm/model_zoo/gpt-3/external_ops
use_fused_rope: false # Whether to use fused rope operations
use_fused_head_and_loss_fn: true # Whether to use fused head and loss function
use_fused_linear: true # Whether to use fused linear operations. 像是一个没有用的参数
recompute: false # Whether to enable gradient checkpointing for memory optimization
recompute_use_reentrant: false # Whether to use reentrant recompute
recompute_granularity: "full" # Granularity of recompute
bf16: true # Whether to use mixed precision with bfloat16
fp16_opt_level: "O2" # Optimization level for fp16 and bf16 training
amp_master_grad: false # Whether to use float32 weight gradients for master weights in amp opt level=’O2’
amp_custom_black_list: ["reduce_sum", "softmax_with_cross_entropy", "c_softmax_with_cross_entropy", "elementwise_div", "sin", "cos"] # Custom black list for amp
amp_custom_white_list: ["lookup_table", "lookup_table_v2", "flash_attn", "matmul", "matmul_v2", "fused_gemm_epilogue"] # Custom white list for amp
offload_level: "freeze_model" # Level of model offloading to pinned memory, supported values: freeze_model, train_model, optimizer
release_grads: true # Whether to release gradients
offload_optim: false # Whether to offload optimizer to pinned memory

# benchmark args
skip_profile_timer: false # Whether to skip profiling time
Loading
Loading