首頁 > 軟體

Python如何載入模型並檢視網路

2022-07-15 14:04:24

載入模型並檢視網路

載入模型,以vgg19為例。

開啟終端

> python
Python 3.7.2 (tags/v3.7.2:9a3ffc0492, Dec 23 2018, 23:09:28) [MSC v.1916 64 bit
(AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> from torchvision import models
>>> model = models.vgg19(pretrained=True) #此時如果是第一次載入會開始下載模型的pth檔案
>>> print(model.model)

結果:

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace)
    (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace)
    (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (24): ReLU(inplace)
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): ReLU(inplace)
    (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): ReLU(inplace)
    (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (33): ReLU(inplace)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): ReLU(inplace)
    (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

注意,直接列印模型是沒有辦法看到模型結構的,只能看到帶模型引數的pth檔案內容;需要列印model.model才可以看到模型本身。

神經網路_模型的儲存,模型的載入

模型的儲存(torch.save)

方式1(模型結構+模型引數)

引數:儲存位置

# 建立模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 儲存方式1——模型結構+模型引數
torch.save(vgg16, "vgg16_method1.pth")

方式2(模型引數)

# 儲存方式2  模型引數(官方推薦)。儲存成字典,只儲存網路模型中的一些引數
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

模型的載入(torch.load)

對應儲存方式1

引數:模型路徑

# 方式1 --》 儲存方式1
model1 = torch.load("vgg16_method1.pth")

對應儲存方式2

vgg16.load_state_dict("vgg16_method2.pth")

輸出為字典形式。若要回復網路,採用以下形式:

model2 = torch.load("vgg16_method2.pth")  #輸出是字典形式
# 恢復網路結構
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(model2)

方式1儲存,載入時需注意事項

新建自己的網路:

class test(nn.Module):
    def __init__(self):
        super(lh, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x

儲存自己的網路:

Test = test()
# 儲存自己定義的網路
torch.save(Test, "Test_method1.pth")

載入自己的網路:

model3 = torch.load("Test_method1.pth")

會報錯!!!!!!

解決辦法(需要注意):

將定義的網路複製到載入的python檔案中:

class test(nn.Module):
    def __init__(self):
        super(test, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x
model3 = torch.load("Test_method1.pth")

以上為個人經驗,希望能給大家一個參考,也希望大家多多支援it145.com。


IT145.com E-mail:sddin#qq.com