#获取网络中的参数

#1、定义一个网络
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(in_channels=3, out_channels=5, kernel_size=3)

    def forward(self,x):
        x = self.conv(x)
        return x

#2、随机模拟获得一个图像数据, 按 NCHW 的顺序填入数字
img = torch.randn((1, 3, 5, 5))

#3、创建网络实例
net = Net()

#4、网络输出
y = net(img)

#5、查看网络中的参数及计算网络中参数总量
#网络中的参数的计算方式:参数每一维尺寸相乘,如[5,3,3,3]--- 5*3*3*3
total_params = 0
for param in net.parameters():
    print(param)     #查看网络中每层每个参数的数值
    print('--------------')
    print(param.size()) #查看网络中每层参数的尺寸
    print('***************')

    dims = len(param.size()) #获取参数尺寸的维数
    p = 1
    for i in range(dims):
        p *= param.size(i)
    total_params += p
print('总参数数量为:', total_params)



#6、同时查看参数名称和参数    
for name, param in net.named_parameters():
    print('同时查看参数名称和参数')
    print(name)
    print(param)
    print(param.requires_grad) #查看该参数是否需要进梯度更新

Jupyter上的python代码