1. 代码实现
#!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T import theano if __name__ == "__main__": # 用一个累加器来测试共享变量 state = theano.shared(np.array(0, dtype=np.float64), 'state') inc = T.scalar('inc', dtype=state.dtype) accumulator = theano.function([inc], state, updates=[ (state, state + inc) ]) print state.get_value() accumulator(10) print state.get_value() # 这里不宜直接用print accumulaot(10)进行取值 # 设置共享变量的值 state.set_value(-1) accumulator(3) print state.get_value() # 使用另一个变量暂时代替共享变量进行赋值 tmp_function = state * 2 + inc a = T.scalar(dtype=state.dtype) skip_shared = theano.function([inc, a], tmp_function, givens=[ (state, a) ]) print skip_shared(2, 3) print state.get_value()
结果:
/Users/liudaoqiang/PycharmProjects/numpy/venv/bin/python /Users/liudaoqiang/Project/python_project/theano_day3/shared_value.py 0.0 10.0 2.0 8.0 2.0 Process finished with exit code 0