<em>Mac</em>Book项目 2009年学校开始实施<em>Mac</em>Book项目,所有师生配备一本<em>Mac</em>Book,并同步更新了校园无线网络。学校每周进行电脑技术更新,每月发送技术支持资料,极大改变了教学及学习方式。因此2011
2021-06-01 09:32:01
pytorch載入圖片資料集有兩種方法。
1.ImageFolder 適合於分類資料集,並且每一個類別的圖片在同一個資料夾, ImageFolder載入的資料集, 訓練資料為檔案件下的圖片, 訓練標籤是對應的資料夾, 每個資料夾為一個類別
匯入ImageFolder()包 from torchvision.datasets import ImageFolder
在Flower_Orig_dataset資料夾下有flower_orig 和 sunflower這兩個資料夾, 這兩個資料夾下放著同一個類別的圖片。 使用 ImageFolder 載入的圖片, 就會返回圖片資訊和對應的label資訊, 但是label資訊是根據資料夾給出的, 如flower_orig就是標籤0, sunflower就是標籤1。
1. 匯入包和設定transform
import torch from torchvision import transforms, datasets import torch.nn as nn from torch.utils.data import DataLoader transforms = transforms.Compose([ transforms.Resize(256), # 將圖片短邊縮放至256,長寬比保持不變: transforms.CenterCrop(224), #將圖片從中心切剪成3*224*224大小的圖片 transforms.ToTensor() #把圖片進行歸一化,並把資料轉換成Tensor型別 ])
2.載入資料集: 將分類圖片的父目錄作為路徑傳遞給ImageFolder(), 並傳入transform。這樣就有了要載入的資料集, 之後就可以使用DataLoader載入資料, 並構建網路訓練。
path = r'D:資料集Flower_Orig_dataset' data_train = datasets.ImageFolder(path, transform=transforms) data_loader = DataLoader(data_train, batch_size=64, shuffle=True) for i, data in enumerate(data_loader): images, labels = data print(images.shape) print(labels.shape) break
具體步驟:
1. 首先要有一個txt檔案, 這個檔案格式是: 圖片路徑 標籤. 這樣的格式, 所以使用os庫, 遍歷自己的圖片名, 並把標籤和圖片路徑寫入txt檔案。
2. 有了這個txt檔案, 我們就可以在類裡面構造我們的資料集.
2.1 把圖片路徑和圖片標籤分割開, 有兩個列表, 一個列表是圖片路徑名, 一個列表是標籤號, 有一點就是第 i 個圖片列表和 第 i 個標籤是對應的
3. 重寫__len__方法 和 __getitem__方法
3.1 getitem方法中, 獲得對應的圖片路徑,並用PIL庫讀取檔案把圖片transfrom後, 在getitem函數中返回讀取的圖片和標籤即可
4.就可以構建資料集範例和載入資料集.
定義一個用來生成[ 圖片路徑 標籤] 這樣的txt檔案函數
def make_txt(root, file_name, label): path = os.path.join(root, file_name) data = os.listdir(path) f = open(path+'\'+'f.txt', 'w') for line in data: f.write(line+' '+str(label)+'n') f.close() #呼叫函數生成兩個資料夾下的txt檔案 make_txt(path, file_name='flower_orig', label=0) make_txt(path, file_name='sunflower', label=1)
將連個txt檔案合併成一個txt檔案,表示資料集所有的圖片和標籤
def link_txt(file1, file2): txt_list = [] path = r'D:資料集Flower_Orig_datasetdata.txt' f = open(path, 'a') f1 = open(file1, 'r') data1 = f1.readlines() for line in data1: txt_list.append(line) f2 = open(file2, 'r') data2 = f2.readlines() for line in data2: txt_list.append(line) for line in txt_list: f.write(line) f.close() f1.close() f2.close() #呼叫函數, 將兩個資料夾下的txt檔案合併 file1 = r'D:資料集Flower_Orig_datasetflower_origf.txt' file2 = r'D:資料集Flower_Orig_datasetsunflowerf.txt' link_txt(file1=file1, file2=file2)
現在我們已經有了我們製作資料集所需要的txt檔案, 接下來要做的即使繼承Dataset類, 來構建自己的資料集 , 別忘了前面說的 構建資料集步驟, 在__getitem__函數中, 需要拿到圖片路徑和標籤, 並且用PIL庫方法讀取圖片,對圖片進行transform轉換後,返回圖片資訊和標籤資訊
我們讀取圖片的根目錄, 在根目錄下有所有圖片的txt檔案, 拿到txt檔案後, 先讀取txt檔案, 之後遍歷txt檔案中的每一行, 首先去除掉尾部的換行符, 在以空格切分,前半部分是圖片名稱, 後半部分是圖片標籤, 當圖片名稱和根目錄結合,就得到了我們的圖片路徑 class MyDataset(Dataset): def __init__(self, img_path, transform=None): super(MyDataset, self).__init__() self.root = img_path self.txt_root = self.root + 'data.txt' f = open(self.txt_root, 'r') data = f.readlines() imgs = [] labels = [] for line in data: line = line.rstrip() word = line.split() imgs.append(os.path.join(self.root, word[1], word[0])) labels.append(word[1]) self.img = imgs self.label = labels self.transform = transform def __len__(self): return len(self.label) def __getitem__(self, item): img = self.img[item] label = self.label[item] img = Image.open(img).convert('RGB') #此時img是PIL.Image型別 label是str型別 if transforms is not None: img = self.transform(img) label = np.array(label).astype(np.int64) label = torch.from_numpy(label) return img, label
載入我們的資料集:
path = r'D:資料集Flower_Orig_dataset' dataset = MyDataset(path, transform=transform) data_loader = DataLoader(dataset=dataset, batch_size=64, shuffle=True)
接下來我們就可以構建我們的網路架構:
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3,16,3) self.maxpool = nn.MaxPool2d(2,2) self.conv2 = nn.Conv2d(16,5,3) self.relu = nn.ReLU() self.fc1 = nn.Linear(55*55*5, 1200) self.fc2 = nn.Linear(1200,64) self.fc3 = nn.Linear(64,2) def forward(self,x): x = self.maxpool(self.relu(self.conv1(x))) #113 x = self.maxpool(self.relu(self.conv2(x))) #55 x = x.view(-1, self.num_flat_features(x)) x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x def num_flat_features(self, x): size = x.size()[1:] num_features = 1 for s in size: num_features *= s return num_features
訓練我們的網路:
model = Net() criterion = torch.nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.01) epochs = 10 for epoch in range(epochs): running_loss = 0.0 for i, data in enumerate(data_loader): images, label = data out = model(images) loss = criterion(out, label) optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() if(i+1)%10 == 0: print('[%d %5d] loss: %.3f'%(epoch+1, i+1, running_loss/100)) running_loss = 0.0 print('finished train')
儲存網路模型(這裡不止是儲存引數,還儲存了網路結構)
#儲存模型 torch.save(net, 'model_name.pth') #儲存的是模型, 不止是w和b權重值 # 讀取模型 model = torch.load('model_name.pth')
到此這篇關於pytorch載入自己的圖片資料集的2種方法的文章就介紹到這了,更多相關pytorch載入圖片資料集內容請搜尋it145.com以前的文章或繼續瀏覽下面的相關文章希望大家以後多多支援it145.com!
相關文章
<em>Mac</em>Book项目 2009年学校开始实施<em>Mac</em>Book项目,所有师生配备一本<em>Mac</em>Book,并同步更新了校园无线网络。学校每周进行电脑技术更新,每月发送技术支持资料,极大改变了教学及学习方式。因此2011
2021-06-01 09:32:01
综合看Anker超能充系列的性价比很高,并且与不仅和iPhone12/苹果<em>Mac</em>Book很配,而且适合多设备充电需求的日常使用或差旅场景,不管是安卓还是Switch同样也能用得上它,希望这次分享能给准备购入充电器的小伙伴们有所
2021-06-01 09:31:42
除了L4WUDU与吴亦凡已经多次共事,成为了明面上的厂牌成员,吴亦凡还曾带领20XXCLUB全队参加2020年的一场音乐节,这也是20XXCLUB首次全员合照,王嗣尧Turbo、陈彦希Regi、<em>Mac</em> Ova Seas、林渝植等人全部出场。然而让
2021-06-01 09:31:34
目前应用IPFS的机构:1 谷歌<em>浏览器</em>支持IPFS分布式协议 2 万维网 (历史档案博物馆)数据库 3 火狐<em>浏览器</em>支持 IPFS分布式协议 4 EOS 等数字货币数据存储 5 美国国会图书馆,历史资料永久保存在 IPFS 6 加
2021-06-01 09:31:24
开拓者的车机是兼容苹果和<em>安卓</em>,虽然我不怎么用,但确实兼顾了我家人的很多需求:副驾的门板还配有解锁开关,有的时候老婆开车,下车的时候偶尔会忘记解锁,我在副驾驶可以自己开门:第二排设计很好,不仅配置了一个很大的
2021-06-01 09:30:48
不仅是<em>安卓</em>手机,苹果手机的降价力度也是前所未有了,iPhone12也“跳水价”了,发布价是6799元,如今已经跌至5308元,降价幅度超过1400元,最新定价确认了。iPhone12是苹果首款5G手机,同时也是全球首款5nm芯片的智能机,它
2021-06-01 09:30:45