风格迁移

风格迁移


1. 什么是风格迁移

1.1 风格迁移任务介绍


简而言之,风格迁移即是将一种图像的画风(颜色、纹理等)融合到另一幅图像中,使得融合后的图像能够以风格图像的色彩风格呈现原始图像的内容(任务、景物等前景)。具体示例可见下方图示。

向风格迁移网络中输入风格图像style和内容图像content,网络便会输出融合stylecontent的新图像style transfer

1.2 风格迁移网络


基于深度学习的风格迁移任务的难点在于如何度量图片的风格,如若找到一种可以度量图像风格的数学模型,便可以基于反向传播训练风格迁移网络,以达到良好的风格迁移效果。

  1. Gram Matrix

2015年Gatys利用Gram Matrix实现了图像风格的定量计算方法,现在大多数图像风格迁移的风格度量均采用Gram Matrix或者Gram Matrix的改进版本。因为本文所用风格度量方法也是基于Gram Matrix的,于是先简单介绍一下什么是Gram Matrix。

gram

当输入一幅图像(2xWxH)进入卷积神经网络时,图像经过卷积层会得到CxWxH的图像特征,在特征的每个通道特征1xWxH上保留了图像的高语义特征(包括内容和风格),但是不同通道之间的学习到的特征内容是不同。比如输入是一幅人脸图像,经过卷积网络之后可能通道1学习到了人脸鼻子特征,而通道2学习到了眼睛特征,它们的特征主体是不一样的,但同时他们来源于同一幅图像,他们的风格应该是一样的,由此Gatys通过计算不同通道特征向量W*H的的协方差矩阵(CxC)定量刻画图像的风格。具体做法是将卷积网络中特征图每个通道的特征reshape成一个向量(1xWxH->1xWH)然后计算不同通道特征的(偏心)协方差矩阵,这个就矩阵就是Gram Matrix。操作示意图如上所示(图源CSDN):

  1. 风格转换网络结构

使用的风格转换网络来源于Cui的《Multi-style Transfer: Generalizing Fast Style Transfer to Several Genres》,网络结构图如下所示。

从图中可以看到风格转换网络由两个网络拼接而成,Image Transform Net用于向陌生的图像融合学习到的风格特征;Loss Network仅用在网络训练阶段,借用Loss Network以得到图像的内容损失和风格损失,以此更新模型参数。具体而言,输入陌生图片$x$(一般情况下也是内容图像),经过Image Transform Net得到融合新风格的图像$\hat{y}$,然后将风格迁移后图像$\hat{y}$、风格图像$y_{s}$和内容图像(也即原始图像)$y_{c}$分别送入Loss NetworkLoss Network是VGG-16的特征提取网络。分别保存三幅图像在前、后卷积层的输出响应(特征图),针对风格迁移后图像$\hat{y}$和风格图像$y_{s}$分别计算其Gram Matrix,然后利用MSELoss()度量风格差异;针对风格迁移后图像$\hat{y}$和内容图像$y_{c}$直接使用MSELoss()度量图像之间特征主体的差异。

2. 基于Pytorch的快速风格迁移实例


环境:windows 10 + RTX3060 + CUDA 11.4

requirements:

1
2
3
4
5
6
matplotlib==3.4.3
numpy==1.21.4
Pillow==8.4.0
torch==1.10.0
torchvision==0.11.1
tqdm==4.62.3

代码主要参考:https://blog.csdn.net/weixin_48866452/article/details/109309245


2.1 数据集构建


对于风格迁移任务来说训练样本不需要太多,最少两张图片便可完成风格转换的训练。本次选取了3张风格图片与6张内容图片构建风格迁移的数据集。风格和内容图片示例如下:

样本图片来自于这个GitHub仓库

具体的内容数据集加载方式如下方代码所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class Content_Dataset(data.Dataset):
def __init__(self) -> None:
super(Content_Dataset, self).__init__()
r_path = 'dataset/content/' # 存放内容图片的路径
c_img_name = os.listdir(r_path)
self.c_img_path = [r_path+i for i in c_img_name]
self.transforms = T.Compose([
T.Resize((512, 512)),
T.ToTensor(),
])

def __getitem__(self, index):
c_img = Image.open(self.c_img_path[index])

return self.transforms(c_img)

def __len__(self):
return len(self.c_img_path)

2.2 风格迁移网络


迁移网络采用自编码结构,先用卷积将图像尺度缩小并学习高语义特征,再利用上采样将图像尺度放大,保证图像风格变换前后大小不变。迁移网络的Pytorch实现如下:

  1. TransNet
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class ResBlock(nn.Module):
def __init__(self, channels):
super(ResBlock, self).__init__()
self.layer = nn.Sequential(
nn.Conv2d(channels, channels, 3, 1, 1, bias=False),
nn.InstanceNorm2d(channels), # 在图像像素上,对每个通道的HW做归一化。
nn.ReLU(),
nn.Conv2d(channels, channels, 3, 1, 1, bias=False),
nn.InstanceNorm2d(channels)
)
self.relu = nn.ReLU()

