Skip to content

Commit 2cf528b

Browse files
authored
Merge pull request FlagAI-Open#1 from shunxing1234/yzd
Yzd
2 parents 2768ffc + ccc2da3 commit 2cf528b

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

examples/alm_seq2seq/train.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright © 2022 BAAI. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License")
4+
from flagai.trainer import Trainer
5+
from flagai.model.glm_model import GLMForSeq2Seq
6+
from flagai.data.tokenizer import Tokenizer
7+
from flagai.data.dataset import Seq2SeqDataset
8+
from flagai.test_utils import Seq2SeqCollateArguments
9+
from flagai.data.dataset.superglue.control import DEFAULT_METRICS, CH_TASKS
10+
from flagai.data.dataset import ConstructSeq2seqStrategy
11+
12+
13+
# Compared with original seq2seq, seq2seq dataset is used
14+
# task_name :['cmrc',xxxx]
15+
task_name = "cmrc"
16+
17+
cl_args = Seq2SeqCollateArguments()
18+
trainer = Trainer(env_type='pytorch',
19+
epochs=1,
20+
batch_size=4,
21+
eval_interval=5,
22+
log_interval=50,
23+
experiment_name='glm_large',
24+
pytorch_device='cuda',
25+
load_dir=None,
26+
lr=1e-4)
27+
print("downloading...")
28+
29+
if task_name in CH_TASKS:
30+
model_name = 'ALM-1.0'
31+
else:
32+
model_name = 'GLM-large-en'
33+
34+
tokenizer = Tokenizer.from_pretrained(model_name)
35+
36+
train_dataset = Seq2SeqDataset(task_name=task_name,
37+
data_dir='./datasets/',
38+
dataset_type='train',
39+
tokenizer=tokenizer)
40+
valid_dataset = Seq2SeqDataset(task_name=task_name,
41+
data_dir='./datasets/',
42+
dataset_type='dev',
43+
tokenizer=tokenizer)
44+
collate_fn = ConstructSeq2seqStrategy(cl_args,
45+
tokenizer,
46+
task_name=task_name)
47+
train_dataset.example_list = train_dataset.example_list[:20]
48+
valid_dataset.example_list = valid_dataset.example_list[:20]
49+
50+
model = GLMForSeq2Seq.from_pretrain(model_name=model_name,download_path="/mnt/xw/ALM")
51+
52+
trainer.train(model,
53+
collate_fn=collate_fn,
54+
train_dataset=train_dataset,
55+
valid_dataset=valid_dataset,
56+
metric_methods=[])

0 commit comments

Comments
 (0)