PyTorch 修剪教程

2020-09-10 14:00 更新
原文: https://pytorch.org/tutorials/intermediate/pruning_tutorial.html

作者: Michela Paganini

最新的深度學(xué)習(xí)技術(shù)依賴于難以部署的過度參數(shù)化模型。 相反,已知生物神經(jīng)網(wǎng)絡(luò)使用有效的稀疏連通性。 為了減少內(nèi)存,電池和硬件消耗,同時(shí)又不犧牲精度,在設(shè)備上部署輕量級模型并通過私有設(shè)備內(nèi)計(jì)算來確保私密性,確定通過減少模型中參數(shù)數(shù)量來壓縮模型的最佳技術(shù)很重要。 在研究方面,修剪用于研究參數(shù)過度配置和參數(shù)不足網(wǎng)絡(luò)之間學(xué)習(xí)動(dòng)態(tài)的差異,以研究幸運(yùn)稀疏子網(wǎng)絡(luò)和初始化(“ 彩票”)作為破壞性對象的作用。 神經(jīng)結(jié)構(gòu)搜索技術(shù)等等。

在本教程中,您將學(xué)習(xí)如何使用torch.nn.utils.prune稀疏神經(jīng)網(wǎng)絡(luò),以及如何擴(kuò)展它以實(shí)現(xiàn)自己的自定義修剪技術(shù)。

要求

"torch>=1.4.0a0+8e8a5e0"

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

建立模型

在本教程中,我們使用 LeCun 等人,1998 年的 LeNet 體系結(jié)構(gòu)。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)


    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = LeNet().to(device=device)

檢查模塊

讓我們檢查一下 LeNet 模型中的(未經(jīng)修剪的)conv1層。 目前它將包含兩個(gè)參數(shù)weightbias,并且沒有緩沖區(qū)。

module = model.conv1
print(list(module.named_parameters()))

得出:

