import numpy as np
import time
import os
import math, json
import torch
import torch.nn as nn
from torch.nn import functional as F
import src.utils
from src.model import GPT, GPTConfig
# .set_seed(42) # 是否固定随机数(固定后每次运行的生成结果都一样)
RUN_DEVICE = 'gpu' # gpu 或 dml 或 cpu
MODEL_NAME = 'model/wangwen-2022-02-15' # 模型名
WORD_NAME = 'model/wangwen-2022-02-15' # 这个也修改
NUM_OF_RUNS = 10 # 写多少遍
LENGTH_OF_EACH = 1024 # 每次写多少字
top_p = 0.75 # 这个的范围是 0 到 1。越大,变化越多。越小,生成效果越规矩。自己试试 0 和 0.5 和 1.0 的效果就知道了
top_p_newline = 0.9 # 已知生成各个词的总概率是1(即默认是1.0)如果top_p小于1,则从高到低累加直到top_p,取这前N个词作为候选。
# 开头非常重要。开头需创造剧情点。开头文笔越好,续写就越好。开头乱写,续写也乱写。
context = "他在一片黑暗中醒来,没有光的世界一片寂静,感知不到躯体的意识随着黑暗的浪潮漂浮升空,又陡然坠落。时间在这里失去了意义,直到偶然一个瞬间,他想起了自己的名字。"
ctx_len = 512
n_layer = 12
n_head = 12
n_embd = n_head * 64
n_attn = n_embd
n_ffn = n_embd
context = context.strip().split('\n')
for c in range(len(context)):
context[c] = context[c].strip().strip('\u3000')
context = '\n' + ('\n'.join(context)).strip()
print('您输入的开头有 ' + str(len(context)) + ' 个字。注意,模型只会看最后 ' + str(ctx_len) + ' 个字。')
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
word_table = json.load(result_file)
vocab_size = len(word_table)
train_dataset = lambda: None
train_dataset.stoi = {v: int(k) for k, v in word_table.items()}
train_dataset.itos = {int(k): v for k, v in word_table.items()}
UNKNOWN_CHAR = train_dataset.stoi['\ue083']
print(f'\nLoading model for {RUN_DEVICE}...', end=' ')
if RUN_DEVICE == 'dml':
import onnxruntime as rt
sess_options = rt.SessionOptions()
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = rt.ExecutionMode.ORT_SEQUENTIAL
sess_options.enable_mem_pattern = False
rt_session = rt.InferenceSession(MODEL_NAME + '.onnx', sess_options=sess_options,
model = GPT(
GPTConfig(vocab_size, ctx_len, n_layer=n_layer, n_head=n_head, n_embd=n_embd, n_attn=n_attn, n_ffn=n_ffn))
m2 = torch.load(MODEL_NAME + '.pth', map_location='cpu').state_dict()
for i in range(n_layer):
prefix = f'blocks.{i}.attn.'
time_w = m2[prefix + 'time_w']
time_alpha = m2[prefix + 'time_alpha']
time_beta = m2[prefix + 'time_beta']
TT = ctx_len
T = ctx_len
w = F.pad(time_w, (0, TT))
w = torch.tile(w, [TT])
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
w = w[:, :, TT - 1:]
w = w[:, :T, :T] * time_alpha[:, :, :T] * time_beta[:, :T, :]
m2[prefix + 'time_ww'] = w
del m2[prefix + 'time_w']
del m2[prefix + 'time_alpha']
del m2[prefix + 'time_beta']
if RUN_DEVICE == 'gpu':
model = model.cuda()
print('done:', MODEL_NAME, '&', WORD_NAME)
# 根据时间戳创建对应的文件夹用于存储数据
dir_name = str(int(time.time()))
os.makedirs("result/" + dir_name)
# 根据设置的NUM_OF_RUNS撰写对应数量的段落
for run in range(NUM_OF_RUNS):
x = np.array([train_dataset.stoi.get(s, UNKNOWN_CHAR) for s in context], dtype=np.int64)
real_len = len(x)
print_begin = 0
for i in range(LENGTH_OF_EACH):
if i == 0:
print(('-' * 60) + '\n' + context.replace('\n', '\n ').strip('\n'), end='')
# 将打印的数据写入对应目录下的文档中
write_content = context.replace('\n', '\n ').strip('\n')
with open("result/" + dir_name + "/第 " + str(run + 1) + " 遍.txt", 'a+', encoding="utf8") as f:
print_begin = real_len
with torch.no_grad():
if RUN_DEVICE == 'dml':
if real_len < ctx_len:
xxx = np.pad(x, (0, ctx_len - real_len))
xxx = x
out = rt_session.run(None, {rt_session.get_inputs()[0].name: [xxx[-ctx_len:]]})
out = torch.tensor(out[0])
xxx = torch.tensor(x[-ctx_len:], dtype=torch.long)[None, ...]
if RUN_DEVICE == 'gpu':
xxx = xxx.cuda()
out, _ = model(xxx)
out[:, :, UNKNOWN_CHAR] = -float('Inf')
pos = -1 if real_len >= ctx_len else real_len - 1
if train_dataset.itos[int(x[real_len - 1])] == '\n':
char = src.utils.sample_logits(out, pos, temperature=1.0, top_p=top_p_newline)
char = src.utils.sample_logits(out, pos, temperature=1.0, top_p=top_p)
x = np.append(x, char)
real_len += 1
if i % 2 == 1 or i == LENGTH_OF_EACH - 1 or i < 10 or RUN_DEVICE != 'gpu':
completion = ''.join([train_dataset.itos[int(i)] for i in x[print_begin:real_len]])
print(completion.replace('\n', '\n '), end='', flush=True)
# 将打印的数据写入对应目录下的文档中
write_content = completion.replace('\n', '\n ')
with open("result/" + dir_name + "/第 " + str(run + 1) + " 遍.txt", 'a+', encoding="utf8") as f:
print_begin = real_len