PyTorch torch.utils.data

2020-09-15 11:52 更新

原文: PyTorch torch.utils.datal

PyTorch 數(shù)據(jù)加載實用程序的核心是 torch.utils.data.DataLoader 類。 它表示可在數(shù)據(jù)集上迭代的 Python,并支持

這些選項由 DataLoader 的構造函數(shù)參數(shù)配置,該參數(shù)具有簽名:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

以下各節(jié)詳細介紹了這些選項的效果和用法。

數(shù)據(jù)集類型

DataLoader 構造函數(shù)的最重要參數(shù)是dataset,它指示要從中加載數(shù)據(jù)的數(shù)據(jù)集對象。 PyTorch 支持兩種不同類型的數(shù)據(jù)集:

地圖樣式數(shù)據(jù)集

映射樣式數(shù)據(jù)集是一種實現(xiàn)__getitem__()__len__()協(xié)議的數(shù)據(jù)集,它表示從(可能是非整數(shù))索引/關鍵字到數(shù)據(jù)樣本的映射。

例如,當使用dataset[idx]訪問時,此類數(shù)據(jù)集可以從磁盤上的文件夾中讀取第idx張圖像及其對應的標簽。

迭代式數(shù)據(jù)集

可迭代樣式的數(shù)據(jù)集是 IterableDataset 子類的實例,該子類實現(xiàn)了__iter__()協(xié)議,并表示數(shù)據(jù)樣本上的可迭代。 這種類型的數(shù)據(jù)集特別適用于隨機讀取價格昂貴甚至不大可能,并且批處理大小取決于所獲取數(shù)據(jù)的情況。

例如,這種數(shù)據(jù)集稱為iter(dataset)時,可以返回從數(shù)據(jù)庫,遠程服務器甚至實時生成的日志中讀取的數(shù)據(jù)流。

注意

當將 IterableDataset 與一起使用時,多進程數(shù)據(jù)加載。 在每個工作進程上都復制相同的數(shù)據(jù)集對象,因此必須對副本進行不同的配置,以避免重復的數(shù)據(jù)。 有關如何實現(xiàn)此功能的信息,請參見 IterableDataset 文檔。

數(shù)據(jù)加載順序和 Sampler

對于迭代式數(shù)據(jù)集,數(shù)據(jù)加載順序完全由用戶定義的迭代器控制。 這樣可以更輕松地實現(xiàn)塊讀取和動態(tài)批次大小的實現(xiàn)(例如,通過每次生成一個批次的樣本)。

本節(jié)的其余部分涉及地圖樣式數(shù)據(jù)集的情況。 torch.utils.data.Sampler 類用于指定數(shù)據(jù)加載中使用的索引/鍵的順序。 它們代表數(shù)據(jù)集索引上的可迭代對象。 例如,在具有隨機梯度體面(SGD)的常見情況下, Sampler 可以隨機排列一列索引,一次生成每個索引,或者為小批量生成少量索引 新幣。

基于 DataLoadershuffle參數(shù),將自動構建順序采樣或混洗的采樣器。 或者,用戶可以使用sampler參數(shù)指定一個自定義 Sampler 對象,該對象每次都會產生要提取的下一個索引/關鍵字。

可以一次生成批量索引列表的自定義 Sampler 作為batch_sampler參數(shù)傳遞。 也可以通過batch_sizedrop_last參數(shù)啟用自動批處理。 有關更多詳細信息,請參見下一部分的。

Note

samplerbatch_sampler都不與可迭代樣式的數(shù)據(jù)集兼容,因為此類數(shù)據(jù)集沒有鍵或索引的概念。

加載批處理和非批處理數(shù)據(jù)

DataLoader 支持通過參數(shù)batch_size,drop_lastbatch_sampler將各個提取的數(shù)據(jù)樣本自動整理為批次。

自動批處理(默認)

這是最常見的情況,對應于獲取一小批數(shù)據(jù)并將其整理為批處理的樣本,即包含張量,其中一維為批處理維度(通常是第一維)。

batch_size(默認1)不是None時,數(shù)據(jù)加載器將生成批處理的樣本,而不是單個樣本。 batch_sizedrop_last參數(shù)用于指定數(shù)據(jù)加載器如何獲取數(shù)據(jù)集密鑰的批處理。 對于地圖樣式的數(shù)據(jù)集,用戶可以選擇指定batch_sampler,它一次生成一個鍵列表。

Note

