Self-Attention GAN(SAGAN)详解与实现
Self-Attention GAN(SAGAN)详解与实现
0. 前言
自注意力 (Self-Attention) 在计算机视觉(包括分类任务)中得到了广泛采用,自注意力可以帮助我们捕获图像中的重要特征,而无需在较大的有效感受野上使用深层网络。StyleGAN 非常适合生成人脸,但要根据 ImageNet 生成图像会很困难。
从某种意义上说,人脸很容易生成,因为眼睛,鼻子和嘴唇都具有相似的形状,并且在各个面孔上的位置都相似。相比之下,ImageNet 的 1000 类图像包含各种对象(例如,狗,卡车,鱼和枕头)和背景。因此,判别器必须更有效地捕获各种物体的独特特征。这就是自注意力的用武之地了。借助条件批归一化和频谱归一化,我们将实现 Self-Attention GAN (SAGAN) 以基于给定的类别标签生成图像。
1. 频谱归一化
频谱归一化是稳定生成对抗网络 (Generative Adversarial Network, GAN) 训练的重要方法,并且已在许多 GAN 中使用。与批归一化或其他归一化激活方法不同,频谱标准化将权重归一化。频谱归一化的目的是限制权重的增长,因此网络遵守 1-Lipschitz 约束,这已证明可以有效地稳定 GAN 训练。
我们将修改 WGAN,以使我们更好地理解频谱归一化背后的思想。WGAN 判别器(也称为评论家)需要将其预测保持较小数值范围,以满足 1-Lipschtiz 约束。WGAN 通过将权重裁剪到 [-0.01, 0.01] 的范围来做到这一点,但这不是一种可靠的方法,因为我们需要微调裁剪范围,这是一个超参数。我们需要有一种系统的方法可以在不使用超参数的情况下强制执行 1-Lipschitz 约束,而频谱归一化正是我们所需的工具。本质上,频谱归一化通过将权重除以频谱范数来对其进行归一化。
1.1 频谱归一化原理
接下来,我们将研究一些线性代数,以大致解释什么是频谱范数。矩阵理论中的特征值和特征向量通过以下公式定义:
A
v
=
λ
v
Av=\lambda v
Av=λv
其中,
A
A
A 是一个方阵,
v
v
v 是特征向量,而
λ
\lambda
λ 是其特征值。
我们将使用一个简单的示例来尝试理解这些术语。假设
v
v
v 是位置
(
x
,
y
)
(x, y)
(x,y) 的向量,而
A
A
A 是线性变换:
A
=
(
a
b
c
d
)
,
v
=
(
x
y
)
A=\begin{pmatrix} a & b \\ c & d \end{pmatrix},\ \ \ \ v=\begin{pmatrix} x \\ y \end{pmatrix}
A=(acbd), v=(xy)
如果将
A
A
A 乘以
v
v
v,我们将获得一个新的位置,其方向改变如下:
A
v
=
(
a
b
c
d
)
×
(
x
y
)
=
(
a
x
+
b
y
c
x
+
d
y
)
Av=\begin{pmatrix} a & b \\ c & d \end{pmatrix} \times \begin{pmatrix} x \\ y \end{pmatrix}=\begin{pmatrix} ax+by \\ cx+dy \end{pmatrix}
Av=(acbd)×(xy)=(ax+bycx+dy)
特征向量是将
A
A
A 应用于向量时不会改变方向的向量。取而代之的是,它们可以仅通过标量特征值
λ
\lambda
λ 进行缩放。可以有多个特征向量—特征值对,最大特征值的平方根是矩阵的谱范数。对于非方矩阵,我们将需要使用数学算法(例如奇异值分解 (singular value decomposition, SVD) )来计算特征值,这在计算上可能会非常昂贵。
因此,实践中采用幂迭代法来加速计算,使其能适用于神经网络训练。接下来我们将在 TensorFlow 中实现作为权重约束条件的谱归一化。
1.2 实现频谱归一化
频谱归一化数学算法可能看起来很复杂。但是,像往常一样,算法实现比数学上看起来更简单。以下是执行频谱归一化的步骤:
- 卷积层中的权重是一个
4维张量,因此第一步是将其重塑为W的2D矩阵,在这里我们保留权重的最后一个维度。现在,权重的形状为(H×W, C) - 用 N ( 0 , 1 ) N(0,1) N(0,1) 初始化向量 U U U
- 在
for循环中,计算以下内容:- 用矩阵转置和矩阵乘法计算 V = W T U V =W^TU V=WTU
- 用其
L2范数归一化 V V V,即 V = V ∣ ∣ V ∣ ∣ 2 V = \frac {V}{||V||_2} V=∣∣V∣∣2V - 计算 U = W V U = WV U=WV
- 用
L2范数归一化 U U U,即 U = U ∣ ∣ U ∣ ∣ 2 U = \frac U{||U||_2} U=∣∣U∣∣2U
- 计算频谱范数为 U T W V U^TWV UTWV
- 最后,将权重除以频谱范数
完整的代码如下:
class SpectralNorm(tf.keras.constraints.Constraint):
def __init__(self, n_iter=5):
self.n_iter = n_iter
def call(self, input_weights):
w = tf.reshape(input_weights, (-1, input_weights.shape[-1]))
u = tf.random.normal((w.shape[0], 1))
for _ in range(self.n_iter):
v = tf.matmul(w, u, transpose_a=True)
v /= tf.norm(v)
u = tf.matmul(w, v)
u /= tf.norm(u)
spec_norm = tf.matmul(u, tf.matmul(w, v), transpose_a=True)
return input_weights/spec_norm
迭代次数是一个超参数,5 次就足够了。还可以通过设置变量来记忆向量 u,而不是每次从随机值开始计算,这样可以将迭代次数减少至 1 次。现在我们在定义网络层时,可以通过 kernel_constraint 参数应用谱归一化,例如使用 Conv2D(3, 1, kernel_constraint=SpectralNorm()) 的方式来实现。
2. 自注意力模块
自注意力模块随着 Transformer 的引入而变得流行起来。在诸如语言翻译之类的自然语言处理 (Natural Language Processing, NLP) 应用程序中,模型通常需要逐字阅读句子以理解它们,然后再产生输出。Transformer 问世之前使用的神经网络是循环神经网络 (recurrent neural network, RNN) 的某些变体,例如长短期记忆 (long short-term memory, LSTM),RNN 具有内部状态,可以在阅读句子时记住单词。
这样做的一个缺点是,当单词数量增加时,第一个单词的梯度会逐渐消失。也就是说,随着 RNN 读取更多单词,句子开头的单词逐渐变得不那么重要。
Transformer 的处理方式有所不同。它会一次读取所有单词,并权衡每个单词的重要性。因此,模型会对更重要的单词赋予更多关注,因此被称为注意力。自注意力是目前主流 NLP 模型的基石。
2.1 计算机视觉中的自注意力
卷积神经网络 (Convolutional Neural Network, CNN) 主要由卷积层组成。对于核大小为 3×3 的卷积层,它将仅查看输入激活中的 3 × 3 = 9 个特征以计算每个输出特征。它将不会查看超出此范围的像素。为了捕获超出此范围的像素,我们可以将核大小略微增加到 5 × 5 或 7 × 7,但与特征图大小相比仍然很小。
我们必须向下移动一个网络层,以使卷积核的接受域足够大以捕获我们想要的内容。与 RNN 一样,输入特征的相对重要性随着我们在网络层中的移动而下降。因此,我们可以利用自注意力来观察特征图中的每个像素,并进行应注意的工作。
接下来,我们将研究自注意力机制的工作原理。自注意力的第一步是将每个输入要素投影到三个向量中,这些向量称为键 (key),查询 (query) 和值 (value)。下图说明了如何从查询中生成注意力图:

