本文内容源自百度强化学习 7 日入门课程学习整理
感谢百度 PARL 团队李科浇老师的课程讲解
文章目录
一、为什么要引入神经网络
Q 表只能解决少量状态的问题,如果状态数量上涨,那我们面对的可能性呈现指数上涨,这样的话Q表格就没有这个处理能力了
比如:
- 国际象棋:种状态
- 围棋:种状态
- 连续操作的问题:不可数状态(不如弯曲角度)
- (整个宇宙的原子数量预估:)
Q表格不行的时候,我们可以采用:值函数(Q函数)近似
Q表格的作用在于:输入状态和动作,输出Q值
那我们可以用一个 “带参数” 的 Q 函数来进行替代:
- 多项式函数
- 神经网络
不同的近似方式:
- 输入状态 s 和动作 a,输出一个q 值
- 输入状态 s,输出多个 q 值(不同动作所对应的 q 值)
Q表格方法的缺点:
- 表格占用大量内存
- 表格大的时候,查表效率低
值函数近似的优点:
- 仅需存储有限数量的参数
- 状态泛化,相似的状态可以输出一样
神经网络可以逼近任意连续的函数
- 比如 CNN 在强化学习中引入后,可以让强化学习算法根据图片做出决策(输入图片,输出动作)
- 神经网络的原理在于,定义 cost 为真实值和预测值之间的差距,然后用梯度下降来最小化 cost
二、DQN 算法
DQN 是使用神经网络解决强化学习问题最经典的算法
该算法由谷歌的 DeepMind 团队在 2015 年提出
《Human-level control through deep reinforcement learning》这篇论文被发表在了 Nature 杂志上
通过高维度的输入信息(像素级别的图像),使用了神经网络的 DQN 在 49 个 Atari 游戏中,有 30 个超越了人类水平
使用神经网络代替Q表格以后:
- 输入可以是一个向量,包含各种值(比如四轴飞行器的高度,角度,转速等)
- 输入可以是一个图片,包含各个像素点的信息
- 输出直接是对应的动作
2.1 DQN 约等于 Q-learning + 神经网络
- 输入 状态 s
- 输出 q 向量,如果一个状态下有 5 种动作,那 q 就是 5 维的
- 然后根据我们具体的动作选择,确定 q 值
- 然后要让输出的 q 值,逼近 目标 q 值(target_q)
- target_q 的计算公式就是 Q-learning 的方法:
- 神经网络输出的预测值:
- 计算预测值和目标值的均方差(即 loss):
- 使用优化器,最小化 loss
2.2 DQN 的两大创新
神经网络中由于引入了非线形函数,比如 “relu”
所以在理论上,无法证明训练之后一定会收敛
于是 DQN 提出两大创新,使得训练更有效率,也更稳定
2.2.1 经验回放 Experience replay
作用:
- 解决序列决策的样本关联性问题
- 解决样本利用率低的问题
问题来源:
- 在监督学习中,训练样本是独立的
- 但是在强化学习中,输入的是状态值,每一个状态都是连续发生,前后状态相互关联,所以样本之间具有关联性
解决方案:
- 需要打乱,或者切断输入样本之间的联系
- 这里用到了 Q-learning 的 Off-Policy 特点
- 先存储一批经验数据
- 然后打乱
- 从中随机选取一个小的 batch 来更新网络
- 这样就打破了样本间的相关性,同时使得网络更有效率
Off-Policy 在经验回放中的作用:
- 设置经验池:是一个固定长度的队列
- 一条经验指的是:一组 ,,,
- 每拿到一条经验就往经验池进行存储
- 满了以后,弹出旧的经验
- 从经验池中随机抽取一个 batch
- 去更新 Q 值(这里就是更新神经网络的系数)
优点:
- 由于经验池中的数据有可能被重复抽取到,所以相当于经验可以重复利用,即提高了样本的利用率
- 另外由于是随机抽取,所以打乱了样本间的相关性
2.2.2 固定 Q 目标 Fixed Q target
作用:
- 解决算法训练不稳定的问题
问题来源:
- 监督学习中,我们预测值要去逼近真实值,而真实值是固定不变的
- 但是在 DQN 中,输入状态输出预测的Q,要逼近的是目标Q
- 其中 也是神经网络的输出,而神经网络权重系数一旦更新以后,这个值也会发生变化
- 所以只要我们更新一次神经网络,那目标 Q 值也就会不断变化
解决方法:
- 我们要想办法把 Q-target 值固定住
- 也就是我们要把输出 Q-target 的神经网络参数固定一段时间
- 然后过一段时间以后,再用最新的学习后的神经网络参数,刷新这个神经网络
2.3 DQN 流程框架图
Model:
- 代替了 Q 表
- 输入 S 输出 不同动作对应的 Q(预测值)给 Agent
- 同时设定一个固定一段时间的神经网络用于输出 Q_target
- 过一段时间更新该固定网络参数
引入神经网络的问题解决:
- 经验回放
- 固定目标值
Agent:
- 和环境交互
- 交互数据(经验)存储到经验池
- 提取经验池数据,更新 Model 参数(利用最小化 预测值和目标值之间的 loss)——DQN最核心部分
2.4 PARL 的 DQN 框架
分为 model,algorithm,agent 这 3 个部分
- model:用来定义神经网络部分的网络结构,同时实现模型复制
- algorithm:实现具体算法,如何定义损失函数,更新 model,主要包含了 predict() 和 learn() 两个函数
- agent:负责和环境做交互,数据预处理,构建计算图
总体抽象来说:
- Agent 包含了 Algorithm 和 Model
- Algorithm 包含了 Model
PARL 常用的 API:
- agent.save():保存模型
- agent.restore():加载模型
- model.sync_weights_to():把当前模型的参数同步到另一个模型去
- model.parameters():返回一个 list,包含模型所有参数的名称
- model.get_weights():返回一个 list,包含模型的所有参数
- model.set_weights():设置模型参数
PARL 里面打印日志的工具:
- parl.utils.logger:打印日志,涵盖时间,代码所在文件及行数,方便记录训练时间
PARL 的 API 文档地址:
https://parl.readthedocs.io/en/latest/model.html