batch_sizedrop_last自變量本質上用于從sampler構造batch_sampler。 對于地圖樣式的數(shù)據(jù)集,sampler由用戶提供或基于shuffle參數(shù)構造。 對于可迭代樣式的數(shù)據(jù)集,sampler是一個虛擬的無限數(shù)據(jù)集。

Note

當從可重復樣式數(shù)據(jù)集進行多重處理提取時,drop_last參數(shù)會刪除每個工作人員數(shù)據(jù)集副本的最后一個非完整批次。

使用來自采樣器的索引獲取樣本列表后,作為collate_fn參數(shù)傳遞的函數(shù)用于將樣本列表整理為批次。

在這種情況下,從地圖樣式數(shù)據(jù)集加載大致等效于:

for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

從可迭代樣式的數(shù)據(jù)集加載大致等效于:

dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

自定義collate_fn可用于自定義排序規(guī)則,例如,將順序數(shù)據(jù)填充到批處理的最大長度。

禁用自動批處理

在某些情況下,用戶可能希望以數(shù)據(jù)集代碼手動處理批處理,或僅加載單個樣本。 例如,直接加載批處理的數(shù)據(jù)(例如,從數(shù)據(jù)庫中批量讀取或讀取連續(xù)的內存塊)可能更便宜,或者批處理大小取決于數(shù)據(jù),或者該程序設計為可處理單個樣本。 在這種情況下,最好不要使用自動批處理(其中collate_fn用于整理樣本),而應讓數(shù)據(jù)加載器直接返回dataset對象的每個成員。

batch_sizebatch_sampler均為None時(batch_sampler的默認值已為None),自動批處理被禁用。 從dataset獲得的每個樣本都將作為collate_fn參數(shù)傳遞的函數(shù)進行處理。

禁用自動批處理時,默認值collate_fn僅將 NumPy 數(shù)組轉換為 PyTorch 張量,而其他所有內容均保持不變。

In this case, loading from a map-style dataset is roughly equivalent with:

for index in sampler:
    yield collate_fn(dataset[index])

and loading from an iterable-style dataset is roughly equivalent with:

for data in iter(dataset):
    yield collate_fn(data)

使用collate_fn

啟用或禁用自動批處理時,collate_fn的使用略有不同。

禁用自動批處理時,將對每個單獨的數(shù)據(jù)樣本調用collate_fn,并且從數(shù)據(jù)加載器迭代器產生輸出。 在這種情況下,默認的collate_fn僅轉換 PyTorch 張量中的 NumPy 數(shù)組。

啟用自動批處理時,會每次調用collate_fn并帶有數(shù)據(jù)樣本列表。 期望將輸入樣本整理為一批,以便從數(shù)據(jù)加載器迭代器中獲得收益。 本節(jié)的其余部分描述了這種情況下默認collate_fn的行為。

例如,如果每個數(shù)據(jù)樣本都包含一個 3 通道圖像和一個整體類標簽,即數(shù)據(jù)集的每個元素返回一個元組(image, class_index),則默認值collate_fn將此類元組的列表整理為一個元組 批處理圖像張量和批處理類標簽 Tensor。 特別是,默認collate_fn具有以下屬性:

  • 它始終將新維度添加為批次維度。
  • 它會自動將 NumPy 數(shù)組和 Python 數(shù)值轉換為 PyTorch 張量。
  • 它保留了數(shù)據(jù)結構,例如,如果每個樣本都是一個字典,它將輸出一個具有相同鍵集但將批處理 Tensors 作為值的字典(如果無法將這些值轉換為 Tensors,則將其列出)。 與list,tuple,namedtuple等相同。

用戶可以使用自定義的collate_fn來實現(xiàn)自定義批處理,例如,沿除第一個維度之外的其他維度進行校對,各種長度的填充序列或添加對自定義數(shù)據(jù)類型的支持。

單進程和多進程數(shù)據(jù)加載

默認情況下, DataLoader 使用單進程數(shù)據(jù)加載。

在 Python 進程中,全局解釋器鎖定(GIL)阻止了跨線程真正的完全并行化 Python 代碼。 為了避免在加載數(shù)據(jù)時阻塞計算代碼,PyTorch 提供了一個簡單的開關,只需將參數(shù)num_workers設置為正整數(shù)即可執(zhí)行多進程數(shù)據(jù)加載。

單進程數(shù)據(jù)加載(默認)

