首頁 > 軟體

Pytorch技法之繼承Subset類完成自定義資料拆分

2022-02-20 19:00:06

我們在 《torch.utils.data.DataLoader與迭代器轉換操作》 中介紹瞭如何使用Pytorch內建的資料集進行論文實驗,如 torchvision.datasets 。下面是載入內建訓練資料集的常見操作:

from torchvision.datasets import FashionMNIST
from torchvision.transforms import Compose, ToTensor, Normalize
RAW_DATA_PATH = './rawdata'
transform = Compose(
        [ToTensor(),
         Normalize((0.1307,), (0.3081,))
         ]
    )
train_data = FashionMNIST(
        root=RAW_DATA_PATH,
        download=True,
        train=True,
        transform=transform
    )

這裡的train_data 做為 dataset 物件,它擁有許多熟悉,我們可以通過以下方法獲取樣本資料的分類類別集合、樣本的特徵維度、樣本的標籤集合等資訊。

classes = train_data.classes
num_features = train_data.data[0].shape[0]
train_labels = train_data.targets

print(classes)
print(num_features)
print(train_labels)

輸出如下:

['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
28
tensor([9, 0, 0,  ..., 3, 0, 5])

但是,我們常常會在訓練集的基礎上拆分出驗證集(或者只用部分資料來進行訓練)。我們想到的第一個方法是使用 torch.utils.data.random_splitdataset 進行劃分,下面我們假設劃分10000個樣本做為訓練集,其餘樣本做為驗證集:

from torch.utils.data import random_split
k = 10000
train_data, valid_data = random_split(train_data, [k, len(train_data)-k])

注意我們如果列印 train_data 和 valid_data 的型別,可以看到顯示:

<class 'torch.utils.data.dataset.Subset'>

已經不再是torchvision.datasets.mnist.FashionMNIST 物件,而是一個所謂的 Subset 物件!此時 Subset 物件雖然仍然還存有 data 屬性,但是內建的 target classes 屬性已經不復存在,

比如如果我們強行存取 valid_data 的 target 屬性:

valid_target = valid_data.target

就會報如下錯誤:

'Subset' object has no attribute 'target'

但如果我們在後續的程式碼中常常會將拆分後的資料集也預設為 dataset 物件,那麼該如何做到程式碼的一致性呢?

這裡有一個trick,那就是以繼承 SubSet 類的方式的方式定義一個新的 CustomSubSet 類,使新類在保持 SubSet 類的基本屬性的基礎上,擁有和原本資料集類相似的屬性,如 targets classes 等:

from torch.utils.data import Subset
class CustomSubset(Subset):
    '''A custom subset class'''
    def __init__(self, dataset, indices):
        super().__init__(dataset, indices)
        self.targets = dataset.targets # 保留targets屬性
        self.classes = dataset.classes # 保留classes屬性

    def __getitem__(self, idx): #同時支援索引存取操作
        x, y = self.dataset[self.indices[idx]]      
        return x, y 

    def __len__(self): # 同時支援取長度操作
        return len(self.indices)

然後就引出了第二種劃分方法,即通過初始化 CustomSubset 物件的方式直接對資料集進行劃分(這裡為了簡化省略了shuffle的步驟):

import numpy as np
from copy import deepcopy
origin_data = deepcopy(train_data)
train_data = CustomSubset(origin_data, np.arange(k))
valid_data = CustomSubset(origin_data, np.arange(k, len(origin_data))-k)

注意: CustomSubset 類的初始化方法的第二個引數 indices 為樣本索引,我們可以通過 np.arange() 的方法來建立。

然後,我們再存取 valid_data 對應的 classes 和 targes 屬性:

print(valid_data.classes)
print(valid_data.targets)

此時,我們發現可以成功存取這些屬性了:

['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
tensor([9, 0, 0,  ..., 3, 0, 5])

當然, CustomSubset 的作用並不只是新增資料集的屬性,我們還可以自定義一些資料預處理操作。

我們將類的結構修改如下:

class CustomSubset(Subset):
    '''A custom subset class with customizable data transformation'''
    def __init__(self, dataset, indices, subset_transform=None):
        super().__init__(dataset, indices)
        self.targets = dataset.targets
        self.classes = dataset.classes
        self.subset_transform = subset_transform

    def __getitem__(self, idx):
        x, y = self.dataset[self.indices[idx]]
        
        if self.subset_transform:
            x = self.subset_transform(x)
      
        return x, y   
    
    def __len__(self): 
        return len(self.indices)

我們可以在使用樣本前設定好資料預處理運算元:

from torchvision import transforms
valid_data.subset_transform = transforms.Compose(
    [transforms.RandomRotation((180,180))])

這樣,我們再像下列這樣用索引存取取出資料集樣本時,就會自動呼叫運算元完成預處理操作:

print(valid_data[0])

列印結果縮略如下:

(tensor([[[-0.4242, -0.4242, -0.4242, ......-0.4242, -0.4242, -0.4242, -0.4242, -0.4242]]]), 9)

 到此這篇關於Pytorch技法之繼承Subset類完成自定義資料拆分的文章就介紹到這了,更多相關繼承Subset類完成自定義資料拆分內容請搜尋it145.com以前的文章或繼續瀏覽下面的相關文章希望大家以後多多支援it145.com!


IT145.com E-mail:sddin#qq.com