首頁 > 科技

不怕訓練大模型,TorchShard庫減少GPU記憶體消耗API與PyTorch相同

2021-07-28 03:05:09

選自medium

者:Kaiyu Yue

機器之心編譯

編輯:陳

訓練大模型時,如何優雅地減少 GPU 記憶體消耗?你不妨試試這個 TorchShard 庫,兼具模型並行與資料並行等特點,還具有與 PyTorch 相同的 API 設計。

模型並行效能夠促進視覺任務的效能。但是目前,還沒有一個標準庫可以讓我們像採用混合精度等其他 SOTA 技術那樣輕鬆地採用模型並行性。

最近,馬里蘭大學帕克分校計算機科學系的研究者 Kaiyu Yue 開源了一個工具TorchShard,這是一個輕量級的引擎,用於將 PyTorch 張量切片成並行的 shard。當模型擁有大量的線性層(例如 BERT、GPT)或者很多類(數百萬)時,TorchShard 可以減少 GPU 記憶體並擴展訓練規模,它具有與 PyTorch 相同的 API 設計。

項目地址:https://github.com/KaiyuYue/torchshard

BERT 和 GPT 等超大模型正在成為 NLP 領域應用中的趨勢。然而訓練這種大模型面臨記憶體限制的問題,為了解決這個難題,研究者使用 Megatron-LM 和 PyTorch-Lightning 模型並行性擴大訓練。其中,Megatron-LM 只專注於大規模訓練語言模型,而 PyTorch-Lightning 僅基於 sharded 優化器狀態和梯度,如 DeepSpeed。

在計算機視覺任務中,我們會在訓練基於 Transformer、MLP 模型或在數百萬個類中訓練模型時遇到同樣的問題。TorchShard 的目標是:

  • 建立一個標準的 PyTorch 擴展庫,用於使用模型並行性進行擴展訓練;

  • 以一種簡單、自然的方式使用 PyTorch。

TorchShard 是對模型並行單元(mpu)的徹底重寫,是 Megatron-LM 核心。最重要的是,TorchShard 具有與 PyTorch 相同的 API 設計,這意味著所有的子類和子函數都保持與 PyTorch 相同。例如,如果你想讓原來的線性層 torch.nn. linear 是並行的,只需將 torch 變成 ts,並呼叫帶有 dim 參數的子類 nn.ParallelLinear,如下所示:


