Object Detection API 安装以及使用总结(二)
使用Object Detection API训练自己的数据
迟来的填坑
1.数据标注和存放
一般目标检测的标注工具使用labelImg,每一张图会对应生成一个XML文件。网上博客大多采用PASCAL VOC标准的文件存放方式来放置文件的位置,实际用的时候我觉得比较冗余,这里只需要准备以下数据:
-
JPEGImages
存放原始的训练数据(图片) -
Annotations
存放标记生成的xml文档 -
.pbtxt文件
根据标记的类别名称,将所有的数据类别手动按如下格式的编辑.pbtxt格式的文件
2.数据格式转化
为了训练的高效处理,Object Detection API 需要将训练数据转化成tfrecord
按照上述形式放置好文件后使用如下代码进行格式转换
# create by guanqiuyu
#
# 2018.10.26
#
# Harbin Institute of Technology Visual Technology Labrary
# ==============================================================
import argparse
import tensorflow as tf
import logging
import os
import io
import PIL.Image
import hashlib
import time
import random
from object_detection.utils import dataset_util
from object_detection.utils import label_map_util
from lxml import etree
parser = argparse.ArgumentParser()
parser.add_argument('--ratio', required = True, type = float ,help = 'the ratio of train and val')
parser.add_argument('--path_data', required = True, help = 'the dir of your training data')
parser.add_argument('--path_label_dict', required = True, help = 'the dir of your .pbxt')
opt = parser.parse_args()
SETS = ['train','val','trainval','test']
def IsSubString(SubStrList, Str):
flag = True
for substr in SubStrList:
if not (substr in Str):
flag = False
return flag
def GetFileList(FindPath, FlagStr=[]):
'''get list of all id in the dir'''
import os
FileList = []
FileNames = os.listdir(FindPath)
if (len(FileNames) > 0):
for fn in FileNames:
if (len(FlagStr) > 0):
if (IsSubString(FlagStr, fn)):
FileList.append(fn[:-4])
else:
FileList.append(fn)
if (len(FileList) > 0):
FileList.sort()
return FileList
def partial_data_set(path_usr,path,ratio):
'''split data set'''
f = open(path)
train_list = open(path_usr + '/train.txt', 'w+')
val_list = open(path_usr + '/val.txt', 'w+')
for line in f:
if random.random()<ratio:
val_list.write(line)
else:
train_list.write(line)
f.close()
train_list.close()
val_list.close()
def get_list(path_txt,path_user):
'''Generate the list of Image data'''
with open(path_txt, 'w+') as list_file: # 数据集的图片list保存路径
if not os.path.exists('%s/Annotations/' % path_user):
os.makedirs('%s/Annotations/' % path_user)
image_ids = GetFileList(path_user + '/Annotations/', ['xml'])
for image_id in image_ids:
print(image_id)
list_file.write(image_id+'\n')
def covert_data_to_tfrecord(data,label_dict,path_data,pathname = 'JPEGImages'):
img_path = os.path.join(path_data,pathname,data['filename'])
with open(img_path,'rb') as f:
image_data = f.read()
# image_data_io = io.BytesIO(image_data)
# image = PIL.Image.open(image_data_io)
'''sha256加密'''
key = hashlib.sha256(image_data).hexdigest()
width = int(data['size']['width'])
height = int(data['size']['height'])
xmax = []
xmin = []
ymax = []
ymin = []
classes = []
classes_text = []
truncated = []
poses = []
difficult_obj = []
if 'object' in data:
for obj in data['object']:
difficult = bool(int(obj['difficult']))
difficult_obj.append(int(difficult))
xmin.append(float(obj['bndbox']['xmin']) / width)
ymin.append(float(obj['bndbox']['ymin']) / height)
xmax.append(float(obj['bndbox']['xmax']) / width)
ymax.append(float(obj['bndbox']['ymax']) / height)
classes_text.append(obj['name'].encode('utf8'))
classes.append(label_dict[obj['name']])
truncated.append(int(obj['truncated']))
poses.append(obj['pose'].encode('utf8'))
example = tf.train.Example(features = tf.train.Features(feature = {
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/source_id': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
'image/encoded': dataset_util.bytes_feature(image_data),
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
'image/object/truncated': dataset_util.int64_list_feature(truncated),
'image/object/view': dataset_util.bytes_list_feature(poses),
}))
return example
def main():
# if opt.mode not in SETS:
# raise ValueError('Please input true value from {}'.format(SETS))
#create dataset.txt
path_data = opt.path_data
path_dataset_txt = os.path.join(path_data,'dataset.txt')
get_list(path_dataset_txt,path_data)
#use the dataset.txt to generate train.txt and val.txt
partial_data_set(path_data,path_dataset_txt,opt.ratio)
datas = ['train','val']
#get dict from the .pbxt
label_dict = label_map_util.get_label_map_dict(opt.path_label_dict)
for i_data , data in enumerate(datas):
writer = tf.python_io.TFRecordWriter(os.path.join(path_data,data + '.tfrecords'))
print(i_data ,'------', data)
logging.info('open the '+ data + '.txt')
datalist_from_txt = dataset_util.read_examples_list(os.path.join(path_data,data+'.txt'))
for i,datalist in enumerate(datalist_from_txt):
if i % 100 == 0:
logging.info('image %d ----total %d',i,len(datalist_from_txt))
# data_name = datalist.split('.')[0].split('/')[-1]
path = os.path.join(path_data,'Annotations', datalist + '.xml')
with open(path,'r') as f:
xml_str = f.read()
xml_content = etree.fromstring(xml_str)
data = dataset_util.recursive_parse_xml_to_dict(xml_content)['annotation']
tf_example = covert_data_to_tfrecord(data,label_dict,path_data)
writer.write(tf_example.SerializeToString())
writer.close()
if __name__ == '__main__':
# print(time.strftime("%b+ %B +%a %A",time.localtime()))
main()
执行命令python xml_to_record.py --ratio #R --path_data #D --path_label_dict #T.pbxt
- #R —— 分割数据集的比例 ,如训练集:验证集 = 9:1 ,则设为0.1
- #D —— 存放Annotations 和 JPEGImages的文件目录
- #T —— .pbxt文件的地址
比如我的设置为python xml_to_record.py --ratio 0.1 --path_data /home/stargrain/data/train --path_label_dict /home/stargrain/data/train/label_map.pbxt
执行不报错,会生成train.tfrecords 和 val.tfrecords 两个文件
第二步完成
3.下载模型并训练
我是地址
在model_zoos选择需要的网络模型,下载并解压,更改pipeline.config文件
num_classes 改为自己数据的类别数
最下面的路径,finetune的权重、.pbxt文件、两个tfrecods文件 的路径修改为自己的
最后cd 到models/research/object_detection/legacy目录下找到train.py文件
使用python train.py --logtostderr --train_dir=输出目录地址 --pipeline_config_path=刚刚修改的config文件地址
PS.
注意这里输出目录千万不要和加载的权重文件在同一个目录下,否则会报错!!!
注意这里输出目录千万不要和加载的权重文件在同一个目录下,否则会报错!!!
注意这里输出目录千万不要和加载的权重文件在同一个目录下,否则会报错!!!
以上