首頁 > 科技

基於微軟開源深度學習演算法,用 Python 實現影象和視訊修復

2021-06-22 03:03:56

作者 | 李秋鍵 責編 | 歐陽姝黎

影象修復是計算機視覺領域的一個重要任務,在數字藝術品修復、公安刑偵面部修復等種種實際場景中被廣泛應用。影象修復的核心挑戰在於為缺失區域合成視覺逼真和語義合理的畫素,要求合成的畫素與原畫素具有一致性。

傳統的影象修復技術有基於結構和紋理兩種方法。基於結構的影象修復演算法具有代表性的是 Bertalmio 等提出的 BSCB 模型和 Shen 等提出的基於曲率擴散的修復模型 CDD。基於紋理的修復演算法中具有代表性的有 Criminisi 等提出的基於 patch 的紋理合成演算法。這兩種傳統的修復演算法可以修復小塊區域的破損,但是在破損區域越來越大時, 修復效果則直線下降, 並且修復結果存在影象模糊、結構扭曲、紋理不清晰和視覺不連貫等問題。

近年來,隨著硬體裝置等計算能力的不斷提升, 以及深度學習技術在影象翻譯、影象超解析度、影象修復等計算機視覺領域的迅速發展, 採用深度學習技術的修復方法能夠捕獲影象的高層語義資訊, 與傳統的修復方法相比, 具有良好的修復效果。故今天我們使用 Python 實現 Bringing Old Photo Back to Life 演算法實現對影象和視訊的修復。得到的模型評估效果如下:

基本介紹

傳統的影象修復技術可以分為基於結構的影象修復技術和基於紋理的影象修復技術兩大類。其中,變分偏微分方程模型是基於結構的影象修復技術的典型代表,由變分模型和偏微分方程模型組成。紋理合成是基於紋理的影象修復技術的典型代表。傳統數字影象修復技術分類如下圖所示。

傳統的影象修復方法結果中存在語義資訊不完整、影象模糊等問題,無法達到目前對影象修復的要求。基於深度學習的影象修復演算法能夠捕獲更多影象的高階特徵,修復結果較好,所以經常用於影象修復。目前基於生成式對抗網路的影象修復是深度學習影象修復領域的一大研究熱點,為影象修復技術的發展奠定了堅實的基礎。而我們使用的演算法就是基於深度學習的微軟開源的 Bringing Old Photo Back to Life 去修復影象。

1.1 環境要求

本次環境使用的是 Python3.6.5+windows 平臺。主要用的庫有:

  • PyTorch 模組。PyTorch 是一個基於 Torch 的 Python 開源機器學習庫,用於自然語言處理等應用程式。它主要由 Facebookd 的人工智慧小組開發,不僅能夠 實現強大的 GPU 加速,同時還支援動態神經網路,這一點是現在很多主流框架如 TensorFlow 都不支援的。PyTorch提供了兩個高階功能:1.具有強大的 GPU 加速的張量計算(如Numpy) 2.包含自動求導系統的深度神經網路 除了 Facebook 之外,Twitter、GMU 和 Salesforce 等機構都採用了 PyTorch。

  • pillow 模組。Pillow 是Python裡的影象處理庫(PIL:Python Image Library),提供了了廣泛的檔案格式支援,強大的影象處理能力,主要包括影象儲存、影象顯示、格式轉換以及基本的影象處理操作等。

  • Numpy 模組。Numpy 是應用 Python 進行科學計算時的基礎模組。它是一個提供多維陣列物件的 Python 庫,除此之外,還包含了多種衍生的物件(比如掩碼式陣列(masked arrays)或矩陣)以及一系列的為快速計算陣列而生的例程,包括數學運算,邏輯運算,形狀操作,排序,選擇,I/O,離散傅立葉變換,基本線性代數,基本統計運算,隨機模擬等等。

  • collections 這個模組實現了特定目標的容器,以提供 Python 標準內建容器 dict、list、set、tuple 的替代選擇。Counter:字典的子類,提供了可雜湊物件的計數功能;defaultdict:字典的子類,提供了一個工廠函數,為字典查詢提供了預設值;OrderedDict:字典的子類,保留了他們被新增的順序;namedtuple:創建命名元組子類的工廠函數;deque:類似列表容器,實現了在兩端快速新增(append)和彈出(pop);ChainMap:類似字典的容器類,將多個對映集合到一個視圖裡面。

修復模型演算法

