Skip to content

Commit d4a31f0

Browse files
williamFalconBordaluiscapeakshaykulkarni07ethanwharris
authored
Enable TPU support (#868)
* added tpu docs * added tpu flags * add tpu docs + init training call * amp * amp * amp * amp * optimizer step * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * fix test pkg create (#873) * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Luis Capelo <[email protected]> * Fix segmentation example (#876) * removed torchvision model and added custom model * minor fix * Fixed relative imports issue * Fix/typo (#880) * Update greetings.yml * Update greetings.yml * Changelog (#869) * Create CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md * Update PULL_REQUEST_TEMPLATE.md * Update PULL_REQUEST_TEMPLATE.md * Add PR links to Version 0.6.0 in CHANGELOG.md * Add PR links for Unreleased in CHANGELOG.md * Update PULL_REQUEST_TEMPLATE.md * Fixing Function Signatures (#871) * added tpu docs * added tpu flags * add tpu docs + init training call * amp * amp * amp * amp * optimizer step * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Luis Capelo <[email protected]> Co-authored-by: Akshay Kulkarni <[email protected]> Co-authored-by: Ethan Harris <[email protected]> Co-authored-by: Shikhar Chauhan <[email protected]>
1 parent e38b18e commit d4a31f0

File tree

14 files changed

+489
-48
lines changed

14 files changed

+489
-48
lines changed

docs/source/apex.rst

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
16-bit training
22
=================
3+
Lightning offers 16-bit training for CPUs, GPUs and TPUs.
4+
5+
GPU 16-bit
6+
-----------
37
Lightning uses NVIDIA apex to handle 16-bit precision training.
48

59
To use 16-bit precision, do two things:
10+
611
1. Install Apex
7-
2. Set the amp trainer flag.
12+
2. Set the "precision" trainer flag.
813

914
Install apex
10-
----------------------------------------------
15+
^^^^^^^^^^^^
1116
.. code-block:: bash
1217
1318
$ git clone https://github.com/NVIDIA/apex
@@ -31,12 +36,25 @@ Install apex
3136
3237
3338
Enable 16-bit
34-
--------------
39+
^^^^^^^^^^^^^
3540

3641
.. code-block:: python
3742
38-
# DEFAULT
39-
trainer = Trainer(amp_level='O1', use_amp=False)
43+
# turn on 16-bit
44+
trainer = Trainer(amp_level='O1', precision=16)
4045
4146
If you need to configure the apex init for your particular use case or want to use a different way of doing
42-
16-bit training, override :meth:`pytorch_lightning.core.LightningModule.configure_apex`.
47+
16-bit training, override :meth:`pytorch_lightning.core.LightningModule.configure_apex`.
48+
49+
TPU 16-bit
50+
----------
51+
16-bit on TPus is much simpler. To use 16-bit with TPUs set precision to 16 when using the tpu flag
52+
53+
.. code-block:: python
54+
55+
# DEFAULT
56+
trainer = Trainer(num_tpu_cores=8, precision=32)
57+
58+
# turn on 16-bit
59+
trainer = Trainer(num_tpu_cores=8, precision=16)
60+

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ PyTorch-Lightning Documentation
5555
single_gpu
5656
sequences
5757
training_tricks
58+
tpu
5859
test_set
5960
optimizers
6061
profiler

docs/source/new-project.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ Then you could do rapid research by switching between these two and using the sa
6060
else:
6161
model = CoolerNotBERT()
6262
63-
trainer = Trainer(gpus=4, use_amp=True)
63+
trainer = Trainer(gpus=4, precision=16)
6464
trainer.fit(model)
6565
6666

docs/source/tpu.rst

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
TPU support
2+
===========
3+
4+
Lightning supports running on TPUs. At this moment, TPUs are only available
5+
on Google Cloud (GCP). For more information on TPUs
6+
`watch this video <https://www.youtube.com/watch?v=kPMpmcl_Pyw>`_.
7+
8+
Live demo
9+
----------
10+
Check out this `Google Colab <https://colab.research.google.com/drive/1-_LKx4HwAxl5M6xPJmqAAu444LTDQoa3>`_ to see how to train MNIST on TPUs.
11+
12+
TPU Terminology
13+
---------------
14+
A TPU is a Tensor processing unit. Each TPU has 8 cores where each
15+
core is optimized for 128x128 matrix multiplies. In general, a single
16+
TPU is about as fast as 5 V100 GPUs!
17+
18+
A TPU pod hosts many TPUs on it. Currently, TPU pod v2 has 2048 cores!
19+
You can request a full pod from Google cloud or a "slice" which gives you
20+
some subset of those 2048 cores.
21+
22+
How to access TPUs
23+
-------------------
24+
To access TPUs there are two main ways.
25+
26+
1. Using google colab.
27+
2. Using Google Cloud (GCP).
28+
29+
Colab TPUs
30+
-----------
31+
Colab is like a jupyter notebook with a free GPU or TPU
32+
hosted on GCP.
33+
34+
To get a TPU on colab, follow these steps:
35+
36+
1. Go to https://colab.research.google.com/.
37+
38+
2. Click "new notebook" (bottom right of pop-up).
39+
40+
3. Click runtime > change runtime settings. Select Python 3,
41+
and hardware accelerator "TPU". This will give you a TPU with 8 cores.
42+
43+
4. Next, insert this code into the first cell and execute. This
44+
will install the xla library that interfaces between PyTorch and
45+
the TPU.
46+
47+
.. code-block:: python
48+
49+
import collections
50+
from datetime import datetime, timedelta
51+
import os
52+
import requests
53+
import threading
54+
55+
_VersionConfig = collections.namedtuple('_VersionConfig', 'wheels,server')
56+
VERSION = "xrt==1.15.0" #@param ["xrt==1.15.0", "torch_xla==nightly"]
57+
CONFIG = {
58+
'xrt==1.15.0': _VersionConfig('1.15', '1.15.0'),
59+
'torch_xla==nightly': _VersionConfig('nightly', 'XRT-dev{}'.format(
60+
(datetime.today() - timedelta(1)).strftime('%Y%m%d'))),
61+
}[VERSION]
62+
DIST_BUCKET = 'gs://tpu-pytorch/wheels'
63+
TORCH_WHEEL = 'torch-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
64+
TORCH_XLA_WHEEL = 'torch_xla-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
65+
TORCHVISION_WHEEL = 'torchvision-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
66+
67+
# Update TPU XRT version
68+
def update_server_xrt():
69+
print('Updating server-side XRT to {} ...'.format(CONFIG.server))
70+
url = 'http://{TPU_ADDRESS}:8475/requestversion/{XRT_VERSION}'.format(
71+
TPU_ADDRESS=os.environ['COLAB_TPU_ADDR'].split(':')[0],
72+
XRT_VERSION=CONFIG.server,
73+
)
74+
print('Done updating server-side XRT: {}'.format(requests.post(url)))
75+
76+
update = threading.Thread(target=update_server_xrt)
77+
update.start()
78+
79+
# Install Colab TPU compat PyTorch/TPU wheels and dependencies
80+
!pip uninstall -y torch torchvision
81+
!gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" .
82+
!gsutil cp "$DIST_BUCKET/$TORCH_XLA_WHEEL" .
83+
!gsutil cp "$DIST_BUCKET/$TORCHVISION_WHEEL" .
84+
!pip install "$TORCH_WHEEL"
85+
!pip install "$TORCH_XLA_WHEEL"
86+
!pip install "$TORCHVISION_WHEEL"
87+
!sudo apt-get install libomp5
88+
update.join()
89+
5. Once the above is done, install PyTorch Lightning (v 0.6.1+).
90+
91+
.. code-block::
92+
93+
! pip install pytorch-lightning
94+
95+
6. Then set up your LightningModule as normal.
96+
97+
7. TPUs require a DistributedSampler. That means you should change your
98+
train_dataloader (and val, train) code as follows.
99+
100+
.. code-block:: python
101+
102+
import torch_xla.core.xla_model as xm
103+
104+
@pl.data_loader
105+
def train_dataloader(self):
106+
dataset = MNIST(
107+
os.getcwd(),
108+
train=True,
109+
download=True,
110+
transform=transforms.ToTensor()
111+
)
112+
113+
# required for TPU support
114+
sampler = None
115+
if use_tpu:
116+
sampler = torch.utils.data.distributed.DistributedSampler(
117+
dataset,
118+
num_replicas=xm.xrt_world_size(),
119+
rank=xm.get_ordinal(),
120+
shuffle=True
121+
)
122+
123+
loader = DataLoader(
124+
dataset,
125+
sampler=sampler,
126+
batch_size=32
127+
)
128+
129+
return loader
130+
131+
8. Configure the number of TPU cores in the trainer. You can only choose
132+
1 or 8. To use a full TPU pod skip to the TPU pod section.
133+
134+
.. code-block:: python
135+
136+
import pytorch_lightning as pl
137+
138+
my_model = MyLightningModule()
139+
trainer = pl.Trainer(num_tpu_cores=8)
140+
trainer.fit(my_model)
141+
142+
That's it! Your model will train on all 8 TPU cores.
143+
144+
TPU Pod
145+
--------
146+
To train on more than 8 cores, your code actually doesn't change!
147+
All you need to do is submit the following command:
148+
149+
.. code-block:: bash
150+
$ python -m torch_xla.distributed.xla_dist
151+
--tpu=$TPU_POD_NAME
152+
--conda-env=torch-xla-nightly
153+
-- python /usr/share/torch-xla-0.5/pytorch/xla/test/test_train_imagenet.py --fake_data
154+
155+
16 bit precision
156+
-----------------
157+
Lightning also supports training in 16-bit precision with TPUs.
158+
By default, TPU training will use 32-bit precision. To enable 16-bit, also
159+
set the 16-bit flag.
160+
161+
.. code-block:: python
162+
163+
import pytorch_lightning as pl
164+
165+
my_model = MyLightningModule()
166+
trainer = pl.Trainer(num_tpu_cores=8, precision=16)
167+
trainer.fit(my_model)
168+
169+
Under the hood the xla library will use the `bfloat16 type <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_.
170+
171+
172+
About XLA
173+
----------
174+
XLA is the library that interfaces PyTorch with the TPUs.
175+
For more information check out `XLA <https://github.com/pytorch/xla>`_.

pytorch_lightning/core/hooks.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,10 @@ def on_after_backward(self):
113113
114114
"""
115115

116-
def backward(self, use_amp, loss, optimizer, optimizer_idx):
116+
def backward(self, trainer, loss, optimizer, optimizer_idx):
117117
"""Override backward with your own implementation if you need to
118118
119-
:param use_amp: Whether amp was requested or not
119+
:param trainer: Pointer to the trainer
120120
:param loss: Loss is already scaled by accumulated grads
121121
:param optimizer: Current optimizer being used
122122
:param optimizer_idx: Index of the current optimizer being used
@@ -137,8 +137,11 @@ def backward(self, use_amp, loss, optimizer):
137137
loss.backward()
138138
139139
"""
140-
if use_amp:
141-
with amp.scale_loss(loss, optimizer) as scaled_loss:
142-
scaled_loss.backward()
140+
if trainer.precision == 16:
141+
142+
# .backward is not special on 16-bit with TPUs
143+
if not trainer.on_tpu:
144+
with amp.scale_loss(loss, optimizer) as scaled_loss:
145+
scaled_loss.backward()
143146
else:
144147
loss.backward()

pytorch_lightning/trainer/auto_mix_precision.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212

1313
class TrainerAMPMixin(ABC):
1414

15+
def __init__(self):
16+
self.use_amp = None
17+
1518
def init_amp(self, use_amp):
1619
self.use_amp = use_amp and APEX_AVAILABLE
1720
if self.use_amp:

pytorch_lightning/trainer/data_loading.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def __init__(self):
3636
self.use_ddp2 = None
3737
self.shown_warnings = None
3838
self.val_check_interval = None
39+
self.use_tpu = None
40+
self.tpu_local_core_rank = None
3941

4042
def _percent_range_check(self, name):
4143
value = getattr(self, name)
@@ -80,9 +82,10 @@ def init_train_dataloader(self, model):
8082
self.val_check_batch = max(1, self.val_check_batch)
8183

8284
on_ddp = self.use_ddp or self.use_ddp2
83-
if on_ddp and not isinstance(self.get_train_dataloader().sampler, DistributedSampler):
85+
needs_sampler = on_ddp or self.use_tpu
86+
if needs_sampler and not isinstance(self.get_train_dataloader().sampler, DistributedSampler):
8487
msg = """
85-
You're using multiple gpus and multiple nodes without using a DistributedSampler
88+
You're using multiple gpus and multiple nodes, or TPUs without using a
8689
to assign a subset of your data to each process. To silence this warning, pass a
8790
DistributedSampler to your DataLoader.
8891
@@ -119,13 +122,14 @@ def init_val_dataloader(self, model):
119122
self.num_val_batches = int(self.num_val_batches * self.val_percent_check)
120123

121124
on_ddp = self.use_ddp or self.use_ddp2
122-
if on_ddp and self.get_val_dataloaders() is not None:
125+
needs_sampler = on_ddp or self.use_tpu
126+
if needs_sampler and self.get_val_dataloaders() is not None:
123127
for dataloader in self.get_val_dataloaders():
124128
if not isinstance(dataloader.sampler, DistributedSampler):
125129
msg = """
126130
Your val_dataloader(s) don't use DistributedSampler.
127131
128-
You're using multiple gpus and multiple nodes without using a
132+
You're using multiple gpus and multiple nodes, or TPUs without using a
129133
DistributedSampler to assign a subset of your data to each process.
130134
To silence this warning, pass a DistributedSampler to your DataLoader.
131135
@@ -162,13 +166,14 @@ def init_test_dataloader(self, model):
162166
self.num_test_batches = int(self.num_test_batches * self.test_percent_check)
163167

164168
on_ddp = self.use_ddp or self.use_ddp2
165-
if on_ddp and self.get_test_dataloaders() is not None:
169+
needs_sampler = on_ddp or self.use_tpu
170+
if needs_sampler and self.get_test_dataloaders() is not None:
166171
for dataloader in self.get_test_dataloaders():
167172
if not isinstance(dataloader.sampler, DistributedSampler):
168173
msg = """
169174
Your `test_dataloader(s)` don't use DistributedSampler.
170175
171-
You're using multiple gpus and multiple nodes without using a
176+
You're using multiple gpus and multiple nodes, or TPUs without using a
172177
DistributedSampler to assign a subset of your data to each process.
173178
To silence this warning, pass a DistributedSampler to your DataLoader.
174179
@@ -210,6 +215,14 @@ def get_dataloaders(self, model):
210215
self.get_test_dataloaders()
211216
self.get_val_dataloaders()
212217

218+
# on TPUs load each dataloader only on process 0
219+
# this will trigger the data downloads
220+
if self.use_tpu:
221+
if self.tpu_local_core_rank == 0:
222+
self.get_train_dataloader()
223+
self.get_test_dataloaders()
224+
self.get_val_dataloaders()
225+
213226
# support IterableDataset for train data
214227
self.is_iterable_train_dataloader = (
215228
EXIST_ITER_DATASET and isinstance(self.get_train_dataloader().dataset, IterableDataset))

pytorch_lightning/trainer/distrib_data_parallel.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def __init__(self):
144144
self.distributed_backend = None
145145
self.use_amp = None
146146
self.amp_level = None
147+
self.use_tpu = None
147148

148149
@abstractmethod
149150
def copy_trainer_model_properties(self, model):
@@ -160,6 +161,13 @@ def init_optimizers(self, optimizers):
160161
# this is just empty shell for code from other class
161162
pass
162163

164+
def init_tpu(self):
165+
# turn off all the GPU stuff
166+
self.distributed_backend = None
167+
168+
# enable tpu
169+
self.use_tpu = True
170+
163171
def set_distributed_mode(self, distributed_backend, num_gpu_nodes):
164172
# skip for CPU
165173
if self.num_gpus == 0:

0 commit comments

Comments
 (0)