NLP - AI 写诗

时间:2024-05-08 18:12:41

文章目录

    • 代码
      • 测试已训练好的模型
      • Tokenize
      • 加载数据集
      • 训练


代码


测试已训练好的模型

由于模型较大,up主上传时使用了分卷压缩,这里我们先合并解压

# 将四个文件合并为1个
$ cat save.zip.00* > save.zip

# 解压;这一步也可以手动
$ unzip save.zip

import torch

model_path = 'xx/Chinese_Poetry_Generate-main/save.model'
model = torch.load(model_path)

generate('秋', row=4, col=7)

然后根据需要,一步步添加下方的方法。

这里我得到的效果为:


> generate('秋', row=4, col=7)
0 [CLS] 秋 树 摇 风 满 远 林 , 登 临 此 兴 可 同 吟 。 云 生 半 夜 天 风 捲 , 月 过 东 篱 露 气 深 。
1 [CLS] 秋 阴 凝 未 尽 人 意 , 老 去 人 间 已 非 客 。 一 叶 舟 中 无 所 依 , 秋 风 满 前 山 色 黑 。
2 [CLS] 秋 霜 萧 萧 一 叶 秋 , 秋 风 萧 萧 吹 远 愁 。 高 堂 夜 静 月 照 影 , 独 立 空 台 愁 白 头 。

generate('莫', row=4, col=5)
0 [CLS] 莫 把 年 华 念 , 吾 衰 鬓 未 斑 。 病 身 今 老 矣 , 衰 病 旧 依 然 。
1 [CLS] 莫 信 江 山 胜 , 还 因 古 塞 名 。 云 收 天 半 暮 , 霜 坠 雁 馀 清 。
2 [CLS] 莫 辞 寒 食 过 , 为 惜 雨 晴 时 。 风 景 今 犹 昨 , 江 山 几 更 非 。

Tokenize

from transformers import AutoTokenizer

#加载编码器
tokenizer = AutoTokenizer.from_pretrained('uer/gpt2-chinese-cluecorpussmall')

print(tokenizer)

#编码试算
tokenizer.batch_encode_plus([
    '欲出未出光辣达,千山万山如火发.须臾走向天上来,逐却残星赶却月.',
    '满目*四望幽,白云高卷嶂烟收.日回禽影穿疏木,风递猿声入小楼.远岫似屏横碧落,断帆如叶截中流.'
])

PreTrainedTokenizerFast(name_or_path='uer/gpt2-chinese-cluecorpussmall', vocab_size=21128, model_max_len=1024, is_fast=True, padding_side='right', truncation_side='right', 
special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

{'input_ids': [[101, 3617, 1139, 3313, 1139, 1045, 6793, 6809, 117, 1283, 2255, 674, 2255, 1963, 4125, 1355, 119, 7557, 5640, 6624, 1403, 1921, 677, 3341, 117, 6852, 1316, 3655, 3215, 6628, 1316, 3299, 119, 102], [101, 4007, 4680, 3736, 2255, 1724, 3307, 2406, 117, 4635, 756, 7770, 1318, 2322, 4170, 3119, 119, 3189, 1726, 4896, 2512, 4959, 4541, 3312, 117, 7599, 6853, 4351, 1898, 1057, 2207, 3517, 119, 6823, 2273, 849, 2242, 3566, 4819, 5862, 117, 3171, 2359, 1963, 1383, 2779, 704, 3837, 119, 102]], 
'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 
'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

加载数据集

import torch


#简单数据集
class Dataset(torch.utils.data.Dataset):

    def __init__(self):
        with open('chinese_poems.txt') as f:
            lines = f.readlines()
        lines = [i.strip() for i in lines]

        self.lines = lines

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, i):
        return self.lines[i]


dataset = Dataset()

len(dataset), dataset[0]

(304752, '欲出未出光辣达,千山万山如火发.须臾走向天上来,逐却残星赶却月.')

import torch
import os
import pandas as pd


#更多数据数据集
class Dataset(torch.utils.data.Dataset):

    def __init__(self):
        data = []
        for i in os.listdir('more_datas'):
            if i == '.ipynb_checkpoints':
                continue
            data.append(pd.read_csv('more_datas/%s' % i))

        data = pd.concat(data).reset_index()

        data = data['内容']

        data = data.str.strip()

        #移除一些标点符号
        data = data.str.replace('[《》“”「」]', '', regex=True)

        #正则过滤
        select = data.str.match('^[\w,。?、!:;]+$', na=False)
        data = data[select]

        #标点符号合并
        data = data.str.replace('[?!;]', '。', regex=True)
        data = data.str.replace('[、:]', ',', regex=True)

        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        return self.data.iloc[i]


dataset = Dataset()

len(dataset), dataset[0]

(839587,
 '不饮强须饮,今日是重阳。向来健者安在,世事两茫茫。叔子去人远矣,正复何关人事,堕泪忽成行。叔子泪自堕,湮没使人伤。燕何归,鸿欲断,蝶休忙。渊明自无可奈,冷眼菊花黄。看取龙山落日,又见骑台荒草,谁弱复谁强。酒亦有何好,暂醉得相忘。')

def collate_fn(data):
    data = tokenizer.batch_encode_plus(data,
                                       padding=True,
                                       truncation=True,
                                       max_length=512,
                                       return_tensors='pt')

    data['labels'] = data['input_ids'].clone()

    return data


#数据加载器
loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=8,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
)