本文所使用的 Bringing Old Photo Back to Life 演算法流程分別為全局修復、臉部檢測、臉部特徵加強和特徵融合。其中隱空間修復網路採用局部-全局視野融合,其中全局支路採用 nonlocal 模組大大增強處理視野。我們對局部破損圖片建立了資料集,訓練網路預測破損區域,該破損區域顯式的送入 nonlocal 模組,並設定模組感受野為非破損區域

2.1 全局視野修復

本文的模型主要由三個部分組成兩個變分自編碼器(variational-autoencoder,VAE)和一個 latent space 對映網路,每個部分都可以看作是單獨的一個模組。下面將介紹網路設計的思想和不同部分的作用。

模型使用了兩個 VAE:

第一個 VAE 用於將合成的老照片(模糊、磨損)進行編碼到隱空間。

第二個 VAE 用於將對應的乾淨的老照片進行編碼。

然後,在隱空間學習從汙損的老照片到乾淨照片的對映。

就這樣,實現了一個老照片的修復演算法。

這個有點像在學習控制圖片清晰、磨損的一個特徵表示,通過控制這個特徵,可以達到修復破損照片的目的。

關鍵程式碼如下:

model = networks.UNet(in_channels=1, out_channels=1, depth=4, conv_num=2, wf=6, padding=True, batch_norm=True, up_mode="upsample",with_tanh=False, sync_bn=True, antialiasing=True,)for image_name in imagelist:idx += 1print("processing", image_name)results = []scratch_image = Image.open(os.path.join(config.test_path, image_name)).convert("RGB")w, h = scratch_image.sizetransformed_image_PIL = data_transforms(scratch_image, config.input_size)scratch_image = transformed_image_PIL.convert("L")scratch_image = tv.transforms.ToTensor()(scratch_image)scratch_image = tv.transforms.Normalize([0.5], [0.5])(scratch_image)scratch_image = torch.unsqueeze(scratch_image, 0)scratch_image = scratch_image.to(config.GPU)P = torch.sigmoid(model(scratch_image))P = P.data.cpu()tv.utils.save_image((P >= 0.4).float(),os.path.join(output_dir, image_name[:-4] + ".png",),nrow=1,padding=0,normalize=True,)    transformed_image_PIL.save(os.path.join(input_dir, image_name[:-4] + ".png"))

2.2 局部臉部修復加強

臉部特徵的加強使用pixpix2模型對臉部二次修復。其中, Pix2Pix模型由Isola等於2017年提出, 它由U-Net和PatchGAN組成, 分別充當Pix2Pix模型中的生成器和判別器。該模型使使用者只需提供一個草圖便能生成一個與之對應的高質量影象; 對應到影象著色工作中, 網路接收真實影象的亮度資訊, 對亮度資訊進行特徵提取並預測影象顏色值。

關鍵程式碼:

def create_optimizers(self, opt):    G_params = list(self.netG.parameters())if opt.use_vae:        G_params += list(self.netE.parameters())if opt.isTrain:        D_params = list(self.netD.parameters())    beta1, beta2 = opt.beta1, opt.beta2if opt.no_TTUR:        G_lr, D_lr = opt.lr, opt.lrelse:        G_lr, D_lr = opt.lr / 2, opt.lr * 2    optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2))    optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2))return optimizer_G, optimizer_Ddef generate_fake(self, input_semantics, degraded_image, real_image, compute_kld_loss=False):    z = None    KLD_loss = Noneif self.opt.use_vae:        z, mu, logvar = self.encode_z(real_image)if compute_kld_loss:            KLD_loss = self.KLDLoss(mu, logvar) * self.opt.lambda_kld    fake_image = self.netG(input_semantics, degraded_image, z=z)    assert (not compute_kld_loss    ) or self.opt.use_vae, "You cannot compute KLD loss if opt.use_vae == False"return fake_image, KLD_lossdef discriminate(self, input_semantics, fake_image, real_image):if self.opt.no_parsing_map:        fake_concat = fake_image        real_concat = real_imageelse:        fake_concat = torch.cat([input_semantics, fake_image], dim=1)        real_concat = torch.cat([input_semantics, real_image], dim=1)    fake_and_real = torch.cat([fake_concat, real_concat], dim=0)    discriminator_out = self.netD(fake_and_real)    pred_fake, pred_real = self.divide_pred(discriminator_out)return pred_fake, pred_real

原始碼:https://pan.baidu.com/s/1lAzmWvAEyxi6RFsLpA5l_Q

提取碼:osuh


IT145.com E-mail:sddin#qq.com