<em>Mac</em>Book项目 2009年学校开始实施<em>Mac</em>Book项目,所有师生配备一本<em>Mac</em>Book,并同步更新了校园无线网络。学校每周进行电脑技术更新,每月发送技术支持资料,极大改变了教学及学习方式。因此2011
2021-06-01 09:32:01
ResNet由一系列堆疊的殘差塊組成,其主要作用是通過無限制地增加網路深度,從而使其更加強大。在建立ResNet模型之前,讓我們先定義4個層,每個層由多個殘差塊組成。這些層的目的是降低空間尺寸,同時增加通道數量。
以ResNet50為例,我們可以使用以下程式碼來定義ResNet網路:
class ResNet(nn.Module): def __init__(self, num_classes=1000): super().__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace (續) 即模型需要在輸入層加入一些 normalization 和啟用層。 ```python import torch.nn.init as init class Flatten(nn.Module): def __init__(self): super().__init__() def forward(self, x): return x.view(x.size(0), -1) class ResNet(nn.Module): def __init__(self, num_classes=1000): super().__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.layer1 = nn.Sequential( ResidualBlock(64, 256, stride=1), *[ResidualBlock(256, 256) for _ in range(1, 3)] ) self.layer2 = nn.Sequential( ResidualBlock(256, 512, stride=2), *[ResidualBlock(512, 512) for _ in range(1, 4)] ) self.layer3 = nn.Sequential( ResidualBlock(512, 1024, stride=2), *[ResidualBlock(1024, 1024) for _ in range(1, 6)] ) self.layer4 = nn.Sequential( ResidualBlock(1024, 2048, stride=2), *[ResidualBlock(2048, 2048) for _ in range(1, 3)] ) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.flatten = Flatten() self.fc = nn.Linear(2048, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): init.constant_(m.weight, 1) init.constant_(m.bias, 0) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = self.flatten(x) x = self.fc(x) return x
改進點如下:
nn.Sequential
元件,將多個殘差塊組合成一個功能塊(layer)。這樣可以方便地修改網路深度,並將其與其他層分離九更容易上手,例如遷移學習中重新訓練頂部分類器時。nn.Conv2d
和批標準化層等神經網路元件,我們使用了PyTorch中的內建初始化函數。它們會自動為我們設定好每層的引數。Flatten
層,將4維輸出展平為2維張量,以便通過接下來的全連線層進行分類。我們現在已經實現了ResNet50模型,接下來我們將解釋如何訓練和測試該模型。
首先我們需要定義損失函數和優化器。在這裡,我們使用交叉熵損失函數,以及Adam優化器。
import torch.optim as optim device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = ResNet(num_classes=1000).to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001)
在使用PyTorch進行訓練時,我們通常會建立一個迴圈,為每個批次的輸入資料計算損失並對模型引數進行更新。以下是該回圈的程式碼:
def train(model, optimizer, criterion, train_loader, device): model.train() train_loss = 0 correct = 0 total = 0 for batch_idx, (inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to(device), targets.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() train_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() acc = 100 * correct / total avg_loss = train_loss / len(train_loader) return acc, avg_loss
在上面的訓練迴圈中,我們首先通過model.train()
代表進入訓練模式。然後使用optimizer.zero_grad()
清除
以上就是利用Pytorch實現ResNet網路構建及模型訓練的詳細內容,更多關於Pytorch ResNet構建網路模型訓練的資料請關注it145.com其它相關文章!
相關文章
<em>Mac</em>Book项目 2009年学校开始实施<em>Mac</em>Book项目,所有师生配备一本<em>Mac</em>Book,并同步更新了校园无线网络。学校每周进行电脑技术更新,每月发送技术支持资料,极大改变了教学及学习方式。因此2011
2021-06-01 09:32:01
综合看Anker超能充系列的性价比很高,并且与不仅和iPhone12/苹果<em>Mac</em>Book很配,而且适合多设备充电需求的日常使用或差旅场景,不管是安卓还是Switch同样也能用得上它,希望这次分享能给准备购入充电器的小伙伴们有所
2021-06-01 09:31:42
除了L4WUDU与吴亦凡已经多次共事,成为了明面上的厂牌成员,吴亦凡还曾带领20XXCLUB全队参加2020年的一场音乐节,这也是20XXCLUB首次全员合照,王嗣尧Turbo、陈彦希Regi、<em>Mac</em> Ova Seas、林渝植等人全部出场。然而让
2021-06-01 09:31:34
目前应用IPFS的机构:1 谷歌<em>浏览器</em>支持IPFS分布式协议 2 万维网 (历史档案博物馆)数据库 3 火狐<em>浏览器</em>支持 IPFS分布式协议 4 EOS 等数字货币数据存储 5 美国国会图书馆,历史资料永久保存在 IPFS 6 加
2021-06-01 09:31:24
开拓者的车机是兼容苹果和<em>安卓</em>,虽然我不怎么用,但确实兼顾了我家人的很多需求:副驾的门板还配有解锁开关,有的时候老婆开车,下车的时候偶尔会忘记解锁,我在副驾驶可以自己开门:第二排设计很好,不仅配置了一个很大的
2021-06-01 09:30:48
不仅是<em>安卓</em>手机,苹果手机的降价力度也是前所未有了,iPhone12也“跳水价”了,发布价是6799元,如今已经跌至5308元,降价幅度超过1400元,最新定价确认了。iPhone12是苹果首款5G手机,同时也是全球首款5nm芯片的智能机,它
2021-06-01 09:30:45