for i, data in enumerate(loader):
    break

for k, v in data.items():
    print(k, v.shape)

len(loader)

input_ids torch.Size([8, 130])
token_type_ids torch.Size([8, 130])
attention_mask torch.Size([8, 130])
labels torch.Size([8, 130])
104948

from transformers import AutoModelForCausalLM, GPT2Model

#加载模型
model = AutoModelForCausalLM.from_pretrained(
    'uer/gpt2-chinese-cluecorpussmall')

#统计参数量
print(sum(i.numel() for i in model.parameters()) / 10000)

with torch.no_grad():
    out = model(**data)

out['loss'], out['logits'].shape

10206.8736

(tensor(9.5407), torch.Size([8, 130, 21128]))

def generate(text, row, col):

    def generate_loop(data):
        with torch.no_grad():
            out = model(**data)

        #取最后一个字
        #[5, b, 50257]
        out = out['logits']
        #[5, 50257]
        out = out[:, -1]

        #第50大的值,以此为分界线,小于该值的全部赋值为负无穷
        #[5, 50257] -> [5, 50]
        topk_value = torch.topk(out, 50).values
        #[5, 50] -> [5] -> [5, 1]
        topk_value = topk_value[:, -1].unsqueeze(dim=1)

        #赋值
        #[5, 50257]
        out = out.masked_fill(out < topk_value, -float('inf'))

        #不允许写特殊符号
        out[:, tokenizer.sep_token_id] = -float('inf')
        out[:, tokenizer.unk_token_id] = -float('inf')
        out[:, tokenizer.pad_token_id] = -float('inf')
        for i in ',。':
            out[:, tokenizer.get_vocab()[i]] = -float('inf')

        #根据概率采样,无放回,所以不可能重复
        #[5, 50257] -> [5, 1]
        out = out.softmax(dim=1)
        out = out.multinomial(num_samples=1)

        #强制添加标点符号
        c = data['input_ids'].shape[1] / (col + 1)
        if c % 1 == 0:
            if c % 2 == 0:
                out[:, 0] = tokenizer.get_vocab()['。']
            else:
                out[:, 0] = tokenizer.get_vocab()[',']

        data['input_ids'] = torch.cat([data['input_ids'], out], dim=1)
        data['attention_mask'] = torch.ones_like(data['input_ids'])
        data['token_type_ids'] = torch.zeros_like(data['input_ids'])
        data['labels'] = data['input_ids'].clone()

        if data['input_ids'].shape[1] >= row * col + row + 1:
            return data

        return generate_loop(data)

    #重复3遍
    data = tokenizer.batch_encode_plus([text] * 3, return_tensors='pt')
    data['input_ids'] = data['input_ids'][:, :-1]
    data['attention_mask'] = torch.ones_like(data['input_ids'])
    data['token_type_ids'] = torch.zeros_like(data['input_ids'])
    data['labels'] = data['input_ids'].clone()

    data = generate_loop(data)

    for i in range(3):
        print(i, tokenizer.decode(data['input_ids'][i]))


generate('秋高气爽', row=4, col=5)

训练

from transformers import AdamW
from transformers.optimization import get_scheduler


#训练
def train():
    global model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)

    optimizer = AdamW(model.parameters(), lr=5e-5)
    scheduler = get_scheduler(name='linear',
                              num_warmup_steps=0,
                              num_training_steps=len(loader),
                              optimizer=optimizer)

    model.train()
    for i, data in enumerate(loader):
        for k in data.keys():
            data[k] = data[k].to(device)
        out = model(**data)
        loss = out['loss']

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters