#获取网络中的参数
#1、定义一个网络
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(in_channels=3, out_channels=5, kernel_size=3)
def forward(self,x):
x = self.conv(x)
return x
#2、随机模拟获得一个图像数据, 按 NCHW 的顺序填入数字
img = torch.randn((1, 3, 5, 5))
#3、创建网络实例
net = Net()
#4、网络输出
y = net(img)
#5、查看网络中的参数及计算网络中参数总量
#网络中的参数的计算方式:参数每一维尺寸相乘,如[5,3,3,3]--- 5*3*3*3
total_params = 0
for param in net.parameters():
print(param) #查看网络中每层每个参数的数值
print('--------------')
print(param.size()) #查看网络中每层参数的尺寸
print('***************')
dims = len(param.size()) #获取参数尺寸的维数
p = 1
for i in range(dims):
p *= param.size(i)
total_params += p
print('总参数数量为:', total_params)
#6、同时查看参数名称和参数
for name, param in net.named_parameters():
print('同时查看参数名称和参数')
print(name)
print(param)
print(param.requires_grad) #查看该参数是否需要进梯度更新
Jupyter上的python代码