Python激活函数

VIT参数及可视化

Python示例代码

    import torch
    from transformers import ViTModel, ViTConfig
    import collections

    def analyze_vit_parameters(model_name='./google/vit-base-patch16-224'):
        """
        分析 ViT 模型的参数量
        
        Args:
            model_name: ViT 模型名称,默认为 vit-base-patch16-224
        """
        print(f"正在加载模型: {model_name}")
        
        # 加载模型和配置
        model = ViTModel.from_pretrained(model_name)
        config = ViTConfig.from_pretrained(model_name)
        model.eval()
        
        print(f"\n{'='*60}")
        print(f"ViT 模型参数分析: {model_name}")
        print(f"{'='*60}")
        
        # 打印模型基本信息
        print(f"模型配置:")
        print(f"  - 图像尺寸: {config.image_size}×{config.image_size}")
        print(f"  - Patch 大小: {config.patch_size}×{config.patch_size}")
        print(f"  - 隐藏层维度: {config.hidden_size}")
        print(f"  - Transformer 层数: {config.num_hidden_layers}")
        print(f"  - 注意力头数: {config.num_attention_heads}")
        print(f"  - MLP 扩展比例: {config.intermediate_size/config.hidden_size:.1f}x")
        
        total_params = 0
        layer_params = collections.OrderedDict()
        
        print(f"\n{'层名称':<25} {'参数类型':<15} {'参数量':<12} {'形状':<30}")
        print("-" * 85)
        
        # 1. Patch Embedding 参数
        patch_embed_params = sum(p.numel() for p in model.embeddings.patch_embeddings.projection.parameters())
        total_params += patch_embed_params
        layer_params['patch_embedding'] = patch_embed_params
        print(f"{'patch_embedding':<25} {'Conv2d':<15} {patch_embed_params:<12,} {str(tuple(model.embeddings.patch_embeddings.projection.weight.shape)):<30}")
        
        # 2. Position Embedding 参数
        pos_embed_params = model.embeddings.position_embeddings.numel()
        total_params += pos_embed_params
        layer_params['position_embedding'] = pos_embed_params
        print(f"{'position_embedding':<25} {'Embedding':<15} {pos_embed_params:<12,} {str(model.embeddings.position_embeddings.shape):<30}")
        
        # 3. CLS Token 参数
        cls_token_params = model.embeddings.cls_token.numel()
        total_params += cls_token_params
        layer_params['cls_token'] = cls_token_params
        print(f"{'cls_token':<25} {'Parameter':<15} {cls_token_params:<12,} {str(model.embeddings.cls_token.shape):<30}")
        
        # 4. Transformer 层参数
        transformer_total = 0
        for i in range(config.num_hidden_layers):
            layer = model.encoder.layer[i]
            
            # 注意力层参数 (Q, K, V, O)
            attention_params = sum(p.numel() for p in layer.attention.parameters())
            
            # MLP 层参数
            mlp_params = sum(p.numel() for p in layer.intermediate.parameters()) + \
                        sum(p.numel() for p in layer.output.parameters())
            
            # 层归一化参数
            layer_norm_params = sum(p.numel() for p in layer.layernorm_before.parameters()) + \
                               sum(p.numel() for p in layer.layernorm_after.parameters())
            
            layer_total = attention_params + mlp_params + layer_norm_params
            transformer_total += layer_total
            total_params += layer_total
            
            layer_params[f'encoder.layer.{i}'] = layer_total
            
            print(f"{f'encoder.layer.{i}':<25} {'Transformer':<15} {layer_total:<12,} {'-':<30}")
            print(f"{'  ├─ attention':<25} {'MSA':<15} {attention_params:<12,} {'-':<30}")
            print(f"{'  ├─ mlp':<25} {'FFN':<15} {mlp_params:<12,} {'-':<30}")
            print(f"{'  └─ layer_norm':<25} {'LayerNorm':<15} {layer_norm_params:<12,} {'-':<30}")
        
        layer_params['transformer_total'] = transformer_total
        
        # 5. 池化层和其他参数
        other_params = 0
        # 添加其他可能的参数...
        
        print("-" * 85)
        print(f"{'总计':<25} {'-':<15} {total_params:<12,} {'-':<30}")
        
        return total_params, layer_params, config

    def print_parameter_distribution(total_params, layer_params, config):
        """打印参数分布统计"""
        print(f"\n{'='*60}")
        print("参数分布分析")
        print(f"{'='*60}")
        
        # 计算各模块参数比例
        patch_embed_pct = layer_params['patch_embedding'] / total_params * 100
        pos_embed_pct = layer_params['position_embedding'] / total_params * 100
        cls_token_pct = layer_params['cls_token'] / total_params * 100
        transformer_pct = layer_params['transformer_total'] / total_params * 100
        
        print(f"\n各模块参数分布:")
        print(f"  - Patch Embedding: {layer_params['patch_embedding']:,} params ({patch_embed_pct:.2f}%)")
        print(f"  - Position Embedding: {layer_params['position_embedding']:,} params ({pos_embed_pct:.2f}%)")
        print(f"  - CLS Token: {layer_params['cls_token']:,} params ({cls_token_pct:.4f}%)")
        print(f"  - Transformer Layers: {layer_params['transformer_total']:,} params ({transformer_pct:.2f}%)")
        
        # 每个 Transformer 层的平均参数
        avg_layer_params = layer_params['transformer_total'] / config.num_hidden_layers
        print(f"\n每个 Transformer 层平均参数: {avg_layer_params:,.0f}")
        
        # 参数总量换算
        print(f"\n总参数量: {total_params:,} ({total_params/1e6:.2f}M)")

    # 使用示例
    if __name__ == "__main__":
        # 可以分析不同的 ViT 模型
        models_to_analyze = [
            'google/vit-base-patch16-224',
            # 'google/vit-large-patch16-224',
            # 'google/vit-huge-patch14-224-in21k'
        ]
        
        for model_name in models_to_analyze:
            total_params, layer_params, config = analyze_vit_parameters(model_name)
            print_parameter_distribution(total_params, layer_params, config)
            print("\n" + "="*80 + "\n")
        

    运行结果
    正在加载模型: google/vit-base-patch16-224
    Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
    You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

    ============================================================
    ViT 模型参数分析: google/vit-base-patch16-224
    ============================================================
    模型配置:
      - 图像尺寸: 224×224
      - Patch 大小: 16×16
      - 隐藏层维度: 768
      - Transformer 层数: 12
      - 注意力头数: 12
      - MLP 扩展比例: 4.0x

    层名称                       参数类型            参数量          形状                            
    -------------------------------------------------------------------------------------
    patch_embedding           Conv2d          590,592      (768, 3, 16, 16)              
    position_embedding        Embedding       151,296      torch.Size([1, 197, 768])     
    cls_token                 Parameter       768          torch.Size([1, 1, 768])       
    encoder.layer.0           Transformer     7,087,872    -                             
      ├─ attention            MSA             2,362,368    -                             
      ├─ mlp                  FFN             4,722,432    -                             
      └─ layer_norm           LayerNorm       3,072        -                             
    encoder.layer.1           Transformer     7,087,872    -                             
      ├─ attention            MSA             2,362,368    -                             
      ├─ mlp                  FFN             4,722,432    -                             
      └─ layer_norm           LayerNorm       3,072        -                             
    encoder.layer.2           Transformer     7,087,872    -                             
      ├─ attention            MSA             2,362,368    -                             
      ├─ mlp                  FFN             4,722,432    -                             
      └─ layer_norm           LayerNorm       3,072        -                             
    encoder.layer.3           Transformer     7,087,872    -                             
      ├─ attention            MSA             2,362,368    -                             
      ├─ mlp                  FFN             4,722,432    -                             
      └─ layer_norm           LayerNorm       3,072        -                             
    encoder.layer.4           Transformer     7,087,872    -                             
      ├─ attention            MSA             2,362,368    -                             
      ├─ mlp                  FFN             4,722,432    -                             
      └─ layer_norm           LayerNorm       3,072        -                             
    encoder.layer.5           Transformer     7,087,872    -                             
      ├─ attention            MSA             2,362,368    -                             
      ├─ mlp                  FFN             4,722,432    -                             
      └─ layer_norm           LayerNorm       3,072        -                             
    encoder.layer.6           Transformer     7,087,872    -                             
      ├─ attention            MSA             2,362,368    -                             
      ├─ mlp                  FFN             4,722,432    -                             
      └─ layer_norm           LayerNorm       3,072        -                             
    encoder.layer.7           Transformer     7,087,872    -                             
      ├─ attention            MSA             2,362,368    -                             
      ├─ mlp                  FFN             4,722,432    -                             
      └─ layer_norm           LayerNorm       3,072        -                             
    encoder.layer.8           Transformer     7,087,872    -                             
      ├─ attention            MSA             2,362,368    -                             
      ├─ mlp                  FFN             4,722,432    -                             
      └─ layer_norm           LayerNorm       3,072        -                             
    encoder.layer.9           Transformer     7,087,872    -                             
      ├─ attention            MSA             2,362,368    -                             
      ├─ mlp                  FFN             4,722,432    -                             
      └─ layer_norm           LayerNorm       3,072        -                             
    encoder.layer.10          Transformer     7,087,872    -                             
      ├─ attention            MSA             2,362,368    -                             
      ├─ mlp                  FFN             4,722,432    -                             
      └─ layer_norm           LayerNorm       3,072        -                             
    encoder.layer.11          Transformer     7,087,872    -                             
      ├─ attention            MSA             2,362,368    -                             
      ├─ mlp                  FFN             4,722,432    -                             
      └─ layer_norm           LayerNorm       3,072        -                             
    -------------------------------------------------------------------------------------
    总计                        -               85,797,120   -                             

    ============================================================
    参数分布分析
    ============================================================

    各模块参数分布:
      - Patch Embedding: 590,592 params (0.69%)
      - Position Embedding: 151,296 params (0.18%)
      - CLS Token: 768 params (0.0009%)
      - Transformer Layers: 85,054,464 params (99.13%)

    每个 Transformer 层平均参数: 7,087,872

    总参数量: 85,797,120 (85.80M)

    ​


    


Designed by Xiaoyu Linghu