Pytorch 摳圖算法 Deep Image Matting 模型實(shí)現(xiàn)

????????本文旨在實(shí)現(xiàn)摳圖算法 Semantic Human Matting 的第二階段模型 M-Net,也即 Deep Image Matting。值得說(shuō)明的是,本文實(shí)現(xiàn)的模型與原始論文略有出入,除了模型的輸入層有細(xì)微差別之外,損失函數(shù)也作了簡(jiǎn)化(但無(wú)本質(zhì)差別)。

????????本文完整代碼見(jiàn) GitHub: deep_image_matting_pytorch。Pytorch 需要 1.1.0 或后續(xù)版本。

????????本文 訓(xùn)練數(shù)據(jù) 來(lái)源于 愛(ài)分割 公司開(kāi)源的 數(shù)據(jù)集,總共包含 34426 張圖片和對(duì)應(yīng)的 alpha 通道,數(shù)據(jù)量非常大,能公開(kāi)特別值得點(diǎn)贊。但同時(shí)因?yàn)闃?biāo)注數(shù)據(jù)的 alpha 通道精度不高,導(dǎo)致訓(xùn)練后測(cè)試效果較差。建議使用 Deep Image Matting 的數(shù)據(jù)集訓(xùn)練。

Semantic Human Matting 摳圖模型

????????總的來(lái)說(shuō),Semantic Human Matting 論文提出的自動(dòng)摳圖的思路特別清晰明了(如上圖),對(duì)于一張待摳圖像,首先通過(guò)語(yǔ)義分割模型(即 T-Net)分割出前景F_s、背景B_s和未知區(qū)域U_sF_s + B_s + U_s = 1),然后廣義的認(rèn)為前景(F_s)+ 未知區(qū)域(U_s)組成一個(gè)三分圖( Trimap),此時(shí)再利用 Deep Image Matting(即 M-Net) 即可高質(zhì)量的完成摳圖。完整的模型將在接下來(lái)的幾篇文章逐步實(shí)現(xiàn),本文只關(guān)注該模型的第二階段(M-Net)。

????????M-Net 接受待摳圖像(前景與背景的 RGB 3 通道合成)以及語(yǔ)義分割模型輸出的 3 通道預(yù)測(cè)(F_s, B_s, U_s)拼接而成的 6 通道輸入,經(jīng)過(guò)編碼器提取圖像特征之后,由解碼器得到預(yù)測(cè) \alpha_r。如果語(yǔ)義分割模型分割的精度較高,那么可以認(rèn)為 F_s, B_s 對(duì)應(yīng)的區(qū)域已經(jīng)很好的摳出了大部分的前景和背景,唯一需要提升準(zhǔn)確率的是待摳對(duì)象的邊緣區(qū)域,所以模型的第二階段 M-Net 的目的就是細(xì)化的預(yù)測(cè)邊緣區(qū)域(這正是 Deep Image Matting 要干的事情),兩部分結(jié)合即得到最終的預(yù)測(cè):
F_s + U_s\alpha_r 。

這個(gè)公式可以這樣理解:
\textrm{預(yù)測(cè)的前景} = \textrm{確定區(qū)域上的前景} + \textrm{未知區(qū)域上的前景}

也就是:
P(\textrm{前景}) = P(確定區(qū)域上前景) + P(未知區(qū)域上前景)

根據(jù)全概率公式,用符號(hào)來(lái)表示則是:
\begin{align} \alpha &= P(F) \\ &= P(F|\mathrm{known})P(\mathrm{known}) + P(F|\mathrm{unknown})P(\mathrm{unknow}) \\ &= \frac{F_s}{F_s + B_s}(F_s + B_s) + \alpha_rU_s\\ &= F_s + U_s\alpha_r \end{align}

????????但上述公式存在一個(gè)缺陷,即如果待摳目標(biāo)外有大塊噪聲,則最終的預(yù)測(cè)也消除不了這個(gè)噪聲,如下圖:

語(yǔ)義分割之后的前景帶有外部噪聲(衣服左側(cè)的小照片)

