Run a calculation on a Cloud TPU VM using PyTorch (original) (raw)

This document provides a brief introduction to working with PyTorch and Cloud TPU.

Before you begin

Before running the commands in this document, you must create a Google Cloud account, install the Google Cloud CLI, and configure the gcloud command. For more information, see Set up the Cloud TPU environment.

Create a Cloud TPU using gcloud

  1. Define some environment variables to make the commands easier to use.
    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-east5-a
    export ACCELERATOR_TYPE=v5litepod-8
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite

Environment variable descriptions

Variable Description
PROJECT_ID Your Google Cloud project ID. Use an existing project orcreate a new one.
TPU_NAME The name of the TPU.
ZONE The zone in which to create the TPU VM. For more information about supported zones, seeTPU regions and zones.
ACCELERATOR_TYPE The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, seeTPU versions.
RUNTIME_VERSION The Cloud TPU software version.
2. Create your TPU VM by running the following command:
$ gcloud compute tpus tpu-vm create $TPU_NAME \
--project=$PROJECT_ID \  
--zone=$ZONE \  
--accelerator-type=$ACCELERATOR_TYPE \  
--version=$RUNTIME_VERSION

Connect to your Cloud TPU VM

Connect to your TPU VM over SSH using the following command:

$ gcloud compute tpus tpu-vm ssh $TPU_NAME
--project=$PROJECT_ID
--zone=$ZONE

If you fail to connect to a TPU VM using SSH, it might be because the TPU VM doesn't have an external IP address. To access a TPU VM without an external IP address, follow the instructions in Connect to a TPU VM without a public IP address.

Install PyTorch/XLA on your TPU VM

$ (vm) sudo apt-get update $ (vm) sudo apt-get install libopenblas-dev -y $ (vm) pip install numpy $ (vm) pip install torch torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html

Verify PyTorch can access TPUs

Use the following command to verify PyTorch can access your TPUs:

$ (vm) PJRT_DEVICE=TPU python3 -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices("TPU"))"

The output from the command should look like the following:

['xla:0', 'xla:1', 'xla:2', 'xla:3', 'xla:4', 'xla:5', 'xla:6', 'xla:7']

Perform a basic calculation

  1. Create a file named tpu-test.py in the current directory and copy and paste the following script into it:
import torch  
import torch_xla.core.xla_model as xm  
dev = xm.xla_device()  
t1 = torch.randn(3,3,device=dev)  
t2 = torch.randn(3,3,device=dev)  
print(t1 + t2)  
  1. Run the script:
    (vm)$ PJRT_DEVICE=TPU python3 tpu-test.py
    The output from the script shows the result of the computation:
tensor([[-0.2121,  1.5589, -0.6951],  
        [-0.7886, -0.2022,  0.9242],  
        [ 0.8555, -1.8698,  1.4333]], device='xla:1')  

Clean up

To avoid incurring charges to your Google Cloud account for the resources used on this page, follow these steps.

  1. Disconnect from the Cloud TPU instance, if you have not already done so:
    (vm)$ exit
    Your prompt should now be username@projectname, showing you are in the Cloud Shell.
  2. Delete your Cloud TPU.
    $ gcloud compute tpus tpu-vm delete $TPU_NAME \
    --project=$PROJECT_ID \
    --zone=$ZONE
  3. Verify the resources have been deleted by running the following command. Make sure your TPU is no longer listed. The deletion might take several minutes.
    $ gcloud compute tpus tpu-vm list \
    --zone=$ZONE

What's next

Read more about Cloud TPU VMs: