keras圖像風格遷移



風格遷移: 在內容上盡量與基准圖像保持一致,在風格上盡量與風格圖像保持一致。

  • 1. 使用預訓練的VGG19網絡提取特征
  • 2. 損失函數之一是“內容損失”(content loss),代表合成的圖像的特征與基准圖像的特征之間的L2距離,保證生成的圖像內容和基准圖像保持一致。
  • 3. 損失函數之二是“風格損失”(style loss),代表合成圖像的特征與風格圖像的特征之間的Gram矩陣之間的差異,保證生成圖像的風格和風格圖像保持一致。
  • 4. 損失函數之三是“差異損失”(variation loss),代表合成的圖像局部特征之間的差異,保證生成的圖像局部特征的一致性,整體看上去自然不突兀。

 

基於keras的代碼實現:

# coding: utf-8
from __future__ import print_function
from keras.preprocessing.image import load_img, img_to_array
import numpy as np
from scipy.optimize import fmin_l_bfgs_b
import time
import argparse
from scipy.misc import imsave
from keras.applications import vgg19
from keras import backend as K
import os
from PIL import Image, ImageFont, ImageDraw, ImageOps, ImageEnhance, ImageFilter

# 輸入參數
parser = argparse.ArgumentParser(description='基於Keras的圖像風格遷移.')  # 解析器
parser.add_argument('--style_reference_image_path', metavar='ref', type=str,default = './style.jpg',
                    help='目標風格圖片的位置')
parser.add_argument('--base_image_path', metavar='ref', type=str,default = './base.jpg',
                    help='基准圖片的位置')
parser.add_argument('--iter', type=int, default=25, required=False,
                    help='迭代次數')
parser.add_argument('--pictrue_size', type=int, default=500, required=False,
                    help='圖片大小.')

# 獲取參數
args = parser.parse_args()
base_image_path = args.base_image_path
style_reference_image_path = args.style_reference_image_path
iterations = args.iter
pictrue_size = args.pictrue_size


source_image = Image.open(base_image_path)
source_image= source_image.resize((pictrue_size, pictrue_size))

width, height = pictrue_size, pictrue_size


def save_img(fname, image, image_enhance=True):  # 圖像增強
    image = Image.fromarray(image)
    if image_enhance:
        # 亮度增強
        enh_bri = ImageEnhance.Brightness(image)
        brightness = 1.2
        image = enh_bri.enhance(brightness)

        # 色度增強
        enh_col = ImageEnhance.Color(image)
        color = 1.2
        image = enh_col.enhance(color)

        # 銳度增強
        enh_sha = ImageEnhance.Sharpness(image)
        sharpness = 1.2
        image = enh_sha.enhance(sharpness)
    imsave(fname, image)
    return


# util function to resize and format pictures into appropriate tensors
def preprocess_image(image):
    """
    預處理圖片,包括變形到(1,width, height)形狀,數據歸一到0-1之間
    :param image: 輸入一張圖片
    :return: 預處理好的圖片
    """
    image = image.resize((width, height))
    image = img_to_array(image)
    image = np.expand_dims(image, axis=0)  # (width, height)->(1,width, height)
    image = vgg19.preprocess_input(image)  # 0-255 -> 0-1.0
    return image

def deprocess_image(x):
    """
    將0-1之間的數據變成圖片的形式返回
    :param x: 數據在0-1之間的矩陣
    :return: 圖片,數據都在0-255之間
    """
    x = x.reshape((width, height, 3))
    x[:, :, 0] += 103.939
    x[:, :, 1] += 116.779
    x[:, :, 2] += 123.68
    # 'BGR'->'RGB'
    x = x[:, :, ::-1]
    x = np.clip(x, 0, 255).astype('uint8')  # 以防溢出255范圍
    return x


def gram_matrix(x):  # Gram矩陣
    assert K.ndim(x) == 3
    if K.image_data_format() == 'channels_first':
        features = K.batch_flatten(x)
    else:
        features = K.batch_flatten(K.permute_dimensions(x, (2, 0, 1)))
    gram = K.dot(features, K.transpose(features))
    return gram

# 風格損失,是風格圖片與結果圖片的Gram矩陣之差,並對所有元素求和
def style_loss(style, combination):
    assert K.ndim(style) == 3
    assert K.ndim(combination) == 3
    S = gram_matrix(style)
    C = gram_matrix(combination)
    S_C = S-C
    channels = 3
    size = height * width
    return K.sum(K.square(S_C)) / (4. * (channels ** 2) * (size ** 2))
    #return K.sum(K.pow(S_C,4)) / (4. * (channels ** 2) * (size ** 2))  # 居然和平方沒有什么不同
    #return K.sum(K.pow(S_C,4)+K.pow(S_C,2)) / (4. * (channels ** 2) * (size ** 2))  # 也能用,花后面出現了葉子


def eval_loss_and_grads(x):  # 輸入x,輸出對應於x的梯度和loss
    if K.image_data_format() == 'channels_first':
        x = x.reshape((1, 3, height, width))
    else:
        x = x.reshape((1, height, width, 3))
    outs = f_outputs([x])  # 輸入x,得到輸出
    loss_value = outs[0]
    if len(outs[1:]) == 1:
        grad_values = outs[1].flatten().astype('float64')
    else:
        grad_values = np.array(outs[1:]).flatten().astype('float64')
    return loss_value, grad_values

# an auxiliary loss function
# designed to maintain the "content" of the
# base image in the generated image
def content_loss(base, combination):
    return K.sum(K.square(combination - base))

# the 3rd loss function, total variation loss,
# designed to keep the generated image locally coherent
def total_variation_loss(x,img_nrows=width, img_ncols=height):
    assert K.ndim(x) == 4
    if K.image_data_format() == 'channels_first':
        a = K.square(x[:, :, :img_nrows - 1, :img_ncols - 1] - x[:, :, 1:, :img_ncols - 1])
        b = K.square(x[:, :, :img_nrows - 1, :img_ncols - 1] - x[:, :, :img_nrows - 1, 1:])
    else:
        a = K.square(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, 1:, :img_ncols - 1, :])
        b = K.square(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, :img_nrows - 1, 1:, :])
    return K.sum(K.pow(a + b, 1.25))


# Evaluator可以只需要進行一次計算就能得到所有的梯度和loss
class Evaluator(object):
    def __init__(self):
        self.loss_value = None
        self.grads_values = None

    def loss(self, x):
        assert self.loss_value is None
        loss_value, grad_values = eval_loss_and_grads(x)
        self.loss_value = loss_value
        self.grad_values = grad_values
        return self.loss_value

    def grads(self, x):
        assert self.loss_value is not None
        grad_values = np.copy(self.grad_values)
        self.loss_value = None
        self.grad_values = None
        return grad_values


# 得到需要處理的數據,處理為keras的變量(tensor),處理為一個(3, width, height, 3)的矩陣
# 分別是基准圖片,風格圖片,結果圖片
base_image = K.variable(preprocess_image(source_image))   # 基准圖像
style_reference_image = K.variable(preprocess_image(load_img(style_reference_image_path)))
if K.image_data_format() == 'channels_first':
    combination_image = K.placeholder((1, 3, width, height))
else:
    combination_image = K.placeholder((1, width, height, 3))

# 組合以上3張圖片,作為一個keras輸入向量
input_tensor = K.concatenate([base_image, style_reference_image, combination_image], axis=0)   #組合

# 使用Keras提供的訓練好的Vgg19網絡,不帶3個全連接層
model = vgg19.VGG19(input_tensor=input_tensor,weights='imagenet', include_top=False)
model.summary()  # 打印出模型概況
'''
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         (None, None, None, 3)     0
_________________________________________________________________
block1_conv1 (Conv2D)        (None, None, None, 64)    1792             A
_________________________________________________________________
block1_conv2 (Conv2D)        (None, None, None, 64)    36928
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, None, None, 64)    0
_________________________________________________________________
block2_conv1 (Conv2D)        (None, None, None, 128)   73856            B
_________________________________________________________________
block2_conv2 (Conv2D)        (None, None, None, 128)   147584
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, None, None, 128)   0
_________________________________________________________________
block3_conv1 (Conv2D)        (None, None, None, 256)   295168           C
_________________________________________________________________
block3_conv2 (Conv2D)        (None, None, None, 256)   590080
_________________________________________________________________
block3_conv3 (Conv2D)        (None, None, None, 256)   590080
_________________________________________________________________
block3_conv4 (Conv2D)        (None, None, None, 256)   590080
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, None, None, 256)   0
_________________________________________________________________
block4_conv1 (Conv2D)        (None, None, None, 512)   1180160          D
_________________________________________________________________
block4_conv2 (Conv2D)        (None, None, None, 512)   2359808
_________________________________________________________________
block4_conv3 (Conv2D)        (None, None, None, 512)   2359808
_________________________________________________________________
block4_conv4 (Conv2D)        (None, None, None, 512)   2359808
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, None, None, 512)   0
_________________________________________________________________
block5_conv1 (Conv2D)        (None, None, None, 512)   2359808          E
_________________________________________________________________
block5_conv2 (Conv2D)        (None, None, None, 512)   2359808
_________________________________________________________________
block5_conv3 (Conv2D)        (None, None, None, 512)   2359808
_________________________________________________________________
block5_conv4 (Conv2D)        (None, None, None, 512)   2359808          F
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, None, None, 512)   0
=================================================================
'''
# Vgg19網絡中的不同的名字,儲存起來以備使用
outputs_dict = dict([(layer.name, layer.output) for layer in model.layers])

