Python激活函数
import torch
from transformers import ViTModel, ViTConfig
import collections
def analyze_vit_parameters(model_name='./google/vit-base-patch16-224'):
"""
分析 ViT 模型的参数量
Args:
model_name: ViT 模型名称,默认为 vit-base-patch16-224
"""
print(f"正在加载模型: {model_name}")
# 加载模型和配置
model = ViTModel.from_pretrained(model_name)
config = ViTConfig.from_pretrained(model_name)
model.eval()
print(f"\n{'='*60}")
print(f"ViT 模型参数分析: {model_name}")
print(f"{'='*60}")
# 打印模型基本信息
print(f"模型配置:")
print(f" - 图像尺寸: {config.image_size}×{config.image_size}")
print(f" - Patch 大小: {config.patch_size}×{config.patch_size}")
print(f" - 隐藏层维度: {config.hidden_size}")
print(f" - Transformer 层数: {config.num_hidden_layers}")
print(f" - 注意力头数: {config.num_attention_heads}")
print(f" - MLP 扩展比例: {config.intermediate_size/config.hidden_size:.1f}x")
total_params = 0
layer_params = collections.OrderedDict()
print(f"\n{'层名称':<25} {'参数类型':<15} {'参数量':<12} {'形状':<30}")
print("-" * 85)
# 1. Patch Embedding 参数
patch_embed_params = sum(p.numel() for p in model.embeddings.patch_embeddings.projection.parameters())
total_params += patch_embed_params
layer_params['patch_embedding'] = patch_embed_params
print(f"{'patch_embedding':<25} {'Conv2d':<15} {patch_embed_params:<12,} {str(tuple(model.embeddings.patch_embeddings.projection.weight.shape)):<30}")
# 2. Position Embedding 参数
pos_embed_params = model.embeddings.position_embeddings.numel()
total_params += pos_embed_params
layer_params['position_embedding'] = pos_embed_params
print(f"{'position_embedding':<25} {'Embedding':<15} {pos_embed_params:<12,} {str(model.embeddings.position_embeddings.shape):<30}")
# 3. CLS Token 参数
cls_token_params = model.embeddings.cls_token.numel()
total_params += cls_token_params
layer_params['cls_token'] = cls_token_params
print(f"{'cls_token':<25} {'Parameter':<15} {cls_token_params:<12,} {str(model.embeddings.cls_token.shape):<30}")
# 4. Transformer 层参数
transformer_total = 0
for i in range(config.num_hidden_layers):
layer = model.encoder.layer[i]
# 注意力层参数 (Q, K, V, O)
attention_params = sum(p.numel() for p in layer.attention.parameters())
# MLP 层参数
mlp_params = sum(p.numel() for p in layer.intermediate.parameters()) + \
sum(p.numel() for p in layer.output.parameters())
# 层归一化参数
layer_norm_params = sum(p.numel() for p in layer.layernorm_before.parameters()) + \
sum(p.numel() for p in layer.layernorm_after.parameters())
layer_total = attention_params + mlp_params + layer_norm_params
transformer_total += layer_total
total_params += layer_total
layer_params[f'encoder.layer.{i}'] = layer_total
print(f"{f'encoder.layer.{i}':<25} {'Transformer':<15} {layer_total:<12,} {'-':<30}")
print(f"{' ├─ attention':<25} {'MSA':<15} {attention_params:<12,} {'-':<30}")
print(f"{' ├─ mlp':<25} {'FFN':<15} {mlp_params:<12,} {'-':<30}")
print(f"{' └─ layer_norm':<25} {'LayerNorm':<15} {layer_norm_params:<12,} {'-':<30}")
layer_params['transformer_total'] = transformer_total
# 5. 池化层和其他参数
other_params = 0
# 添加其他可能的参数...
print("-" * 85)
print(f"{'总计':<25} {'-':<15} {total_params:<12,} {'-':<30}")
return total_params, layer_params, config
def print_parameter_distribution(total_params, layer_params, config):
"""打印参数分布统计"""
print(f"\n{'='*60}")
print("参数分布分析")
print(f"{'='*60}")
# 计算各模块参数比例
patch_embed_pct = layer_params['patch_embedding'] / total_params * 100
pos_embed_pct = layer_params['position_embedding'] / total_params * 100
cls_token_pct = layer_params['cls_token'] / total_params * 100
transformer_pct = layer_params['transformer_total'] / total_params * 100
print(f"\n各模块参数分布:")
print(f" - Patch Embedding: {layer_params['patch_embedding']:,} params ({patch_embed_pct:.2f}%)")
print(f" - Position Embedding: {layer_params['position_embedding']:,} params ({pos_embed_pct:.2f}%)")
print(f" - CLS Token: {layer_params['cls_token']:,} params ({cls_token_pct:.4f}%)")
print(f" - Transformer Layers: {layer_params['transformer_total']:,} params ({transformer_pct:.2f}%)")
# 每个 Transformer 层的平均参数
avg_layer_params = layer_params['transformer_total'] / config.num_hidden_layers
print(f"\n每个 Transformer 层平均参数: {avg_layer_params:,.0f}")
# 参数总量换算
print(f"\n总参数量: {total_params:,} ({total_params/1e6:.2f}M)")
# 使用示例
if __name__ == "__main__":
# 可以分析不同的 ViT 模型
models_to_analyze = [
'google/vit-base-patch16-224',
# 'google/vit-large-patch16-224',
# 'google/vit-huge-patch14-224-in21k'
]
for model_name in models_to_analyze:
total_params, layer_params, config = analyze_vit_parameters(model_name)
print_parameter_distribution(total_params, layer_params, config)
print("\n" + "="*80 + "\n")
运行结果
正在加载模型: google/vit-base-patch16-224
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
============================================================
ViT 模型参数分析: google/vit-base-patch16-224
============================================================
模型配置:
- 图像尺寸: 224×224
- Patch 大小: 16×16
- 隐藏层维度: 768
- Transformer 层数: 12
- 注意力头数: 12
- MLP 扩展比例: 4.0x
层名称 参数类型 参数量 形状
-------------------------------------------------------------------------------------
patch_embedding Conv2d 590,592 (768, 3, 16, 16)
position_embedding Embedding 151,296 torch.Size([1, 197, 768])
cls_token Parameter 768 torch.Size([1, 1, 768])
encoder.layer.0 Transformer 7,087,872 -
├─ attention MSA 2,362,368 -
├─ mlp FFN 4,722,432 -
└─ layer_norm LayerNorm 3,072 -
encoder.layer.1 Transformer 7,087,872 -
├─ attention MSA 2,362,368 -
├─ mlp FFN 4,722,432 -
└─ layer_norm LayerNorm 3,072 -
encoder.layer.2 Transformer 7,087,872 -
├─ attention MSA 2,362,368 -
├─ mlp FFN 4,722,432 -
└─ layer_norm LayerNorm 3,072 -
encoder.layer.3 Transformer 7,087,872 -
├─ attention MSA 2,362,368 -
├─ mlp FFN 4,722,432 -
└─ layer_norm LayerNorm 3,072 -
encoder.layer.4 Transformer 7,087,872 -
├─ attention MSA 2,362,368 -
├─ mlp FFN 4,722,432 -
└─ layer_norm LayerNorm 3,072 -
encoder.layer.5 Transformer 7,087,872 -
├─ attention MSA 2,362,368 -
├─ mlp FFN 4,722,432 -
└─ layer_norm LayerNorm 3,072 -
encoder.layer.6 Transformer 7,087,872 -
├─ attention MSA 2,362,368 -
├─ mlp FFN 4,722,432 -
└─ layer_norm LayerNorm 3,072 -
encoder.layer.7 Transformer 7,087,872 -
├─ attention MSA 2,362,368 -
├─ mlp FFN 4,722,432 -
└─ layer_norm LayerNorm 3,072 -
encoder.layer.8 Transformer 7,087,872 -
├─ attention MSA 2,362,368 -
├─ mlp FFN 4,722,432 -
└─ layer_norm LayerNorm 3,072 -
encoder.layer.9 Transformer 7,087,872 -
├─ attention MSA 2,362,368 -
├─ mlp FFN 4,722,432 -
└─ layer_norm LayerNorm 3,072 -
encoder.layer.10 Transformer 7,087,872 -
├─ attention MSA 2,362,368 -
├─ mlp FFN 4,722,432 -
└─ layer_norm LayerNorm 3,072 -
encoder.layer.11 Transformer 7,087,872 -
├─ attention MSA 2,362,368 -
├─ mlp FFN 4,722,432 -
└─ layer_norm LayerNorm 3,072 -
-------------------------------------------------------------------------------------
总计 - 85,797,120 -
============================================================
参数分布分析
============================================================
各模块参数分布:
- Patch Embedding: 590,592 params (0.69%)
- Position Embedding: 151,296 params (0.18%)
- CLS Token: 768 params (0.0009%)
- Transformer Layers: 85,054,464 params (99.13%)
每个 Transformer 层平均参数: 7,087,872
总参数量: 85,797,120 (85.80M)


Designed by Xiaoyu Linghu