【语义分割】Tensorflow deeplabv3+训练自己的数据集

时间:2024-04-03 13:38:12

一、制作语义分割数据集

按照【语义分割】用labelme制作语义分割数据集的方法制作训练数据集。

这里我就在网上下载了30张道路和车辆的图片。

【语义分割】Tensorflow deeplabv3+训练自己的数据集

制作好的label看起来是全黑的,但是其实是有值的,它的像素值就是类别标签,因为我这里只有background,road,car这三类,所以像素值只有0, 1, 2这三个数字,非常小,肉眼难以分辨,但是读到matlab可以看到它们的值。

【语义分割】Tensorflow deeplabv3+训练自己的数据集

然后需要生成训练集的train.txt和val.txt。

import os
import random

trainfilepath = 'train'
txtsavepath = 'txt'
train_file = os.listdir(trainfilepath)

num=len(train_file)
list = range(num)

train = random.sample(list, num)  

os.chdir(txtsavepath)   

ftrain = open('train.txt', 'w')  

for i in list :
  name =train_file[i][:-4] + '\n'
  ftrain.write(name)
ftrain.close()

按照同样的方式生成val.txt。我的训练集有22张图片,验证集5张,3张用做最后的测试。

按照PASCAL VOC 2012数据集的格式整理一下数据。

segtest #数据集文件夹名称
   + JPEGImages # 原始图片
   + Segmentation # 存放txt文件
      + train.txt
      + val.txt
   + SegmentationClassRaw #标注文件,8位的png图
   + tfrecord # 存放转换的tfrecord文件,用于训练
  

下面可以生成tfrecord文件了。

用models/research/deeplab/datasets/build_voc2012_data.py的来生成,因为我的图片比较少,所以将81行的_NUM_SHARDS = 2,默认是4。

python datasets/build_voc2012_data.py \
  --image_folder="/home/cc/dataset/segtest/JPEGImages" \
  --semantic_segmentation_folder="/home/cc/dataset/segtest/SegmentationClassRaw" \
  --list_folder="/home/cc/dataset/segtest/Segmentation" \
  --image_format="jpg" \
  --output_dir="/home/cc/dataset/segtest/tfrecord"

生成了tfrecord如下。

【语义分割】Tensorflow deeplabv3+训练自己的数据集

二、修改代码

首先要注册自己的数据集,在models/research/deeplab/datasets/segmentation_dataset.py中110行的地方添加如下内容。

_SEGTEST_INFORMATION = DatasetDescriptor(
    splits_to_sizes={
        'train': 22,  # num of samples in images/training
        'val': 5,  # num of samples in images/validation
    },
    num_classes=4,
    ignore_label=255,
)

这里的num_classes是background+road+car+ignore_label。ignore_label是不参与计算loss的。在mask中将ignore_label的灰度值标记为255

ignore_label就像下面这样

【语义分割】Tensorflow deeplabv3+训练自己的数据集

PASCAL VOC 2012的数据集有这种描边的ignore_label,我的数据集里并没有,但是不这样设置会出错。

在120行把自己的数据集注册进去。

_DATASETS_INFORMATION = {
    'cityscapes': _CITYSCAPES_INFORMATION,
    'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION,
    'ade20k': _ADE20K_INFORMATION,
    'segtest': _SEGTEST_INFORMATION
}

三、训练

DATASET_DIR="/home/cc/dataset/segtest"


# Set up the working directories.
SEG_FOLDER="mysegtest"
EXP_FOLDER="exp/train_on_trainval_set"
INIT_FOLDER="${DATASET_DIR}/${SEG_FOLDER}/init_models"
TRAIN_LOGDIR="${DATASET_DIR}/${SEG_FOLDER}/${EXP_FOLDER}/train"
EVAL_LOGDIR="${DATASET_DIR}/${SEG_FOLDER}/${EXP_FOLDER}/eval"
VIS_LOGDIR="${DATASET_DIR}/${SEG_FOLDER}/${EXP_FOLDER}/vis"
EXPORT_DIR="${DATASET_DIR}/${SEG_FOLDER}/${EXP_FOLDER}/export"
mkdir -p "${INIT_FOLDER}"
mkdir -p "${TRAIN_LOGDIR}"
mkdir -p "${EVAL_LOGDIR}"
mkdir -p "${VIS_LOGDIR}"
mkdir -p "${EXPORT_DIR}"


SEG_DATASET="${DATASET_DIR}/tfrecord"

