beam search及pytorch的实现方式

时间:2022-09-07 00:23:20

主要记录两种不同的beam search版本

版本一

使用类似层次遍历的方式进行搜索,用队列进行维护,每次循环对当前层的所有节点进行搜索,这些节点每个分别对应topk个节点作为下一层候选节点,取所有候选节点的前tok个作为下一层节点加入队列

bfs with width constraint. 启发式搜索的一种. 属于贪心算法. 如果k -> inf,那么等价于bfs.

从根节点开始(),选取所有可能(大概几万个)里面概率最大的k个,拓展为下一层节点.

然后在这k个节点里面,其可能拓展的所有节点中(一般是k * 几万个),再选取概率最大的k个(注意这里的概率是累乘,即从根节点到该节点的概率乘积)拓展. 这里拓展的k个子节点,其父节点可以是上一层的k个,也可以只是其中一部分,甚至全部出自其中一个节点. 以此类推.

这样形成的是一棵每层都是k个节点树(除了根节点、末尾,和候选者不足k个的情况).

一般概率取log,避免值过小.

举个例子:k=2

<sos> 选取概率最大的三个, “i”: 0.6, “he”: 0.4. 其他单词忽略不计

拓展一共有4个 (1)“i"后面接,假设概率最大的是"love”: 0.7, “like”: 0.3 其他单词忽略不计(2)“he"后面接:假设概率最大的是"hates”: 0.9, “loves”: 0.1 其他单词忽略不计; 这样4种可能中,到这里 "i love"概率是0.6 * 0.7 = 0.42, "i like"概率是0.6 * 0.3 = 0.18, "he hates"概率是0.4 * 0.9 = 0.36, "he loves"概率是0.4 * 0.1 = 0.04; 选取概率最大的两个,“i love"和"he hates”.

下一层拓展仍为4个 (1) "i love"后面接 ,假设概率最大是 “you”:0.9, 其他单词加起来0.1;(2)“he hates"后面接,假设概率最大的是"her”:0.8, “himself”:0.1, 其他单词加起来0.1; 那么"i love you"概率为 0.42 * 0.9 = 0.378; "he hates her"概率为0.36*0.8 = 0.228,其他不用算了都小于这个值. 最后也选取2个概率最大的: "i love you"和 “he hates her”

下一层拓展, “i love you"应该拓展两个子节点,发现”"概率0.99,其他单词加起来0.01;“he hates her"应该拓展两个子节点,发现”"概率0.99,其他单词加起来0.01;所以概率最大的是"i love you "和"he hates you ". 因两个分支均遇到,均结束搜索.

最后在两个当中选择概率最大的 "i love you ". 结束

代码是从一个项目中截取的,只选取了关键内容,pytorch实现:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class Node(object):
    def __init__(self, hidden, previous_node, decoder_input, attn, log_prob, length):
        self.hidden = hidden
        self.previous_node = previous_node
        self.decoder_input = decoder_input
        self.attn = attn
        self.log_prob = log_prob
        self.length = length       
def beam_search(beam_width):
    ...
    root = Node(hidden, None, decoder_input, None, 0, 1)
    q = Queue()
    q.put(root)
    
    end_nodes = [] #最终节点的位置,用于回溯
    while not q.empty():
        candidates = []  #每一层的可能被拓展的节点,只需选取每个父节点的儿子节点中概率最大的k个即可
    
        for _ in range(q.qsize()):
            node = q.get()
            decoder_input = node.decoder_input
            hidden = node.hidden
            
            # 搜索终止条件
            if decoder_input.item() == EOS or node.length >= 50:
                end_nodes.append(node)
                continue
              
            log_prob, hidden, attn = decoder(
                 decoder_input, hidden, encoder_input
             )
             
             log_prob, indices = log_prob.topk(beam_width) #选取某个父节点的儿子节点概率最大的k个
             
             for k in range(beam_width):
                  index = indices[k].unsqueeze(0)
                  log_p = log_prob[k].item()
                  child = Node(hidden, node, index, attn, node.log_prob + log_p, node.length + 1)
                  candidates.append((node.log_prob + log_p, child))  #建立候选儿子节点,注意这里概率需要累计
           
         candidates = sorted(candidates, key=lambda x:x[0], reverse=True) #候选节点排序
         length = min(len(candidates), beam_width)  #取前k个,如果不足k个,则全部入选
         for i in range(length):
             q.put(candidates[i][1]) 
    # 后面是回溯, 省略
    ...

