240712_昇思学习打卡-Day24-LSTM+CRF序列标注(3)

时间:2024-07-16 22:24:16

240712_昇思学习打卡-Day24-LSTM+CRF序列标注(3)

今天做LSTM+CRF序列标注第三部分,同样,仅作简单记录及注释,最近确实太忙了。

Viterbi算法

在完成前向训练部分后,需要实现解码部分。这里我们选择适合求解序列最优路径的Viterbi算法。与计算Normalizer类似,使用动态规划求解所有可能的预测序列得分。不同的是在解码时同时需要将第????个Token对应的score取值最大的标签保存,供后续使用Viterbi算法求解最优预测序列使用。

取得最大概率得分ScoreScore,以及每个Token对应的标签历史HistoryHistory后,根据Viterbi算法可以得到公式:

请添加图片描述

从第0个至第????个Token对应概率最大的序列,只需要考虑从第0个至第????−1个Token对应概率最大的序列,以及从第????个至第????−1个概率最大的标签即可。因此我们逆序求解每一个概率最大的标签,构成最佳的预测序列。

由于静态图语法限制,我们将Viterbi算法求解最佳预测序列的部分作为后处理函数,不纳入后续CRF层的实现。

# 定义维特比解码算法,用于找出具有最大概率的标签序列
def viterbi_decode(emissions, mask, trans, start_trans, end_trans):
    # emissions: (seq_length, batch_size, num_tags) 发射概率矩阵
    # mask: (seq_length, batch_size) 序列掩码,用于标记有效序列长度
    # trans: 转移概率矩阵
    # start_trans: 初始状态转移概率向量
    # end_trans: 终止状态转移概率向量

    seq_length = mask.shape[0]  # 获取序列长度

    # 初始化分数矩阵,等于初始状态转移概率加上第一个发射概率
    score = start_trans + emissions[0]
    history = ()  # 初始化历史路径记录

    # 遍历序列中的每个时间步
    for i in range(1, seq_length):
        # 扩展维度以便广播运算
        broadcast_score = score.expand_dims(2)
        broadcast_emission = emissions[i].expand_dims(1)
        
        # 计算所有可能的转移分数
        next_score = broadcast_score + trans + broadcast_emission

        # 找出当前Token对应的最大分数标签,并保存
        indices = next_score.argmax(axis=1)
        history += (indices,)  # 保存历史路径信息

        # 取出最大分数
        next_score = next_score.max(axis=1)
        
        # 更新分数矩阵,只更新mask为True的部分
        score = mnp.where(mask[i].expand_dims(1), next_score, score)

    # 加上终止状态转移概率
    score += end_trans

    # 返回最终的分数矩阵和历史路径信息
    return score, history


# 根据解码过程中的得分和历史路径信息,重构最优标签序列
def post_decode(score, history, seq_length):
    # score: 最终得分矩阵
    # history: 历史路径信息
    # seq_length: 每个样本的实际序列长度

    batch_size = seq_length.shape[0]  # 获取批次大小
    seq_ends = seq_length - 1  # 计算每个样本的最后一个Token位置
    
    # 初始化最佳标签序列列表
    best_tags_list = []

    # 对批次中的每个样本进行解码
    for idx in range(batch_size):
        # 找出使最后一个Token对应的预测概率最大的标签
        best_last_tag = score[idx].argmax(axis=0)
        best_tags = [int(best_last_tag.asnumpy())]  # 添加最佳标签到序列

        # 从历史路径信息中反向追踪,找到每个Token的最佳标签
        for hist in reversed(history[:seq_ends[idx]]):
            best_last_tag = hist[idx][best_tags[-1]]
            best_tags.append(int(best_last_tag.asnumpy()))

        # 将逆序的标签序列反转,得到正序的最优标签序列
        best_tags.reverse()
        best_tags_list.append(best_tags)  # 添加到结果列表

    # 返回最优标签序列列表
    return best_tags_list

CRF层

完成上述前向训练和解码部分的代码后,将其组装完整的CRF层。考虑到输入序列可能存在Padding的情况,CRF的输入需要考虑输入序列的真实长度,因此除发射矩阵和标签外,加入seq_length参数传入序列Padding前的长度,并实现生成mask矩阵的sequence_mask方法。

综合上述代码,使用nn.Cell进行封装,最后实现完整的CRF层如下:

# 导入MindSpore相关模块
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.numpy as mnp
from mindspore.common.initializer import initializer, Uniform

