Mask R-CNN网络训练自己的数据

除了传统目标检测方法yolov系列,ssd,faster-rnnn等之外,基于像素分割的Mask-RCNN网络也可以做目标检测,尤其最近在做版面分析,看到有人用mask-rcnn做票据识别的效果不错,可以准确定位出票据上面的关键信息点,于是特意研究了一番。

mask-rcnn的GitHub地址:https://github.com/matterport/Mask_RCNN

需求说明:我有火车票和票据数据一共20张(自己用手机拍摄的),我现在想提取这些票据上面我关心的关键信息条目。

解决办法:使用mask-rcnn解决

第一步:数据准备,使用labelme标注我的数据,GitHub地址https://github.com/wkentaro/labelme#anaconda,我是mac机,所以按照mac机的安装方法执行了两条命令就装好了.安装完成后打开命令提示符,执行labelme即可出现如下界面

Mask R-CNN网络训练自己的数据

此为标注界面,点击左侧的“openDir"按钮,选择自己的数据文件夹目录,然后开始标注。标注的时候我只标注了我感兴趣的条目,并分别分类为1-7之间。

标注完后再数据目录会出现和图片命名一样的json文件,将这些json文件单独放到一个文件夹,然后编辑一个test.sh的文件,在里面写入如下内容:

#!/bin/bash
for((i=33;i<57;i++))
do
labelme_json_to_dataset   /Users/yjf/work/datasets/labelme/WechatIMG124${i}.json
done

该代码为将json文件转换为可以训练的文件数据。注意改一下第二行和第四行,我的数据有编号,所以这样写的,大家根据自己的写,总之让第四行的第二个参数为自己的json文件就行。然后在命令行以此执行chmod 777 test.sh 和sh test.sh即可开始转换。转换后会生成多个文件夹,文件夹里面存放的就是标注结果,每个文件夹有5个文件,有原图,标注结果等,大家标注完了自己看就行。然后将这些生成的文件夹单独存放在个文件夹中,此文件夹即为等会训练要用的文件夹了,我的目录为/Users/yjf/work/datasets/labelme, 注意,如果目录中有.DS_Store,记得删掉。

第二步:下载mask-rcnn:GitHub地址:https://github.com/matterport/Mask_RCNN,下载完后再下载coco预训练的模型https://github.com/matterport/Mask_RCNN/releases/download/v2.0/mask_rcnn_coco.h5,将其存放到下载好的Mask_RCNN目录中根目录下(和mrcnn在同一个目录下)。

第三步:修改网络。在sample目录下新建yjf_test目录,我在该目录下新建了一个railway_test.py的文件,里面的内容如下:


import os
import sys
import random
import yaml
import math
import re
import time
import numpy as np
import cv2
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt

# Root directory of the project
ROOT_DIR = os.path.abspath("./")

# Import Mask RCNN
sys.path.append(ROOT_DIR)  # To find local version of the library
from mrcnn.config import Config
from mrcnn import utils
import mrcnn.model as modellib
from mrcnn import visualize
from mrcnn.model import log

# %matplotlib inline 

# Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs_")

# Local path to trained weights file
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
# Download COCO trained weights from Releases if needed
if not os.path.exists(COCO_MODEL_PATH):
    utils.download_trained_weights(COCO_MODEL_PATH)

############################################################
#  Configurations
############################################################
class ShapesConfig(Config):
    """Configuration for training on the toy shapes dataset.
    Derives from the base Config class and overrides values specific
    to the toy shapes dataset.
    """
    # Give the configuration a recognizable name
    NAME = "shapes"

    # Train on 1 GPU and 8 images per GPU. We can put multiple images on each
    # GPU because the images are small. Batch size is 8 (GPUs * images/GPU).
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1

    # Number of classes (including background)
    NUM_CLASSES = 1 + 7  # background + 3 shapes注意这里要是你类别,我是七个类别,所以为7,外加背景1个

    # Use small images for faster training. Set the limits of the small side
    # the large side, and that determines the image shape.
    IMAGE_MIN_DIM = 256
    IMAGE_MAX_DIM = 256
    ##定义图片大小
    # Use smaller anchors because our image and objects are small
    RPN_ANCHOR_SCALES = (8*6, 16*6, 32*6, 64*6, 128*6)  # anchor side in pixels

    # Reduce training ROIs per image because the images are small and have
    # few objects. Aim to allow ROI sampling to pick 33% positive ROIs.
    TRAIN_ROIS_PER_IMAGE = 32

    # Use a small epoch since the data is simple
    STEPS_PER_EPOCH = 50

    # use small validation steps since the epoch is small
    VALIDATION_STEPS = 10
    
config = ShapesConfig()
config.display()


class ShapesDataset(utils.Dataset):
    """Generates the shapes synthetic dataset. The dataset consists of simple
    shapes (triangles, squares, circles) placed randomly on a blank surface.
    The images are generated on the fly. No file access required.
    """

    def get_obj_index(self, image):
        n = np.max(image)
        return n

    def from_yaml_get_class(self, image_id):
        info = self.image_info[image_id]
        with open(info['yaml_path']) as f:
            temp = yaml.load(f.read(), Loader=yaml.FullLoader)
            labels = temp['label_names']

            del labels[0]
        return labels


    def draw_mask(self, num_obj, mask, image,image_id):
        info = self.image_info[image_id]
        for index in range(num_obj):
            for i in range(info['width']):
                for j in range(info['height']):
                    at_pixel = image.getpixel((i, j))
                    if at_pixel == index + 1:
                        mask[j, i, index] = 1
        return mask


    def load_shapes(self,count, img_floder, imglist):
        """Generate the requested number of synthetic images.
        count: number of images to generate.
        height, width: the size of the generated images.
        """
        # Add classes
    # Add classes. We have only one class to add.
        self.add_class("shapes", 1, "1")
        self.add_class("shapes", 2, "2")
        self.add_class("shapes", 3, "3")
        self.add_class("shapes", 4, "4")

        self.add_class("shapes", 5, "5")
        self.add_class("shapes", 6, "6")
        self.add_class("shapes", 7, "7")
        ###按照自己的数据类别添加就行
        # Add images
        # Generate random specifications of images (i.e. color and
        # list of shapes sizes and locations). This is more compact than
        # actual images. Images are generated on the fly in load_image().
        for i in range(count):
            # print(imglist[i])
            filestr = imglist[i].split(".")[0]
         
            mask_path = dataset_root_path + filestr + "/label.png"
            yaml_path = dataset_root_path + filestr + "/info.yaml"
            print(dataset_root_path  + filestr + "/img.png", 'img_path')
            print(mask_path)
            print(yaml_path)
            ###打印这几个信息的意思是为了确定数据路径对不对
            cv_img = cv2.imread(dataset_root_path + filestr + "/img.png")
 
            self.add_image("shapes", image_id=i, path= dataset_root_path + filestr + "/img.png",
                           width=cv_img.shape[1], height=cv_img.shape[0], mask_path=mask_path, yaml_path=yaml_path)
        
    def load_mask(self, image_id):
        """Generate instance masks for shapes of the given image ID.
        """

        global iter_num
        print("image_id", image_id)
        info = self.image_info[image_id]
        count = 1  # number of object
        img = Image.open(info['mask_path'])
        num_obj = self.get_obj_index(img)
        mask = np.zeros([info['height'], info['width'], num_obj], dtype=np.uint8)
        mask = self.draw_mask(num_obj, mask, img, image_id)
        occlusion = np.logical_not(mask[:, :, -1]).astype(np.uint8)
        for i in range(count - 2, -1, -1):
            mask[:, :, i] = mask[:, :, i] * occlusion

            occlusion = np.logical_and(occlusion, np.logical_not(mask[:, :, i]))
        labels = []
        labels = self.from_yaml_get_class(image_id)
        labels_form = []
        for i in range(len(labels)):
            if labels[i].find("1") != -1:
                labels_form.append("1")
            elif labels[i].find("2") != -1:
                labels_form.append("2")
            elif labels[i].find("3") != -1:
                labels_form.append("3")
            elif labels[i].find("4") != -1:
                labels_form.append("4")
            elif labels[i].find("5") != -1:
                labels_form.append("5")
            elif labels[i].find("6") != -1:
                labels_form.append("6")
            elif labels[i].find("7") != -1:
                labels_form.append("7")

