PyTorch中查看模型的参数信息的几种方式

时间:2025-03-29 09:47:28

如下 :

# model是实例化的模型对象

print(model)
print('*********************************************************************')

for param_tensor in model.state_dict():  # 字典的遍历默认是遍历 key,所以param_tensor实际上是键值
    print(param_tensor, '\t', model.state_dict()[param_tensor].size())

print('*********************************************************************')

from prettytable import PrettyTable


def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = ()
        table.add_row([name, param])
        total_params += param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params


count_parameters(model)
print('*********************************************************************')

for para in model.named_parameters():  # 返回的每一个元素是一个元组 tuple
    '''
    是一个元组 tuple ,元组的第一个元素是参数所对应的名称,第二个元素就是对应的参数值
    '''
    print(para[0], '\t', para[1].size())

print('*********************************************************************')

# 总参数个数
print('总参数个数 = ',sum(() for p in () if p.requires_grad))

print('*********************************************************************')
params = list(())
# 网络层数
print('网络层数 = ',params.__len__())

相关文章