首頁 > 軟體

pytorch載入自己的資料集原始碼分享

2022-08-16 14:02:46

一、標準的資料集流程梳理

分為幾個步驟
資料準備以及載入資料庫–>資料載入器的呼叫或者設計–>批次呼叫進行訓練或者其他作用

資料來源

直接讀取了x和y的資料變數,對比後面的就從把對應的路徑寫進了文字檔案中,通過載入器進行讀取

x = torch.linspace(1, 10, 10)   # 訓練資料 linspace返回一個一維的張量,(最小值,最大值,多少個數)
print(x)
y = torch.linspace(10, 1, 10)   # 標籤
print(y)

將資料載入進資料庫

輸出的結果是<torch.utils.data.dataset.TensorDataset object at 0x00000145BD93F1C0>,需要使用載入器進行載入,才能迭代遍歷

import torch.utils.data as Data
torch_dataset = Data.TensorDataset(x, y)  # 對給定的 tensor 資料,將他們包裝成 dataset
#輸出的結果是<torch.utils.data.dataset.TensorDataset object at 0x00000145BD93F1C0>,需要使用載入器進行載入,才能迭代遍歷
print(torch_dataset)

所以要想看裡面的內容,就需要用迭代進行操作或者檢視。

BATCH_SIZE=5
loader = Data.DataLoader(#使用支援的預設的資料集載入的方式
    # 從資料庫中每次抽出batch size個樣本
    dataset=torch_dataset,       # torch TensorDataset format   載入資料集
    batch_size=BATCH_SIZE,       # mini batch size 5
    shuffle=False,                # 要不要打亂資料 (打亂比較好)
    num_workers=2,               # 多執行緒來讀資料
)
 
def show_batch():
    for epoch in range(3):
        for step, (batch_x, batch_y) in enumerate(loader): #載入資料集的時候起的作用很奇怪
            # training
            print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))
            print("*"*100)
if __name__ == '__main__':
    show_batch()

二、實現載入自己的資料集

實現自己的資料集就需要完成對dataset類的過載。這個類的過載完成幾個函數的作用

  • 初始化資料集中的資料以及標籤__init__()
  • 返回資料和對應標籤__getitem__
  • 返回資料集的大小__len__

基本的資料集的方法就是完成以上步驟,但是可以想想資料集通常是一些圖片和標籤組成,而這些資料集以及標籤是儲存在計算機上,具有相對應的位置,那麼直接存取對應的位置因為是在資料夾下需要進行遍歷等一系列操作,而且這就顯得和dataset類沒有解耦,因為有時候在這些位置的操作可能會有一些特殊操作,所以如果能夠將其位置儲存在文字檔案中可能就會方便很多,所以就採取儲存文字檔案的方式。

# 自定義資料集類
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, *args):
        super().__init__()
        # 初始化資料集包含的資料和標籤
        pass
        
    def __getitem__(self, index):
        # 根據索引index從檔案中讀取一個資料
        # 對資料預處理
        # 返回資料和對應標籤
        pass
    
    def __len__(self):
        # 返回資料集的大小
        return len()

1. 儲存在txt檔案中(生成訓練集和測試集,其實這裡的訓練集以及測試集也都是用文字檔案的形式儲存下來的)

所以這裡新建一個資料庫就是新建了兩個文字檔案,然後載入器通過文字檔案就將圖片以及label載入進去了。而標準的資料集操作是使用了自帶的資料集介面,在載入的時候也不用再去實現相關的__getitem__方法

  • 陣列定義
  • 將絕對路徑載入進陣列中
  • 陣列定義
  • 將絕對路徑載入進陣列中
  • 通過os.walk操作
  • os.walk可以獲得根路徑、資料夾以及檔案,並會一直進行迭代遍歷下去,直至只有檔案才會結束
  • 將陣列的內容打亂順序
  • 分別將絕對路徑對應的陣列內容寫進文字檔案裡,那麼這裡的文字檔案就是儲存的資料庫,其實資料就是一個儲存相關資訊或者其內容的檔案,而標準也是將將其資料儲存在了一個地方,然後對應到標準介面就可以載入了(Data.TensorDataset以及Data.DataLoader)

以下程式碼用於生成對應的train.txt val.txt

'''
生成訓練集和測試集,儲存在txt檔案中
'''
import os
import random


