环境:Ubuntu14.04,tensorflow=1.4(bazel源码安装),Anaconda python=3.6
声明变量主要有两种方法:tf.Variable和 tf.get_variable,二者的最大区别是:
(1) tf.Variable是一个类,自带很多属性函数;而 tf.get_variable是一个函数;
(2) tf.Variable只能生成独一无二的变量,即如果给出的name已经存在,则会自动修改生成新的变量name;
(3) tf.get_variable可以用于生成共享变量。默认情况下,该函数会进行变量名检查,如果有重复则会报错。当在指定变量域中声明可
以变量共享时,可以重复使用该变量(例如RNN中的参数共享)。
下面给出简单的的示例程序:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
import tensorflow as tf
with tf.variable_scope( 'scope1' ,reuse = tf.AUTO_REUSE) as scope1:
x1 = tf.Variable(tf.ones([ 1 ]),name = 'x1' )
x2 = tf.Variable(tf.zeros([ 1 ]),name = 'x1' )
y1 = tf.get_variable( 'y1' ,initializer = 1.0 )
y2 = tf.get_variable( 'y1' ,initializer = 0.0 )
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print (x1.name,x1. eval ())
print (x2.name,x2. eval ())
print (y1.name,y1. eval ())
print (y2.name,y2. eval ())
|
输出结果为:
1
2
3
4
|
scope1 / x1: 0 [ 1. ]
scope1 / x1_1: 0 [ 0. ]
scope1 / y1: 0 1.0
scope1 / y1: 0 1.0
|
1. tf.Variable(…)
tf.Variable(…)使用给定初始值来创建一个新变量,该变量会默认添加到 graph collections listed in collections, which defaults to [GraphKeys.GLOBAL_VARIABLES]。
如果trainable属性被设置为True,该变量同时也会被添加到graph collection GraphKeys.TRAINABLE_VARIABLES.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
# tf.Variable
__init__(
initial_value = None ,
trainable = True ,
collections = None ,
validate_shape = True ,
caching_device = None ,
name = None ,
variable_def = None ,
dtype = None ,
expected_shape = None ,
import_scope = None ,
constraint = None
)
|
2. tf.get_variable(…)
tf.get_variable(…)的返回值有两种情形:
使用指定的initializer来创建一个新变量;
当变量重用时,根据变量名搜索返回一个由tf.get_variable创建的已经存在的变量;
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
get_variable(
name,
shape = None ,
dtype = None ,
initializer = None ,
regularizer = None ,
trainable = True ,
collections = None ,
caching_device = None ,
partitioner = None ,
validate_shape = True ,
use_resource = None ,
custom_getter = None ,
constraint = None
)
|
3. 根据名称查找变量
在创建变量时,即使我们不指定变量名称,程序也会自动进行命名。于是,我们可以很方便的根据名称来查找变量,这在抓取参数、finetune模型等很多时候都很有用。
示例1:
通过在tf.global_variables()变量列表中,根据变量名进行匹配搜索查找。 该种搜索方式,可以同时找到由tf.Variable或者tf.get_variable创建的变量。
1
2
3
4
5
6
7
|
import tensorflow as tf
x = tf.Variable( 1 ,name = 'x' )
y = tf.get_variable(name = 'y' ,shape = [ 1 , 2 ])
for var in tf.global_variables():
if var.name = = 'x:0' :
print (var)
|
示例2:
利用get_tensor_by_name()同样可以获得由tf.Variable或者tf.get_variable创建的变量。
需要注意的是,此时获得的是Tensor, 而不是Variable,因此 x不等于x1.
1
2
3
4
5
6
7
8
9
|
import tensorflow as tf
x = tf.Variable( 1 ,name = 'x' )
y = tf.get_variable(name = 'y' ,shape = [ 1 , 2 ])
graph = tf.get_default_graph()
x1 = graph.get_tensor_by_name( "x:0" )
y1 = graph.get_tensor_by_name( "y:0" )
|
示例3:
针对tf.get_variable创建的变量,可以利用变量重用来直接获取已经存在的变量。
1
2
3
4
5
6
7
8
9
10
|
with tf.variable_scope( "foo" ):
bar1 = tf.get_variable( "bar" , ( 2 , 3 )) # create
with tf.variable_scope( "foo" , reuse = True ):
bar2 = tf.get_variable( "bar" ) # reuse
with tf.variable_scope("", reuse = True ): # root variable scope
bar3 = tf.get_variable( "foo/bar" ) # reuse (equivalent to the above)
print ((bar1 is bar2) and (bar2 is bar3))
|
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持服务器之家。