左图是带有点标记的查询 (query) 的图像,接下来的五个图像显示了查询给出的注意力图。顶部的第一个注意力图查询兔子的一只眼睛;注意图的两只眼睛周围有更多白色(指示重要区域),其他区域接近纯黑色(重要性较低)。接下来,我们将逐一介绍键,查询和值术语:
- 值 (
value) 表示输入特征。我们不希望自注意力模块查看每个像素,因为这在计算上过于昂贵且不必要。相反,我们对输入激活的局部区域更感兴趣。因此,值在激活图尺寸(例如,它可以被下采样以具有较小的高度和宽度)和信道的数目方面都减小了来自输入特征的维数。对于卷积层激活,通过使用1x1卷积来减少通道数,并通过最大池化或平均池化来减小空间大小 - 键 (
key) 和查询 (query) 用于计算自注意图中特征的重要性。为了计算位置 x x x 处的输出特征,我们在位置 x x x 处进行查询,并将其与所有位置处的键进行比较。
为了进一步说明这一点,假设我们有一个肖像画。当网络正在处理肖像的一只眼睛时,它将进行查询,该查询具有“眼睛”的语义含义,并使用肖像的其他区域的键进行检查。如果其他区域的键之一是眼睛,那么我们知道我们找到了另一只眼睛,当然这是我们要注意的事,以便我们可以匹配眼睛的颜色。
将其放到方程中,对于特征 0,我们计算向量 q0×k0,q0×k1,q0×k2,依此类推,得出 q0×kN-1。然后使用 softmax 将向量归一化,因此它们的总和为 1.0,这是我们的注意力得分。这用作权重以执行值的逐元素乘法,以提供注意输出。
SAGAN 自注意力模块基于非局部块,其最初是为视频分类而设计的。在确定当前体系结构之前,作者尝试了多种实现自注意力的方法。下图显示了 SAGAN 中的注意力模块,其中
θ
θ
θ,
φ
φ
φ 和
g
g
g 对应于键,查询和值:

