Skip to content

Commit fd2d3bd

Browse files
authored
Add CLIPPO model, pp_ops, config, and readme. Also update proj/image_text trainer and evaluators. (#27)
1 parent b00544b commit fd2d3bd

File tree

13 files changed

+702
-123
lines changed

13 files changed

+702
-123
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ codebase:
5858
Xiaohua Zhai*, Xiao Wang*, Basil Mustafa*, Andreas Steiner*, Daniel Keysers,
5959
Alexander Kolesnikov, and Lucas Beyer*\
6060
Resources: [trainer](big_vision/trainers/proj/image_text/contrastive.py), [config](big_vision/configs/proj/image_text/lit_coco.py), [colab](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/lit.ipynb).
61+
- [Image-and-Language Understanding from Pixels Only](https://arxiv.org/abs/2212.08045), by
62+
Michael Tschannen, Basil Mustafa, Neil Houlsby
63+
Resources [readme](big_vision/configs/proj/clippo/README.md), [config](big_vision/configs/proj/clippo/train_clippo.py)
6164

6265
### Knowledge distillation
6366

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
## Image-and-Language Understanding from Pixels Only
2+
3+
*by Michael Tschannen, Basil Mustafa, Neil Houlsby* [[arxiv]](https://arxiv.org/abs/2212.08045)
4+
5+
We provide code to train CLIP with Pixels Only (CLIPPO) models on image/alt-text data sets.
6+
7+
To train your own CLIPPO model, please follow the setup instructions in the [`big_vision` main README](https://github.com/google-research/big_vision#cloud-tpu-vm-setup). In the following, we provide the CLIPPO-specific commands required in addition to the setup, assume you are using the Google Cloud TPU setup (potentially with adapted TPU configuration, see table below). If you are using GPUs, please set up your machine directly and only execute the `--command` portions of the commands below from the `big_vision` repository root.
8+
9+
The text rendering preproprocessing function requires manual download of the Unifont .hex files from [Unifoundry](https://unifoundry.com/unifont/) (please follow link for license).:
10+
11+
```bash
12+
gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all \
13+
--command "bash big_vision/pp/proj/clippo/download_unifont.sh"
14+
```
15+
16+
Launch the training by running
17+
18+
```bash
19+
gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all \
20+
--command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.trainers.proj.image_text.contrastive --config big_vision/configs/proj/clippo/train_clippo.py --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'`"
21+
```
22+
23+
*Important note:* The input pipeline relies on [TensorFlow Datasets (TFDS)](https://www.tensorflow.org/datasets) which does not provide automatic integration with large image/alt-text datasets out of the box. The above config therefore trains by default on MS-COCO Captions which can be automatically downloaded via TFDS, and additionally initializes the CLIPPO ViT backbone with weights pretrained on ImageNet21k. This setup is not meant to produce good accuracy, but to provide the user with a way to sanity-check their setup. If you want to train on a large data set such as [`LAION-400M`](https://arxiv.org/abs/2111.02114) or [`YFCC100M`](https://arxiv.org/abs/1503.01817), please follow [these instructions](https://www.tensorflow.org/datasets/add_dataset) to wrap your data set using TFDS, and update the dataset in the config accordingly. Also note that the ImageNet1k evaluations require manual download of the data, see [these instructions](https://github.com/google-research/big_vision#preparing-tfds-data). To train with your own data set and with ImageNet1k-based evaluations, use `--config big_vision/configs/proj/clippo/train_clippo.py:test_with_coco=False,i1k_eval=True` in the command above.
24+
25+
#### Expected results
26+
27+
| train dataset | batch size | #steps | TPU chips | ImageNet 0-shot | MS-COCO I→T | MS-COCO T→I | Config `arg` |
28+
| :--- | ---: | ---: | ---: | :---: | :---: | :---: | :--- |
29+
| *MS-COCO (sanity check)* | 4000 | 400 | 32 v3 | 4.2 | 12.6 | 8.6 | `i1k_eval=True` |
30+
| LAION-400M | 8192 | 100k |128 v2 | 51.5 | 44.8 | 29.3 | `test_with_coco=False,i1k_eval=True` |
31+
| LAION-400M | 10240\* | 100k | 128 v3 | 53.6 | 46.7 | 30.3 | `test_with_coco=False,i1k_eval=True` |
32+
33+
\* The experiments in the paper use a batch size of 10240 which requires a memory-optimized ViT implementation to run on 128 TPU v2 chips or 128 TPU v3 chips (in which case the TPU memory capacity allows to increase the batch size beyond 10240).
34+
35+
#### Citation
36+
37+
```
38+
@article{tschannen2022image,
39+
title={Image-and-Language Understanding from Pixels Only},
40+
author={Tschannen, Michael and Mustafa, Basil and Houlsby, Neil},
41+
journal={arXiv preprint arXiv:2212.08045},
42+
year={2022}
43+
}
44+
```
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# Copyright 2022 Big Vision Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pylint: disable=line-too-long
16+
r"""Trains CLIP with Pixels Only (CLIPPO), https://arxiv.org/abs/2212.08045
17+
18+
IMPORTANT NOTE: This config uses coco_captions by default for demonstration
19+
purposes since the TFDS catalog does not provide any large image/alt-text data
20+
set; the training will not produce a model with useful accuracy. Please
21+
replace the data set below (marked by a comment) with an appropriate image/
22+
alt-text data set wrapped in TFDS (for example LAION-400M) and run the config
23+
with the suffix `:test_with_coco=False` to train on your data set. Refer to
24+
the following guide to build a TFDS wrapper for your favorite image/alt-text
25+
data set:
26+
https://www.tensorflow.org/datasets/add_dataset
27+
28+
Also note that evaluation on ImageNet requires manual TFDS setup, see
29+
https://github.com/google-research/big_vision#preparing-tfds-data
30+
31+
32+
Example training:
33+
34+
big_vision.trainers.proj.image_text.contrastive \
35+
--config big_vision/configs/proj/clippo/train_clippo.py \
36+
--workdir gs://[your_bucket]/big_vision/`date '+%Y-%m-%d_%H%M'`
37+
38+
"""
39+
40+
import big_vision.configs.common as bvcc
41+
from big_vision.configs.common_fewshot import get_fewshot_lsr
42+
from big_vision.configs.proj.image_text import common
43+
from ml_collections import ConfigDict
44+
45+
46+
def get_config(arg=None):
47+
"""The base configuration."""
48+
arg = bvcc.parse_arg(
49+
arg, res=224, runlocal=False, variant='B/16',
50+
test_with_coco=True, i1k_eval=False)
51+
config = ConfigDict()
52+
53+
config.input = {}
54+
if arg.test_with_coco:
55+
# Use COCO Captions for sanity-checking
56+
config.input.data = dict(name='coco_captions', split='train')
57+
val_data = dict(config.input.data)
58+
val_data['split'] = 'val'
59+
config.input.batch_size = 4000 if not arg.runlocal else 32
60+
config.input.shuffle_buffer_size = 50_000 if not arg.runlocal else 50
61+
config.total_steps = 400 if not arg.runlocal else 10
62+
else:
63+
# Please add your favorite image/alt-text dataset here
64+
config.input.data = None
65+
val_data = None
66+
assert config.input.data is not None and val_data is not None, (
67+
config.input.data, val_data)
68+
69+
# The value in the paper is 10 * 1024, which requires 128 TPUv3 cores or a
70+
# memory optimized ViT implementation when running on 128 TPUv2 cores.
71+
config.input.batch_size = 8 * 1024 if not arg.runlocal else 32
72+
config.input.shuffle_buffer_size = 250_000 if not arg.runlocal else 50
73+
config.total_steps = 100_000 if not arg.runlocal else 10
74+
75+
def tokenizer(inkey, outkey='labels'):
76+
return (f'render_unifont('
77+
f'inkey="{inkey}", '
78+
f'outkey="{outkey}", '
79+
f'image_size={arg.res}, '
80+
f'lower=True, '
81+
f'font_size=16, '
82+
f'text_brightness=0, '
83+
f'background_brightness=127)|'
84+
f'value_range(-1, 1, inkey="{outkey}", outkey="{outkey}")')
85+
86+
pp_image = f'decode|resize({arg.res})|value_range(-1,1)'
87+
if arg.test_with_coco:
88+
# Train with augmentation when sanity-checking
89+
pp_image_aug = (
90+
f'decode|resize({arg.res})|flip_lr|randaug(2,10)|value_range(-1,1)')
91+
config.input.pp = pp_eval = (
92+
f'{pp_image_aug}|flatten|{tokenizer("captions/text")}|'
93+
f'keep("image", "labels")')
94+
else:
95+
config.input.pp = pp_eval = (
96+
f'{pp_image}|flatten|{tokenizer("text")}|keep("image", "labels")')
97+
98+
config.pp_modules = [
99+
'ops_general', 'ops_image', 'ops_text', 'proj.clippo.pp_ops']
100+
101+
config.log_training_steps = 50
102+
config.ckpt_steps = 1000
103+
config.keep_ckpt_steps = 5000
104+
105+
config.loss_use_global_batch = True
106+
107+
# Define the model
108+
config.model_name = 'proj.clippo.one_tower'
109+
110+
config.model = ConfigDict()
111+
config.model.image_model = 'vit'
112+
config.model.image = ConfigDict({
113+
'variant': arg.variant,
114+
'pool_type': 'map',
115+
'head_zeroinit': False,
116+
})
117+
118+
if arg.test_with_coco:
119+
# Initialize with ImageNet21k pretrained checkpoint for sanity-checking
120+
assert arg.variant == 'B/16', arg.variant
121+
config.model_init = {'image': 'howto-i21k-B/16'}
122+
config.model_load = {}
123+
config.model_load['img_load_kw'] = {
124+
'dont_load': ['^head/.*', '^MAPHead_0/.*', 'cls']}
125+
126+
config.model.temperature_init = 10.0
127+
config.model.out_dim = 768
128+
129+
# Define the optimizer
130+
config.optax_name = 'big_vision.scale_by_adafactor'
131+
config.grad_clip_norm = 1.0
132+
133+
if arg.test_with_coco:
134+
# Short schedule for sanity-checking
135+
config.lr = 0.0001
136+
config.wd = 0.0003
137+
config.schedule = dict(decay_type='rsqrt',
138+
timescale=100,
139+
warmup_steps=100 if not arg.runlocal else 5,
140+
cooldown_steps=100 if not arg.runlocal else 5)
141+
else:
142+
config.lr = 0.001
143+
config.wd = 0.0001
144+
config.schedule = dict(decay_type='rsqrt',
145+
timescale=10_000,
146+
warmup_steps=10_000 if not arg.runlocal else 5,
147+
cooldown_steps=10_000 if not arg.runlocal else 5)
148+
149+
# Eval section (Both few-shot and zero-shot)
150+
eval_common = dict(
151+
type='proj.image_text.contrastive',
152+
use_global_batch=config.loss_use_global_batch,
153+
log_steps=1000 if not arg.runlocal else 5,
154+
)
155+
config.evals = {}
156+
sub = '[:4]' if arg.runlocal else ''
157+
config.evals.val = {
158+
**eval_common,
159+
'data': val_data,
160+
'pp_fn': pp_eval,
161+
}
162+
config.evals.coco = {
163+
**eval_common,
164+
'data': dict(name='coco_captions', split=f'val{sub}'),
165+
'pp_fn': (
166+
f'{pp_image}|flatten|{tokenizer("captions/text")}|'
167+
f'keep("image", "labels")'),
168+
}
169+
170+
if arg.i1k_eval:
171+
# Requires manual download, see
172+
# https://github.com/google-research/big_vision#preparing-tfds-data
173+
config.evals.imagenet = {
174+
**eval_common,
175+
'data': dict(name='imagenet2012', split=f'validation{sub}'),
176+
'pp_fn': (
177+
f'{pp_image}|clip_i1k_label_names|'
178+
f'{tokenizer("labels")}|keep("image", "labels")'),
179+
}
180+
config.evals.disclf = dict(
181+
type='proj.image_text.discriminative_classifier',
182+
pp_txt=tokenizer('texts', 'labels'),
183+
prefix='z/0shot/',
184+
log_steps=5_000 if not arg.runlocal else 5)
185+
186+
config.evals.retrieval_coco = common.get_coco(
187+
pp_img=f'resize({arg.res})|value_range(-1, 1)',
188+
pp_txt=tokenizer('texts'),
189+
log_steps=5_000 if not arg.runlocal else 5,
190+
)
191+
192+
# Few-shot metrics
193+
config.evals.fewshot = get_fewshot_lsr()
194+
config.evals.fewshot.log_steps = 5_000 if not arg.runlocal else 5
195+
config.evals.fewshot.representation_layer = 'img/pre_logits'
196+
197+
config.seed = 0
198+
199+
return config

big_vision/configs/proj/image_text/common.py

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,30 +32,96 @@
3232
# pylint: enable=line-too-long
3333

3434

35+
def _square875(sz):
36+
return f'resize({int(sz/0.875)})|central_crop({sz})|value_range(-1,1)'
37+
38+
39+
def _aspect75(sz):
40+
return f'resize_small({int(sz/0.75)})|central_crop({sz})|value_range(-1,1)'
41+
42+
43+
def _drop_no_real_label(f):
44+
return len(f['real_label']) > 0
45+
46+
47+
def _drop_no_imagenet(f):
48+
return len(f['labels_imagenet']) > 0
49+
50+
51+
DISCLF_DATASET_OVERRIDES = {
52+
'imagenet2012': {'class_names': 'clip', 'split': 'validation'},
53+
'imagenet2012_minival': {
54+
'dataset_name': 'imagenet2012',
55+
'class_names': 'clip',
56+
'split': 'train[99%:]',
57+
},
58+
'imagenet2012_real': {
59+
'split': 'validation',
60+
'class_names': 'clip',
61+
'class_names_dataset_name': 'imagenet2012',
62+
'pp_img': lambda sz: (
63+
_square875(sz) + '|pad_to_shape(inkey="real_label", outkey="label", shape=[10], pad_value=-1)|keep("label", "image")'), # pylint: disable=line-too-long
64+
'filter_fn': _drop_no_real_label,
65+
},
66+
'imagenet_v2': {'class_names': 'clip'},
67+
'imagenet_a': {
68+
'class_names': 'clip',
69+
'pp_img': lambda sz: _aspect75(sz) + '|map("i1k_i1ka")',
70+
},
71+
'imagenet_r': {
72+
'class_names': 'clip',
73+
'pp_img': lambda sz: _square875(sz) + '|map("i1k_i1kr")',
74+
},
75+
}
76+
77+
78+
def get_disclf(sz, *, log_steps, pp_txt=None, dataset_names=('imagenet2012',)):
79+
"""Returns config for discriminative_classifier of specified datasets."""
80+
config = ml_collections.ConfigDict(dict(
81+
dataset_names=list(dataset_names),
82+
type='proj.image_text.discriminative_classifier',
83+
prefix='z/0shot/',
84+
pp_img=_square875(sz),
85+
dataset_overrides={},
86+
log_steps=log_steps,
87+
cache_final=True,
88+
))
89+
if pp_txt:
90+
config.pp_txt = pp_txt
91+
for name in dataset_names:
92+
if name in DISCLF_DATASET_OVERRIDES:
93+
config.dataset_overrides[name] = {**DISCLF_DATASET_OVERRIDES[name]}
94+
d = config.dataset_overrides[name]
95+
if 'pp_img' in d and callable(d['pp_img']):
96+
with d.ignore_type():
97+
d['pp_img'] = d['pp_img'](sz)
98+
return config
99+
100+
35101
def get_coco(
36102
*,
103+
log_steps,
37104
pp_img='resize(224)|value_range(-1, 1)',
38105
pp_txt='tokenize(max_len=16, inkey="texts", eos="sticky", pad_value=1)',
39-
prefix='z/retr/coco_',
40-
log_steps):
106+
prefix='z/retr/coco_'):
41107
"""Returns config for mscoco retrieval zero-shot.
42108
43109
Args:
110+
log_steps: How often the evaluators should be run.
44111
pp_img: Pre-processing string for "image" feature.
45112
pp_txt: Pre-processing string for texts (expected to tokenize "texts" to
46113
"labels").
47114
prefix: Prefix to use for metrics.
48-
log_steps: How often the evaluators should be run.
49115
50116
Returns:
51117
`ConfigDict` that can be used as a retrieval evaluator configuration.
52118
"""
53119
return ml_collections.ConfigDict({
54120
'type': 'proj.image_text.retrieval',
55-
'log_steps': log_steps,
56121
'pp_txt': pp_txt,
57122
'pp_img': pp_img,
58123
'prefix': prefix,
59124
'dataset': 'coco_captions',
60125
'txt_name': ('captions', 'text'),
126+
'log_steps': log_steps,
61127
})

big_vision/evaluators/proj/image_text/contrastive.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Evaluator for the contrastive task."""
15+
"""Evaluator for the contrastive task.
16+
17+
DON'T COMPARE ACROSS RUNS, use for training health monitoring only.
18+
19+
Note that this evaluator's `ncorrect_minibatch` is only a rough proxy for
20+
training progress and does not report the actual `ncorrect`: when the same
21+
labels found multiple times in a batch, then the reported value is biased
22+
towards lower values.
23+
24+
Also note that the `ncorrect_minibatch` is a function of batch size (it's a lot
25+
easier to find correct values in small batches).
26+
"""
1627
import functools
1728

1829
from big_vision import input_pipeline
@@ -38,9 +49,6 @@ def get_eval_fn(predict_fn, use_global_batch):
3849

3950
@functools.partial(jax.pmap, axis_name="batch")
4051
def _eval_fn(params, images, labels, mask):
41-
42-
# Ignore the entries with all zero labels for evaluation.
43-
mask *= jnp.clip(labels.max(axis=1), 0, 1)
4452
zimg, ztxt, extras = predict_fn(params, images, labels)
4553

4654
if use_global_batch:

0 commit comments

Comments
 (0)