W3Cschool
恭喜您成為首批注冊(cè)用戶
獲得88經(jīng)驗(yàn)值獎(jiǎng)勵(lì)
原文: https://pytorch.org/docs/stable/notes/multiprocessing.html
torch.multiprocessing
是 Python python:multiprocessing
模塊的替代品。 它支持完全相同的操作,但是對(duì)其進(jìn)行了擴(kuò)展,以便所有通過(guò)python:multiprocessing.Queue
發(fā)送的張量將其數(shù)據(jù)移至共享內(nèi)存中,并且僅將句柄發(fā)送給另一個(gè)進(jìn)程。
注意
將 Tensor
發(fā)送到另一個(gè)進(jìn)程時(shí),將共享 Tensor
數(shù)據(jù)。 如果 torch.Tensor.grad
不是None,則也將其共享。 將沒(méi)有 torch.Tensor.grad
字段的 Tensor
發(fā)送到另一個(gè)進(jìn)程后,它將創(chuàng)建一個(gè)特定于標(biāo)準(zhǔn)進(jìn)程的.grad Tensor
不會(huì)自動(dòng)在所有進(jìn)程之間共享,這與 Tensor
的數(shù)據(jù)共享方式不同。
這允許實(shí)施各種訓(xùn)練方法,例如 Hogwild,A3C 或任何其他需要異步操作的方法。
CUDA 運(yùn)行時(shí)不支持fork
啟動(dòng)方法。 但是,Python 2 中的python:multiprocessing
只能使用fork
創(chuàng)建子進(jìn)程。 因此,需要 Python 3 和spawn
或forkserver
啟動(dòng)方法在子進(jìn)程中使用 CUDA。
Note
可以通過(guò)使用multiprocessing.get_context(...)
創(chuàng)建上下文或直接使用multiprocessing.set_start_method(...)
來(lái)設(shè)置啟動(dòng)方法。
與 CPU 張量不同,只要接收過(guò)程保留張量的副本,就需要發(fā)送過(guò)程來(lái)保留原始張量。 它是在幕后實(shí)施的,但要求用戶遵循最佳實(shí)踐才能使程序正確運(yùn)行。 例如,只要使用者進(jìn)程具有對(duì)張量的引用,發(fā)送進(jìn)程就必須保持活動(dòng)狀態(tài),并且如果使用者進(jìn)程通過(guò)致命信號(hào)異常退出,則引用計(jì)數(shù)無(wú)法保存您。
產(chǎn)生新進(jìn)程時(shí),很多事情都會(huì)出錯(cuò),死鎖的最常見(jiàn)原因是后臺(tái)線程。 如果有任何持有鎖的線程或?qū)肽K的線程,并且調(diào)用了fork
,則子進(jìn)程很可能會(huì)處于損壞狀態(tài),并且將以不同的方式死鎖或失敗。 請(qǐng)注意,即使您不這樣做,內(nèi)置庫(kù)的 Python 也會(huì)這樣做-不需要比python:multiprocessing
看起來(lái)更深。 python:multiprocessing.Queue
實(shí)際上是一個(gè)非常復(fù)雜的類(lèi),它產(chǎn)生用于序列化,發(fā)送和接收對(duì)象的多個(gè)線程,它們也可能導(dǎo)致上述問(wèn)題。
如果您遇到這種情況,請(qǐng)嘗試使用SimpleQueue
,它不使用任何其他線程。
我們正在努力為您提供便利,并確保不會(huì)發(fā)生這些僵局,但有些事情是我們無(wú)法控制的。 如果您有一段時(shí)間無(wú)法解決的問(wèn)題,請(qǐng)嘗試與論壇聯(lián)系,我們將解決是否可以解決的問(wèn)題。
請(qǐng)記住,每次將 Tensor
放入python:multiprocessing.Queue
時(shí),都必須將其移到共享內(nèi)存中。 如果已共享,則為空操作,否則將產(chǎn)生額外的內(nèi)存副本,從而減慢整個(gè)過(guò)程。 即使您有一個(gè)將數(shù)據(jù)發(fā)送到單個(gè)進(jìn)程的進(jìn)程池,也要使它將緩沖區(qū)發(fā)送回去-這幾乎是免費(fèi)的,并且可以避免在發(fā)送下一批時(shí)復(fù)制。
使用 torch.multiprocessing
,可以異步訓(xùn)練模型,參數(shù)可以始終共享,也可以定期同步。 在第一種情況下,我們建議發(fā)送整個(gè)模型對(duì)象,而在后一種情況下,建議僅發(fā)送 state_dict() 。
我們建議使用python:multiprocessing.Queue
在進(jìn)程之間傳遞各種 PyTorch 對(duì)象。 例如 當(dāng)使用fork
start 方法時(shí),它會(huì)繼承共享內(nèi)存中已經(jīng)存在的張量和存儲(chǔ),但是,它非常容易出錯(cuò),應(yīng)謹(jǐn)慎使用,并且只有高級(jí)用戶可以使用。 隊(duì)列即使有時(shí)不是很優(yōu)雅的解決方案,也可以在所有情況下正常工作。
警告
您應(yīng)謹(jǐn)慎使用不受if __name__ == '__main__'
約束的全局語(yǔ)句。 如果使用與fork不同的啟動(dòng)方法,則將在所有子過(guò)程中執(zhí)行它們。
您可以在示例存儲(chǔ)庫(kù)中找到具體的 Hogwild 實(shí)現(xiàn),但為了展示代碼的整體結(jié)構(gòu),下面還有一個(gè)最小的示例:
import torch.multiprocessing as mp
from model import MyModel
def train(model):
# Construct data_loader, optimizer, etc.
for data, labels in data_loader:
optimizer.zero_grad()
loss_fn(model(data), labels).backward()
optimizer.step() # This will update the shared parameters
if __name__ == '__main__':
num_processes = 4
model = MyModel()
# NOTE: this is required for the ``fork`` method to work
model.share_memory()
processes = []
for rank in range(num_processes):
p = mp.Process(target=train, args=(model,))
p.start()
processes.append(p)
for p in processes:
p.join()
Copyright©2021 w3cschool編程獅|閩ICP備15016281號(hào)-3|閩公網(wǎng)安備35020302033924號(hào)
違法和不良信息舉報(bào)電話:173-0602-2364|舉報(bào)郵箱:jubao@eeedong.com
掃描二維碼
下載編程獅App
編程獅公眾號(hào)
聯(lián)系方式:
更多建議: