QANet 是在2018年提出的一种新型机器阅读理解模型,其显著特点是不依赖传统的循环神经网络(RNN),而是完全基于卷积和自注意力机制。循环神经网络由于其顺序处理的特性在训练和推理时速度较慢,而QANet通过卷积捕获文本的局部结构,通过自注意力机制学习全局词对的交互,从而提高了效率。
本文将详细讲解QANet模型的构建过程,展示如何实现它的核心模块,包括深度可分离卷积、多头自注意力、上下文-查询注意力等,并在SQuAD数据集上进行训练。
1. 加载预处理数据
与之前的BiDAF模型类似,QANet也需要处理上下文和问题对的嵌入。我们可以直接加载之前保存的预处理过的数据,包括单词和字符的索引。
import pickle
import pandas as pd
with open('bidafw2id.pickle', 'rb') as handle:
word2idx = pickle.load(handle)
with open('bidafc2id.pickle', 'rb') as handle:
char2idx = pickle.load(handle)
train_df = pd.read_pickle('bidaftrain.pkl')
valid_df = pd.read_pickle('bidafvalid.pkl')
idx2word = {v: k for k, v in word2idx.items()}
2. 数据加载器
QANet的数据加载器与BiDAF类似,用于动态生成批次数据并进行适当的填充,以便模型能够处理变长的上下文和问题输入。
class SquadDataset:
def __init__(self, data, batch_size):
self.batch_size = batch_size
data = [data[i:i+self.batch_size] for i in range(0, len(data), self.batch_size)]
self.data = data
def make_char_vector(self, max_sent_len, sentence, max_word_len=16):
char_vec = torch.zeros(max_sent_len, max_word_len).type(torch.LongTensor)
for i, word in enumerate(nlp(sentence, disable=['parser','ner'])):
for j, ch in enumerate(word.text):
if j == max_word_len:
break
char_vec[i][j] = char2idx.get(ch, 0)
return char_vec
def __iter__(self):
for batch in self.data:
max_context_len = max([len(ctx) for ctx in batch.context_ids])
padded_context = to