Skip to content

Commit 2ce9c92

Browse files
[Example] add unetformer code (#1193)
* add unetformer code * add unetformer code * add unetformer code * add unetformer code * add unetformer code * add unetformer code * add unetformer code * add unetformer code * add unetformer code * Update examples/unetformer/geoseg/losses/__init__.py --------- Co-authored-by: HydrogenSulfate <[email protected]>
1 parent 02ba0d2 commit 2ce9c92

25 files changed

+3147
-6
lines changed

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计
132132

133133
| 问题类型 | 案例名称 | 优化算法 | 模型类型 | 训练方式 | 数据集 | 参考资料 |
134134
|-----|---------|-----|---------|----|---------|---------|
135-
| 天气预报 | [Extformer-MoE 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/extformer_moe.md) | 数据驱动 | Transformer | 监督学习 | [enso](https://tianchi.aliyun.com/dataset/98942) | - |
135+
| 天气预报 | [Extformer-MoE 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/extformer_moe) | 数据驱动 | Transformer | 监督学习 | [enso](https://tianchi.aliyun.com/dataset/98942) | - |
136136
| 天气预报 | [FourCastNet 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/fourcastnet) | 数据驱动 | AFNO | 监督学习 | [ERA5](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://arxiv.org/pdf/2202.11214.pdf) |
137137
| 天气预报 | [NowCastNet 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/nowcastnet) | 数据驱动 | GAN | 监督学习 | [MRMS](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://www.nature.com/articles/s41586-023-06184-4) |
138138
| 天气预报 | [GraphCast 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/graphcast) | 数据驱动 | GNN | 监督学习 | - | [Paper](https://arxiv.org/abs/2212.12794) |
@@ -142,11 +142,11 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计
142142
| 天气预报 | [Pangu-Weather 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/pangu_weather) | 数据驱动 | Transformer | 监督学习 | - | [Paper](https://arxiv.org/pdf/2211.02556) |
143143
| 大气污染物 | [UNet 污染物扩散](https://aistudio.baidu.com/projectdetail/5663515?channel=0&channelType=0&sUid=438690&shared=1&ts=1698221963752) | 数据驱动 | UNet | 监督学习 | [Data](https://aistudio.baidu.com/datasetdetail/198102) | - |
144144
| 大气污染物 | [STAFNet 污染物浓度预测](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/stafnet) | 数据驱动 | STAFNet | 监督学习 | [Data](https://quotsoft.net/air) | [Paper](https://link.springer.com/chapter/10.1007/978-3-031-78186-5_22) |
145-
| 天气预报 | [DGMR 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/dgmr.md) | 数据驱动 | GAN | 监督学习 | [UK dataset](https://huggingface.co/datasets/openclimatefix/nimrod-uk-1km) | [Paper](https://arxiv.org/pdf/2104.00954.pdf) |
146-
| 地震波形反演 | [VelocityGAN 地震波形反演](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/velocity_gan.md) | 数据驱动 | VelocityGAN | 监督学习 | [OpenFWI](https://openfwi-lanl.github.io/docs/data.html#vel) | [Paper](https://arxiv.org/abs/1809.10262v6) |
147-
| 交通预测 | [TGCN 交通流量预测](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/tgcn.md) | 数据驱动 | GCN & CNN | 监督学习 | [PEMSD4 & PEMSD8](https://paddle-org.bj.bcebos.com/paddlescience/datasets/tgcn/tgcn_data.zip) | - |
148-
| 生成模型| [图像生成中的梯度惩罚应用](./zh/examples/wgan_gp.md)|数据驱动|WGAN GP|监督学习|[Data1](https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz)<br>[Data2](http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz)| [Paper](https://github.com/igul222/improved_wgan_training) |
149-
</details>
145+
| 天气预报 | [DGMR 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/dgmr) | 数据驱动 | GAN | 监督学习 | [UK dataset](https://huggingface.co/datasets/openclimatefix/nimrod-uk-1km) | [Paper](https://arxiv.org/pdf/2104.00954.pdf) |
146+
| 地震波形反演 | [VelocityGAN 地震波形反演](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/velocity_gan) | 数据驱动 | VelocityGAN | 监督学习 | [OpenFWI](https://openfwi-lanl.github.io/docs/data.html#vel) | [Paper](https://arxiv.org/abs/1809.10262v6) |
147+
| 交通预测 | [TGCN 交通流量预测](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/tgcn) | 数据驱动 | GCN & CNN | 监督学习 | [PEMSD4 & PEMSD8](https://paddle-org.bj.bcebos.com/paddlescience/datasets/tgcn/tgcn_data.zip) | - |
148+
| 遥感图像分割 | [UNetFormer 遥感图像分割](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/unetformer) | 数据驱动 | UNetFormer | 监督学习 | [Vaihingen](https://paperswithcode.com/dataset/isprs-vaihingen) | [Paper](https://github.com/WangLibo1995/GeoSeg) |
149+
| 生成模型| [图像生成中的梯度惩罚应用](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/wgan_gp)|数据驱动|WGAN GP|监督学习|[Data1](https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz)<br>[Data2](http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz)| [Paper](https://github.com/igul222/improved_wgan_training) |
150150

151151
## 🕘最近更新
152152

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@
161161
| 大气污染物 | [STAFNet 污染物浓度预测](./zh/examples/stafnet.md) | 数据驱动 | STAFNet | 监督学习 | [Data](https://quotsoft.net/air) | [Paper](https://link.springer.com/chapter/10.1007/978-3-031-78186-5_22) |
162162
| 天气预报 | [DGMR 气象预报](./zh/examples/dgmr.md) | 数据驱动 | GAN | 监督学习 | [UK dataset](https://huggingface.co/datasets/openclimatefix/nimrod-uk-1km) | [Paper](https://arxiv.org/pdf/2104.00954.pdf) |
163163
| 地震波形反演 | [VelocityGAN 地震波形反演](./zh/examples/velocity_gan.md) | 数据驱动 | VelocityGAN | 监督学习 | [OpenFWI](https://openfwi-lanl.github.io/docs/data.html#vel) | [Paper](https://arxiv.org/abs/1809.10262v6) |
164+
| 遥感图像分割 | [UNetFormer分割图像](./zh/examples/unetformer.md) | 数据驱动 | UNetformer | 监督学习 | [Vaihingen](https://paperswithcode.com/dataset/isprs-vaihingen) | [Paper](https://github.com/WangLibo1995/GeoSeg) |
164165
| 交通预测 | [TGCN 交通流量预测](./zh/examples/tgcn.md) | 数据驱动 | GCN & CNN | 监督学习 | [PEMSD4 & PEMSD8](https://paddle-org.bj.bcebos.com/paddlescience/datasets/tgcn/tgcn_data.zip) | - |
165166
| 生成模型| [图像生成中的梯度惩罚应用](./zh/examples/wgan_gp.md)|数据驱动|WGAN GP|监督学习|[Data1](https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz)<br>[Data2](http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz)| [Paper](https://github.com/igul222/improved_wgan_training) |
166167

docs/zh/examples/unetformer.md

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# UNetFormer
2+
3+
!!! note
4+
5+
1. 运行之前,建议快速了解一下[数据集](#31)和[数据读取方式](#32-dataset-api)。
6+
2. 将[Vaihingen数据集]下载到`data`目录中对应的子目录(如`data/vaihingen/train_images`)。
7+
3. 运行tools/vaihingen_patch_split.py处理原数据集,得到可供训练的数据。
8+
9+
文件数据集结构如下
10+
```none
11+
airs
12+
├── unetformer(code)
13+
├── model_weights (save the model weights trained on ISPRS vaihingen)
14+
├── fig_results (save the masks predicted by models)
15+
├── lightning_logs (CSV format training logs)
16+
├── data
17+
│ ├── vaihingen
18+
│ │ ├── train_images (original)
19+
│ │ ├── train_masks (original)
20+
│ │ ├── test_images (original)
21+
│ │ ├── test_masks (original)
22+
│ │ ├── test_masks_eroded (original)
23+
│ │ ├── train (processed)
24+
│ │ ├── test (processed)
25+
```
26+
27+
=== "模型训练命令"
28+
29+
``` sh
30+
# 将[Vaihingen数据集]下载到`data`目录中对应的子目录(如`data/vaihingen/train_images`)
31+
# 创建训练数据集
32+
python tools/vaihingen_patch_split.py --img-dir "data/vaihingen/train_images" --mask-dir "data/vaihingen/train_masks" --output-img-dir "data/vaihingen/train/images_1024" --output-mask-dir "data/vaihingen/train/masks_1024" --mode "train" --split-size 1024 --stride 512
33+
# 创建测试数据集
34+
python tools/vaihingen_patch_split.py --img-dir "data/vaihingen/test_images" --mask-dir "data/vaihingen/test_masks_eroded" --output-img-dir "data/vaihingen/test/images_1024" --output-mask-dir "data/vaihingen/test/masks_1024" --mode "val" --split-size 1024 --stride 1024 --eroded
35+
# 创建masks_1024_rgb可视化数据集
36+
python tools/vaihingen_patch_split.py --img-dir "data/vaihingen/test_images" --mask-dir "data/vaihingen/test_masks" --output-img-dir "data/vaihingen/test/images_1024" --output-mask-dir "data/vaihingen/test/masks_1024_rgb" --mode "val" --split-size 1024 --stride 1024 --gt
37+
# 模型训练
38+
python train_supervision.py -c config/vaihingen/unetformer.py
39+
```
40+
41+
=== "模型评估命令"
42+
43+
``` sh
44+
# 下载处理好的[Vaihingen测试数据集](https://paddle-org.bj.bcebos.com/paddlescience/datasets/unetformer/test.zip),并解压。
45+
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/unetformer/test.zip -P ./data/vaihingen/
46+
unzip -q ./data/vaihingen/test.zip -d data/vaihingen/
47+
# 下载预训练模型文件
48+
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/unetformer/unetformer-r18-512-crop-ms-e105_epoch0_best.pdparams -P ./model_weights/vaihingen/unetformer-r18-512-crop-ms-e105/
49+
python vaihingen_test.py -c config/vaihingen/unetformer.py -o fig_results/vaihingen/unetformer --rgb
50+
```
51+
52+
## 1. 背景简介
53+
54+
遥感城市场景图像的语义分割在众多实际应用中具有广泛需求,例如土地覆盖制图、城市变化检测、环境保护和经济评估等领域。在深度学习技术快速发展的推动下,卷积神经网络(CNN)多年来一直主导着语义分割领域。CNN采用分层特征表示方式,展现出强大的局部信息提取能力。然而卷积层的局部特性限制了网络捕获全局上下文信息的能力。近年来,作为计算机视觉领域的热点研究方向,Transformer架构在全局信息建模方面展现出巨大潜力,显著提升了图像分类、目标检测特别是语义分割等视觉相关任务的性能。
55+
56+
本文提出了一种基于Transformer的解码器架构,构建了类UNet结构的Transformer网络(UNetFormer),用于实时城市场景分割。为实现高效分割,UNetFormer选择轻量级ResNet18作为编码器,并在解码器中开发了高效的全局-局部注意力机制,以同时建模全局和局部信息。本文提出的基于Transformer的解码器与Swin Transformer编码器结合后,在Vaihingen数据集上也取得了当前最佳性能(91.3% F1分数和84.1% mIoU)。
57+
58+
## 2. 模型原理
59+
60+
本段落仅简单介绍模型原理,具体细节请阅读[UNetFormer: A UNet-like Transformer for Efficient
61+
Semantic Segmentation of Remote Sensing Urban
62+
Scene Imagery](https://arxiv.org/abs/2109.08937)
63+
64+
### 2.1 模型结构
65+
66+
UNetFormer是一种基于transformer的解码器的深度学习网络,下图显示了模型的整体结构。
67+
68+
![UNetFormer1](https://paddle-org.bj.bcebos.com/paddlescience/docs/unetformer/unetformer.png)
69+
70+
- `ResBlock`是resnet18网络的各个模块。
71+
72+
- `GLTB`由全局-局部注意、MLP、两个batchnorm层和两个加和操作组成。
73+
74+
### 2.2 损失函数
75+
76+
判别器的损失函数由两部分组成,主损失函数$\mathcal{L}_{\text {p }}$为SoftCrossEntropyLoss交叉熵损失函数$\mathcal{L}_{c e}$和DiceLoss损失函数$\mathcal{L}_{\text {dice }}$。其表达式为:
77+
78+
$$
79+
\mathcal{L}_{c e}=-\frac{1}{N} \sum_{n=1}^{N} \sum_{k=1}^{K} y_{k}^{(n)} \log \hat{y}_{k}^{(n)}
80+
$$
81+
82+
$$
83+
\mathcal{L}_{\text {dice }}=1-\frac{2}{N} \sum_{n=1}^{N} \sum_{k=1}^{K} \frac{\hat{y}_{k}^{(n)} y_{k}^{(n)}}{\hat{y}_{k}^{(n)}+y_{k}^{(n)}}
84+
$$
85+
86+
$$
87+
\mathcal{L}_{\text {p }}=\mathcal{L}_{c e}+\mathcal{L}_{\text {dice }}
88+
$$
89+
90+
其中N、K分别表示样本数量和类别数量。$y^{(n)}$和$\hat{y}^{(n)}$表示标签的one-hot编码和相应的softmax输出,$\mathrm{n} \in[1, \ldots, \mathrm{n}]$。
91+
92+
为了更好的结合,我们选择交叉熵函数作为辅助损失函数${L}_{a u x}$,并且乘以系数$\alpha$总损失函数其表达式为:
93+
94+
$$
95+
\mathcal{L}=\mathcal{L}_{p}+\alpha \times \mathcal{L}_{a u x}
96+
$$
97+
98+
其中,$\alpha$默认为0.4。
99+
100+
## 3. 模型构建
101+
以下我们讲解释用PaddleScience构建UnetFormer的关键部分。
102+
103+
### 3.1 数据集介绍
104+
105+
数据集采用了[ISPRS](https://www.isprs.org/)开源的[Vaihingen](https://www.isprs.org/resources/datasets/benchmarks/UrbanSemLab/2d-sem-label-vaihingen.aspx)数据集。
106+
107+
ISPRS提供了城市分类和三维建筑重建测试项目的两个最先进的机载图像数据集。该数据集采用了由高分辨率正交照片和相应的密集图像匹配技术产生的数字地表模型(DSM)。这两个数据集区域都涵盖了城市场景。Vaihingen是一个相对较小的村庄,有许多独立的建筑和小的多层建筑,该数据集包含33幅不同大小的遥感图像,每幅图像都是从一个更大的顶层正射影像图片提取的,图像选择的过程避免了出现没有数据的情况。顶层影像和DSM的空间分辨率为9 cm。遥感图像格式为8位TIFF文件,由近红外、红色和绿色3个波段组成。DSM是单波段的TIFF文件,灰度等级(对应于DSM高度)为32位浮点值编码。
108+
109+
![image-vaihingen](https://paddle-org.bj.bcebos.com/paddlescience/docs/unetformer/overview_tiles.jpg)
110+
111+
每个数据集已手动分类为6个最常见的土地覆盖类别。
112+
113+
①不透水面 (RGB: 255, 255, 255)
114+
115+
②建筑物(RGB: 0, 0, 255)
116+
117+
③低矮植被 (RGB: 0, 255, 255)
118+
119+
④树木 (RGB: 0, 255, 0)
120+
121+
⑤汽车(RGB: 255, 255, 0)
122+
123+
⑥背景 (RGB: 255, 0, 0)
124+
125+
背景类包括水体和与其他已定义类别不同的物体(例如容器、网球场、游泳池),这些物体通常属于城市场景中的不感兴趣的语义对象。
126+
127+
### 3.2 构建dataset API
128+
129+
由于一份数据集由33个超大遥感图片组成组成。为了方便训练,我们自定义一个图像分割程序,将原始图片分割为1024×1024大小的可训练图片,程序代码具体信息在GeoSeg/tools/vaihingen_patch_split.py中可以看到。
130+
131+
### 3.3 模型构建
132+
133+
本案例的模型搭建代码如下
134+
135+
136+
137+
参数配置如下:
138+
``` py linenums="12"
139+
--8<--
140+
examples/unetformer/config/vaihingen/unetformer.py:12:36
141+
--8<--
142+
```
143+
144+
### 3.4 loss函数
145+
146+
UNetFormer的损失函数由SoftCrossEntropyLoss交叉熵损失函数和DiceLoss损失函数组成
147+
148+
#### 3.4.1 SoftCrossEntropyLoss
149+
150+
151+
``` py linenums="13"
152+
--8<--
153+
examples/unetformer/geoseg/losses/soft_ce.py:13:43
154+
--8<--
155+
```
156+
157+
#### 3.4.2 DiceLoss
158+
159+
``` py linenums="36"
160+
--8<--
161+
examples/unetformer/geoseg/losses/dice.py:36:145
162+
--8<--
163+
```
164+
165+
#### 3.4.2 JointLoss
166+
SoftCrossEntropyLoss和DiceLoss将使用JointLoss进行组合
167+
168+
``` py linenums="23"
169+
--8<--
170+
examples/unetformer/geoseg/losses/joint_loss.py:23:40
171+
--8<--
172+
```
173+
#### 3.4.2 UNetFormerLoss
174+
``` py linenums="93"
175+
--8<--
176+
examples/unetformer/geoseg/losses/useful_loss.py:93:114
177+
--8<--
178+
```
179+
180+
### 3.5 优化器构建
181+
182+
UNetFormer使用AdamW优化器,可直接调用`paddle.optimizer.AdamW`构建,代码如下:
183+
184+
``` py linenums="65"
185+
--8<--
186+
examples/unetformer/config/vaihingen/unetformer.py:65:76
187+
--8<--
188+
```
189+
190+
### 3.6 模型训练
191+
192+
``` py linenums="236"
193+
--8<--
194+
examples/unetformer/train_supervision.py:236:300
195+
--8<--
196+
```
197+
198+
199+
### 3.7 模型测试
200+
201+
``` py linenums="61"
202+
--8<--
203+
examples/unetformer/vaihingen_test.py:61:121
204+
--8<--
205+
```
206+
207+
## 4. 结果展示
208+
209+
使用[Vaihingen](https://www.isprs.org/resources/datasets/benchmarks/UrbanSemLab/2d-sem-label-vaihingen.aspx)数据集的训练结果。
210+
211+
| F1 | mIOU | OA |
212+
| :----: | :----: | :----: |
213+
| 0.9062 | 0.8318 | 0.9283 |
214+
215+
![image-vaihingen1](https://paddle-org.bj.bcebos.com/paddlescience/docs/unetformer/top_mosaic_09cm_area38_0_6.tif)
216+
217+
![image-vaihingen2](https://paddle-org.bj.bcebos.com/paddlescience/docs/unetformer/result.png)
218+
219+
两张图片对比可以看出模型已经精确地分割出遥感图片中建筑、树木、汽车等物体的轮廓,并且很好地处理了重叠区域。
220+
## 6. 参考文献
221+
222+
- [UNetFormer: A UNet-like Transformer for Efficient Semantic Segmentation of Remote Sensing Urban Scene Imagery](https://arxiv.org/abs/2109.08937)
223+
- [https://github.com/WangLibo1995/GeoSeg](https://github.com/WangLibo1995/GeoSeg)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import os
2+
3+
import paddle
4+
from geoseg.datasets.vaihingen_dataset import CLASSES
5+
from geoseg.datasets.vaihingen_dataset import VaihingenDataset
6+
from geoseg.datasets.vaihingen_dataset import train_aug
7+
from geoseg.datasets.vaihingen_dataset import val_aug
8+
from geoseg.losses.useful_loss import UnetFormerLoss
9+
from geoseg.models.UNetFormer import UNetFormer
10+
from tools.utils import process_model_params
11+
12+
max_epoch = 105
13+
ignore_index = len(CLASSES)
14+
train_batch_size = 8
15+
val_batch_size = 8
16+
lr = 0.0006
17+
weight_decay = 0.01
18+
backbone_lr = 6e-05
19+
backbone_weight_decay = 0.01
20+
num_classes = len(CLASSES)
21+
classes = CLASSES
22+
weights_name = "unetformer-r18-512-crop-ms-e105"
23+
weights_path = "model_weights/vaihingen/{}".format(weights_name)
24+
test_weights_name = "unetformer-r18-512-crop-ms-e105_epoch0_best"
25+
log_name = "vaihingen/{}".format(weights_name)
26+
monitor = "val_F1"
27+
monitor_mode = "max"
28+
save_top_k = 1
29+
save_last = True
30+
check_val_every_n_epoch = 1
31+
pretrained_ckpt_path = None
32+
gpus = "auto"
33+
resume_ckpt_path = None
34+
net = UNetFormer(num_classes=num_classes)
35+
loss = UnetFormerLoss(ignore_index=ignore_index)
36+
use_aux_loss = True
37+
os.makedirs("data/vaihingen/train/images_1024", exist_ok=True)
38+
os.makedirs("data/vaihingen/train/masks_1024", exist_ok=True)
39+
if len(os.listdir("data/vaihingen/train/images_1024")) == 0:
40+
pass
41+
else:
42+
train_dataset = VaihingenDataset(
43+
data_root="data/vaihingen/train",
44+
mode="train",
45+
mosaic_ratio=0.25,
46+
transform=train_aug,
47+
)
48+
train_loader = paddle.io.DataLoader(
49+
dataset=train_dataset,
50+
batch_size=train_batch_size,
51+
num_workers=4,
52+
shuffle=True,
53+
drop_last=True,
54+
)
55+
val_dataset = VaihingenDataset(transform=val_aug)
56+
test_dataset = VaihingenDataset(data_root="data/vaihingen/test", transform=val_aug)
57+
58+
val_loader = paddle.io.DataLoader(
59+
dataset=val_dataset,
60+
batch_size=val_batch_size,
61+
num_workers=4,
62+
shuffle=False,
63+
drop_last=False,
64+
)
65+
layerwise_params = {
66+
"backbone.*": dict(lr=backbone_lr, weight_decay=backbone_weight_decay)
67+
}
68+
net_params = process_model_params(net, layerwise_params=layerwise_params)
69+
optimizer = paddle.optimizer.AdamW(
70+
parameters=net_params, learning_rate=lr, weight_decay=weight_decay
71+
)
72+
tmp_lr = paddle.optimizer.lr.CosineAnnealingWarmRestarts(
73+
T_0=15, T_mult=2, learning_rate=optimizer.get_lr()
74+
)
75+
optimizer.set_lr_scheduler(tmp_lr)
76+
lr_scheduler = tmp_lr

examples/unetformer/geoseg/__init__.py

Whitespace-only changes.

examples/unetformer/geoseg/datasets/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)