版本二

不进行层次遍历,而是每次从整个队列中拿出概率最大的节点出队(优先队列)进行搜索,将该节点的topk加入优先队列,循环终止的条件是节点所在位置对应长度达到限制或队列节点个数超过限制

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import operator
import torch
import torch.nn as nn
import torch.nn.functional as F
from queue import PriorityQueue
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SOS_token = 0
EOS_token = 1
MAX_LENGTH = 50
class DecoderRNN(nn.Module):
    def __init__(self, embedding_size, hidden_size, output_size, cell_type, dropout=0.1):
        '''
        Illustrative decoder
        '''
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.cell_type = cell_type
        self.embedding = nn.Embedding(num_embeddings=output_size,
                                      embedding_dim=embedding_size,
                                      )
        self.rnn = nn.GRU(embedding_size, hidden_size, bidirectional=True, dropout=dropout, batch_first=False)
        self.dropout_rate = dropout
        self.out = nn.Linear(hidden_size, output_size)
    def forward(self, input, hidden, not_used):
        embedded = self.embedding(input).transpose(0, 1# [B,1] -> [ 1, B, D]
        embedded = F.dropout(embedded, self.dropout_rate)
        output = embedded
        # batch_first=False, output维度为 (seq_len, batch_size, num_directions * hidden_size) = [1, batch_size, 2*hidden_size]
        output, hidden = self.rnn(output, hidden)
        out = self.out(output.squeeze(0))
        # output维度为 [batch_size, vocab_size]
        # hidden维度为 [num_layers * num_directions, batch_size, hidden_size]
        output = F.log_softmax(out, dim=1)
        return output, hidden
class BeamSearchNode(object):
    def __init__(self, hiddenstate, previousNode, wordId, logProb, length):
        '''
        :param hiddenstate:
        :param previousNode:
        :param wordId:
        :param logProb:
        :param length:
        '''
        self.h = hiddenstate
        self.prevNode = previousNode
        self.wordid = wordId
        self.logp = logProb
        self.leng = length
    def eval(self, alpha=1.0):
        reward = 0
        # Add here a function for shaping a reward
        return self.logp / float(self.leng - 1 + 1e-6) + alpha * reward
decoder = DecoderRNN()
def beam_decode(target_tensor, decoder_hiddens, encoder_outputs=None):
    '''
    :param target_tensor: target indexes tensor of shape [B, T] where B is the batch size and T is the maximum length of the output sentence
    :param decoder_hidden: input tensor of shape [1, B, H] for start of the decoding
    :param encoder_outputs: if you are using attention mechanism you can pass encoder outputs, [T, B, H] where T is the maximum length of input sentence
    :return: decoded_batch
    '''
    beam_width = 10
    topk = 1  # how many sentence do you want to generate
    decoded_batch = []
    # decoding goes sentence by sentence
    for idx in range(target_tensor.size(0)):
        if isinstance(decoder_hiddens, tuple):  # LSTM case
            decoder_hidden = (decoder_hiddens[0][:,idx, :].unsqueeze(0),decoder_hiddens[1][:,idx, :].unsqueeze(0))
        else:
            decoder_hidden = decoder_hiddens[:, idx, :].unsqueeze(0)
        encoder_output = encoder_outputs[:,idx, :].unsqueeze(1)
        # Start with the start of the sentence token
        decoder_input = torch.LongTensor([[SOS_token]], device=device)
        # Number of sentence to generate
        endnodes = []
        number_required = min((topk + 1), topk - len(endnodes))
        # starting node -  hidden vector, previous node, word id, logp, length
        node = BeamSearchNode(decoder_hidden, None, decoder_input, 0, 1)
        nodes = PriorityQueue()
        # start the queue
        nodes.put((-node.eval(), node))
        qsize = 1
        # start beam search
        while True:
            # give up when decoding takes too long
            if qsize > 2000: break
            # fetch the best node
            score, n = nodes.get()
            decoder_input = n.wordid
            decoder_hidden = n.h
            if n.wordid.item() == EOS_token and n.prevNode != None:
                endnodes.append((score, n))
                # if we reached maximum # of sentences required
                if len(endnodes) >= number_required:
                    break
                else:
                    continue
            # output维度为 [batch_size, vocab_size]
            # hidden维度为 [num_layers * num_directions, batch_size, hidden_size]
            # decode for one step using decoder
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_output)
            # PUT HERE REAL BEAM SEARCH OF TOP
            # log_prov, indexes维度为 [batch_size, beam_width] = [1, beam_width]
            log_prob, indexes = torch.topk(decoder_output, beam_width, dim=1)
            nextnodes = []
            for new_k in range(beam_width):
                # decoded_t: [1,1],通过view(1,-1)将数字tensor变为维度为[1,1]的tensor
                decoded_t = indexes[0][new_k].view(1, -1)
                # log_p, int
                log_p = log_prob[0][new_k].item() # item()将tensor数字变为int
                node = BeamSearchNode(decoder_hidden, n, decoded_t, n.logp + log_p, n.leng + 1)
                score = -node.eval()
                nextnodes.append((score, node))
            # put them into queue
            for i in range(len(nextnodes)):
                score, nn = nextnodes[i]
                nodes.put((score, nn))
                # increase qsize
            qsize += len(nextnodes) - 1
        # choose nbest paths, back trace them
        if len(endnodes) == 0:
            endnodes = [nodes.get() for _ in range(topk)]
        utterances = []
        for score, n in sorted(endnodes, key=operator.itemgetter(0)):
            utterance = []
            utterance.append(n.wordid)
            # back trace
            while n.prevNode != None:
                n = n.prevNode
                utterance.append(n.wordid)
            utterance = utterance[::-1]
            utterances.append(utterance)
        decoded_batch.append(utterances)
    return decoded_batch
