基于keras的农业病害图像分类-lenet实现
实验流程:
- 问题分析
- 整体框架
- 数据准备
- 代码实现
- 实验结果
a. 问题分析
首先我们针对农业病害识别问题做一个简单分析,我们希望做出来的效果是:用户输入一张患病的植物图片,得到该病害所属类型,同时输出预测准确率accuracy和loss。
b.整体框架如下图:
c.数据准备
我们使用的数据均来自网络搜集,本次实验我们只选择了四种玉米病害图像进行实验,数据已经经过预处理,分布如下表:
真实数据是下面这样的:
d.代码实现
我们使用Keras编写的代码结构,实验中加入了Keras ImageDataGenerator 增强我们的数据。数据集读取方式直接采用Keras自带的,训练集和测试集数据各放在4个文件夹下,类别标签名分别设置为0、1、2、3.。全部代码如下:
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense
from keras import backend as K
# dimensions of our images.
img_width, img_height = 28, 28
train_data_dir = './train'
validation_data_dir = './test'
nb_train_samples = 2500
nb_validation_samples = 350
epochs = 150
batch_size = 20
if K.image_data_format() == 'channels_first':
input_shape = (3, img_width, img_height)
else:
input_shape = (img_width, img_height, 3)
# create model
model=Sequential()
model.add(Conv2D(filters=6, kernel_size=(5,5), padding='valid', input_shape=input_shape, activation='tanh'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Conv2D(filters=16, kernel_size=(5,5), padding='valid', activation='tanh'))
model.add(MaxPooling2D(pool_size=(2,2)))
#池化后变成16个4x4的矩阵,然后把矩阵压平变成一维的,一共256个单元。
model.add(Flatten())
#下面就是全连接层了
model.add(Dense(120, activation='tanh'))
model.add(Dense(84, activation='tanh'))
model.add(Dense(4, activation='softmax'))
model.compile(loss='categorical_crossentropy',
optimizer='sgd',
metrics=['accuracy'])
# this is the augmentation configuration we will use for training
train_datagen = ImageDataGenerator(
rescale=1. / 255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
# this is the augmentation configuration we will use for testing:
# only rescaling
test_datagen = ImageDataGenerator(rescale=1. / 255)
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical') #多分类
validation_generator = test_datagen.flow_from_directory(
validation_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical') #多分类
model.fit_generator(
train_generator,
steps_per_epoch=nb_train_samples // batch_size,
epochs=epochs,
validation_data=validation_generator,
validation_steps=nb_validation_samples // batch_size)
e.实验结果
我们使用GPU训练了150次,每次放入20张图片训练,最终得出了82%左右准确率,loss最终停在0.45左右,可以基本上完成对4类玉米病害的识别与检测。结尾放一张我的其他模型跑出来的图,可以达到96%左右正确率。