為了消除第一階段可能包含的外部噪聲,本文的在實(shí)現(xiàn) M-Net 的時(shí)候做了一個(gè)小的改動(dòng):第二階段的輸入改為由待摳圖像 + F_s 組成的 4 通道圖片(此時(shí),相當(dāng)于將 F_s 看成是三分圖 trimap),并且將第二階段的預(yù)測(cè)則作為最終的預(yù)測(cè)。另外,第二階段損失函數(shù)簡(jiǎn)化為只用 alpha 通道的損失更好。

一、模型實(shí)現(xiàn)

????????Deep Image Matting 原文的模型如下:

Deep Image Matting 論文模型

模型先通過(guò)一個(gè)編碼器提取特征,之后經(jīng)過(guò)一個(gè)解碼器預(yù)測(cè)一個(gè)初始的 alpha 通道,這個(gè)預(yù)測(cè)值效果已經(jīng)很好,但作者為了進(jìn)一步提升摳圖的精度,又額外的接了幾層細(xì)化的小網(wǎng)絡(luò),然后將細(xì)化后的輸出作為整個(gè)模型的最終輸出。具體來(lái)說(shuō),首先將待摳圖像(3 通道)以及事先準(zhǔn)備好的三分圖(trimap)合成一個(gè) 4 通道圖像,然后經(jīng)過(guò) VGG16 的前 13 個(gè)卷積層以及之后的 1 個(gè)全連接層(看成是 1x1 的卷積層),總共 14 個(gè)卷積層提取圖像特征(此時(shí)已做了 5 個(gè)最大池化,因此圖像分辨率下降了 32 倍,如果輸入是 320x320,那么特征映射的分辨率就變成了 10x10),這是模型的編碼器階段。接下來(lái)對(duì)圖像特征進(jìn)行解碼,即開(kāi)始解碼器階段。解碼器使用 6 個(gè)卷積層(5x5 的卷積核)和 5 個(gè) 反池化層,每個(gè)反池化層將特征映射的分辨率提升 2 倍,因此解碼器的輸出與模型輸入的大小一樣。這里,使用反池化層的效果要比直接使用轉(zhuǎn)置卷積(deconvolution)的效果要好。雖然他們都是為了提升圖像分辨率,但使用轉(zhuǎn)置卷積并不能很好的摳出細(xì)節(jié),而使用反池化層卻可以摳圖頭發(fā)絲等非常細(xì)的前景。為了最求極致效果,作者又接了一個(gè)小網(wǎng)絡(luò),將待摳圖像和編碼器預(yù)測(cè)的 alpha 通道合成一個(gè) 4 通道圖像,然后通過(guò) 4 個(gè) 3x3 的卷積層得到細(xì)化后的 alpha 通道預(yù)測(cè),作為最后的輸出。

Deep Image Matting 效果圖:(a) 原圖;(b) 編碼器-解碼器階段結(jié)果;(c)細(xì)化階段結(jié)果

????????損失方面,總共用了 3 個(gè)分損失來(lái)合成網(wǎng)絡(luò)的損失:

  • 編碼器階段預(yù)測(cè)的 alpha 通道和真實(shí)的 alpha 通道的損失;
  • 編碼器階段使用預(yù)測(cè)的 alpha 合成的圖像和真實(shí)的 alpha 合成的圖像的損失;
  • 細(xì)化階段預(yù)測(cè)的 alpha 通道和真實(shí)的 alpha 通道的損失。

這些損失都是逐點(diǎn)損失,即平方和誤差:

預(yù)測(cè)與真實(shí) alpha 通道之間的損失
由前景、背景和 alpha 通道合成圖像之間的損失

三個(gè)損失使用加權(quán)和形成整個(gè)網(wǎng)絡(luò)最后反向傳播的總損失。

????????Deep Image Matting 雖然論文上報(bào)告的效果很驚人,但實(shí)際實(shí)現(xiàn)時(shí)(在個(gè)人應(yīng)用數(shù)據(jù)集上)泛化性能不夠理想。

