主要讲Vision Transformer的模型结构部分,其中使用到的Encoder会直接调用PyTorch封装好的函数。

整体架构(论文创新点)

先看ViT的整体架构动态流程图:

img

接着还是用论文中的模型结构来分析:

image-20240415202959396

上图就是模型的整体结构。

从左下角开始看。可以看到模型的思想就是将一整张图片,切分为若干个(图中为9个)小方块。切分成若干个小方块后,在进行线性展平(flatten),得到粉色小块就是每个图片线性展平后的向量。之后再拼接紫色的小块,紫色小块为Positional Embedding。这里的借鉴了BERT中的class token,添加了一个cls_token在最前面,表示图片的分类。这样之后就可以直接拿到Transformer Encoder中训练了,训练完成后,经过一个MLP Head(就是先进行LayerNorm,再全连接将输入维度转换为输出类别个数)。虚线右侧就是Transformer Encoder。

论文的创新点

  1. 采用Transformer中的自注意力机制进行图像相关任务
  2. 分块序列化处理,使得图像处理任务可以使用Transformer来完成,并且也给后人提供了一种用NLP的方法做视觉任务的启发思想
  3. 与训练与微调,借鉴了BERT的class token思想,可以通过大规模无标注数据集进行预训练,然后在特定任务上进行微调。这一策略大大提升了模型的泛化能力和迁移学习能力。

浅层分析一点思想

「为什么要切分为小块再展平,将一整张图片展平不行吗?」

因为self-attention在NLP中是会对每个单词进行两两计算相似度求注意力,它的复杂度很高,在Transformer中一般认为的最大维度为512。对于目前常用的图片数据集,图片的size都是224*224或300*300甚至600*600这样展平后就会远超512。切分后,每一个图片都变成比较小的尺寸,例如16*16,对于三通道的图片那么它的维度就为16*16*3=768,这样就可以用到Transformer中了。如果从NLP的概念来理解的话,就类似于将一个长句子切分为多个短句子。

「为什么使用cls_token?」

因为本质还是想要用到图像分类问题上,需要给这整个图片标记一个分类。这里也是借鉴了BERT的思想。

注意,这个标记是可以学习,最终结果就是根据这个标记的值来分类整张图片的。

并且将cls_token放在第0位,即固定位置,也能够避免输出受到位置编码的干扰。

「为什么cls_token和Positional Embedding都是随机初始化?」

论文中,作者说了,随机和不随机的效果是差不多的。

但是随机相比不随机又会有如下优点:

  1. 该token对所有其他token上的信息做汇聚(全局特征聚合),并且由于它本身不基于图像内容,因此可以避免对sequence中某个特定token的偏向性;
  2. 可以放到网络中被训练,能够编码整个数据集的统计特征;
  3. 处理打乱顺序的图片的效果更好了

模型代码复现

Patch Embedding

这一部分其实就用是卷积操作完成的。

IMG_0993

如上图,切分为小块其实就是一步stride=kernel_size=16的卷积操作

但是看了上图,有的人可能就要懵了(我一开始也懵了)。各个变量的含义应该是看懂了,不懂点为:

  1. 不是切分后图片的size为16吗,怎么变成14了

其实这个就是关键,这里是想要用NLP的方法处理图像,不能以传统的图像的思维。Transformer计算的实际上是小图片的像素点。这里利用卷积的方式切分图片,实际上是把图片切分到通道上了。

即这一步卷积之后,通道数才是实际的embed_dim,1414或者展平的196才是“通道数”。*之后会“转置”将这两个维度交换位置。

所以上一步的操作使得图像变换为:(512, 3, 224, 224) => (512, 768, 14, 14)

接下来就是对后两维进行线性展平即可,如图中所示,就是nn.Flatten(2)

第三步就是随机初始化cls_token和Positional Embedding即可:

  • cls_token因为要和图片像素点拼接,因此维度为(1, 1, embed_dim)

  • 从图中可以看出:

    Positional Embedding的维度为(1, num_patches+1, embed_dim)

根据上述分析,写出如下Patch Embedding模块代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class PatchEmbedding(nn.Module):
"""
功能: 切分小方块
"""

def __init__(self, in_channels=3, patch_size=16, embed_dim=768, num_patches=14, dropout=0.01):
"""
初始化模型 \n
:param in_channels: 输入图片通道数 \n
:param patch_size: 切完后,小方块的大小 \n
:param embed_dim: 切完后,总的通道数 = (patch_size**2)*in_channels \n
:param num_patches: 切完后,小方块的数目 = (height/patch_size) * (weight/patch_size) \n
:param dropout: 随机丢弃层的 元素归零的概率 \n
"""
super(PatchEmbedding, self).__init__()
self.patcher = nn.Sequential(
# 利用卷积操作将原始图片分割为小方块,主要在于kernel_size = stride = patch_size
# (512, 3, 224, 224) => (512, 768, 14, 14)
nn.Conv2d(in_channels=in_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size),
# (512, 768, 14, 14) => (512, 768, 196)
nn.Flatten(2)
)

# 随机生成cls_token
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim), requires_grad=True) # (1, 1, 768)

# 随机生成位置编码,Transformer中的Positional Embedding
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim), requires_grad=True) # (1, 197, 768)

self.dropout = nn.Dropout(p=dropout)

def forward(self, x):
"""
前向传播 \n
:param x: 输入图像 (batch_size, channels, height, width)
"""

# 实例化cls_token,利用expand(),将batch_size对齐,其他维不变
# (1, 1, embed_dim) => (batch_size, 1, embed_dim)
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # (512, 1, 768)

