-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Expand file tree
/
Copy pathorpo.py
More file actions
106 lines (93 loc) · 3.38 KB
/
orpo.py
File metadata and controls
106 lines (93 loc) · 3.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
# "trl[peft]",
# "trackio",
# "kernels",
# ]
# ///
"""
Run the ORPO training script with the following command with some example arguments.
In general, the optimal configuration for ORPO will be similar to that of DPO without the need for a reference model:
# regular:
python examples/scripts/orpo.py \
--dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style \
--model_name_or_path gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 8e-6 \
--gradient_accumulation_steps 1 \
--eval_steps 500 \
--output_dir "gpt2-aligned-orpo" \
--warmup_steps 150 \
--logging_first_step \
--no_remove_unused_columns
# peft:
python examples/scripts/orpo.py \
--dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style \
--model_name_or_path gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 8e-5 \
--gradient_accumulation_steps 1 \
--eval_steps 500 \
--output_dir "gpt2-lora-aligned-orpo" \
--optim rmsprop \
--warmup_steps 150 \
--logging_first_step \
--no_remove_unused_columns \
--use_peft \
--lora_r 16 \
--lora_alpha 16
"""
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from trl import ModelConfig, ScriptArguments, get_peft_config
from trl.experimental.orpo import ORPOConfig, ORPOTrainer
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, ORPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_into_dataclasses()
################
# Model & Tokenizer
################
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
################
# Dataset
################
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
################
# Training
################
trainer = ORPOTrainer(
model,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
processing_class=tokenizer,
peft_config=get_peft_config(model_args),
)
# train and save the model
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)