前言
先前筆者在使用 TensorFlow Object Detection API Mask R-CNN 訓(xùn)練自定義圖像分割模型介紹了如何使用 Mask R-CNN 實現(xiàn)圖像分割。不過 Mask R-CNN 網(wǎng)絡(luò)較為復(fù)雜,性能消耗較高,在實際的運(yùn)用中,如果不是很復(fù)雜的分割任務(wù),還有一個較為合適的選擇,那就是本文要講解的 U-NET 。
介紹 U-NET 的文章很多,不過從自定義數(shù)據(jù)集到模型定義、訓(xùn)練、預(yù)測的文章卻寥寥無幾。因此,本文旨在通過 一個 Demo 來覆蓋各個步驟,讓大家快速掌握 U-NET 。

環(huán)境搭建
下載源代碼 https://github.com/CatchZeng/tensorflow-unet-labelme 到本地,并進(jìn)入該目錄下。
如果你使用的是 macOS, 你需要在安裝前先執(zhí)行以下命令。
? brew install pyqt
執(zhí)行以下命令,安裝虛擬環(huán)境 unet。
? conda create -n unet -y python=3.9 && conda activate unet && pip install -r requirements.txt
數(shù)據(jù)集
文本還是以使用 TensorFlow Object Detection API Mask R-CNN 訓(xùn)練自定義圖像分割模型 中的茶杯(cup)、茶壺(teapot)、加濕器(humidifier) 來做案例。
數(shù)據(jù)標(biāo)注
數(shù)據(jù)標(biāo)注已在使用 TensorFlow Object Detection API Mask R-CNN 訓(xùn)練自定義圖像分割模型 闡述,這里就不再贅述。
生成 VOC 數(shù)據(jù)集
U-NET 的數(shù)據(jù)集包含原始圖像(jpg) 和標(biāo)簽(mask)圖像(png),通常使用 VOC 格式來整理數(shù)據(jù)集。
將標(biāo)注好的數(shù)據(jù)存放到 tensorflow-unet-labelme 的 datasets/train 下,并新建 datasets/labels.txt,內(nèi)容為分類名稱,詳見 https://github.com/wkentaro/labelme/tree/main/examples/semantic_segmentation。
? tree -L 3
.
├── Makefile
├── README.md
├── datasets
│ ├── README.md
│ ├── labels.txt
│ ├── train
│ │ ├── 1.jpg
│ │ ├── 1.json
......
│ │ ├── 9.jpg
│ │ └── 9.json
├── labelme2voc.py
├── train.gif
├── unet.ipynb
└── voc_annotation.py
注:可以參考 https://github.com/CatchZeng/tensorflow-unet-labelme/tree/master/datasets。
執(zhí)行以下命令,生成 VOC 格式數(shù)據(jù)集。
? make voc
? tree -L 5
.
├── Makefile
├── README.md
├── datasets
│ ├── README.md
│ ├── labels.txt
│ ├── test
│ │ └── 1.jpg
│ ├── train
│ │ ├── 1.jpg
│ │ ├── 1.json
......
│ │ ├── 9.jpg
│ │ └── 9.json
│ └── train_voc
│ ├── ImageSets
│ │ └── Segmentation
│ │ ├── test.txt
│ │ ├── train.txt # 訓(xùn)練集圖像名稱列表
│ │ ├── trainval.txt
│ │ └── val.txt # 驗證集圖像名稱列表
│ ├── JPEGImages # 原圖
│ │ ├── 1.jpg
......
│ │ └── 9.jpg
│ ├── SegmentationClass
│ │ ├── 1.npy
......
│ │ └── 9.npy
│ ├── SegmentationClassPNG # 標(biāo)簽(mask)圖
│ │ ├── 1.png
......
│ │ └── 9.png
│ ├── SegmentationClassVisualization
│ │ ├── 1.jpg
......
│ │ └── 9.jpg
│ └── class_names.txt
├── labelme2voc.py
├── train.gif
├── unet.ipynb
└── voc_annotation.py
生成的 datasets/train_voc 便是訓(xùn)練用到的數(shù)據(jù)集。
訓(xùn)練
打開 unet.ipynb,并選擇 Python 解釋器為 unet,即可開始訓(xùn)練。
代碼詳解
數(shù)據(jù)集分為訓(xùn)練集和驗證集,分別從 ImageSets/Segmentation/train.txt 和 ImageSets/Segmentation/val.txt 讀取文件。然后,通過 UnetDataset 類構(gòu)建成為 tf.keras.utils.Sequence 對象,方便后面通過 model.fit 直接訓(xùn)練。
dataset_path = 'datasets/train_voc'
# read dataset txt files
with open(os.path.join(dataset_path, "ImageSets/Segmentation/train.txt"),
"r",
encoding="utf8") as f:
train_lines = f.readlines()
with open(os.path.join(dataset_path, "ImageSets/Segmentation/val.txt"),
"r",
encoding="utf8") as f:
val_lines = f.readlines()
train_batches = UnetDataset(train_lines, INPUT_SHAPE, BATCH_SIZE, NUM_CLASSES,
True, dataset_path)
val_batches = UnetDataset(val_lines, INPUT_SHAPE, BATCH_SIZE, NUM_CLASSES,
False, dataset_path)
STEPS_PER_EPOCH = len(train_lines) // BATCH_SIZE
VALIDATION_STEPS = len(val_lines) // BATCH_SIZE // VAL_SUBSPLITS
UnetDataset 類繼承自 tf.keras.utils.Sequence 。通過 __getitem__ 方法返回一組 batch_size 的數(shù)據(jù),其中包含原圖(images)和標(biāo)簽圖(targets)。因為模型有固定的 input shape,因此,在 process_data 方法中做了 resize 操作;在訓(xùn)練過程中,還可以加入數(shù)據(jù)增強(qiáng),這里使用了一個簡單的 flip;
class UnetDataset(tf.keras.utils.Sequence):
def __init__(self, annotation_lines, input_shape, batch_size, num_classes,
train, dataset_path):
self.annotation_lines = annotation_lines
self.length = len(self.annotation_lines)
self.input_shape = input_shape
self.batch_size = batch_size
self.num_classes = num_classes
self.train = train
self.dataset_path = dataset_path
def __len__(self):
return math.ceil(len(self.annotation_lines) / float(self.batch_size))
def __getitem__(self, index):
images = []
targets = []
for i in range(index * self.batch_size, (index + 1) * self.batch_size):
i = i % self.length
name = self.annotation_lines[i].split()[0]
jpg = Image.open(
os.path.join(os.path.join(self.dataset_path, "JPEGImages"),
name + ".jpg"))
png = Image.open(
os.path.join(
os.path.join(self.dataset_path, "SegmentationClassPNG"),
name + ".png"))
jpg, png = self.process_data(jpg,
png,
self.input_shape,
random=self.train)
images.append(jpg)
targets.append(png)
images = np.array(images)
targets = np.array(targets)
return images, targets
def rand(self, a=0, b=1):
return np.random.rand() * (b - a) + a
def process_data(self, image, label, input_shape, random=True):
image = cvtColor(image)
label = Image.fromarray(np.array(label))
h, w, _ = input_shape
# resize
image, _, _ = resize_image(image, (w, h))
label, _, _ = resize_label(label, (w, h))
if random:
# flip
flip = self.rand() < .5
if flip:
image = image.transpose(Image.FLIP_LEFT_RIGHT)
label = label.transpose(Image.FLIP_LEFT_RIGHT)
# np
image = np.array(image, np.float32)
image = normalize(image)
label = np.array(label)
label[label >= self.num_classes] = self.num_classes
return image, label
模型定義部分,比較簡單,跟論文中一樣,主要為下采樣,上采樣,和 concat