##这里你是几类你就照样子写几个就行,我是7类,所以写了七个

        class_ids = np.array([self.class_names.index(s) for s in labels_form])
        return mask, class_ids.astype(np.int32)


def get_ax(rows=1, cols=1, size=8):
    """Return a Matplotlib Axes array to be used in
    all visualizations in the notebook. Provide a
    central point to control graph sizes.
 
    Change the default size attribute to control the size
    of rendered images
    """
    _, ax = plt.subplots(rows, cols, figsize=(size * cols, size * rows))
    return ax

############################################################
#  Training
############################################################


dataset_root_path="/Users/yjf/work/datasets/labelme/"###存放训练数据的目录
img_floder = dataset_root_path

imglist = os.listdir(img_floder)
count = len(imglist)

dataset_train = ShapesDataset()
dataset_train.load_shapes(count, dataset_root_path, imglist)
dataset_train.prepare()

dataset_val = ShapesDataset()
dataset_val.load_shapes(2, dataset_root_path, imglist)
dataset_val.prepare()
 

# Create model in training mode
model = modellib.MaskRCNN(mode="training", config=config,
                          model_dir=MODEL_DIR)
 
# Which weights to start with?
init_with = "coco"  # imagenet, coco, or last
 
if init_with == "imagenet":
    model.load_weights(model.get_imagenet_weights(), by_name=True)
elif init_with == "coco":
    # Load weights trained on MS COCO, but skip layers that
    # are different due to the different number of classes
    # See README for instructions to download the COCO weights
    # print(COCO_MODEL_PATH)
    model.load_weights(COCO_MODEL_PATH, by_name=True,
                       exclude=["mrcnn_class_logits", "mrcnn_bbox_fc",
                                "mrcnn_bbox", "mrcnn_mask"])
elif init_with == "last":
    # Load the last model you trained and continue training
    model.load_weights(model.find_last()[1], by_name=True)
 
# Train the head branches
# Passing layers="heads" freezes all layers except the head
# layers. You can also pass a regular expression to select
# which layers to train by name pattern.
model.train(dataset_train, dataset_val,
            learning_rate=config.LEARNING_RATE,
            epochs=30,
            layers='heads')
 

model.train(dataset_train, dataset_val,
            learning_rate=config.LEARNING_RATE / 10,
            epochs=60,
            layers="all")

上述的中文标注的部分为自己要改的地方,其他地方看看就行

修改完成后在根目录执行python3 samples/yjf_test/railway_test.py训练(1,训练之前修改mrcnn/model.py文件中所有的keepdims为keep_dims,否则报错提醒然你改,改两处就好了。2,报错(op:’assign‘)with input shapes:[1024, 32], [1024, 324])什么的,是在加载预训练的coco模型的时候出错的,原因是你的tf和keras版本不对。。记得tf用1.3或者1.4都行,但是keras必须使用2.1.0,我用比这个版本高的keras不行)如果遇到其他错误很简单,一搜就是了,很简单的,问题很少很少

如下开始训练

Mask R-CNN网络训练自己的数据

训练结束在根目录下有一个log_目录。里面就是训练生成的.h5模型文件了

第四步:预测,在sample/yjf_test目录下新建一个railw_predict.py文件,文件代码如下:

# -*- coding: utf-8 -*-
import os
import sys
import random
import math
import numpy as np
import skimage.io
import matplotlib
import matplotlib.pyplot as plt
import cv2
import time

