首頁 > 軟體

PyTorch中torch.manual_seed()的用法範例詳解

2022-06-11 18:00:46

一、torch.manual_seed(seed) 介紹

torch.manual_seed(seed) 功能描述

設定 CPU 生成亂數的 種子 ,方便下次復現實驗結果。

為 CPU 設定 種子 用於生成亂數,以使得結果是確定的。

當你設定一個隨機種子時,接下來的隨機演演算法生成數根據當前的隨機種子按照一定規律生成。
隨機種子作用域是在設定時到下一次設定時。要想重複實驗結果,設定同樣隨機種子即可。

語法

torch.manual_seed(seed) → torch._C.Generator

引數

seed,int型別,是種子 – CPU生成亂數的種子。取值範圍為 [-0x8000000000000000, 0xffffffffffffffff] ,十進位制是 [-9223372036854775808, 18446744073709551615] ,超出該範圍將觸發 RuntimeError 報錯。

返回

返回一個torch.Generator物件。

二、類似函數的功能

為CPU中設定種子,生成亂數:

torch.manual_seed(number)

為特定GPU設定種子,生成亂數:

torch.cuda.manual_seed(number)

為所有GPU設定種子,生成亂數:

# 如果使用多個GPU,應該使用torch.cuda.manual_seed_all()為所有的GPU設定種子。
torch.cuda.manual_seed_all(number)

使用原因 :

在需要生成亂資料的實驗中,每次實驗都需要生成資料。設定隨機種子是為了確保每次生成固定的亂數,這就使得每次實驗結果顯示一致了,有利於實驗的比較和改進。使得每次執行該 .py 檔案時生成的亂數相同。

三、範例

範例 1 :不設隨機種子,生成亂數

# test.py
import torch
print(torch.rand(1)) # 返回一個張量,包含了從區間[0, 1)的均勻分佈中抽取的一組亂數

每次執行test.py的輸出結果都不相同:

tensor([0.4351])

tensor([0.3651])

tensor([0.7465])

範例 2 :設定隨機種子,使得每次執行程式碼生成的亂數都一樣

# test.py
import torch
# 設定隨機種子
torch.manual_seed(0)
# 生成亂數
print(torch.rand(1)) # 返回一個張量,包含了從區間[0, 1)的均勻分佈中抽取的一組亂數

每次執行 test.py 的輸出結果都是一樣:

tensor([0.4963])

範例 3 :不同的隨機種子生成不同的值

改變隨機種子的值,設為 1 :

# test.py
import torch
torch.manual_seed(1)
print(torch.rand(1)) # 返回一個張量,包含了從區間[0, 1)的均勻分佈中抽取的一組亂數

每次執行 test.py,輸出結果都是:

tensor([0.7576])

改變隨機種子的值,設為 5 :

# test.py
import torch
torch.manual_seed(5)
print(torch.rand(1)) # 返回一個張量,包含了從區間[0, 1)的均勻分佈中抽取的一組亂數

每次執行 test.py,輸出結果都是:

tensor([0.8303])

可見不同的隨機種子能夠生成不同的亂數。

但只要隨機種子一樣,每次執行程式碼都會生成該種子下的亂數。

範例 4 :設定隨機種子後,是每次執行test.py檔案的輸出結果都一樣,而不是每次隨機函數生成的結果一樣
# test.py
import torch
torch.manual_seed(0)
print(torch.rand(1))
print(torch.rand(1))

輸出結果:

tensor([0.4963])
tensor([0.7682])

可以看到兩次列印 torch.rand(1) 函數生成的結果是不一樣的,但如果你再執行test.py,還是會列印:

tensor([0.4963])
tensor([0.7682])

範例 5 :如果你就是想要每次執行隨機函數生成的結果都一樣,那你可以在每個隨機函數前都設定一模一樣的隨機種子

# test.py
import torch
torch.manual_seed(0)
print(torch.rand(1))
torch.manual_seed(0)
print(torch.rand(1))

輸出結果:

tensor([0.4963])
tensor([0.4963])

參考連結

【pytorch】torch.manual_seed()用法詳解

總結

到此這篇關於PyTorch中torch.manual_seed()的法的文章就介紹到這了,更多相關PyTorch中torch.manual_seed()內容請搜尋it145.com以前的文章或繼續瀏覽下面的相關文章希望大家以後多多支援it145.com!


IT145.com E-mail:sddin#qq.com