import torchshard as ts
ts.init_process_group(group_size=2) # init parallel groups
m = torch.nn.Sequential( torch.nn.Linear(20, 30, bias=True), ts.nn.ParallelLinear(30, 30, bias=True, dim=None), # equal to nn.Linear() ts.nn.ParallelLinear(30, 30, bias=True, dim=0), # parallel in row dimension ts.nn.ParallelLinear(30, 30, bias=True, dim=1), # parallel in column dimension).cuda()
x = m(x) # forwardloss = ts.nn.functional.parallel_cross_entropy(x, y) # parallel loss functionloss.backward() # backward
torch.save( ts.collect_state_dict(m, m.state_dict()), 'm.pt') # save model state

除此之外,TorchShard 還支援與 DDP 一起使用時的各種特性,儲存和載入 shard checkpoints,初始化 shard 參數,以及跨多臺機器和 GPU 處理張量。具體如下:

  • torchshard 包含必要的功能和操作,如 torch 包;

  • torchshard.nn 包含圖形的基本構建塊,如 torch.nn 包;

  • torchshard.nn.functional 包含 torchshard.nn 的相應功能操作,如 torch.nn.functional 包;

  • torchshard.distributed 包含處理分散式張量和組的基本功能,如 torch.distributed 包更容易使用。

如何開始 TorchShard?

安裝要求:Python 版本 3.6 以上(含)以及 PyTorch 版本 1.9.0 以上(含)。通過 pip 安裝 TorchShard 庫:


pip install torchshard

這裡以 ImageNet 上訓練 ResNet-50 為例,展示僅需幾行程式碼就能在項目中使用 TorchShard。通常 ResNet-50 設計正規化包含兩部分:卷積塊和全連線層,如下圖 1 所示。一般來說,由於大量的類依賴於資料集,最後的線性層比卷積塊有更多的參數。所以我們切片最後一個線性層來檢查其最大尺寸。

圖 1:DDP 以及 DDP + TorchShard 前向訓練流。

在上圖 1 中,左邊展示了傳統的 DDP 訓練正規化。假設我們有兩個等級,DDP 將強制每個等級有重複的模型參數。然而,TorchShard 會將層級參數切片到不同的等級,從而減少整個 GPU 記憶體。現在向 ImageNet 官方訓練指令碼新增一些程式碼,修改後的版本已經成為 TorchShard 項目的一部分。

首先將 torchshard import 進來:


import torchshard as ts

然後需要初始化模型並行的程序組,就像初始化 DDP 程序組的方法一樣。只需要設定一個功能參數來告訴 torchshard 應該從目標層中切片出多少個 shard。


ts.distributed.init_process_group(group_size=args.world_size)

接下來將模型轉換為並行版本,其中可以直接將整個模型輸入到轉換輔助函數中,無需特殊處理。


import resnetmodel = resnet.__dict__[args.arch](pretrained=args.pretrained)ts.nn.ParallelLinear.convert_parallel_linear(    model, dim=args.model_parallel_dim)print("=> paralleling model'{}'".format(args.arch))

此外,不要忘記損失函數 torchshard.nn.ParallelCrossEntropy ,該損失函數可以根據輸入張量在原始 PyTorch 版本和並行版本之間切換運行模式。例如,如果輸入張量是由 torchshard 並行層產生的,torchshard.nn.ParallelCrossEntropy 將以並行方式計算損失值。


criterion = ts.nn.ParallelCrossEntropyLoss().cuda(args.gpu)

當模型並行模式(TorchShard)和資料並行模式(DDP)一起工作時,我們需要處理並行層的輸入。每個等級中的參數和訓練資料都不同。因此,我們在 ResNet forward 中的並行線性層之前收集輸入張量。


x = ts.distributed.gather(x, dim=0) # gather input along the dim of batch size x = self.fc(x)

同樣地,我們在計算損失值之前收集目標張量。


output = model(images)if args.enable_model_parallel:    target = ts.distributed.gather(target, dim=0)loss = criterion(output, target)

最後,使用 TorchShard 函數儲存和載入 checkpoints 非常簡單。TorchShard 提供了名為 torchshard.collect_state_dict 基本函數用於儲存 checkpoints,torchshard.relocate_state_dict 用於載入 checkpoints。

儲存檢查點:


state_dict = model.state_dict()# collect states across all ranksstate_dict = ts.collect_state_dict(model, state_dict)if ts.distributed.get_rank() == 0:    torch.save(state_dict, 'resnet50.pt') # save as before

載入檢查點:


if ts.distributed.get_rank() == 0:     state_dict = torch.load('resnet50.pt')# relocate state_dict() for all ranksstate_dict = ts.relocate_state_dict(model, state_dict)model.load_state_dict(state_dict) # load as before

現在我們已經完成了在 ImageNet 上為 shard 訓練新增程式碼, 然後可以通過增加類的數量來擴展它,即最後一個線性層的輸出特徵維度。訓練指令碼可以在 torchshard/project/imagenet 中找到。下圖展示了在 8 個 NVIDIA TITAN-XP (12196 MiB) GPU 、類數 ≤ 1000000 上和 16 個 GPU 、類數為 2000000 上訓練 ResNet-50 擴展能力。

圖 2:在不同並行策略下使用標準 ResNet 訓練設定(即輸入大小 224 和批量大小 256)的 GPU 記憶體成本。

使用 AMP 與 ZeRO

TorchShard 以簡單自然的 PyTorch 方式與其他技術(例如自動混合精度 AMP 以及 ZeRO)一起混合使用。


# gradscalerscaler = torch.cuda.amp.GradScaler(enabled=args.enable_amp_mode)

with torch.cuda.amp.autocast(enabled=args.enable_amp_mode): # compute output output = model(images) if args.enable_model_parallel: target = ts.distributed.gather(target, dim=0) loss = criterion(output, target)

# compute gradient and do SGD stepscaler.scale(loss).backward()scaler.step(optimizer)scaler.update()optimizer.zero_grad()

圖 3:在不同並行策略以及 AMP 下,使用標準的 ResNet 訓練設定時(輸入尺寸 224,batch 大小 256),使用 GPU 記憶體的成本。

ZeRO 是 DeepSpeed 的核心,與 PyTorch >= 1.9.0 一起使用。如果你想測試一個函數,請安裝最新版本的指令碼來運行,程式碼如下:


from torch.distributed.optim import ZeroRedundancyOptimizer

if args.enable_zero_optim: print('=> using ZeroRedundancyOptimizer') optimizer = torch.distributed.optim.ZeroRedundancyOptimizer( model.parameters(), optimizer_class=torch.optim.SGD, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)else: optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

圖 4:在不同的並行策略和 ZeRO 優化器下,在標準 ResNet 訓練設定(輸入大小 224 和批大小 256) GPU 記憶體成本。

此外,TorchShard 還提供了基本的 Python API 以及和相應的模板檔案,以簡化自定義並行層的實現。

研究者將持續開發 TorchShard,如 TorchShard 下一個特性是新的資料取樣器 torchshard.utils.data.DistributedGroupSampler,它的命名遵循 torch.utils.data.DistributedSampler。該取樣器旨在幫助使用者構建 M-way 資料並行、N-way 模型並行,使得其就像 DDP 中的 DistributedSampler 一樣簡單。使用者唯一要做的就是設定模型並行組號,然後 DistributedGroupSampler 來確保同一模型並行組中的模組具有相同的訓練資料。


IT145.com E-mail:sddin#qq.com