参考博客:
大部分的大模型(LLM)采用左填充(left-padding)的原因
注:文章主要内容参考以上博客,及其评论区,如有侵权,联系删除。
最近在看大模型相关内容的时候,突然想到我实习时候一直一知半解的问题,大模型采用padding left还是right。以上内容,根据我自身的理解,如有错误,请大家批评指正。这里的padding 的位置,仅仅考虑推理时候的left or right。
为什么需要padding?
因为输入的序列长度不一致,为了能够在Batch内进行数据推理,所以需要增加padding,使输入序列的长度是一致的。
为什么要考虑padding left or right?
在BERT时代,通常padding的方式为right,即在右侧进行padding,因为BERT在初始位置有个特殊token,[CLS]左侧进行padding,不好操作。
在大模型时代,可能更偏向于左侧padding, 为什么进行左侧padding,我理解主要原因可能是为了更好也是为了更好操作。
最直观的想法,如果右侧进行padding,生成的序列中间会存在padding token,还需要进一步处理padding token。如下图所示:
如果采用左侧的padding 的方式则是比较方便处理或者操作。在进行batch推理的时候左侧,进行操作,非常的方便, 如下图所示:
摘取一些比较好的解释
大模型batch推理时只能padding left?
大模型在推理时候,同样可以采用padding right,只不过需要增加一些步骤,没有padding left这么直观。
由于只找到LLama和Gemma的推理代码,所以仅仅参考这两个代码进行解释。
参考代码:
LLama
Gemma
下面是Gemma推理代码:
def generate(
self,
prompts: Union[str, Sequence[str]],
device: Any,
output_len: int = 100,
temperature: Union[float, None] = 0.95,
top_p: float = 1.0,
top_k: int = 100,
) -> Union[str, Sequence[str]]:
"""Generates responses for given prompts using Gemma model."""
# If a single prompt is provided, treat it as a batch of 1.
is_str_prompt = isinstance(prompts, str)
if is_str_prompt:
prompts = [prompts]
batch_size = len(prompts)
prompt_tokens = [self.tokenizer.encode(prompt) for prompt in prompts]
min_prompt_len = min(len(p) for p in prompt_tokens)
max_prompt_len = max(len(p) for p in prompt_tokens)
max_seq_len = max_prompt_len + output_len
assert max_seq_len <= self.config.max_position_embeddings
# build KV caches
kv_caches = []
for _ in range(self.config.num_hidden_layers):
size = (batch_size, max_seq_len, self.config.num_key_value_heads,
self.config.head_dim)
dtype = self.config.get_dtype()
k_cache = torch.zeros(size=size, dtype=dtype, device=device)
v_cache = torch.zeros(size=size, dtype=dtype, device=device)
kv_caches.append((k_cache, v_cache))
# prepare inputs
token_ids_tensor = torch.full((batch_size, max_seq_len),
self.tokenizer.pad_id, dtype=torch.int64)
input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
self.tokenizer.pad_id,
dtype=torch.int64)
for i, p in enumerate(prompt_tokens):
token_ids_tensor[i, :len(p)] = torch.tensor(p)
input_token_ids_tensor[i, :min_prompt_len] = torch.tensor(
p[:min_prompt_len])
token_ids_tensor = token_ids_tensor.to(device)
input_token_ids_tensor = input_token_ids_tensor.to(device)
prompt_mask_tensor = token_ids_tensor != self.tokenizer.pad_id
input_positions_tensor = torch.arange(0, min_prompt_len,
dtype=torch.int64).to(device)
mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len),
-2.3819763e38).to(torch.float)
mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device)
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(
device)
temperatures_tensor = None if not temperature else torch.FloatTensor(
[temperature] * batch_size).to(device)
top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(
device)
# Prefill up to min_prompt_len tokens, then treat other prefill as
# decode and ignore output.
for i in range(max_seq_len - min_prompt_len):
next_token_ids = self(
input_token_ids=input_token_ids_tensor,
input_positions=input_positions_tensor,
kv_write_indices=None,
kv_caches=kv_caches,
mask=curr_mask_tensor,
output_positions=output_positions_tensor,
temperatures=temperatures_tensor,
top_ps=top_ps_tensor,
top_ks=top_ks_tensor,
)
curr_prompt_mask = prompt_mask_tensor.index_select(
1, output_index).squeeze(dim=1)
curr_token_ids = token_ids_tensor.index_select(
1, output_index).squeeze(dim=1)
output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
next_token_ids).unsqueeze(dim=1)
token_ids_tensor.index_copy_(1, output_index, output_token_ids)
input_token_ids_tensor = output_token_ids
input_positions_tensor = output_index.unsqueeze(dim=-1)
curr_mask_tensor = mask_tensor.index_select(2,
input_positions_tensor)
output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(
device)
output_index = output_index + 1
# Detokenization.
token_ids = token_ids_tensor.tolist()
results = []
for i, tokens in enumerate(token_ids):
trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i])
+ output_len]
if self.tokenizer.eos_id in trimmed_output:
eos_index = trimmed_output.index(self.tokenizer.eos_id)
trimmed_output = trimmed_output[:eos_index]
results.append(self.tokenizer.decode(trimmed_output))
# If a string was provided as input, return a string as output.
return results[0] if is_str_prompt else results
以下面的图为例子进行讲解:
在推理的时候,先进行右侧padding,使长度一致。选择最短的长度同时进行处理,上图为1, 那么我们同时处理batch min(即1), 然后开始逐个token进行推理,怎么避免下图的形式呢
核心代码为下面内容:output_token_ids = torch.where(curr_prompt_mask, curr_token_ids, next_token_ids).unsqueeze(dim=1),主要的思路就是,如何当前的token为padding token 则填充上一步预测的token的结果,否则填充当前的token。
例如:
当前位置的token不为pading,则token还是3,不为2这个位置预测的token。
当前token为padding,则为2这个位置预测的token。
以上就是大模型采用right padding的方法。
总结
感觉pading left or right, 其实无所谓,主要就是为了方便。根据实际情况的具体需求,进行使用,用的正确,方便即可。