首頁 > 軟體

python機器學習pytorch自定義資料載入器

2022-10-14 14:01:30

正文

處理資料樣本的程式碼可能會逐漸變得混亂且難以維護;理想情況下,我們希望我們的資料集程式碼與我們的模型訓練程式碼分離,以獲得更好的可讀性和模組化。PyTorch 提供了兩個資料原語:torch.utils.data.DataLoadertorch.utils.data.Dataset 允許我們使用預載入的資料集以及自定義資料。 Dataset儲存樣本及其對應的標籤,DataLoader封裝了一個迭代器用於遍歷Dataset,以便輕鬆存取樣本資料。

PyTorch 領域庫提供了許多預載入的資料集(例如 FashionMNIST),這些資料集繼承自torch.utils.data.Dataset並實現了特定於特定資料的功能。它們可用於對您的模型進行原型設計和基準測試。你可以在這裡找到它們:影象資料集、 文字資料集和 音訊資料集

1. 載入資料集

下面是如何從 TorchVision 載入Fashion-MNIST資料集的範例。Fashion-MNIST 是 Zalando 文章影象的資料集,由 60,000 個訓練範例和 10,000 個測試範例組成。每個範例都包含 28×28 灰度影象和來自 10 個類別之一的相關標籤。

我們使用以下引數載入FashionMNIST 資料集:

  • root是儲存訓練/測試資料的路徑,
  • train指定訓練或測試資料集,
  • download=True如果資料不可用,則從 Internet 下載資料root
  • transformtarget_transform指定特徵和標籤轉換
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
  0%|          | 0/26421880 [00:00<?, ?it/s]
  0%|          | 32768/26421880 [00:00<01:26, 303914.51it/s]
  0%|          | 65536/26421880 [00:00<01:27, 301769.74it/s]
  0%|          | 131072/26421880 [00:00<01:00, 437795.76it/s]
  1%|          | 229376/26421880 [00:00<00:42, 621347.43it/s]
  2%|1         | 491520/26421880 [00:00<00:20, 1259673.64it/s]
  4%|3         | 950272/26421880 [00:00<00:11, 2264911.11it/s]
  7%|7         | 1933312/26421880 [00:00<00:05, 4467299.81it/s]
 15%|#4        | 3833856/26421880 [00:00<00:02, 8587616.55it/s]
 26%|##6       | 6881280/26421880 [00:00<00:01, 14633777.99it/s]
 37%|###7      | 9830400/26421880 [00:01<00:00, 18150145.01it/s]
 49%|####8     | 12910592/26421880 [00:01<00:00, 21161097.17it/s]
 61%|######    | 16023552/26421880 [00:01<00:00, 23366004.89it/s]
 72%|#######2  | 19136512/26421880 [00:01<00:00, 24967488.10it/s]
 84%|########4 | 22249472/26421880 [00:01<00:00, 26016258.24it/s]
 95%|#########5| 25231360/26421880 [00:01<00:00, 26218488.24it/s]
100%|##########| 26421880/26421880 [00:01<00:00, 15984902.80it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
  0%|          | 0/29515 [00:00<?, ?it/s]
100%|##########| 29515/29515 [00:00<00:00, 268356.24it/s]
100%|##########| 29515/29515 [00:00<00:00, 266767.69it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
  0%|          | 0/4422102 [00:00<?, ?it/s]
  1%|          | 32768/4422102 [00:00<00:14, 302027.13it/s]
  1%|1         | 65536/4422102 [00:00<00:14, 300501.69it/s]
  3%|2         | 131072/4422102 [00:00<00:09, 436941.45it/s]
  5%|5         | 229376/4422102 [00:00<00:06, 619517.19it/s]
 10%|9         | 425984/4422102 [00:00<00:03, 1044158.55it/s]
 20%|##        | 884736/4422102 [00:00<00:01, 2114396.73it/s]
 40%|####      | 1769472/4422102 [00:00<00:00, 4067080.68it/s]
 80%|########  | 3538944/4422102 [00:00<00:00, 7919346.09it/s]
100%|##########| 4422102/4422102 [00:00<00:00, 5036535.17it/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
  0%|          | 0/5148 [00:00<?, ?it/s]
100%|##########| 5148/5148 [00:00<00:00, 22168662.21it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

2. 迭代和視覺化資料集

我們可以像python 列表一樣索引Datasets,比如:

training_data[index]

我們用matplotlib來視覺化訓練資料中的一些樣本。

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

3.建立自定義資料集

自定義 Dataset 類必須實現三個函數:initlen__和__getitem

比如: FashionMNIST 影象儲存在一個目錄img_dir中,它們的標籤分別儲存在一個 CSV 檔案annotations_file中。

在接下來的部分中,我們將分析每個函數中發生的事情。

import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
    def __len__(self):
        return len(self.img_labels)
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

3.1 __init__

init 函數在範例化 Dataset 物件時執行一次。我們初始化包含影象、註釋檔案和兩種轉換的目錄(在下一節中更詳細地介紹)。

labels.csv 檔案如下所示:

tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file)
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform

3.2 __len__

len 函數返回我們資料集中的樣本數。

例子:

def __len__(self):
    return len(self.img_labels)

3.3 __getitem__

getitem 函數從給定索引處的資料集中載入並返回一個樣本idx。基於索引,它識別影象在磁碟上的位置,使用 將其轉換為張量read_image,從 csv 資料中檢索相應的標籤self.img_labels,呼叫它們的轉換函數(如果適用),並返回張量影象和相應的標籤一個元組。

def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    return image, label

4. 使用 DataLoaders 為訓練準備資料

Dataset一次載入一個樣本資料和其對應的label。在訓練模型時,我們通常希望以minibatches“小批次”的形式傳遞樣本,在每個 epoch 重新洗牌以減少模型過擬合,並使用 Pythonmultiprocessing加速資料檢索。

DataLoader是一個可迭代物件,它封裝了複雜性並暴漏了簡單的API。

from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

5.遍歷 DataLoader

我們已將該資料集載入到 DataLoader中,並且可以根據需要遍歷資料集。下面的每次迭代都會返回一批train_featurestrain_labels(分別包含batch_size=64特徵和標籤)。因為我們指定shuffle=True了 ,所以在我們遍歷所有批次之後,資料被打亂(為了更細粒度地控制資料載入順序,請檢視Samplers)。

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 4

以上就是python機器學習pytorch自定義資料載入器的詳細內容,更多關於python pytorch自定義資料載入器的資料請關注it145.com其它相關文章!


IT145.com E-mail:sddin#qq.com