在此模式下,以與初始化 DataLoader 相同的過程完成數(shù)據(jù)提取。 因此,數(shù)據(jù)加載可能會阻止計算。 然而,當用于在進程之間共享數(shù)據(jù)的資源(例如,共享存儲器,文件描述符)受到限制時,或者當整個數(shù)據(jù)集很小并且可以完全加載到存儲器中時,該模式可能是優(yōu)選的。 此外,單進程加載通常顯示更多可讀的錯誤跟蹤,因此對于調試很有用。

多進程數(shù)據(jù)加載

將參數(shù)num_workers設置為正整數(shù)將打開具有指定數(shù)量的加載程序工作進程的多進程數(shù)據(jù)加載。

在此模式下,每次創(chuàng)建 DataLoader 的迭代器時(例如,當您調用enumerate(dataloader)時),都會創(chuàng)建num_workers工作進程。 此時,datasetcollate_fnworker_init_fn被傳遞給每個工作程序,在這里它們被用來初始化和獲取數(shù)據(jù)。 這意味著數(shù)據(jù)集訪問及其內部 IO 轉換(包括collate_fn)在工作進程中運行。

torch.utils.data.get_worker_info() 在工作進程中返回各種有用的信息(包括工作 ID,數(shù)據(jù)集副本,初始種子等),并在主進程中返回None。 用戶可以在數(shù)據(jù)集代碼和/或worker_init_fn中使用此功能來分別配置每個數(shù)據(jù)集副本,并確定代碼是否正在工作進程中運行。 例如,這在分片數(shù)據(jù)集時特別有用。

對于地圖樣式的數(shù)據(jù)集,主過程使用sampler生成索引并將其發(fā)送給工作人員。 因此,任何隨機播放都是在主過程中完成的,該過程通過為索引分配索引來引導加載。

對于可迭代樣式的數(shù)據(jù)集,由于每個工作進程都獲得dataset對象的副本,因此幼稚的多進程加載通常會導致數(shù)據(jù)重復。 用戶可以使用 torch.utils.data.get_worker_info() 和/或worker_init_fn獨立配置每個副本。 (有關如何實現(xiàn)此操作的信息,請參見 IterableDataset 文檔。)出于類似的原因,在多進程加載中,drop_last參數(shù)刪除每個工作程序的可迭代樣式數(shù)據(jù)集副本的最后一個非完整批次。

一旦迭代結束或迭代器被垃圾回收,工作器將關閉。

警告

通常不建議在多進程加載中返回 CUDA 張量,因為在使用 CUDA 和在并行處理中共享 CUDA 張量時存在很多微妙之處(請參見在并行處理中的 CUDA)。 相反,我們建議使用自動內存固定(即,設置pin_memory=True),該功能可以將數(shù)據(jù)快速傳輸?shù)街С?CUDA 的 GPU。

平臺特定的行為

由于工作程序依賴于 Python multiprocessing,因此與 Unix 相比,Windows 上的工作程序啟動行為有所不同。

  • 在 Unix 上,fork()是默認的multiprocessing啟動方法。 使用fork(),童工通??梢灾苯油ㄟ^克隆的地址空間訪問dataset和 Python 參數(shù)函數(shù)。
  • 在 Windows 上,spawn()是默認的multiprocessing啟動方法。 使用spawn()啟動另一個解釋器,該解釋器運行您的主腳本,然后運行內部工作程序函數(shù),該函數(shù)通過序列化pickle接收dataset,collate_fn和其他參數(shù)。

這種獨立的序列化意味著您應該采取兩個步驟來確保在使用多進程數(shù)據(jù)加載時與 Windows 兼容:

  • 將您的大部分主腳本代碼包裝在if __name__ == '__main__':塊中,以確保在啟動每個工作進程時,該腳本不會再次運行(很可能會產生錯誤)。 您可以在此處放置數(shù)據(jù)集和 DataLoader 實例創(chuàng)建邏輯,因為它不需要在 worker 中重新執(zhí)行。
  • 確保在__main__檢查之外將任何自定義collate_fnworker_init_fndataset代碼聲明為頂級定義。 這樣可以確保它們在工作進程中可用。 (這是必需的,因為將函數(shù)僅作為引用而不是bytecode進行腌制。)

多進程數(shù)據(jù)加載中的隨機性

