原文: https://pytorch.org/tutorials/intermediate/pruning_tutorial.html
作者: Michela Paganini
最新的深度學(xué)習(xí)技術(shù)依賴于難以部署的過度參數(shù)化模型。 相反,已知生物神經(jīng)網(wǎng)絡(luò)使用有效的稀疏連通性。 為了減少內(nèi)存,電池和硬件消耗,同時(shí)又不犧牲精度,在設(shè)備上部署輕量級模型并通過私有設(shè)備內(nèi)計(jì)算來確保私密性,確定通過減少模型中參數(shù)數(shù)量來壓縮模型的最佳技術(shù)很重要。 在研究方面,修剪用于研究參數(shù)過度配置和參數(shù)不足網(wǎng)絡(luò)之間學(xué)習(xí)動(dòng)態(tài)的差異,以研究幸運(yùn)稀疏子網(wǎng)絡(luò)和初始化(“ 彩票”)作為破壞性對象的作用。 神經(jīng)結(jié)構(gòu)搜索技術(shù)等等。
在本教程中,您將學(xué)習(xí)如何使用torch.nn.utils.prune
稀疏神經(jīng)網(wǎng)絡(luò),以及如何擴(kuò)展它以實(shí)現(xiàn)自己的自定義修剪技術(shù)。
"torch>=1.4.0a0+8e8a5e0"
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
在本教程中,我們使用 LeCun 等人,1998 年的 LeNet 體系結(jié)構(gòu)。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 1 input image channel, 6 output channels, 3x3 square conv kernel
self.conv1 = nn.Conv2d(1, 6, 3)
self.conv2 = nn.Conv2d(6, 16, 3)
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = LeNet().to(device=device)
讓我們檢查一下 LeNet 模型中的(未經(jīng)修剪的)conv1
層。 目前它將包含兩個(gè)參數(shù)weight
和bias
,并且沒有緩沖區(qū)。
module = model.conv1
print(list(module.named_parameters()))
得出:
[('weight', Parameter containing:
tensor([[[[ 0.3161, -0.2212, 0.0417],
[ 0.2488, 0.2415, 0.2071],
[-0.2412, -0.2400, -0.2016]]],
[[[ 0.0419, 0.3322, -0.2106],
[ 0.1776, -0.1845, -0.3134],
[-0.0708, 0.1921, 0.3095]]],
[[[-0.2070, 0.0723, 0.2876],
[ 0.2209, 0.2077, 0.2369],
[ 0.2108, 0.0861, -0.2279]]],
[[[-0.2799, -0.1527, -0.0388],
[-0.2043, 0.1220, 0.1032],
[-0.0755, 0.1281, 0.1077]]],
[[[ 0.2035, 0.2245, -0.1129],
[ 0.3257, -0.0385, -0.0115],
[-0.3146, -0.2145, -0.1947]]],
[[[-0.1426, 0.2370, -0.1089],
[-0.2491, 0.1282, 0.1067],
[ 0.2159, -0.1725, 0.0723]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021, 0.1425], device='cuda:0',
requires_grad=True))]
print(list(module.named_buffers()))
得出:
[]
要修剪模塊(在此示例中,為 LeNet 架構(gòu)的conv1
層),請首先從torch.nn.utils.prune
中可用的那些技術(shù)中選擇一種修剪技術(shù)(或通過子類化BasePruningMethod
實(shí)現(xiàn)您自己的)。
然后,指定模塊和該模塊中要修剪的參數(shù)的名稱。 最后,使用所選修剪技術(shù)所需的適當(dāng)關(guān)鍵字參數(shù),指定修剪參數(shù)。
在此示例中,我們將在conv1
層中名為weight
的參數(shù)中隨機(jī)修剪 30%的連接。 模塊作為第一個(gè)參數(shù)傳遞給函數(shù); name
使用其字符串標(biāo)識符在該模塊內(nèi)標(biāo)識參數(shù); amount
表示與修剪的連接百分比(如果它是介于 0 和 1 之間的浮點(diǎn)數(shù)),或者表示與修剪的連接的絕對數(shù)量(如果它是非負(fù)整數(shù))。<
prune.random_unstructured(module, name="weight", amount=0.3)
修剪是通過從參數(shù)中刪除weight
并將其替換為名為weight_orig
的新參數(shù)(即,將"_orig"
附加到初始參數(shù)name
)來進(jìn)行的。 weight_orig
存儲未修剪的張量版本。 bias
未修剪,因此它將保持完整。
print(list(module.named_parameters()))
得出:
[('bias', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021, 0.1425], device='cuda:0',
requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.3161, -0.2212, 0.0417],
[ 0.2488, 0.2415, 0.2071],
[-0.2412, -0.2400, -0.2016]]],
[[[ 0.0419, 0.3322, -0.2106],
[ 0.1776, -0.1845, -0.3134],
[-0.0708, 0.1921, 0.3095]]],
[[[-0.2070, 0.0723, 0.2876],
[ 0.2209, 0.2077, 0.2369],
[ 0.2108, 0.0861, -0.2279]]],
[[[-0.2799, -0.1527, -0.0388],
[-0.2043, 0.1220, 0.1032],
[-0.0755, 0.1281, 0.1077]]],
[[[ 0.2035, 0.2245, -0.1129],
[ 0.3257, -0.0385, -0.0115],
[-0.3146, -0.2145, -0.1947]]],
[[[-0.1426, 0.2370, -0.1089],
[-0.2491, 0.1282, 0.1067],
[ 0.2159, -0.1725, 0.0723]]]], device='cuda:0', requires_grad=True))]
通過以上選擇的修剪技術(shù)生成的修剪掩碼將保存為名為weight_mask
的模塊緩沖區(qū)(即,將"_mask"
附加到初始參數(shù)name
)。
print(list(module.named_buffers()))
得出:
[('weight_mask', tensor([[[[0., 1., 0.],
[1., 0., 0.],
[1., 1., 1.]]],
[[[1., 0., 1.],
[1., 1., 0.],
[1., 0., 1.]]],
[[[1., 0., 0.],
[0., 1., 1.],
[1., 1., 1.]]],
[[[1., 0., 0.],
[1., 1., 1.],
[1., 1., 1.]]],
[[[1., 0., 1.],
[1., 1., 1.],
[0., 1., 1.]]],
[[[1., 1., 1.],
[1., 1., 0.],
[1., 1., 0.]]]], device='cuda:0'))]
為了使前向通過不更改即可工作,需要存在weight
屬性。 torch.nn.utils.prune
中實(shí)現(xiàn)的修剪技術(shù)計(jì)算權(quán)重的修剪版本(通過將掩碼與原始參數(shù)組合)并將其存儲在屬性weight
中。 注意,這不再是module
的參數(shù),現(xiàn)在只是一個(gè)屬性。
print(module.weight)
得出:
tensor([[[[ 0.0000, -0.2212, 0.0000],
[ 0.2488, 0.0000, 0.0000],
[-0.2412, -0.2400, -0.2016]]],
[[[ 0.0419, 0.0000, -0.2106],
[ 0.1776, -0.1845, -0.0000],
[-0.0708, 0.0000, 0.3095]]],
[[[-0.2070, 0.0000, 0.0000],
[ 0.0000, 0.2077, 0.2369],
[ 0.2108, 0.0861, -0.2279]]],
[[[-0.2799, -0.0000, -0.0000],
[-0.2043, 0.1220, 0.1032],
[-0.0755, 0.1281, 0.1077]]],
[[[ 0.2035, 0.0000, -0.1129],
[ 0.3257, -0.0385, -0.0115],
[-0.0000, -0.2145, -0.1947]]],
[[[-0.1426, 0.2370, -0.1089],
[-0.2491, 0.1282, 0.0000],
[ 0.2159, -0.1725, 0.0000]]]], device='cuda:0',
grad_fn=<MulBackward0>)
最后,使用 PyTorch 的forward_pre_hooks
在每次向前傳遞之前應(yīng)用修剪。 具體來說,當(dāng)修剪module
時(shí)(如我們在此處所做的那樣),它將為與之關(guān)聯(lián)的每個(gè)參數(shù)獲取forward_pre_hook
進(jìn)行修剪。 在這種情況下,由于到目前為止我們只修剪了名稱為weight
的原始參數(shù),因此只會出現(xiàn)一個(gè)鉤子。
print(module._forward_pre_hooks)
得出:
OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f1e6c425400>)])
為了完整起見,我們現(xiàn)在也可以修剪bias
,以查看module
的參數(shù),緩沖區(qū),掛鉤和屬性如何變化。 僅出于嘗試另一種修剪技術(shù)的目的,在此我們按 L1 范數(shù)修剪偏差中的 3 個(gè)最小條目,如l1_unstructured
修剪功能中所實(shí)現(xiàn)的。
prune.l1_unstructured(module, name="bias", amount=3)
現(xiàn)在,我們希望命名的參數(shù)同時(shí)包含weight_orig
(從前)和bias_orig
。 緩沖區(qū)將包括weight_mask
和bias_mask
。 兩個(gè)張量的修剪版本將作為模塊屬性存在,并且該模塊現(xiàn)在將具有兩個(gè)forward_pre_hooks
。
print(list(module.named_parameters()))
得出:
[('weight_orig', Parameter containing:
tensor([[[[ 0.3161, -0.2212, 0.0417],
[ 0.2488, 0.2415, 0.2071],
[-0.2412, -0.2400, -0.2016]]],
[[[ 0.0419, 0.3322, -0.2106],
[ 0.1776, -0.1845, -0.3134],
[-0.0708, 0.1921, 0.3095]]],
[[[-0.2070, 0.0723, 0.2876],
[ 0.2209, 0.2077, 0.2369],
[ 0.2108, 0.0861, -0.2279]]],
[[[-0.2799, -0.1527, -0.0388],
[-0.2043, 0.1220, 0.1032],
[-0.0755, 0.1281, 0.1077]]],
[[[ 0.2035, 0.2245, -0.1129],
[ 0.3257, -0.0385, -0.0115],
[-0.3146, -0.2145, -0.1947]]],
[[[-0.1426, 0.2370, -0.1089],
[-0.2491, 0.1282, 0.1067],
[ 0.2159, -0.1725, 0.0723]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021, 0.1425], device='cuda:0',
requires_grad=True))]
print(list(module.named_buffers()))
得出:
[('weight_mask', tensor([[[[0., 1., 0.],
[1., 0., 0.],
[1., 1., 1.]]],
[[[1., 0., 1.],
[1., 1., 0.],
[1., 0., 1.]]],
[[[1., 0., 0.],
[0., 1., 1.],
[1., 1., 1.]]],
[[[1., 0., 0.],
[1., 1., 1.],
[1., 1., 1.]]],
[[[1., 0., 1.],
[1., 1., 1.],
[0., 1., 1.]]],
[[[1., 1., 1.],
[1., 1., 0.],
[1., 1., 0.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 1., 1., 0., 1.], device='cuda:0'))]
print(module.bias)
得出:
tensor([-0.0000, -0.0000, -0.2656, -0.1519, -0.0000, 0.1425], device='cuda:0',
grad_fn=<MulBackward0>)
print(module._forward_pre_hooks)
得出:
OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f1e6c425400>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7f1e6c425550>)])
一個(gè)模塊中的同一參數(shù)可以被多次修剪,各種修剪調(diào)用的效果等于串聯(lián)應(yīng)用的各種蒙版的組合。 PruningContainer
的compute_mask
方法可處理新遮罩與舊遮罩的組合。
例如,假設(shè)我們現(xiàn)在要進(jìn)一步修剪module.weight
,這一次是使用沿著張量的第 0 軸的結(jié)構(gòu)化修剪(第 0 軸對應(yīng)于卷積層的輸出通道,并且conv1
的維數(shù)為 6) ,基于渠道的 L2 規(guī)范。 這可以通過ln_structured
和n=2
和dim=0
功能來實(shí)現(xiàn)。
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
## As we can verify, this will zero out all the connections corresponding to
## 50% (3 out of 6) of the channels, while preserving the action of the
## previous mask.
print(module.weight)
得出:
tensor([[[[ 0.0000, -0.2212, 0.0000],
[ 0.2488, 0.0000, 0.0000],
[-0.2412, -0.2400, -0.2016]]],
[[[ 0.0000, 0.0000, -0.0000],
[ 0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000]]],
[[[-0.2070, 0.0000, 0.0000],
[ 0.0000, 0.2077, 0.2369],
[ 0.2108, 0.0861, -0.2279]]],
[[[-0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000],
[-0.0000, 0.0000, 0.0000]]],
[[[ 0.2035, 0.0000, -0.1129],
[ 0.3257, -0.0385, -0.0115],
[-0.0000, -0.2145, -0.1947]]],
[[[-0.0000, 0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000],
[ 0.0000, -0.0000, 0.0000]]]], device='cuda:0',
grad_fn=<MulBackward0>)
現(xiàn)在,對應(yīng)的鉤子將為torch.nn.utils.prune.PruningContainer
類型,并將存儲應(yīng)用于weight
參數(shù)的修剪歷史。
for hook in module._forward_pre_hooks.values():
if hook._tensor_name == "weight": # select out the correct hook
break
print(list(hook)) # pruning history in the container
得出:
[<torch.nn.utils.prune.RandomUnstructured object at 0x7f1e6c425400>, <torch.nn.utils.prune.LnStructured object at 0x7f1e6c4259b0>]
所有相關(guān)的張量,包括掩碼緩沖區(qū)和用于計(jì)算修剪的張量的原始參數(shù),都存儲在模型的state_dict
中,因此可以根據(jù)需要輕松地序列化和保存。
print(model.state_dict().keys())
得出:
odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
要使修剪永久化,請刪除weight_orig
和weight_mask
的重新參數(shù)化,然后刪除forward_pre_hook
,我們可以使用torch.nn.utils.prune的remove
功能。 請注意,這不會撤消修剪,好像從未發(fā)生過。 它只是通過將參數(shù)weight
重新分配為模型參數(shù)(修剪后的版本)來使其永久不變。
刪除重新參數(shù)化之前:
print(list(module.named_parameters()))
得出:
[('weight_orig', Parameter containing:
tensor([[[[ 0.3161, -0.2212, 0.0417],
[ 0.2488, 0.2415, 0.2071],
[-0.2412, -0.2400, -0.2016]]],
[[[ 0.0419, 0.3322, -0.2106],
[ 0.1776, -0.1845, -0.3134],
[-0.0708, 0.1921, 0.3095]]],
[[[-0.2070, 0.0723, 0.2876],
[ 0.2209, 0.2077, 0.2369],
[ 0.2108, 0.0861, -0.2279]]],
[[[-0.2799, -0.1527, -0.0388],
[-0.2043, 0.1220, 0.1032],
[-0.0755, 0.1281, 0.1077]]],
[[[ 0.2035, 0.2245, -0.1129],
[ 0.3257, -0.0385, -0.0115],
[-0.3146, -0.2145, -0.1947]]],
[[[-0.1426, 0.2370, -0.1089],
[-0.2491, 0.1282, 0.1067],
[ 0.2159, -0.1725, 0.0723]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021, 0.1425], device='cuda:0',
requires_grad=True))]
print(list(module.named_buffers()))
得出:
[('weight_mask', tensor([[[[0., 1., 0.],
[1., 0., 0.],
[1., 1., 1.]]],
[[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]],
[[[1., 0., 0.],
[0., 1., 1.],
[1., 1., 1.]]],
[[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]],
[[[1., 0., 1.],
[1., 1., 1.],
[0., 1., 1.]]],
[[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 1., 1., 0., 1.], device='cuda:0'))]
print(module.weight)
得出:
tensor([[[[ 0.0000, -0.2212, 0.0000],
[ 0.2488, 0.0000, 0.0000],
[-0.2412, -0.2400, -0.2016]]],
[[[ 0.0000, 0.0000, -0.0000],
[ 0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000]]],
[[[-0.2070, 0.0000, 0.0000],
[ 0.0000, 0.2077, 0.2369],
[ 0.2108, 0.0861, -0.2279]]],
[[[-0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000],
[-0.0000, 0.0000, 0.0000]]],
[[[ 0.2035, 0.0000, -0.1129],
[ 0.3257, -0.0385, -0.0115],
[-0.0000, -0.2145, -0.1947]]],
[[[-0.0000, 0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000],
[ 0.0000, -0.0000, 0.0000]]]], device='cuda:0',
grad_fn=<MulBackward0>)
刪除重新參數(shù)化后:
prune.remove(module, 'weight')
print(list(module.named_parameters()))
得出:
[('bias_orig', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021, 0.1425], device='cuda:0',
requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0000, -0.2212, 0.0000],
[ 0.2488, 0.0000, 0.0000],
[-0.2412, -0.2400, -0.2016]]],
[[[ 0.0000, 0.0000, -0.0000],
[ 0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000]]],
[[[-0.2070, 0.0000, 0.0000],
[ 0.0000, 0.2077, 0.2369],
[ 0.2108, 0.0861, -0.2279]]],
[[[-0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000],
[-0.0000, 0.0000, 0.0000]]],
[[[ 0.2035, 0.0000, -0.1129],
[ 0.3257, -0.0385, -0.0115],
[-0.0000, -0.2145, -0.1947]]],
[[[-0.0000, 0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000],
[ 0.0000, -0.0000, 0.0000]]]], device='cuda:0', requires_grad=True))]
print(list(module.named_buffers()))
得出:
[('bias_mask', tensor([0., 0., 1., 1., 0., 1.], device='cuda:0'))]
通過指定所需的修剪技術(shù)和參數(shù),我們可以輕松地修剪網(wǎng)絡(luò)中的多個(gè)張量,也許根據(jù)它們的類型,如在本示例中將看到的那樣。
new_model = LeNet()
for name, module in new_model.named_modules():
# prune 20% of connections in all 2D-conv layers
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=0.2)
# prune 40% of connections in all linear layers
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.4)
print(dict(new_model.named_buffers()).keys()) # to verify that all masks exist
得出:
dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])
到目前為止,我們僅研究了通常被稱為“局部”修剪的方法,即通過比較每個(gè)條目的統(tǒng)計(jì)信息(權(quán)重,激活度,梯度等)來逐一修剪模型中的張量的做法。 到該張量中的其他條目。 但是,一種常見且可能更強(qiáng)大的技術(shù)是通過刪除(例如)刪除整個(gè)模型中最低的 20%的連接,而不是刪除每一層中最低的 20%的連接來一次修剪模型。 這很可能導(dǎo)致每個(gè)層的修剪百分比不同。 讓我們看看如何使用torch.nn.utils.prune
中的global_unstructured
進(jìn)行操作。
model = LeNet()
parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2,
)
現(xiàn)在,我們可以檢查在每個(gè)修剪參數(shù)中引起的稀疏性,該稀疏性將不等于每層中的 20%。 但是,全球稀疏度將(大約)為 20%。
print(
"Sparsity in conv1.weight: {:.2f}%".format(
100\. * float(torch.sum(model.conv1.weight == 0))
/ float(model.conv1.weight.nelement())
)
)
print(
"Sparsity in conv2.weight: {:.2f}%".format(
100\. * float(torch.sum(model.conv2.weight == 0))
/ float(model.conv2.weight.nelement())
)
)
print(
"Sparsity in fc1.weight: {:.2f}%".format(
100\. * float(torch.sum(model.fc1.weight == 0))
/ float(model.fc1.weight.nelement())
)
)
print(
"Sparsity in fc2.weight: {:.2f}%".format(
100\. * float(torch.sum(model.fc2.weight == 0))
/ float(model.fc2.weight.nelement())
)
)
print(
"Sparsity in fc3.weight: {:.2f}%".format(
100\. * float(torch.sum(model.fc3.weight == 0))
/ float(model.fc3.weight.nelement())
)
)
print(
"Global sparsity: {:.2f}%".format(
100\. * float(
torch.sum(model.conv1.weight == 0)
+ torch.sum(model.conv2.weight == 0)
+ torch.sum(model.fc1.weight == 0)
+ torch.sum(model.fc2.weight == 0)
+ torch.sum(model.fc3.weight == 0)
)
/ float(
model.conv1.weight.nelement()
+ model.conv2.weight.nelement()
+ model.fc1.weight.nelement()
+ model.fc2.weight.nelement()
+ model.fc3.weight.nelement()
)
)
)
得出:
Sparsity in conv1.weight: 7.41%
Sparsity in conv2.weight: 9.49%
Sparsity in fc1.weight: 22.00%
Sparsity in fc2.weight: 12.28%
Sparsity in fc3.weight: 9.76%
Global sparsity: 20.00%
torch.nn.utils.prune
要實(shí)現(xiàn)自己的修剪功能,您可以通過繼承BasePruningMethod
基類來擴(kuò)展nn.utils.prune
模塊,這與所有其他修剪方法一樣。 基類為您實(shí)現(xiàn)以下方法:__call__
,apply_mask
,apply
,prune
和remove
。 除了某些特殊情況外,您不必為新的修剪技術(shù)重新實(shí)現(xiàn)這些方法。 但是,您將必須實(shí)現(xiàn)__init__
(構(gòu)造函數(shù))和compute_mask
(有關(guān)如何根據(jù)修剪技術(shù)的邏輯為給定張量計(jì)算掩碼的說明)。
另外,您將必須指定此技術(shù)實(shí)現(xiàn)的修剪類型(支持的選項(xiàng)為global
,structured
和unstructured
)。 需要確定在迭代應(yīng)用修剪的情況下如何組合蒙版。 換句話說,當(dāng)修剪預(yù)修剪的參數(shù)時(shí),當(dāng)前的修剪技術(shù)應(yīng)作用于參數(shù)的未修剪部分。 指定PRUNING_TYPE
將使PruningContainer
(處理修剪蒙版的迭代應(yīng)用)正確識別要修剪的參數(shù)。
例如,假設(shè)您要實(shí)施一種修剪技術(shù),以修剪張量中的所有其他條目(或者-如果先前已修剪過張量,則在張量的其余未修剪部分中)。 這將是PRUNING_TYPE='unstructured'
,因?yàn)樗饔糜趯又械膯蝹€(gè)連接,而不作用于整個(gè)單元/通道('structured'
),或作用于不同的參數(shù)('global'
)。
class FooBarPruningMethod(prune.BasePruningMethod):
"""Prune every other entry in a tensor
"""
PRUNING_TYPE = 'unstructured'
def compute_mask(self, t, default_mask):
mask = default_mask.clone()
mask.view(-1)[::2] = 0
return mask
現(xiàn)在,要將其應(yīng)用于nn.Module
中的參數(shù),還應(yīng)該提供一個(gè)簡單的函數(shù)來實(shí)例化該方法并將其應(yīng)用。
def foobar_unstructured(module, name):
"""Prunes tensor corresponding to parameter called `name` in `module`
by removing every other entry in the tensors.
Modifies module in place (and also return the modified module)
by:
1) adding a named buffer called `name+'_mask'` corresponding to the
binary mask applied to the parameter `name` by the pruning method.
The parameter `name` is replaced by its pruned version, while the
original (unpruned) parameter is stored in a new parameter named
`name+'_orig'`.
Args:
module (nn.Module): module containing the tensor to prune
name (string): parameter name within `module` on which pruning
will act.
Returns:
module (nn.Module): modified (i.e. pruned) version of the input
module
Examples:
>>> m = nn.Linear(3, 4)
>>> foobar_unstructured(m, name='bias')
"""
FooBarPruningMethod.apply(module, name)
return module
試試吧!
model = LeNet()
foobar_unstructured(model.fc3, name='bias')
print(model.fc3.bias_mask)
得出:
tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])
腳本的總運(yùn)行時(shí)間:(0 分鐘 0.146 秒)
Download Python source code: pruning_tutorial.py
Download Jupyter notebook: pruning_tutorial.ipynb
更多建議: