首頁 > 軟體

python中關於CIFAR10資料集的使用

2023-02-03 18:00:42

關於CIFAR10資料集的使用

主要解決了如何把資料集與transforms結合在一起的問題。

CIFAR10的官方解釋

torchvision.datasets.CIFAR10(
root: str, 
train: bool = True, 
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False)

註釋:

  • root (string)存在 cifar-10-batches-py 目錄的資料集的根目錄,如果下載設定為 True,則將儲存到該目錄。
  • train (bool, optional)如果為True,則從訓練集建立資料集, 如果為False,從測試集建立資料集。
  • transform (callable, optional)它接受一個 PIL 影象並返回一個轉換後的版本。 例如,transforms.RandomCrop/transforms.ToTensor
  • target_transform (callable, optional) 接收目標並對其進行轉換的函數/轉換。
  • download (bool, optional)如果為 true,則從 Internet 下載資料集並將其放在根目錄中。 如果資料集已經下載,則不會再次下載。

實戰操作

1.CIAFR10資料集的下載

程式碼如下:

import torchvision   #匯入torchvision這個類

train_set = torchvision.datasets.CIFAR10(root = "./dataset", train = True, 
download= True)  #從訓練集建立資料集
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False,
 download=True)    #從測試集建立資料集

root = "./dataset",將下載的資料集儲存在這個資料夾下;download= True,從 Internet 下載資料集並將其放在根目錄中,這裡就是在相對路徑中,建立dataset資料夾,將資料集儲存在dataset中。

2.檢視下載的CIAFR10資料集

執行程式,開始下載資料集。下載成功後,可以進行一些檢視。程式碼如下:

接著輸入:

print(train_set[0])  #檢視train_set訓練集中的第一個資料
print(train_set.classes)   #檢視train_set訓練集中有多少個類別
 
img, target = train_set[0]
print(img)
print(target)
print(train_set.classes[target])
img.show()  #顯示圖片

輸出結果:

(<PIL.Image.Image image mode=RGB size=32x32 at 0x161E924B8D0>, 6)
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship',
'truck']
<PIL.Image.Image image mode=RGB size=32x32 at 0x161E924B710>
6
frog

註釋:可以看見,train_set資料集中有10個類別,train_set中第0個元素的target是6,也就是說,這個元素是屬於第7個類別frog的。

3.資料轉換

因為這些圖片型別都是PIL Image,如果要供給pytorch使用的話,需要將資料全都轉化成tensor型別。

完整程式碼如下:

import torchvision   #匯入torchvision這個類
from torch.utils.tensorboard import SummaryWriter

from torchvision import transforms
dataset_transforms = transforms.ToTensor()

# dataset_transforms = torchvision.transforms.Compose([
#     torchvision.transforms.ToTensor()
# ])    第3  4 行程式碼可以用compose直接寫
train_set = torchvision.datasets.CIFAR10(root = "./dataset", train = True, transform=dataset_transforms, download= True) #訓練集
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transforms, download=True)   #測試集

writer = SummaryWriter("logs")

# print(train_set[0])  #檢視train_set訓練集中的第一個資料
# print(train_set.classes)   #檢視train_set訓練集中有多少個類別

# img, target = train_set[0]
# print(img)
# print(target)
# print(train_set.classes[target])
# img.show()
for i in range(20):
    img, target = train_set[i]
    writer.add_image("cifar10_test2", img, i)

writer.close()

小結:CIFAR10資料集記憶體很小,只有100多m,下載方便。對我們學習資料集非常友好,練習的時候,我們可以使用SummaryWriter來將資料寫入tensorboard中。

CIFAR-10 資料集簡介

復現程式碼的過程中,簡單瞭解了作者使用的資料集CIFAR-10 dataset ,簡單記錄一下。

CIFAR-10資料集是8000萬微小圖片的標籤子集,它的收集者是:Alex Krizhevsky, Vinod Nair, Geoffrey Hinton。

資料集由6萬張32*32的彩色圖片組成,一共有10個類別。每個類別6000張圖片。其中有5萬張訓練圖片及1萬張測試圖片。

資料集被劃分為5個訓練塊和1個測試塊,每個塊1萬張圖片。

測試塊包含了1000張從每個類別中隨機選擇的圖片。訓練塊包含隨機的剩餘影象,但某些訓練塊可能對於一個類別的包含多於其他類別,訓練塊包含來自各個類別的5000張圖片。

這些類是完全互斥的,及在一個類別中出現的圖片不會出現在其它類中。

資料集版本

作者提供了3個版本的資料集:python version; Matlab version; binary version。

可根據自己的需求選擇。

資料集下載地址:下載連結

資料集佈置

以python version進行介紹,Matlab version與之相同。

下載後獲得檔案 data_batch_1, data_batch_2,…, data_batch_5。測試塊相同。這些檔案中的每一個都是用cPickle生成的python pickled物件。

具體使用方法:

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

返回字典類,每個塊的檔案包含一個字典類,包含以下元素:

  • data: 一個100003072的numpy陣列(unit8)每個行儲存3232的彩色圖片,3072=1024*3,分別是red, green, blue。儲存方式以行為主。
  • labels:使用0-9進行索引。

資料集包含的另一個檔案batches.meta同樣包含python字典,用於載入label_names。如:label_names[0] == “airplane”, label_names[1] == “automobile”

總結

以上為個人經驗,希望能給大家一個參考,也希望大家多多支援it145.com。


IT145.com E-mail:sddin#qq.com