ViT中的Patch Embedding

题目分析

给定图像尺寸 img_size、块大小 patch_size、通道数 channels 和嵌入维度 embedding_dim,计算 Vision Transformer 中 Patch Embedding 层输出的形状。

思路

数学推导

ViT 将输入图像划分为等大的正方形 patch,每个 patch 通过线性映射得到一个嵌入向量。在所有 patch token 前面还会拼接一个分类 token(CLS token)。

具体计算步骤:

  1. 每条边上的 patch 数量为
  2. 总 patch 数量为
  3. 加上 1 个 CLS token,最终 token 总数为
  4. 每个 token 的维度就是 embedding_dim,保持不变。

以样例为例:,加上 CLS token 得 ,嵌入维度为 ,输出 145 512

注意 channels 在计算输出形状时不影响结果,它只决定每个 patch 展平后的输入维度(),经过线性层后映射到 embedding_dim

代码

img_size, patch_size, channels, embedding_dim = map(int, input().split())
num_patches = (img_size // patch_size) ** 2
print(num_patches + 1, embedding_dim)

复杂度分析

  • 时间复杂度,只需常数次算术运算。
  • 空间复杂度,只使用常数个变量。