Python激活函数
import torch
import torchvision.models as models
# 加载预训练的AlexNet
model = models.alexnet(weights=models.AlexNet_Weights.DEFAULT)
model.eval()
def analyze_alexnet(model):
total_params = 0
print(f"{'Layer':<15} {'Type':<15} {'Params':<10} {'Shape':<25}")
print("-" * 65)
# 卷积层
conv_params = 0
for i, layer in enumerate(model.features):
if isinstance(layer, torch.nn.Conv2d):
params = sum(p.numel() for p in layer.parameters())
conv_params += params
total_params += params
print(f"{f'conv{(i//3)+1}':<15} {'Conv2d':<15} {params:<10,} {str(tuple(layer.weight.shape)):<25}")
# 全连接层
fc_params = 0
fc_layers = [layer for layer in model.classifier if isinstance(layer, torch.nn.Linear)]
names = ['fc1', 'fc2', 'fc3']
for name, layer in zip(names, fc_layers):
params = sum(p.numel() for p in layer.parameters())
fc_params += params
total_params += params
print(f"{name:<15} {'Linear':<15} {params:<10,} {str(tuple(layer.weight.shape)):<25}")
print("-" * 65)
print(f"{'Total':<15} {'-':<15} {total_params:<10,} {'-':<25}")
return total_params, conv_params, fc_params
# 执行分析
total, conv, fc = analyze_alexnet(model)
print(f"\n参数分布:")
print(f"卷积层: {conv:,} ({conv/total*100:.1f}%)")
print(f"全连接层: {fc:,} ({fc/total*100:.1f}%)")
print(f"总参数: {total:,} ({total/1e6:.1f}M)")
# 测试输入输出
test_input = torch.randn(1, 3, 224, 224)
output = model(test_input)
print(f"\n输入: {test_input.shape} -> 输出: {output.shape}")
运行结果
Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
100%|██████████| 233M/233M [36:29<00:00, 112kB/s]
Layer Type Params Shape
-----------------------------------------------------------------
conv1 Conv2d 23,296 (64, 3, 11, 11)
conv2 Conv2d 307,392 (192, 64, 5, 5)
conv3 Conv2d 663,936 (384, 192, 3, 3)
conv3 Conv2d 884,992 (256, 384, 3, 3)
conv4 Conv2d 590,080 (256, 256, 3, 3)
fc1 Linear 37,752,832 (4096, 9216)
fc2 Linear 16,781,312 (4096, 4096)
fc3 Linear 4,097,000 (1000, 4096)
-----------------------------------------------------------------
Total - 61,100,840 -
参数分布:
卷积层: 2,469,696 (4.0%)
全连接层: 58,631,144 (96.0%)
总参数: 61,100,840 (61.1M)
输入: torch.Size([1, 3, 224, 224]) -> 输出: torch.Size([1, 1000])