首頁 > 軟體

PyTorch如何建立自己的資料集

2022-11-28 22:01:49

PyTorch建立自己的資料集

圖片檔案在同一的資料夾下

思路是繼承 torch.utils.data.Dataset,並重點重寫其 __getitem__方法,範例程式碼如下:

class ImageFolder(Dataset):
    def __init__(self, folder_path):
        self.files = sorted(glob.glob('%s/*.*' % folder_path))

    def __getitem__(self, index):
        path = self.files[index % len(self.files)]
        img = np.array(Image.open(path))
        h, w, c = img.shape
        pad = ((40, 40), (4, 4), (0, 0))

        # img = np.pad(img, pad, 'constant', constant_values=0) / 255
        img = np.pad(img, pad, mode='edge') / 255.0
        img = torch.from_numpy(img).float()
        patches = np.reshape(img, (3, 10, 128, 11, 128))
        patches = np.transpose(patches, (0, 1, 3, 2, 4))

        return img, patches, path

    def __len__(self):
        return len(self.files)

圖片檔案在不同的資料夾下

比如我們有資料如下:

─── data
├── train
│ ├── 0.jpg
│ └── 1.jpg
├── test
│ ├── 0.jpg
│ └── 1.jpg
└── val
├── 1.jpg
└── 2.jpg

此時我們只需要將以上程式碼稍作修改即可,修改的程式碼如下:

self.files = sorted(glob.glob('%s/**/*.*' % folder_path, recursive=True))

其他程式碼不變。

pytorch常用資料集的使用

對於pytorch資料集的使用,範例程式碼如下:

from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import Compose
from torchvision import transforms
import torchvision
import ssl

ssl._create_default_https_context = ssl._create_unverified_context

dataset_transform = Compose([transforms.ToTensor()])


# 關於官方資料集的使用還是關鍵要看pytorch的官方檔案
train_set = torchvision.datasets.CIFAR10(root="./CIFAR10",train=True,transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root="./CIFAR10",train=False,transform=dataset_transform,download=True)

# 檢視測試資料集中的第一個資料
# print(test_set[0])
# 檢視測試資料集中的分類情況
# print(test_set.classes)
#
# 取出第一個資料中的圖片(img)和分類結果(target)
# img,target = test_set[0]
# 檢檢視片資料的型別
# print(img)
# print(target)
# 輸出類別
# print(test_set.classes[target])
# 檢檢視片
# img.show()

# 使用tensorboard顯示tensor資料型別的圖片
writer = SummaryWriter("logs")
for i in range(10):
	# 取出資料中的圖片(img)和分類結果(target)
    img,target = test_set[i]
    writer.add_image("test_set",img,i)

writer.close()

上述程式碼執行結果在tensorboard視覺化:

程式碼

train_set = torchvision.datasets.CIFAR10(root="./CIFAR10",train=True,transform=dataset_transform,download=True)

常用引數講解

  • root:根目錄,存放資料集的位置
  • train:若為True,則劃分為訓練資料集,若為False,則劃分為測試資料集
  • transform:指定輸入資料集處理方式
  • download:若為True,則會將資料集下載到root指定的目錄下,否則不會下載

官方檔案對引數的解釋:

root (string) – Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True.

train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.

transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop

target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

注意:

  • 關於官方資料集的使用還是關鍵要看pytorch的官方檔案
  • 下載資料集的細節之處:知道下載連結(下載連結可以在原始碼中檢視)之後可以不用使用程式碼下載了,使用迅雷來下載可能會更快。
  • 要學會使用Pycharm中的ctrl+p和ctrl+alt這兩個快捷鍵
  • pytorch官網
  • pytorch官方資料集(下載資料集方法)

以上為個人經驗,希望能給大家一個參考,也希望大家多多支援it145.com。


IT145.com E-mail:sddin#qq.com