首頁 > 軟體

pytorch SummaryWriter儲存紀錄檔的方法

2023-08-28 18:05:19

在pytorch框架中,關於紀錄檔的儲存,其中一種方式就是借鑑使用了tensorboard的庫。所以我們需要在環境中安裝tensorboard庫,然後再在工程中進行該庫的呼叫

1 安裝與匯入

安裝:conda install tensorboardX 或者 pip install tensorboardX
匯入
from tensorboardX import SummaryWriter
 writer = SummaryWriter(logPath)
 ...
 writer.close()

2 新增需要儲存標量資料

  • add_scalar(tag, scalar_value, global_step=None) 從原始碼中我們能看到核心的三個引數為前三個。通俗的講分別代表
    • tag:圖的標籤名,唯一標識
    • scalar_value:y軸資料,標量資料的具體數值
    • global_step:x軸資料,要記錄的全域性步長值
  • add_scalars(main_tag, tag_scalar_dit)多項標題記錄方法,其中:
    • main_tag —— 該圖的標籤
    • tag_salar_dict —— 字典形式的tag-scalar_value對

原始碼中也有例子:

from tensorboardX import SummaryWriter
import numpy as np

writer = SummaryWriter('run/logs')

max_epoch = 100
for x in range(max_epoch):

    writer.add_scalar('t/y=2x', x * 2, x)    #x*2為y軸資料,x為x軸資料
    writer.add_scalar('t/y=pow_2_x', 2^x, x)
    writer.add_scalars('scalar_group', {"xsinx": x * np.sin(x),
                                     "xcosx": x * np.cos(x)}, x)
    writer.close()

執行完該指令碼後,執行tensorboard命令:tensorboard --logdir=./run/

在瀏覽器中開啟連結:【http://localhost:6006/】

3 新增需要儲存圖片資料

從原始碼中我們能看到add_image的主要引數如下。通俗的講分別代表

  • tag:曲線圖名字,唯一標識
  • img_tensor:圖片資料,型別要求為 tensor/numpy/string 等
  • global_step:要記錄的全域性步長值
  • dataformats:圖片輸入的預設維度。注意是"CHW"
from tensorboardX import SummaryWriter
import numpy as np
img = np.zeros((3, 100, 100))
img[0] = np.arange(0, 10000).reshape(100, 100) / 10000
img[1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000

img_HWC = np.zeros((100, 100, 3))
img_HWC[:, :, 0] = np.arange(0, 10000).reshape(100, 100) / 10000
img_HWC[:, :, 1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000

writer = SummaryWriter('run/logs')
writer.add_image('my_image', img, 0)

# If you have non-default dimension setting, set the dataformats argument.
writer.add_image('my_image_HWC', img_HWC, 0, dataformats='HWC')
writer.close()

4 直方圖的記錄

畫直方圖主要為了看引數的分佈狀態,使用add_histogram(tag, values, global_step=None, bins=’tensorflow’, walltime=None),其中tag, value, global_step的含義同上,範例如下:

# 每個epoch,記錄梯度,權值
for name, param in net.named_parameters():
    writer.add_histogram(name + '_grad', param.grad, epoch)
    writer.add_histogram(name + '_data', param, epoch)

5 網路結構的記錄

展示結構圖使用add_graph(model, input_to_model=None, verbose=False)

writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix")
# 模型
fake_img = torch.randn(1, 3, 32, 32)
yolo = Yolo(classes=2)
writer.add_graph(yolo, fake_img)
writer.close()

到此這篇關於pytorch SummaryWriter儲存紀錄檔的文章就介紹到這了,更多相關pytorch 儲存紀錄檔內容請搜尋it145.com以前的文章或繼續瀏覽下面的相關文章希望大家以後多多支援it145.com!


IT145.com E-mail:sddin#qq.com