数据类型参数如何在TensorFlow的Value函数中使用?

时间:2021-10-12 16:28:03

While creating variables in TensorFlow we can specify the data type. However, it looks to me that this argument is just ignored. For example:

在TensorFlow中创建变量时,我们可以指定数据类型。但是,我认为这个论点只是被忽略了。例如:

In [26]: b = tf.Variable([3.0, 4.0, 5.0], tf.float64)

In [27]: b.dtype
Out[27]: tf.float32_ref

In [28]: b = tf.Variable(np.array([3.0, 4.0, 5.0]), tf.float64)

In [29]: b.dtype
Out[29]: tf.float64_ref

In [30]: b = tf.Variable(np.array([3.0, 4.0, 5.0]), tf.float32)

In [31]: b.dtype
Out[31]: tf.float64_ref

So, if I initialize values from a Python list, I get tf.float32_ref as the type (even though I give tf.float64 as the second argument of the Value function). If I use a numpy array for the value initialization, the situation is opposite (I get tf.float64_ref as the data type even if I give tf.float32 as the second argument of the Value function).

所以,如果我从Python列表初始化值,我得到tf.float32_ref作为类型(即使我将tf.float64作为Value函数的第二个参数)。如果我使用numpy数组进行值初始化,情况则相反(即使我将tf.float32作为Value函数的第二个参数,我也将tf.float64_ref作为数据类型)。

I guess that the data-type is taken from the data-type of the object that is used for the value initialization. Which kind of makes sense, but than why do we need the dtype argument in the Value function?

我猜数据类型取自用于值初始化的对象的数据类型。哪种有意义,但为什么我们需要在Value函数中使用dtype参数?

1 个解决方案

#1


1  

I think the problem is just that the 2nd argument of Variable constructor is not dtype:

我认为问题只是变量构造函数的第二个参数不是dtype:

>>>b = tf.Variable(np.array([3.0, 4.0, 5.0]), tf.float32)
>>>b.dtype
tf.float64_ref
>>>b = tf.Variable(np.array([3.0, 4.0, 5.0]), dtype=tf.float32)
>>>b.dtype
tf.float32_ref

You can check the doc here: the 2nd argument is "trainable".

你可以在这里查看文档:第二个参数是“可训练的”。

As a side note: in more recent versions of Tensorflow, it is advised to use tf.get_variable to create variables instead of tf.Variable().

作为旁注:在更新版本的Tensorflow中,建议使用tf.get_variable来创建变量而不是tf.Variable()。

#1


1  

I think the problem is just that the 2nd argument of Variable constructor is not dtype:

我认为问题只是变量构造函数的第二个参数不是dtype:

>>>b = tf.Variable(np.array([3.0, 4.0, 5.0]), tf.float32)
>>>b.dtype
tf.float64_ref
>>>b = tf.Variable(np.array([3.0, 4.0, 5.0]), dtype=tf.float32)
>>>b.dtype
tf.float32_ref

You can check the doc here: the 2nd argument is "trainable".

你可以在这里查看文档:第二个参数是“可训练的”。

As a side note: in more recent versions of Tensorflow, it is advised to use tf.get_variable to create variables instead of tf.Variable().

作为旁注:在更新版本的Tensorflow中,建议使用tf.get_variable来创建变量而不是tf.Variable()。