loss = K.variable(0.)

layer_features = outputs_dict['block5_conv2']
base_image_features = layer_features[0, :, :, :]
combination_features = layer_features[2, :, :, :]
content_weight = 0.08
loss += content_weight * content_loss(base_image_features,
                                      combination_features)

feature_layers = ['block1_conv1','block2_conv1','block3_conv1','block4_conv1','block5_conv1']
feature_layers_w = [0.1,0.1,0.4,0.3,0.1]
# feature_layers = ['block5_conv1']
# feature_layers_w = [1]
for i in range(len(feature_layers)):
    # 每一層的權重以及數據
    layer_name, w = feature_layers[i], feature_layers_w[i]
    layer_features = outputs_dict[layer_name]  # 該層的特征

    style_reference_features = layer_features[1, :, :, :]  # 參考圖像在VGG網絡中第i層的特征
    combination_features = layer_features[2, :, :, :]     # 結果圖像在VGG網絡中第i層的特征

    loss += w * style_loss(style_reference_features, combination_features)  # 目標風格圖像的特征和結果圖像特征之間的差異作為loss

loss += total_variation_loss(combination_image)


# 求得梯度,輸入combination_image,對loss求梯度, 每輪迭代中combination_image會根據梯度方向做調整
grads = K.gradients(loss, combination_image)

outputs = [loss]
if isinstance(grads, (list, tuple)):
    outputs += grads
else:
    outputs.append(grads)

f_outputs = K.function([combination_image], outputs)

evaluator = Evaluator()
x = preprocess_image(source_image)
img = deprocess_image(x.copy())
fname = '原始圖片.png'
save_img(fname, img)

# 開始迭代
for i in range(iterations):
    start_time = time.time()
    print('迭代', i,end="   ")
    x, min_val, info = fmin_l_bfgs_b(evaluator.loss, x.flatten(), fprime=evaluator.grads, maxfun=20, epsilon=1e-7)
    # 一個scipy的L-BFGS優化器
    print('目前loss:', min_val,end="  ")
    # 保存生成的圖片
    img = deprocess_image(x.copy())

    fname = 'result_%d.png' % i
    end_time = time.time()
    print('耗時%.2f s' % (end_time - start_time))

    if i%5 == 0 or i == iterations-1:
        save_img(fname, img, image_enhance=True)
        print('文件保存為', fname)

基准圖像:

風格圖像:

 

合成的藝術風格圖像:

 

訓練時候整體的loss是3個loss的和,每個loss都有一個系數,調整不同的系數,對應不同的效果。

 

“內容損失”(content loss)

以下圖片分別對應內容損失系數為0.1、1、5、10的效果:

 

隨着內容損失系數的增大,迭代優化會更加側重於調整合成圖像的內容,使得圖像跟原始圖像越來越接近。

 

“風格損失”(style loss)

 

風格損失是VGG網絡5個CNN層的特征的融合,單純增大風格損失系數對圖像最終風格影響不大,以下是系數是1和100的對比:

 

系數相差100倍,但是圖像風格並沒有明顯的改變。可能調整5個卷積特征不同的比例系數會有效果。

以下是單純使用第1、2、3、4、5個卷積層特征的效果:

 

可見 5個卷積層特征里第3和第4個卷積層對圖像的風格影響較大。

以下調整第3和第4個卷積層的系數,5個系數比為1:1:1:1:1和0.5:0.5:0.4:0.4:1

增大第3、4層比例之后,圖像風格更加接近風格圖像。

 

 

“差異損失”(variation loss)

 

圖像差異損失衡量的是圖像本身的局部特征之間的差異,系數越大,圖像局部越接近,表現在圖像上就是圖像像素間過度自然,以下是系數是1、5、10的效果:

 

以上。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM