Faster批量测试并保存\所有类一起显示\groundtrue显示
原来的功能只是单张图片并且按类显示出框,修改后批量测试并保存,所有类的检测结果绘制在一张图片上,且同时将groundtrue一起绘制出来。
#!/usr/bin/env python
# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Xinlei Chen, based on code from Ross Girshick
#qhy
#2018.10.26
# --------------------------------------------------------
"""
Demo script showing detections in sample images.
See README.md for installation instructions before running.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import xml.etree.ElementTree as ET
import _init_paths
from model.config import cfg
from model.test import im_detect
from model.nms_wrapper import nms
from utils.timer import Timer
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import os, cv2
import argparse
from nets.vgg16 import vgg16
from nets.resnet_v1 import resnetv1
CLASSES = ('__background__', # always index 0
'normal bolt','normal bolt-2','normal bolt-3','nut losing','nut losing-2','nut directly loosening','pin closing','visible pin losing','visible pin losing-2','invisible pin losing','invisible pin losing-2')
NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',),'res101': ('res101_faster_rcnn_iter_100000.ckpt',)}
DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}
def vis_detections(image_name, im, class_name, dets, thresh=0.5):
"""Draw detected bounding boxes."""
inds = np.where(dets[:, -1] >= thresh)[0]
if len(inds) == 0:
return
im = im[:, :, (2, 1, 0)]
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(im, aspect='equal')
for i in inds:
bbox = dets[i, :4]
score = dets[i, -1]
ax.add_patch(
plt.Rectangle((bbox[0], bbox[1]),
bbox[2] - bbox[0],
bbox[3] - bbox[1], fill=False,
edgecolor='red', linewidth=1.5)
)
ax.text(bbox[0], bbox[1] - 2,
'{:s} {:.3f}'.format(class_name, score),
bbox=dict(facecolor='blue', alpha=0.5),
fontsize=14, color='white')
ax.set_title(('{} detections with '
'p({} | box) >= {:.1f}').format(class_name, class_name,
thresh),
fontsize=14)
# plt.axis('off')
# plt.tight_layout()
# plt.draw()
# image_name=image_name.replace('jpg','png')
# plt.savefig('/home/omnisky/q/tf-faster-rcnn-master/data/result/'+image_name)
# print("save image to /home/omnisky/q/tf-faster-rcnn-master/data/result/{}".format(image_name))
def demo(image_name, sess, net):
"""Detect object classes in an image using pre-computed object proposals."""
# Load the demo image
im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
im = cv2.imread(im_file)
# Detect all object classes and regress object bounds
timer = Timer()
timer.tic()
scores, boxes = im_detect(sess, net, im)
timer.toc()
print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))
# Visualize detections for each class
CONF_THRESH = 0.7
thresh=0.7
NMS_THRESH = 0.3
im = im[:, :, (2, 1, 0)]
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(im, aspect='equal', alpha=0.75)
###
xml_path='/home/omnisky/q/tf-faster-rcnn-master/data/VOCdevkit2007/VOC2007/Annotations/'
xml_name=image_name.replace('jpg','xml')
xml_file=os.path.join(xml_path+xml_name)
tree=ET.parse(xml_file)
root=tree.getroot()
for object in root.findall('object'):
a=[]
a.append(object.find('name').text)
a.append(int(object.find('bndbox').find('xmin').text))
a.append(int(object.find('bndbox').find('ymin').text))
a.append(int(object.find('bndbox').find('xmax').text))
a.append(int(object.find('bndbox').find('ymax').text))
ax.add_patch(
plt.Rectangle((a[1],a[2]),a[3]-a[1],a[4]-a[2],fill=False,edgecolor='g', linewidth=2.5)
)
ax.text(a[1], a[2] - 5,'{:s}'.format(a[0]),bbox=dict(facecolor='yellow', alpha=0.5),fontsize=14, color='black')
###
for cls_ind, cls in enumerate(CLASSES[1:]):
cls_ind += 1 # because we skipped background
cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
cls_scores = scores[:, cls_ind]
dets = np.hstack((cls_boxes,
cls_scores[:, np.newaxis])).astype(np.float32)
keep = nms(dets, NMS_THRESH)
dets = dets[keep, :]
# vis_detections(image_name, im, cls, dets, thresh=CONF_THRESH)
inds = np.where(dets[:, -1] >= thresh)[0]
if len(inds) == 0:
continue
for i in inds:
bbox = dets[i, :4]
score = dets[i, -1]
ax.add_patch(
plt.Rectangle((bbox[0], bbox[1]),
bbox[2] - bbox[0],
bbox[3] - bbox[1], fill=False,
edgecolor='red', linewidth=2.5)
)
ax.text(bbox[0], bbox[1] - 2,
'{:s} {:.3f}'.format(cls, score),
bbox=dict(facecolor='blue', alpha=0.5),
fontsize=14, color='white')
plt.axis('off')
plt.tight_layout()
plt.draw()
image_name=image_name.replace('jpg','png')
plt.savefig('/home/omnisky/q/tf-faster-rcnn-master/data/result/'+image_name)
print("save image to /home/omnisky/q/tf-faster-rcnn-master/data/result/{}".format(image_name))
def parse_args():
"""Parse input arguments."""
parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')
parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
choices=NETS.keys(), default='res101')
parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
choices=DATASETS.keys(), default='pascal_voc_0712')
args = parser.parse_args()
return args
if __name__ == '__main__':
cfg.TEST.HAS_RPN = True # Use RPN for proposals
args = parse_args()
# model path
demonet = args.demo_net
dataset = args.dataset
tfmodel = ('/home/omnisky/q/tf-faster-rcnn-master/output/res101/voc_2007_trainval/default/res101_faster_rcnn_iter_100000.ckpt')
if not os.path.isfile(tfmodel + '.meta'):
raise IOError(('{:s} not found.\nDid you download the proper networks from '
'our server and place them properly?').format(tfmodel + '.meta'))
# set config
tfconfig = tf.ConfigProto(allow_soft_placement=True)
tfconfig.gpu_options.allow_growth=True
# init session
sess = tf.Session(config=tfconfig)
# load network
if demonet == 'vgg16':
net = vgg16()
elif demonet == 'res101':
net = resnetv1(num_layers=101)
else:
raise NotImplementedError
net.create_architecture("TEST", 12,
tag='default', anchor_scales=[8, 16, 32])
saver = tf.train.Saver()
saver.restore(sess, tfmodel)
print('Loaded network {:s}'.format(tfmodel))
fi=open('/home/omnisky/q/tf-faster-rcnn-master/data/VOCdevkit2007/VOC2007/ImageSets/Main/test.txt')
txt=fi.readlines()
im_names = []
for line in txt:
line=line.strip('\n')
line=line.replace('\r','')
line=(line+'.jpg')
im_names.append(line)
print(im_names)
fi.close()
for im_name in im_names:
print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
print('Demo for data/demo/{}'.format(im_name))
demo(im_name, sess, net)
# plt.show()
绿色框是真实标签,真实标签的名字用黄色框显示的。