Skip to content

Latest commit

 

History

History
39 lines (29 loc) · 936 Bytes

README.md

File metadata and controls

39 lines (29 loc) · 936 Bytes

CUDA PJRT plugin (experimental)

This directory contains an experimental implementation of the PJRT GPU client as a plugin. The actual implementation of the PJRT C API lives in the main OpenXLA repository (see bazel build command below).

Building

See our contributing guide for build environment setup steps.

# Build wheel
pip wheel plugins/cuda -v
# Or install directly
pip install plugins/cuda -v

Usage

import os

# Log device type
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ['TF_CPP_VMODULE'] = 'pjrt_registry=5'

from torch_xla.experimental import plugins
import torch_xla_cuda_plugin
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr

# Use dynamic plugin instead of built-in CUDA support
plugins.use_dynamic_plugins()
plugins.register_plugin('CUDA', torch_xla_cuda_plugin.CudaPlugin())
xr.set_device_type('CUDA')

print(xm.xla_device())