????????Semantic Human Matting(SHM)這篇論文的 M-Net 在以上基礎(chǔ)上做了一些簡(jiǎn)化和修改。首先,為了防止網(wǎng)絡(luò)容量太大造成過(guò)擬合,SHM 只使用 VGG16 的前 13 個(gè)卷積層及 4 個(gè)最大池化層來(lái)作為編碼器,相應(yīng)的,解碼器階段也就少一個(gè)反池化層。另外,為了加速網(wǎng)絡(luò)收斂,所有的卷積層(編碼器以及解碼器的)都帶批標(biāo)準(zhǔn)化(Batch Normalization)處理。其次,網(wǎng)絡(luò)的輸入由 4 通道變成了 6 通道,這樣做一方面沒(méi)有影響網(wǎng)絡(luò)性能(論文 4.2 節(jié)),另一方面也是為了方便與 T-Net 對(duì)接,因?yàn)?T-Net 輸出 前景、背景、未知 3 個(gè)預(yù)測(cè)通道,與待摳圖像的 3 通道直接合成即得到 6 通道輸入。最后,SHM 直接去掉了 Deep Image Matting 網(wǎng)絡(luò)的細(xì)化階段,因此損失也相應(yīng)的減少為 2 個(gè)分損失。

????????本文基本忠實(shí)的實(shí)現(xiàn)了 SHM 的 M-Net 結(jié)構(gòu),但如本文開(kāi)始時(shí)候說(shuō)的那樣,將 6 通道的輸入改成了 4 通道,且為了完全引入 VGG16 的預(yù)訓(xùn)練模型,直接在 VGG16 的最前面接了一個(gè)輸入為 6 通道、輸出為 4 通道的卷積層。此外,本文將 M-Net 的預(yù)測(cè)作為最終的輸出,以及訓(xùn)練時(shí)不再求合成圖像的損失(以下模型實(shí)現(xiàn)時(shí),loss 函數(shù)是支持合成圖像損失的)。

????????總的來(lái)說(shuō),Deep Image Matting (或 M-Net)網(wǎng)絡(luò)是非常清晰明了的,實(shí)現(xiàn)也很簡(jiǎn)單,模型文件 model.py 如下:

# -*- coding: utf-8 -*-
"""
Created on Sun Jul 21 07:08:58 2019

@author: shirhe-lyh

Implementation of paper:
    Deep Image Matting, Ning Xu, eta., arxiv:1703.03872
"""

import torch
import torchvision as tv

VGG16_BN_MODEL_URL = 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'

VGG16_BN_CONFIGS = {
    '13conv':
        [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 
         'M', 512, 512, 512],
    '10conv':
        [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
    }


def make_layers(cfg, batch_norm=False):
    """Copy from: torchvision/models/vgg.
    
    Changs retrue_indices in MaxPool2d from False to True.
    """
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [torch.nn.MaxPool2d(kernel_size=2, stride=2, 
                                          return_indices=True)]
        else:
            conv2d = torch.nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, torch.nn.BatchNorm2d(v), 
                           torch.nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, torch.nn.ReLU(inplace=True)]
            in_channels = v
    return torch.nn.Sequential(*layers)


class VGGFeatureExtractor(torch.nn.Module):
    """Feature extractor by VGG network."""
    
    def __init__(self, config=None, batch_norm=True):
        """Constructor.
        
        Args:
            config: The convolutional architecture of VGG network.
            batch_norm: A boolean indicating whether the architecture 
                include Batch Normalization layers or not.
        """
        super(VGGFeatureExtractor, self).__init__()
        self._config = config
        if self._config is None:
            self._config = VGG16_BN_CONFIGS.get('10conv')
        self.features = make_layers(self._config, batch_norm=batch_norm)
        self._indices = None
        
    def forward(self, x):
        self._indices = []
        for layer in self.features:
            if isinstance(layer, torch.nn.modules.pooling.MaxPool2d):
                x, indices = layer(x)
                self._indices.append(indices)
            else:
                x = layer(x)
        return x
    
    
def vgg16_bn_feature_extractor(config=None, pretrained=True, progress=True):
    model = VGGFeatureExtractor(config, batch_norm=True)
    if pretrained:
        state_dict = tv.models.utils.load_state_dict_from_url(
            VGG16_BN_MODEL_URL, progress=progress)
        model.load_state_dict(state_dict, strict=False)
    return model