這里,筆者參考了 https://www.tensorflow.org/tutorials/images/segmentation ,詳細(xì)的解析大家可以查看。
def unet_model(output_channels: int):
inputs = tf.keras.layers.Input(shape=INPUT_SHAPE)
# Downsampling through the model
skips = down_stack(inputs)
x = skips[-1]
skips = reversed(skips[:-1])
# Upsampling and establishing the skip connections
for up, skip in zip(up_stack, skips):
x = up(x)
concat = tf.keras.layers.Concatenate()
x = concat([x, skip])
# This is the last layer of the model
last = tf.keras.layers.Conv2DTranspose(filters=output_channels,
kernel_size=3,
strides=2,
padding='same') #64x64 -> 128x128
x = last(x)
return tf.keras.Model(inputs=inputs, outputs=x)
Callback 部分,筆者主要使用了 DisplayCallback 和 ModelCheckpointCallback。
DisplayCallback 用于訓(xùn)練完一個 epoch 后,顯示預(yù)測的結(jié)果,便于觀測模型的效果。

ModelCheckpointCallback 用于訓(xùn)練完一個 epoch 后,在 logs 文件夾下 保存權(quán)值(模型),并且記錄每個 epoch 后的準(zhǔn)確率和損失率。這樣,用戶在訓(xùn)練完后,可以挑選效果比較好的模型。
class DisplayCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
clear_output(wait=True)
show_predictions()
print('\nSample Prediction after epoch {}\n'.format(epoch + 1))
class ModelCheckpointCallback(tf.keras.callbacks.Callback):
def __init__(self,
filepath,
monitor='val_loss',
verbose=0,
save_best_only=False,
save_weights_only=False,
mode='auto',
period=1):
super(ModelCheckpointCallback, self).__init__()
self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
self.save_best_only = save_best_only
self.save_weights_only = save_weights_only
self.period = period
self.epochs_since_last_save = 0
if mode not in ['auto', 'min', 'max']:
warnings.warn(
'ModelCheckpoint mode %s is unknown, '
'fallback to auto mode.' % (mode), RuntimeWarning)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
self.best = np.Inf
elif mode == 'max':
self.monitor_op = np.greater
self.best = -np.Inf
else:
if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
self.monitor_op = np.greater
self.best = -np.Inf
else:
self.monitor_op = np.less
self.best = np.Inf
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epochs_since_last_save += 1
if self.epochs_since_last_save >= self.period:
self.epochs_since_last_save = 0
filepath = self.filepath.format(epoch=epoch + 1, **logs)
if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
warnings.warn(
'Can save best model only with %s available, '
'skipping.' % (self.monitor), RuntimeWarning)
else:
if self.monitor_op(current, self.best):
if self.verbose > 0:
print(
'\nEpoch %05d: %s improved from %0.5f to %0.5f,'
' saving model to %s' %
(epoch + 1, self.monitor, self.best, current,
filepath))
self.best = current
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)
else:
if self.verbose > 0:
print('\nEpoch %05d: %s did not improve' %
(epoch + 1, self.monitor))
else:
if self.verbose > 0:
print('\nEpoch %05d: saving model to %s' %
(epoch + 1, filepath))
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)
預(yù)測模型部分,先將文件 resize 到模型的 INPUT_SHAPE 大小。這里需要注意的是,預(yù)測的圖不一定是與 INPUT_SHAPE 比例相等的。為了不因為比例問題,導(dǎo)致預(yù)測結(jié)果不準(zhǔn)確,筆者這里在 resize 的時候為圖片不在比例的地方,添加了灰色占位邊,如下圖所示。

