LLM padding left or right

时间:2024-04-20 07:22:53

参考博客:
大部分的大模型(LLM)采用左填充(left-padding)的原因
注:文章主要内容参考以上博客,及其评论区,如有侵权,联系删除。

最近在看大模型相关内容的时候,突然想到我实习时候一直一知半解的问题,大模型采用padding left还是right。以上内容,根据我自身的理解,如有错误,请大家批评指正。这里的padding 的位置,仅仅考虑推理时候的left or right。

为什么需要padding?

因为输入的序列长度不一致,为了能够在Batch内进行数据推理,所以需要增加padding,使输入序列的长度是一致的。

为什么要考虑padding left or right?

在BERT时代,通常padding的方式为right,即在右侧进行padding,因为BERT在初始位置有个特殊token,[CLS]左侧进行padding,不好操作。
在大模型时代,可能更偏向于左侧padding, 为什么进行左侧padding,我理解主要原因可能是为了更好也是为了更好操作。
最直观的想法,如果右侧进行padding,生成的序列中间会存在padding token,还需要进一步处理padding token。如下图所示:
在这里插入图片描述

如果采用左侧的padding 的方式则是比较方便处理或者操作。在进行batch推理的时候左侧,进行操作,非常的方便, 如下图所示:
在这里插入图片描述

摘取一些比较好的解释

在这里插入图片描述

大模型batch推理时只能padding left?

大模型在推理时候,同样可以采用padding right,只不过需要增加一些步骤,没有padding left这么直观。
由于只找到LLama和Gemma的推理代码,所以仅仅参考这两个代码进行解释。
参考代码:
LLama
Gemma
下面是Gemma推理代码:

 def generate(
        self,
        prompts: Union[str, Sequence[str]],
        device: Any,
        output_len: int = 100,
        temperature: Union[float, None] = 0.95,
        top_p: float = 1.0,
        top_k: int = 100,
    ) -> Union[str, Sequence[str]]:
        """Generates responses for given prompts using Gemma model."""
        # If a single prompt is provided, treat it as a batch of 1.
        is_str_prompt = isinstance(prompts, str)
        if is_str_prompt:
            prompts = [prompts]

        batch_size = len(prompts)
        prompt_tokens = [self.tokenizer.encode(prompt) for prompt in prompts]
        min_prompt_len = min(len(p) for p in prompt_tokens)
        max_prompt_len = max(len(p) for p in prompt_tokens)
        max_seq_len = max_prompt_len + output_len
        assert max_seq_len <= self.config.max_position_embeddings

        # build KV caches
        kv_caches = []
        for _ in range(self.config.num_hidden_layers):
            size = (batch_size, max_seq_len, self.config.num_key_value_heads,
                    self.config.head_dim)
            dtype = self.config.get_dtype()
            k_cache = torch.zeros(size=size, dtype=dtype, device=device)
            v_cache = torch.zeros(size=size, dtype=dtype, device=device)
            kv_caches.append((k_cache, v_cache))

        # prepare inputs
        token_ids_tensor = torch.full((batch_size, max_seq_len),
                                      self.tokenizer.pad_id, dtype=torch.int64)
        input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
                                            self.tokenizer.pad_id,
                                            dtype=torch.int64)
        for i, p in enumerate(prompt_tokens):
            token_ids_tensor[i, :len(p)] = torch.tensor(p)
            input_token_ids_tensor[i, :min_prompt_len] = torch.tensor(
                p[:min_prompt_len])
        token_ids_tensor = token_ids_tensor.to(device)
        input_token_ids_tensor = input_token_ids_tensor.to(device)
        prompt_mask_tensor = token_ids_tensor != self.tokenizer.pad_id
        input_positions_tensor = torch.arange(0, min_prompt_len,
                                              dtype=torch.int64).to(device)
        mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len),
                                 -2.3819763e38).to(torch.float)
        mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device)
        curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
        output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(
            device)
        temperatures_tensor = None if not temperature else torch.FloatTensor(
            [temperature] * batch_size).to(device)
        top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
        top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
        output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(
            device)

        # Prefill up to min_prompt_len tokens, then treat other prefill as
        # decode and ignore output.
        for i in range(max_seq_len - min_prompt_len):
            next_token_ids = self(
                input_token_ids=input_token_ids_tensor,
                input_positions=input_positions_tensor,
                kv_write_indices=None,
                kv_caches=kv_caches,
                mask=curr_mask_tensor,
                output_positions=output_positions_tensor,
                temperatures=temperatures_tensor,
                top_ps=top_ps_tensor,
                top_ks=top_ks_tensor,
            )

            curr_prompt_mask = prompt_mask_tensor.index_select(
                1, output_index).squeeze(dim=1)
            curr_token_ids = token_ids_tensor.index_select(
                1, output_index).squeeze(dim=1)
            output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
                                           next_token_ids).unsqueeze(dim=1)
            token_ids_tensor.index_copy_(1, output_index, output_token_ids)

            input_token_ids_tensor = output_token_ids
            input_positions_tensor = output_index.unsqueeze(dim=-1)
            curr_mask_tensor = mask_tensor.index_select(2,
                                                        input_positions_tensor)
            output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(
                device)
            output_index = output_index + 1

        # Detokenization.
        token_ids = token_ids_tensor.tolist()
        results = []
        for i, tokens in enumerate(token_ids):
            trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i])
                                    + output_len]
            if self.tokenizer.eos_id in trimmed_output:
                eos_index = trimmed_output.index(self.tokenizer.eos_id)
                trimmed_output = trimmed_output[:eos_index]
            results.append(self.tokenizer.decode(trimmed_output))

        # If a string was provided as input, return a string as output.
        return results[0] if is_str_prompt else results

以下面的图为例子进行讲解:
在这里插入图片描述
在推理的时候,先进行右侧padding,使长度一致。选择最短的长度同时进行处理,上图为1, 那么我们同时处理batch min(即1), 然后开始逐个token进行推理,怎么避免下图的形式呢在这里插入图片描述
核心代码为下面内容:output_token_ids = torch.where(curr_prompt_mask, curr_token_ids, next_token_ids).unsqueeze(dim=1),主要的思路就是,如何当前的token为padding token 则填充上一步预测的token的结果,否则填充当前的token。
例如:
在这里插入图片描述
当前位置的token不为pading,则token还是3,不为2这个位置预测的token。
在这里插入图片描述
当前token为padding,则为2这个位置预测的token。
以上就是大模型采用right padding的方法。

总结

感觉pading left or right, 其实无所谓,主要就是为了方便。根据实际情况的具体需求,进行使用,用的正确,方便即可。