有了上面的各种组件,就可以进行训练了,首先确定模型结构,这里采用一个三层的结构。
from config import CustomConfig
from modeling_custom import CustomForCausalLM
config = CustomConfig(
vocab_size=len(tokenizer.get_vocab()),
max_position_embeddings=2048,
hidden_size=4096,
intermediate_size=16384,
num_hidden_layers=3,
pad_token_id=tokenizer.pad_token_id,
)
model = CustomForCausalLM(config)
这里简单实现一个函数来计算模型参数量
def get_model_size(model: nn.Module):
"""
获取模型参数量
"""
return sum(p.numel() for p in model.parameters())
模型准备好后就可以进行模型训练,下面是一个简单的训练流程:
for epoch in range(args.epochs):
for idx, batch in enumerate(data_loader):
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
logits, loss = outputs
# 反向传播
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=args.max_norm
)
# 梯度更新
optimizer.step()
# 学习率更新
scheduler.step()
# 清除梯度
optimizer.zero_grad()
对于一个 1.2B
模型,模型权重、优化器状态和梯度三部分大约占用显存计算如下:
1.2×109×4(FP32)10243×4≈17.88GB\frac{1.2 \times 10^9 \times 4(FP32)}{1024^3} \times 4 \approx 17.88GB102431.2×109×4(FP32)×4≈17.88GB
这里简单计算一下中间激活值占用显存,假设 batchsize
为16,这一批 padding 之后的长度为512,因此 input_ids
的大小为 (16,512)。
- 经过 embedding 层之后维度变为 (16,512,4096)
- 经过注意力层会投影到 Q K V,三个维度均为 (16,512,4096)
- 经过输出投影,维度为 (16,512,4096)
- 经过前馈网络中上投影和门控,得到两个维度 (16,512,16384)
- 经过下投影得到维度 (16,512,4096)
- 经过词汇表大小投影(这里大约57000的大小)得到维度 (16,512,57000)
embedding 层的结果大约 32M,一层 Attention 层结果大约 128M,一层前馈网络结果大约 288M,最后词汇表投影大约 445M。这个模型中使用 3 层解码器层,不考虑层归一化的中间结果,这个模型总共中间结果大约有 1725M结果,每个结果占用 4Bytes,则最后总共显存占用大约 6.7GB。
这是长度为 512 的情况,实际上我的训练文本中大量存在 2k 左右文本,它会使占用显存成倍数增加,假设一个 2k 的文本,则显存占用会扩展到 26.8GB。
上面最理想的情况,实际计算中还会产生各种变量占用显存,很快就会导致显存溢出而从无法训练。幸运的是在实现模型结构时加入了梯度检查点,只需要保存关键节点的中间结果,反向传播时重新从最近节点开始计算即可,这样大大节省了显存。
在这个模型中只需要调用 model.enable_gradient_checkpoint()
即可开启梯度检查点。
除了梯度检查点,还可以通过减少 batchsize
来减少中间激活值占用显存,但是减少批量大小可能导致损失震荡无法收敛,这里我们采用多步累加解决这个问题,在一个小批次反向传播计算梯度之后,先不更新权重和清除梯度,而是累计多个小批次之后一起更新然后清除梯度。
最后还可以采用混合精度训练,这样不仅能加快训练速度还能显著减少中间激活值空间占用。
有了以上策略,可以尝试愉快训练模型了,训练前为了方便修改配置,我们进行一些封装,同时添加一些日志信息,方便最后观测整个训练过程,这里直接给出最后的代码。
import json
import os
import random
from dataclasses import dataclass
from typing import Optional, Union
import numpy as np
import torch
import torch.nn as nn
from datasets import Dataset
from torch.cuda.amp import GradScaler, autocast
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, SequentialLR
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling
from config import CustomConfig
from modeling_custom import CustomForCausalLM
from tokenization_custom import CustomTokenizer
from utils import get_model_size
SEED = 42
def set_seed(seed: int):
torch.manual_seed(seed=seed)
torch.cuda.manual_seed(seed=seed)
torch.cuda.manual_seed_all(seed=seed)
np.random.seed(seed=seed)
random.seed(seed)
def get_lr_warmup(warmup_steps: int):
def lr_warmup(current_step: int):
return float(current_step) / float(max(1, warmup_steps))
return lr_warmup
@dataclass
class TrainingArgs:
output_dir: str
logging_steps: int = 500
saving_steps: int = 500
batch_size: int = 1
epochs: int = 3
lr: float = 1e-4
weight_decay: float = 1e-4
max_norm: float = 1.0
warm_up_ratio: float = 0.1
gradient_checkpointing: bool = False
gradient_accumulation_steps: int = 24
def train(
model: nn.Module,
args: TrainingArgs,
dataset: Dataset,
device: Optional[Union[str, torch.device]] = None,
data_collator=None,
):
data_loader = DataLoader(
dataset=dataset,
batch_size=args.batch_size,
shuffle=True,
collate_fn=data_collator,
num_workers=8,
)
# 完整的有效步
complete_steps_per_epoch = len(data_loader) // args.gradient_accumulation_steps
# 不完整的有效步,最后剩余的小批量
last_mini_steps = len(data_loader) % args.gradient_accumulation_steps
# 一个 epoch 等效步
if last_mini_steps != 0:
steps_per_epoch = complete_steps_per_epoch + 1
else:
steps_per_epoch = complete_steps_per_epoch
total_steps = steps_per_epoch * args.epochs
# 优化器
optimizer = AdamW(
params=model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay,
)
# 学习率调度
warmup_steps = int(total_steps * args.warm_up_ratio)
cosine_steps = total_steps - warmup_steps
warmup_scheduler = LambdaLR(
optimizer=optimizer, lr_lambda=get_lr_warmup(warmup_steps=warmup_steps)
)
cosine_scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=cosine_steps)
scheduler = SequentialLR(
optimizer=optimizer,
schedulers=[warmup_scheduler, cosine_scheduler],
milestones=[warmup_steps],
)
# 设备
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
os.makedirs(args.output_dir, exist_ok=True)
model = model.to(device=device)
if args.gradient_checkpointing:
model.enable_gradient_checkpoint()
loggin_info = []
current_step = 0
progress_bar = tqdm(range(total_steps))
scaler = GradScaler()
for epoch in range(args.epochs):
current_loss = 0.0
for idx, batch in enumerate(data_loader):
batch = {k: v.to(device) for k, v in batch.items()}
if last_mini_steps == 0 or len(data_loader) - (idx + 1) > last_mini_steps:
current_accumulation = args.gradient_accumulation_steps
else:
current_accumulation = last_mini_steps
with autocast(dtype=torch.bfloat16):
outputs = model(**batch)
logits, loss = outputs
loss /= current_accumulation
current_loss += loss.item()
# 反向传播
scaler.scale(loss).backward()
if (idx + 1) % args.gradient_accumulation_steps == 0 or (idx + 1) == len(
data_loader
):
# 梯度裁剪
scaler.unscale_(optimizer=optimizer)
torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=args.max_norm
)
# 梯度更新
scaler.step(optimizer=optimizer)
# 更新缩放因子
scaler.update()
# 学习率更新
scheduler.step()
# 清除梯度
optimizer.zero_grad()
progress_bar.update(1)
current_step += 1
if current_step % args.logging_steps == 0:
current_epochs = current_step / steps_per_epoch
info = {
"Epoch": f"{current_epochs:.2f}/{args.epochs}",
"Step": f"{current_step}/{total_steps}",
"Loss": current_loss,
"LR": scheduler.get_last_lr()[0],
}
loggin_info.append(info)
print(info)
if current_step % args.saving_steps == 0:
ckpt_path = os.path.join(
args.output_dir,
f"checkpoint-{current_step}.pt",
)
torch.save(model.state_dict(), ckpt_path)
current_loss = 0.0
ckpt_path = os.path.join(
args.output_dir,
"last.pt",
)
torch.save(model.state_dict(), ckpt_path)
with open("logging.jsonl", "w", encoding="utf-8") as fw:
for logging_data in loggin_info:
fw.write(json.dumps(logging_data) + "\n")
if __name__ == "__main__":
set_seed(SEED)
tokenizer = CustomTokenizer.from_pretrained("tokenizer")
config = CustomConfig(
vocab_size=len(tokenizer.get_vocab()),
max_position_embeddings=2048,
hidden_size=4096,
intermediate_size=16384,
num_hidden_layers=3,
pad_token_id=tokenizer.pad_token_id,
)
model = CustomForCausalLM(config)
print(f"Model size is {get_model_size(model)}")
dataset = Dataset.load_from_disk("nlp_datas/cached")
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
args = TrainingArgs(
output_dir="result",
gradient_checkpointing=True,
batch_size=4,
logging_steps=50,
warm_up_ratio=0.03,
epochs=1,
gradient_accumulation_steps=8,
lr=1e-3,
weight_decay=1e-5,
)
train(model=model, args=args, dataset=dataset, data_collator=data_collator)