然后再預(yù)測完之后,再去掉灰邊。

注:本案例效果圖因為是 Demo,所以只是訓(xùn)練了 20 幾個 epoch,準(zhǔn)確度是一般的,大家在實際應(yīng)用中,可以多訓(xùn)練下,提高準(zhǔn)確度。
def detect_image(image_path):
image = Image.open(image_path)
image = cvtColor(image)
old_img = copy.deepcopy(image)
ori_h = np.array(image).shape[0]
ori_w = np.array(image).shape[1]
# resize 并添加灰邊
image_data, nw, nh = resize_image(image, (INPUT_SHAPE[1], INPUT_SHAPE[0]))
image_data = normalize(np.array(image_data, np.float32))
image_data = np.expand_dims(image_data, 0)
pr = model.predict(image_data)[0]
## 去掉灰邊
pr = pr[int((INPUT_SHAPE[0] - nh) // 2) : int((INPUT_SHAPE[0] - nh) // 2 + nh), \
int((INPUT_SHAPE[1] - nw) // 2) : int((INPUT_SHAPE[1] - nw) // 2 + nw)]
pr = cv2.resize(pr, (ori_w, ori_h), interpolation=cv2.INTER_LINEAR)
pr = pr.argmax(axis=-1)
# seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
# for c in range(NUM_CLASSES):
# seg_img[:, :, 0] += ((pr[:, :] == c ) * colors[c][0]).astype('uint8')
# seg_img[:, :, 1] += ((pr[:, :] == c ) * colors[c][1]).astype('uint8')
# seg_img[:, :, 2] += ((pr[:, :] == c ) * colors[c][2]).astype('uint8')
seg_img = np.reshape(
np.array(colors, np.uint8)[np.reshape(pr, [-1])], [ori_h, ori_w, -1])
image = Image.fromarray(seg_img)
image = Image.blend(old_img, image, 0.7)
return image
小結(jié)
本文,通過一個 Demo 介紹了,如何從制作數(shù)據(jù)集到訓(xùn)練 U-NET 模型并預(yù)測圖片整個流程。大家可以自己找一個場景,制作一個自定義的數(shù)據(jù)集,然后實踐一遍,以便更好地掌握。本文就到這里了,咱們下一篇見。