易 AI - 使用 TensorFlow 和 Labelme 訓(xùn)練自定義 U-NET 圖像分割模型

原文:https://makeoptim.com/deep-learning/yiai-unet

前言

先前筆者在使用 TensorFlow Object Detection API Mask R-CNN 訓(xùn)練自定義圖像分割模型介紹了如何使用 Mask R-CNN 實現(xiàn)圖像分割。不過 Mask R-CNN 網(wǎng)絡(luò)較為復(fù)雜,性能消耗較高,在實際的運(yùn)用中,如果不是很復(fù)雜的分割任務(wù),還有一個較為合適的選擇,那就是本文要講解的 U-NET 。

介紹 U-NET 的文章很多,不過從自定義數(shù)據(jù)集模型定義、訓(xùn)練預(yù)測的文章卻寥寥無幾。因此,本文旨在通過 一個 Demo 來覆蓋各個步驟,讓大家快速掌握 U-NET

環(huán)境搭建

下載源代碼 https://github.com/CatchZeng/tensorflow-unet-labelme 到本地,并進(jìn)入該目錄下。

如果你使用的是 macOS, 你需要在安裝前先執(zhí)行以下命令。

? brew install pyqt

執(zhí)行以下命令,安裝虛擬環(huán)境 unet。

? conda create -n unet -y python=3.9 && conda activate unet && pip install -r requirements.txt

數(shù)據(jù)集

文本還是以使用 TensorFlow Object Detection API Mask R-CNN 訓(xùn)練自定義圖像分割模型 中的茶杯(cup)、茶壺(teapot)、加濕器(humidifier) 來做案例。

數(shù)據(jù)標(biāo)注

數(shù)據(jù)標(biāo)注已在使用 TensorFlow Object Detection API Mask R-CNN 訓(xùn)練自定義圖像分割模型 闡述,這里就不再贅述。

生成 VOC 數(shù)據(jù)集

U-NET 的數(shù)據(jù)集包含原始圖像(jpg) 和標(biāo)簽(mask)圖像(png),通常使用 VOC 格式來整理數(shù)據(jù)集。

將標(biāo)注好的數(shù)據(jù)存放到 tensorflow-unet-labelmedatasets/train 下,并新建 datasets/labels.txt,內(nèi)容為分類名稱,詳見 https://github.com/wkentaro/labelme/tree/main/examples/semantic_segmentation

? tree -L 3
.
├── Makefile
├── README.md
├── datasets
│   ├── README.md
│   ├── labels.txt
│   ├── train
│   │   ├── 1.jpg
│   │   ├── 1.json
......
│   │   ├── 9.jpg
│   │   └── 9.json
├── labelme2voc.py
├── train.gif
├── unet.ipynb
└── voc_annotation.py

注:可以參考 https://github.com/CatchZeng/tensorflow-unet-labelme/tree/master/datasets。

執(zhí)行以下命令,生成 VOC 格式數(shù)據(jù)集。

? make voc
? tree -L 5
.
├── Makefile
├── README.md
├── datasets
│   ├── README.md
│   ├── labels.txt
│   ├── test
│   │   └── 1.jpg
│   ├── train
│   │   ├── 1.jpg
│   │   ├── 1.json
......
│   │   ├── 9.jpg
│   │   └── 9.json
│   └── train_voc
│       ├── ImageSets
│       │   └── Segmentation
│       │       ├── test.txt
│       │       ├── train.txt   # 訓(xùn)練集圖像名稱列表
│       │       ├── trainval.txt
│       │       └── val.txt  # 驗證集圖像名稱列表
│       ├── JPEGImages # 原圖
│       │   ├── 1.jpg
......
│       │   └── 9.jpg
│       ├── SegmentationClass
│       │   ├── 1.npy
......
│       │   └── 9.npy
│       ├── SegmentationClassPNG # 標(biāo)簽(mask)圖
│       │   ├── 1.png
......
│       │   └── 9.png
│       ├── SegmentationClassVisualization
│       │   ├── 1.jpg
......
│       │   └── 9.jpg
│       └── class_names.txt
├── labelme2voc.py
├── train.gif
├── unet.ipynb
└── voc_annotation.py

生成的 datasets/train_voc 便是訓(xùn)練用到的數(shù)據(jù)集。

訓(xùn)練

打開 unet.ipynb,并選擇 Python 解釋器為 unet,即可開始訓(xùn)練。

代碼詳解

數(shù)據(jù)集分為訓(xùn)練集驗證集,分別從 ImageSets/Segmentation/train.txtImageSets/Segmentation/val.txt 讀取文件。然后,通過 UnetDataset 類構(gòu)建成為 tf.keras.utils.Sequence 對象,方便后面通過 model.fit 直接訓(xùn)練。

dataset_path = 'datasets/train_voc'

