keras实现LFW测试
部分代码由朋友提供,如有侵权,请及时联系。
1、裁剪图像。
可以自己写,但有时候会出现漏检测。其次,也可以网上下载,但是需要写脚本处理图像的格式以及进行再分类。
网上下载地址:http://conradsanderson.id.au/lfwcrop/
2、读取pairs文件,生成自己的label文件,每行包含图像位置信息以及标签(0-不同人,1-同一个人)
读取
def read_pairs(self, pairs_filename):
pairs = []
f = open(pairs_filename, 'r')
while True:
line = f.readline().strip('\n').split()
if not line:
break
if len(line) ==3 or len(line) == 4:
pairs.append(line)
#print(pairs)
return pairs
生成
def get_paths(self, ori_path, pairs):
ori_path = 'E:/sign_system/lfw/'
file = open('E:/sign_system/execute_system/testcode/labelcrop_3.txt', 'w')
labellines = []
for i in range(0, len(pairs)):
if len(pairs[i]) == 3:
labelline = ori_path+pairs[i][0] + '/' + pairs[i][0] + '_' + \
'%04d' % int(pairs[i][1]) + '.jpg' + '\t' +ori_path + '/' + \
pairs[i][0] + '/' + pairs[i][0] + '_' +'%04d' % int(pairs[i][2])\
+ '.jpg' + '\t' + '1\n'
labellines.append(labelline)
elif len(pairs[i]) == 4:
labelline = ori_path+pairs[i][0] + '/' + pairs[i][0] + '_' + \
'%04d' % int(pairs[i][1]) + '.jpg' + '\t' + ori_path + '/' +\
pairs[i][2] + '/'+ pairs[i][2] + '_' + '%04d' % int(pairs[i][3])\
+ '.jpg' + '\t' + '0\n'
labellines.append(labelline)
else:
print("error!!!!")
file.writelines(labellines)
file.close()
3、再次读取文件,生成label文件中同一行的左右图像特征
读取label文件
def readImagelist(self,labelFile):
file = open(labelFile)
lines = file.readlines()
file.close()
left = []
right = []
labels = []
for line in lines:
path = line.strip('\n').split('\t')
#read left image
left.append(path[0])
#read right image
right.append(path[1])
#read label
labels.append(int(path[2]))
assert(len(left) == len(right))
assert(len(right) == len(labels))
return left, right, labels
提取特征
提取前需要在前面导入模型
self.model = Model_half()
path = '‘’(你需要导入模型的地址)
self.model = load_model(path)
提取
def extractFeature(self, leftImageList, rightImageList):
leftfeature = []
rightfeature = []
for i in range(0, len(leftImageList)):
if (i%200 == 0):
print("there are %d images done!"%i)
#读取左边图像,并提取特征
imagel = cv2.imread(leftImageList[i])
#图像标准化,为了提取特征
if K.image_data_format() == 'channels_first' and imagel.shape != (1, 3, 224, 224):
imagel = resize_image(imagel)
imagel = imagel.reshape((1, 224, 224, 3))
elif K.image_data_format() == 'channels_last' and imagel.shape != (1, 224, 224, 3):
imagel = resize_image(imagel)
imagel = imagel.reshape((1, 224, 224, 3))
imagel = imagel.astype('float32')
imagel /= 255.0
f1 = self.model.predict(imagel, batch_size = 128)[0]
leftfeature.append(f1)
#读取右边图像,并提取特征
imager = cv2.imread(rightImageList[i])
if K.image_data_format() == 'channels_first' and imager.shape != (1, 3, 224, 224):
imager = resize_image(imager)
imager = imager.reshape((1, 224, 224, 3))
elif K.image_data_format() == 'channels_last' and imager.shape != (1, 224, 224, 3):
imager = resize_image(imager)
imager = imager.reshape((1, 224, 224, 3))
imager = imager.astype('float32')
imager /= 255.0
f2 =self.model.predict(imager, batch_size = 128)[0]
rightfeature.append(f2)
return leftfeature, rightfeature
4、计算余弦相似度并做归一化
注意:余弦相似度与余弦距离的区别,可以参考我的文章:https://blog.****.net/u010847579/article/details/88893107
求出余弦相似度
dis = 1-pw.pairwise_distances(leftfeature, rightfeature, metric='cosine')
distance = np.empty((len(labels),))
for i in range(len(labels)):
distance[i] = dis[i][i]
余弦相似度归一化(这一步也可以不做,看自己的需求)
distance_norm = np.empty((len(labels)))
for i in range(len(labels)):
distance_norm[i] = (distance[i]-np.min(distance))/(np.max(distance)-np.min(distance))
5、计算不同阈值下的精确度,确定最佳精度以及生成tpr,fpr的关系图
计算精确度
def calculate_accuracy(self,distance, labels, num):
accuracy = {}
predict = np.empty((num,))
threshold = 0.1
while threshold <= 0.9:
for i in range(num):
if distance[i] >= threshold:
predict[i] = 1
else:
predict[i] = 0
predict_right =0.0
for i in range(num):
if predict[i] == labels[i]:
predict_right += 1.0
current_accuracy = (predict_right / num)
accuracy[str(threshold)] = current_accuracy
threshold = threshold + 0.001
#将字典按照value排序
temp = sorted(accuracy.items(), key = lambda d:d[1], reverse = True)
highestAccuracy = temp[0][1]
thres = temp[0][0]
return highestAccuracy, thres
生成
fpr, tpr, thresholds = sklearn.metrics.roc_curve(labels, distance_norm)
绘制roc
def draw_roc_curve(self, fpr,tpr,title='cosine',save_name='roc_lfw'):
plt.figure()
plt.plot(fpr, tpr)
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic using: '+title)
plt.legend(loc="lower right")
pathplt = ''(保存的地址)
plt.savefig(pathplt)
plt.show()
效果,随便拿的一个轻量化模型。
如图,可以看到最高准确率以及对应的阈值。
以上差不多就整体完成了,如有疑问,可以私信留言。