我们知道可以通过()获取网络的参数,那这个是如何实现的呢?我先直接看看函数的代码实现:
def parameters(self):
r"""Returns an iterator over module parameters.
This is typically passed to an optimizer.
Yields:
Parameter: module parameter
Example::
>>> for param in ():
>>> print(type(), ())
<class ''> (20L,)
<class ''> (20L, 1L, 5L, 5L)
"""
for name, param in self.named_parameters():
yield param
def named_parameters(self, memo=None, prefix=''):
r"""Returns an iterator over module parameters, yielding both the
name of the parameter as well as the parameter itself
Yields:
(string, Parameter): Tuple containing the name and parameter
Example::
>>> for name, param in self.named_parameters():
>>> if name in ['bias']:
>>> print(())
"""
if memo is None:
memo = set()
#本身模块的参数
for name, p in self._parameters.items():
if p is not None and p not in memo:
(p)
yield prefix + ('.' if prefix else '') + name, p
for mname, module in self.named_children():
submodule_prefix = prefix + ('.' if prefix else '') + mname
#递归取得子模块的参数
for name, p in module.named_parameters(memo, submodule_prefix):
yield name, p
可以看到是通过枚举模块和子模块(成员对象是Module类型)的成员_parameters,那_parameters是什么?我先不着急,我们先看Module的一些实现,首先看下初始化函数:
def __init__(self):
self._backend = thnn_backend
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._modules = OrderedDict()
= True
可以看到_parameters(也留意_modules ) 其实是有序字典。
接着我们看下函数__setattr__(self, name, value)
def __setattr__(self, name, value):
def remove_from(*dicts):
for d in dicts:
if name in d:
del d[name]
params = self.__dict__.get('_parameters')
#如果成员是Parameter类型
if isinstance(value, Parameter):
if params is None:
raise AttributeError(
"cannot assign parameters before Module.__init__() call")
remove_from(self.__dict__, self._buffers, self._modules)
self.register_parameter(name, value)
elif params is not None and name in params:
if value is not None:
raise TypeError("cannot assign '{}' as parameter '{}' "
"( or None expected)"
.format((value), name))
self.register_parameter(name, value)
else:
modules = self.__dict__.get('_modules')
#如果成员是Module类型
if isinstance(value, Module):
if modules is None:
raise AttributeError(
"cannot assign module before Module.__init__() call")
remove_from(self.__dict__, self._parameters, self._buffers)
modules[name] = value
elif modules is not None and name in modules:
if value is not None:
raise TypeError("cannot assign '{}' as child module '{}' "
"( or None expected)"
.format((value), name))
modules[name] = value
else:
buffers = self.__dict__.get('_buffers')
if buffers is not None and name in buffers:
if value is not None and not isinstance(value, ):
raise TypeError("cannot assign '{}' as buffer '{}' "
"( or None expected)"
.format((value), name))
buffers[name] = value
else:
object.__setattr__(self, name, value)
我们知道如果类实现了该函数,赋值类成员时,将调用该函数。可以看到如果赋值类成员的对象是Parameter类型,那么将调用函数register_parameter注册参数,看该函数实现,其实是添加参数到有序字典成员_parameters中:
def register_parameter(self, name, param):
r"""Adds a parameter to the module.
The parameter can be accessed as an attribute using given name.
Args:
name (string): name of the parameter. The parameter can be accessed
from this module using the given name
parameter (Parameter): parameter to be added to the module.
"""
if '_parameters' not in self.__dict__:
raise AttributeError(
"cannot assign parameter before Module.__init__() call")
elif not isinstance(name, torch._six.string_classes):
raise TypeError("parameter name should be a string. "
"Got {}".format((name)))
elif '.' in name:
raise KeyError("parameter name can't contain \".\"")
elif name == '':
raise KeyError("parameter name can't be empty string \"\"")
elif hasattr(self, name) and name not in self._parameters:
raise KeyError("attribute '{}' already exists".format(name))
if param is None:
self._parameters[name] = None
elif not isinstance(param, Parameter):
raise TypeError("cannot assign '{}' object to parameter '{}' "
"( or None required)"
.format((param), name))
elif param.grad_fn:
raise ValueError(
"Cannot assign non-leaf Tensor to parameter '{0}'. Model "
"parameters must be created explicitly. To express '{0}' "
"as a function of another Tensor, compute the value in "
"the forward() method.".format(name))
else:
self._parameters[name] = param
所以通过调用()获取网络的参数,有一部分是类成员中的Parameter对象,是不是全部呢?我们后面看