|
| 1 | +TPU support |
| 2 | +=========== |
| 3 | + |
| 4 | +Lightning supports running on TPUs. At this moment, TPUs are only available |
| 5 | +on Google Cloud (GCP). For more information on TPUs |
| 6 | +`watch this video <https://www.youtube.com/watch?v=kPMpmcl_Pyw>`_. |
| 7 | + |
| 8 | +Live demo |
| 9 | +---------- |
| 10 | +Check out this `Google Colab <https://colab.research.google.com/drive/1-_LKx4HwAxl5M6xPJmqAAu444LTDQoa3>`_ to see how to train MNIST on TPUs. |
| 11 | + |
| 12 | +TPU Terminology |
| 13 | +--------------- |
| 14 | +A TPU is a Tensor processing unit. Each TPU has 8 cores where each |
| 15 | +core is optimized for 128x128 matrix multiplies. In general, a single |
| 16 | +TPU is about as fast as 5 V100 GPUs! |
| 17 | + |
| 18 | +A TPU pod hosts many TPUs on it. Currently, TPU pod v2 has 2048 cores! |
| 19 | +You can request a full pod from Google cloud or a "slice" which gives you |
| 20 | +some subset of those 2048 cores. |
| 21 | + |
| 22 | +How to access TPUs |
| 23 | +------------------- |
| 24 | +To access TPUs there are two main ways. |
| 25 | + |
| 26 | +1. Using google colab. |
| 27 | +2. Using Google Cloud (GCP). |
| 28 | + |
| 29 | +Colab TPUs |
| 30 | +----------- |
| 31 | +Colab is like a jupyter notebook with a free GPU or TPU |
| 32 | +hosted on GCP. |
| 33 | + |
| 34 | +To get a TPU on colab, follow these steps: |
| 35 | + |
| 36 | +1. Go to https://colab.research.google.com/. |
| 37 | + |
| 38 | +2. Click "new notebook" (bottom right of pop-up). |
| 39 | + |
| 40 | +3. Click runtime > change runtime settings. Select Python 3, |
| 41 | +and hardware accelerator "TPU". This will give you a TPU with 8 cores. |
| 42 | + |
| 43 | +4. Next, insert this code into the first cell and execute. This |
| 44 | +will install the xla library that interfaces between PyTorch and |
| 45 | +the TPU. |
| 46 | + |
| 47 | +.. code-block:: python |
| 48 | +
|
| 49 | + import collections |
| 50 | + from datetime import datetime, timedelta |
| 51 | + import os |
| 52 | + import requests |
| 53 | + import threading |
| 54 | +
|
| 55 | + _VersionConfig = collections.namedtuple('_VersionConfig', 'wheels,server') |
| 56 | + VERSION = "xrt==1.15.0" #@param ["xrt==1.15.0", "torch_xla==nightly"] |
| 57 | + CONFIG = { |
| 58 | + 'xrt==1.15.0': _VersionConfig('1.15', '1.15.0'), |
| 59 | + 'torch_xla==nightly': _VersionConfig('nightly', 'XRT-dev{}'.format( |
| 60 | + (datetime.today() - timedelta(1)).strftime('%Y%m%d'))), |
| 61 | + }[VERSION] |
| 62 | + DIST_BUCKET = 'gs://tpu-pytorch/wheels' |
| 63 | + TORCH_WHEEL = 'torch-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels) |
| 64 | + TORCH_XLA_WHEEL = 'torch_xla-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels) |
| 65 | + TORCHVISION_WHEEL = 'torchvision-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels) |
| 66 | +
|
| 67 | + # Update TPU XRT version |
| 68 | + def update_server_xrt(): |
| 69 | + print('Updating server-side XRT to {} ...'.format(CONFIG.server)) |
| 70 | + url = 'http://{TPU_ADDRESS}:8475/requestversion/{XRT_VERSION}'.format( |
| 71 | + TPU_ADDRESS=os.environ['COLAB_TPU_ADDR'].split(':')[0], |
| 72 | + XRT_VERSION=CONFIG.server, |
| 73 | + ) |
| 74 | + print('Done updating server-side XRT: {}'.format(requests.post(url))) |
| 75 | +
|
| 76 | + update = threading.Thread(target=update_server_xrt) |
| 77 | + update.start() |
| 78 | +
|
| 79 | + # Install Colab TPU compat PyTorch/TPU wheels and dependencies |
| 80 | + !pip uninstall -y torch torchvision |
| 81 | + !gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" . |
| 82 | + !gsutil cp "$DIST_BUCKET/$TORCH_XLA_WHEEL" . |
| 83 | + !gsutil cp "$DIST_BUCKET/$TORCHVISION_WHEEL" . |
| 84 | + !pip install "$TORCH_WHEEL" |
| 85 | + !pip install "$TORCH_XLA_WHEEL" |
| 86 | + !pip install "$TORCHVISION_WHEEL" |
| 87 | + !sudo apt-get install libomp5 |
| 88 | + update.join() |
| 89 | +5. Once the above is done, install PyTorch Lightning (v 0.6.1+). |
| 90 | + |
| 91 | +.. code-block:: |
| 92 | +
|
| 93 | + ! pip install pytorch-lightning |
| 94 | +
|
| 95 | +6. Then set up your LightningModule as normal. |
| 96 | + |
| 97 | +7. TPUs require a DistributedSampler. That means you should change your |
| 98 | +train_dataloader (and val, train) code as follows. |
| 99 | + |
| 100 | +.. code-block:: python |
| 101 | +
|
| 102 | + import torch_xla.core.xla_model as xm |
| 103 | +
|
| 104 | + @pl.data_loader |
| 105 | + def train_dataloader(self): |
| 106 | + dataset = MNIST( |
| 107 | + os.getcwd(), |
| 108 | + train=True, |
| 109 | + download=True, |
| 110 | + transform=transforms.ToTensor() |
| 111 | + ) |
| 112 | +
|
| 113 | + # required for TPU support |
| 114 | + sampler = None |
| 115 | + if use_tpu: |
| 116 | + sampler = torch.utils.data.distributed.DistributedSampler( |
| 117 | + dataset, |
| 118 | + num_replicas=xm.xrt_world_size(), |
| 119 | + rank=xm.get_ordinal(), |
| 120 | + shuffle=True |
| 121 | + ) |
| 122 | +
|
| 123 | + loader = DataLoader( |
| 124 | + dataset, |
| 125 | + sampler=sampler, |
| 126 | + batch_size=32 |
| 127 | + ) |
| 128 | +
|
| 129 | + return loader |
| 130 | +
|
| 131 | +8. Configure the number of TPU cores in the trainer. You can only choose |
| 132 | +1 or 8. To use a full TPU pod skip to the TPU pod section. |
| 133 | + |
| 134 | +.. code-block:: python |
| 135 | +
|
| 136 | + import pytorch_lightning as pl |
| 137 | +
|
| 138 | + my_model = MyLightningModule() |
| 139 | + trainer = pl.Trainer(num_tpu_cores=8) |
| 140 | + trainer.fit(my_model) |
| 141 | +
|
| 142 | +That's it! Your model will train on all 8 TPU cores. |
| 143 | + |
| 144 | +TPU Pod |
| 145 | +-------- |
| 146 | +To train on more than 8 cores, your code actually doesn't change! |
| 147 | +All you need to do is submit the following command: |
| 148 | + |
| 149 | +.. code-block:: bash |
| 150 | + $ python -m torch_xla.distributed.xla_dist |
| 151 | + --tpu=$TPU_POD_NAME |
| 152 | + --conda-env=torch-xla-nightly |
| 153 | + -- python /usr/share/torch-xla-0.5/pytorch/xla/test/test_train_imagenet.py --fake_data |
| 154 | +
|
| 155 | +16 bit precision |
| 156 | +----------------- |
| 157 | +Lightning also supports training in 16-bit precision with TPUs. |
| 158 | +By default, TPU training will use 32-bit precision. To enable 16-bit, also |
| 159 | +set the 16-bit flag. |
| 160 | + |
| 161 | +.. code-block:: python |
| 162 | +
|
| 163 | + import pytorch_lightning as pl |
| 164 | +
|
| 165 | + my_model = MyLightningModule() |
| 166 | + trainer = pl.Trainer(num_tpu_cores=8, precision=16) |
| 167 | + trainer.fit(my_model) |
| 168 | +
|
| 169 | +Under the hood the xla library will use the `bfloat16 type <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_. |
| 170 | + |
| 171 | + |
| 172 | +About XLA |
| 173 | +---------- |
| 174 | +XLA is the library that interfaces PyTorch with the TPUs. |
| 175 | +For more information check out `XLA <https://github.com/pytorch/xla>`_. |
0 commit comments