<em>Mac</em>Book项目 2009年学校开始实施<em>Mac</em>Book项目,所有师生配备一本<em>Mac</em>Book,并同步更新了校园无线网络。学校每周进行电脑技术更新,每月发送技术支持资料,极大改变了教学及学习方式。因此2011
2021-06-01 09:32:01
官方檔案給出的解釋為:
總共有七個引數,其中只有前三個是必須的。由於大家普遍使用PyTorch的DataLoader來形成批次資料,因此batch_first也比較重要。LSTM的兩個常見的應用場景為文書處理和時序預測,因此下面對每個引數我都會從這兩個方面來進行具體解釋。
關於LSTM的輸入,官方檔案給出的定義為:
可以看到,輸入由兩部分組成:input、(初始的隱狀態h_0,初始的單元狀態c_0)
其中input:
input(seq_len, batch_size, input_size)
(h_0, c_0):
h_0(num_directions * num_layers, batch_size, hidden_size) c_0(num_directions * num_layers, batch_size, hidden_size)
h_0和c_0的shape一致。
關於LSTM的輸出,官方檔案給出的定義為:
可以看到,輸出也由兩部分組成:otput、(隱狀態h_n,單元狀態c_n)
其中output的shape為:
output(seq_len, batch_size, num_directions * hidden_size)
h_n和c_n的shape保持不變,引數解釋見前文。
batch_first
如果在初始化LSTM時令batch_first=True,那麼input和output的shape將由:
input(seq_len, batch_size, input_size) output(seq_len, batch_size, num_directions * hidden_size)
變為:
input(batch_size, seq_len, input_size) output(batch_size, seq_len, num_directions * hidden_size)
即batch_size提前。
簡單搭建一個LSTM如下所示:
class LSTM(nn.Module): def __init__(self, input_size, hidden_size, num_layers, output_size, batch_size): super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.output_size = output_size self.num_directions = 1 # 單向LSTM self.batch_size = batch_size self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True) self.linear = nn.Linear(self.hidden_size, self.output_size) def forward(self, input_seq): batch_size, seq_len = input_seq[0], input_seq[1] h_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device) c_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device) # output(batch_size, seq_len, num_directions * hidden_size) output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64) pred = self.linear(output) # (5, 30, 1) pred = pred[:, -1, :] # (5, 1) return pred
其中定義模型的程式碼為:
self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True) self.linear = nn.Linear(self.hidden_size, self.output_size)
我們加上具體的數位:
self.lstm = nn.LSTM(self.input_size=1, self.hidden_size=64, self.num_layers=5, batch_first=True) self.linear = nn.Linear(self.hidden_size=64, self.output_size=1)
再看前向傳播:
def forward(self, input_seq): batch_size, seq_len = input_seq[0], input_seq[1] h_0 = torch.randn(self.num_directions * self.num_layers, batch_size, self.hidden_size).to(device) c_0 = torch.randn(self.num_directions * self.num_layers, batch_size, self.hidden_size).to(device) # input(batch_size, seq_len, input_size) # output(batch_size, seq_len, num_directions * hidden_size) output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64) pred = self.linear(output) # (5, 30, 1) pred = pred[:, -1, :] # (5, 1) return pred
假設用前30個預測下一個,則seq_len=30,batch_size=5,由於設定了batch_first=True,因此,輸入到LSTM中的input的shape應該為:
input(batch_size, seq_len, input_size) = input(5, 30, 1)
經過DataLoader處理後的input_seq為:
input_seq(batch_size, seq_len, input_size) = input_seq(5, 30, 1)
然後將input_seq送入LSTM:
output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64)
根據前文,output的shape為:
output(batch_size, seq_len, num_directions * hidden_size) = output(5, 30, 64)
全連線層的定義為:
self.linear = nn.Linear(self.hidden_size=64, self.output_size=1)
然後將output送入全連線層:
pred = self.linear(output) # pred(5, 30, 1)
得到的預測值shape為(5, 30, 1),由於輸出是輸入右移,我們只需要取pred第二維度(time)中的最後一個資料:
pred = pred[:, -1, :] # (5, 1)
這樣,我們就得到了預測值,然後與label求loss,然後再反向更新引數即可。
到此這篇關於深入學習PyTorch中LSTM的輸入和輸出的文章就介紹到這了,更多相關PyTorch LSTM內容請搜尋it145.com以前的文章或繼續瀏覽下面的相關文章希望大家以後多多支援it145.com!
相關文章
<em>Mac</em>Book项目 2009年学校开始实施<em>Mac</em>Book项目,所有师生配备一本<em>Mac</em>Book,并同步更新了校园无线网络。学校每周进行电脑技术更新,每月发送技术支持资料,极大改变了教学及学习方式。因此2011
2021-06-01 09:32:01
综合看Anker超能充系列的性价比很高,并且与不仅和iPhone12/苹果<em>Mac</em>Book很配,而且适合多设备充电需求的日常使用或差旅场景,不管是安卓还是Switch同样也能用得上它,希望这次分享能给准备购入充电器的小伙伴们有所
2021-06-01 09:31:42
除了L4WUDU与吴亦凡已经多次共事,成为了明面上的厂牌成员,吴亦凡还曾带领20XXCLUB全队参加2020年的一场音乐节,这也是20XXCLUB首次全员合照,王嗣尧Turbo、陈彦希Regi、<em>Mac</em> Ova Seas、林渝植等人全部出场。然而让
2021-06-01 09:31:34
目前应用IPFS的机构:1 谷歌<em>浏览器</em>支持IPFS分布式协议 2 万维网 (历史档案博物馆)数据库 3 火狐<em>浏览器</em>支持 IPFS分布式协议 4 EOS 等数字货币数据存储 5 美国国会图书馆,历史资料永久保存在 IPFS 6 加
2021-06-01 09:31:24
开拓者的车机是兼容苹果和<em>安卓</em>,虽然我不怎么用,但确实兼顾了我家人的很多需求:副驾的门板还配有解锁开关,有的时候老婆开车,下车的时候偶尔会忘记解锁,我在副驾驶可以自己开门:第二排设计很好,不仅配置了一个很大的
2021-06-01 09:30:48
不仅是<em>安卓</em>手机,苹果手机的降价力度也是前所未有了,iPhone12也“跳水价”了,发布价是6799元,如今已经跌至5308元,降价幅度超过1400元,最新定价确认了。iPhone12是苹果首款5G手机,同时也是全球首款5nm芯片的智能机,它
2021-06-01 09:30:45