[('weight', Parameter containing:
tensor([[[[ 0.3161, -0.2212,  0.0417],
          [ 0.2488,  0.2415,  0.2071],
          [-0.2412, -0.2400, -0.2016]]],


        [[[ 0.0419,  0.3322, -0.2106],
          [ 0.1776, -0.1845, -0.3134],
          [-0.0708,  0.1921,  0.3095]]],


        [[[-0.2070,  0.0723,  0.2876],
          [ 0.2209,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],


        [[[-0.2799, -0.1527, -0.0388],
          [-0.2043,  0.1220,  0.1032],
          [-0.0755,  0.1281,  0.1077]]],


        [[[ 0.2035,  0.2245, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.3146, -0.2145, -0.1947]]],


        [[[-0.1426,  0.2370, -0.1089],
          [-0.2491,  0.1282,  0.1067],
          [ 0.2159, -0.1725,  0.0723]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021,  0.1425], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))

得出:

[]

修剪模塊

要修剪模塊(在此示例中,為 LeNet 架構(gòu)的conv1層),請首先從torch.nn.utils.prune中可用的那些技術(shù)中選擇一種修剪技術(shù)(或通過子類化BasePruningMethod實(shí)現(xiàn)您自己的)。 然后,指定模塊和該模塊中要修剪的參數(shù)的名稱。 最后,使用所選修剪技術(shù)所需的適當(dāng)關(guān)鍵字參數(shù),指定修剪參數(shù)。

在此示例中,我們將在conv1層中名為weight的參數(shù)中隨機(jī)修剪 30%的連接。 模塊作為第一個(gè)參數(shù)傳遞給函數(shù); name使用其字符串標(biāo)識符在該模塊內(nèi)標(biāo)識參數(shù); amount表示與修剪的連接百分比(如果它是介于 0 和 1 之間的浮點(diǎn)數(shù)),或者表示與修剪的連接的絕對數(shù)量(如果它是非負(fù)整數(shù))。<

prune.random_unstructured(module, name="weight", amount=0.3)

修剪是通過從參數(shù)中刪除weight并將其替換為名為weight_orig的新參數(shù)(即,將"_orig"附加到初始參數(shù)name)來進(jìn)行的。 weight_orig存儲未修剪的張量版本。 bias未修剪,因此它將保持完整。

print(list(module.named_parameters()))

得出:

[('bias', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021,  0.1425], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.3161, -0.2212,  0.0417],
          [ 0.2488,  0.2415,  0.2071],
          [-0.2412, -0.2400, -0.2016]]],


        [[[ 0.0419,  0.3322, -0.2106],
          [ 0.1776, -0.1845, -0.3134],
          [-0.0708,  0.1921,  0.3095]]],


        [[[-0.2070,  0.0723,  0.2876],
          [ 0.2209,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],


        [[[-0.2799, -0.1527, -0.0388],
          [-0.2043,  0.1220,  0.1032],
          [-0.0755,  0.1281,  0.1077]]],


        [[[ 0.2035,  0.2245, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.3146, -0.2145, -0.1947]]],


        [[[-0.1426,  0.2370, -0.1089],
          [-0.2491,  0.1282,  0.1067],
          [ 0.2159, -0.1725,  0.0723]]]], device='cuda:0', requires_grad=True))]

通過以上選擇的修剪技術(shù)生成的修剪掩碼將保存為名為weight_mask的模塊緩沖區(qū)(即,將"_mask"附加到初始參數(shù)name)。

print(list(module.named_buffers()))

得出:

[('weight_mask', tensor([[[[0., 1., 0.],
          [1., 0., 0.],
          [1., 1., 1.]]],


        [[[1., 0., 1.],
          [1., 1., 0.],
          [1., 0., 1.]]],


        [[[1., 0., 0.],
          [0., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 0., 0.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 0., 1.],
          [1., 1., 1.],
          [0., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 0.],
          [1., 1., 0.]]]], device='cuda:0'))]

為了使前向通過不更改即可工作,需要存在weight屬性。 torch.nn.utils.prune中實(shí)現(xiàn)的修剪技術(shù)計(jì)算權(quán)重的修剪版本(通過將掩碼與原始參數(shù)組合)并將其存儲在屬性weight中。 注意,這不再是module的參數(shù),現(xiàn)在只是一個(gè)屬性。

print(module.weight)

得出:

tensor([[[[ 0.0000, -0.2212,  0.0000],
          [ 0.2488,  0.0000,  0.0000],
          [-0.2412, -0.2400, -0.2016]]],


        [[[ 0.0419,  0.0000, -0.2106],
          [ 0.1776, -0.1845, -0.0000],
          [-0.0708,  0.0000,  0.3095]]],


        [[[-0.2070,  0.0000,  0.0000],
          [ 0.0000,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],


        [[[-0.2799, -0.0000, -0.0000],
          [-0.2043,  0.1220,  0.1032],
          [-0.0755,  0.1281,  0.1077]]],


        [[[ 0.2035,  0.0000, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.0000, -0.2145, -0.1947]]],


        [[[-0.1426,  0.2370, -0.1089],
          [-0.2491,  0.1282,  0.0000],
          [ 0.2159, -0.1725,  0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

最后,使用 PyTorch 的forward_pre_hooks在每次向前傳遞之前應(yīng)用修剪。 具體來說,當(dāng)修剪module時(shí)(如我們在此處所做的那樣),它將為與之關(guān)聯(lián)的每個(gè)參數(shù)獲取forward_pre_hook進(jìn)行修剪。 在這種情況下,由于到目前為止我們只修剪了名稱為weight的原始參數(shù),因此只會出現(xiàn)一個(gè)鉤子。

print(module._forward_pre_hooks)

得出:

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f1e6c425400>)])

為了完整起見,我們現(xiàn)在也可以修剪bias,以查看module的參數(shù),緩沖區(qū),掛鉤和屬性如何變化。 僅出于嘗試另一種修剪技術(shù)的目的,在此我們按 L1 范數(shù)修剪偏差中的 3 個(gè)最小條目,如l1_unstructured修剪功能中所實(shí)現(xiàn)的。

prune.l1_unstructured(module, name="bias", amount=3)

現(xiàn)在,我們希望命名的參數(shù)同時(shí)包含weight_orig(從前)和bias_orig。 緩沖區(qū)將包括weight_maskbias_mask。 兩個(gè)張量的修剪版本將作為模塊屬性存在,并且該模塊現(xiàn)在將具有兩個(gè)forward_pre_hooks。

print(list(module.named_parameters()))

得出:

[('weight_orig', Parameter containing:
tensor([[[[ 0.3161, -0.2212,  0.0417],
          [ 0.2488,  0.2415,  0.2071],
          [-0.2412, -0.2400, -0.2016]]],


        [[[ 0.0419,  0.3322, -0.2106],
          [ 0.1776, -0.1845, -0.3134],
          [-0.0708,  0.1921,  0.3095]]],


        [[[-0.2070,  0.0723,  0.2876],
          [ 0.2209,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],


        [[[-0.2799, -0.1527, -0.0388],
          [-0.2043,  0.1220,  0.1032],
          [-0.0755,  0.1281,  0.1077]]],


        [[[ 0.2035,  0.2245, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.3146, -0.2145, -0.1947]]],


        [[[-0.1426,  0.2370, -0.1089],
          [-0.2491,  0.1282,  0.1067],
          [ 0.2159, -0.1725,  0.0723]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021,  0.1425], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))

得出:

[('weight_mask', tensor([[[[0., 1., 0.],
          [1., 0., 0.],
          [1., 1., 1.]]],


        [[[1., 0., 1.],
          [1., 1., 0.],
          [1., 0., 1.]]],


        [[[1., 0., 0.],
          [0., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 0., 0.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 0., 1.],
          [1., 1., 1.],
          [0., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 0.],
          [1., 1., 0.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 1., 1., 0., 1.], device='cuda:0'))]
print(module.bias)

得出:

tensor([-0.0000, -0.0000, -0.2656, -0.1519, -0.0000,  0.1425], device='cuda:0',
       grad_fn=<MulBackward0>)
print(module._forward_pre_hooks)

得出:

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f1e6c425400>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7f1e6c425550>)])

迭代修剪

一個(gè)模塊中的同一參數(shù)可以被多次修剪,各種修剪調(diào)用的效果等于串聯(lián)應(yīng)用的各種蒙版的組合。 PruningContainercompute_mask方法可處理新遮罩與舊遮罩的組合。

例如,假設(shè)我們現(xiàn)在要進(jìn)一步修剪module.weight,這一次是使用沿著張量的第 0 軸的結(jié)構(gòu)化修剪(第 0 軸對應(yīng)于卷積層的輸出通道,并且conv1的維數(shù)為 6) ,基于渠道的 L2 規(guī)范。 這可以通過ln_structuredn=2dim=0功能來實(shí)現(xiàn)。

prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)


## As we can verify, this will zero out all the connections corresponding to
## 50% (3 out of 6) of the channels, while preserving the action of the
## previous mask.
print(module.weight)

得出:

tensor([[[[ 0.0000, -0.2212,  0.0000],
          [ 0.2488,  0.0000,  0.0000],
          [-0.2412, -0.2400, -0.2016]]],


        [[[ 0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000]]],


        [[[-0.2070,  0.0000,  0.0000],
          [ 0.0000,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],


        [[[-0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000]]],


        [[[ 0.2035,  0.0000, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.0000, -0.2145, -0.1947]]],


        [[[-0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

現(xiàn)在,對應(yīng)的鉤子將為torch.nn.utils.prune.PruningContainer類型,并將存儲應(yīng)用于weight參數(shù)的修剪歷史。

for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break


print(list(hook))  # pruning history in the container

得出:

[<torch.nn.utils.prune.RandomUnstructured object at 0x7f1e6c425400>, <torch.nn.utils.prune.LnStructured object at 0x7f1e6c4259b0>]

序列化修剪的模型

所有相關(guān)的張量,包括掩碼緩沖區(qū)和用于計(jì)算修剪的張量的原始參數(shù),都存儲在模型的state_dict中,因此可以根據(jù)需要輕松地序列化和保存。

print(model.state_dict().keys())

得出:

odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

刪除修剪重新參數(shù)化

要使修剪永久化,請刪除weight_origweight_mask的重新參數(shù)化,然后刪除forward_pre_hook,我們可以使用torch.nn.utils.prune的remove功能。 請注意,這不會撤消修剪,好像從未發(fā)生過。 它只是通過將參數(shù)weight重新分配為模型參數(shù)(修剪后的版本)來使其永久不變。

刪除重新參數(shù)化之前:

print(list(module.named_parameters()))

得出:

[('weight_orig', Parameter containing:
tensor([[[[ 0.3161, -0.2212,  0.0417],
          [ 0.2488,  0.2415,  0.2071],
          [-0.2412, -0.2400, -0.2016]]],


        [[[ 0.0419,  0.3322, -0.2106],
          [ 0.1776, -0.1845, -0.3134],
          [-0.0708,  0.1921,  0.3095]]],


        [[[-0.2070,  0.0723,  0.2876],
          [ 0.2209,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],


        [[[-0.2799, -0.1527, -0.0388],
          [-0.2043,  0.1220,  0.1032],
          [-0.0755,  0.1281,  0.1077]]],


        [[[ 0.2035,  0.2245, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.3146, -0.2145, -0.1947]]],


        [[[-0.1426,  0.2370, -0.1089],
          [-0.2491,  0.1282,  0.1067],
          [ 0.2159, -0.1725,  0.0723]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021,  0.1425], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))

得出:

[('weight_mask', tensor([[[[0., 1., 0.],
          [1., 0., 0.],
          [1., 1., 1.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[1., 0., 0.],
          [0., 1., 1.],
          [1., 1., 1.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[1., 0., 1.],
          [1., 1., 1.],
          [0., 1., 1.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 1., 1., 0., 1.], device='cuda:0'))]
print(module.weight)

得出:

tensor([[[[ 0.0000, -0.2212,  0.0000],
          [ 0.2488,  0.0000,  0.0000],
          [-0.2412, -0.2400, -0.2016]]],


        [[[ 0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000]]],


        [[[-0.2070,  0.0000,  0.0000],
          [ 0.0000,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],


        [[[-0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000]]],


        [[[ 0.2035,  0.0000, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.0000, -0.2145, -0.1947]]],


        [[[-0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

刪除重新參數(shù)化后:

prune.remove(module, 'weight')
print(list(module.named_parameters()))

得出:

[('bias_orig', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021,  0.1425], device='cuda:0',
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0000, -0.2212,  0.0000],
          [ 0.2488,  0.0000,  0.0000],
          [-0.2412, -0.2400, -0.2016]]],


        [[[ 0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000]]],


        [[[-0.2070,  0.0000,  0.0000],
          [ 0.0000,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],


        [[[-0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000]]],


        [[[ 0.2035,  0.0000, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.0000, -0.2145, -0.1947]]],


        [[[-0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000]]]], device='cuda:0', requires_grad=True))]
print(list(module.named_buffers()))

得出:

[('bias_mask', tensor([0., 0., 1., 1., 0., 1.], device='cuda:0'))]

修剪模型中的多個(gè)參數(shù)

通過指定所需的修剪技術(shù)和參數(shù),我們可以輕松地修剪網(wǎng)絡(luò)中的多個(gè)張量,也許根據(jù)它們的類型,如在本示例中將看到的那樣。

new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)


print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

得出:

dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])

全球修剪

到目前為止,我們僅研究了通常被稱為“局部”修剪的方法,即通過比較每個(gè)條目的統(tǒng)計(jì)信息(權(quán)重,激活度,梯度等)來逐一修剪模型中的張量的做法。 到該張量中的其他條目。 但是,一種常見且可能更強(qiáng)大的技術(shù)是通過刪除(例如)刪除整個(gè)模型中最低的 20%的連接,而不是刪除每一層中最低的 20%的連接來一次修剪模型。 這很可能導(dǎo)致每個(gè)層的修剪百分比不同。 讓我們看看如何使用torch.nn.utils.prune中的global_unstructured進(jìn)行操作。

model = LeNet()


parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)


prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

現(xiàn)在,我們可以檢查在每個(gè)修剪參數(shù)中引起的稀疏性,該稀疏性將不等于每層中的 20%。 但是,全球稀疏度將(大約)為 20%。

print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100\. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

得出:

Sparsity in conv1.weight: 7.41%
Sparsity in conv2.weight: 9.49%
Sparsity in fc1.weight: 22.00%
Sparsity in fc2.weight: 12.28%
Sparsity in fc3.weight: 9.76%
Global sparsity: 20.00%

使用自定義修剪功能擴(kuò)展torch.nn.utils.prune

要實(shí)現(xiàn)自己的修剪功能,您可以通過繼承BasePruningMethod基類來擴(kuò)展nn.utils.prune模塊,這與所有其他修剪方法一樣。 基類為您實(shí)現(xiàn)以下方法:__call__,apply_maskapply,pruneremove。 除了某些特殊情況外,您不必為新的修剪技術(shù)重新實(shí)現(xiàn)這些方法。 但是,您將必須實(shí)現(xiàn)__init__(構(gòu)造函數(shù))和compute_mask(有關(guān)如何根據(jù)修剪技術(shù)的邏輯為給定張量計(jì)算掩碼的說明)。 另外,您將必須指定此技術(shù)實(shí)現(xiàn)的修剪類型(支持的選項(xiàng)為global,structuredunstructured)。 需要確定在迭代應(yīng)用修剪的情況下如何組合蒙版。 換句話說,當(dāng)修剪預(yù)修剪的參數(shù)時(shí),當(dāng)前的修剪技術(shù)應(yīng)作用于參數(shù)的未修剪部分。 指定PRUNING_TYPE將使PruningContainer(處理修剪蒙版的迭代應(yīng)用)正確識別要修剪的參數(shù)。

例如,假設(shè)您要實(shí)施一種修剪技術(shù),以修剪張量中的所有其他條目(或者-如果先前已修剪過張量,則在張量的其余未修剪部分中)。 這將是PRUNING_TYPE='unstructured',因?yàn)樗饔糜趯又械膯蝹€(gè)連接,而不作用于整個(gè)單元/通道('structured'),或作用于不同的參數(shù)('global')。

class FooBarPruningMethod(prune.BasePruningMethod):
    """Prune every other entry in a tensor
    """
    PRUNING_TYPE = 'unstructured'


    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0
        return mask

現(xiàn)在,要將其應(yīng)用于nn.Module中的參數(shù),還應(yīng)該提供一個(gè)簡單的函數(shù)來實(shí)例化該方法并將其應(yīng)用。

def foobar_unstructured(module, name):
    """Prunes tensor corresponding to parameter called `name` in `module`
    by removing every other entry in the tensors.
    Modifies module in place (and also return the modified module)
    by:
    1) adding a named buffer called `name+'_mask'` corresponding to the
    binary mask applied to the parameter `name` by the pruning method.
    The parameter `name` is replaced by its pruned version, while the
    original (unpruned) parameter is stored in a new parameter named
    `name+'_orig'`.


    Args:
        module (nn.Module): module containing the tensor to prune
        name (string): parameter name within `module` on which pruning
                will act.


    Returns:
        module (nn.Module): modified (i.e. pruned) version of the input
            module


    Examples:
        >>> m = nn.Linear(3, 4)
        >>> foobar_unstructured(m, name='bias')
    """
    FooBarPruningMethod.apply(module, name)
    return module

試試吧!

model = LeNet()
foobar_unstructured(model.fc3, name='bias')


print(model.fc3.bias_mask)

得出:

tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])

腳本的總運(yùn)行時(shí)間:(0 分鐘 0.146 秒)

Download Python source code: pruning_tutorial.py Download Jupyter notebook: pruning_tutorial.ipynb



以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號