train_ratio = 0.6


test_ratio = 1-train_ratio

rootdata = r"dataset"

#陣列定義
train_list, test_list = [],[]
data_list = []

class_flag = -1
# 將絕對路徑載入進陣列中
for a,b,c in os.walk(rootdata):#os.walk可以獲得根路徑、資料夾以及檔案,並會一直進行迭代遍歷下去,直至只有檔案才會結束
    print(a)
    for i in range(len(c)):
        data_list.append(os.path.join(a,c[i]))

    for i in range(0,int(len(c)*train_ratio)):
        train_data = os.path.join(a, c[i])+'t'+str(class_flag)+'n' #class_flag表示分類的類別
        train_list.append(train_data)

    for i in range(int(len(c) * train_ratio),len(c)):
        test_data = os.path.join(a, c[i]) + 't' + str(class_flag)+'n'
        test_list.append(test_data)

    class_flag += 1 

print(train_list)
# 將陣列的內容打亂順序
random.shuffle(train_list)
random.shuffle(test_list)

#分別將絕對路徑對應的陣列內容寫進文字檔案裡
with open('train.txt','w',encoding='UTF-8') as f:
    for train_img in train_list:
        f.write(str(train_img))

with open('test.txt','w',encoding='UTF-8') as f:
    for test_img in test_list:
        f.write(test_img)

2. 在繼承dataset類LoadData的三個函數裡呼叫train.txt以及test.txt實現相關功能

初始化資料集中的資料以及標籤、相關變數__init__()

def __init__(self, txt_path, train_flag=True):
     #初始化圖片對應的變數imgs_info以及一些相關變數
     self.imgs_info = self.get_images(txt_path) #imgs_info儲存了圖片以及標籤
     self.train_flag = train_flag

     self.train_tf = transforms.Compose([#對訓練集的圖片進行預處理
             transforms.Resize(224),
             transforms.RandomHorizontalFlip(),
             transforms.RandomVerticalFlip(),
             transforms.ToTensor(),
             transform_BZ
         ])
     self.val_tf = transforms.Compose([#對測試集的圖片進行預處理
             transforms.Resize(224),
             transforms.ToTensor(),
             transform_BZ
         ])

返回資料對應標籤__getitem__

def __getitem__(self, index):
     img_path, label = self.imgs_info[index]
     #開啟圖片,並將RGBA轉換為RGB,這裡是通過PIL庫開啟圖片的
     img = Image.open(img_path)
     img = img.convert('RGB')
     img = self.padding_black(img) #將圖片新增上黑邊的
     if self.train_flag: #選擇是訓練集還是測試集
         img = self.train_tf(img)
     else:
         img = self.val_tf(img)
     label = int(label)

     return img, label

返回資料集的大小__len__

def __len__(self):
     return len(self.imgs_info)

由於前面已經對整合dataset的類進行了實現三種方法,那麼就可以在載入器中進行載入,將載入後的資料傳入到train函數或者test函數都可以

  • train_dataloader = DataLoader(dataset=train_data, num_workers=4, pin_memory=True, batch_size=batch_size, shuffle=True):使用載入器載入資料
  • train(train_dataloader, model, loss_fn, optimizer) test(test_dataloader, model):將資料傳入train或者test中進行訓練或者測試
  • 注意:LoadData是繼承了dataset的類
if __name__=='__main__':
    batch_size = 16

    # # 給訓練集和測試集分別建立一個資料集載入器
    train_data = LoadData("train.txt", True)
    valid_data = LoadData("test.txt", False)


    train_dataloader = DataLoader(dataset=train_data, num_workers=4, pin_memory=True, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(dataset=valid_data, num_workers=4, pin_memory=True, batch_size=batch_size)

    for X, y in test_dataloader:
        print("Shape of X [N, C, H, W]: ", X.shape)
        print("Shape of y: ", y.shape, y.dtype)
        break

三、原始碼

連結: https://pan.baidu.com/s/19Oo87gbcm9e8zvYGkBi95A 提取碼: 2tss 

到此這篇關於pytorch載入自己的資料集原始碼分享的文章就介紹到這了,更多相關pytorch載入自己的資料集內容請搜尋it145.com以前的文章或繼續瀏覽下面的相關文章希望大家以後多多支援it145.com!


IT145.com E-mail:sddin#qq.com