PyTorch ------图像风格迁移学习原理
图像风格迁移学习原理
图像风格迁移学习介绍
- 利用算法将一张图片的风格样式,应用到另一张图画上的技术亦可以称为Neural-Style或者Neural-Transfer
- 该算法获取三张图片,即输入图片、内容图片和样式图片,然后更改输入以使其类似于内容图像的内容和样式图像的风格
基本原理
- 原理很简单我们定义两个表示距离的变量,一个表示输入图片和内容图片的距离(Dc),一个表示输入图片和样式图片的距离(Ds).即Dc测量输入和内容图片的内容差异的距离,Ds则测量输入和样式图片之间样式的差异距离.
- 最后我们将优化Dc和Ds使之最小,即完成图像风格转移
相关知识
Gram matrix
- Gram矩阵和协方差矩阵相似,差异在于Gram矩阵没有白化,直接使用两变量做内积
- Gram矩阵和相关系数矩阵叶相似,差异在于,没有白化,也没有标准化
- 总结上面说来就是Gram 矩阵相对于协方差矩阵和相关关系矩阵来说比较粗糙简单,但亦能表达其意思.
- 不了解协方差和相关关系的同学可以参考传送门,知乎赞最多的一篇文章????????????????????????????????????
VGG
- 提取图像风格和图像内容的图像是VGG19神经网络模型
- 这个可以参考上一片文章传送门????????????????????????????????????
- 对于VGG模型一般来说,越靠近输入层的卷积层输出越容易抽取图像的细节信息例如浅层的conv1_1,conv1_2,提取的特征通常是比较简单的线,角,靠近输出的卷积层输出的是全局的信息,特征比较复杂也可以认为是整体的信息
- 下图为VGG19的特征提取的结构
下面上代码时间
import time
import torch
import torch.nn.functional as F
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2 as cv
import os
import sys
import platform
#检测设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
print(torch.__version__)
Base_file_path = "./dataset/"
content_file_pathpng = "rainier.png"
style_file_pathpng = "autumn_oak.png"
content_file_path = "rainier.jpg"
style_file_path = "autumn_oak.jpg"
content_image = Image.open(os.path.join(Base_file_path,content_file_path))
plt.imshow(content_image)
plt.show()
style_image = Image.open(os.path.join(Base_file_path,style_file_path))
plt.imshow(style_image)
rgb_mean = np.array([0.485,0.456,0.406])
rgb_std = np.array([0.229,0.224,0.225])
def preprocess(PIL_image,image_shape):
process = torchvision.transforms.Compose([
torchvision.transforms.Resize(size = image_shape),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=rgb_mean,std=rgb_std)
])
return process(PIL_image).unsqueeze(dim=0)
def postprocess(img_tensor):
inv_normalize = torchvision.transforms.Normalize(
mean=-rgb_mean/rgb_std,
std=1/rgb_std
)
to_PIL_image = torchvision.transforms.ToPILImage()
return to_PIL_image(inv_normalize(img_tensor[0].cpu()).clamp(0,1))
#VGG19
pretrained_net = torchvision.models.vgg19(pretrained=True)
print(pretrained_net)
"""
为了抽取图像的内容特征和样式特征,我们可以选择VGG网络中某些层的输出。
一般来说,越靠近输入层的输出越容易抽取图像的细节信息,反之则越容易抽取图像的全局信息。
为了避免合成图像过多的保留内容图像的细节,我们选择VGG较靠近输出的层,也称为内容层,来输出图像的内容特征。
我们还从VGG中选择不同层的输出来匹配局部和全局的样式,这些层也叫样式层。
"""
#style layers 每个Block的第一个卷积层
"""
指定的特征层 可以优化
或许其他层的 提取样式 内容 会更好 也可以 更换模型 来对比 提取的样式内容
"""
style_layers,content_layers = [0,5,10,19,28],[25]
#提取特征
net_list = []
for i in range(max(content_layers + style_layers) + 1):
# 将 预训练 模型的 指定的层的特征 提取出来
net_list.append(pretrained_net.features[i])
#重新组成一个模型
net = torch.nn.Sequential(*net_list)
"""
给定输入X,如果简单调用前向计算net(X),只能获得最后一层的输出。
由于我们还需要中间层的输出,因此这里我们逐层计算,并保留内容层和样式层的输出。
"""
def extract_features(X,content_layers,style_layers):
contents = []
styles = []
for i in range(len(net)):
X = net[i](X)
if i in style_layers:
styles.append(X)
if i in content_layers:
contents.append(X)
return contents,styles
#将内容图片 的内容特征 提取出来
def get_contents(image_shape,device):
#将内容图片 经过处理后的矩阵
content_X = preprocess(content_image,image_shape).to(device)
#将内容 图片的 内容特征 提取出来
contents_Y ,_ = extract_features(content_X,content_layers,style_layers)
return content_X,contents_Y
#将样式图片 的样式特征 提取出来
def get_styles(image_shape,device):
style_X = preprocess(style_image,image_shape).to(device)
# 将样式图片 的样式特征 提取出来
_,styles_Y = extract_features(style_X,content_layers,style_layers)
return style_X,styles_Y
tmp_content, contents_Y = get_contents(1360, device)
result = postprocess(tmp_content)
plt.imshow(result)
plt.show()
#定义损失函数
"""
样式迁移的损失函数,它由内容损失,样式损失和总变差损失3部分组成
"""
"""
内容损失
与线性回归中的损失函数类似,内容损失通过平方差误差函数衡量合成图像与内容特征上的差异。
平方误差函数的两个输入均为 extract——features 函数计算所得到的内容层的输出
"""
#内容损失函数 比较的是 合成图像内容 和 提供 内容的图像的loss function
def content_loss(Y_hat,Y):
return F.mse_loss(Y_hat,Y)
"""
样式损失
样式损失也一样通过平方差误差函数衡量合成图像与样式图像在样式上的差异,为了表达样式层输出的样式,
我们先通过extract-features函数计算样式层的输出。假设该输出的样本数为1,通道数为C 高和宽分别为h和w,
我们可以把输出变换成c行hw列的矩阵X。矩阵X可以看作是由C个长度为hw的向量X1,。。。Xc
"""
"""
gram metrax 体现的是图片中 各个特征通道的相关性
"""
def gram(X):
num_channels,n = X.shape[1],X.shape[2] * X.shape[3]
X = X.view(num_channels,n)
return torch.matmul(X,X.t())/(num_channels*n)
#样式的损失函数 比较的 是 合成图像的样式 和样式图片提供的样式的 loss function
def style_loss(Y_hat,gram_Y):
return F.mse_loss(gram(Y_hat),gram_Y)
#总变差 损失
"""
合成图像里面有大量的高频噪点,即有特别亮或者特别暗的颗粒像素 。
一种常用的降噪方法是总变差降噪 降低总变差损失
"""
def tv_loss(Y_hat):
return 0.5 *(F.l1_loss(Y_hat[:,:,1:,:],Y_hat[:,:,:-1,:]) +
F.l1_loss(Y_hat[:,:,:,1:],Y_hat[:,:,:,:-1]))
content_weight,style_weight,tv_weight = 1,1e3,10
"""
样式迁移的损失函数 即 内容损失、样式损失和总变差损失函数的加权和
通过调节这些权值超参数我们可以权衡合成图像在保留内容、迁移样式以及降噪三方面的相对重要性
"""
def compute_loss(X,contents_Y_hat,styles_Y_hat,contents_Y,styles_Y_gram):
# 计算内容损失
contents_l = [content_loss(Y_hat,Y) * content_weight for Y_hat,Y in zip(contents_Y_hat,contents_Y)]
#计算 样式损失
styles_l = [style_loss(Y_hat,Y) * style_weight for Y_hat,Y in zip(styles_Y_hat,styles_Y_gram)]
# 计算 总变差损失
tv_l = tv_loss(X) * tv_weight
l = sum(styles_l) + sum(contents_l) + tv_l
return contents_l,styles_l,tv_l,l
class GeneratedImage(torch.nn.Module):
def __init__(self,image_shape):
super(GeneratedImage,self).__init__()
self.weight = torch.nn.Parameter(torch.rand(* image_shape))
def forward(self):
print("into here forward Generate Image")
return self.weight
def get_inits(X,device,lr,styles_Y):
gen_image = GeneratedImage(X.shape).to(device)
gen_image.weight.data = X.data
optimizer = torch.optim.Adam(gen_image.parameters(),lr = lr)
styles_Y_gram = [gram(Y) for Y in styles_Y]
return gen_image(),styles_Y_gram,optimizer
def train(X, contents_Y, styles_Y, device, lr, max_epochs, lr_decay_epoch):
print("training on ", device)
X, styles_Y_gram, optimizer = get_inits(X, device, lr, styles_Y)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, lr_decay_epoch, gamma=0.1)
for i in range(max_epochs):
start = time.time()
print( " epoch ",i)
XCopy = np.array(X.data)
print("equal metrix",(XCopy == np.array(X.data)).all())
contents_Y_hat, styles_Y_hat = extract_features(X, content_layers, style_layers)
contents_l, styles_l, tv_l, l = compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)
print("zero grad....")
optimizer.zero_grad()
l.backward(retain_graph=True)
print("backward....")
optimizer.step()
print("update ....")
print("equal metrix", (XCopy == np.array(X.data)).all())
scheduler.step()
if i % 50 == 0 and i != 0:
print('epoch %3d, content loss %.2f, style loss %.2f, '
'TV loss %.2f, %.2f sec'
% (i, sum(contents_l).item(), sum(styles_l).item(), tv_l.item(),
time.time() - start))
return X.detach()
#创建合成 图片的尺寸
image_shape = (150, 224)
#将模型 转化为 当前设备的 数据类型
net = net.to(device)
#contentX 合成图片的内容 载体 将 内容图片大小 内容设置为 合成图片尺寸
#contentY 是将 内容图片的 特征 提取出来
content_X, contents_Y = get_contents(image_shape, device)
#style_X 为合成图片的 样式载体 将 样式图片大小 内容 设置为 合成图片大小
#content_Y 为将样式图片的 特征 提取出来
style_X, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.01, 50, 20)
print(“well done”)
解答阶段
- 最后解答一下同学的问题:
- 有同学问:最后输出图片尺寸大于内容图片尺寸,那最后输出是不是不准确
- 答:在开始的时候,对图片做了Resize根据设置的大小对图片大小做处理,大小设置太大,图片可能会失真
- 同学问:gram matrix 为什么要归一化
- 答归一化的原因 ATA内积产生的数值过大这些较大的值将导致第一层在梯度下降期间具有较大的影响,可以使模型更深,所以归一化至关重要