首頁 > 軟體

PyTorch中torch.tensor()和torch.to_tensor()的區別

2023-01-28 18:02:34

前言

在跑模型的時候,遇到如下報錯

UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).

網上查了一下,發現將 torch.tensor() 改寫成 torch.as_tensor() 就可以避免報錯了。

# 如下寫法報錯
 feature = torch.tensor(image, dtype=torch.float32)
 
# 改為
feature = torch.as_tensor(image, dtype=torch.float32)

然後就又仔細研究了下 torch.as_tensor()torch.tensor() 的區別,在此記錄。

1、torch.as_tensor()

new_data = torch.as_tensor(data, dtype=None,device=None)->Tensor

作用:生成一個新的 tensor, 這個新生成的tensor 會根據原資料的實際情況,來決定是進行淺拷貝,還是深拷貝。當然,會優先淺拷貝,淺拷貝會共用記憶體,並共用 autograd 歷史記錄。

情況一:資料型別相同 且 device相同,會進行淺拷貝,共用記憶體

import numpy
import torch

a = numpy.array([1, 2, 3])
t = torch.as_tensor(a)
t[0] = -1

print(a)   # [-1  2  3]
print(a.dtype)   # int64
print(t)   # tensor([-1,  2,  3])
print(t.dtype)   # torch.int64
import numpy
import torch

a = torch.tensor([1, 2, 3], device=torch.device('cuda'))
t = torch.as_tensor(a)
t[0] = -1

print(a)   # tensor([-1,  2,  3], device='cuda:0')
print(t)   # tensor([-1,  2,  3], device='cuda:0')

情況二: 資料型別相同,但是device不同,深拷貝,不再共用記憶體

import numpy
import torch

import numpy
a = numpy.array([1, 2, 3])
t = torch.as_tensor(a, device=torch.device('cuda'))
t[0] = -1

print(a)   # [1 2 3]
print(a.dtype)   # int64
print(t)   # tensor([-1,  2,  3], device='cuda:0')
print(t.dtype)   # torch.int64

情況三:device相同,但資料型別不同,深拷貝,不再共用記憶體

import numpy
import torch

a = numpy.array([1, 2, 3])
t = torch.as_tensor(a, dtype=torch.float32)
t[0] = -1

print(a)   # [1 2 3]
print(a.dtype)   # int64
print(t)   # tensor([-1.,  2.,  3.])
print(t.dtype)   # torch.float32

2、torch.tensor()

torch.tensor() 是深拷貝方式。

torch.tensor(data, dtype=None, device=None, requires_grad=False, pin_memory=False)

深拷貝:會拷貝 資料型別 和 device,不會記錄 autograd 歷史 (also known as a “leaf tensor” 葉子tensor)

重點是:

  • 如果原資料的資料型別是:list, tuple, NumPy ndarray, scalar, and other types,不會 waring
  • 如果原資料的資料型別是:tensor,使用 torch.tensor(data) 就會報waring
# 原資料型別是:tensor 會發出警告
import numpy
import torch

a = torch.tensor([1, 2, 3], device=torch.device('cuda'))
t = torch.tensor(a)
t[0] = -1

print(a)
print(t)

# 輸出:
# tensor([1, 2, 3], device='cuda:0')
# tensor([-1,  2,  3], device='cuda:0')
# /opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
# 原資料型別是:list, tuple, NumPy ndarray, scalar, and other types, 沒警告
import torch
import numpy

a =  numpy.array([1, 2, 3])
t = torch.tensor(a) 

b = [1,2,3]
t= torch.tensor(b)

c = (1,2,3)
t= torch.tensor(c)

結論就是:以後儘量用 torch.as_tensor()

總結

到此這篇關於PyTorch中torch.tensor()和torch.to_tensor()區別的文章就介紹到這了,更多相關torch.tensor()和torch.to_tensor()區別內容請搜尋it145.com以前的文章或繼續瀏覽下面的相關文章希望大家以後多多支援it145.com!


IT145.com E-mail:sddin#qq.com