def greedy_decode(decoder_hidden, encoder_outputs, target_tensor):
    '''
    :param target_tensor: target indexes tensor of shape [B, T] where B is the batch size and T is the maximum length of the output sentence
    :param decoder_hidden: input tensor of shape [1, B, H] for start of the decoding
    :param encoder_outputs: if you are using attention mechanism you can pass encoder outputs, [T, B, H] where T is the maximum length of input sentence
    :return: decoded_batch
    '''
    batch_size, seq_len = target_tensor.size()
    decoded_batch = torch.zeros((batch_size, MAX_LENGTH))
    decoder_input = torch.LongTensor([[SOS_token] for _ in range(batch_size)], device=device)
    for t in range(MAX_LENGTH):
        decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
        topv, topi = decoder_output.data.topk(1# get candidates
        topi = topi.view(-1)
        decoded_batch[:, t] = topi
        decoder_input = topi.detach().view(-1, 1)
    return decoded_batch

补充:beam search 简单例子实现及讲解

看代码吧~

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from math import log
from numpy import array
from numpy import argmax
# beam search
def beam_search_decoder(data, k):
    sequences = [[list(), 1.0]]
    # walk over each step in sequence
    for row in data:
        all_candidates = list()
        # expand each current candidate
        for i in range(len(sequences)):
            seq, score = sequences[i]
            for j in range(len(row)):
                candidate = [seq + [j], score * -log(row[j])]
                all_candidates.append(candidate)
        # order all candidates by score
        ordered = sorted(all_candidates, key=lambda tup :tup[1])
        # select k best
        sequences = ordered[:k]
    return sequences
def greedy_decoder(data):
    # index for largest probability each row
    return [argmax(s) for s in data]
# define a sequence of 10 words over a vocab of 5 words
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1]]
data = array(data)
# decode sequence
result = beam_search_decoder(data, 3)
# print result
for seq in result:
    print(seq)

每次循环sequences的值

[[[4], 0.6931471805599453], [[3], 0.916290731874155], [[2], 1.2039728043259361]]

[[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[3, 0], 0.6351243373717793]]

[[[4, 0, 4], 0.33302465198892944], [[4, 0, 3], 0.4402346437542523], [[4, 1, 4], 0.4402346437542523]]

最终print的结果

[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108]

[[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397]

[[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]

以上为个人经验,希望能给大家一个参考,也希望大家多多支持服务器之家。

原文链接:https://blog.csdn.net/u014514939/article/details/95667422