from datetime import datetime 
# Root directory of the project
ROOT_DIR = os.path.abspath("./")

# Import Mask RCNN
sys.path.append(ROOT_DIR)  # To find local version of the library
from mrcnn import utils
from mrcnn.config import Config
import mrcnn.model as modellib
from mrcnn import visualize
# Import COCO config
sys.path.append(os.path.join(ROOT_DIR, "samples/coco/"))  # To find local version
from samples.coco import coco

# Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs_")
###这里为自己训练的模型的地址和模型名称
# Local path to trained weights file
COCO_MODEL_PATH = os.path.join(ROOT_DIR ,"logs_", "shapes20190415T2121", "mask_rcnn_shapes_0001.h5")
# Download COCO trained weights from Releases if needed
# if not os.path.exists(COCO_MODEL_PATH):
#     utils.download_trained_weights(COCO_MODEL_PATH)
#     print("cuiwei***********************")

# 测试图片目录
IMAGE_DIR = "/Users/yjf/work/datasets/tmp/pic"

class ShapesConfig(Config):
    """Configuration for training on the toy shapes dataset.
    Derives from the base Config class and overrides values specific
    to the toy shapes dataset.
    """
    # Give the configuration a recognizable name
    NAME = "shapes"

    # Train on 1 GPU and 8 images per GPU. We can put multiple images on each
    # GPU because the images are small. Batch size is 8 (GPUs * images/GPU).
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1

    # Number of classes (including background)
    NUM_CLASSES = 1 + 7  # background + 7个类别

    # Use small images for faster training. Set the limits of the small side
    # the large side, and that determines the image shape.
    IMAGE_MIN_DIM = 256   #和训练的时候保持一致
    IMAGE_MAX_DIM = 256

    # Use smaller anchors because our image and objects are small
    RPN_ANCHOR_SCALES = (8 * 6, 16 * 6, 32 * 6, 64 * 6, 128 * 6)  # anchor side in pixels

    # Reduce training ROIs per image because the images are small and have
    # few objects. Aim to allow ROI sampling to pick 33% positive ROIs.
    TRAIN_ROIS_PER_IMAGE =32

    # Use a small epoch since the data is simple
    STEPS_PER_EPOCH = 50

    # use small validation steps since the epoch is small
    VALIDATION_STEPS = 10


config = ShapesConfig()
config.display()

#import train_tongue
#class InferenceConfig(coco.CocoConfig):
class InferenceConfig(ShapesConfig):
    # Set batch size to 1 since we'll be running inference on
    # one image at a time. Batch size = GPU_COUNT * IMAGES_PER_GPU
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1

config = InferenceConfig()

model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)


# Create model object in inference mode.
model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)

# Load weights trained on MS-COCO
model.load_weights(COCO_MODEL_PATH, by_name=True)

# COCO Class names
# Index of the class in the list is its ID. For example, to get ID of
# the teddy bear class, use: class_names.index('teddy bear')
class_names = ['BG', '1','2','3','4','5','6','7']
# Load a random image from the images folder
file_names = next(os.walk(IMAGE_DIR))[2]
##这里指定了要预测哪一张图片
image = skimage.io.imread(os.path.join(IMAGE_DIR, "WechatIMG12435.jpeg"))

a=datetime.now() 
# Run detection
results = model.detect([image], verbose=1)
b=datetime.now() 
# Visualize results
print("shijian",(b-a).seconds)
r = results[0]
visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'], 
                            class_names, r['scores'])
将里面为中文的地方改成自己的就行。

然后执行python3 samples/yjf_test/railway_predict.py 即可,如下为我预测的几个图片:

 

Mask R-CNN网络训练自己的数据

 Mask R-CNN网络训练自己的数据

可以看到训练效果还不错,我感兴趣的条目都能出来。训练和预测代码的百度网盘地址:https://pan.baidu.com/s/1SRYwODWLYfpLsPQREVIMtg