class DIM(torch.nn.Module):
    """Deep Image Matting."""
    
    def __init__(self, feature_extractor):
        """Constructor.
        
        Args:
            feature_extractor: Feature extractor, such as VGGFeatureExtractor.
        """
        super(DIM, self).__init__()
        # Head convolution layer, number of channels: 4 -> 3
        self._head_conv = torch.nn.Conv2d(in_channels=4, out_channels=3,
                                          kernel_size=5, padding=2)
        # Encoder
        self._feature_extractor = feature_extractor
        self._feature_extract_config = self._feature_extractor._config
        # Decoder
        self._decode_layers = self.decode_layers()
        # Prediction
        self._final_conv = torch.nn.Conv2d(self._feature_extract_config[0], 1,
                                           kernel_size=5, padding=2)
        self._sigmoid = torch.nn.Sigmoid()
        
    def forward(self, x):
        x = self._head_conv(x)
        x = self._feature_extractor(x)
        indices = self._feature_extractor._indices[::-1]
        index = 0
        for layer in self._decode_layers:
            if isinstance(layer, torch.nn.modules.pooling.MaxUnpool2d):
                x = layer(x, indices[index])
                index += 1
            else:
                x = layer(x)
        x = self._final_conv(x)
        x = self._sigmoid(x)
        return x
    
    def decode_layers(self):
        layers = []
        strides = [1]
        channels = []
        config_reversed = self._feature_extract_config[::-1]
        for i, v in enumerate(config_reversed):
            if v == 'M':
                strides.append(2)
                channels.append(config_reversed[i+1])
        channels.append(channels[-1])
        in_channels = self._feature_extract_config[-1]
        for stride, out_channels in zip(strides, channels):
            if stride == 2:
                layers += [torch.nn.MaxUnpool2d(kernel_size=2, stride=2)]
            layers += [torch.nn.Conv2d(in_channels, out_channels,
                                       kernel_size=5, padding=2),
                       torch.nn.BatchNorm2d(num_features=out_channels),
                       torch.nn.ReLU(inplace=True)]
            in_channels = out_channels
        return torch.nn.Sequential(*layers)
    
    
def loss(alphas_pred, alphas_gt, images=None, epsilon=1e-12):
    losses = torch.sqrt(
        torch.mul(alphas_pred - alphas_gt, alphas_pred - alphas_gt) + 
        epsilon)
    loss = torch.mean(losses)
    if images is not None:
        images_fg_gt = torch.mul(images, alphas_gt)
        images_fg_pred = torch.mul(images, alphas_pred)
        images_fg_error = images_fg_pred - images_fg_gt
        losses_image = torch.sqrt(
            torch.mul(images_fg_error, images_fg_error) + epsilon)
        loss += torch.mean(losses_image)
    return loss

????????Pytorch 的官方是帶有 VGG 系列模型的,使用也很方便,比如使用帶批標(biāo)準(zhǔn)化層的 VGG16 直接寫(xiě)為:

vgg = torchvision.models.vgg16_bn(pretrained=True)

其中 pretrained=True 表示導(dǎo)入在 ImageNet 上預(yù)訓(xùn)練的參數(shù)。但因?yàn)?,我們是只使?VGG16 的前 13 個(gè)卷積層,而不需要后面的全連接層,因此,不會(huì)像上面那樣直接使用,而是要從 torchvision 的官方實(shí)現(xiàn)中截取卷積層的部分(官方實(shí)現(xiàn)見(jiàn)文件 Python 安裝路徑下的 site-packages/torchvision/models/vgg.py)。我們主要復(fù)制該文件中的 make_layers 函數(shù),但因?yàn)楹竺娼獯a器階段要用反池化,所以還要做一些修改:要把

layers += [nn.MaxPool2d(kernel_size=2, stride=2)]

改為

layers += [torch.nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)]

之所以要加上 return_indices=True,是因?yàn)楹竺娣闯鼗瘜右玫竭@些池化層的池化過(guò)程中的最大值的下標(biāo)(從而要記下來(lái))。也正因?yàn)槌鼗瘜佣喾祷亓艘粋€(gè)值(同時(shí)返回特征映射和最大值下標(biāo)張量),因此在重載 forward 函數(shù)時(shí)要進(jìn)行如下的區(qū)別對(duì)待:

def forward(self, x):
    self._indices = []
    for layer in self.features:
        if isinstance(layer, torch.nn.modules.pooling.MaxPool2d):
            x, indices = layer(x)
            self._indices.append(indices)
        else:
            x = layer(x)
    return x

除此之外,編碼器階段都是非常簡(jiǎn)單的,無(wú)需贅言。

????????來(lái)看解碼器階段。只需要重點(diǎn)關(guān)注一下反池化層(接后續(xù)的卷積層)的實(shí)現(xiàn)即可。具體也很簡(jiǎn)單:

unpool = torch.nn.MaxUnpool2d(kernel_size=2, stride=2)
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=5, padding=2)

即,先用反池化操作提升 2 倍分辨率,然后再接一個(gè)普通卷積層(可選操作:批標(biāo)準(zhǔn)化、整流線性單元)。實(shí)際前向傳播時(shí)的計(jì)算如下:

x = unpool(x, indices)
x = conv(x)

其中,indices 是編碼器階段對(duì)應(yīng)的池化層返回的最大池化操作返回的最大值下標(biāo)。反池化層與轉(zhuǎn)置卷積層都是可以訓(xùn)練的(都帶有參數(shù)),作用也幾乎相同(提升分辨率),但對(duì)于摳圖這個(gè)任務(wù)來(lái)說(shuō),最關(guān)心的就是目標(biāo)的邊界區(qū)域,而這些邊界因?yàn)槎际乔熬?、背景的交界區(qū),因此表現(xiàn)在特征映射的響應(yīng)上,就基本都是局部極大值,從而在池化操作時(shí),返回的最大值下標(biāo)就基本完整的記錄下了待摳目標(biāo)的邊界,反池化操作因?yàn)闀?huì)重點(diǎn)關(guān)注這些區(qū)域,所以效果較好。

????????DIM 類的 decode_layers 函數(shù)就是整個(gè)的解碼器的定義。它看起來(lái)有點(diǎn)晦澀,其實(shí)不難理解。我們看編碼器階段的配置:

[64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512]

其中的 M 表示最大池化層,其它的數(shù)字就是全部的 13 個(gè)卷積層對(duì)應(yīng)的輸出通道數(shù)。解碼器執(zhí)行的操作基本上就是以上操作的逆過(guò)程:

[512, 'U', 512, 'U', 256, 'U', 128, 'U', 64]

其中的 U 表示 torch.nn.MaxUnpool2d 反池化操作。

二、訓(xùn)練過(guò)程

????????本節(jié)訓(xùn)練數(shù)據(jù)來(lái)源于愛(ài)分割開(kāi)源的 數(shù)據(jù)集。該數(shù)據(jù)集內(nèi)的所有 34426 張圖片都類似于如下的上身模特圖:

愛(ài)分割開(kāi)源數(shù)據(jù)集實(shí)例圖片

可以明顯看到標(biāo)注的 alpha 通道是非常粗陋的,遠(yuǎn)達(dá)不到頭發(fā)絲的精度。因?yàn)檫@個(gè)數(shù)據(jù)集直接給出了原始圖像,所以不需要前景、背景圖片的合成。

數(shù)據(jù)準(zhǔn)備

????????當(dāng)你下載好愛(ài)分割開(kāi)源數(shù)據(jù)集(并解壓)之后,我們需要一次性將所有圖片的掩碼(mask)都準(zhǔn)備好,因此你需要打開(kāi) data/retrieve.py 文件,將 root_dir 改成你的 Matting_Human_Half 文件夾的路徑,然后執(zhí)行 retrieve.py 等待生成所有圖片的 alpha 和 mask(在 Matting_Human_Half 文件夾內(nèi)),以及用于訓(xùn)練的 train.txtval.txt(在 data 文件夾內(nèi),其中默認(rèn)隨機(jī)選擇 100 張圖像用于驗(yàn)證)。假如,你訓(xùn)練時(shí)不再改動(dòng) Matting_Human_Half 文件夾的路徑,那么你不需要再做其它處理了。如果你訓(xùn)練時(shí),Matting_Human_Half 與以上制作 train.txt 和 val.txt 時(shí)指定的 root_dir 路徑不一致了,那么你可以使用諸如 Notepad ++ 之類的工具,將 root_dir 替換為空,形成如下的形式:

去掉 root_dir 的標(biāo)注文件

????????train.txt 和 val.txt 分別記錄了訓(xùn)練和驗(yàn)證圖像的路徑,每一行對(duì)應(yīng)一張圖像的 4 個(gè)路徑,分別是 原圖像路徑(3 通道)、透明圖路徑(4 通道)、alpha 通道圖像路徑、mask 路徑,它們通過(guò) @ 符號(hào)分隔。