def forward(self, x):
return self.relu(self.layer(x)+x)


class TransNet(nn.Module):

def __init__(self):
super(TransNet, self).__init__()
self.layer = nn.Sequential(
nn.Conv2d(3, 32, 9, 1, 4, bias=False),
nn.InstanceNorm2d(32),
nn.ReLU(),
nn.Conv2d(32,64,3,2,1, bias=False),
nn.InstanceNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, 3, 2, 1, bias=False),
nn.InstanceNorm2d(128),
nn.ReLU(),
ResBlock(128),
ResBlock(128),
ResBlock(128),
ResBlock(128),
ResBlock(128),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(128,64,3,1,1, bias=False),
nn.InstanceNorm2d(64),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(64, 32, 3, 1, 1, bias=False),
nn.InstanceNorm2d(32),
nn.ReLU(),
nn.Conv2d(32,3,9,1,4),
nn.Sigmoid()
)

def forward(self, x):
return self.layer(x)
  1. LossNet
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class Vgg16(nn.Module):
def __init__(self) -> None:
super(Vgg16, self).__init__()
# 直接加载预训练的VGG16网络提取图像特征
loss_net = M.vgg16(pretrained=True)
loss_net = loss_net.features
self.feature1 = loss_net[:4]
self.feature2 = loss_net[4:9]
self.feature3 = loss_net[9:16]
self.feature4 = loss_net[16:23]

def forward(self, x):
feature1 = self.feature1(x)
feature2 = self.feature2(feature1)
feature3 = self.feature3(feature2)
feature4 = self.feature4(feature3)

return feature1, feature2, feature3, feature4

2.3 必要的工具函数


因为风格迁移没有定量评估迁移好坏的指标(一般都是通过人眼自行观察风格迁移结果去确定迁移性能),所以这里的工具函数目前包括Gram矩阵的计算函数和训练过程可视化函数。

1
2
3
4
5
6
# 定义gram矩阵
def get_gram_matrix(feature_map):
n, c, h, w = feature_map.shape
feature_map = feature_map.reshape(n*c, h*w)
gram_matrix = t.mm(feature_map, feature_map.t())
return gram_matrix.div(n*c*h*w)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# 训练过程损失可视化
def training_process_visualization(data):
style_loss = data['style_loss']
content_loss = data['content_loss']
total_loss = data['total_loss']

plt.figure(1)
plt.plot(range(len(content_loss)), content_loss)
plt.title('content_loss')
plt.ylabel('content_loss')
plt.xlabel('epoch')
plt.savefig('content_loss.png')

plt.figure(2)
plt.plot(range(len(style_loss)), style_loss)
plt.title('style_loss')
plt.ylabel('style_loss')
plt.xlabel('epoch')
plt.savefig('style_loss.png')

plt.figure(3)
plt.plot(range(len(total_loss)), total_loss)
plt.title('total_loss')
plt.ylabel('total_loss')
plt.xlabel('epoch')
plt.savefig('total_loss.png')
plt.show()

通常在工具函数中还包括了存储网络超参数的config.py文件,本次风格迁移使用的超参数文件如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from torchvision import transforms as T
class Config():
# style_img = 'dataset/style/hosi.jpg'
# style_img = 'dataset/style/la_muse.jpg'
# style_img = 'dataset/style/trial.jpg'
style_img = 'dataset/style/sketch.png'
nw = 0 # 多线程加载数据集(windows多线程加载有问题,所以改成了0)
bs = 3 # batchsize
epochs = 150
lr = 0.001
wc = 1 # 内容损失的权重
ws = 100000 # 风格损失的权重
result_path = 'checkpoints' # 训练结果保存文件夹
save_frequency = 30 # 每30个epoch保存一次
trans = T.ToTensor()

2.4 网络训练及测试


**网络训练:**通过更改config.py文件中的style_img路径更改训练的风格图片,然后完成2.1节所示3种风格的学习。具体的train.py如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from PIL import Image
import torch.nn as nn
import torch as t
import os
from tqdm._tqdm import trange
from torch.utils import data
from torchvision.utils import save_image
from torchvision import transforms as T

from dataset.my_dataset import Content_Dataset
from utils.config import Config
from utils.tools import get_gram_matrix, training_process_visualization
from model.my_net import TransNet, Vgg16


# 初始化超参数实例
opt = Config()
device = t.device('cuda' if t.cuda.is_available() else 'cpu')

# 加载风格图片
img_style = opt.trans(Image.open(opt.style_img)).unsqueeze(0)
img_style = img_style.expand(opt.bs, img_style.shape[1], img_style.shape[2], img_style.shape[3])
img_style = img_style.to(device)

# 创建结果保存文件夹
# 创建文件夹
if not os.path.exists(opt.result_path):
os.mkdir(opt.result_path)

# 初始化内容图片加载器
content_dataset = Content_Dataset()
content_loader = data.DataLoader(content_dataset, opt.bs, num_workers=opt.nw)