# 定义序列掩码生成函数
def sequence_mask(seq_length, max_length, batch_first=False):
    """
    根据序列的实际长度和最大长度生成mask矩阵。
    
    参数:
    seq_length: 实际序列长度张量。
    max_length: 序列的最大长度。
    batch_first: 是否将批次放在第一维度。
    
    返回:
    mask矩阵,形状为(batch_size, max_length),其中True表示有效位置,False表示填充位置。
    """
    # 生成从0到max_length的范围向量
    range_vector = mnp.arange(0, max_length, 1, seq_length.dtype)
    # 创建mask矩阵,shape为(seq_length.shape + (1,))
    result = range_vector < seq_length.view(seq_length.shape + (1,))
    # 转换数据类型并根据batch_first参数调整维度顺序
    if batch_first:
        return result.astype(ms.int64)
    return result.astype(ms.int64).swapaxes(0, 1)


# 定义条件随机场(CRF)模型类
class CRF(nn.Cell):
    def __init__(self, num_tags: int, batch_first: bool = False, reduction: str = 'sum') -> None:
        """
        初始化CRF模型。
        
        参数:
        num_tags: 标签数量。
        batch_first: 是否将批次放在第一维度。
        reduction: 损失函数的缩减方式。
        """
        # 检查标签数量是否有效
        if num_tags <= 0:
            raise ValueError(f'无效的标签数量: {num_tags}')
        super().__init__()
        # 检查reduction参数是否有效
        if reduction not in ('none', 'sum', 'mean', 'token_mean'):
            raise ValueError(f'无效的缩减方式: {reduction}')
        self.num_tags = num_tags  # 标签数量
        self.batch_first = batch_first  # 批次是否在第一维度
        self.reduction = reduction  # 损失函数缩减方式
        # 初始化起始和结束状态转移权重
        self.start_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='start_transitions')
        self.end_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='end_transitions')
        # 初始化状态间转移权重
        self.transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags, num_tags)), name='transitions')

    def construct(self, emissions, tags=None, seq_length=None):
        """
        CRF模型的前向传播方法。
        
        参数:
        emissions: 发射概率张量。
        tags: 真实标签张量。
        seq_length: 序列长度张量。
        
        返回:
        如果tags为None,则返回解码结果;否则返回损失值。
        """
        if tags is None:
            return self._decode(emissions, seq_length)
        return self._forward(emissions, tags, seq_length)

    def _forward(self, emissions, tags=None, seq_length=None):
        """
        计算损失值。
        
        参数:
        emissions: 发射概率张量。
        tags: 真实标签张量。
        seq_length: 序列长度张量。
        
        返回:
        损失值。
        """
        # 根据batch_first参数调整emissions和tags的维度顺序
        if self.batch_first:
            batch_size, max_length = tags.shape
            emissions = emissions.swapaxes(0, 1)
            tags = tags.swapaxes(0, 1)
        else:
            max_length, batch_size = tags.shape
        
        # 如果seq_length未给出,则假设所有序列都是最大长度
        if seq_length is None:
            seq_length = mnp.full((batch_size,), max_length, ms.int64)
        
        # 生成mask矩阵
        mask = sequence_mask(seq_length, max_length)
        
        # 计算分子部分(真实路径的得分)
        numerator = compute_score(emissions, tags, seq_length-1, mask, self.transitions, self.start_transitions, self.end_transitions)
        # 计算分母部分(所有可能路径的得分总和)
        denominator = compute_normalizer(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)
        # 计算对数似然比
        llh = denominator - numerator
        
        # 根据reduction参数选择损失值的缩减方式
        if self.reduction == 'none':
            return llh
        elif self.reduction == 'sum':
            return llh.sum()
        elif self.reduction == 'mean':
            return llh.mean()
        return llh.sum() / mask.astype(emissions.dtype).sum()

    def _decode(self, emissions, seq_length=None):
        """
        解码方法,用于预测最优标签序列。
        
        参数:
        emissions: 发射概率张量。
        seq_length: 序列长度张量。
        
        返回:
        最优标签序列。
        """
        # 根据batch_first参数调整emissions的维度顺序
        if self.batch_first:
            batch_size, max_length = emissions.shape[:2]
            emissions = emissions.swapaxes(0, 1)
        else:
            batch_size, max_length = emissions.shape[:2]
        
        # 如果seq_length未给出,则假设所有序列都是最大长度
        if seq_length is None:
            seq_length = mnp.full((batch_size,), max_length, ms.int64)
        
        # 生成mask矩阵
        mask = sequence_mask(seq_length, max_length)
        
        # 使用维特比算法解码最优路径
        return viterbi_decode(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)

打卡图片:

请添加图片描述