# 执行完分块后,将第二维和第三维交换位置,因为Transformer要对图片像素序列处理
# (batch_size, channels, height, weight) => (batch_size, (height/patch_size)*(weight/patch_size), channels*(patch_size**2))
x = self.patcher(x).permute(0, 2, 1) # (512, 196, 768)

# 按channel拼接cls_token和x
# dim的取值:dim=0按batch_size;dim=1按channel;dim=2按height;dim=3按weight
# (batch_size, (height/patch_size)*(weight/patch_size), channels*(patch_size**2)) => (batch_size, (height/patch_size)*(weight/patch_size)+1, channels*(patch_size**2))
x = torch.cat([cls_token, x], dim=1) # (512, 197, 768)

# x与Positional Embedding相加,注意不是cat()
x = x + self.pos_embed # (512, 197, 768)

x = self.dropout(x) # (512, 197, 768)
return x

去掉注释,精简版:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class PatchEmbedding(nn.Module):
def __init__(self, in_channels=3, patch_size=16, embed_dim=768, num_patches=14, dropout=0.01):
super(PatchEmbedding, self).__init__()
self.patcher = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size),
nn.Flatten(2)
)
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim), requires_grad=True)
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim), requires_grad=True)
self.dropout = nn.Dropout(p=dropout)

def forward(self, x):
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = self.patcher(x).permute(0, 2, 1)
x = torch.cat([cls_token, x], dim=1)
x = x + self.pos_embed
x = self.dropout(x)
return x

ViT

image-20240415202959396

有了上面的PatchEmbedding模块,接下来定义ViT模块,根据模型结构图,可以看出其实就是先经过PatchEmbedding然后再输入到L个Transformer Encoder中,最后经过MLP返回其类别。

那么这里主要就是写MLP,MLP其实就是经过一个LayerNorm,再经过一个线性层,将embed_dim个维度转换为num_classes个维度。

整体实现代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class ViT(nn.Module):
"""
定义ViT网络结构\n
"""
def __init__(self, in_channels, patch_size, embed_dim, num_patches, dropout, num_heads, activation, num_encoders, num_classes):
"""
初始化模型 \n
:param in_channels: 输入的通道数 \n
:param patch_size: 切完后,小方块的大小 \n
:param embed_dim: 切完后,总的通道数 = (patch_size**2)*in_channels \n
:param num_patches: 切完后,小方块的数目 = (height/patch_size) * (weight/patch_size) \n
:param dropout: 随机丢弃层的 元素归零的概率 \n
:param num_heads: Transformer的Multi_Head个数 \n
:param activation: 激活函数 \n
:param num_encoders:encoder的个数\n
:param num_classes: 类别个数\n
"""
super(ViT, self).__init__()
# 使用上面定义的PatchEmbedding模块
self.patch_embedding = PatchEmbedding(in_channels, patch_size, embed_dim, num_patches, dropout) # (512, 197, 768)

# 设置encoder layers
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, activation=activation,batch_first=True, norm_first=True)
self.encoder_layers = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders) # (512, 197, 768)

# 定义MLP
self.MLP = nn.Sequential(
nn.LayerNorm(normalized_shape=embed_dim), # (512, 197, 768)
nn.Linear(in_features=embed_dim, out_features=num_classes) # (512, 197, num_classes)
)

def forward(self, x):
"""
前向传播 \n
:param x: 输入图像 (batch_size, channels, height, width)
"""
x = self.patch_embedding(x)
x = self.encoder_layers(x)
x = self.MLP(x[:, 0, :])
return x

无注释,精简版:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class ViT(nn.Module):
def __init__(self, in_channels, patch_size, embed_dim, num_patches, dropout, num_heads, activation, num_encoders, num_classes):
super(ViT, self).__init__()
self.patch_embedding = PatchEmbedding(in_channels, patch_size, embed_dim, num_patches, dropout)
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, activation=activation,batch_first=True, norm_first=True)
self.encoder_layers = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)
self.MLP = nn.Sequential(
nn.LayerNorm(normalized_shape=embed_dim),
nn.Linear(in_features=embed_dim, out_features=num_classes)

def forward(self, x):
x = self.patch_embedding(x)
x = self.encoder_layers(x)
x = self.MLP(x[:, 0, :])
return x

整体代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class PatchEmbedding(nn.Module):
def __init__(self, in_channels=3, patch_size=16, embed_dim=768, num_patches=14, dropout=0.01):
super(PatchEmbedding, self).__init__()
self.patcher = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size),
nn.Flatten(2)
)
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim), requires_grad=True)
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim), requires_grad=True)
self.dropout = nn.Dropout(p=dropout)

def forward(self, x):
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = self.patcher(x).permute(0, 2, 1)
x = torch.cat([cls_token, x], dim=1)
x = x + self.pos_embed
x = self.dropout(x)
return x


class ViT(nn.Module):
def __init__(self, in_channels, patch_size, embed_dim, num_patches, dropout, num_heads, activation, num_encoders, num_classes):
super(ViT, self).__init__()
self.patch_embedding = PatchEmbedding(in_channels, patch_size, embed_dim, num_patches, dropout)
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, activation=activation,batch_first=True, norm_first=True)
self.encoder_layers = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)
self.MLP = nn.Sequential(
nn.LayerNorm(normalized_shape=embed_dim),
nn.Linear(in_features=embed_dim, out_features=num_classes)

def forward(self, x):
x = self.patch_embedding(x)
x = self.encoder_layers(x)
x = self.MLP(x[:, 0, :])
return x

然后利用训练的常规套路,训练模型即可。