1
+ # Copyright 2022 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""A config for training a UViM stage II model for the panoptic task.
17
+
18
+ This config is expected to reproduce the paper's result and achieve
19
+ approximately 43.7 PQ points on the COCO holdout data.
20
+
21
+ We also provide a low-resource variant of this config, which can be enabled
22
+ by adding `:singlehost` postfix to the config name. This one is expected to
23
+ achieve 39.4 PQ points on the COCO holdout data.
24
+ """
25
+
26
+ import big_vision .configs .common as bvcc
27
+ from ml_collections import ConfigDict
28
+
29
+ VTT_MODELS = {
30
+ 'base' : dict (num_layers = 12 , num_heads = 12 , mlp_dim = 3072 , emb_dim = 768 ),
31
+ 'large' : dict (num_layers = 24 , num_heads = 16 , mlp_dim = 4096 , emb_dim = 1024 ),
32
+ }
33
+
34
+ VQVAE_MODELS = {
35
+ 'base' : dict (enc_depth = 6 , dec_depth = 12 , num_heads = 12 , mlp_dim = 3072 , width = 768 ),
36
+ }
37
+
38
+ RES = 512
39
+ PATCH_SIZE = 16
40
+ LABEL_RES = 512
41
+ LABEL_PATCH_SIZE = 16
42
+
43
+
44
+ def get_config (arg = '' ):
45
+ """Config for training."""
46
+ arg = bvcc .parse_arg (arg , runlocal = False , singlehost = False )
47
+ config = ConfigDict ()
48
+
49
+ config .pp_train = (
50
+ f'decode|coco_panoptic|concat(["semantics","instances"], "labels")|'
51
+ f'randu("fliplr")|det_fliplr(key="image")|det_fliplr(key="labels")|'
52
+ f'inception_box|crop_box(key="image")|crop_box(key="labels")|'
53
+ f'resize({ LABEL_RES } , inkey="image", outkey="image_ctx")|'
54
+ f'resize({ RES } )|resize({ LABEL_RES } ,key="labels",method="nearest")|'
55
+ f'value_range(-1, 1, key="image_ctx")|'
56
+ f'value_range(-1, 1)|make_canonical|keep("image","image_ctx","labels")'
57
+ )
58
+ pp_eval = (
59
+ f'decode|coco_panoptic|concat(["semantics","instances"], "labels")|'
60
+ f'resize({ LABEL_RES } , inkey="image", outkey="image_ctx")|'
61
+ f'resize({ RES } )|resize({ LABEL_RES } ,key="labels",method="nearest")|'
62
+ f'value_range(-1, 1, key="image_ctx")|'
63
+ f'value_range(-1, 1)|make_canonical|keep("image","image_ctx","labels")'
64
+ )
65
+ pp_predict = (
66
+ f'resize({ LABEL_RES } , inkey="image", outkey="image_ctx")|resize({ RES } )|'
67
+ f'value_range(-1, 1, key="image_ctx")|value_range(-1, 1)|'
68
+ f'keep("image","image_ctx","image/id")' # image/id used for rng seeds.
69
+ )
70
+
71
+ config .dataset = 'coco/2017_panoptic'
72
+ config .train_split = 'train[4096:]'
73
+
74
+ config .batch_size = 512
75
+ config .total_epochs = 200
76
+
77
+ config .log_training_steps = 50
78
+ config .shuffle_buffer_size = 50_000
79
+ config .ckpt_steps = 1000
80
+ config .keep_ckpt_steps = 5000
81
+ config .ckpt_timeout = 1
82
+ config .prefetch_to_device = 2
83
+ config .trial = 0
84
+
85
+ # Optimizer section
86
+ config .optax_name = 'big_vision.scale_by_adafactor'
87
+ config .optax = dict (beta2_cap = 0.95 )
88
+
89
+ config .lr = 0.001
90
+ config .wd = 0.000001
91
+ config .lr_mults = [
92
+ ('pos_embedding_encoder.*' , 0.1 ),
93
+ ('EmbedPatches.*' , 0.1 ),
94
+ ('encoder.*' , 0.1 ),
95
+ ('decoder.*' , 1.0 )
96
+ ]
97
+ config .schedule = dict (decay_type = 'cosine' , warmup_steps = 4_000 )
98
+
99
+ # Oracle section
100
+ config .oracle = ConfigDict ()
101
+ config .oracle .task = 'proj.uvim.panoptic_task'
102
+ config .oracle .model_init = 'gs://big_vision/uvim/panoptic_stageI_params.npz'
103
+ config .oracle .model_name = 'proj.uvim.vit'
104
+ config .oracle .model = ConfigDict (VQVAE_MODELS ['base' ])
105
+ config .oracle .model .input_size = (LABEL_RES , LABEL_RES )
106
+ config .oracle .model .patch_size = (LABEL_PATCH_SIZE , LABEL_PATCH_SIZE )
107
+ config .oracle .model .code_len = 256
108
+ config .oracle .model .dict_size = 4096
109
+ config .oracle .model .codeword_dim = 768
110
+ config .oracle .model .with_encoder_ctx = True
111
+ config .oracle .model .with_decoder_ctx = True
112
+ config .oracle .model .code_dropout = 'random'
113
+ config .oracle .model .bottleneck_resize = True
114
+ config .oracle .model .inputs = {
115
+ 'semantics' : (133 + 1 , LABEL_PATCH_SIZE ** 2 ), # +1 for void label
116
+ 'instances' : (100 , LABEL_PATCH_SIZE ** 2 ), # COCO: actually 98 train/78 validation.
117
+ }
118
+ config .oracle .model .outputs = config .oracle .model .inputs
119
+
120
+ # Model section
121
+ config .model_name = 'proj.uvim.vtt'
122
+ # config.model_init = {'encoder': 'howto-i21k-B/8'}
123
+ config .model_init = {'encoder' : 'howto-i21k-L/16' }
124
+ config .model = ConfigDict (VTT_MODELS ['large' ])
125
+ config .model .patches = ConfigDict ({'size' : (PATCH_SIZE , PATCH_SIZE )})
126
+ config .model .vocab_size = config .oracle .model .get_ref ('dict_size' ) + 1
127
+ config .model .posemb_type = 'learn'
128
+ config .model .input_size = (RES , RES )
129
+ config .model .seq_len = config .oracle .model .get_ref ('code_len' )
130
+
131
+ # Evaluation section
132
+ config .evals = {}
133
+ config .evals .val = ConfigDict ()
134
+ config .evals .val .type = 'proj.uvim.compute_mean'
135
+ config .evals .val .pred = 'validation'
136
+ config .evals .val .dataset = config .dataset
137
+ config .evals .val .split = 'train[:4096]'
138
+ config .evals .val .pp_fn = pp_eval
139
+ config .evals .val .log_steps = 1000
140
+
141
+ base = {
142
+ 'type' : 'proj.uvim.coco_panoptic' ,
143
+ 'pp_fn' : pp_predict ,
144
+ 'log_steps' : 10_000 ,
145
+ # Filters objects that occupy less than 0.03^2 fraction of all pixels.
146
+ # 'predict_kwargs': {'min_fraction': 0.03 ** 2},
147
+ }
148
+ config .evals .coco_panoptic_train = dict (** base , split = 'train[4096:8192]' )
149
+ config .evals .coco_panoptic_holdout = dict (** base , split = 'train[:4096]' )
150
+ config .evals .coco_panoptic = dict (** base , split = 'validation' )
151
+
152
+ # config.evals.save_pred = dict(type='proj.uvim.save_predictions')
153
+ # config.evals.save_pred.pp = pp_eval.replace('decode|', '')
154
+ # config.evals.save_pred.log_steps = 100_000
155
+ # config.evals.save_pred.dataset = config.dataset
156
+ # config.evals.save_pred.split = 'validation[:1024]'
157
+ # config.evals.save_pred.outfile = 'inference.npz'
158
+
159
+ if arg .singlehost :
160
+ config .batch_size = 32
161
+ config .num_epochs = 50
162
+ elif arg .runlocal :
163
+ config .batch_size = 4
164
+ config .shuffle_buffer_size = 10
165
+ config .evals .val .split = 'train[:16]'
166
+ return config
0 commit comments