# read dataset txt files
with open(os.path.join(dataset_path, "ImageSets/Segmentation/train.txt"),
          "r",
          encoding="utf8") as f:
    train_lines = f.readlines()

with open(os.path.join(dataset_path, "ImageSets/Segmentation/val.txt"),
          "r",
          encoding="utf8") as f:
    val_lines = f.readlines()

train_batches = UnetDataset(train_lines, INPUT_SHAPE, BATCH_SIZE, NUM_CLASSES,
                            True, dataset_path)
val_batches = UnetDataset(val_lines, INPUT_SHAPE, BATCH_SIZE, NUM_CLASSES,
                          False, dataset_path)

STEPS_PER_EPOCH = len(train_lines) // BATCH_SIZE
VALIDATION_STEPS = len(val_lines) // BATCH_SIZE // VAL_SUBSPLITS

UnetDataset 類繼承自 tf.keras.utils.Sequence 。通過 __getitem__ 方法返回一組 batch_size 的數(shù)據(jù),其中包含原圖(images)和標(biāo)簽圖(targets)。因為模型有固定的 input shape,因此,在 process_data 方法中做了 resize 操作;在訓(xùn)練過程中,還可以加入數(shù)據(jù)增強(qiáng),這里使用了一個簡單的 flip

class UnetDataset(tf.keras.utils.Sequence):

    def __init__(self, annotation_lines, input_shape, batch_size, num_classes,
                 train, dataset_path):
        self.annotation_lines = annotation_lines
        self.length = len(self.annotation_lines)
        self.input_shape = input_shape
        self.batch_size = batch_size
        self.num_classes = num_classes
        self.train = train
        self.dataset_path = dataset_path

    def __len__(self):
        return math.ceil(len(self.annotation_lines) / float(self.batch_size))

    def __getitem__(self, index):
        images = []
        targets = []
        for i in range(index * self.batch_size, (index + 1) * self.batch_size):
            i = i % self.length
            name = self.annotation_lines[i].split()[0]
            jpg = Image.open(
                os.path.join(os.path.join(self.dataset_path, "JPEGImages"),
                             name + ".jpg"))
            png = Image.open(
                os.path.join(
                    os.path.join(self.dataset_path, "SegmentationClassPNG"),
                    name + ".png"))

            jpg, png = self.process_data(jpg,
                                         png,
                                         self.input_shape,
                                         random=self.train)

            images.append(jpg)
            targets.append(png)

        images = np.array(images)
        targets = np.array(targets)
        return images, targets

    def rand(self, a=0, b=1):
        return np.random.rand() * (b - a) + a

    def process_data(self, image, label, input_shape, random=True):
        image = cvtColor(image)
        label = Image.fromarray(np.array(label))
        h, w, _ = input_shape

        # resize
        image, _, _ = resize_image(image, (w, h))
        label, _, _ = resize_label(label, (w, h))

        if random:
            # flip
            flip = self.rand() < .5
            if flip:
                image = image.transpose(Image.FLIP_LEFT_RIGHT)
                label = label.transpose(Image.FLIP_LEFT_RIGHT)

        # np
        image = np.array(image, np.float32)
        image = normalize(image)

        label = np.array(label)
        label[label >= self.num_classes] = self.num_classes

        return image, label

模型定義部分,比較簡單,跟論文中一樣,主要為下采樣,上采樣,和 concat

這里,筆者參考了 https://www.tensorflow.org/tutorials/images/segmentation ,詳細(xì)的解析大家可以查看。

def unet_model(output_channels: int):
    inputs = tf.keras.layers.Input(shape=INPUT_SHAPE)

    # Downsampling through the model
    skips = down_stack(inputs)
    x = skips[-1]
    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])

    # This is the last layer of the model
    last = tf.keras.layers.Conv2DTranspose(filters=output_channels,
                                           kernel_size=3,
                                           strides=2,
                                           padding='same')  #64x64 -> 128x128

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

Callback 部分,筆者主要使用了 DisplayCallbackModelCheckpointCallback。

DisplayCallback 用于訓(xùn)練完一個 epoch 后,顯示預(yù)測的結(jié)果,便于觀測模型的效果

ModelCheckpointCallback 用于訓(xùn)練完一個 epoch 后,在 logs 文件夾下 保存權(quán)值(模型),并且記錄每個 epoch 后的準(zhǔn)確率和損失率。這樣,用戶在訓(xùn)練完后,可以挑選效果比較好的模型。

class DisplayCallback(tf.keras.callbacks.Callback):

    def on_epoch_end(self, epoch, logs=None):
        clear_output(wait=True)
        show_predictions()
        print('\nSample Prediction after epoch {}\n'.format(epoch + 1))