默認情況下,每個工作人員的 PyTorch 種子將設置為base_seed + worker_id,其中base_seed是主進程使用其 RNG 生成的長整數(shù)(因此,強制使用 RNG 狀態(tài))。 但是,初始化工作程序(例如 NumPy)時,可能會復制其他庫的種子,導致每個工作程序返回相同的隨機數(shù)。

worker_init_fn中,您可以使用 torch.utils.data.get_worker_info().seedtorch.initial_seed() 訪問每個工作人員的 PyTorch 種子集,并在加載數(shù)據(jù)之前使用它為其他庫提供種子。

內存固定

主機到 GPU 副本源自固定(頁面鎖定)內存時,速度要快得多。 有關通常何時以及如何使用固定內存的更多詳細信息,請參見使用固定內存緩沖區(qū)。

對于數(shù)據(jù)加載,將pin_memory=True傳遞到 DataLoader 將自動將獲取的數(shù)據(jù)張量放置在固定內存中,從而更快地將數(shù)據(jù)傳輸?shù)絾⒂?CUDA 的 GPU。

默認的內存固定邏輯僅識別張量以及包含張量的映射和可迭代對象。 默認情況下,如果固定邏輯看到一個自定義類型的批處理(如果您有一個collate_fn返回自定義批處理類型,則會發(fā)生),或者如果該批處理的每個元素都是自定義類型,則固定邏輯將 無法識別它們,它將返回該批處理(或那些元素)而不固定內存。 要為自定義批處理或數(shù)據(jù)類型啟用內存固定,請在自定義類型上定義pin_memory()方法。

請參見下面的示例。

例:

class SimpleCustomBatch:
    def __init__(self, data):
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)


    # custom memory pinning method on custom type
    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self


def collate_wrapper(batch):
    return SimpleCustomBatch(batch)


inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)


loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                    pin_memory=True)


for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)?

數(shù)據(jù)加載器。 組合數(shù)據(jù)集和采樣器,并在給定的數(shù)據(jù)集上提供可迭代的。

DataLoader 支持地圖樣式和可迭代樣式的數(shù)據(jù)集,具有單進程或多進程加載,自定義加載順序以及可選的自動批處理(歸類)和內存固定。

有關更多詳細信息,請參見 torch.utils.data 文檔頁面。

