tensorflow 中的 name_scope 与 variable_scope

时间:2022-06-18 13:32:58

0. 为什么需要

  1. 共享变量:
    减少需要训练的参数的个数
    多机多卡并行化训练

  2. 避免变量名和操作名重复

1. tf.Variable 和 tf.get_variable

tf.Variable:

tf.name_scope 配合使用,用于创建一个新变量,在同一个name_scope下面,可以创建相同名字的变量,底层实现会自动引入别名机制。

tf.get_variable:

tf.variable_scope 配合使用,不受name_scope的约束。查找全称为 当前variable_scope + name 的变量:

  • 如果变量不存在,则自动创建一个变量。
  • 如果有重名的变量:
    • 变量已经设置为共享:返回同名变量
    • 变量名没有设置为共享变量时:报错

2. tf.name_scope, tf.variable_scope

tf.name_scope

主要用于管理一个图里面的各种op,返回的是一个以scope_name命名的context manager。一个graph会维护一个name_space的堆,每一个namespace下面可以定义各种op或者子namespace,实现一种层次化有条理的管理,避免各个op之间命名冲突。

tf.variable_scope

一般与tf.name_scope()配合使用,用于管理变量的名字,避免变量之间的命名冲突,可以嵌套使用。允许在一个variable_scope下面共享变量。

3. 注意

  • 在 variable_scope 里面的 variable_scope 会继承上层的 reuse 值,即上面一层开启了 reuse ,则下面的也跟着开启。但是不能人为的设置 reuse 为 false ,只有退出 variable_scope 才能让 reuse 变为 false。
with tf.variable_scope("root"):  
# At start, the scope is not reusing.
assert tf.get_variable_scope().reuse == False
with tf.variable_scope("foo"):
# Opened a sub-scope, still not reusing.
assert tf.get_variable_scope().reuse == False
with tf.variable_scope("foo", reuse=True):
# Explicitly opened a reusing scope.
assert tf.get_variable_scope().reuse == True
with tf.variable_scope("bar"):
# Now sub-scope inherits the reuse flag.
assert tf.get_variable_scope().reuse == True
# Exited the reusing scope, back to a non-reusing one.
assert tf.get_variable_scope().reuse == False
  • 当在某一 variable_scope 内使用别的 scope 的名字时,此时不再受这里的等级关系束缚,直接与使用的 scope 的名字一样:
with tf.variable_scope("foo") as foo_scope:  
assert foo_scope.name == "foo"
with tf.variable_scope("bar")
with tf.variable_scope("baz") as other_scope:
assert other_scope.name == "bar/baz"
with tf.variable_scope(foo_scope) as foo_scope2:
assert foo_scope2.name == "foo" # Not changed.
  • name_scope 只会影响 ops 和 tf.Variable() 的名字,而并不会影响 get_variable() 的名字。当我们用with tf.variable_scope(“name”)时,这就间接地开启了一个tf.name_scope(“name”)
with tf.variable_scope("foo"):  
with tf.name_scope("bar"):
v = tf.get_variable("v", [1])
x = 1.0 + v
assert v.name == "foo/v:0"
assert x.op.name == "foo/bar/add"