深度学习中的大多数计算都是为了提高速度性能而向量化的,而对于自注意力也没有什么不同。如果为简单起见忽略 batch 维,则 1×1 卷积后的激活将具有 (H, W, C) 的形状。第一步是将其重塑为形状为 (H×W, C) 的 2D 矩阵,并使用
θ
θ
θ 和
φ
φ
φ 之间的矩阵相乘来计算注意力图。在 SAGAN 中使用的自注意力模块中,还有另一个 1×1 卷积,用于将通道号恢复到输入通道,然后使用可学习的参数进行缩放。进而,该模块被设计为残差块结构。
2.2 实现自注意力模块
我们将首先在自定义层的 build() 中定义所有 1×1 卷积层和权重。需要注意的是,使用频谱归一化函数作为卷积层的核约束:
class SelfAttention(Layer):
def __init__(self):
super(SelfAttention, self).__init__()
def build(self, input_shape):
n,h,w,c = input_shape
self.conv_theta = Conv2D(c//8, 1, padding='same', kernel_constraint=SpectralNorm(), name='Conv_Theta')
self.conv_phi = Conv2D(c//8, 1, padding='same', kernel_constraint=SpectralNorm(), name='Conv_Phi')
self.conv_g = Conv2D(c//8, 1, padding='same', kernel_constraint=SpectralNorm(), name='Conv_g')
self.conv_attn_g = Conv2D(c//8, 1, padding='same', kernel_constraint=SpectralNorm(), name='Conv_AttnG')
self.sigma = self.add_weight(shape=[1], initializer='zeros', trainable=True, name='sigma')
需要注意的是:
- 内部激活层的维度可以适当缩减以加速计算,
SAGAN作者通过实验确定了最优的维度缩减比例 - 在每个卷积层之后,激活
(H, W, C)被重塑为形状为(H*W, C)的二维矩阵,以便进行矩阵乘法运算
以下是该层的 call() 函数,用于执行自注意力操作。我们将首先计算 theta,phi 和 g:
def call(self, x):
n, h, w, c = x.shape
theta = self.conv_theta(x)
theta = tf.reshape(theta, (-1, self.n_feats, theta.shape[-1]))
phi = self.conv_phi(x)
phi = tf.nn.max_pool2d(phi, ksize=2, strides=2, padding='VALID')
phi = tf.reshape(phi, (-1, self.n_feats//4, phi.shape[-1]))
g = self.conv_g(x)
g = tf.nn.max_pool2d(g, ksize=2, strides=2, padding='VALID')
g = tf.reshape(g, (-1, self.n_feats//4, g.shape[-1]))
然后,将按以下方式计算注意力图:
attn = tf.matmul(theta, phi, transpose_b=True)
attn = tf.nn.softmax(attn)
最后,将注意力图与查询 g 相乘,并继续产生最终输出:
attn_g = tf.matmul(attn, g)
attn_g = tf.reshape(attn_g, (-1, h, w, attn_g.shape[-1]))
attn_g = self.conv_attn_g(attn_g)
output = x + self.sigma * attn_g
return output
编写了频谱归一化和自注意力层之后,接下来,可以使用它们来构建 SAGAN。
3. 实现 SAGAN
SAGAN 具有类似于 DCGAN 的简单体系结构。但是,它是使用类别标签来生成和判别图像的以类别为条件的 GAN。在下图中,每行上的每个图像都是从不同的类别标签生成的:

为了快速实验,在本节中,我们将使用 CIFAR10 数据集,其中包含 10 类图像,分辨率为 32x32。
3.1 SAGAN 生成器
总的来说,SAGAN 生成器与其他 GAN 生成器没有太大区别:以噪声作为输入并经过一个全连接层,然后经过多个上采样和卷积块,以实现目标图像分辨率。我们从 4×4 分辨率开始,并使用三个上采样块来达到 32×32 的最终分辨率:
def build_generator(z_dim, n_class):
DIM = 64
z = layers.Input(shape=(z_dim))
lables = layers.Input(shape=(1), dtype='int32')
x = Dense(4*4*4*DIM)(z)
x = layers.Reshape((4,4,4*DIM))(x)
x = layers.UpSampling2D((2,2))(x)
x = Resblock(4*DIM, n_class)(x, labels)
x = layers.UpSampling2D((2,2))(x)
x = Resblock(2*DIM, n_class)(x, labels)
x = SelfAttention()(x)
x = layers.UpSampling2D((2,2))(x)
x = Resblock(DIM, n_class)(x, labels)
output_image = tanh(Conv2D(3, 3, padding='same')(x))
return Model([z, labels], output_image, name='generator')
尽管在自注意模块中使用了不同的激活尺寸,但其输出与输入的形状相同。因此,可以将其插入卷积层之后的任何位置。但是,当核大小为 3×3 时,在分辨率为 4×4 时使用自注意力层是不必要的。因此,自注意层仅在 SAGAN 生成器空间分辨率较高的阶段中插入一次,以最大程度地利用自注意层。判别器也是如此,将自注意层放置在空间分辨率较高的层。
如果我们要进行无条件的图像生成,那么这就是生成器的全部内容。但是,我们还需要将类标签提供给生成器,以便它可以根据给定的类创建图像。我们已经了解过一些对标签进行条件处理的常用方法,但是 SAGAN 使用了更高级的方法,它在批归一化中将类标签编码为可学习的参数。
3.2 条件批归一化
在 CIFAR10 中,有 10 类图像:6 种是动物(鸟,猫,鹿,狗,青蛙和马),4 种是交通工具(飞机,汽车,轮船和卡车)。显然,它们看起来截然不同——车辆往往具有坚硬而笔直的边缘,而动物倾向于具有弯曲的边缘和较柔和的纹理。
在风格迁移一节中,我们了解到激活的统计数据决定了图像风格。因此,混合批次统计信息可能会创建看上去有点像动物同时也有点像车辆(例如,汽车形状的猫)的图像。这是因为批归一化在由不同类组成的整个批次中仅使用一个 gamma 和一个 beta。如果每种风格(类别)都有一个 gamma 和一个 beta,则该问题得以解决,而这正是条件批归一化的意义所在。每个类别有一个 gamma 和一个 beta,因此 CIFAR10 中的 10 个类别每层有 10 个 beta 和 10 个 gamma。
现在,我们可以构造条件批归一化所需的变量:
- 形状为
(10, C)的gamma和beta,其中C是激活通道数 - 形状为
(1, 1, 1, C)的移动均值和方差。在训练中,均值和方差是从小批次数据计算得出的,推理时则使用训练过程中累积的移动平均值。这样的形状设计使得算术运算能够广播到N、H和W维度。
实现用于条件批归一化:
class ConditionBatchNorm(Layer):
def build(self, input_shape):
self.input_size = input_shape
n, h, w, c = input_shape
self.gamma = self.add_weight(shape=[self.n_class, c],
initializer='zeros', trainable=True, name='gamma')
self.moving_mean = self.add_weight(shape=[1, 1, 1, c],
initializer='zeros', trainable=False, name='moving_mean')
self.moving_var = self.add_weight(shape=[1, 1, 1, c],
initializer='zeros', trainable=False, name='moving_var')
当运行条件批归一化时,为标签检索正确的 beta 和 gamma。这是使用 tf.gather(self.beta, labels) 完成的,它在概念上等效于 beta = self.beta [labels]:
def call(self, x, labels, trainable=False):
beta = tf.gather(self.beta, labels)
beta = tf.expand_dims(beta, 1)
gamma = tf.gather(self.gamma, labels)
gamma = tf.expand_dims(gamma, 1)
除此之外,其余代码与批归一化相同。现在,可以将条件批归一化添加到生成器的残差块中:
class ResBlock(Layer):
def build(self, input_shape):
input_filter = input_shape[-1]
self.conv_1 = Conv2D(self.filters, 3, padding='same', name='conv2d_1')
self.conv_2 = Conv2D(self.filters, 3, padding='same', name='conv2d_2')
self.cbn_1 = ConditionBatchNorm(self.n_class)
self.cbn_2 = ConditionBatchNorm(self.n_class)
self.learned_skip = False
if self.filters != input_filter:
self.learned_skip = True
self.conv_3 = Conv2D(self.filters, 1, padding='same', name='conv2d_3')
self.cbn_3 = ConditionBatchNorm(self.n_class)
以下是用于条件批归一化的前向计算的运行时代码:
def call(self, input_tensor, labels):
x = self.conv_1(input_tensor)
x = self.cbn_1(x, labels)
x = tf.nn.leaky_relu(x, 0.2)
x = self.conv_2(x)
x = tf.cbn_2(x, labels)
x = tf.nn.leaky_relu(x, 0.2)
if self.learned_skip:
skip = self.conv_3(input_tensor)
skip = self.cbn_3(skip, labels)
skip = tf.nn.leaky_relu(skip, 0.2)
else:
skip = input_tensor
output = skip + x
return output
判别器的残差块与生成器的残差块相似,区别在于:
- 没有归一化
- 下采样发生在具有平均池化的残差块内部
3.3 构建判别器
判别器也使用自注意层,并将其放置在输入层附近以捕获大尺度激活图。由于它是条件 GAN,因此我们还将在判别器中使用标签,以确保生成器正在生成与类匹配的正确图像。合并标签信息的一般方法是先将标签投影到嵌入空间中,然后在输入层或内部层使用嵌入。有两种将嵌入与激活合并的常用方法——级联和逐元素乘法。SAGAN 使用的架构类似于投影模型:

首先将标签投影到嵌入空间中,然后在全连接层(图中的 ψ ψ ψ )之前执行激活的逐元素乘法。然后将结果添加到全连接层输出中,以给出最终预测:
def build_discriminator(n_class):
DIM = 64
input_image = Input(shape=IMAGE_SHAPE)
input_labels = Input(shape=(1))
embedding = Embedding(n_class, 4*DIM)(input_labels)
embedding = Flatten()(embedding)
x = ResblockDown(DIM)(input_image) # 16
x = SelfAttention()(x)
x = ResblockDown(2*DIM)(x)
x = ResblockDown(4*DIM, False)(x)
x = tf.reduce_sum(x, (1,2))
embedded_x = tf.reduce_mean(x*embedding, axis=1, keep_dims=True)
output = Dense(1)(x)
output += embedded_x
return Model([input_image, input_labels], output, name='discriminator')
3.4 训练 SAGAN
我们将使用标准的 GAN 训练步骤。损失函数是铰链损失,使用 Adam 优化器。生成器 (1e-4) 和判别器 (4e-4) 使用不同的初始学习率。由于 CIFAR10 尺寸为 32×32 的小图像,因此训练相对稳定且快速。

原文地址:https://blog.csdn.net/LOVEmy134611/article/details/151835424
免责声明:本站文章内容转载自网络资源,如侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!