class ModelCheckpointCallback(tf.keras.callbacks.Callback):

    def __init__(self,
                 filepath,
                 monitor='val_loss',
                 verbose=0,
                 save_best_only=False,
                 save_weights_only=False,
                 mode='auto',
                 period=1):
        super(ModelCheckpointCallback, self).__init__()
        self.monitor = monitor
        self.verbose = verbose
        self.filepath = filepath
        self.save_best_only = save_best_only
        self.save_weights_only = save_weights_only
        self.period = period
        self.epochs_since_last_save = 0

        if mode not in ['auto', 'min', 'max']:
            warnings.warn(
                'ModelCheckpoint mode %s is unknown, '
                'fallback to auto mode.' % (mode), RuntimeWarning)
            mode = 'auto'

        if mode == 'min':
            self.monitor_op = np.less
            self.best = np.Inf
        elif mode == 'max':
            self.monitor_op = np.greater
            self.best = -np.Inf
        else:
            if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
                self.monitor_op = np.greater
                self.best = -np.Inf
            else:
                self.monitor_op = np.less
                self.best = np.Inf

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        self.epochs_since_last_save += 1
        if self.epochs_since_last_save >= self.period:
            self.epochs_since_last_save = 0
            filepath = self.filepath.format(epoch=epoch + 1, **logs)
            if self.save_best_only:
                current = logs.get(self.monitor)
                if current is None:
                    warnings.warn(
                        'Can save best model only with %s available, '
                        'skipping.' % (self.monitor), RuntimeWarning)
                else:
                    if self.monitor_op(current, self.best):
                        if self.verbose > 0:
                            print(
                                '\nEpoch %05d: %s improved from %0.5f to %0.5f,'
                                ' saving model to %s' %
                                (epoch + 1, self.monitor, self.best, current,
                                 filepath))
                        self.best = current
                        if self.save_weights_only:
                            self.model.save_weights(filepath, overwrite=True)
                        else:
                            self.model.save(filepath, overwrite=True)
                    else:
                        if self.verbose > 0:
                            print('\nEpoch %05d: %s did not improve' %
                                  (epoch + 1, self.monitor))
            else:
                if self.verbose > 0:
                    print('\nEpoch %05d: saving model to %s' %
                          (epoch + 1, filepath))
                if self.save_weights_only:
                    self.model.save_weights(filepath, overwrite=True)
                else:
                    self.model.save(filepath, overwrite=True)

預(yù)測模型部分,先將文件 resize 到模型的 INPUT_SHAPE 大小。這里需要注意的是,預(yù)測的圖不一定是與 INPUT_SHAPE 比例相等的。為了不因為比例問題,導(dǎo)致預(yù)測結(jié)果不準(zhǔn)確,筆者這里在 resize 的時候為圖片不在比例的地方,添加了灰色占位邊,如下圖所示。

然后再預(yù)測完之后,再去掉灰邊。

注:本案例效果圖因為是 Demo,所以只是訓(xùn)練了 20 幾個 epoch,準(zhǔn)確度是一般的,大家在實際應(yīng)用中,可以多訓(xùn)練下,提高準(zhǔn)確度。

def detect_image(image_path):
    image = Image.open(image_path)
    image = cvtColor(image)

    old_img = copy.deepcopy(image)
    ori_h = np.array(image).shape[0]
    ori_w = np.array(image).shape[1]

   # resize 并添加灰邊
    image_data, nw, nh = resize_image(image, (INPUT_SHAPE[1], INPUT_SHAPE[0]))

    image_data = normalize(np.array(image_data, np.float32))

    image_data = np.expand_dims(image_data, 0)

    pr = model.predict(image_data)[0]

   ## 去掉灰邊
    pr = pr[int((INPUT_SHAPE[0] - nh) // 2) : int((INPUT_SHAPE[0] - nh) // 2 + nh), \
            int((INPUT_SHAPE[1] - nw) // 2) : int((INPUT_SHAPE[1] - nw) // 2 + nw)]

    pr = cv2.resize(pr, (ori_w, ori_h), interpolation=cv2.INTER_LINEAR)

    pr = pr.argmax(axis=-1)

    # seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
    # for c in range(NUM_CLASSES):
    #     seg_img[:, :, 0] += ((pr[:, :] == c ) * colors[c][0]).astype('uint8')
    #     seg_img[:, :, 1] += ((pr[:, :] == c ) * colors[c][1]).astype('uint8')
    #     seg_img[:, :, 2] += ((pr[:, :] == c ) * colors[c][2]).astype('uint8')
    seg_img = np.reshape(
        np.array(colors, np.uint8)[np.reshape(pr, [-1])], [ori_h, ori_w, -1])

    image = Image.fromarray(seg_img)
    image = Image.blend(old_img, image, 0.7)

    return image

小結(jié)

本文,通過一個 Demo 介紹了,如何從制作數(shù)據(jù)集到訓(xùn)練 U-NET 模型并預(yù)測圖片整個流程。大家可以自己找一個場景,制作一個自定義的數(shù)據(jù)集,然后實踐一遍,以便更好地掌握。本文就到這里了,咱們下一篇見。

延伸閱讀

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容