首頁 > 軟體

pytorch載入自己的圖片資料集的2種方法詳解

2022-06-11 14:00:37

pytorch載入圖片資料集有兩種方法。

1.ImageFolder 適合於分類資料集,並且每一個類別的圖片在同一個資料夾, ImageFolder載入的資料集, 訓練資料為檔案件下的圖片, 訓練標籤是對應的資料夾, 每個資料夾為一個類別

匯入ImageFolder()包
from torchvision.datasets import ImageFolder

在Flower_Orig_dataset資料夾下有flower_orig 和 sunflower這兩個資料夾, 這兩個資料夾下放著同一個類別的圖片。 使用 ImageFolder 載入的圖片, 就會返回圖片資訊和對應的label資訊, 但是label資訊是根據資料夾給出的, 如flower_orig就是標籤0, sunflower就是標籤1。

ImageFolder 載入資料集

1. 匯入包和設定transform

import torch
from torchvision import transforms, datasets
import torch.nn as nn
from torch.utils.data import DataLoader
 
transforms = transforms.Compose([
    transforms.Resize(256),    # 將圖片短邊縮放至256,長寬比保持不變:
    transforms.CenterCrop(224),   #將圖片從中心切剪成3*224*224大小的圖片
    transforms.ToTensor()          #把圖片進行歸一化,並把資料轉換成Tensor型別
]) 

2.載入資料集: 將分類圖片的父目錄作為路徑傳遞給ImageFolder(), 並傳入transform。這樣就有了要載入的資料集, 之後就可以使用DataLoader載入資料, 並構建網路訓練。

path = r'D:資料集Flower_Orig_dataset'
data_train = datasets.ImageFolder(path, transform=transforms)
data_loader = DataLoader(data_train, batch_size=64, shuffle=True)
for i, data in enumerate(data_loader):
    images, labels = data
    print(images.shape)
    print(labels.shape)
    break

使用pytorch提供的Dataset類建立自己的資料集。

具體步驟:

1.  首先要有一個txt檔案, 這個檔案格式是: 圖片路徑  標籤.  這樣的格式, 所以使用os庫, 遍歷自己的圖片名, 並把標籤和圖片路徑寫入txt檔案。

2. 有了這個txt檔案, 我們就可以在類裡面構造我們的資料集.

2.1    把圖片路徑和圖片標籤分割開, 有兩個列表, 一個列表是圖片路徑名, 一個列表是標籤號, 有一點就是第 i 個圖片列表和 第 i 個標籤是對應的

3. 重寫__len__方法  和  __getitem__方法

3.1 getitem方法中, 獲得對應的圖片路徑,並用PIL庫讀取檔案把圖片transfrom後, 在getitem函數中返回讀取的圖片和標籤即可

4.就可以構建資料集範例和載入資料集.

 定義一個用來生成[ 圖片路徑 標籤] 這樣的txt檔案函數

def make_txt(root, file_name, label):
    path = os.path.join(root, file_name)
    data = os.listdir(path)
    f = open(path+'\'+'f.txt', 'w')
    for line in data:
        f.write(line+' '+str(label)+'n')
    f.close()
#呼叫函數生成兩個資料夾下的txt檔案
make_txt(path, file_name='flower_orig', label=0)
make_txt(path, file_name='sunflower', label=1)

將連個txt檔案合併成一個txt檔案,表示資料集所有的圖片和標籤

def link_txt(file1, file2):
    txt_list = []
    path = r'D:資料集Flower_Orig_datasetdata.txt'
 
    f = open(path, 'a')
 
    f1 = open(file1, 'r')
    data1 = f1.readlines()
    for line in data1:
        txt_list.append(line)
 
    f2 = open(file2, 'r')
    data2 = f2.readlines()
    for line in data2:
        txt_list.append(line)
 
    for line in txt_list:
        f.write(line)
 
    f.close()
    f1.close()
    f2.close()
 
#呼叫函數, 將兩個資料夾下的txt檔案合併
file1 = r'D:資料集Flower_Orig_datasetflower_origf.txt'
file2 = r'D:資料集Flower_Orig_datasetsunflowerf.txt'
link_txt(file1=file1, file2=file2)

現在我們已經有了我們製作資料集所需要的txt檔案, 接下來要做的即使繼承Dataset類, 來構建自己的資料集 , 別忘了前面說的 構建資料集步驟, 在__getitem__函數中, 需要拿到圖片路徑和標籤, 並且用PIL庫方法讀取圖片,對圖片進行transform轉換後,返回圖片資訊和標籤資訊

Dataset載入資料集

我們讀取圖片的根目錄, 在根目錄下有所有圖片的txt檔案, 拿到txt檔案後, 先讀取txt檔案, 之後遍歷txt檔案中的每一行, 首先去除掉尾部的換行符, 在以空格切分,前半部分是圖片名稱, 後半部分是圖片標籤, 當圖片名稱和根目錄結合,就得到了我們的圖片路徑   
class MyDataset(Dataset):
    def __init__(self, img_path, transform=None):
        super(MyDataset, self).__init__()
        self.root = img_path
 
        self.txt_root = self.root + 'data.txt'
        f = open(self.txt_root, 'r')
        data = f.readlines()
 
        imgs = []
        labels = []
        for line in data:
            line = line.rstrip()
            word = line.split()
            imgs.append(os.path.join(self.root, word[1], word[0]))
 
            labels.append(word[1])
        self.img = imgs
        self.label = labels
        self.transform = transform
 
    def __len__(self):
        return len(self.label)
 
    def __getitem__(self, item):
        img = self.img[item]
        label = self.label[item]
 
        img = Image.open(img).convert('RGB')
 
        #此時img是PIL.Image型別   label是str型別
 
        if transforms is not None:
            img = self.transform(img)
 
        label = np.array(label).astype(np.int64)
        label = torch.from_numpy(label)
        
        return img, label

 載入我們的資料集:

path = r'D:資料集Flower_Orig_dataset'
dataset = MyDataset(path, transform=transform)
 
data_loader = DataLoader(dataset=dataset, batch_size=64, shuffle=True)

接下來我們就可以構建我們的網路架構:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3,16,3)
        self.maxpool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(16,5,3)
 
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(55*55*5, 1200)
        self.fc2 = nn.Linear(1200,64)
        self.fc3 = nn.Linear(64,2)
 
    def forward(self,x):
        x = self.maxpool(self.relu(self.conv1(x)))    #113
        x = self.maxpool(self.relu(self.conv2(x)))    #55
        x = x.view(-1, self.num_flat_features(x))
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    
    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
 
        return num_features
 

 訓練我們的網路:

model = Net()
 
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
 
 
epochs = 10
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(data_loader):
        images, label = data
 
        out = model(images)
 
        loss = criterion(out, label)
 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
        running_loss += loss.item()
        if(i+1)%10 == 0:
            print('[%d  %5d]   loss: %.3f'%(epoch+1, i+1, running_loss/100))
            running_loss = 0.0
 
print('finished train')

 儲存網路模型(這裡不止是儲存引數,還儲存了網路結構)

#儲存模型
torch.save(net, 'model_name.pth')   #儲存的是模型, 不止是w和b權重值
 
# 讀取模型
model = torch.load('model_name.pth')

總結

到此這篇關於pytorch載入自己的圖片資料集的2種方法的文章就介紹到這了,更多相關pytorch載入圖片資料集內容請搜尋it145.com以前的文章或繼續瀏覽下面的相關文章希望大家以後多多支援it145.com!


IT145.com E-mail:sddin#qq.com