8,训练模型

进入 models/research路径

python deeplab/train.py \
    --logtostderr \
    --training_number_of_steps=1000 \
    --train_split="train" \
    --model_variant="xception_65" \
    --atrous_rates=6 \
    --atrous_rates=12 \
    --atrous_rates=18 \
    --output_stride=16 \
    --decoder_output_stride=4 \
    --train_crop_size="513,513" \
    --train_batch_size=2 \
    --fine_tune_batch_norm=false \
    --dataset="mydata" \
    --tf_initial_checkpoint='/home/lw/data/cityscapes/deeplabv3_cityscapes_train/model.ckpt' \
    --train_logdir='/home/lw/data/mydata/train' \
    --dataset_dir='/home/lw/data/mydata/tfrecord'

模型存在于/home/lw/data/mydata/train

训练过程很烧GPU,发热,注意散热

9.验证模型

python deeplab/eval.py \
  --logtostderr \
  --eval_split="val" \
  --model_variant="xception_65" \
  --atrous_rates=6 \
  --atrous_rates=12 \
  --atrous_rates=18 \
  --output_stride=16 \
  --decoder_output_stride=4 \
  --dataset="mydata" \
  --checkpoint_dir='/home/lw/data/mydata/train' \
  --eval_logdir='/home/lw/data/mydata/eval' \
  --dataset_dir='/home/lw/data/mydata/tfrecord' 

默认只有miou评价标准的值,读者可自行加入其他评价指标。比如,accuracy,precision,recall,f1_score的值

运行结果如下:

 

10, 预测模型

python deeplab/vis.py \
  --logtostderr \
  --vis_split="val" \
  --model_variant="xception_65" \
  --atrous_rates=6 \
  --atrous_rates=12 \
  --atrous_rates=18 \
  --output_stride=16 \
  --decoder_output_stride=4 \
  --dataset="mydata" \
  --checkpoint_dir='/home/lw/data/mydata/train' \
  --vis_logdir="/home/lw/data/mydata/vis" \
  --dataset_dir="/home/lw/data/mydata/tfrecord"

11.模型导出.pd格式,并用前面博客显示代码显示出来

python deeplab/export_model.py \
    --logtostderr \
    --checkpoint_path="/home/lw/data/mydata/train/model.ckpt-1000" \
    --export_path="/home/lw/data/mydata/pb/frozen_inference_graph.pb"  \
    --model_variant="xception_65"  \
    --atrous_rates=6  \
    --atrous_rates=12  \
    --atrous_rates=18   \
    --output_stride=16  \
    --decoder_output_stride=4  \
    --num_classes=2 \
    --inference_scales=1.0

我们发现问题:预测结果和实际位置不匹配。

import os  
import cv2
file_path = "/home/lw/data/mydata/mask/"
list_path = os.listdir(file_path)
for i in range(0, len(list_path)):
    path = os.path.join(file_path, list_path[i])
    if os.path.isfile(path) & path.endswith('.png'):
        image = cv2.imread(path, -1)  
        image = cv2.resize(image, (512, 512) )  
        cv2.imwrite(path, image) 

把原始图像image和标签图像mask重新整理成大小一样的,再重新生成tfrecord,训练模型,预测,再做一遍。

第一张比较准,是因为这是训练集中的数据进行测试。

第二张不准,说明模型训练的还不好,可能的原因是数据量太小了,总共才10张训练数据。后序完善数据。

12,边缘提取

import cv2
import numpy as np

img = cv2.imread("/home/lw/data/mydata/vis/segmentation_results/000000_prediction.png", 0)

# #(3, 3)表示高斯矩阵的长与宽都是3,标准差取0
# img = cv2.GaussianBlur(img,(3,3),0)

#image:源图像     threshold1:阈值1    threshold2:阈值2
#其中,较大的阈值2用于检测图像中明显的边缘,但一般情况下检测的效果不会那么完美,边缘检测出来是断断续续的。所以这时候用较小的第一个阈值用于将这些间断的边缘连接起来。
canny = cv2.Canny(img, 0, 150)
cv2.imshow('Canny', canny)
cv2.waitKey(0)
cv2.destroyAllWindows()

参考博客https://www.jianshu.com/p/4f33821d28ba

 

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