PyTorch 編寫自定義數(shù)據(jù)集,數(shù)據(jù)加載器和轉(zhuǎn)換

2020-09-07 17:25 更新
原文: https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

作者: Sasank Chilamkurthy

解決任何機(jī)器學(xué)習(xí)問題都需要花費(fèi)大量精力來準(zhǔn)備數(shù)據(jù)。 PyTorch 提供了許多工具來簡(jiǎn)化數(shù)據(jù)加載過程,并有望使代碼更具可讀性。 在本教程中,我們將了解如何從非空的數(shù)據(jù)集中加載和預(yù)處理/增強(qiáng)數(shù)據(jù)。

要運(yùn)行本教程,請(qǐng)確保已安裝以下軟件包:

  • scikit-image:用于圖像 io 和變換
  • pandas:用于更輕松的 csv 解析
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils


## Ignore warnings
import warnings
warnings.filterwarnings("ignore")


plt.ion()   # interactive mode

我們要處理的數(shù)據(jù)集是面部姿勢(shì)數(shù)據(jù)集。 這意味著將對(duì)面部進(jìn)行如下注釋:

../_images/landmarked_face2.png

總體上,每個(gè)面孔都標(biāo)注了 68 個(gè)不同的界標(biāo)點(diǎn)。

Note

此處下載數(shù)據(jù)集,將圖像存放于名為“ data / faces /”的目錄中。 該數(shù)據(jù)集實(shí)際上是通過對(duì)來自標(biāo)記為“面部”的  imagenet  上的一些圖像應(yīng)用出色的  dlib 姿態(tài)估計(jì) 生成的。

數(shù)據(jù)集帶有一個(gè)帶注釋的 csv 文件,如下所示:

image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

讓我們快速閱讀 CSV 并獲取 (N,2)數(shù)組中的注釋,其中 N 是地標(biāo)數(shù)。

landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')


n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:]
landmarks = np.asarray(landmarks)
landmarks = landmarks.astype('float').reshape(-1, 2)


print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))

輸出:

Image name: person-7.jpg
Landmarks shape: (68, 2)
First 4 Landmarks: [[32. 65.]
 [33. 76.]
 [34. 86.]
 [34. 97.]]

讓我們編寫一個(gè)簡(jiǎn)單的輔助函數(shù)來顯示圖像及其地標(biāo),并使用它來顯示示例。

def show_landmarks(image, landmarks):
    """Show image with landmarks"""
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)  # pause a bit so that plots are updated


plt.figure()
show_landmarks(io.imread(os.path.join('data/faces/', img_name)),
               landmarks)
plt.show()

../_images/sphx_glr_data_loading_tutorial_001.png

數(shù)據(jù)集類

torch.utils.data.Dataset是代表數(shù)據(jù)集的抽象類。 您的自定義數(shù)據(jù)集應(yīng)繼承Dataset并覆蓋以下方法:

  • __len__,以便 len(dataset)返回?cái)?shù)據(jù)集的大小。
  • __getitem__支持索引,以便可以使用dataset[i]獲取第 個(gè)樣本

讓我們?yōu)槊娌枯喞獢?shù)據(jù)集創(chuàng)建一個(gè)數(shù)據(jù)集類。 我們將在__init__中讀取 csv,但將圖像讀取留給__getitem__。 由于所有圖像不會(huì)立即存儲(chǔ)在內(nèi)存中,而是根據(jù)需要讀取,因此可以提高存儲(chǔ)效率。

我們的數(shù)據(jù)集樣本將是 dict {'image': image, 'landmarks': landmarks}。 我們的數(shù)據(jù)集將使用可選參數(shù)transform,以便可以將任何所需的處理應(yīng)用于樣本。 我們將在下一部分中看到transform的有用性。

class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""


    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform


    def __len__(self):
        return len(self.landmarks_frame)


    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()


        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:]
        landmarks = np.array([landmarks])
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}


        if self.transform:
            sample = self.transform(sample)


        return sample

讓我們實(shí)例化該類并遍歷數(shù)據(jù)樣本。 我們將打印前 4 個(gè)樣本的大小并顯示其地標(biāo)。

face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                    root_dir='data/faces/')


fig = plt.figure()


for i in range(len(face_dataset)):
    sample = face_dataset[i]


    print(i, sample['image'].shape, sample['landmarks'].shape)


    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    show_landmarks(**sample)


    if i == 3:
        plt.show()
        break

../_images/sphx_glr_data_loading_tutorial_002.png

輸出:

0 (324, 215, 3) (68, 2)
1 (500, 333, 3) (68, 2)
2 (250, 258, 3) (68, 2)
3 (434, 290, 3) (68, 2)

Transforms 變換

從上面可以看到的一個(gè)問題是樣本的大小不同。 大多數(shù)神經(jīng)網(wǎng)絡(luò)期望圖像的大小固定。 因此,我們將需要編寫一些預(yù)處理代碼。 讓我們創(chuàng)建三個(gè)轉(zhuǎn)換:

  • Rescale:縮放圖像
  • RandomCrop:從圖像中隨機(jī)裁剪。 這是數(shù)據(jù)增強(qiáng)。
  • ToTensor:將 numpy 圖像轉(zhuǎn)換為 torch 圖像(我們需要交換軸)。

我們會(huì)將它們編寫為可調(diào)用的類,而不是簡(jiǎn)單的函數(shù),這樣就不必每次調(diào)用轉(zhuǎn)換時(shí)都傳遞其參數(shù)。 為此,我們只需要實(shí)現(xiàn)__call__方法,如果需要,還可以實(shí)現(xiàn)__init__方法。 然后我們可以使用這樣的變換:

tsfm = Transform(params)
transformed_sample = tsfm(sample)

在下面觀察如何將這些變換同時(shí)應(yīng)用于圖像和地標(biāo)。

