首頁 > 軟體

PyTorch Dataset與DataLoader使用超詳細講解

2022-10-16 14:01:47

一、Dataset

Dataset 類提供一種方式去獲取資料及其標籤

主要有兩個目的:

  • 獲取每一個資料及其標籤
  • 獲取資料的總量大小

1. 在控制檯進行操作

Hymenoptera (膜翅目昆蟲)資料集下載地址:

連結: https://pan.baidu.com/s/1XKwXsAtE2yzZW2IsvBDpnw?pwd=8a5t

提取碼: 8a5t 

這是一個螞蟻蜜蜂二分類的資料集,通常資料集有以下三種組織形式(上面的資料集屬於第一種):

  • 不同的類別以資料夾的形式存在,資料夾中是該類別的圖片
  • 圖片與標籤分別儲存,圖片在一個資料夾下,label資訊在另一個資料夾下
  • label直接寫在圖片名稱裡

①獲取圖片的基本資訊

在Pycharm 中,點選下方的PythonConsole進入控制檯進行操作(通過控制檯可以看到變數的詳細資訊)

首先載入圖片,逐行輸入下方程式碼:

from PIL import Image
img_path = "./dataset/hymenoptera_data/train/ants/0013035.jpg"
img = Image.open(img_path)

此時我們就可以在右側看到相關變數的資訊:

點選img變數,可以檢檢視片的詳細資訊。通過控制檯執行程式能夠直觀地獲取後續操作所需的資料:

最後可以通過img.show()開啟圖片檢視:

②獲取檔案的基本資訊

同樣還是在控制檯逐行輸入以下程式碼:

dir_path = "dataset/hymenoptera_data/train/ants"
import os
img_path_list = os.listdir(dir_path)
img_path_list[0]

我們就可以獲取到資料夾下的檔名稱,由於是使用控制檯,我們還可以在右側檢視列表的詳細資訊:

因此在控制檯操作是有很大的優點的,我們可以在控制檯逐行執行已經編寫好的檔案中的語句,通過檢視右側變數的值來判斷程式寫的是否有問題

2. 編寫一個繼承Dataset 的類載入資料

下面的程式碼也可以在控制檯執行(可以多行復制貼上)來檢驗程式是否有誤

①定義 MyData類

匯入所需標頭檔案:

from torch.utils.data import Dataset
from PIL import Image
import os

定義MyData類:

  • __init__:初始化函數
  • __getitem__:返回指定下標的圖片和標籤
  • __len__:返回資料集的大小
class MyData(Dataset):
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path = os.listdir(self.path)
    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label
    def __len__(self):
        return len(self.img_path)

其中os.path.join()可以實現多個路徑的合併且不出錯

②建立類的範例並呼叫

建立 MyData 類的範例:

if __name__ == "__main__":
    root_dir = "../dataset/hymenoptera_data/train"
    ants_label_dir = "ants"
    bees_label_dir = "bees"
    ants_dataset = MyData(root_dir, ants_label_dir)
    bees_dataset = MyData(root_dir, bees_label_dir)

呼叫類中寫好的函數:

    img, label = ants_dataset.__getitem__(3)
    print(ants_dataset.__len__(), label)
    img.show()

同時我們也可以通過下面這種方式用已有的資料集來創造資料集:

train_dataset = ants_dataset + bees_dataset

二、DataLoader

  • DataLoader 類是為後面的網路提供不同的資料形式
  • DataLoader 會根據batch_size的值對資料進行打包
  • 匯入所需的包
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

載入資料:

test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

測試:

img, target = test_data[0]
print(img.shape)
print(target)

進行紀錄檔記錄,開始訓練:

writer = SummaryWriter("dataloader")
for epoch in range(2):
    step = 0
    for data in test_loader:
        imgs, targets = data
        print(imgs.shape)
        print(targets)
        writer.add_images("Epoch: {}".format(epoch), imgs, step)
        step = step + 1
writer.close()

到此這篇關於PyTorch Dataset與DataLoader使用超詳細講解的文章就介紹到這了,更多相關PyTorch Dataset與DataLoader內容請搜尋it145.com以前的文章或繼續瀏覽下面的相關文章希望大家以後多多支援it145.com!


IT145.com E-mail:sddin#qq.com