參數(shù)

  • 數(shù)據(jù)集 (數(shù)據(jù)集)–要從中加載數(shù)據(jù)的數(shù)據(jù)集。
  • batch_size (python:int , 可選)–每批次要加載多少個樣本(默認值:1)。
  • 隨機播放 (bool , 可選)–設置為True以使數(shù)據(jù)在每個時間段都重新隨機播放(默認值:False )。
  • 采樣器 (采樣器 , 可選)–定義了從數(shù)據(jù)集中抽取樣本的策略。 如果指定,則shuffle必須為False。
  • batch_sampler (采樣器 可選)–類似sampler,但在 時間。 與batch_sizeshuffle,samplerdrop_last互斥。
  • num_workers (python:int , 可選)–多少個子進程用于數(shù)據(jù)加載。 0表示將在主進程中加載數(shù)據(jù)。 (默認:0
  • collate_fn (可調用的, 可選)–合并樣本列表以形成張量的小批量。 在從地圖樣式數(shù)據(jù)集中使用批量加載時使用。
  • pin_memory (bool , 可選)–如果True,則數(shù)據(jù)加載器將張量復制到 CUDA 固定的內存中,然后返回。 如果您的數(shù)據(jù)元素是自定義類型,或者您的collate_fn返回的是自定義類型的批次,請參見下面的示例。
  • drop_last (布爾 , 可選)–設置為True以刪除最后不完整的批次,如果數(shù)據(jù)集大小不可分割 按批次大小。 如果False并且數(shù)據(jù)集的大小不能被批次大小整除,那么最后一批將較小。 (默認:False
  • 超時(數(shù)字 可選)–如果為正,則表示從工作人員處收集批次的超時值。 應始終為非負數(shù)。 (默認:0
  • worker_init_fn (可調用 , 可選)–如果不是None,則將在每個具有工作人員 ID (在播種之后和數(shù)據(jù)加載之前,將[0, num_workers - 1]中的 int 作為輸入。 (默認:None

Warning

如果使用spawn啟動方法,則worker_init_fn不能是不可拾取的對象,例如 lambda 函數(shù)。 有關 PyTorch 中與并行處理有關的更多詳細信息,請參見并行處理最佳實踐。

Note

len(dataloader)啟發(fā)式方法基于所用采樣器的長度。 當datasetIterableDataset 時,無論多進程加載配置如何,都將返回len(dataset)(如果實現(xiàn)),因為 PyTorch 信任用戶dataset代碼可以正確處理多進程加載 避免重復數(shù)據(jù)。 有關這兩種類型的數(shù)據(jù)集以及 IterableDataset 如何與多進程數(shù)據(jù)加載交互的更多詳細信息,請參見數(shù)據(jù)集類型。

class torch.utils.data.Dataset?

表示 Dataset 的抽象類。

代表從鍵到數(shù)據(jù)樣本的映射的所有數(shù)據(jù)集都應將其子類化。 所有子類都應該覆蓋__getitem__(),支持為給定鍵獲取數(shù)據(jù)樣本。 子類還可以選擇覆蓋__len__(),它有望通過許多 Sampler 實現(xiàn)以及 DataLoader 的默認選項返回數(shù)據(jù)集的大小。

Note

默認情況下, DataLoader 構造一個索引采樣器,該采樣器產生整數(shù)索引。 要使其與具有非整數(shù)索引/鍵的地圖樣式數(shù)據(jù)集一起使用,必須提供自定義采樣器。

class torch.utils.data.IterableDataset?

可迭代的數(shù)據(jù)集。

代表可迭代數(shù)據(jù)樣本的所有數(shù)據(jù)集都應將其子類化。 當數(shù)據(jù)來自流時,這種形式的數(shù)據(jù)集特別有用。

所有子類都應覆蓋__iter__(),這將返回此數(shù)據(jù)集中的樣本迭代器。

當子類與 DataLoader 一起使用時,數(shù)據(jù)集中的每個項目都將由 DataLoader 迭代器產生。 當num_workers > 0時,每個工作進程將具有數(shù)據(jù)集對象的不同副本,因此通常需要獨立配置每個副本,以避免從工作進程返回重復的數(shù)據(jù)。 get_worker_info() 在工作程序進程中調用時,返回有關工作程序的信息。 可以在數(shù)據(jù)集的__iter__()方法或 DataLoaderworker_init_fn選項中使用它來修改每個副本的行為。

示例 1:在__iter__()中將工作負載分配給所有工作人員:

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         worker_info = torch.utils.data.get_worker_info()
...         if worker_info is None:  # single-process data loading, return the full iterator
...             iter_start = self.start
...             iter_end = self.end
...         else:  # in a worker process
...             # split workload
...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
...             worker_id = worker_info.id
...             iter_start = self.start + worker_id * per_worker
...             iter_end = min(iter_start + per_worker, self.end)
...         return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)


>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]


>>> # Mult-process loading with two worker processes
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 5, 4, 6]


>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
[3, 4, 5, 6]

示例 2:使用worker_init_fn在所有工作人員之間分配工作量:

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)


>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]


>>> # Define a `worker_init_fn` that configures each dataset  differently
>>> def worker_init_fn(worker_id):
...     worker_info = torch.utils.data.get_worker_info()
...     dataset = worker_info.dataset  # the dataset copy in this worker process
...     overall_start = dataset.start
...     overall_end = dataset.end
...     # configure the dataset to only process the split workload
...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
...     worker_id = worker_info.id
...     dataset.start = overall_start + worker_id * per_worker
...     dataset.end = min(dataset.start + per_worker, overall_end)
...


>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]


>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]

class torch.utils.data.TensorDataset(*tensors)?

數(shù)據(jù)集包裝張量。

每個樣本將通過沿第一維索引張量來檢索。

Parameters

張量 (tensor*)–具有與第一維相同大小的張量。

class torch.utils.data.ConcatDataset(datasets)?

數(shù)據(jù)集是多個數(shù)據(jù)集的串聯(lián)。

此類對于組裝不同的現(xiàn)有數(shù)據(jù)集很有用。

Parameters

數(shù)據(jù)集(序列)–要連接的數(shù)據(jù)集列表

class torch.utils.data.ChainDataset(datasets)?

用于鏈接多個 IterableDataset 的數(shù)據(jù)集。

此類對于組裝不同的現(xiàn)有數(shù)據(jù)集流很有用。 鏈接操作是即時完成的,因此將大型數(shù)據(jù)集與此類連接起來將非常有效。

Parameters

數(shù)據(jù)集(IterableDataset 的可迭代)–鏈接在一起的數(shù)據(jù)集

