Skip to content

Commit 2f3f493

Browse files
lucasb-eyerxiaohuazhaiakolesnikoff
committed
Release distillation and scaling ViT projects.
And a bunch of small fixes and improvements we made over time. Co-authored-by: Xiaohua Zhai <[email protected]> Co-authored-by: Alexander Kolesnikov <[email protected]>
1 parent e9fb55d commit 2f3f493

File tree

18 files changed

+1656
-24
lines changed

18 files changed

+1656
-24
lines changed

README.md

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,20 @@ codebase:
3333
Xiaohua Zhai*, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer,
3434
Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby*
3535
- [Scaling Vision Transformers](https://arxiv.org/abs/2106.04560), by
36-
Xiaohua Zhai*, Alexander Kolesnikov*, Neil Houlsby, and Lucas Beyer*
36+
Xiaohua Zhai*, Alexander Kolesnikov*, Neil Houlsby, and Lucas Beyer*\
37+
Resources: [config](configs/proj/scaling_laws/train_vit_g.py).
3738
- [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270), by
3839
Andreas Steiner*, Alexander Kolesnikov*, Xiaohua Zhai*, Ross Wightman,
3940
Jakob Uszkoreit, and Lucas Beyer*
4041
- [MLP-Mixer: An all-MLP Architecture for Vision](https://arxiv.org/abs/2105.01601), by
4142
Ilya Tolstikhin*, Neil Houlsby*, Alexander Kolesnikov*, Lucas Beyer*,
4243
Xiaohua Zhai, Thomas Unterthiner, Jessica Yung, Andreas Steiner,
4344
Daniel Keysers, Jakob Uszkoreit, Mario Lucic, Alexey Dosovitskiy
45+
- [Better plain ViT baselines for ImageNet-1k](https://arxiv.org/abs/2205.01580), by
46+
Lucas Beyer, Xiaohua Zhai, Alexander Kolesnikov\
47+
Resources: [config](big_vision/configs/vit_s16_i1k.py)
48+
- [UViM: A Unified Modeling Approach for Vision with Learned Guiding Codes](https://arxiv.org/abs/2205.10337), by
49+
Alexander Kolesnikov*, André Susano Pinto*, Lucas Beyer*, Xiaohua Zhai*, Jeremiah Harmsen*, Neil Houlsby*
4450

4551
### Multimodal research
4652
- [LiT: Zero-Shot Transfer with Locked-image Text Tuning](https://arxiv.org/abs/2111.07991), by
@@ -50,7 +56,8 @@ codebase:
5056
### Knowledge distillation
5157
- [Knowledge distillation: A good teacher is patient and consistent](https://arxiv.org/abs/2106.05237), by
5258
Lucas Beyer*, Xiaohua Zhai*, Amélie Royer*, Larisa Markeeva*, Rohan Anil,
53-
and Alexander Kolesnikov*
59+
and Alexander Kolesnikov*\
60+
Resources: [README](big_vision/configs/proj/distill/README.md), [trainer](big_vision/trainers/proj/distill/distill.py), [colab](https://colab.research.google.com/drive/1nMykzUzsfQ_uAxfj3k35DYsATnG_knPl?usp=sharing).
5461

5562
### Misc
5663
- [Are we done with ImageNet?](https://arxiv.org/abs/2006.07159), by
@@ -90,18 +97,26 @@ gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all
9097

9198
See instructions below for more details on how to use Google Cloud TPUs.
9299

100+
All runs write checkpoints and logfiles. The logfiles are a list of JSON
101+
objects, and we provide an short and straightforward [example colab to read
102+
and display the logs and checkpoints](https://colab.research.google.com/drive/1R_lvV542WUp8Q2y8sbyooZOGCplkn7KI?usp=sharing).
103+
93104
# Current and future contents
94105

95106
The first release contains the core part of pre-training, transferring, and
96107
evaluating classification models at scale on Cloud TPU VMs.
97108

109+
We have since added the following key features and projects:
110+
- Patient and consistent distillation.
111+
- Scaling ViT.
112+
98113
Features and projects we plan to release in the near future, in no particular
99114
order:
100115
- ImageNet-21k in TFDS.
101116
- MLP-Mixer.
102117
- Loading misc public models used in our publications (NFNet, MoCov3, DINO).
103118
- Contrastive Image-Text model training and evaluation as in LiT and CLIP.
104-
- "Patient and consistent" distillation.
119+
- UViM.
105120
- Memory-efficient Polyak-averaging implementation.
106121
- Advanced JAX compute and memory profiling. We are using internal tools for
107122
this, but may eventually add support for the publicly available ones.
@@ -154,7 +169,7 @@ dependencies.
154169

155170
```
156171
git clone --branch=master https://github.com/google-research/big_vision
157-
gcloud alpha compute tpus tpu-vm scp --recurse big_vision/big_vision $NAME: --worker=all --zone=$ZONE
172+
gcloud alpha compute tpus tpu-vm scp --recurse big_vision/big_vision $NAME: --zone=$ZONE --worker=all
158173
gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "bash big_vision/run_tpu.sh"
159174
```
160175

@@ -165,8 +180,9 @@ also do it on your local machine and copy the result to the cloud bucket. For
165180
convenience, we provide instructions on how to prepare data using Cloud TPUs.
166181

167182
Download and prepare TFDS datasets using a single worker. Seven TFDS datasets
168-
used during evaluations will be generated under `~/tensorflow_datasets/` (should
169-
take 10-15 minutes in total).
183+
used during evaluations will be generated under `~/tensorflow_datasets/` (by
184+
default, can be overwritten by TFDS_DATA_DIR env variable). This should take
185+
10-15 minutes in total.
170186

171187
```
172188
gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=0 --command "bash big_vision/run_tpu.sh big_vision.tools.download_tfds_datasets cifar10 cifar100 oxford_iiit_pet oxford_flowers102 cars196 dtd uc_merced"
@@ -206,6 +222,11 @@ run the following command line.
206222
gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/bit_i1k.py --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'`"
207223
```
208224

225+
## Sometimes useful gcloud commands
226+
227+
- Destroy the TPU machines: `gcloud alpha compute tpus tpu-vm delete $NAME --zone $ZONE`
228+
- Remove all big_vision-related folders on all hosts: `gcloud alpha compute tpus tpu-vm ssh $NAME --zone $ZONE --worker=all --command 'rm -rf ~/big_vision ~/bv_venv'`
229+
209230
# ViT baseline
210231

211232
We provide a well-tuned ViT-S/16 baseline in the config file named

big_vision/configs/load_and_eval.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,28 @@ def vit_i1k(config):
114114
)
115115

116116

117+
def mlp_mixer_i1k(config):
118+
# We could omit init_{shapes,types} if we wanted, as they are the default.
119+
config.init_shapes = [(1, 224, 224, 3)]
120+
config.init_types = ['float32']
121+
config.num_classes = 1000
122+
123+
config.model_name = 'mlp_mixer'
124+
config.model_init = '' # Will be set in sweep.
125+
config.model = dict(variant='L/16')
126+
127+
config.evals = {}
128+
config.evals.fewshot = get_fewshot_lsr()
129+
config.evals.val = dict(
130+
type='classification',
131+
dataset='imagenet2012',
132+
split='validation',
133+
pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(1000, key="label", key_result="labels")|keep("image", "labels")',
134+
loss_name='softmax_xent',
135+
cache_final=False, # Only run once, on low-mem machine.
136+
)
137+
138+
117139
def vit_i21k(config):
118140
# We could omit init_{shapes,types} if we wanted, as they are the default.
119141
config.init_shapes = [(1, 224, 224, 3)]
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Knowledge distillation: A good teacher is patient and consistent
2+
*by Lucas Beyer, Xiaohua Zhai, Amélie Royer, Larisa Markeeva, Rohan Anil, Alexander Kolesnikov*
3+
4+
## Introduction
5+
We publish all teacher models, and configurations for the main experiments of
6+
the paper, as well as training logs and student models.
7+
8+
Please read the main [big_vision README](/README.md) to learn how to run
9+
configs, and remember that each config file contains an example invocation in
10+
the top-level comment.
11+
12+
## Results
13+
14+
We provide the following [colab to read and plot the logfiles](https://colab.research.google.com/drive/1nMykzUzsfQ_uAxfj3k35DYsATnG_knPl?usp=sharing)
15+
of a few runs that we reproduced on Cloud.
16+
17+
### ImageNet-1k
18+
19+
The file [bit_i1k.py](bit_i1k.py) is the configuration which reproduces our
20+
distillation runs on ImageNet-1k reported in Figures 1 and 5(left) and the first
21+
row of Table1.
22+
23+
We release both student and teacher models:
24+
25+
| Model | Download link | Resolution | ImageNet top-1 acc. (paper) |
26+
| :--- | :---: | :---: | :---: |
27+
| BiT-R50x1 | [link](https://storage.googleapis.com/bit_models/distill/R50x1_160.npz) | 160 | 80.5 |
28+
| BiT-R50x1 | [link](https://storage.googleapis.com/bit_models/distill/R50x1_224.npz) | 224 | 82.8 |
29+
| BiT-R152x2 | [link](https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz) | 224 | 83.0 |
30+
| BiT-R152x2 | [link](https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz) | 384 | 84.3 |
31+
32+
### Flowers/Pet/Food/Sun
33+
34+
The files [bigsweep_flowers_pet.py](bigsweep_flowers_pet.py) and
35+
[bigsweep_food_sun.py](bigsweep_food_sun.py) can be used to reproduce the
36+
distillation runs on these datasets and shown in Figures 3,4,9-12, and Table4.
37+
38+
While our open-source release does not currently support doing hyper-parameter
39+
sweeps, we still provide an example of the sweeps at the end of the configs
40+
for reference.
41+
42+
### Teacher models
43+
Links to all teacher models we used can be found in [common.py](common.py).
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright 2022 Big Vision Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pylint: disable=line-too-long
16+
r"""Distilling BiT-R152x2 into BiT-R50x1 on Flowers/Pet as in https://arxiv.org/abs/2106.05237
17+
18+
While many epochs are required, this is a small dataset, and thus overall it
19+
is still fast and possible to run on the relatively small v3-8TPUs (or GPUs).
20+
21+
This configuration contains the recommended settings from Fig3/Tab4 of the
22+
paper, which can be selected via the fast/medium/long config argument.
23+
(best settings were selected on a 10% minival)
24+
25+
For Flowers:
26+
- The `fast` variant takes ~1h10m on a v2-8 TPU.
27+
Example logs at gs://big_vision/distill/bit_flowers_fast_06-18_2008/big_vision_metrics.txt
28+
- The `long` variant takes ~25h on a v3-32 TPU.
29+
Example logs at gs://big_vision/distill/bit_flowers_long_06-19_0524/big_vision_metrics.txt
30+
For Pet:
31+
- The `fast` variant takes ~28min on a v2-8 TPU.
32+
Example logs at gs://big_vision/distill/bit_pet_fast_06-16_2338/big_vision_metrics.txt
33+
- The `long` variant takes ~11h on a v2-8 and ~8h on a v3-32.
34+
Example logs at gs://big_vision/distill/bit_pet_long_06-17_0050/big_vision_metrics.txt
35+
36+
big_vision.trainers.proj.distill.distill \
37+
--config big_vision/configs/proj/distill/bigsweep_flowers_pet.py:data=flowers,variant=fast \
38+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
39+
"""
40+
41+
import big_vision.configs.common as bvcc
42+
import big_vision.configs.proj.distill.common as cd
43+
import ml_collections as mlc
44+
45+
NCLS = dict(flowers=102, pet=37)
46+
47+
48+
def get_config(arg=None):
49+
"""Config for massive hypothesis-test on pet."""
50+
arg = bvcc.parse_arg(arg, runlocal=False, data='flowers', variant='medium', crop='inception_crop(128)')
51+
config = mlc.ConfigDict()
52+
53+
config.dataset = dict(flowers='oxford_flowers102', pet='oxford_iiit_pet')[arg.data]
54+
config.cache_raw = True
55+
config.prefetch_to_device = 4
56+
config.train_split = dict(flowers='train', pet='train[:90%]')[arg.data]
57+
config.num_classes = NCLS[arg.data]
58+
59+
config.batch_size = 512
60+
config.num_epochs = {
61+
'flowers': {'fast': 10_000, 'medium': 100_000, 'long': 1_000_000},
62+
'pet': {'fast': 1000, 'medium': 3000, 'long': 30_000},
63+
}[arg.data][arg.variant]
64+
config.shuffle_buffer_size = 50_000
65+
66+
config.log_training_steps = 100
67+
config.checkpoint_steps = 2500
68+
69+
# Model section
70+
config.student_name = 'bit_paper'
71+
config.student = dict(depth=50, width=1)
72+
73+
config.teachers = ['prof_m']
74+
config.prof_m_name = 'bit_paper'
75+
config.prof_m_init = cd.inits[f'BiT-M R152x2 {arg.data} rc128']
76+
config.prof_m = dict(depth=152, width=2)
77+
78+
# Preprocessing pipeline for student & tacher.
79+
pp_common = (
80+
'|value_range(-1, 1)'
81+
f'|onehot({config.num_classes}, key="label", key_result="labels")'
82+
'|keep("image", "labels")'
83+
)
84+
config.pp_train = f'decode|{arg.crop}|flip_lr' + pp_common
85+
ppv = 'decode|resize_small(160)|central_crop(128)' + pp_common
86+
87+
config.mixup = dict(p=1.0, n=2)
88+
89+
# Distillation settings
90+
config.distance = 'kl'
91+
config.distance_kw = dict(t={
92+
'flowers': {'fast': 10., 'medium': 1., 'long': 1.},
93+
'pet': {'fast': 5., 'medium': 10., 'long': 2.},
94+
}[arg.data][arg.variant])
95+
96+
# Optimizer section
97+
config.grad_clip_norm = 1.0
98+
config.optax_name = 'scale_by_adam'
99+
config.optax = dict(mu_dtype='bfloat16')
100+
101+
config.lr = {
102+
'flowers': {'fast': 0.003, 'medium': 0.001, 'long': 0.0003},
103+
'pet': {'fast': 0.01, 'medium': 0.003, 'long': 0.003},
104+
}[arg.data][arg.variant]
105+
config.wd = {
106+
'flowers': {'fast': 3e-4, 'medium': 1e-4, 'long': 1e-5},
107+
'pet': {'fast': 1e-3, 'medium': 3e-4, 'long': 1e-5},
108+
}[arg.data][arg.variant]
109+
config.schedule = dict(warmup_steps=1500, decay_type='cosine')
110+
config.optim_name = 'adam_hp'
111+
112+
# Eval section
113+
minitrain_split = 'train[:512]' if not arg.runlocal else 'train[:16]'
114+
if arg.data == 'flowers':
115+
val_split = 'validation' if not arg.runlocal else 'validation[:16]'
116+
test_split = 'test' if not arg.runlocal else 'test[:16]'
117+
elif arg.data == 'pet':
118+
val_split = 'train[90%:]' if not arg.runlocal else 'train[:16]'
119+
test_split = 'test' if not arg.runlocal else 'test[:16]'
120+
121+
base = dict(
122+
type='classification',
123+
pred='student_fwd',
124+
dataset=config.dataset,
125+
pp_fn=ppv,
126+
loss_name='softmax_xent',
127+
log_steps=500,
128+
)
129+
config.evals = {}
130+
config.evals.student_train = {**base, 'split': minitrain_split}
131+
config.evals.student_val = {**base, 'split': val_split}
132+
config.evals.student_test = {**base, 'split': test_split}
133+
134+
# Teacher is fixed, so rare evals.
135+
teacher = dict(log_steps=100_000, pred='prof_m_fwd')
136+
config.evals.teacher_train = {**config.evals.student_train, **teacher}
137+
config.evals.teacher_val = {**config.evals.student_val, **teacher}
138+
config.evals.teacher_test = {**config.evals.student_test, **teacher}
139+
140+
# Could in principle also look at agreement on other datasets!
141+
dist = dict(
142+
type='proj.distill.distance',
143+
pred='student_prof_m_fwd',
144+
dataset=config.dataset,
145+
pp_fn=ppv + '|keep("image")',
146+
log_steps=1000,
147+
distances=({'kind': 'kl'}, {'kind': 'euclidean'},
148+
{'kind': 'agree', 'k': 1}, {'kind': 'agree', 'k': 5}),
149+
)
150+
config.evals.dist_train = {**dist, 'split': minitrain_split}
151+
config.evals.dist_val = {**dist, 'split': val_split}
152+
config.evals.dist_test = {**dist, 'split': test_split}
153+
154+
# Make a few things much smaller for quick local debugging testruns.
155+
if arg.runlocal:
156+
config.shuffle_buffer_size = 10
157+
config.batch_size = 8
158+
159+
return config

0 commit comments

Comments
 (0)