Skip to content

Latest commit

 

History

History
42 lines (31 loc) · 1.05 KB

README.md

File metadata and controls

42 lines (31 loc) · 1.05 KB

CPU PJRT Plugin (testing)

This directory contains an experimental implementation of the PJRT CPU client as a plugin. This plugin is for testing only and is not officially supported. Use PJRT_DEVICE=CPU with any PyTorch/XLA installation to use built-in CPU support.

The actual implementation of the PJRT C API lives in the main OpenXLA repository (see bazel build command below).

Building

See our contributing guide for build environment setup steps.

# Build wheel
pip wheel plugins/cpu -v
# Or install directly
pip install plugins/cpu -v

Usage

import os

# Log device type
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ['TF_CPP_VMODULE'] = 'pjrt_registry=5'

from torch_xla.experimental import plugins
import torch_xla_cpu_plugin
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr

# Use dynamic plugin instead of built-in CPU support
plugins.use_dynamic_plugins()
plugins.register_plugin('CPU', torch_xla_cpu_plugin.CpuPlugin())
xr.set_device_type('CPU')

print(xm.xla_device())