1
+ # Copyright 2022 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""Distilling BiT-R152x2 into BiT-R50x1 on Flowers/Pet as in https://arxiv.org/abs/2106.05237
17
+
18
+ While many epochs are required, this is a small dataset, and thus overall it
19
+ is still fast and possible to run on the relatively small v3-8TPUs (or GPUs).
20
+
21
+ This configuration contains the recommended settings from Fig3/Tab4 of the
22
+ paper, which can be selected via the fast/medium/long config argument.
23
+ (best settings were selected on a 10% minival)
24
+
25
+ For Flowers:
26
+ - The `fast` variant takes ~1h10m on a v2-8 TPU.
27
+ Example logs at gs://big_vision/distill/bit_flowers_fast_06-18_2008/big_vision_metrics.txt
28
+ - The `long` variant takes ~25h on a v3-32 TPU.
29
+ Example logs at gs://big_vision/distill/bit_flowers_long_06-19_0524/big_vision_metrics.txt
30
+ For Pet:
31
+ - The `fast` variant takes ~28min on a v2-8 TPU.
32
+ Example logs at gs://big_vision/distill/bit_pet_fast_06-16_2338/big_vision_metrics.txt
33
+ - The `long` variant takes ~11h on a v2-8 and ~8h on a v3-32.
34
+ Example logs at gs://big_vision/distill/bit_pet_long_06-17_0050/big_vision_metrics.txt
35
+
36
+ big_vision.trainers.proj.distill.distill \
37
+ --config big_vision/configs/proj/distill/bigsweep_flowers_pet.py:data=flowers,variant=fast \
38
+ --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
39
+ """
40
+
41
+ import big_vision .configs .common as bvcc
42
+ import big_vision .configs .proj .distill .common as cd
43
+ import ml_collections as mlc
44
+
45
+ NCLS = dict (flowers = 102 , pet = 37 )
46
+
47
+
48
+ def get_config (arg = None ):
49
+ """Config for massive hypothesis-test on pet."""
50
+ arg = bvcc .parse_arg (arg , runlocal = False , data = 'flowers' , variant = 'medium' , crop = 'inception_crop(128)' )
51
+ config = mlc .ConfigDict ()
52
+
53
+ config .dataset = dict (flowers = 'oxford_flowers102' , pet = 'oxford_iiit_pet' )[arg .data ]
54
+ config .cache_raw = True
55
+ config .prefetch_to_device = 4
56
+ config .train_split = dict (flowers = 'train' , pet = 'train[:90%]' )[arg .data ]
57
+ config .num_classes = NCLS [arg .data ]
58
+
59
+ config .batch_size = 512
60
+ config .num_epochs = {
61
+ 'flowers' : {'fast' : 10_000 , 'medium' : 100_000 , 'long' : 1_000_000 },
62
+ 'pet' : {'fast' : 1000 , 'medium' : 3000 , 'long' : 30_000 },
63
+ }[arg .data ][arg .variant ]
64
+ config .shuffle_buffer_size = 50_000
65
+
66
+ config .log_training_steps = 100
67
+ config .checkpoint_steps = 2500
68
+
69
+ # Model section
70
+ config .student_name = 'bit_paper'
71
+ config .student = dict (depth = 50 , width = 1 )
72
+
73
+ config .teachers = ['prof_m' ]
74
+ config .prof_m_name = 'bit_paper'
75
+ config .prof_m_init = cd .inits [f'BiT-M R152x2 { arg .data } rc128' ]
76
+ config .prof_m = dict (depth = 152 , width = 2 )
77
+
78
+ # Preprocessing pipeline for student & tacher.
79
+ pp_common = (
80
+ '|value_range(-1, 1)'
81
+ f'|onehot({ config .num_classes } , key="label", key_result="labels")'
82
+ '|keep("image", "labels")'
83
+ )
84
+ config .pp_train = f'decode|{ arg .crop } |flip_lr' + pp_common
85
+ ppv = 'decode|resize_small(160)|central_crop(128)' + pp_common
86
+
87
+ config .mixup = dict (p = 1.0 , n = 2 )
88
+
89
+ # Distillation settings
90
+ config .distance = 'kl'
91
+ config .distance_kw = dict (t = {
92
+ 'flowers' : {'fast' : 10. , 'medium' : 1. , 'long' : 1. },
93
+ 'pet' : {'fast' : 5. , 'medium' : 10. , 'long' : 2. },
94
+ }[arg .data ][arg .variant ])
95
+
96
+ # Optimizer section
97
+ config .grad_clip_norm = 1.0
98
+ config .optax_name = 'scale_by_adam'
99
+ config .optax = dict (mu_dtype = 'bfloat16' )
100
+
101
+ config .lr = {
102
+ 'flowers' : {'fast' : 0.003 , 'medium' : 0.001 , 'long' : 0.0003 },
103
+ 'pet' : {'fast' : 0.01 , 'medium' : 0.003 , 'long' : 0.003 },
104
+ }[arg .data ][arg .variant ]
105
+ config .wd = {
106
+ 'flowers' : {'fast' : 3e-4 , 'medium' : 1e-4 , 'long' : 1e-5 },
107
+ 'pet' : {'fast' : 1e-3 , 'medium' : 3e-4 , 'long' : 1e-5 },
108
+ }[arg .data ][arg .variant ]
109
+ config .schedule = dict (warmup_steps = 1500 , decay_type = 'cosine' )
110
+ config .optim_name = 'adam_hp'
111
+
112
+ # Eval section
113
+ minitrain_split = 'train[:512]' if not arg .runlocal else 'train[:16]'
114
+ if arg .data == 'flowers' :
115
+ val_split = 'validation' if not arg .runlocal else 'validation[:16]'
116
+ test_split = 'test' if not arg .runlocal else 'test[:16]'
117
+ elif arg .data == 'pet' :
118
+ val_split = 'train[90%:]' if not arg .runlocal else 'train[:16]'
119
+ test_split = 'test' if not arg .runlocal else 'test[:16]'
120
+
121
+ base = dict (
122
+ type = 'classification' ,
123
+ pred = 'student_fwd' ,
124
+ dataset = config .dataset ,
125
+ pp_fn = ppv ,
126
+ loss_name = 'softmax_xent' ,
127
+ log_steps = 500 ,
128
+ )
129
+ config .evals = {}
130
+ config .evals .student_train = {** base , 'split' : minitrain_split }
131
+ config .evals .student_val = {** base , 'split' : val_split }
132
+ config .evals .student_test = {** base , 'split' : test_split }
133
+
134
+ # Teacher is fixed, so rare evals.
135
+ teacher = dict (log_steps = 100_000 , pred = 'prof_m_fwd' )
136
+ config .evals .teacher_train = {** config .evals .student_train , ** teacher }
137
+ config .evals .teacher_val = {** config .evals .student_val , ** teacher }
138
+ config .evals .teacher_test = {** config .evals .student_test , ** teacher }
139
+
140
+ # Could in principle also look at agreement on other datasets!
141
+ dist = dict (
142
+ type = 'proj.distill.distance' ,
143
+ pred = 'student_prof_m_fwd' ,
144
+ dataset = config .dataset ,
145
+ pp_fn = ppv + '|keep("image")' ,
146
+ log_steps = 1000 ,
147
+ distances = ({'kind' : 'kl' }, {'kind' : 'euclidean' },
148
+ {'kind' : 'agree' , 'k' : 1 }, {'kind' : 'agree' , 'k' : 5 }),
149
+ )
150
+ config .evals .dist_train = {** dist , 'split' : minitrain_split }
151
+ config .evals .dist_val = {** dist , 'split' : val_split }
152
+ config .evals .dist_test = {** dist , 'split' : test_split }
153
+
154
+ # Make a few things much smaller for quick local debugging testruns.
155
+ if arg .runlocal :
156
+ config .shuffle_buffer_size = 10
157
+ config .batch_size = 8
158
+
159
+ return config
0 commit comments