ViT中的Patch Embedding
题目分析
给定图像尺寸 img_size、块大小 patch_size、通道数 channels 和嵌入维度 embedding_dim,计算 Vision Transformer 中 Patch Embedding 层输出的形状。
思路
数学推导
ViT 将输入图像划分为等大的正方形 patch,每个 patch 通过线性映射得到一个嵌入向量。在所有 patch token 前面还会拼接一个分类 token(CLS token)。
具体计算步骤:
- 每条边上的 patch 数量为
。
- 总 patch 数量为
。
- 加上 1 个 CLS token,最终 token 总数为
。
- 每个 token 的维度就是
embedding_dim,保持不变。
以样例为例:,
,加上 CLS token 得
,嵌入维度为
,输出
145 512。
注意 channels 在计算输出形状时不影响结果,它只决定每个 patch 展平后的输入维度(),经过线性层后映射到
embedding_dim。
代码
img_size, patch_size, channels, embedding_dim = map(int, input().split())
num_patches = (img_size // patch_size) ** 2
print(num_patches + 1, embedding_dim)
复杂度分析
- 时间复杂度:
,只需常数次算术运算。
- 空间复杂度:
,只使用常数个变量。

京公网安备 11010502036488号