原文: http://pytorch.org/xla/
PyTorch 使用 torch_xla 軟件包在 XPU 設(shè)備(如 TPU)上運(yùn)行。 本文檔介紹了如何在這些設(shè)備上運(yùn)行模型。
PyTorch / XLA 向 PyTorch 添加了新的xla設(shè)備類型。 此設(shè)備類型的工作方式與其他 PyTorch 設(shè)備類型一樣。 例如,以下是創(chuàng)建和打印 XLA 張量的方法:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)
此代碼應(yīng)該看起來很熟悉。 PyTorch / XLA 使用與常規(guī) PyTorch 相同的界面,但有一些附加功能。 導(dǎo)入torch_xla會(huì)初始化 PyTorch / XLA,xm.xla_device()
會(huì)返回當(dāng)前的 XLA 設(shè)備。 根據(jù)您的環(huán)境,這可能是 CPU 或 TPU。
可以像 CPU 或 CUDA 張量一樣在 XLA 張量上執(zhí)行 PyTorch 操作。
例如,可以將 XLA 張量添加在一起:
t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
print(t0 + t1)
或乘以矩陣:
print(t0.mm(t1))
或與神經(jīng)網(wǎng)絡(luò)模塊一起使用:
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_out = linear(l_in)
print(l_out)
與其他設(shè)備類型一樣,XLA 張量僅可與同一設(shè)備上的其他 XLA 張量一起使用。 所以代碼像
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
print(l_out)
## Input tensor is not an XLA tensor: torch.FloatTensor
由于 torch.nn.Linear 模塊在 CPU 上,因此將引發(fā)錯(cuò)誤。
建立新的 PyTorch 網(wǎng)絡(luò)或轉(zhuǎn)換現(xiàn)有網(wǎng)絡(luò)以在 XLA 設(shè)備上運(yùn)行僅需要幾行 XLA 專用代碼。 以下代碼片段突出顯示了在單個(gè)設(shè)備,具有 XLA 并行處理功能的多個(gè)設(shè)備或具有 XLA 多線程的多個(gè)線程上運(yùn)行時(shí)的這些行。
以下代碼片段顯示了單個(gè) XLA 設(shè)備上的網(wǎng)絡(luò)訓(xùn)練:
import torch_xla.core.xla_model as xm
device = xm.xla_device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
for data, target in train_loader:
optimizer.zero_grad()
data = data.to(device)
target = target.to(device)
output = model(data)
loss = loss_fn(output, target)
loss.backward()
xm.optimizer_step(optimizer, barrier=True)
此代碼段突出顯示了切換模型以在 XLA 上運(yùn)行非常容易。 模型定義,數(shù)據(jù)加載器,優(yōu)化器和訓(xùn)練循環(huán)可在任何設(shè)備上運(yùn)行。 唯一的 XLA 特定代碼是幾行代碼,這些代碼獲取 XLA 設(shè)備并以屏障進(jìn)入優(yōu)化程序。 在每次訓(xùn)練迭代結(jié)束時(shí)調(diào)用xm.optimizer_step(optimizer, barrier=True)
都會(huì)使 XLA 執(zhí)行其當(dāng)前圖形并更新模型的參數(shù)。 有關(guān) XLA 如何創(chuàng)建圖形和運(yùn)行操作的更多信息,請(qǐng)參見 XLA Tensor Deep Dive 。
通過在多個(gè) XLA 設(shè)備上運(yùn)行,PyTorch / XLA 可以輕松加速訓(xùn)練。 以下代碼段顯示了如何:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
def _mp_fn(index):
device = xm.xla_device()
para_loader = pl.ParallelLoader(train_loader, [device])
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
for data, target in para_loader.per_device_loader(device):
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
xm.optimizer_step(optimizer)
if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
此多設(shè)備代碼段和先前的單設(shè)備代碼段之間存在三個(gè)區(qū)別:
xmp.spawn()
創(chuàng)建分別運(yùn)行 XLA 設(shè)備的進(jìn)程。ParallelLoader
將訓(xùn)練數(shù)據(jù)加載到每個(gè)設(shè)備上。xm.optimizer_step(optimizer)
不再需要障礙。 ParallelLoader 自動(dòng)創(chuàng)建用于評(píng)估圖形的 XLA 障礙。模型定義,優(yōu)化器定義和訓(xùn)練循環(huán)保持不變。
請(qǐng)參閱完整的并行處理示例,以獲取更多關(guān)于在具有并行處理功能的多個(gè) XLA 設(shè)備上訓(xùn)練網(wǎng)絡(luò)的信息。
使用進(jìn)程(請(qǐng)參見上文)在多個(gè) XLA 設(shè)備上運(yùn)行比使用線程更可取。 但是,如果您想使用線程,則 PyTorch / XLA 具有DataParallel
接口。 以下代碼片段顯示了具有多個(gè)線程的相同網(wǎng)絡(luò)訓(xùn)練:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.data_parallel as dp
devices = xm.get_xla_supported_devices()
model_parallel = dp.DataParallel(MNIST, device_ids=devices)
def train_loop_fn(model, loader, device, context):
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
model.train()
for _, (data, target) in loader:
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
xm.optimizer_step(optimizer)
for epoch in range(1, num_epochs + 1):
model_parallel(train_loop_fn, train_loader)
多線程和并行處理代碼之間的唯一區(qū)別是:
xm.get_xla_supported_devices()
在同一過程中獲取多個(gè)設(shè)備。dp.DataParallel
中,并通過了訓(xùn)練循環(huán)和數(shù)據(jù)加載器。有關(guān)在多 XLA 設(shè)備上使用多線程訓(xùn)練網(wǎng)絡(luò)的更多信息,請(qǐng)參見完整的多線程示例。
使用 XLA 張量和設(shè)備僅需要更改幾行代碼。 但是,即使 XLA 張量的行為很像 CPU 和 CUDA 張量,其內(nèi)部結(jié)構(gòu)也不同。 本節(jié)描述了 XLA 張量獨(dú)特的原因。
CPU 和 CUDA 張量立即啟動(dòng)操作或急切啟動(dòng)。 另一方面,XLA 張量是惰性。 他們將操作記錄在圖形中,直到需要結(jié)果為止。 這樣推遲執(zhí)行,XLA 可以對(duì)其進(jìn)行優(yōu)化。 例如,多個(gè)單獨(dú)操作的圖形可能會(huì)融合為一個(gè)優(yōu)化操作。
懶惰執(zhí)行通常對(duì)調(diào)用者不可見。 當(dāng)在 XLA 設(shè)備和 CPU 之間復(fù)制數(shù)據(jù)時(shí),PyTorch / XLA 自動(dòng)構(gòu)建圖形,將它們發(fā)送到 XLA 設(shè)備,并進(jìn)行同步。 采取優(yōu)化程序步驟時(shí)插入屏障會(huì)顯式同步 CPU 和 XLA 設(shè)備。
當(dāng)在 TPU 上運(yùn)行時(shí),PyTorch / XLA 可以使用 bfloat16 數(shù)據(jù)類型。 實(shí)際上,PyTorch / XLA 在 TPU 上處理浮點(diǎn)類型(torch.float
和torch.double
)的方式有所不同。 此行為由XLA_USE_BF16
環(huán)境變量控制:
torch.float
和torch.double
均為torch.float
。XLA_USE_BF16
,則 TPU 上的torch.float
和torch.double
均為bfloat16
。TPU 上的 XLA 張量將始終報(bào)告其 PyTorch 數(shù)據(jù)類型,而不管其使用的實(shí)際數(shù)據(jù)類型是什么。 這種轉(zhuǎn)換是自動(dòng)且不透明的。 如果將 TPU 上的 XLA 張量移回 CPU,它將從其實(shí)際數(shù)據(jù)類型轉(zhuǎn)換為其 PyTorch 數(shù)據(jù)類型。
XLA 張量的內(nèi)部數(shù)據(jù)表示對(duì)于用戶而言是不透明的。 它們不公開其存儲(chǔ),并且它們總是看起來是連續(xù)的,這與 CPU 和 CUDA 張量不同。 這使 XLA 可以調(diào)整張量的內(nèi)存布局以獲得更好的性能。
XLA 張量可以從 CPU 移到 XLA 設(shè)備,也可以從 XLA 設(shè)備移到 CPU。 如果移動(dòng)了視圖,則其視圖的數(shù)據(jù)將被復(fù)制到另一臺(tái)設(shè)備,并且不會(huì)保留視圖關(guān)系。 換句話說,將數(shù)據(jù)復(fù)制到另一設(shè)備后,它與先前的設(shè)備或其上的任何張量都沒有關(guān)系。
在保存之前,應(yīng)將 XLA 張量移至 CPU,如以下代碼段所示:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
device = xm.xla_device()
t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)
tensors = (t0.cpu(), t1.cpu())
torch.save(tensors, 'tensors.pt')
tensors = torch.load('tensors.pt')
t0 = tensors[0].to(device)
t1 = tensors[1].to(device)
這使您可以將加載的張量放置在任何可用設(shè)備上。
根據(jù)以上有關(guān)將 XLA 張量移至 CPU 的說明,使用視圖時(shí)必須格外小心。 建議不要在保存張量并將其移至目標(biāo)設(shè)備后重新創(chuàng)建視圖,而不必保存視圖。
可以直接保存 XLA 張量,但不建議這樣做。 XLA 張量始終會(huì)加載回保存它們的設(shè)備,如果該設(shè)備不可用,加載將會(huì)失敗。 與所有 PyTorch 一樣,PyTorch / XLA 正在積極開發(fā)中,這種行為將來可能會(huì)改變。
其他文檔可在 PyTorch / XLA 存儲(chǔ)庫中找到。 在此處可以找到在 TPU 上運(yùn)行網(wǎng)絡(luò)的更多示例。
torch_xla.core.xla_model.xla_device(n=None, devkind=None)?
返回 XLA 設(shè)備的給定實(shí)例。
參數(shù)
退貨
具有所請(qǐng)求實(shí)例的<cite>torch設(shè)備</cite>。
torch_xla.core.xla_model.get_xla_supported_devices(devkind=None, max_devices=None)?
返回給定類型的受支持設(shè)備的列表。
Parameters
Returns
設(shè)備字符串列表。
torch_xla.core.xla_model.xrt_world_size(defval=1)?
檢索參與復(fù)制的設(shè)備數(shù)。
Parameters
defval (python:int , 可選)–如果沒有可用的復(fù)制信息,則返回默認(rèn)值。 默認(rèn)值:1
Returns
參與復(fù)制的設(shè)備數(shù)。
torch_xla.core.xla_model.get_ordinal(defval=0)?
檢索當(dāng)前進(jìn)程的復(fù)制序號(hào)。
序數(shù)范圍從 0 到 <cite>xrt_world_size()</cite>減 1。
Parameters
defval (python:int , 可選)–如果沒有可用的復(fù)制信息,則返回默認(rèn)值。 默認(rèn)值:0
Returns
當(dāng)前進(jìn)程的復(fù)制序號(hào)。
torch_xla.core.xla_model.is_master_ordinal()?
檢查當(dāng)前進(jìn)程是否為主序(0)。
Returns
一個(gè)布爾值,指示當(dāng)前進(jìn)程是否是主序。
torch_xla.core.xla_model.optimizer_step(optimizer, barrier=False, optimizer_args={})?
運(yùn)行提供的優(yōu)化器步驟并發(fā)出 XLA 設(shè)備步驟計(jì)算。
Parameters
torch.Optimizer
)–需要調(diào)用其 <cite>step()</cite>函數(shù)的<cite>torch.optim器</cite>實(shí)例。 <cite>step()</cite>函數(shù)將使用名為 <cite>optimizer_args</cite> 的參數(shù)調(diào)用。Returns
<cite>Optimizer.step()</cite>調(diào)用返回的值相同。
class torch_xla.distributed.parallel_loader.ParallelLoader(loader, devices, batchdim=0, fixed_batch_size=False, loader_prefetch_size=8, device_prefetch_size=4)?
使用背景數(shù)據(jù)上傳包裝現(xiàn)有的 PyTorch DataLoader。
Parameters
torch.utils.data.DataLoader
)–要包裝的 PyTorch DataLoader。per_device_loader(device)?
檢索給定設(shè)備的加載程序?qū)ο蟆?/p>
Parameters
設(shè)備(<cite>torch設(shè)備</cite>)–正在請(qǐng)求設(shè)備整個(gè)裝載程序。
Returns
<cite>設(shè)備</cite>的數(shù)據(jù)加載器。
class torch_xla.distributed.data_parallel.DataParallel(network, device_ids=None)?
使用線程以復(fù)制模式啟用模型網(wǎng)絡(luò)的執(zhí)行。
Parameters
torch.nn.Module
或可調(diào)用)–模型的網(wǎng)絡(luò)。 <cite>torch.nn.Module</cite> 的子類,或者是返回 <cite>torch.nn.Module</cite> 子類的可調(diào)用對(duì)象。torch.device
…)–應(yīng)在其上進(jìn)行復(fù)制的設(shè)備的列表。 如果列表為空,則網(wǎng)絡(luò)將在 PyTorch CPU 設(shè)備上運(yùn)行。__call__(loop_fn, loader, fixed_batch_size=False, batchdim=0)?
進(jìn)行一次 EPOCH 訓(xùn)練/測(cè)試。
Parameters
Returns
每個(gè)設(shè)備上 <cite>loop_fn</cite> 返回的值的列表。
torch_xla.distributed.xla_multiprocessing.spawn(fn, args=(), nprocs=None, join=True, daemon=False)?
啟用基于并行處理的復(fù)制。
Parameters
Returns
<cite>torch.multiprocessing.spawn</cite> API 返回的同一對(duì)象。
class torch_xla.utils.utils.SampleGenerator(data, sample_count)?
迭代器,它返回給定輸入數(shù)據(jù)的多個(gè)樣本。
可以代替 PyTorch <cite>DataLoader</cite> 生成合成數(shù)據(jù)。
Parameters
更多建議: