【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

时间:2024-04-03 19:14:41

本文用了强化学习,在知识图谱上游走,寻找目标节点。

一、简介

大概意思就是,在知识图谱【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search上,给出一个起始节点和查询(query)【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search,然后找到目标节点【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search 图G包含节点【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search和边【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

如下图,给出起始节点Obama,query:citizenship,目标节点是USA。

 

【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

 

我们要学习一个方法【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search来预测【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

我们我们将f作为强化学习力的agent。他要学习搜索策略(search policy)

训练的时候,我们给出【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search,让f自己学习路径,如果他走到【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search,就给他一个正的reward,或者0分。学完后只给出【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search,预测【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

 

所以设计了一个神经网络的agent,叫M-walk。用RNN将历史路径转化为一个向量,用来学policy和Q function 。reward稀疏,所以用带蒙特卡洛树搜索的RNN,生成路径。

 

二、用马尔科夫决策过程来进行图的游走

(S,A,R,P) s是state,a是action,r是reward function,p是state transition probability

【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

初始状态s0和下一个状态的表示,如上图所示。

【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search是连接点nt的所有边,【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search是nt的所有邻居节点。

si由1)该节点和连接它的边、它的邻居 2)t-1时刻的动作 3)初始query q构成。

 

集合S由所有可能出现的st构成。

在状态st,agent有以下动作可以选择:1)选择【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search中的一条边,他连接到点【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search 2)选择STOP,则【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search就是要预测的【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

动作集合由下图表示

【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

输出【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

 

如果输出是【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search(即输出了正确的答案),则reward=1,否则为0.

这可以看出来,reward是非常稀疏的,只有走到正确的位置才有reward。但是由于图是已知静态确定的,所以如果确定了上一个状态和动作,那么下一个状态时确定的。(文中说这有助于解决reward稀疏。)

 

【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

π是policy(给出状态s,选择动作a),Q是Q function(在状态s下选择动作a,它的Q value是多少,即之后的长期收益是多少)

 

三、M-walk agent

3.1π和Q的神经网路结构

用RNN获得当前状态st的表达ht

ht分为三个部分:

1)【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search 将上个时间的状态、动作、当前节点,综合。

2)【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search综合了nt的邻居n'节点,以及nt和n'之间的边e,代表第n'个候选动作(包括STOP动作)

3)【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search  综合了【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search,用来判断STOP的概率。

 

所以π和Q的计算。

【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

u0是将hst,hAt通过一个full-connected neural network。(这里没说这两个h要怎么整合到一起,可能是拼接吧)

un'是hst和hn't做内积(即点乘,对应位相乘,求和)

u0(STOP的分数),un'(邻居的分数)都是一个数字

Q是对每个数字做sigmoid

π是做温度参数为τ的softmax

关于温度参数

【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

 

3.2 训练算法

传统的使用蒙特卡罗方法的REINFORCE,需要sample一个完整的序列,sample的效率很低,而且reward稀疏。所以sample的时候使用PUCT算法的变体。

 

【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

π是上面提到的策略分数(softmax算的),c和β用来控制探索的程度。N是visit count。W是走(s-a)这条边上的蒙特卡罗树的total action reward。

PUCT算法最开始倾向于选择在状态s下出现少的action(式子的前半部分), 后来倾向于选择分数高的(式子的后半部分)。

当PUCT算法选择了STOP,或者到达了最大探索数(应该是强行选择STOP),则停止。使用【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

用下面的式子,更新上一个式子中的N和W。γ是衰减因子(discount factor).

【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

主要目标就是多生成reward为正的路径。

然后用DQN网络,寻找更好的π就是max Q

【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search 

 3.3预测算法

已知(ns,q)求nT。利用π在G上寻找nT。

我们利用上面已经生成好的蒙特卡罗树。但是可能有多路径到达同一个节点n。走不同路径,就有不同的

这些路径上各个叶子状态sT。怎么比较选择哪个n(n需要综合多条路径),需要算一个分数,排序。

【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

N是蒙特卡罗树的总模拟数量

求和是在所有有关同一个节点n上的子状态sT,是对于同一个候选n的平均权重。

在所有的候选节点中,我们选择score最大的。【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

 

3.4 RNN encoder

【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search

【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search qt约等于右边的(因为s0的原因)

【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search所以st大约可以写成

st由两部分组成 1)【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search  【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search代表候选动作(包括STOP) 2)qt代表历史

所以用两个不同的神经网络去编码他们

前面说过,ht分为三个部分:

1)【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search 将上个时间的状态、动作、当前节点,综合。

2)【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search综合了nt的邻居n'节点,以及nt和n'之间的边e,代表第n'个候选动作(包括STOP动作)

3)【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search  综合了【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search,用来判断STOP的概率。

求 2)的方法很简单,就是边和点的表达通过full-connected neural network

求 3) 的方法,就是max 2)的结果,因为每一次的节点数可能都不一样,这样可以得到统一的结果

求1) 就是编码qt 使用gru的思想

【论文笔记】M-Walk: Learning to Walk over Graphs using Monte Carlo Tree Search