# Train 10 iterations.
NUM_ITERATIONS=500
python "${WORK_DIR}"/train.py \
  --logtostderr \
  --train_split="train" \
  --model_variant="xception_65" \
  --atrous_rates=6 \
  --atrous_rates=12 \
  --atrous_rates=18 \
  --output_stride=16 \
  --decoder_output_stride=4 \
  --train_crop_size=513 \
  --train_crop_size=513 \
  --train_batch_size=4 \
  --training_number_of_steps="${NUM_ITERATIONS}" \
  --fine_tune_batch_norm=true \
  --dataset="segtest" \
  --tf_initial_checkpoint="/home/cc/models/research/deeplab/backbone/deeplabv3_cityscapes_train/model.ckpt" \
  --train_logdir="${TRAIN_LOGDIR}" \
  --dataset_dir="${SEG_DATASET}" \
  --initialize_last_layer=False \
  --last_layers_contain_logits_only=True

这里的预训练权重我使用的是Cityscapes训练的xception-65,下载地址预权重下载

initialize_last_layer=False和last_layers_contain_logits_only=True的设置是根据3730#issuecomment-380168917

【语义分割】Tensorflow deeplabv3+训练自己的数据集

训练完成后再TRAIN_LOGDIR下生成了如下文件

【语义分割】Tensorflow deeplabv3+训练自己的数据集

四、验证

python "${WORK_DIR}"/eval.py \
  --logtostderr \
  --eval_split="val" \
  --model_variant="xception_65" \
  --atrous_rates=6 \
  --atrous_rates=12 \
  --atrous_rates=18 \
  --output_stride=16 \
  --decoder_output_stride=4 \
  --eval_crop_size=513 \
  --eval_crop_size=513 \
  --dataset="segtest" \
  --checkpoint_dir="${TRAIN_LOGDIR}" \
  --eval_logdir="${EVAL_LOGDIR}" \
  --dataset_dir="${SEG_DATASET}" \
  --max_number_of_evaluations=1

【语义分割】Tensorflow deeplabv3+训练自己的数据集

五、可视化

python "${WORK_DIR}"/vis.py \
  --logtostderr \
  --vis_split="val" \
  --model_variant="xception_65" \
  --atrous_rates=6 \
  --atrous_rates=12 \
  --atrous_rates=18 \
  --output_stride=16 \
  --decoder_output_stride=4 \
  --vis_crop_size=513 \
  --vis_crop_size=513 \
  --dataset="segtest" \
  --checkpoint_dir="${TRAIN_LOGDIR}" \
  --vis_logdir="${VIS_LOGDIR}" \
  --dataset_dir="${SEG_DATASET}" \
  --max_number_of_iterations=1

【语义分割】Tensorflow deeplabv3+训练自己的数据集

在VIS_LOGDIR下的segmentation_result目录下就生成了预测的可视化结果。

【语义分割】Tensorflow deeplabv3+训练自己的数据集

原图

【语义分割】Tensorflow deeplabv3+训练自己的数据集

预测结果可视化

【语义分割】Tensorflow deeplabv3+训练自己的数据集

因为只是简单的说明整个流程,数据很少,效果不太好。

六、导出模型

CKPT_PATH="${TRAIN_LOGDIR}/model.ckpt-${NUM_ITERATIONS}"
EXPORT_PATH="${EXPORT_DIR}/frozen_inference_graph.pb"

python "${WORK_DIR}"/export_model.py \
  --logtostderr \
  --checkpoint_path="${CKPT_PATH}" \
  --export_path="${EXPORT_PATH}" \
  --model_variant="xception_65" \
  --atrous_rates=6 \
  --atrous_rates=12 \
  --atrous_rates=18 \
  --output_stride=16 \
  --decoder_output_stride=4 \
  --num_classes=4 \
  --crop_size=513 \
  --crop_size=513 \
  --inference_scales=1.0

【语义分割】Tensorflow deeplabv3+训练自己的数据集

在EXPORT_DIR目录下生成了pb文件

【语义分割】Tensorflow deeplabv3+训练自己的数据集

七、可能存在的一些问题

7.1 [`predictions` out of bound]错误

将models/research/deeplab/eval.py143行修改为如下内容

metric_map = {}
indices = tf.squeeze(tf.where(tf.less_equal(
        labels, dataset.num_classes - 1)), 1)
    labels = tf.cast(tf.gather(labels, indices), tf.int32)
    predictions = tf.gather(predictions, indices)
    # end of insert

    metric_map[predictions_tag] = tf.metrics.mean_iou(
        predictions, labels, dataset.num_classes, weights=weights)

7.2 数据集不平衡

参考链接3730#issuecomment-387100419

7.3 预测图片全黑

如果以上步骤都做了,预测可视化的图片还是全黑的话,检查一下数据集的标注图片是否有colormap,如果是,要用models/research/deeplab/datasets/remove_gt_colormap.py将colormap去掉。