PyTorch学习之路:ResNet-34实现CIFAR10分类
本代码参考廖星宇《深度学习入门之PyTorch》中的示例代码,手动拼接复现而来,仅供个人使用,侵删。
#ResNet实现CIFAR10分类
from datetime import datetime
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
from torch import optim
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
#定义ResNet基本模块-残差模块
def conv3x3(in_channel, out_channel, stride=1):
return nn.Conv2d(
in_channel,
out_channel,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
#Residual Block
class residual_block(nn.Module):
def __init__(self, in_channel, out_channel, stride=1, downsample=None):
super(residual_block, self).__init__()
self.conv1 = conv3x3(in_channel, out_channel, stride)
self.bn1 = nn.BatchNorm2d(out_channel)
self.conv2 = conv3x3(out_channel, out_channel)
self.bn2 = nn.BatchNorm2d(out_channel)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.conv1(x)
out = F.relu(self.bn1(out), True)
out = self.conv2(out)
out = F.relu(self.bn2(out), True)
if self.downsample:
residual = self.downsample(x)
out = out+residual
out = F.relu(out, True)
return out
class ResNet(nn.Module):
# 实现主module:ResNet34
# ResNet34 包含多个layer,每个layer又包含多个residual block
# 用子module来实现residual block,用_make_layer函数来实现layer
def __init__(self, num_classes=1000):
super(ResNet, self).__init__()
# 前几层图像转换
self.pre = nn.Sequential(
nn.Conv2d(3, 16, 3, 1, 1, bias=False),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, 2, 1))
# 重复的layer,分别有3,4,6,3个residual block
self.layer1 = self._make_layer(16, 16, 3)
self.layer2 = self._make_layer(16, 32, 4, stride=1)
self.layer3 = self._make_layer(32, 64, 6, stride=1)
self.layer4 = self._make_layer(64, 64, 3, stride=1)
self.fc = nn.Linear(256, num_classes) # 分类用的全连接
def _make_layer(self, inchannel, outchannel, block_num, stride=1):
# 构建layer,包含多个residual block
shortcut = nn.Sequential(nn.Conv2d(inchannel, outchannel, 1, stride, bias=False), nn.BatchNorm2d(outchannel))
layers = []
layers.append(residual_block(inchannel, outchannel, stride, shortcut))
for i in range(1, block_num):
layers.append(residual_block(outchannel, outchannel))
return nn.Sequential(*layers)
def forward(self, x):
x = self.pre(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = F.avg_pool2d(x, 7)
x = x.view(x.size(0), -1)
return self.fc(x)
#定义超参数
batch_size = 64
learning_rate = 1e-2
num_epoches = 20
if __name__ == '__main__':
#数据预处理
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
#下载训练集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
model = ResNet()
if torch.cuda.is_available():
model = model.cuda()
# 定义损失函数和优化函数
criterion = nn.CrossEntropyLoss() # 损失函数:损失函数交叉熵
optimizer = optim.SGD(model.parameters(), lr=learning_rate) # 优化函数:随机梯度下降法
# 训练网络
epoch = 0
for data in train_loader:
img, label = data
img = Variable(img)
if torch.cuda.is_available():
img = Variable(img).cuda()
label = Variable(label).cuda()
else:
img = Variable(img)
label = Variable(label)
# 前向传播
out = model(img)
loss = criterion(out, label)
# 反向传播
optimizer.zero_grad() # 梯度归零
loss.backward()
optimizer.step() # 更新参数
epoch = epoch+1
if (epoch) % 100 == 0:
print('*' * 10)
print('epoch{}'.format(epoch))
print('loss is {:.4f}'.format(loss.item()))
# 测试网络
model.eval()
eval_loss = 0
eval_acc = 0
for data in test_loader:
img, label = data
# img = img.view(img.size(0), -1)
img = Variable(img)
if torch.cuda.is_available():
img = Variable(img).cuda()
label = Variable(label).cuda()
else:
img = Variable(img)
label = Variable(label)
out = model(img)
loss = criterion(out, label)
eval_loss = eval_loss+loss.item() * label.size(0)
_, pred = torch.max(out, 1)
num_correct = (pred == label).sum()
eval_acc = eval_acc+num_correct.item()
print('Test Loss:{:.6f}, Acc:{:.6f}'.format(eval_loss / (len(test_dataset)), eval_acc / (len(test_dataset))))