首頁 > 軟體

pytorch中forwod函數在父類別中的呼叫方式解讀

2023-02-20 06:00:44

pytorch forwod函數在父類別中的呼叫

問題背景

最近在研究Detetron2的程式碼結構時,發現有些網路程式碼裡面沒有forward函數,卻照樣可以推理,深入挖掘之後,發現其將forword函數都寫在了同一個父類別裡面。

這就牽涉到了下面這個問題,子類中沒有forward函數,只有父類別中有forward函數,這樣能不能正常呼叫網路。

import torch.nn as nn

class Network1(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self,x):
        return x

class Network2(Network1):
    def __init__(self):
        super().__init__()


data = [1,2,3]
model = Network2().eval()
output = model(data)
print(output)

輸出結果如下:

[1,2,3]

pytorch forward方法呼叫原理

在使用Pytorch自定義網路模型的時候,我們需要繼承nn.Module這個類,然後定義forward方法來實現前向轉播。

如下圖的一個自定義的網路模型

首先該網路模型的初始化方法__init__需要繼承父類別nn.Module的初始化方法,用語句super().init()實現。

並在初始化方法裡面,定義了折積、BN、啟用函數等。接下來定義forward方法,將整個網路連線起來。

有了上面的定義,我們可以範例化一個物件,例如:

fire2 = Fire(96, 128,16,64,64)

實現前向傳播,使用 y= fire2(x) 其中x是該網路的輸入,y是輸出,實現了forward方法的額功能。

這裡就會有人感到奇怪,forward作為Fire這個類的方法,使用的時候不應該是 y= fire2.forward(x)嗎。

這裡為什麼一個類的範例可以當做方法直接使用?這是因為這個Fire類繼承的父類別nn.Module裡面定義了__call__方法。

一個類如果定義了__call__方法,則該類的範例就可以作為一個方法那樣直接使用。

例如下列程式碼[1]

class A():
    def __call__(self):
        print('i can be called like a function')
 
a = A()
a()

就會執行print函數,列印其中搞的文字。這裡需要區別的是,範例化的時候,類的名稱後面括號可以傳遞引數,例如前面範例化Fire的時候,傳遞in_channel,out_channel等引數。

但是要利用__call__的特性,是在範例名後面的括號中傳遞引數,例如上面的例子a(),這裡雖然沒有引數,但是也可以改變__call__的定義使之可以傳遞引數。

回到網路模型的內容上來。翻看nn.Module的部分原始碼[2],可以發現,nn.Module裡面果然定義了__call__,並且傳遞了引數*input。在__call__的定義中國,呼叫了self.forward。

這裡其實還有一個點值得注意。其實nn.Module裡面並沒有定義forward,但他卻呼叫self.forward,嚴格來說,他是“想要”呼叫self.forward。

如果我們沒有定義一個類,例如Fire,來繼承nn.Module,並且在這個類裡面定義forward,那麼nn.Module中__call__下面的self.forward就是無效的。

這意味著,父類別中__call__下面呼叫的函數,可以在繼承他的子類中定義

下面給出一個簡單的例子。

class father():
    def __call__(self):
        self.forward()
        print('I''m the father!')

class child(father):
    def forward(self):
        print('Forward!')
F=father()
C=child()

這裡定義了父類別father,並定義了繼承他的一個子類child。此外還進行了他們的範例化。

顯然,在father的__call__方法下面,呼叫了self.forward,但是沒有定義。child在繼承了father之後,定義了forward。

首先,這段程式碼不會報錯,即使father的__call__下面的self.forward並沒有定義,這也是前面我說的,雖然沒有定義forward,但是可以理解為他“想要”呼叫self.forward。

那麼在child記成了father之後,進行了forward的定義,這使得child本身可以呼叫forward。

在上面這段程式碼的基礎上,如果我們執行F(),彙報下面這一段錯誤,這解釋了forward沒有定義,只是“想要”呼叫self.forward。

如果我們執行C(),則如下圖輸出。

顯然,在child中補充了forward的定義,就可以成功呼叫。

總結

以上為個人經驗,希望能給大家一個參考,也希望大家多多支援it145.com。


IT145.com E-mail:sddin#qq.com