PyTorch 數(shù)據(jù)加載實用程序的核心是 torch.utils.data.DataLoader
類。 它表示可在數(shù)據(jù)集上迭代的 Python,并支持
這些選項由 DataLoader
的構造函數(shù)參數(shù)配置,該參數(shù)具有簽名:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
以下各節(jié)詳細介紹了這些選項的效果和用法。
DataLoader
構造函數(shù)的最重要參數(shù)是dataset
,它指示要從中加載數(shù)據(jù)的數(shù)據(jù)集對象。 PyTorch 支持兩種不同類型的數(shù)據(jù)集:
映射樣式數(shù)據(jù)集是一種實現(xiàn)__getitem__()
和__len__()
協(xié)議的數(shù)據(jù)集,它表示從(可能是非整數(shù))索引/關鍵字到數(shù)據(jù)樣本的映射。
例如,當使用dataset[idx]
訪問時,此類數(shù)據(jù)集可以從磁盤上的文件夾中讀取第idx
張圖像及其對應的標簽。
可迭代樣式的數(shù)據(jù)集是 IterableDataset
子類的實例,該子類實現(xiàn)了__iter__()
協(xié)議,并表示數(shù)據(jù)樣本上的可迭代。 這種類型的數(shù)據(jù)集特別適用于隨機讀取價格昂貴甚至不大可能,并且批處理大小取決于所獲取數(shù)據(jù)的情況。
例如,這種數(shù)據(jù)集稱為iter(dataset)
時,可以返回從數(shù)據(jù)庫,遠程服務器甚至實時生成的日志中讀取的數(shù)據(jù)流。
注意
當將 IterableDataset
與一起使用時,多進程數(shù)據(jù)加載。 在每個工作進程上都復制相同的數(shù)據(jù)集對象,因此必須對副本進行不同的配置,以避免重復的數(shù)據(jù)。 有關如何實現(xiàn)此功能的信息,請參見 IterableDataset
文檔。
Sampler
對于迭代式數(shù)據(jù)集,數(shù)據(jù)加載順序完全由用戶定義的迭代器控制。 這樣可以更輕松地實現(xiàn)塊讀取和動態(tài)批次大小的實現(xiàn)(例如,通過每次生成一個批次的樣本)。
本節(jié)的其余部分涉及地圖樣式數(shù)據(jù)集的情況。 torch.utils.data.Sampler
類用于指定數(shù)據(jù)加載中使用的索引/鍵的順序。 它們代表數(shù)據(jù)集索引上的可迭代對象。 例如,在具有隨機梯度體面(SGD)的常見情況下, Sampler
可以隨機排列一列索引,一次生成每個索引,或者為小批量生成少量索引 新幣。
基于 DataLoader
的shuffle
參數(shù),將自動構建順序采樣或混洗的采樣器。 或者,用戶可以使用sampler
參數(shù)指定一個自定義 Sampler
對象,該對象每次都會產生要提取的下一個索引/關鍵字。
可以一次生成批量索引列表的自定義 Sampler
作為batch_sampler
參數(shù)傳遞。 也可以通過batch_size
和drop_last
參數(shù)啟用自動批處理。 有關更多詳細信息,請參見下一部分的。
Note
sampler
和batch_sampler
都不與可迭代樣式的數(shù)據(jù)集兼容,因為此類數(shù)據(jù)集沒有鍵或索引的概念。
DataLoader
支持通過參數(shù)batch_size
,drop_last
和batch_sampler
將各個提取的數(shù)據(jù)樣本自動整理為批次。
這是最常見的情況,對應于獲取一小批數(shù)據(jù)并將其整理為批處理的樣本,即包含張量,其中一維為批處理維度(通常是第一維)。
當batch_size
(默認1
)不是None
時,數(shù)據(jù)加載器將生成批處理的樣本,而不是單個樣本。 batch_size
和drop_last
參數(shù)用于指定數(shù)據(jù)加載器如何獲取數(shù)據(jù)集密鑰的批處理。 對于地圖樣式的數(shù)據(jù)集,用戶可以選擇指定batch_sampler
,它一次生成一個鍵列表。
Note
batch_size
和drop_last
自變量本質上用于從sampler
構造batch_sampler
。 對于地圖樣式的數(shù)據(jù)集,sampler
由用戶提供或基于shuffle
參數(shù)構造。 對于可迭代樣式的數(shù)據(jù)集,sampler
是一個虛擬的無限數(shù)據(jù)集。
Note
當從可重復樣式數(shù)據(jù)集進行多重處理提取時,drop_last
參數(shù)會刪除每個工作人員數(shù)據(jù)集副本的最后一個非完整批次。
使用來自采樣器的索引獲取樣本列表后,作為collate_fn
參數(shù)傳遞的函數(shù)用于將樣本列表整理為批次。
在這種情況下,從地圖樣式數(shù)據(jù)集加載大致等效于:
for indices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
從可迭代樣式的數(shù)據(jù)集加載大致等效于:
dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])
自定義collate_fn
可用于自定義排序規(guī)則,例如,將順序數(shù)據(jù)填充到批處理的最大長度。
在某些情況下,用戶可能希望以數(shù)據(jù)集代碼手動處理批處理,或僅加載單個樣本。 例如,直接加載批處理的數(shù)據(jù)(例如,從數(shù)據(jù)庫中批量讀取或讀取連續(xù)的內存塊)可能更便宜,或者批處理大小取決于數(shù)據(jù),或者該程序設計為可處理單個樣本。 在這種情況下,最好不要使用自動批處理(其中collate_fn
用于整理樣本),而應讓數(shù)據(jù)加載器直接返回dataset
對象的每個成員。
當batch_size
和batch_sampler
均為None
時(batch_sampler
的默認值已為None
),自動批處理被禁用。 從dataset
獲得的每個樣本都將作為collate_fn
參數(shù)傳遞的函數(shù)進行處理。
禁用自動批處理時,默認值collate_fn
僅將 NumPy 數(shù)組轉換為 PyTorch 張量,而其他所有內容均保持不變。
In this case, loading from a map-style dataset is roughly equivalent with:
for index in sampler:
yield collate_fn(dataset[index])
and loading from an iterable-style dataset is roughly equivalent with:
for data in iter(dataset):
yield collate_fn(data)
collate_fn
啟用或禁用自動批處理時,collate_fn
的使用略有不同。
禁用自動批處理時,將對每個單獨的數(shù)據(jù)樣本調用collate_fn
,并且從數(shù)據(jù)加載器迭代器產生輸出。 在這種情況下,默認的collate_fn
僅轉換 PyTorch 張量中的 NumPy 數(shù)組。
啟用自動批處理時,會每次調用collate_fn
并帶有數(shù)據(jù)樣本列表。 期望將輸入樣本整理為一批,以便從數(shù)據(jù)加載器迭代器中獲得收益。 本節(jié)的其余部分描述了這種情況下默認collate_fn
的行為。
例如,如果每個數(shù)據(jù)樣本都包含一個 3 通道圖像和一個整體類標簽,即數(shù)據(jù)集的每個元素返回一個元組(image, class_index)
,則默認值collate_fn
將此類元組的列表整理為一個元組 批處理圖像張量和批處理類標簽 Tensor。 特別是,默認collate_fn
具有以下屬性:
list
,tuple
,namedtuple
等相同。
用戶可以使用自定義的collate_fn
來實現(xiàn)自定義批處理,例如,沿除第一個維度之外的其他維度進行校對,各種長度的填充序列或添加對自定義數(shù)據(jù)類型的支持。
默認情況下, DataLoader
使用單進程數(shù)據(jù)加載。
在 Python 進程中,全局解釋器鎖定(GIL)阻止了跨線程真正的完全并行化 Python 代碼。 為了避免在加載數(shù)據(jù)時阻塞計算代碼,PyTorch 提供了一個簡單的開關,只需將參數(shù)num_workers
設置為正整數(shù)即可執(zhí)行多進程數(shù)據(jù)加載。
在此模式下,以與初始化 DataLoader
相同的過程完成數(shù)據(jù)提取。 因此,數(shù)據(jù)加載可能會阻止計算。 然而,當用于在進程之間共享數(shù)據(jù)的資源(例如,共享存儲器,文件描述符)受到限制時,或者當整個數(shù)據(jù)集很小并且可以完全加載到存儲器中時,該模式可能是優(yōu)選的。 此外,單進程加載通常顯示更多可讀的錯誤跟蹤,因此對于調試很有用。
將參數(shù)num_workers
設置為正整數(shù)將打開具有指定數(shù)量的加載程序工作進程的多進程數(shù)據(jù)加載。
在此模式下,每次創(chuàng)建 DataLoader
的迭代器時(例如,當您調用enumerate(dataloader)
時),都會創(chuàng)建num_workers
工作進程。 此時,dataset
,collate_fn
和worker_init_fn
被傳遞給每個工作程序,在這里它們被用來初始化和獲取數(shù)據(jù)。 這意味著數(shù)據(jù)集訪問及其內部 IO 轉換(包括collate_fn
)在工作進程中運行。
torch.utils.data.get_worker_info()
在工作進程中返回各種有用的信息(包括工作 ID,數(shù)據(jù)集副本,初始種子等),并在主進程中返回None
。 用戶可以在數(shù)據(jù)集代碼和/或worker_init_fn
中使用此功能來分別配置每個數(shù)據(jù)集副本,并確定代碼是否正在工作進程中運行。 例如,這在分片數(shù)據(jù)集時特別有用。
對于地圖樣式的數(shù)據(jù)集,主過程使用sampler
生成索引并將其發(fā)送給工作人員。 因此,任何隨機播放都是在主過程中完成的,該過程通過為索引分配索引來引導加載。
對于可迭代樣式的數(shù)據(jù)集,由于每個工作進程都獲得dataset
對象的副本,因此幼稚的多進程加載通常會導致數(shù)據(jù)重復。 用戶可以使用 torch.utils.data.get_worker_info()
和/或worker_init_fn
獨立配置每個副本。 (有關如何實現(xiàn)此操作的信息,請參見 IterableDataset
文檔。)出于類似的原因,在多進程加載中,drop_last
參數(shù)刪除每個工作程序的可迭代樣式數(shù)據(jù)集副本的最后一個非完整批次。
一旦迭代結束或迭代器被垃圾回收,工作器將關閉。
警告
通常不建議在多進程加載中返回 CUDA 張量,因為在使用 CUDA 和在并行處理中共享 CUDA 張量時存在很多微妙之處(請參見在并行處理中的 CUDA)。 相反,我們建議使用自動內存固定(即,設置pin_memory=True
),該功能可以將數(shù)據(jù)快速傳輸?shù)街С?CUDA 的 GPU。
由于工作程序依賴于 Python multiprocessing
,因此與 Unix 相比,Windows 上的工作程序啟動行為有所不同。
fork()
是默認的multiprocessing
啟動方法。 使用fork()
,童工通??梢灾苯油ㄟ^克隆的地址空間訪問dataset
和 Python 參數(shù)函數(shù)。spawn()
是默認的multiprocessing
啟動方法。 使用spawn()
啟動另一個解釋器,該解釋器運行您的主腳本,然后運行內部工作程序函數(shù),該函數(shù)通過序列化pickle
接收dataset
,collate_fn
和其他參數(shù)。這種獨立的序列化意味著您應該采取兩個步驟來確保在使用多進程數(shù)據(jù)加載時與 Windows 兼容:
if __name__ == '__main__':
塊中,以確保在啟動每個工作進程時,該腳本不會再次運行(很可能會產生錯誤)。 您可以在此處放置數(shù)據(jù)集和 DataLoader
實例創(chuàng)建邏輯,因為它不需要在 worker 中重新執(zhí)行。__main__
檢查之外將任何自定義collate_fn
,worker_init_fn
或dataset
代碼聲明為頂級定義。 這樣可以確保它們在工作進程中可用。 (這是必需的,因為將函數(shù)僅作為引用而不是bytecode
進行腌制。)
默認情況下,每個工作人員的 PyTorch 種子將設置為base_seed + worker_id
,其中base_seed
是主進程使用其 RNG 生成的長整數(shù)(因此,強制使用 RNG 狀態(tài))。 但是,初始化工作程序(例如 NumPy)時,可能會復制其他庫的種子,導致每個工作程序返回相同的隨機數(shù)。
在worker_init_fn
中,您可以使用 torch.utils.data.get_worker_info().seed
或 torch.initial_seed()
訪問每個工作人員的 PyTorch 種子集,并在加載數(shù)據(jù)之前使用它為其他庫提供種子。
主機到 GPU 副本源自固定(頁面鎖定)內存時,速度要快得多。 有關通常何時以及如何使用固定內存的更多詳細信息,請參見使用固定內存緩沖區(qū)。
對于數(shù)據(jù)加載,將pin_memory=True
傳遞到 DataLoader
將自動將獲取的數(shù)據(jù)張量放置在固定內存中,從而更快地將數(shù)據(jù)傳輸?shù)絾⒂?CUDA 的 GPU。
默認的內存固定邏輯僅識別張量以及包含張量的映射和可迭代對象。 默認情況下,如果固定邏輯看到一個自定義類型的批處理(如果您有一個collate_fn
返回自定義批處理類型,則會發(fā)生),或者如果該批處理的每個元素都是自定義類型,則固定邏輯將 無法識別它們,它將返回該批處理(或那些元素)而不固定內存。 要為自定義批處理或數(shù)據(jù)類型啟用內存固定,請在自定義類型上定義pin_memory()
方法。
請參見下面的示例。
例:
class SimpleCustomBatch:
def __init__(self, data):
transposed_data = list(zip(*data))
self.inp = torch.stack(transposed_data[0], 0)
self.tgt = torch.stack(transposed_data[1], 0)
# custom memory pinning method on custom type
def pin_memory(self):
self.inp = self.inp.pin_memory()
self.tgt = self.tgt.pin_memory()
return self
def collate_wrapper(batch):
return SimpleCustomBatch(batch)
inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)
loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
pin_memory=True)
for batch_ndx, sample in enumerate(loader):
print(sample.inp.is_pinned())
print(sample.tgt.is_pinned())
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)?
數(shù)據(jù)加載器。 組合數(shù)據(jù)集和采樣器,并在給定的數(shù)據(jù)集上提供可迭代的。
DataLoader
支持地圖樣式和可迭代樣式的數(shù)據(jù)集,具有單進程或多進程加載,自定義加載順序以及可選的自動批處理(歸類)和內存固定。
有關更多詳細信息,請參見 torch.utils.data
文檔頁面。
參數(shù)
1
)。True
以使數(shù)據(jù)在每個時間段都重新隨機播放(默認值:False
)。shuffle
必須為False
。sampler
,但在 時間。 與batch_size
,shuffle
,sampler
和drop_last
互斥。0
表示將在主進程中加載數(shù)據(jù)。 (默認:0
)True
,則數(shù)據(jù)加載器將張量復制到 CUDA 固定的內存中,然后返回。 如果您的數(shù)據(jù)元素是自定義類型,或者您的collate_fn
返回的是自定義類型的批次,請參見下面的示例。True
以刪除最后不完整的批次,如果數(shù)據(jù)集大小不可分割 按批次大小。 如果False
并且數(shù)據(jù)集的大小不能被批次大小整除,那么最后一批將較小。 (默認:False
)0
)None
,則將在每個具有工作人員 ID (在播種之后和數(shù)據(jù)加載之前,將[0, num_workers - 1]
中的 int 作為輸入。 (默認:None
)Warning
如果使用spawn
啟動方法,則worker_init_fn
不能是不可拾取的對象,例如 lambda 函數(shù)。 有關 PyTorch 中與并行處理有關的更多詳細信息,請參見并行處理最佳實踐。
Note
len(dataloader)
啟發(fā)式方法基于所用采樣器的長度。 當dataset
是 IterableDataset
時,無論多進程加載配置如何,都將返回len(dataset)
(如果實現(xiàn)),因為 PyTorch 信任用戶dataset
代碼可以正確處理多進程加載 避免重復數(shù)據(jù)。 有關這兩種類型的數(shù)據(jù)集以及 IterableDataset
如何與多進程數(shù)據(jù)加載交互的更多詳細信息,請參見數(shù)據(jù)集類型。
class torch.utils.data.Dataset?
表示 Dataset
的抽象類。
代表從鍵到數(shù)據(jù)樣本的映射的所有數(shù)據(jù)集都應將其子類化。 所有子類都應該覆蓋__getitem__()
,支持為給定鍵獲取數(shù)據(jù)樣本。 子類還可以選擇覆蓋__len__()
,它有望通過許多 Sampler
實現(xiàn)以及 DataLoader
的默認選項返回數(shù)據(jù)集的大小。
Note
默認情況下, DataLoader
構造一個索引采樣器,該采樣器產生整數(shù)索引。 要使其與具有非整數(shù)索引/鍵的地圖樣式數(shù)據(jù)集一起使用,必須提供自定義采樣器。
class torch.utils.data.IterableDataset?
可迭代的數(shù)據(jù)集。
代表可迭代數(shù)據(jù)樣本的所有數(shù)據(jù)集都應將其子類化。 當數(shù)據(jù)來自流時,這種形式的數(shù)據(jù)集特別有用。
所有子類都應覆蓋__iter__()
,這將返回此數(shù)據(jù)集中的樣本迭代器。
當子類與 DataLoader
一起使用時,數(shù)據(jù)集中的每個項目都將由 DataLoader
迭代器產生。 當num_workers > 0
時,每個工作進程將具有數(shù)據(jù)集對象的不同副本,因此通常需要獨立配置每個副本,以避免從工作進程返回重復的數(shù)據(jù)。 get_worker_info()
在工作程序進程中調用時,返回有關工作程序的信息。 可以在數(shù)據(jù)集的__iter__()
方法或 DataLoader
的worker_init_fn
選項中使用它來修改每個副本的行為。
示例 1:在__iter__()
中將工作負載分配給所有工作人員:
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... worker_info = torch.utils.data.get_worker_info()
... if worker_info is None: # single-process data loading, return the full iterator
... iter_start = self.start
... iter_end = self.end
... else: # in a worker process
... # split workload
... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
... worker_id = worker_info.id
... iter_start = self.start + worker_id * per_worker
... iter_end = min(iter_start + per_worker, self.end)
... return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)
>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>> # Mult-process loading with two worker processes
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 5, 4, 6]
>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
[3, 4, 5, 6]
示例 2:使用worker_init_fn
在所有工作人員之間分配工作量:
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)
>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]
>>> # Define a `worker_init_fn` that configures each dataset differently
>>> def worker_init_fn(worker_id):
... worker_info = torch.utils.data.get_worker_info()
... dataset = worker_info.dataset # the dataset copy in this worker process
... overall_start = dataset.start
... overall_end = dataset.end
... # configure the dataset to only process the split workload
... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
... worker_id = worker_info.id
... dataset.start = overall_start + worker_id * per_worker
... dataset.end = min(dataset.start + per_worker, overall_end)
...
>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]
>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]
class torch.utils.data.TensorDataset(*tensors)?
數(shù)據(jù)集包裝張量。
每個樣本將通過沿第一維索引張量來檢索。
Parameters
張量 (tensor*)–具有與第一維相同大小的張量。
class torch.utils.data.ConcatDataset(datasets)?
數(shù)據(jù)集是多個數(shù)據(jù)集的串聯(lián)。
此類對于組裝不同的現(xiàn)有數(shù)據(jù)集很有用。
Parameters
數(shù)據(jù)集(序列)–要連接的數(shù)據(jù)集列表
class torch.utils.data.ChainDataset(datasets)?
用于鏈接多個 IterableDataset
的數(shù)據(jù)集。
此類對于組裝不同的現(xiàn)有數(shù)據(jù)集流很有用。 鏈接操作是即時完成的,因此將大型數(shù)據(jù)集與此類連接起來將非常有效。
Parameters
數(shù)據(jù)集(IterableDataset 的可迭代)–鏈接在一起的數(shù)據(jù)集
class torch.utils.data.Subset(dataset, indices)?
指定索引處的數(shù)據(jù)集子集。
Parameters
torch.utils.data.get_worker_info()?
返回有關當前 DataLoader
迭代器工作進程的信息。
在工作線程中調用時,此方法返回一個保證具有以下屬性的對象:
id
:當前工作人員 ID。num_workers
:工人總數(shù)。seed
:當前工作程序的隨機種子集。 該值由主進程 RNG 和工作程序 ID 確定。 有關更多詳細信息,請參見 DataLoader
的文檔。dataset
:此流程在中的數(shù)據(jù)集對象的副本。 請注意,在不同的過程中,這將是與主過程中的對象不同的對象。
在主進程中調用時,將返回None
。
Note
在傳遞給 DataLoader
的worker_init_fn
中使用時,此方法可用于不同地設置每個工作進程,例如,使用worker_id
將dataset
對象配置為僅讀取 分片數(shù)據(jù)集的特定部分,或使用seed
播種數(shù)據(jù)集代碼中使用的其他庫(例如 NumPy)。
torch.utils.data.random_split(dataset, lengths)?
將數(shù)據(jù)集隨機拆分為給定長度的不重疊的新數(shù)據(jù)集。
Parameters
class torch.utils.data.Sampler(data_source)?
所有采樣器的基類。
每個 Sampler 子類都必須提供__iter__()
方法(提供一種對數(shù)據(jù)集元素的索引進行迭代的方法)和__len__()
方法,該方法返回返回的迭代器的長度。
Note
DataLoader
并非嚴格要求__len__()
方法,但在涉及 DataLoader
長度的任何計算中都應采用。
class torch.utils.data.SequentialSampler(data_source)?
始終以相同順序順序采樣元素。
Parameters
data_source (數(shù)據(jù)集)–要從中采樣的數(shù)據(jù)集
class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)?
隨機采樣元素。 如果不進行替換,則從經過改組的數(shù)據(jù)集中采樣。 如果要更換,則用戶可以指定num_samples
進行繪制。
Parameters
True
為默認值,則替換為True
True
時才應指定此參數(shù)。class torch.utils.data.SubsetRandomSampler(indices)?
從給定的索引列表中隨機抽樣元素,而無需替換。
Parameters
索引(序列)–索引序列
class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)?
以給定的概率(權重)從[0,..,len(weights)-1]
中采樣元素。
Parameters
True
,則抽取替代品抽取樣品。 如果沒有,則它們將被替換而不會被繪制,這意味著當為一行繪制樣本索引時,無法為該行再次繪制它。例
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[0, 0, 0, 1, 0]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]
class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)?
包裝另一個采樣器以產生一個小批量的索引。
Parameters
True
,則采樣器將丟棄最后一批,如果其大小小于batch_size
Example
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True)?
將數(shù)據(jù)加載限制為數(shù)據(jù)集子集的采樣器。
與 torch.nn.parallel.DistributedDataParallel
結合使用時特別有用。 在這種情況下,每個進程都可以將 DistributedSampler 實例作為 DataLoader 采樣器傳遞,并加載原始數(shù)據(jù)集的專有子集。
Note
假定數(shù)據(jù)集大小恒定。
Parameters
更多建議: