|
| 1 | +# Copyright © 2022 BAAI. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License") |
| 4 | +from flagai.trainer import Trainer |
| 5 | +from flagai.model.glm_model import GLMForSeq2Seq |
| 6 | +from flagai.data.tokenizer import Tokenizer |
| 7 | +from flagai.data.dataset import Seq2SeqDataset |
| 8 | +from flagai.test_utils import Seq2SeqCollateArguments |
| 9 | +from flagai.data.dataset.superglue.control import DEFAULT_METRICS, CH_TASKS |
| 10 | +from flagai.data.dataset import ConstructSeq2seqStrategy |
| 11 | + |
| 12 | + |
| 13 | +# Compared with original seq2seq, seq2seq dataset is used |
| 14 | +# task_name :['cmrc',xxxx] |
| 15 | +task_name = "cmrc" |
| 16 | + |
| 17 | +cl_args = Seq2SeqCollateArguments() |
| 18 | +trainer = Trainer(env_type='pytorch', |
| 19 | + epochs=1, |
| 20 | + batch_size=4, |
| 21 | + eval_interval=5, |
| 22 | + log_interval=50, |
| 23 | + experiment_name='glm_large', |
| 24 | + pytorch_device='cuda', |
| 25 | + load_dir=None, |
| 26 | + lr=1e-4) |
| 27 | +print("downloading...") |
| 28 | + |
| 29 | +if task_name in CH_TASKS: |
| 30 | + model_name = 'ALM-1.0' |
| 31 | +else: |
| 32 | + model_name = 'GLM-large-en' |
| 33 | + |
| 34 | +tokenizer = Tokenizer.from_pretrained(model_name) |
| 35 | + |
| 36 | +train_dataset = Seq2SeqDataset(task_name=task_name, |
| 37 | + data_dir='./datasets/', |
| 38 | + dataset_type='train', |
| 39 | + tokenizer=tokenizer) |
| 40 | +valid_dataset = Seq2SeqDataset(task_name=task_name, |
| 41 | + data_dir='./datasets/', |
| 42 | + dataset_type='dev', |
| 43 | + tokenizer=tokenizer) |
| 44 | +collate_fn = ConstructSeq2seqStrategy(cl_args, |
| 45 | + tokenizer, |
| 46 | + task_name=task_name) |
| 47 | +train_dataset.example_list = train_dataset.example_list[:20] |
| 48 | +valid_dataset.example_list = valid_dataset.example_list[:20] |
| 49 | + |
| 50 | +model = GLMForSeq2Seq.from_pretrain(model_name=model_name,download_path="/mnt/xw/ALM") |
| 51 | + |
| 52 | +trainer.train(model, |
| 53 | + collate_fn=collate_fn, |
| 54 | + train_dataset=train_dataset, |
| 55 | + valid_dataset=valid_dataset, |
| 56 | + metric_methods=[]) |
0 commit comments