PyTorch 并行處理最佳實(shí)踐

2020-09-11 09:36 更新
原文: https://pytorch.org/docs/stable/notes/multiprocessing.html

torch.multiprocessing 是 Python python:multiprocessing模塊的替代品。 它支持完全相同的操作,但是對(duì)其進(jìn)行了擴(kuò)展,以便所有通過(guò)python:multiprocessing.Queue發(fā)送的張量將其數(shù)據(jù)移至共享內(nèi)存中,并且僅將句柄發(fā)送給另一個(gè)進(jìn)程。

注意

將 Tensor 發(fā)送到另一個(gè)進(jìn)程時(shí),將共享 Tensor 數(shù)據(jù)。 如果 torch.Tensor.grad 不是None,則也將其共享。 將沒(méi)有 torch.Tensor.grad 字段的 Tensor 發(fā)送到另一個(gè)進(jìn)程后,它將創(chuàng)建一個(gè)特定于標(biāo)準(zhǔn)進(jìn)程的.grad Tensor 不會(huì)自動(dòng)在所有進(jìn)程之間共享,這與  Tensor 的數(shù)據(jù)共享方式不同。

這允許實(shí)施各種訓(xùn)練方法,例如 Hogwild,A3C 或任何其他需要異步操作的方法。

并行處理中的 CUDA

CUDA 運(yùn)行時(shí)不支持fork啟動(dòng)方法。 但是,Python 2 中的python:multiprocessing只能使用fork創(chuàng)建子進(jìn)程。 因此,需要 Python 3 和spawnforkserver啟動(dòng)方法在子進(jìn)程中使用 CUDA。

Note

可以通過(guò)使用multiprocessing.get_context(...)創(chuàng)建上下文或直接使用multiprocessing.set_start_method(...)來(lái)設(shè)置啟動(dòng)方法。

與 CPU 張量不同,只要接收過(guò)程保留張量的副本,就需要發(fā)送過(guò)程來(lái)保留原始張量。 它是在幕后實(shí)施的,但要求用戶遵循最佳實(shí)踐才能使程序正確運(yùn)行。 例如,只要使用者進(jìn)程具有對(duì)張量的引用,發(fā)送進(jìn)程就必須保持活動(dòng)狀態(tài),并且如果使用者進(jìn)程通過(guò)致命信號(hào)異常退出,則引用計(jì)數(shù)無(wú)法保存您。

最佳做法和提示

避免和消除僵局

產(chǎn)生新進(jìn)程時(shí),很多事情都會(huì)出錯(cuò),死鎖的最常見(jiàn)原因是后臺(tái)線程。 如果有任何持有鎖的線程或?qū)肽K的線程,并且調(diào)用了fork,則子進(jìn)程很可能會(huì)處于損壞狀態(tài),并且將以不同的方式死鎖或失敗。 請(qǐng)注意,即使您不這樣做,內(nèi)置庫(kù)的 Python 也會(huì)這樣做-不需要比python:multiprocessing看起來(lái)更深。 python:multiprocessing.Queue實(shí)際上是一個(gè)非常復(fù)雜的類(lèi),它產(chǎn)生用于序列化,發(fā)送和接收對(duì)象的多個(gè)線程,它們也可能導(dǎo)致上述問(wèn)題。 如果您遇到這種情況,請(qǐng)嘗試使用SimpleQueue,它不使用任何其他線程。

我們正在努力為您提供便利,并確保不會(huì)發(fā)生這些僵局,但有些事情是我們無(wú)法控制的。 如果您有一段時(shí)間無(wú)法解決的問(wèn)題,請(qǐng)嘗試與論壇聯(lián)系,我們將解決是否可以解決的問(wèn)題。

重用通過(guò)隊(duì)列傳遞的緩沖區(qū)

請(qǐng)記住,每次將 Tensor 放入python:multiprocessing.Queue時(shí),都必須將其移到共享內(nèi)存中。 如果已共享,則為空操作,否則將產(chǎn)生額外的內(nèi)存副本,從而減慢整個(gè)過(guò)程。 即使您有一個(gè)將數(shù)據(jù)發(fā)送到單個(gè)進(jìn)程的進(jìn)程池,也要使它將緩沖區(qū)發(fā)送回去-這幾乎是免費(fèi)的,并且可以避免在發(fā)送下一批時(shí)復(fù)制。

異步多進(jìn)程訓(xùn)練(例如 Hogwild)

使用 torch.multiprocessing ,可以異步訓(xùn)練模型,參數(shù)可以始終共享,也可以定期同步。 在第一種情況下,我們建議發(fā)送整個(gè)模型對(duì)象,而在后一種情況下,建議僅發(fā)送 state_dict() 。

我們建議使用python:multiprocessing.Queue在進(jìn)程之間傳遞各種 PyTorch 對(duì)象。 例如 當(dāng)使用fork start 方法時(shí),它會(huì)繼承共享內(nèi)存中已經(jīng)存在的張量和存儲(chǔ),但是,它非常容易出錯(cuò),應(yīng)謹(jǐn)慎使用,并且只有高級(jí)用戶可以使用。 隊(duì)列即使有時(shí)不是很優(yōu)雅的解決方案,也可以在所有情況下正常工作。

警告

您應(yīng)謹(jǐn)慎使用不受if __name__ == '__main__'約束的全局語(yǔ)句。 如果使用與fork不同的啟動(dòng)方法,則將在所有子過(guò)程中執(zhí)行它們。

霍格威爾德

您可以在示例存儲(chǔ)庫(kù)中找到具體的 Hogwild 實(shí)現(xiàn),但為了展示代碼的整體結(jié)構(gòu),下面還有一個(gè)最小的示例:

import torch.multiprocessing as mp
from model import MyModel


def train(model):
    # Construct data_loader, optimizer, etc.
    for data, labels in data_loader:
        optimizer.zero_grad()
        loss_fn(model(data), labels).backward()
        optimizer.step()  # This will update the shared parameters


if __name__ == '__main__':
    num_processes = 4
    model = MyModel()
    # NOTE: this is required for the ``fork`` method to work
    model.share_memory()
    processes = []
    for rank in range(num_processes):
        p = mp.Process(target=train, args=(model,))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()



以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)