diff --git a/llm/alignment/rl/gsm8k_processor.py b/llm/alignment/rl/gsm8k_processor.py new file mode 100644 index 000000000000..7c9674719b18 --- /dev/null +++ b/llm/alignment/rl/gsm8k_processor.py @@ -0,0 +1,74 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocess the GSM8k dataset to parquet format +""" + +import argparse +import os +import re + +import datasets + + +def extract_solution(solution_str): + solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str) + assert solution is not None + final_solution = solution.group(0) + final_solution = final_solution.split("#### ")[1].replace(",", "") + return final_solution + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--local_dir", default="./gsm8k") + + args = parser.parse_args() + + data_source = "openai/gsm8k" + + dataset = datasets.load_dataset(data_source, "main") + + train_dataset = dataset["train"] + test_dataset = dataset["test"] + + instruction_following = 'Let\'s think step by step and output the final answer after "####".' + + # add a row to each data item that represents a unique id + def make_map_fn(split): + def process_fn(example, idx): + question_raw = "<|im_start|>user\n" + example.pop("question") + + system_raw = ( + "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n" + ) + question = system_raw + question_raw + " " + instruction_following + "<|im_end|>\n<|im_start|>assistant\n" + + answer_raw = example.pop("answer") + solution = extract_solution(answer_raw) + data = { + "src": question, + "tgt": solution, + } + return data + + return process_fn + + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) + test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) + + local_dir = args.local_dir + + train_dataset.to_json(os.path.join(local_dir, "train.jsonl"), orient="records", lines=True) + test_dataset.to_json(os.path.join(local_dir, "test.jsonl"), orient="records", lines=True) diff --git a/llm/alignment/rl/run_rl.py b/llm/alignment/rl/run_rl.py index 122faf116256..6983dd8dfd7b 100644 --- a/llm/alignment/rl/run_rl.py +++ b/llm/alignment/rl/run_rl.py @@ -42,6 +42,7 @@ from paddlenlp.transformers import ( AutoConfig, AutoModelForCausalLM, + AutoModelForTokenClassification, AutoTokenizer, PretrainedConfig, ) @@ -134,7 +135,6 @@ def create_actor_models( ) if not training_args.autotuner_benchmark: reference_model.set_state_dict(actor_model.state_dict()) - actor_tokenizer = AutoTokenizer.from_pretrained( model_args.actor_model_name_or_path, model_max_length=data_args.max_length, @@ -210,46 +210,43 @@ def create_critic_models( data_args: DataArgument, training_args: TrainingArguments, common_config: Dict, - reward_model, ): with timers_scope_runtimer("Critic model loading time"): - reward_model_config = reward_model.config - if model_args.critic_model_name_or_path is None: - model_args.critic_model_name_or_path = model_args.reward_model_name_or_path - critic_model = AutoModelForScore.from_config( - reward_model_config, - dtype=training_args.model_dtype, - score_type="critic", - do_normalize=False, - clip_range_value=training_args.clip_range_value, - **common_config, + critic_model_config = AutoConfig.from_pretrained( + model_args.critic_model_name_or_path, + tensor_parallel_output=training_args.tensor_parallel_output, + tensor_parallel_degree=training_args.tensor_parallel_degree, + tensor_parallel_rank=training_args.tensor_parallel_rank, + dtype=training_args.model_dtype, + recompute=training_args.critic_recompute, + recompute_granularity=model_args.critic_recompute_granularity, + recompute_use_reentrant=training_args.recompute_use_reentrant, + **common_config, + ) + LlmMetaConfig.set_llm_config(critic_model_config, training_args) + + critic_model_config.max_position_embeddings = data_args.max_length + critic_model_config.use_sparse_head_and_loss_fn = False + critic_model_config.num_labels = 1 + critic_model_config.classifier_dropout = 0.0 + critic_model_config.hidden_dropout = 0.0 + logger.info(f"Loading Critic model with config:\n\t{critic_model_config}\n") + + if not training_args.autotuner_benchmark: + critic_model = AutoModelForTokenClassification.from_pretrained( + model_args.critic_model_name_or_path, + config=critic_model_config, ) - if not training_args.autotuner_benchmark: - critic_model.set_state_dict(reward_model.state_dict()) else: - if not training_args.autotuner_benchmark: - critic_model = AutoModelForScore.from_pretrained( - model_args.critic_model_name_or_path, - config=reward_model_config, - score_type="critic", - do_normalize=False, - clip_range_value=training_args.clip_range_value, - **common_config, - ) - else: - critic_model = AutoModelForScore.from_config( - reward_model_config, - score_type="critic", - do_normalize=False, - clip_range_value=training_args.clip_range_value, - **common_config, - ) + critic_model = AutoModelForTokenClassification.from_config( + critic_model_config, + ) critic_tokenizer = AutoTokenizer.from_pretrained( model_args.critic_model_name_or_path, model_max_length=data_args.max_length, padding_side="left", - tokenizer_alpha=model_args.reward_critic_tokenizer_alpha, + tokenizer_alpha=model_args.critic_tokenizer_alpha, use_fast=True, ) if critic_tokenizer.pad_token_id is None: @@ -261,8 +258,8 @@ def create_critic_models( if training_args.eval_mode == "single": config.tensor_parallel_degree = -1 config.tensor_parallel_rank = 0 - with timers_scope_runtimer("Reward critic eval model loading time"): - critic_eval_model = AutoModelForScore.from_config(config) + with timers_scope_runtimer("Critic eval model loading time"): + critic_eval_model = AutoModelForTokenClassification.from_config(config) else: critic_eval_model = None @@ -270,7 +267,7 @@ def create_critic_models( def create_rl_dataset(data_args, training_args, tokenizer): - requires_label = True if training_args.use_rm_server else False + requires_label = True if training_args.use_rm_server or training_args.use_rule_reward else False train_ds = RLHFDataset( dataset_name_or_path=data_args.train_datasets, tokenizer=tokenizer, @@ -333,15 +330,16 @@ def main(): actor_model, actor_eval_model, reference_model, actor_tokenizer = create_actor_models( model_args, data_args, training_args, common_config, reshard_controller ) - - if not training_args.use_rm_server and model_args.reward_model_name_or_path is not None: + if training_args.use_rule_reward: + reward_model, reward_tokenizer = None, actor_tokenizer + elif not training_args.use_rm_server and model_args.reward_model_name_or_path is not None: reward_model, reward_tokenizer = create_reward_models(model_args, data_args, training_args, common_config) else: reward_model, reward_tokenizer = model_args.reward_server, actor_tokenizer if training_args.rl_algorithm == "ppo": critic_model, critic_eval_model, critic_tokenizer = create_critic_models( - model_args, data_args, training_args, common_config, reward_model + model_args, data_args, training_args, common_config ) else: critic_model, critic_eval_model, critic_tokenizer = None, None, None @@ -355,7 +353,8 @@ def main(): offload_tensor_to_cpu((reference_model, "freeze_model")) if training_args.rl_algorithm == "ppo": - offload_tensor_to_cpu((reward_model, "freeze_model")) + if not training_args.use_rm_server and not training_args.use_rule_reward: + offload_tensor_to_cpu((reward_model, "freeze_model")) if critic_eval_model is not None: offload_tensor_to_cpu((critic_eval_model, "freeze_model")) @@ -363,7 +362,14 @@ def main(): paddle.device.cuda.empty_cache() def compute_metrics(eval_preds): - accuracy = (eval_preds.predictions == 3).astype("float32").mean().item() + ''' + If "use_rm_server" is TRUE, the score ranges from -3 to 3, with 3 being the only correct score (format + result). + If using the "Regularized Matching Function (use_rule_reward=True)" (currently only implemented for the gsm8k dataset), the score ranges from 0 to 1. + ''' + if training_args.use_rule_reward: + accuracy = (eval_preds.predictions == 1).astype("float32").mean().item() + else: + accuracy = (eval_preds.predictions == 3).astype("float32").mean().item() return {"accuracy": accuracy} try: @@ -389,7 +395,7 @@ def compute_metrics(eval_preds): data_collator=partial( collate_fn, pad_token_id=actor_tokenizer.pad_token_id, - requires_label=True if training_args.use_rm_server else False, + requires_label=True if training_args.use_rm_server or training_args.use_rule_reward else False, max_prompt_len=data_args.max_prompt_len if training_args.balance_batch else None, ), # NOTE: enforce prompt padding to max_prompt_len when using balance_batch compute_metrics=compute_metrics, # TODO: only used for grpo (kk datasets) diff --git a/llm/config/qwen/ppo_argument.yaml b/llm/config/qwen/ppo_argument.yaml new file mode 100644 index 000000000000..e82e4316290c --- /dev/null +++ b/llm/config/qwen/ppo_argument.yaml @@ -0,0 +1,131 @@ +# RL algorithms +rl_algorithm: "ppo" # The reinforcement learning algorithm used, supported: "ppo", "grpo", "reinforce_plus_plus" + +# models +actor_model_name_or_path: "Qwen/Qwen2.5-1.5B-Instruct" # The name or path of the actor model +reward_model_name_or_path: "" # The name or path of the reward model +critic_model_name_or_path: "Qwen/Qwen2.5-1.5B-Instruct" # The name or path of the critic model +use_rm_server: false # Whether to use the reward model server +reward_server: "http://127.0.0.1:8731" # The address of the reward model server +use_rule_reward: True # The reward for gsm8k dataset. If use_rule_reward: use_rm_server = false + +# logging +logging_dir: ppo-logs # Directory for logging +logging_steps: 1 # Number of steps between logging +output_dir: "qwen2.5-1.5b-gsm8k-ppo/checkpoints" # Directory for output ckpts +report_to: "wandb" # Supported reporting options: "all", "wandb", "tensorboard", "visualdl"(default), "none" +wandb_http_proxy: "http://agent.baidu.com:8188" # HTTP proxy for wandb +run_name: "qwen2.5-1.5b-gsm8k-ppo" # Name of the run + +# data +train_datasets: "gsm8k/train.jsonl" # Path to the training dataset +eval_datasets: "gsm8k/test.jsonl" # Path to the evaluation dataset +prompt_key: "src" # Key for the prompt in the dataset +response_key: "tgt" # Key for the response in the dataset +dataloader_drop_last: true # Whether to drop the last incomplete batch in the DataLoader +balance_batch: true # Whether to balance batch size across dataset_world_size +use_remove_padding: true # Whether to remove padding tokens in the input + +# distributed training args +tensor_parallel_degree: 2 # Degree of tensor parallelism +sequence_parallel: true # Whether to enable sequence parallelism +sharding_parallel_degree: -1 # Degree of sharding parallelism +sharding: "stage1" # Sharding strategy, e.g., "stage1" or "stage2" +sharding_parallel_config: "enable_release_grads" # Configuration for sharding parallelism +pipeline_parallel_degree: 1 # Degree of pipeline parallelism +virtual_pp_degree: 1 # Degree of virtual pipeline parallelism + +# rollout args +max_prompt_len: 1024 # Maximum length of the prompt, exceeding which will be automatically truncated +max_dec_len: 512 # Maximum length of the response +min_dec_len: 32 # Minimum length of the response +top_p: 1.0 # Top-p sampling parameter +temperature: 1.0 # Temperature parameter for sampling +repetition_penalty: 1.0 # Repetition penalty parameter +rollout_max_num_seqs: 1024 # The maximum number of sequences that can be processed in a single inference +rollout_quant_type: "" # Quantization type, e.g., "weight_only_int8" + +# training args +do_train: true # Whether to perform training +seed: 42 # Random seed for reproducibility +global_batch_size: 256 # Global batch size for training (rollouts = rollout_n * global_batch_size) +global_gen_batch_size: -1 # Global generation batch size for dynamic sampling +global_mini_batch_size: 64 # Mini-batch size for training, default = (global_batch_size * rollout_n * update_iters) // dataset_world_size +rollout_n: 1 # Number of rollouts, set rollout_n = 1 for 'ppo' +update_iters: 1 # Number of training iterations for rollout samples +per_device_logprob_batch_size: 4 # Log probability batch size per device +per_device_reward_batch_size: 2 # Reward batch size per device +per_device_value_batch_size: 2 # Value batch size per device +per_device_train_batch_size: 2 # Training micro batch size per device +# gradient_accumulation_steps: 4 # Gradient accumulation steps (auto-calculated): global_bz * rollout_n * +num_train_epochs: 5 # Number of training epochs +max_length: 2048 # Maximum length for training, should be larger than max_prompt_len + max_dec_len +adam_beta1: 0.9 # AdamW optimizer beta1 +adam_beta2: 0.999 # AdamW optimizer beta2 +adam_epsilon: 1e-8 # AdamW optimizer epsilon +max_grad_norm: 1.0 # Maximum gradient norm for clipping +max_steps: -1 # Maximum number of training steps +save_steps: 300 # Number of steps between model saves +save_strategy: "steps" # Strategy for saving models +ignore_save_lr_and_optim: true # Whether to ignore saving learning rate and optimizer state (leave empty if not specified) +disable_tqdm: true # Whether to disable tqdm progress bar + +# actor training args +learning_rate: 1e-6 # Learning rate for training +min_learning_rate: 1e-6 # Minimum learning rate +lr_scheduler_type: "constant" # Learning rate scheduler type +weight_decay: 1e-2 # Weight decay for the AdamW optimizer +warmup_ratio: 0.0 # Number of warmup steps + +# critic training args +critic_learning_rate: 1e-5 # Learning rate for critic model +critic_min_learning_rate: 1e-5 # Minimum learning rate for critic model +critic_lr_scheduler_type: "constant" # Learning rate scheduler type for critic model +critic_weight_decay: 1e-2 # Weight decay for the AdamW optimizer of critic model +critic_warmup_ratio: 0.0 # Number of warmup steps for critic model + +# RL args +kl_coeff: 0.0 # KL coefficient +kl_loss_coeff: 0.001 # KL loss coefficient +pg_loss_coeff: 1.0 # Policy gradient loss coefficient +entropy_coeff: 0.001 # Entropy coefficient +clip_range_ratio: 0.2 # The clipping range for ratio between the old and new policy. (PPO algorithm) +clip_range_ratio_low: 0.2 # The clipping range for ratio between the old and new policy. (PPO algorithm) +clip_range_ratio_high: 0.2 # The clipping range for ratio between the old and new policy. (PPO algorithm) +clip_range_score: 10.0 # The clipping range for the output of the score model. The reward is clipped into [-clip_range_score, clip_range_score]. +enable_overlong_reward_buffer: false # Whether to enable overlong reward buffer +overlong_reward_buffer: 256 # The length of the overlong reward buffer +overlong_penalty_factor: 1.0 # The penalty factor for overlong reward buffer +clip_range_value: 0.5 # The clipping range for the output of the value model. The value is clipped into [-clip_range_value, clip_range_value]. +normalize_reward: false # Whether to normalize reward +normalize_advantage: false # Whether to normalize advantage +dynamic_sampling: false # Whether to use dynamic sampling, which is introcuded in DAPO algorithm https://arxiv.org/abs/2503.14476 +max_gen_batches: 2 # Maximum number of generation batches for dynamic sampling +use_fp32_compute: true # Whether to use fp32 to compute xx_log_prob,rewards, advantages and loss + +# eval args +do_eval: true # Whether to perform evaluation +per_device_eval_batch_size: 1319 # Evaluation batch size per device +evaluation_strategy: "steps" # Evaluation strategy, e.g., "steps" +eval_steps: 10 # Number of steps between evaluations + +# device memory optimization args +use_flash_attention: true # Whether to use fused attention operations +use_fused_rms_norm: true # Whether to use fused RMS norm operations, which needs to install fused_ln in slm/model_zoo/gpt-3/external_ops +use_fused_rope: false # Whether to use fused rope operations +use_fused_head_and_loss_fn: true # Whether to use fused head and loss function +use_fused_linear: true # Whether to use fused linear operations. 像是一个没有用的参数 +recompute: false # Whether to enable gradient checkpointing for memory optimization +recompute_use_reentrant: false # Whether to use reentrant recompute +recompute_granularity: "full" # Granularity of recompute +bf16: true # Whether to use mixed precision with bfloat16 +fp16_opt_level: "O2" # Optimization level for fp16 and bf16 training +amp_master_grad: false # Whether to use float32 weight gradients for master weights in amp opt level=’O2’ +amp_custom_black_list: ["reduce_sum", "softmax_with_cross_entropy", "c_softmax_with_cross_entropy", "elementwise_div", "sin", "cos"] # Custom black list for amp +amp_custom_white_list: ["lookup_table", "lookup_table_v2", "flash_attn", "matmul", "matmul_v2", "fused_gemm_epilogue"] # Custom white list for amp +offload_level: "freeze_model" # Level of model offloading to pinned memory, supported values: freeze_model, train_model, optimizer +release_grads: true # Whether to release gradients +offload_optim: false # Whether to offload optimizer to pinned memory + +# benchmark args +skip_profile_timer: false # Whether to skip profiling time \ No newline at end of file diff --git a/paddlenlp/rl/algos/advantage.py b/paddlenlp/rl/algos/advantage.py index b0bc1891aba1..6985049d40d1 100644 --- a/paddlenlp/rl/algos/advantage.py +++ b/paddlenlp/rl/algos/advantage.py @@ -20,57 +20,32 @@ from ..utils.comm_utils import masked_whiten +@paddle.no_grad() def compute_gae_advantage_return( token_level_rewards: paddle.Tensor, values: paddle.Tensor, sequence_mask: paddle.Tensor, - start: int, gamma: paddle.Tensor, lam: paddle.Tensor, - use_tgt_len_return: bool = True, ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Compute advantages and returns using Generalized Advantage Estimation (GAE).""" # Modified from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py + lastgaelam = 0.0 advantages_reversed = [] gen_len = token_level_rewards.shape[-1] values = values * sequence_mask token_level_rewards = token_level_rewards * sequence_mask - if use_tgt_len_return and start > 0: - # consistent with Beaver - # values length is src+tgt-1, start is src-1, return length is tgt - pass - elif use_tgt_len_return: - # values length is tgt, start is 0, return length is tgt - assert start == 0 - else: - # values length is src+tgt-1, start is src-1, return length is src+tgt-1 - pass - for t in reversed(range(start, gen_len)): # pylint: disable=invalid-name + + for t in reversed(range(0, gen_len)): next_values = values[:, t + 1] if t < gen_len - 1 else 0.0 delta = token_level_rewards[:, t] + gamma * next_values - values[:, t] lastgaelam = delta + gamma * lam * lastgaelam advantages_reversed.append(lastgaelam) advantages = paddle.stack(advantages_reversed[::-1], axis=1) - returns = advantages + values[:, start:].contiguous() - - if not use_tgt_len_return: - advantages = paddle.concat( - [ - paddle.zeros([advantages.shape[0], start], dtype=advantages.dtype), - advantages, - ], - axis=-1, - ) - returns = paddle.concat( - [ - paddle.zeros([returns.shape[0], start], dtype=returns.dtype), - returns, - ], - axis=-1, - ) + returns = advantages + values.contiguous() return advantages.detach(), returns @@ -158,6 +133,7 @@ def compute_reinforce_plus_plus_advantages_and_returns( return advantages, returns +@paddle.no_grad() def add_kl_divergence_regularization( prompt: paddle.Tensor, # size = (B, S) # pylint: disable=unused-argument log_probs: paddle.Tensor, # size = (B, L) diff --git a/paddlenlp/rl/models/ppo_model_utils.py b/paddlenlp/rl/models/ppo_model_utils.py index 50b39d3684e2..84e4ae4f2c03 100644 --- a/paddlenlp/rl/models/ppo_model_utils.py +++ b/paddlenlp/rl/models/ppo_model_utils.py @@ -599,7 +599,7 @@ def forward( @merge_fwd_labels class RLHFValueLoss(nn.Layer): - def __init__(self, config, clip_range_value=5.0): + def __init__(self, config, clip_range_value=5.0, use_fp32_compute=False): """ Initializes the `ClipRewardRange` object. @@ -617,6 +617,7 @@ def __init__(self, config, clip_range_value=5.0): super().__init__() self.clip_range_value = clip_range_value self.config = config + self.use_fp32_compute = use_fp32_compute def critic_loss_fn( self, @@ -636,54 +637,48 @@ def critic_loss_fn( vf_loss2 = paddle.square(values_clipped - returns) return 0.5 * paddle.sum(paddle.maximum(vf_loss1, vf_loss2) * mask) / mask.sum() - def forward(self, reward_values, old_reward_values, reward_returns, sequence_mask): - """ - 计算奖励值的损失函数。 - 如果输入的奖励值和旧奖励值的长度相同,则使用给定的序列掩码来确定有效长度。 - 如果输入的奖励值的长度比旧奖励值少一个,则将最后一个元素视为与输入IDs一致的填充,并删除它。 - 否则,奖励值只有tgt长度。 + def forward( + self, + reward_values, + old_reward_values, + reward_returns, + sequence_mask, + response_start=0, + # for varlen flaskmask + pad_size=0, + raw_input_ids=None, + indices=None, + raw_input_shape=None, + input_ids_rmpad_rolled=None, + ): + """ """ + reward_values = reward_values[0].squeeze(0) + if self.config.sequence_parallel: + from paddle.distributed.fleet.utils.sequence_parallel_utils import GatherOp - Args: - reward_values (paddle.Tensor, list of paddle.Tensor or None, optional): 奖励值,可以是单个张量或列表中的多个张量。默认为None。 - old_reward_values (paddle.Tensor, optional): 旧奖励值。 - reward_returns (paddle.Tensor, optional): 奖励返回值。 - sequence_mask (paddle.Tensor, optional): 序列掩码。 + reward_values = GatherOp.apply(reward_values) - Returns: - paddle.Tensor, float32: 奖励值的损失函数。 + use_remove_padding = indices is not None + if use_remove_padding: + if pad_size > 0: + reward_values = reward_values[:-pad_size, :] - Raises: - ValueError: 当奖励值和旧奖励值的长度不匹配时引发。 - """ - reward_values = reward_values if isinstance(reward_values, paddle.Tensor) else reward_values[0] - reward_values = reward_values.squeeze(axis=-1)[:, :-1] - if reward_values.shape[1] == old_reward_values.shape[1]: - # labels (old_reward_values, reward_returns, sequence_mask) has - # src+tgt-1 length, valid length is determined by sequence_mask - pass - elif reward_values.shape[1] < old_reward_values.shape[1]: - # labels (old_reward_values, reward_returns, sequence_mask) has - # src+tgt length and the last one is a padding to be consistent - # with input_ids - assert reward_values.shape[1] == old_reward_values.shape[1] - 1 - reward_values = paddle.concat( - [ - reward_values, - paddle.zeros([reward_values.shape[0], 1], dtype=reward_values.dtype), - ], - -1, - ) - else: - # labels (old_reward_values, reward_returns, sequence_mask) has - # tgt length - reward_values = reward_values[:, -old_reward_values.shape[1] :] + from ..utils.bert_padding import pad_input + + reward_values = pad_input( + reward_values.squeeze(0), indices, batch=raw_input_shape[0], seqlen=raw_input_shape[1] + ).contiguous() + + if self.use_fp32_compute and reward_values.dtype != paddle.float32: + reward_values = reward_values.cast(paddle.float32) + + reward_values = reward_values.squeeze(axis=-1)[:, response_start:-1] reward_critic_loss = self.critic_loss_fn( reward_values, old_reward_values, reward_returns, sequence_mask, ) - return reward_critic_loss diff --git a/paddlenlp/rl/models/reward_utils.py b/paddlenlp/rl/models/reward_utils.py new file mode 100644 index 000000000000..e0ed8524c92e --- /dev/null +++ b/paddlenlp/rl/models/reward_utils.py @@ -0,0 +1,41 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + + +def extract_solution(solution_str, method="strict"): + assert method in ["strict", "flexible"] + + if method == "strict": + # this also tests the formatting of the model + solutions = re.findall("#### (\\-?[0-9\\.\\,]+)", solution_str) + if len(solutions) == 0: + final_answer = None + else: + # take the last solution + final_answer = solutions[-1].replace(",", "").replace("$", "") + elif method == "flexible": + answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str) + final_answer = None + if len(answer) == 0: + # no reward is there is no answer + pass + else: + invalid_str = ["", "."] + # find the last number that is not '.' + for final_answer in reversed(answer): + if final_answer not in invalid_str: + break + return final_answer diff --git a/paddlenlp/rl/models/score_model.py b/paddlenlp/rl/models/score_model.py index 841525bc2bc3..c1b36a993419 100644 --- a/paddlenlp/rl/models/score_model.py +++ b/paddlenlp/rl/models/score_model.py @@ -18,6 +18,7 @@ import paddle from paddle import nn +from paddle.distributed.fleet.utils.sequence_parallel_utils import GatherOp from ...transformers import PretrainedConfig, PretrainedModel from ...transformers.auto import AutoModel @@ -40,10 +41,35 @@ def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None: TypeError: If the config is not an instance of `PretrainedConfig`. """ super().__init__(config) - self.model = AutoModel.from_config(config) - # config.architectures = [self.__class__.__name__] - self.init_score_head(config, hidden_size=config.hidden_size, **kwargs) + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs): + config = kwargs.pop("config", None) + model = cls(config, **kwargs) + model.config = config + model.model = AutoModel.from_pretrained(pretrained_model_name_or_path, config=config) + model.init_score_head(config, hidden_size=config.hidden_size, **kwargs) + model.head_init_weights() + return model + + def head_init_weights(self): + self.score_head.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range, + shape=self.score_head.weight.shape, + ) + ) + if hasattr(self.score_head, "bias") and isinstance(self.score_head.bias, paddle.Tensor): + self.score_head.bias.set_value(paddle.zeros_like(self.score_head.bias)) + + @classmethod + def from_config(cls, config, **kwargs): + model = cls(config, **kwargs) + model.model = AutoModel.from_config(config) + model.config = config + model.init_score_head(config, hidden_size=config.hidden_size, **kwargs) + return model def get_input_embeddings(self) -> Optional[nn.Embedding]: """ @@ -91,7 +117,6 @@ def set_decoder(self, decoder: PretrainedModel) -> None: def forward( self, input_ids: paddle.Tensor, - attention_mask: paddle.Tensor, position_ids: paddle.Tensor | None = None, past_key_values: list[paddle.Tensor] | None = None, inputs_embeds: paddle.Tensor | None = None, @@ -99,6 +124,8 @@ def forward( output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, + attn_mask_startend_row_indices: paddle.Tensor | None = None, + **kwargs ) -> tuple[paddle.Tensor, paddle.Tensor] | ScoreModelOutput: """ Forward pass of the sentence. @@ -130,7 +157,7 @@ def forward( AssertionError: Raised when `attention_mask` is not None. """ - assert attention_mask is not None + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -139,19 +166,26 @@ def forward( outputs = self.model( input_ids=input_ids, - attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, return_dict=return_dict, ) - hidden_states = outputs[0] # size = (B, L, E) + critic_hidden_states = outputs[0] # size = (B, L, E) + + if self.config.sequence_parallel: + gather_hidden_states = GatherOp.apply(critic_hidden_states) + + hidden_states = paddle.reshape_( + gather_hidden_states, [-1, position_ids.shape[1], gather_hidden_states.shape[-1]] + ) + return self.get_score( hidden_states, - attention_mask=attention_mask, position_ids=position_ids, return_dict=return_dict, ) diff --git a/paddlenlp/rl/models/score_model_utils.py b/paddlenlp/rl/models/score_model_utils.py index 73d8b901cc26..eaedc778bdf3 100644 --- a/paddlenlp/rl/models/score_model_utils.py +++ b/paddlenlp/rl/models/score_model_utils.py @@ -53,6 +53,7 @@ class ScoreModelMixin: normalize_function: NormalizeFunction = "affine" _initialized: bool = False + @classmethod def init_score_head(self, config: PretrainedConfig, hidden_size: int, **kwargs: Any) -> None: """Initialize the score head.""" if self._initialized: @@ -113,11 +114,13 @@ def init_score_head(self, config: PretrainedConfig, hidden_size: int, **kwargs: def get_score( self, hidden_state: paddle.Tensor, # size = (B, L, E) - attention_mask: paddle.Tensor | None = None, # size = (B, L) position_ids: paddle.Tensor | None = None, # size = (B, L) + attn_mask_startend_row_indices: paddle.Tensor | None = None, # size = (B, 1), (B, 2), (B, 3) or (B, 4) return_dict: bool | None = None, + attention_mask: paddle.Tensor | None = None, # size = (B, L) ) -> ScoreModelOutput: """Forward pass of the score model.""" + hidden_state = hidden_state.cast(paddle.float32) scores = self.score_head(hidden_state) # size = (B, L, D) if scores.dtype != hidden_state.dtype: # EB rm cast to float32 scores = scores.cast(hidden_state.dtype) @@ -170,7 +173,6 @@ def get_score( if self.do_normalize: scores = self.normalizer.normalize(scores) - end_score = self.normalizer.normalize(end_score) if not return_dict: return scores, end_score @@ -180,6 +182,8 @@ def get_score( end_scores=end_score, # size = (B, D) ) + return scores + def set_normalize(self, mode: bool = True) -> None: """ Set whether to normalize the input data, default is True. diff --git a/paddlenlp/rl/trainer/actor_trainer.py b/paddlenlp/rl/trainer/actor_trainer.py index 9d7938fe064a..c1e03b27939b 100644 --- a/paddlenlp/rl/trainer/actor_trainer.py +++ b/paddlenlp/rl/trainer/actor_trainer.py @@ -164,6 +164,7 @@ def compute_logprob(self, batch: DataProto, key) -> DataProto: {key: paddle.concat(log_probs_list, axis=0)}, meta_info={"temperature": self.args.temperature} ) + @paddle.no_grad() def compute_fused_logprob( self, input_ids: paddle.Tensor, key, position_ids: paddle.Tensor = None, prompt=None, loop_chunk_size=1024 ) -> DataProto: @@ -421,7 +422,7 @@ def generate_sequences(self, prompt_only_batch: DataProto, do_eval=False) -> Lis if repeat_num > 1: input_ids = input_ids.repeat_interleave(repeat_num, axis=0) - if self.args.use_rm_server: + if self.args.use_rm_server or self.args.use_rule_reward: label_ids = prompt_only_batch.batch["label_ids"] if repeat_num > 1: label_ids = label_ids.repeat_interleave(repeat_num, axis=0) @@ -437,7 +438,7 @@ def generate_sequences(self, prompt_only_batch: DataProto, do_eval=False) -> Lis "input_ids": seq, **( {"label_ids": label_ids[idx * len(seq) : (idx + 1) * len(seq)]} - if self.args.use_rm_server + if self.args.use_rm_server or self.args.use_rule_reward else {} ), # tgt response "index": np.array([str(uuid.uuid4())] * len(seq), dtype=object), diff --git a/paddlenlp/rl/trainer/critic_trainer.py b/paddlenlp/rl/trainer/critic_trainer.py index 5574bb3493ca..e803e0e42bc8 100644 --- a/paddlenlp/rl/trainer/critic_trainer.py +++ b/paddlenlp/rl/trainer/critic_trainer.py @@ -15,12 +15,15 @@ from __future__ import annotations +import paddle + from ...datasets.rlhf_datasets.protocol import DataProto from ...transformers import PretrainedTokenizer from ..models.ppo_model_utils import RLHFValueLoss, create_startend_row_indices -from ..utils.comm_utils import CriticStages -from ..utils.offload_utils import reload_and_offload_scope -from ..utils.timer_utils import TimerScope + +# from ..utils.comm_utils import CriticStages +# from ..utils.offload_utils import reload_and_offload_scope +# from ..utils.timer_utils import TimerScope from .rl_trainer import RLTrainer @@ -30,26 +33,77 @@ class CriticTrainer(RLTrainer): # define loss name for logging loss_identifier = lambda self, inputs: "reward_critic_loss" + @paddle.no_grad() def compute_value( self, batch: DataProto, input_ids_tokenizer: PretrainedTokenizer = None, ) -> DataProto: + self.model.eval() + input_ids = batch.batch["input_ids"] position_ids = batch.batch["position_ids"] - # TODO: confirm actor_tokenizer or reward_tokenizer or critic_tokenizer - # need retokenize? - attn_mask_startend_row_indices = create_startend_row_indices(input_ids, self.tokenizer.pad_token_id) - reward_value = self.model( - input_ids, - attention_mask=None, - position_ids=position_ids, - attn_mask_startend_row_indices=attn_mask_startend_row_indices, - )[0] - reward_value = reward_value.squeeze(axis=-1) - reward_value = reward_value[:, :-1] - return DataProto.from_single_dict({"reward_value": reward_value}) + values_list = [] + batch_size, sequence_length = input_ids.shape + per_device_value_batch_size = self.args.per_device_value_batch_size + num_batches = (batch_size + per_device_value_batch_size - 1) // per_device_value_batch_size + startend_row_indices = create_startend_row_indices(input_ids, self.tokenizer.pad_token_id) + response_start = batch.batch["prompt"].shape[-1] - 1 if "prompt" in batch.batch else 0 + for i in range(num_batches): + start_index = i * per_device_value_batch_size + end_index = min(start_index + per_device_value_batch_size, batch_size) + + # Extract the current batch + current_input_ids = input_ids[start_index:end_index] + current_startend_row_indices = ( + startend_row_indices[start_index:end_index] if startend_row_indices is not None else None + ) + current_position_ids = position_ids[start_index:end_index] if position_ids is not None else None + if self.args.use_remove_padding: + from ..utils.bert_padding import prepare_flashmask_inputs + + update_inputs = prepare_flashmask_inputs( + current_input_ids, + current_position_ids, + self.tokenizer.pad_token_id, + self.model.config.sequence_parallel, + self.model.config.tensor_parallel_degree, + ) + current_input_ids = update_inputs["input_ids"] + current_position_ids = update_inputs["position_ids"] + current_startend_row_indices = update_inputs["attn_mask_startend_row_indices"] + indices = update_inputs["indices"] + raw_input_shape = update_inputs["raw_input_shape"] + pad_size = update_inputs["pad_size"] + reward_value = self.model( + current_input_ids, + position_ids=current_position_ids, + attn_mask_startend_row_indices=current_startend_row_indices, + use_cache=False, + )[0] + reward_value = reward_value.squeeze(0) + if self.model.config.sequence_parallel: + from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + GatherOp, + ) + + reward_value = GatherOp.apply(reward_value) + + if self.args.use_remove_padding: + from ..utils.bert_padding import pad_input + + if pad_size > 0: + reward_value = reward_value[:-pad_size, :] + reward_value = pad_input( + reward_value.squeeze(0).unsqueeze(-1), indices, batch=raw_input_shape[0], seqlen=raw_input_shape[1] + ).squeeze(-1) + reward_value = reward_value[:, response_start:-1].contiguous() + values_list.append(reward_value.squeeze(-1)) + reward_value = None + paddle.device.cuda.empty_cache() + + return DataProto.from_single_dict({"reward_values": paddle.concat(values_list, axis=0)}) def update_critic(self, rl_batch: DataProto) -> DataProto: """ @@ -66,30 +120,30 @@ def update_critic(self, rl_batch: DataProto) -> DataProto: Returns (Dict[str, Any]): - train_value_loss (float): Training loss of the critic (reward function). """ + self.model.train() # Inputs shared by policy and value trainer input_ids = rl_batch.batch["input_ids"].contiguous() # length: src+tgt - attention_mask = rl_batch.batch["attention_mask"] # length: src+tgt position_ids = rl_batch.batch["position_ids"] # length: src+tgt - sequence_mask = rl_batch.batch["sequence_mask"] # length: src+tgt(-1) + sequence_mask = rl_batch.batch["eos_mask"] # length: src+tgt(-1) + if self.args.use_fp32_compute and sequence_mask.dtype != paddle.float32: + sequence_mask = sequence_mask.cast(paddle.float32) # Inputs used by value trainer old_reward_values = rl_batch.batch["reward_values"] # length: src+tgt(-1) reward_returns = rl_batch.batch["reward_returns"] # length: src+tgt(-1) + attn_mask_startend_row_indices = create_startend_row_indices(input_ids, self.tokenizer.pad_token_id) value_trainer_inputs = { "input_ids": input_ids, - "attention_mask": attention_mask, "position_ids": position_ids, "old_reward_values": old_reward_values, "reward_returns": reward_returns, "sequence_mask": sequence_mask, + "response_start": rl_batch.batch["prompt"].shape[-1] - 1, + "attn_mask_startend_row_indices": attn_mask_startend_row_indices, } - with TimerScope( - self.timers, CriticStages.MODEL_ENABLE_DISABLE, minus_names=[CriticStages.CRITIC_TRAINING_STEP] - ): - with reload_and_offload_scope(self, self.model, self.optimizer): - with TimerScope(self.timers, CriticStages.CRITIC_TRAINING_STEP): - reward_critic_loss = self.full_training_step(**value_trainer_inputs) + reward_critic_loss = self.full_training_step(**value_trainer_inputs) - return DataProto(meta_info={"metrics": {"train_value_loss": reward_critic_loss}}) + # return DataProto(meta_info={"metrics": {"train_value_loss": reward_critic_loss}}) + return {"train_value_loss": reward_critic_loss} diff --git a/paddlenlp/rl/trainer/ppo_trainer.py b/paddlenlp/rl/trainer/ppo_trainer.py index f130257111da..2ccd738f44df 100644 --- a/paddlenlp/rl/trainer/ppo_trainer.py +++ b/paddlenlp/rl/trainer/ppo_trainer.py @@ -46,6 +46,7 @@ logger, speed_metrics, ) +from ...trainer.trainer_utils import SchedulerType from ...trainer.utils.helper import broadcast_dataset_rank0_model, distributed_concat from ...transformers import ( CosineAnnealingWithWarmupDecay, @@ -287,6 +288,37 @@ def __init__( preprocess_logits_for_metrics, ) + self.train_dataset = train_dataset + self.eval_dataset = eval_dataset + self._model_config = actor_model.config + self._actor_model_eval = actor_model_eval + self._critic_model_eval = critic_model_eval + + # ##### trainging data and related num setting ##### + with ( + guard_set_args( + args, + {"per_device_train_batch_size": self.args.global_gen_batch_size // self.args.dataset_world_size}, + ), + guard_set_args( + self, + {"train_dataset": self.train_dataset, "data_collator": self.data_collator}, + ), + ): + self.train_dataloader = self.prompt_only_dataloader = self.get_train_dataloader() # 64 + + ( + self.total_train_batch_size, + self.len_dataloader, + self.max_steps, + self.num_train_epochs, + self.num_update_steps_per_epoch, + self.num_examples_, # There is a problem with duplicate names + self.num_train_samples, + ) = self.init_train_num(self.train_dataloader) + + args.max_steps = self.max_steps + self.reshard_controller = reshard_controller trainer_agrs = { # "model": None, @@ -331,11 +363,6 @@ def __init__( **trainer_agrs, ) - self.train_dataset = train_dataset - self.eval_dataset = eval_dataset - self._model_config = actor_model.config - self._actor_model_eval = actor_model_eval - self._critic_model_eval = critic_model_eval self.reference_model.eval() if isinstance(reward_model, PretrainedModel): self.reward_model.eval() @@ -350,7 +377,9 @@ def __init__( self.kl_coeff = self.args.kl_coeff self.clip_range_score = self.args.clip_range_score self.gamma = 1.0 - self.gae_lambda = 0.95 + # [gae_lambda] value needs to be set manually! + # On the gsm8k benchmark, this value is 1.0. + self.gae_lambda = 1.0 # for reward norm self.reward_mean = 0.0 @@ -430,6 +459,7 @@ def create_critic_trainer( value_training_args = copy.deepcopy(args) for attr_name in [ "critic_learning_rate", + "critic_min_learning_rate", "critic_weight_decay", "critic_lr_scheduler_type", "critic_warmup_ratio", @@ -547,7 +577,7 @@ def create_reward_trainer( reward_server=model, ) - if not self.args.use_rm_server: + if not self.args.use_rm_server and not self.args.use_rule_reward: if args.pipeline_parallel_degree > 1 or ShardingOption.FULL_SHARD in args.sharding: reward_trainer.init_train_model_opt(100, None, clear_master_weight=True) # dummy max_steps @@ -638,7 +668,7 @@ def get_scheduler(self, args): warmup_steps = args.warmup_ratio * args.max_steps lr_scheduler = None if args.min_learning_rate is not None: - if args.lr_scheduler_type == "cosine": + if args.lr_scheduler_type == SchedulerType.COSINE or args.lr_scheduler_type == "cosine": lr_scheduler = CosineAnnealingWithWarmupDecay( max_lr=args.learning_rate, min_lr=args.min_learning_rate, @@ -646,7 +676,7 @@ def get_scheduler(self, args): decay_step=args.decay_steps, last_epoch=0, ) - elif args.lr_scheduler_type == "linear": + elif args.lr_scheduler_type == SchedulerType.LINEAR or args.lr_scheduler_type == "linear": lr_scheduler = LinearAnnealingWithWarmupDecay( max_lr=args.learning_rate, min_lr=args.min_learning_rate, @@ -694,7 +724,11 @@ def prediction_step( prompt_only_batch = DataProto.from_single_dict( { "input_ids": inputs["input_ids"], - **({"label_ids": inputs["label_ids"]} if self.args.use_rm_server else {}), + **( + {"label_ids": inputs["label_ids"]} + if self.args.use_rm_server or self.args.use_rule_reward + else {} + ), } ) generated_seq = self.actor_trainer.generate_sequences(prompt_only_batch, do_eval=True)[0].batch[ @@ -703,7 +737,16 @@ def prediction_step( if self.reshard_controller is not None: self.reshard_controller.set_train_env("[after prediction_step]") - if not self.args.use_rm_server: + if self.args.use_rule_reward: + prompt_len = inputs["input_ids"].shape[-1] + if "label_ids" not in inputs: + raise ValueError("Rule-based reward needs labels.") + tgt = self.tokenizer.batch_decode(inputs["label_ids"], skip_special_tokens=False) + response = self.tokenizer.batch_decode(generated_seq[:, prompt_len:], skip_special_tokens=False) + ground_truth = [i.replace(self.tokenizer.pad_token, "") for i in tgt] + response_str = [i.replace(self.tokenizer.pad_token, "") for i in response] + reward_score = self.reward_trainer.compute_score(response_str, ground_truth) + elif not self.args.use_rm_server: if self._model_config.sequence_parallel: # pad to max_sequence_length seq = self.tokenizer.pad( @@ -1028,15 +1071,12 @@ def init_train_num( 7. num_train_samples (int) - The total number of samples in the training data. """ args = self.args - total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.dataset_world_size len_dataloader = None if not self._is_iterable_dataset(self.train_dataset): len_dataloader = len(train_dataloader) - num_train_sub_steps = ( - len_dataloader * self.args.update_iters * self.args.rollout_n // self.args.per_device_train_batch_size - ) - num_update_steps_per_epoch = num_train_sub_steps // args.gradient_accumulation_steps + num_train_sub_steps = self.args.global_mini_batch_size // args.per_device_train_batch_size + num_update_steps_per_epoch = (num_train_sub_steps // args.gradient_accumulation_steps) * len_dataloader num_examples = len(self.train_dataset) if args.max_steps > 0: max_steps = args.max_steps @@ -1044,14 +1084,16 @@ def init_train_num( args.max_steps % num_update_steps_per_epoch > 0 ) else: - max_steps = int(num_update_steps_per_epoch * args.num_train_epochs) + max_steps = int(args.num_train_epochs * num_update_steps_per_epoch) num_train_epochs = math.ceil(args.num_train_epochs) - num_train_samples = total_train_batch_size * max_steps + num_train_samples = ( + self.args.global_batch_size * self.args.update_iters * self.args.rollout_n + ) * len_dataloader else: assert args.max_steps > 0 max_steps = args.max_steps num_train_epochs = sys.maxsize - num_update_steps_per_epoch = args.max_steps + num_update_steps_per_epoch = max_steps num_examples = total_train_batch_size * args.max_steps num_train_samples = args.max_steps * total_train_batch_size @@ -1113,7 +1155,7 @@ def remove_pad_tokens_after_generate(self, generated_batches: List[DataProto]): for row in batch.batch["input_ids"] ] ) - if self.args.use_rm_server: + if self.args.use_rm_server or self.args.use_rule_reward: label_ids_batches.extend( [ process_row( @@ -1237,42 +1279,18 @@ def train( args = self.args self.is_in_train = True - # ##### trainging data and related num setting ##### - # TODO(guosheng): remove the binding method get_collator of dataset - with ( - guard_set_args( - args, - {"per_device_train_batch_size": self.args.global_gen_batch_size // self.args.dataset_world_size}, - ), - guard_set_args( - self, - {"train_dataset": self.train_dataset, "data_collator": self.data_collator}, - ), - ): - train_dataloader = self.prompt_only_dataloader = self.get_train_dataloader() - - ( - total_train_batch_size, - len_dataloader, - max_steps, - num_train_epochs, - num_update_steps_per_epoch, - num_examples, - num_train_samples, - ) = self.init_train_num(train_dataloader) - # ##### model and optimizer related setting ##### - actor_model, critic_model = self.init_train_model_opt(max_steps, resume_from_checkpoint) + actor_model, critic_model = self.init_train_model_opt(self.max_steps, resume_from_checkpoint) paddle.device.cuda.empty_cache() # ##### traing statistic logging ##### # Number of trainable parameters only account for actor_model self.init_train_log( - num_examples, - num_train_epochs, - total_train_batch_size, - max_steps, - num_train_samples, + self.num_examples_, + self.num_train_epochs, + self.total_train_batch_size, + self.max_steps, + self.num_train_samples, actor_model, ) @@ -1281,20 +1299,19 @@ def train( # correct. Thus, data cannot be resumed perfectly when not breaking at epoch end. epochs_trained, steps_trained_in_current_epoch, steps_trained_progress_bar = self.init_train_state( resume_from_checkpoint, - train_dataloader, - max_steps, - num_train_epochs, - num_update_steps_per_epoch, + self.train_dataloader, + self.max_steps, + self.num_train_epochs, + self.num_update_steps_per_epoch, ) - steps_in_epoch = num_update_steps_per_epoch * args.gradient_accumulation_steps - + steps_in_epoch = self.num_update_steps_per_epoch * args.gradient_accumulation_steps # self.callback_handler.model = self.model # self.callback_handler.optimizer = self.optimizer # self.callback_handler.lr_scheduler = self.lr_scheduler # self.callback_handler.train_dataloader = train_dataloader - self.state.max_steps = int(max_steps) - self.state.num_train_epochs = num_train_epochs + self.state.max_steps = int(self.max_steps) + self.state.num_train_epochs = self.num_train_epochs self.state.is_local_process_zero = self.is_local_process_zero() self.state.is_world_process_zero = self.is_world_process_zero() @@ -1321,11 +1338,11 @@ def train( sharding_parallel_group = None data_parallel_group = None - for epoch in range(epochs_trained, num_train_epochs): - if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance( - train_dataloader.batch_sampler, DistributedBatchSampler + for epoch in range(epochs_trained, int(self.num_train_epochs)): + if isinstance(self.train_dataloader, paddle.io.DataLoader) and isinstance( + self.train_dataloader.batch_sampler, DistributedBatchSampler ): - train_dataloader.batch_sampler.set_epoch(epoch) + self.train_dataloader.batch_sampler.set_epoch(epoch) num_gen_batches += 1 self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) @@ -1408,7 +1425,8 @@ def train( return_attention_mask=False, pad_to_multiple_of=pad_to_multiple_of, )["input_ids"] - label_ids = DataProto.pad_batch_data(label_ids_batches, pad_token_id=pad_token_id) + if len(label_ids_batches) > 0: + label_ids = DataProto.pad_batch_data(label_ids_batches, pad_token_id=pad_token_id) position_ids = make_position_ids_from_input_ids(input_ids, pad_token_id=pad_token_id) prompt_len = paddle.full(shape=[expand_prompt.shape[0]], fill_value=expand_prompt.shape[1], dtype=expand_prompt.dtype) # fmt: skip @@ -1424,17 +1442,16 @@ def train( "prompt_len_without_pad": prompt_len_without_pad, "response_len_without_pad": response_len_without_pad, "index": indices, - **({"label_ids": label_ids} if self.args.use_rm_server else {}), + **({"label_ids": label_ids} if self.args.use_rm_server or self.args.use_rule_reward else {}), **( {"raw_label_ids_len": prompt_only_batch_expand.batch["raw_label_ids_len"]} - if self.args.use_rm_server + if self.args.use_rm_server or self.args.use_rule_reward else {} ), } ) batch = data_group_merge(batch, group=data_trans_group) - # step 2-2: balance batches based on batch tokens if self.args.balance_batch: batch = self._balance_batch(batch) @@ -1460,8 +1477,8 @@ def train( ): with reload_and_offload_scope( self, - self.reward_critic_model if self.args.rl_algorithm == "ppo" else None, - self.reward_model if not self.args.use_rm_server else None, + self.critic_model if self.args.rl_algorithm == "ppo" else None, + self.reward_model if not self.args.use_rm_server and not self.args.use_rule_reward else None, ): with TimerScope(self.timers, RolloutStages.ROLLOUT_REWARD_VALUE): reward_tensor = self.reward_trainer.compute_reward( @@ -1563,7 +1580,6 @@ def train( else: local_batch = batch batch = batch - # step 2-3: compute reward normalization batch.batch["ori_rewards"] = batch.batch["rewards"].clone() @@ -1572,7 +1588,7 @@ def train( with TimerScope(self.timers, RolloutStages.ROLLOUT_ADVANTAGE): # step 2-4: compute advantage - batch = self.compute_advantage(batch, use_tgt_len_value=args.use_tgt_len_value) + batch = self.compute_advantage(batch) # step 2-5: compute advantage normalization if self.args.normalize_advantage: @@ -1597,6 +1613,7 @@ def train( for micro_step, micro_batch in enumerate(micro_batches * self.args.update_iters): step = 0 if step == -1 else step + with TimerScopeManualLabel( self.timers, get_timer_label(ActorStages.MICRO_STEPS) + f"_{micro_step}" ): @@ -1606,7 +1623,10 @@ def train( if self.args.rl_algorithm == "ppo": train_value_loss = self.critic_trainer.update_critic(micro_batch) - rl_info.union(train_value_loss) + rl_info.meta_info["metrics"].update(train_value_loss) + + paddle.device.cuda.empty_cache() + if self.is_step_end(): self.state.global_step += 1 self.state.epoch = epoch + (step + 1) / steps_in_epoch @@ -1659,7 +1679,7 @@ def train( metrics = speed_metrics( "train", start_time, - num_samples=num_train_samples, + num_samples=self.num_train_samples, num_steps=self.state.max_steps, ) @@ -1762,57 +1782,6 @@ def _maybe_log_save_evaluate(self, tr_loss: DataProto, model, epoch, ignore_keys with guard_set_args(self.control, {"should_log": False}): super()._maybe_log_save_evaluate(tr_loss.meta_info["metrics"], model, epoch, ignore_keys_for_eval) - def get_advantages_and_returns( - self, - values: paddle.Tensor, - rewards: paddle.Tensor, - sequence_mask: paddle.Tensor, - start: int, - use_tgt_len_return: bool = True, - ) -> Tuple[paddle.Tensor, paddle.Tensor]: - """Compute advantages and returns using Generalized Advantage Estimation (GAE).""" - # Modified from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py - last_gae_lambda = 0.0 - advantages_reversed = [] - values = values * sequence_mask - rewards = rewards * sequence_mask - length = rewards.shape[-1] - if use_tgt_len_return and start > 0: - # consistent with Beaver - # values length is src+tgt-1, start is src-1, return length is tgt - pass - elif use_tgt_len_return: - # values length is tgt, start is 0, return length is tgt - assert start == 0 - else: - # values length is src+tgt-1, start is src-1, return length is src+tgt-1 - pass - for t in reversed(range(start, length)): # pylint: disable=invalid-name - next_values = values[:, t + 1] if t < length - 1 else 0.0 - delta = rewards[:, t] + self.gamma * next_values - values[:, t] - last_gae_lambda = delta + self.gamma * self.gae_lambda * last_gae_lambda - advantages_reversed.append(last_gae_lambda) - advantages = paddle.stack(advantages_reversed[::-1], axis=1) - returns = advantages + values[:, start:].contiguous() - - if not use_tgt_len_return: - advantages = paddle.concat( - [ - paddle.zeros([advantages.shape[0], start], dtype=advantages.dtype), - advantages, - ], - axis=-1, - ) - returns = paddle.concat( - [ - paddle.zeros([returns.shape[0], start], dtype=returns.dtype), - returns, - ], - axis=-1, - ) - - return advantages.detach(), returns - @paddle.no_grad() def compute_reward_normalization(self, batch): batch_rewards = batch["rewards"].cast(paddle.float32) @@ -1857,7 +1826,7 @@ def compute_reward_normalization(self, batch): return batch @paddle.no_grad() - def compute_advantage(self, batch: DataProto, use_tgt_len_value) -> DataProto: + def compute_advantage(self, batch: DataProto) -> DataProto: if "log_probs" in batch.batch: old_log_probs = batch.batch["log_probs"] # length: src + tgt -1 if "ref_log_probs" in batch.batch: @@ -1868,48 +1837,45 @@ def compute_advantage(self, batch: DataProto, use_tgt_len_value) -> DataProto: if self.args.rl_algorithm == "grpo": eos_mask = batch.batch["eos_mask"] - start = 0 reward_advantages = compute_grpo_advantages( - rewards, batch.non_tensor_batch["index"], eos_mask[:, start:], eos_mask.shape[-1] + rewards, batch.non_tensor_batch["index"], eos_mask, eos_mask.shape[-1] ) elif self.args.rl_algorithm == "ppo": - start = batch.batch["prompt"].shape[-1] - 1 - eos_mask = (batch.batch["input_ids"] != self.tokenizer.pad_token_id)[:, 1:].to(old_log_probs.dtype) + eos_mask = (batch.batch["input_ids"] != self.tokenizer.pad_token_id)[ + :, batch.batch["prompt"].shape[-1] : + ].to(old_log_probs.dtype) rewards_with_kl, kl_rewards = add_kl_divergence_regularization( None, # prompt, old_log_probs, ref_log_probs, rewards, - eos_mask[:, start:], + eos_mask, self.kl_coeff, self.clip_range_score, - ) # length: tgt if use_tgt_len_value src + tgt -1 + ) reward_advantages, reward_returns = compute_gae_advantage_return( rewards_with_kl, old_reward_values, - eos_mask[:, start:], - start=0 if use_tgt_len_value else start, + eos_mask, gamma=self.gamma, lam=self.gae_lambda, - use_tgt_len_return=use_tgt_len_value, - ) # length: tgt if use_tgt_len_value src + tgt -1 + ) elif self.args.rl_algorithm == "reinforce_plus_plus": - start = 0 eos_mask = batch.batch["eos_mask"] rewards_with_kl, kl_rewards = add_kl_divergence_regularization( None, # prompt, old_log_probs, ref_log_probs, rewards, - eos_mask[:, start:], + eos_mask, self.kl_coeff, self.clip_range_score, - ) # length: tgt if use_tgt_len_value src + tgt -1 + ) reward_advantages, reward_returns = compute_reinforce_plus_plus_advantages_and_returns( rewards_with_kl, - eos_mask[:, start:], + eos_mask, self.gamma, - ) # length: tgt if use_tgt_len_value src + tgt -1 + ) else: raise ValueError(f"Unknown rl_algorithm: {self.args.rl_algorithm}") @@ -1917,10 +1883,10 @@ def compute_advantage(self, batch: DataProto, use_tgt_len_value) -> DataProto: { # "log_probs": old_log_probs, "reward_advantages": reward_advantages, - "reward_advantages_clean": reward_advantages[eos_mask[:, start:] != 0], + "reward_advantages_clean": reward_advantages[eos_mask != 0], # "ref_log_probs": ref_log_probs, "rewards": rewards, - "eos_mask": eos_mask[:, start:], + "eos_mask": eos_mask, } ) if self.args.rl_algorithm in ["reinforce_plus_plus", "ppo"]: diff --git a/paddlenlp/rl/trainer/reward_trainer.py b/paddlenlp/rl/trainer/reward_trainer.py index e8e958afbe84..00a84061744e 100644 --- a/paddlenlp/rl/trainer/reward_trainer.py +++ b/paddlenlp/rl/trainer/reward_trainer.py @@ -34,6 +34,7 @@ ) from ...transformers import PretrainedModel, PretrainedTokenizer from ..models.ppo_model_utils import create_startend_row_indices +from ..models.reward_utils import extract_solution from .rl_trainer import RLTrainer from .trainer_utils import batch_retokenize @@ -56,7 +57,10 @@ def __init__( preprocess_logits_for_metrics: Optional[Callable[[paddle.Tensor, paddle.Tensor], paddle.Tensor]] = None, reward_server: str = None, ): - if args.use_rm_server: + if args.use_rule_reward: + self.args = args + self.tokenizer = tokenizer + elif args.use_rm_server: assert isinstance(model, str), "reward trainer need a str (http://xxx:port) for request" self.args = args self.tokenizer = tokenizer @@ -87,7 +91,20 @@ def compute_reward( position_ids = batch.batch["position_ids"] label_ids = batch.batch["label_ids"] prompt = batch.batch["prompt"] - if not self.args.use_rm_server: + + if self.args.use_rule_reward: + prompt_len = prompt.shape[-1] + if label_ids is None: + raise ValueError("Rule-based reward needs labels.") + tgt = input_ids_tokenizer.batch_decode(label_ids, skip_special_tokens=False) + response = input_ids_tokenizer.batch_decode(input_ids[:, prompt_len:], skip_special_tokens=False) + ground_truth = [i.replace(self.tokenizer.pad_token, "") for i in tgt] + response = [i.replace(self.tokenizer.pad_token, "") for i in response] + reward_score = self.compute_score( + solution=response, + ground_truth=ground_truth, + ) + elif not self.args.use_rm_server: if self.tokenizer is not input_ids_tokenizer: # right padding reward_tokenize_output = batch_retokenize( @@ -124,8 +141,6 @@ def compute_reward( reward_score = reward_score.squeeze(axis=-1) return reward_score - # if self.args.rl_algorithm in ["grpo", "reinforce_plus_plus"]: - # return {"rewards": reward_score} def request_reward_server(self, src, tgt, response): data = {"src": src, "tgt": tgt, "response": response} @@ -167,4 +182,30 @@ def post(): paddle.distributed.barrier(tp_group) paddle.distributed.broadcast(reward_score, src=tp_group.ranks[0], group=tp_group) - return reward_score.unsqueeze(-1) + reward_score = reward_score.unsqueeze(-1) + return reward_score + + def compute_score(self, solution, ground_truth, method="strict", format_score=0.0, score=1.0): + """The scoring function for GSM8k. + + Reference: Trung, Luong, et al. "Reft: Reasoning with reinforced fine-tuning." Proceedings of the 62nd Annual + Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024. + + Args: + solution: the solution text + ground_truth: the ground truth + method: the method to extract the solution, choices are 'strict' and 'flexible' + format_score: the score for the format + score: the score for the correct answer + """ + reward_tensor = paddle.zeros((len(solution), 1), dtype=paddle.float32) + for i in range(len(solution)): + answer = extract_solution(solution_str=solution[i], method=method) + if answer is None: + reward_tensor[i] = 0 + else: + if answer == ground_truth[i]: + reward_tensor[i] = score + else: + reward_tensor[i] = format_score + return reward_tensor diff --git a/paddlenlp/rl/utils/config_utils.py b/paddlenlp/rl/utils/config_utils.py index 2f407973be2b..9aae90c9001b 100644 --- a/paddlenlp/rl/utils/config_utils.py +++ b/paddlenlp/rl/utils/config_utils.py @@ -300,6 +300,7 @@ class TrainingArguments(TrainingArguments): metadata={"help": "Whether to use tgt for KL."}, ) use_rm_server: bool = field(default=False, metadata={"help": "Use reward server instead of reward model."}) + use_rule_reward: bool = field(default=False, metadata={"help": "Use rule-based reward only for gsm8k, to date."}) use_fp32_compute: bool = field( default=False, metadata={"help": "Use fp32 to compute xx_log_prob,rewards, advantages and loss."} ) @@ -345,7 +346,7 @@ def __post_init__(self): self._post_init_parallel_degree() if self.global_mini_batch_size < 0: - self.global_mini_batch_size = self.global_batch_size + self.global_mini_batch_size = self.global_batch_size // self.dataset_world_size if ( self.global_batch_size % self.dataset_world_size != 0 @@ -382,6 +383,7 @@ def __post_init__(self): // self.per_device_train_batch_size // self.dataset_world_size ) + if self.gradient_accumulation_steps <= 0: logger.warning( f"gradient_accumulation_steps: {self.gradient_accumulation_steps} must be greater than zero!" @@ -443,14 +445,14 @@ def __post_init__(self): self.normalize_advantage = False max_per_device_eval_batch_size = ( - self.global_mini_batch_size * self.rollout_n * self.update_iters // self.dataset_world_size + self.global_batch_size * self.rollout_n * self.update_iters // self.dataset_world_size ) if self.per_device_eval_batch_size > max_per_device_eval_batch_size: logger.warning( f"per_device_eval_batch_size: {self.per_device_eval_batch_size} is larger than " - f"global_mini_batch_size: {self.global_mini_batch_size} * rollout_n: " + f"global_batch_size: {self.global_batch_size} * rollout_n: " f"{self.rollout_n} * update_iters: {self.update_iters}, which may cause infer error. " - f"We will set it to global_mini_batch_size * rollout_n * update_iters // dataset_world_size!" + f"We will set it to global_batch_size * rollout_n * update_iters // dataset_world_size!" ) self.per_device_eval_batch_size = max_per_device_eval_batch_size @@ -530,7 +532,7 @@ class ModelArgument: ) actor_tokenizer_alpha: float = field(default=None, metadata={"help": "Tokenizer will tokenize randomly"}) reward_tokenizer_alpha: float = field(default=None, metadata={"help": "Tokenizer will tokenize randomly"}) - reward_critic_tokenizer_alpha: float = field(default=None, metadata={"help": "Tokenizer will tokenize randomly"}) + critic_tokenizer_alpha: float = field(default=None, metadata={"help": "Tokenizer will tokenize randomly"}) stage: str = field(default="PPO", metadata={"help": "The type of training."}) critic_recompute_granularity: str = field( default="full", diff --git a/paddlenlp/rl/utils/offload_utils.py b/paddlenlp/rl/utils/offload_utils.py index d9edff5e712b..4c37b1d38be5 100644 --- a/paddlenlp/rl/utils/offload_utils.py +++ b/paddlenlp/rl/utils/offload_utils.py @@ -183,14 +183,18 @@ def reload_and_offload_scope(trainer, *args): offload_map = { trainer.actor_model: "train_model", trainer.reference_model: "freeze_model", - **({trainer.reward_model: "freeze_model"} if not trainer.args.use_rm_server else {}), + **( + {trainer.reward_model: "freeze_model"} + if not trainer.args.use_rm_server and not trainer.args.use_rule_reward + else {} + ), trainer.actor_trainer.optimizer: "optimizer", } if trainer.args.rl_algorithm == "ppo": offload_map.update( { - trainer.reward_critic_model: "train_model", + trainer.critic_model: "train_model", trainer.critic_trainer.optimizer: "optimizer", } ) @@ -201,8 +205,8 @@ def reload_and_offload_scope(trainer, *args): # NOTE(gongenlei): for export_evaluate_model objs.append((trainer.actor_model, offload_map.get(trainer.actor_model, ""))) if trainer.args.rl_algorithm == "ppo": - if trainer.reward_critic_model not in [i for i, _ in objs]: + if trainer.critic_model not in [i for i, _ in objs]: if getattr(trainer.critic_trainer, "_inner_eval_model", None) is not None: # NOTE(gongenlei): for export_evaluate_model - objs.append((trainer.reward_critic_model, offload_map.get(trainer.reward_critic_model, ""))) + objs.append((trainer.critic_model, offload_map.get(trainer.critic_model, ""))) return OffloadController(objs) diff --git a/paddlenlp/transformers/qwen2/modeling.py b/paddlenlp/transformers/qwen2/modeling.py index cb604d3a18b0..8747fbc2cc49 100644 --- a/paddlenlp/transformers/qwen2/modeling.py +++ b/paddlenlp/transformers/qwen2/modeling.py @@ -1858,6 +1858,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + attn_mask_startend_row_indices=None, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`paddle.Tensor` of shape `(batch_size,)`, *optional*): @@ -1877,6 +1878,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output)