訓(xùn)練

????????直接在命令行執(zhí)行:

python3 train.py --root_dir "xxx/Matting_Human_Half" [--gpu_indices 0 1 ...]

開(kāi)始訓(xùn)練,如果你從制作數(shù)據(jù)時(shí)開(kāi)始, Matting_Human_Half 這個(gè)文件夾的路徑始終沒(méi)有改動(dòng)過(guò),那么 root_dir 這個(gè)參數(shù)也可以不指定(指定也無(wú)妨)。后面的 [--gpu_indices ...] 表示需要根據(jù)實(shí)際情況,可選的指定可用的 GPU 下標(biāo),這里默認(rèn)是使用 0,1,2,3 共 4 塊 GPU,如果你使用一塊 GPU,則指定

--gpu_indices 0

如果使用多塊,比如使用 第 1 塊和第 3 塊 GPU,則指定

--gpu_indices 1 3

即可。其它類似。訓(xùn)練過(guò)程中的所有超參數(shù)都在 train.py 文件的開(kāi)頭部分,可以直接修改默認(rèn)值或通過(guò)命令行指定。

????????訓(xùn)練開(kāi)始幾分鐘后,你在項(xiàng)目路徑下執(zhí)行:

tensorboard --logdir ./models/logs

可以打開(kāi)瀏覽器查看訓(xùn)練的學(xué)習(xí)率、損失曲線,和訓(xùn)練過(guò)程中的分割結(jié)果圖像。這里使用的是 Pytorch 自帶的類:from torch.utils.tensorboard import SummaryWriter 來(lái)調(diào)用 tensorboard,因此需要 Pytorch 1.1.0 以及之后的版本才可以。(但好像瀏覽器刷新不了新結(jié)果,需要不斷重開(kāi) tensorboard 才可以觀看訓(xùn)練進(jìn)展)

????????訓(xùn)練結(jié)束后(默認(rèn)訓(xùn)練 30 個(gè) epoch),在 models 文件夾中保存了訓(xùn)練過(guò)程中的模型參數(shù)文件(模型使用參考 predict.py)。直接執(zhí)行:

python3 predict.py

將在 test 文件夾里生成測(cè)試圖片的摳圖結(jié)果。

其它數(shù)據(jù)集上訓(xùn)練

????????訓(xùn)練 Pytorch 模型時(shí),需要重載 torch.utils.data.Dataset 類,用來(lái)提供數(shù)據(jù)的批量生成。重載時(shí),只需要實(shí)現(xiàn) __ init __, __ getitem __, __ len __ 這三個(gè)函數(shù)。在這個(gè)項(xiàng)目里,我們使用的是 dataset.py 的重載類 MattingDataset。讀者可以按照自己的方式依據(jù)自己的標(biāo)注格式來(lái)重載,也可以依照 MattingDataset 來(lái)改寫(xiě)。

????????對(duì)于只提供前、背景分離的數(shù)據(jù),建議先一次性提前合成好合成圖像,和制作好 alpha 通道圖像。此時(shí)你就可以適當(dāng)修改一下 get_image_mask_paths 函數(shù)即可。這個(gè)函數(shù)需要返回一個(gè)如下格式的列表

[[image_path, alpha_path],
 [image_path, alpha_path],
...
 [image_path, alpha_path]]

另外,__ getitem __ 函數(shù)數(shù)據(jù)增強(qiáng)的方式裁剪、縮放、水平翻轉(zhuǎn)(以及 alpha 通道隨機(jī)膨脹腐蝕),如果你還有其它的處理方式請(qǐng)自行添加或刪減。另外,這里指定的裁剪尺寸:

crop_sizes = [320, 480, 600, 800]

是根據(jù)愛(ài)分割提供的數(shù)據(jù)來(lái)劃定的,里面所有圖片都是 600x800 的分辨率。一般來(lái)說(shuō),根據(jù) Deep Image Matting 論文,是從 320 開(kāi)始,每隔 160 像素的尺寸裁剪,最后統(tǒng)一縮放到 320 即可。

三、Deep Image Matting 數(shù)據(jù)集上的復(fù)現(xiàn)