# 定义模型、优化器、损失函数
trans_net = TransNet().to(device)
loss_net = Vgg16().to(device).eval()


optimizer = t.optim.AdamW(trans_net.parameters(), lr=opt.lr)
loss_func = nn.MSELoss().to(device)


# 利用Vgg16和Gram Matrix度量风格
# 此处不用管风格图片的大小因为最后生成的gram matrix大小均为NCxNC
styles = []
features = loss_net(img_style)
for f in features:
# 这里一定得加上detach()截断风格输入,不然会保留的风格图片的梯度,导致BP失败
styles.append(get_gram_matrix(f).detach())

# 用于保存训练过程中的损失
style_loss = []
content_loss = []
total_loss = []

# 更改为tqdm模块内的trange函数以了解训练时间
for epoch in trange(1, opt.epochs+1):
trans_net.train()
for i, image in enumerate(content_loader):
image_c = image.to(device)
# print(image_c.shape)
image_g = trans_net(image_c)

# 计算风格损失
loss_s = 0.0
outs = loss_net(image_g)
for out, style in zip(outs, styles):
loss_s += loss_func(get_gram_matrix(out), style)

# 计算内容损失
contents = loss_net(image_c)
loss_c2 = loss_func(outs[1], contents[1])

# 总损失
loss = loss_c2 * opt.wc + loss_s * opt.ws
# print(epoch, loss.item(), loss_c2.item(), loss_s.item())
style_loss.append(loss_s.item())
content_loss.append(loss_c2.item())
total_loss.append(loss.item())

optimizer.zero_grad()
loss.backward()
optimizer.step()
# lr_adjust.step()


# 中途保存模型
if epoch % opt.save_frequency == 0:
all_data = dict(
optimizer=optimizer.state_dict(),
model=trans_net.state_dict(),
info=u'模型和优化器的所有参数'
)
p = opt.style_img.split('/')[-1]
p = opt.result_path + "/" + p.split('.')[0]
t.save(all_data, '{}_trans_net_{}.pth'.format(p, epoch))
img_path = '{}_fuse_{}.jpg'.format(p, epoch)
_, _, h, w = img_style.size()
trans = T.Resize((h, w))

save_image([img_style.clone()[0],trans(image_c[0]), trans(image_g[0])],
img_path, padding=0, normalize=True, range=(0, 1))

data = {'style_loss':style_loss, 'content_loss':content_loss, 'total_loss':total_loss}
training_process_visualization(data)

训练过程的损失变化图如下:

content_lossstyle_losstotal_loss

**网络测试:**构建test.py文件,加载训练时保存的模型参数,传入需要进行风格迁移的图片,得到风格迁移结果。具体的测试文件如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

from os import name
import torch as t
import glob
from PIL import Image
from torch.utils.data.dataset import T
from utils.config import Config
from model.my_net import TransNet
from torchvision.utils import save_image

opt = Config()
device = t.device('cuda' if t.cuda.is_available() else 'cpu')
trans_net = TransNet().to(device)

images = glob.glob('dataset/content/**')
styles = ['hosi', 'la_muse', 'sketch', 'trial']
# 遍历所有content图片
for img_p in images:
img = opt.trans(Image.open(img_p)).unsqueeze(0)
img = img.to(device)
# 遍历风格
for style in styles:
pred_models = ['checkpoints/{}_trans_net_{}.pth'.format(style, i) for i in range(30, 151, 30)]
outs = []
outs.append(img[0])
# 遍历风格的不同epoch模型
for m in pred_models:
all_data = t.load(m)
trans_net.load_state_dict(all_data['model'])
trans_net.eval() # 设置为推理模式
outs.append(trans_net(img)[0])
name = style + '_' + img_p[16:]
save_image(outs, name, padding=0, normalize=True, range=(0, 1))

2.5 迁移结果展示


最后得到的风格迁移结果如下:

综合以上三种风格的迁移可视化结果可以看出,(从纵轴上观察)网络在处理人物或者动物图像时,过度的看重了图像的主体特征(内容),忽略了图像风格,这样的情况应该可以通过候选不断调整内容损失和风格损失的权重取得两者之间的平衡去解决,但限于时间和算力,并未做过多的尝试。(从横轴上观察)随着网络训练的加深,迁移后图像的风格也更加自然(虽然我觉得其实epoch=30的时候更像是图像迁移的最终目的😂)。

3. 总结

通过这次风格迁移的学习(当然也只是最基础的部分),也让我大致了解风格迁移任务的整体流程,目前完成的效果因为训练时间的限制并不是太好,整体来看有点儿像是风格滤镜那种,不像第一节介绍那种实现图片风格的完全卡通化,后面如果有时间再来好好研究一下。这次任务也让我了解一些新的东西比如:利用Gram Matrix去定量衡量图片风格,自编码网络结构的搭建,利用glob库直接生成目录内文件路径以及不用重复造轮子直接利用Pytorch官方的save_img()保存tensor为图片等等。再接再厉,继续学习🐛🐛🐛!