基于AI-Writer小说写作

时间:2024-10-23 07:26:21
import numpy as np import time import os import math, json import torch import torch.nn as nn from torch.nn import functional as F import src.utils from src.model import GPT, GPTConfig # .set_seed(42) # 是否固定随机数(固定后每次运行的生成结果都一样) RUN_DEVICE = 'gpu' # gpu 或 dml 或 cpu MODEL_NAME = 'model/wangwen-2022-02-15' # 模型名 WORD_NAME = 'model/wangwen-2022-02-15' # 这个也修改 NUM_OF_RUNS = 10 # 写多少遍 LENGTH_OF_EACH = 1024 # 每次写多少字 top_p = 0.75 # 这个的范围是 0 到 1。越大,变化越多。越小,生成效果越规矩。自己试试 0 和 0.5 和 1.0 的效果就知道了 top_p_newline = 0.9 # 已知生成各个词的总概率是1(即默认是1.0)如果top_p小于1,则从高到低累加直到top_p,取这前N个词作为候选。 # 开头非常重要。开头需创造剧情点。开头文笔越好,续写就越好。开头乱写,续写也乱写。 context = "他在一片黑暗中醒来,没有光的世界一片寂静,感知不到躯体的意识随着黑暗的浪潮漂浮升空,又陡然坠落。时间在这里失去了意义,直到偶然一个瞬间,他想起了自己的名字。" ctx_len = 512 n_layer = 12 n_head = 12 n_embd = n_head * 64 n_attn = n_embd n_ffn = n_embd context = context.strip().split('\n') for c in range(len(context)): context[c] = context[c].strip().strip('\u3000') context = '\n' + ('\n'.join(context)).strip() print('您输入的开头有 ' + str(len(context)) + ' 个字。注意,模型只会看最后 ' + str(ctx_len) + ' 个字。') with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file: word_table = json.load(result_file) vocab_size = len(word_table) train_dataset = lambda: None train_dataset.stoi = {v: int(k) for k, v in word_table.items()} train_dataset.itos = {int(k): v for k, v in word_table.items()} UNKNOWN_CHAR = train_dataset.stoi['\ue083'] print(f'\nLoading model for {RUN_DEVICE}...', end=' ') if RUN_DEVICE == 'dml': import onnxruntime as rt sess_options = rt.SessionOptions() sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.execution_mode = rt.ExecutionMode.ORT_SEQUENTIAL sess_options.enable_mem_pattern = False rt_session = rt.InferenceSession(MODEL_NAME + '.onnx', sess_options=sess_options, providers=['DmlExecutionProvider']) rt_session.set_providers(['DmlExecutionProvider']) else: model = GPT( GPTConfig(vocab_size, ctx_len, n_layer=n_layer, n_head=n_head, n_embd=n_embd, n_attn=n_attn, n_ffn=n_ffn)) m2 = torch.load(MODEL_NAME + '.pth', map_location='cpu').state_dict() for i in range(n_layer): prefix = f'blocks.{i}.attn.' time_w = m2[prefix + 'time_w'] time_alpha = m2[prefix + 'time_alpha'] time_beta = m2[prefix + 'time_beta'] TT = ctx_len T = ctx_len w = F.pad(time_w, (0, TT)) w = torch.tile(w, [TT]) w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1) w = w[:, :, TT - 1:] w = w[:, :T, :T] * time_alpha[:, :, :T] * time_beta[:, :T, :] m2[prefix + 'time_ww'] = w del m2[prefix + 'time_w'] del m2[prefix + 'time_alpha'] del m2[prefix + 'time_beta'] if RUN_DEVICE == 'gpu': model = model.cuda() model.load_state_dict(m2) print('done:', MODEL_NAME, '&', WORD_NAME) ############################################################################## # 根据时间戳创建对应的文件夹用于存储数据 dir_name = str(int(time.time())) os.makedirs("result/" + dir_name) # 根据设置的NUM_OF_RUNS撰写对应数量的段落 for run in range(NUM_OF_RUNS): x = np.array([train_dataset.stoi.get(s, UNKNOWN_CHAR) for s in context], dtype=np.int64) real_len = len(x) print_begin = 0 for i in range(LENGTH_OF_EACH): if i == 0: print(('-' * 60) + '\n' + context.replace('\n', '\n ').strip('\n'), end='') # 将打印的数据写入对应目录下的文档中 write_content = context.replace('\n', '\n ').strip('\n') with open("result/" + dir_name + "/第 " + str(run + 1) + " 遍.txt", 'a+', encoding="utf8") as f: f.write(write_content) print_begin = real_len with torch.no_grad(): if RUN_DEVICE == 'dml': if real_len < ctx_len: xxx = np.pad(x, (0, ctx_len - real_len)) else: xxx = x out = rt_session.run(None, {rt_session.get_inputs()[0].name: [xxx[-ctx_len:]]}) out = torch.tensor(out[0]) else: xxx = torch.tensor(x[-ctx_len:], dtype=torch.long)[None, ...] if RUN_DEVICE == 'gpu': xxx = xxx.cuda() out, _ = model(xxx) out[:, :, UNKNOWN_CHAR] = -float('Inf') pos = -1 if real_len >= ctx_len else real_len - 1 if train_dataset.itos[int(x[real_len - 1])] == '\n': char = src.utils.sample_logits(out, pos, temperature=1.0, top_p=top_p_newline) else: char = src.utils.sample_logits(out, pos, temperature=1.0, top_p=top_p) x = np.append(x, char) real_len += 1 if i % 2 == 1 or i == LENGTH_OF_EACH - 1 or i < 10 or RUN_DEVICE != 'gpu': completion = ''.join([train_dataset.itos[int(i)] for i in x[print_begin:real_len]]) print(completion.replace('\n', '\n '), end='', flush=True) # 将打印的数据写入对应目录下的文档中 write_content = completion.replace('\n', '\n ') with open("result/" + dir_name + "/第 " + str(run + 1) + " 遍.txt", 'a+', encoding="utf8") as f: f.write(write_content) print_begin = real_len print()