????????本節(jié)將在 Deep Image Matting 數(shù)據(jù)集上進(jìn)行訓(xùn)練(訓(xùn)練數(shù)據(jù)可聯(lián)系論文作者獲?。虿糠謪?shù)未仔細(xì)調(diào)整,訓(xùn)練結(jié)果并非最優(yōu)。使用的背景圖像集是 COCO/train2017。

數(shù)據(jù)準(zhǔn)備

????????我們將前景圖像背景圖像通過(guò) alpha 通道合成訓(xùn)練集。假設(shè)你已經(jīng)獲取了 DIM 數(shù)據(jù)集,那么進(jìn)入 data_dim 文件夾,打開(kāi) composition.py,將 root_dir 替換成你保持 Combined_Dataset 文件夾的路徑,bg_image_root_dir 填寫(xiě)背景圖像的文件夾路徑,output_dir 填寫(xiě)合成圖像的保持文件夾路徑。num_bg_images_per_fg 表示一張前景圖像對(duì)應(yīng)多少?gòu)埍尘皥D像,論文里這個(gè)值取 100(這里為了減少訓(xùn)練時(shí)間,我取的是 50,讀者根據(jù)具體情況修改)。當(dāng)這些值都確認(rèn)無(wú)誤后,執(zhí)行 composition.py,將花費(fèi)很長(zhǎng)一段時(shí)間來(lái)合成圖片。合成圖像結(jié)束后,在當(dāng)前路徑下生成 train.txtval.txt 兩個(gè)標(biāo)注文件。train.txt 文件里每一行對(duì)應(yīng) 4 個(gè)路徑,分別是合成圖像路徑、前景圖像路徑、alpha 通道路徑、背景圖像路徑,val.txt 里除了缺少合成圖像路徑之外,其它順序一致。

????????訓(xùn)練時(shí),我們只需要 train.txt 中的第 1 個(gè)和第 3 個(gè)路徑,因此和訓(xùn)練愛(ài)分割開(kāi)源數(shù)據(jù)集的標(biāo)注文件格式一致,從而可以共用同一個(gè) dataset.py,只需要確保 dataset.py 里面的函數(shù) __ getitem __ 中的

crop_sizes = [320, 480, 640]

即可。

訓(xùn)練過(guò)程

????????執(zhí)行

python3 train.py --annotation_path "./data_dim/train.txt" [gpu_indices 0 1 2 ...]

開(kāi)始訓(xùn)練,期間可通過(guò) tensorboard 查看訓(xùn)練進(jìn)程。訓(xùn)練的所有超參數(shù)可在 train.py 內(nèi)修改,也可以通過(guò)命令行直接指定。

訓(xùn)練期間的 tensorboard 摳圖結(jié)果展示
訓(xùn)練期間的 tensorboard 學(xué)習(xí)率和損失曲線

????????本次訓(xùn)練,超參數(shù)采用的都是 train.py 中的默認(rèn)值,由損失曲線可以看到,如果再繼續(xù)訓(xùn)練(已訓(xùn)練 200 epoch),損失會(huì)進(jìn)一步下降,摳圖效果會(huì)更好。

結(jié)果展示

????????訓(xùn)練結(jié)束后,執(zhí)行(如果模型保存路徑是默認(rèn)的 ./models,否則需要修改一下 ckpt_path):

python3 predict_trimap.py

會(huì)在 data_dim/test 文件夾里生成預(yù)測(cè)結(jié)果(見(jiàn)文件夾 preds):

合成圖片、trimap、摳圖結(jié)果、GT alpha
合成圖片、trimap、摳圖結(jié)果、GT alpha

????????從以上摳圖結(jié)果可看出,在某些細(xì)節(jié)上效果還不理想。如果你想獲得更好的結(jié)果,一方面可以在合成訓(xùn)練圖片時(shí)提升 1 張前景對(duì)應(yīng)的背景圖片數(shù)(我取的是 1:50,論文是 1:100);其次,仔細(xì)調(diào)整學(xué)習(xí)率及其衰減;再次,增加訓(xùn)練的輪次(epoch,我訓(xùn)練了 200 個(gè) epoch)等。

附錄
快速上手 Pytorch 的資料:PyTorch Tutorial for Deep Learning Researchers。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容