如下 :
# model是实例化的模型对象
print(model)
print('*********************************************************************')
for param_tensor in model.state_dict(): # 字典的遍历默认是遍历 key,所以param_tensor实际上是键值
print(param_tensor, '\t', model.state_dict()[param_tensor].size())
print('*********************************************************************')
from prettytable import PrettyTable
def count_parameters(model):
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad: continue
param = ()
table.add_row([name, param])
total_params += param
print(table)
print(f"Total Trainable Params: {total_params}")
return total_params
count_parameters(model)
print('*********************************************************************')
for para in model.named_parameters(): # 返回的每一个元素是一个元组 tuple
'''
是一个元组 tuple ,元组的第一个元素是参数所对应的名称,第二个元素就是对应的参数值
'''
print(para[0], '\t', para[1].size())
print('*********************************************************************')
# 总参数个数
print('总参数个数 = ',sum(() for p in () if p.requires_grad))
print('*********************************************************************')
params = list(())
# 网络层数
print('网络层数 = ',params.__len__())