class Rescale(object):
    """Rescale the image in a sample to a given size.


    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """


    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size


    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']


        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size


        new_h, new_w = int(new_h), int(new_w)


        img = transform.resize(image, (new_h, new_w))


        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 0 respectively
        landmarks = landmarks * [new_w / w, new_h / h]


        return {'image': img, 'landmarks': landmarks}


class RandomCrop(object):
    """Crop randomly the image in a sample.


    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """


    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size


    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']


        h, w = image.shape[:2]
        new_h, new_w = self.output_size


        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)


        image = image[top: top + new_h,
                      left: left + new_w]


        landmarks = landmarks - [left, top]


        return {'image': image, 'landmarks': landmarks}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""


    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']


        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'landmarks': torch.from_numpy(landmarks)}

撰寫變換

現(xiàn)在,我們將轉(zhuǎn)換應(yīng)用于樣本。

假設(shè)我們要將圖片的較短邊重新縮放為 256,然后從中隨機(jī)裁剪一個(gè)尺寸為 224 的正方形。 也就是說,我們要組成RescaleRandomCrop轉(zhuǎn)換。 torchvision.transforms.Compose是一個(gè)簡(jiǎn)單的可調(diào)用類,它使我們可以執(zhí)行此操作。

scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
                               RandomCrop(224)])


## Apply each of the above transforms on sample.
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
    transformed_sample = tsfrm(sample)


    ax = plt.subplot(1, 3, i + 1)
    plt.tight_layout()
    ax.set_title(type(tsfrm).__name__)
    show_landmarks(**transformed_sample)


plt.show()

../_images/sphx_glr_data_loading_tutorial_003.png

遍歷數(shù)據(jù)集

讓我們將所有這些放在一起,以創(chuàng)建具有組合轉(zhuǎn)換的數(shù)據(jù)集。 總而言之,每次采樣此數(shù)據(jù)集時(shí):

  • 從文件中即時(shí)讀取圖像
  • 轉(zhuǎn)換應(yīng)用于讀取的圖像
  • 由于其中一種轉(zhuǎn)換是隨機(jī)的,因此數(shù)據(jù)是在采樣時(shí)進(jìn)行增強(qiáng)

我們可以像以前一樣使用 for i in range 循環(huán)遍歷創(chuàng)建的數(shù)據(jù)集。

transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                           root_dir='data/faces/',
                                           transform=transforms.Compose([
                                               Rescale(256),
                                               RandomCrop(224),
                                               ToTensor()
                                           ]))


for i in range(len(transformed_dataset)):
    sample = transformed_dataset[i]


    print(i, sample['image'].size(), sample['landmarks'].size())


    if i == 3:
        break

輸出:

0 torch.Size([3, 224, 224]) torch.Size([68, 2])
1 torch.Size([3, 224, 224]) torch.Size([68, 2])
2 torch.Size([3, 224, 224]) torch.Size([68, 2])
3 torch.Size([3, 224, 224]) torch.Size([68, 2])

但是,通過使用簡(jiǎn)單的for循環(huán)迭代數(shù)據(jù),我們失去了很多功能。 特別是,我們錯(cuò)過了:

  • 批量處理數(shù)據(jù)
  • 打亂數(shù)據(jù)
  • 使用multiprocessing工作程序并行加載數(shù)據(jù)。

torch.utils.data.DataLoader是提供所有這些功能的迭代器。 下面使用的參數(shù)應(yīng)該清楚。 感興趣的一個(gè)參數(shù)是collate_fn。 您可以使用collate_fn指定需要如何精確地分批樣品。 但是,默認(rèn)精度在大多數(shù)情況下都可以正常工作。

dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=4)


## Helper function to show a batch
def show_landmarks_batch(sample_batched):
    """Show image with landmarks for a batch of samples."""
    images_batch, landmarks_batch = \
            sample_batched['image'], sample_batched['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)
    grid_border_size = 2


    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))


    for i in range(batch_size):
        plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,
                    landmarks_batch[i, :, 1].numpy() + grid_border_size,
                    s=10, marker='.', c='r')


        plt.title('Batch from dataloader')


for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['landmarks'].size())


    # observe 4th batch and stop.
    if i_batch == 3:
        plt.figure()
        show_landmarks_batch(sample_batched)
        plt.axis('off')
        plt.ioff()
        plt.show()
        break

../_images/sphx_glr_data_loading_tutorial_004.png

輸出:

0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])

后記:torchvision

在本教程中,我們已經(jīng)看到了如何編寫和使用數(shù)據(jù)集,轉(zhuǎn)換和數(shù)據(jù)加載器。 torchvision包提供了一些常見的數(shù)據(jù)集和轉(zhuǎn)換。 您甚至不必編寫自定義類。 Torchvision  中可用的更通用的數(shù)據(jù)集之一是ImageFolder。 假定圖像的組織方式如下:

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

其中“螞蟻”,“蜜蜂”等是類別標(biāo)簽。 同樣也可以使用對(duì)PIL.Image,ScalePIL.Image進(jìn)行操作的通用轉(zhuǎn)換。 您可以使用以下代碼編寫數(shù)據(jù)加載器,如下所示:

import torch
from torchvision import transforms, datasets


data_transform = transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                             batch_size=4, shuffle=True,
                                             num_workers=4)

有關(guān)訓(xùn)練代碼的示例,請(qǐng)參見計(jì)算機(jī)視覺轉(zhuǎn)換學(xué)習(xí)教程。

腳本的總運(yùn)行時(shí)間:(0 分鐘 58.611 秒)

Download Python source code: data_loading_tutorial.py Download Jupyter notebook: data_loading_tutorial.ipynb

由獅身人面像畫廊生成的畫廊


以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)