|
| 1 | +# Copyright 2022 Big Vision Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +# pylint: disable=line-too-long |
| 16 | +r"""Trains CLIP with Pixels Only (CLIPPO), https://arxiv.org/abs/2212.08045 |
| 17 | +
|
| 18 | +IMPORTANT NOTE: This config uses coco_captions by default for demonstration |
| 19 | +purposes since the TFDS catalog does not provide any large image/alt-text data |
| 20 | +set; the training will not produce a model with useful accuracy. Please |
| 21 | +replace the data set below (marked by a comment) with an appropriate image/ |
| 22 | +alt-text data set wrapped in TFDS (for example LAION-400M) and run the config |
| 23 | +with the suffix `:test_with_coco=False` to train on your data set. Refer to |
| 24 | +the following guide to build a TFDS wrapper for your favorite image/alt-text |
| 25 | +data set: |
| 26 | +https://www.tensorflow.org/datasets/add_dataset |
| 27 | +
|
| 28 | +Also note that evaluation on ImageNet requires manual TFDS setup, see |
| 29 | +https://github.com/google-research/big_vision#preparing-tfds-data |
| 30 | +
|
| 31 | +
|
| 32 | +Example training: |
| 33 | +
|
| 34 | +big_vision.trainers.proj.image_text.contrastive \ |
| 35 | + --config big_vision/configs/proj/clippo/train_clippo.py \ |
| 36 | + --workdir gs://[your_bucket]/big_vision/`date '+%Y-%m-%d_%H%M'` |
| 37 | +
|
| 38 | +""" |
| 39 | + |
| 40 | +import big_vision.configs.common as bvcc |
| 41 | +from big_vision.configs.common_fewshot import get_fewshot_lsr |
| 42 | +from big_vision.configs.proj.image_text import common |
| 43 | +from ml_collections import ConfigDict |
| 44 | + |
| 45 | + |
| 46 | +def get_config(arg=None): |
| 47 | + """The base configuration.""" |
| 48 | + arg = bvcc.parse_arg( |
| 49 | + arg, res=224, runlocal=False, variant='B/16', |
| 50 | + test_with_coco=True, i1k_eval=False) |
| 51 | + config = ConfigDict() |
| 52 | + |
| 53 | + config.input = {} |
| 54 | + if arg.test_with_coco: |
| 55 | + # Use COCO Captions for sanity-checking |
| 56 | + config.input.data = dict(name='coco_captions', split='train') |
| 57 | + val_data = dict(config.input.data) |
| 58 | + val_data['split'] = 'val' |
| 59 | + config.input.batch_size = 4000 if not arg.runlocal else 32 |
| 60 | + config.input.shuffle_buffer_size = 50_000 if not arg.runlocal else 50 |
| 61 | + config.total_steps = 400 if not arg.runlocal else 10 |
| 62 | + else: |
| 63 | + # Please add your favorite image/alt-text dataset here |
| 64 | + config.input.data = None |
| 65 | + val_data = None |
| 66 | + assert config.input.data is not None and val_data is not None, ( |
| 67 | + config.input.data, val_data) |
| 68 | + |
| 69 | + # The value in the paper is 10 * 1024, which requires 128 TPUv3 cores or a |
| 70 | + # memory optimized ViT implementation when running on 128 TPUv2 cores. |
| 71 | + config.input.batch_size = 8 * 1024 if not arg.runlocal else 32 |
| 72 | + config.input.shuffle_buffer_size = 250_000 if not arg.runlocal else 50 |
| 73 | + config.total_steps = 100_000 if not arg.runlocal else 10 |
| 74 | + |
| 75 | + def tokenizer(inkey, outkey='labels'): |
| 76 | + return (f'render_unifont(' |
| 77 | + f'inkey="{inkey}", ' |
| 78 | + f'outkey="{outkey}", ' |
| 79 | + f'image_size={arg.res}, ' |
| 80 | + f'lower=True, ' |
| 81 | + f'font_size=16, ' |
| 82 | + f'text_brightness=0, ' |
| 83 | + f'background_brightness=127)|' |
| 84 | + f'value_range(-1, 1, inkey="{outkey}", outkey="{outkey}")') |
| 85 | + |
| 86 | + pp_image = f'decode|resize({arg.res})|value_range(-1,1)' |
| 87 | + if arg.test_with_coco: |
| 88 | + # Train with augmentation when sanity-checking |
| 89 | + pp_image_aug = ( |
| 90 | + f'decode|resize({arg.res})|flip_lr|randaug(2,10)|value_range(-1,1)') |
| 91 | + config.input.pp = pp_eval = ( |
| 92 | + f'{pp_image_aug}|flatten|{tokenizer("captions/text")}|' |
| 93 | + f'keep("image", "labels")') |
| 94 | + else: |
| 95 | + config.input.pp = pp_eval = ( |
| 96 | + f'{pp_image}|flatten|{tokenizer("text")}|keep("image", "labels")') |
| 97 | + |
| 98 | + config.pp_modules = [ |
| 99 | + 'ops_general', 'ops_image', 'ops_text', 'proj.clippo.pp_ops'] |
| 100 | + |
| 101 | + config.log_training_steps = 50 |
| 102 | + config.ckpt_steps = 1000 |
| 103 | + config.keep_ckpt_steps = 5000 |
| 104 | + |
| 105 | + config.loss_use_global_batch = True |
| 106 | + |
| 107 | + # Define the model |
| 108 | + config.model_name = 'proj.clippo.one_tower' |
| 109 | + |
| 110 | + config.model = ConfigDict() |
| 111 | + config.model.image_model = 'vit' |
| 112 | + config.model.image = ConfigDict({ |
| 113 | + 'variant': arg.variant, |
| 114 | + 'pool_type': 'map', |
| 115 | + 'head_zeroinit': False, |
| 116 | + }) |
| 117 | + |
| 118 | + if arg.test_with_coco: |
| 119 | + # Initialize with ImageNet21k pretrained checkpoint for sanity-checking |
| 120 | + assert arg.variant == 'B/16', arg.variant |
| 121 | + config.model_init = {'image': 'howto-i21k-B/16'} |
| 122 | + config.model_load = {} |
| 123 | + config.model_load['img_load_kw'] = { |
| 124 | + 'dont_load': ['^head/.*', '^MAPHead_0/.*', 'cls']} |
| 125 | + |
| 126 | + config.model.temperature_init = 10.0 |
| 127 | + config.model.out_dim = 768 |
| 128 | + |
| 129 | + # Define the optimizer |
| 130 | + config.optax_name = 'big_vision.scale_by_adafactor' |
| 131 | + config.grad_clip_norm = 1.0 |
| 132 | + |
| 133 | + if arg.test_with_coco: |
| 134 | + # Short schedule for sanity-checking |
| 135 | + config.lr = 0.0001 |
| 136 | + config.wd = 0.0003 |
| 137 | + config.schedule = dict(decay_type='rsqrt', |
| 138 | + timescale=100, |
| 139 | + warmup_steps=100 if not arg.runlocal else 5, |
| 140 | + cooldown_steps=100 if not arg.runlocal else 5) |
| 141 | + else: |
| 142 | + config.lr = 0.001 |
| 143 | + config.wd = 0.0001 |
| 144 | + config.schedule = dict(decay_type='rsqrt', |
| 145 | + timescale=10_000, |
| 146 | + warmup_steps=10_000 if not arg.runlocal else 5, |
| 147 | + cooldown_steps=10_000 if not arg.runlocal else 5) |
| 148 | + |
| 149 | + # Eval section (Both few-shot and zero-shot) |
| 150 | + eval_common = dict( |
| 151 | + type='proj.image_text.contrastive', |
| 152 | + use_global_batch=config.loss_use_global_batch, |
| 153 | + log_steps=1000 if not arg.runlocal else 5, |
| 154 | + ) |
| 155 | + config.evals = {} |
| 156 | + sub = '[:4]' if arg.runlocal else '' |
| 157 | + config.evals.val = { |
| 158 | + **eval_common, |
| 159 | + 'data': val_data, |
| 160 | + 'pp_fn': pp_eval, |
| 161 | + } |
| 162 | + config.evals.coco = { |
| 163 | + **eval_common, |
| 164 | + 'data': dict(name='coco_captions', split=f'val{sub}'), |
| 165 | + 'pp_fn': ( |
| 166 | + f'{pp_image}|flatten|{tokenizer("captions/text")}|' |
| 167 | + f'keep("image", "labels")'), |
| 168 | + } |
| 169 | + |
| 170 | + if arg.i1k_eval: |
| 171 | + # Requires manual download, see |
| 172 | + # https://github.com/google-research/big_vision#preparing-tfds-data |
| 173 | + config.evals.imagenet = { |
| 174 | + **eval_common, |
| 175 | + 'data': dict(name='imagenet2012', split=f'validation{sub}'), |
| 176 | + 'pp_fn': ( |
| 177 | + f'{pp_image}|clip_i1k_label_names|' |
| 178 | + f'{tokenizer("labels")}|keep("image", "labels")'), |
| 179 | + } |
| 180 | + config.evals.disclf = dict( |
| 181 | + type='proj.image_text.discriminative_classifier', |
| 182 | + pp_txt=tokenizer('texts', 'labels'), |
| 183 | + prefix='z/0shot/', |
| 184 | + log_steps=5_000 if not arg.runlocal else 5) |
| 185 | + |
| 186 | + config.evals.retrieval_coco = common.get_coco( |
| 187 | + pp_img=f'resize({arg.res})|value_range(-1, 1)', |
| 188 | + pp_txt=tokenizer('texts'), |
| 189 | + log_steps=5_000 if not arg.runlocal else 5, |
| 190 | + ) |
| 191 | + |
| 192 | + # Few-shot metrics |
| 193 | + config.evals.fewshot = get_fewshot_lsr() |
| 194 | + config.evals.fewshot.log_steps = 5_000 if not arg.runlocal else 5 |
| 195 | + config.evals.fewshot.representation_layer = 'img/pre_logits' |
| 196 | + |
| 197 | + config.seed = 0 |
| 198 | + |
| 199 | + return config |
0 commit comments