選自medium作者:Kaiyu Yue機器之心編譯編輯:陳訓練大模型時,如何優雅地減少 GPU 記憶體消耗?你不妨試試這個 TorchShard 庫,兼具模型並行與資料並行等特點,還具有與 PyTorch 相同
2021-07-28 03:05:09
作者:Kaiyu Yue 機器之心編譯 編輯:陳
訓練大模型時,如何優雅地減少 GPU 記憶體消耗?你不妨試試這個 TorchShard 庫,兼具模型並行與資料並行等特點,還具有與 PyTorch 相同的 API 設計。
建立一個標準的 PyTorch 擴展庫,用於使用模型並行性進行擴展訓練;
以一種簡單、自然的方式使用 PyTorch。
import torchshard as ts
ts.init_process_group(group_size=2) # init parallel groups
m = torch.nn.Sequential(
torch.nn.Linear(20, 30, bias=True),
ts.nn.ParallelLinear(30, 30, bias=True, dim=None), # equal to nn.Linear()
ts.nn.ParallelLinear(30, 30, bias=True, dim=0), # parallel in row dimension
ts.nn.ParallelLinear(30, 30, bias=True, dim=1), # parallel in column dimension
).cuda()
x = m(x) # forward
loss = ts.nn.functional.parallel_cross_entropy(x, y) # parallel loss function
loss.backward() # backward
torch.save(
ts.collect_state_dict(m, m.state_dict()), 'm.pt') # save model state
torchshard 包含必要的功能和操作,如 torch 包;
torchshard.nn 包含圖形的基本構建塊,如 torch.nn 包;
torchshard.nn.functional 包含 torchshard.nn 的相應功能操作,如 torch.nn.functional 包;
torchshard.distributed 包含處理分散式張量和組的基本功能,如 torch.distributed 包更容易使用。
pip install torchshard
import torchshard as ts
ts.distributed.init_process_group(group_size=args.world_size)
import resnetmodel = resnet.__dict__[args.arch](pretrained=args.pretrained)ts.nn.ParallelLinear.convert_parallel_linear( model, dim=args.model_parallel_dim)print("=> paralleling model'{}'".format(args.arch))
criterion = ts.nn.ParallelCrossEntropyLoss().cuda(args.gpu)
x = ts.distributed.gather(x, dim=0) # gather input along the dim of batch size x = self.fc(x)
output = model(images)if args.enable_model_parallel: target = ts.distributed.gather(target, dim=0)loss = criterion(output, target)
state_dict = model.state_dict()# collect states across all ranksstate_dict = ts.collect_state_dict(model, state_dict)if ts.distributed.get_rank() == 0: torch.save(state_dict, 'resnet50.pt') # save as before
if ts.distributed.get_rank() == 0: state_dict = torch.load('resnet50.pt')# relocate state_dict() for all ranksstate_dict = ts.relocate_state_dict(model, state_dict)model.load_state_dict(state_dict) # load as before
# gradscaler
scaler = torch.cuda.amp.GradScaler(enabled=args.enable_amp_mode)
with torch.cuda.amp.autocast(enabled=args.enable_amp_mode): # compute output
output = model(images)
if args.enable_model_parallel:
target = ts.distributed.gather(target, dim=0)
loss = criterion(output, target)
# compute gradient and do SGD step
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
from torch.distributed.optim import ZeroRedundancyOptimizer
if args.enable_zero_optim:
print('=> using ZeroRedundancyOptimizer')
optimizer = torch.distributed.optim.ZeroRedundancyOptimizer(
model.parameters(),
optimizer_class=torch.optim.SGD,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
else:
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
相關文章
選自medium作者:Kaiyu Yue機器之心編譯編輯:陳訓練大模型時,如何優雅地減少 GPU 記憶體消耗?你不妨試試這個 TorchShard 庫,兼具模型並行與資料並行等特點,還具有與 PyTorch 相同
2021-07-28 03:05:09
博雯 發自 凹非寺量子位 報道 | 公眾號 QbitAI這是一座非典型的物流倉庫。走進大門,首先看到的是在空無一人的地面上不斷運動的機器人:他們的日常工作就是在偌大的倉庫裡來回
2021-07-28 03:03:31
彪彪覺得上次贏過鍵鍵,逢人就說自己如何刻苦練習,說鍵鍵如何操作不方便,不值得花時間在鍵鍵身上。這不,練習了幾天,就又來找鍵鍵了約戰。鍵鍵見彪彪來了,很無奈地說:你都已經勝了,你
2021-07-28 03:03:20
最近,Win10使用者在安裝軟體時收到了軟體,收到了「無效驅動器」,這導致軟體安裝失敗,發生了什麼?由於系統臨時目錄被移動到D磁碟或安裝路徑選擇,這可能會導致錯誤。讓我們來看看解
2021-07-28 03:03:12
2021年7月27日,OPPO「超能代表釋出會」正式釋出新一代全智慧手錶旗艦——OPPO Watch 2系列。它支援獨立eSIM、50+APP、100+運動模式,並提供多種專業健康監測功能。新系列採用
2021-07-28 03:03:00
7 月 31 日至 8 月 1 日,GOTC 全球開源技術峰會深圳站將在深圳會展中心召開。由貝殼技術 VP,PHP 開發組核心成員惠新宸(鳥哥,Laruence)擔任出品人的「程式語言藝術」專題論壇將在
2021-07-28 03:02:54