PyTorch 經(jīng)常問(wèn)的問(wèn)題

2020-09-10 17:00 更新
原文: https://pytorch.org/docs/stable/notes/faq.html

我的模型報(bào)告“ CUDA 運(yùn)行時(shí)錯(cuò)誤(2):內(nèi)存不足”

如錯(cuò)誤消息所暗示,您的 GPU 內(nèi)存已用完。 由于我們經(jīng)常在 PyTorch 中處理大量數(shù)據(jù),因此小錯(cuò)誤可能會(huì)迅速導(dǎo)致您的程序用盡所有 GPU; 幸運(yùn)的是,這些情況下的修復(fù)程序通常很簡(jiǎn)單。 以下是一些常見(jiàn)的檢查事項(xiàng):

不要在整個(gè)訓(xùn)練循環(huán)中累積歷史記錄。 默認(rèn)情況下,涉及需要漸變的變量的計(jì)算將保留歷史記錄。 這意味著您應(yīng)避免在計(jì)算中使用此類變量,這些變量將不受訓(xùn)練循環(huán)的影響,例如在跟蹤統(tǒng)計(jì)信息時(shí)。 相反,您應(yīng)該分離變量或訪問(wèn)其基礎(chǔ)數(shù)據(jù)。

有時(shí),可微變量發(fā)生時(shí)可能不是很明顯。 考慮以下訓(xùn)練循環(huán)(從刪節(jié)):

total_loss = 0
for i in range(10000):
    optimizer.zero_grad()
    output = model(input)
    loss = criterion(output)
    loss.backward()
    optimizer.step()
    total_loss += loss

在這里,total_loss會(huì)在您的訓(xùn)練循環(huán)中累積歷史記錄,因?yàn)?code>loss是具有自動(dòng)分級(jí)歷史記錄的可微變量。 您可以改寫 <cite>total_loss + = float(loss)</cite>來(lái)解決此問(wèn)題。

此問(wèn)題的其他實(shí)例: 1 。

不要使用不需要的張量和變量。 如果將 Tensor 或 Variable 分配給本地,Python 將不會(huì)取消分配,直到本地超出范圍。 您可以使用del x釋放此參考。 同樣,如果將 Tensor 或 Variable 分配給對(duì)象的成員變量,則在對(duì)象超出范圍之前它不會(huì)釋放。 如果不使用不需要的臨時(shí)存儲(chǔ),則將獲得最佳的內(nèi)存使用率。

當(dāng)?shù)厝说姆秶赡軙?huì)超出您的預(yù)期。 例如:

for i in range(5):
    intermediate = f(input[i])
    result += g(intermediate)
output = h(result)
return output

這里,即使h正在執(zhí)行,intermediate仍保持活動(dòng)狀態(tài),因?yàn)樗淖饔糜虺隽搜h(huán)的結(jié)尾。 要提早釋放它,使用完后應(yīng)del intermediate。

不要對(duì)太大的序列運(yùn)行 RNN。 通過(guò) RNN 反向傳播所需的內(nèi)存量與 RNN 輸入的長(zhǎng)度成線性比例; 因此,如果您嘗試向 RNN 輸入過(guò)長(zhǎng)的序列,則會(huì)耗盡內(nèi)存。

這種現(xiàn)象的技術(shù)術(shù)語(yǔ)是到時(shí)間的反向傳播,關(guān)于如何實(shí)現(xiàn)截?cái)?BPTT 的參考很??多,包括字語(yǔ)言模型示例; 截?cái)嘤?a rel="external nofollow" target="_blank" target="_blank">本論壇帖子中所述的repackage功能處理。

請(qǐng)勿使用太大的線性圖層。 線性層nn.Linear(m, n)使用 內(nèi)存:也就是說(shuō),權(quán)重的內(nèi)存要求與要素?cái)?shù)量成正比關(guān)系。 以這種方式穿透內(nèi)存非常容易(請(qǐng)記住,您至少需要權(quán)重大小的兩倍,因?yàn)槟€需要存儲(chǔ)漸變。)

我的 GPU 內(nèi)存未正確釋放

PyTorch 使用緩存內(nèi)存分配器來(lái)加速內(nèi)存分配。 因此,nvidia-smi中顯示的值通常不能反映真實(shí)的內(nèi)存使用情況。 有關(guān) GPU 內(nèi)存管理的更多詳細(xì)信息,請(qǐng)參見(jiàn)內(nèi)存管理。

如果即使在退出 Python 后仍沒(méi)有釋放 GPU 內(nèi)存,則很可能某些 Python 子進(jìn)程仍然存在。 您可以通過(guò)ps -elf | grep python找到它們,然后使用kill -9 [pid]手動(dòng)將其殺死。

我的數(shù)據(jù)加載器工作人員返回相同的隨機(jī)數(shù)

您可能會(huì)使用其他庫(kù)在數(shù)據(jù)集中生成隨機(jī)數(shù)。 例如,當(dāng)通過(guò)fork啟動(dòng)工作程序子流程時(shí),NumPy 的 RNG 被復(fù)制。 請(qǐng)參閱 torch.utils.data.DataLoade 的文檔,以了解如何通過(guò)worker_init_fn選項(xiàng)在工人中正確設(shè)置隨機(jī)種子。

我的經(jīng)常性網(wǎng)絡(luò)無(wú)法使用數(shù)據(jù)并行性

在 Module 與 DataParallel 或 data_parallel() 中使用pack sequence -> recurrent network -> unpack sequence模式是很微妙的。 每個(gè)設(shè)備上每個(gè)forward()的輸入僅是整個(gè)輸入的一部分。 由于默認(rèn)情況下,拆包操作 torch.nn.utils.rnn.pad_packed_sequence() 僅填充其看到的最長(zhǎng)輸入,即該特定設(shè)備上的最長(zhǎng)輸入,因此,將結(jié)果匯總在一起時(shí)會(huì)發(fā)生大小不匹配的情況。 因此,您可以改而利用  pad_packed_sequence() 的total_length自變量來(lái)確保forward()調(diào)用相同長(zhǎng)度的返回序列。 例如,您可以編寫:

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


class MyModule(nn.Module):
    # ... __init__, other methods, etc.


    # padded_input is of shape [B x T x *] (batch_first mode) and contains
    # the sequences sorted by lengths
    #   B is the batch size
    #   T is max sequence length
    def forward(self, padded_input, input_lengths):
        total_length = padded_input.size(1)  # get the max sequence length
        packed_input = pack_padded_sequence(padded_input, input_lengths,
                                            batch_first=True)
        packed_output, _ = self.my_lstm(packed_input)
        output, _ = pad_packed_sequence(packed_output, batch_first=True,
                                        total_length=total_length)
        return output


m = MyModule().cuda()
dp_m = nn.DataParallel(m)

此外,當(dāng)批處理尺寸為1(即batch_first=False)且數(shù)據(jù)平行時(shí),需要格外小心。 在這種情況下,pack_padded_sequence padding_input的第一個(gè)參數(shù)的形狀將為[T x B x *],并且應(yīng)沿昏暗1分散,而第二個(gè)參數(shù)input_lengths的形狀將為[B],并且應(yīng)沿昏暗[[Gate] 0。 將需要額外的代碼來(lái)操縱張量形狀。

以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)