class torch.utils.data.Subset(dataset, indices)?

指定索引處的數(shù)據(jù)集子集。

Parameters

  • 數(shù)據(jù)集 (數(shù)據(jù)集)–整個數(shù)據(jù)集
  • 索引(序列)–為子集選擇的整個集合中的索引

torch.utils.data.get_worker_info()?

返回有關當前 DataLoader 迭代器工作進程的信息。

在工作線程中調用時,此方法返回一個保證具有以下屬性的對象:

  • id:當前工作人員 ID。
  • num_workers:工人總數(shù)。
  • seed:當前工作程序的隨機種子集。 該值由主進程 RNG 和工作程序 ID 確定。 有關更多詳細信息,請參見 DataLoader 的文檔。
  • dataset:此流程在中的數(shù)據(jù)集對象的副本。 請注意,在不同的過程中,這將是與主過程中的對象不同的對象。

在主進程中調用時,將返回None

Note

在傳遞給 DataLoaderworker_init_fn中使用時,此方法可用于不同地設置每個工作進程,例如,使用worker_iddataset對象配置為僅讀取 分片數(shù)據(jù)集的特定部分,或使用seed播種數(shù)據(jù)集代碼中使用的其他庫(例如 NumPy)。

torch.utils.data.random_split(dataset, lengths)?

將數(shù)據(jù)集隨機拆分為給定長度的不重疊的新數(shù)據(jù)集。

Parameters

  • 數(shù)據(jù)集 (數(shù)據(jù)集)–要拆分的數(shù)據(jù)集
  • 長度(序列)–要產生的分割的長度

class torch.utils.data.Sampler(data_source)?

所有采樣器的基類。

每個 Sampler 子類都必須提供__iter__()方法(提供一種對數(shù)據(jù)集元素的索引進行迭代的方法)和__len__()方法,該方法返回返回的迭代器的長度。

Note

DataLoader 并非嚴格要求__len__()方法,但在涉及 DataLoader 長度的任何計算中都應采用。

class torch.utils.data.SequentialSampler(data_source)?

始終以相同順序順序采樣元素。

Parameters

data_source (數(shù)據(jù)集)–要從中采樣的數(shù)據(jù)集

class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)?

隨機采樣元素。 如果不進行替換,則從經過改組的數(shù)據(jù)集中采樣。 如果要更換,則用戶可以指定num_samples進行繪制。

Parameters

  • data_source (Dataset) – dataset to sample from
  • 替換 (bool )–如果True為默認值,則替換為True
  • num_samples (python:int )–要繪制的樣本數(shù),默認為 len(dataset)。 僅當<cite>替換</cite>為True時才應指定此參數(shù)。

class torch.utils.data.SubsetRandomSampler(indices)?

從給定的索引列表中隨機抽樣元素,而無需替換。

Parameters

索引(序列)–索引序列

class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)?

以給定的概率(權重)從[0,..,len(weights)-1]中采樣元素。

Parameters

  • 權重(序列)–權重序列,不必累加一個
  • num_samples (python:int )–要繪制的樣本數(shù)
  • 替代品 (bool )–如果True,則抽取替代品抽取樣品。 如果沒有,則它們將被替換而不會被繪制,這意味著當為一行繪制樣本索引時,無法為該行再次繪制它。

>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[0, 0, 0, 1, 0]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]

class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)?

包裝另一個采樣器以產生一個小批量的索引。

Parameters

  • 采樣器 (采樣器)–基本采樣器。
  • batch_size (python:int )–迷你批量的大小。
  • drop_last (bool )–如果為True,則采樣器將丟棄最后一批,如果其大小小于batch_size

Example

>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]

class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True)?

將數(shù)據(jù)加載限制為數(shù)據(jù)集子集的采樣器。

torch.nn.parallel.DistributedDataParallel 結合使用時特別有用。 在這種情況下,每個進程都可以將 DistributedSampler 實例作為 DataLoader 采樣器傳遞,并加載原始數(shù)據(jù)集的專有子集。

Note

假定數(shù)據(jù)集大小恒定。

Parameters

  • 數(shù)據(jù)集 –用于采樣的數(shù)據(jù)集。
  • num_replicas (可選)–參與分布式訓練的進程數(shù)。
  • 等級(可選)–當前進程在 num_replicas 中的等級。
  • 隨機播放(可選)–如果為 true(默認值),采樣器將隨機播放索引
以上內容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號