Minillama3->dpo训练

时间:2025-01-02 07:02:21
  • import sys
  • ("/home/image_team/image_team_docker_home/lgd/e_commerce_llm/minillama3/")
  • import os
  • import re
  • import ujson
  • import torch
  • import numpy as np
  • from rich import progress
  • import as pq
  • from tools.raw_data_process import delete_file
  • from config import ModelConfig, DPOPhiConfig, DPOQwenConfig, DPOMinillama3Config
  • from model.llm_model import PhiHandler, Minillama3Handler, QwenHandler
  • from transformers import GenerationConfig
  • .enable_mem_efficient_sdp(False)
  • .enable_flash_sdp(False)
  • device = 'cuda:0' if .is_available() else 'cpu'
  • handler_dict = {
  • "Phi": PhiHandler,
  • "Minillama3": Minillama3Handler,
  • "Qwen": QwenHandler
  • }
  • if == 'Qwen':
  • config = DPOQwenConfig()
  • elif == "Phi":
  • config = DPOPhiConfig()
  • elif == "Minillama3":
  • config = DPOMinillama3Config()
  • else:
  • raise TypeError("just support qwen/phi/minillama3!!!")
  • HandlerClass = handler_dict.get()
  • handler = HandlerClass(config)
  • # step 1. 加载tokenizer
  • tokenizer = handler.load_tokenizer()
  • # step 2. 初始化模型
  • model = handler.get_sft_model().to(device)
  • gen_config = GenerationConfig(
  • temperature=0.3,
  • top_k=20,
  • top_p=0.5,
  • do_sample=True,
  • num_beams=1,
  • repetition_penalty=1.1,
  • max_new_tokens=512,
  • eos_token_id=tokenizer.eos_token_id,
  • pad_token_id=tokenizer.pad_token_id,
  • )
  • def process_alpaca_gpt4_data(max_len: int = 512) -> None:
  • ''''
  • 处理RM高质量回答部分
  • 数据集:/datasets/c-s-ale/alpaca-gpt4-data-zh
  • '''
  • read_file = PROJECT_ROOT + '/alpaca_gpt4_data_zh.json'
  • save_file = PROJECT_ROOT + '/alpaca_gpt4_data_dop_chosen_zh.json'
  • max_len += 8
  • my_data = []
  • with open(read_file, 'r', encoding='utf-8') as f:
  • data = (f)
  • print('length of {} is {}'.format(read_file, len(data)))
  • for item in (data):
  • prompt = item['instruction']
  • inputs = item['input']
  • response = item['output']
  • if len(response) > max_len: continue # 超长的不要
  • if len(()) > 0:
  • prompt = f"{prompt}{inputs}"
  • if len(prompt) > max_len: continue
  • if len(prompt) == 0 or len(response) == 0: continue
  • my_data.append(
  • {
  • 'prompt': prompt,
  • 'chosen': response
  • }
  • )
  • print('length of {} is {}'.format(save_file, len(my_data)))
  • with open(save_file, 'w', encoding='utf-8') as f:
  • (my_data, f, indent=4, ensure_ascii=False)
  • def generate_alpaca_gpt4_reject_response(groups_cnt: int = 50000, max_len: int = 320, batch_size: int = 32) -> None:
  • '''生成不是很满意的回答回答
  • '''
  • print('load model...')
  • finetune_file = PROJECT_ROOT + '/alpaca_gpt4_data_dop_chosen_zh.json'
  • save_rw_json_file = PROJECT_ROOT + '/alpaca_gpt4_data_dpo_zh.json'
  • # save_rw_parquet_file = PROJECT_ROOT + '/data/my_rlhf_dataset.parquet'
  • data = []
  • with open(finetune_file, 'r', encoding='utf-8') as f:
  • data = (f)
  • print('length of {} is {}'.format(save_rw_json_file, len(data)))
  • model_outs = []
  • batch_prompt = []
  • process_item = []
  • for i, item in (enumerate(data), total=len(data)):
  • # 模型生成的答案为拒绝答案
  • batch_prompt.append(f"{item['prompt']}[EOS]")
  • process_item.append(item)
  • if i % 500 == 0:
  • print('process {} items.'.format(i))
  • if len(batch_prompt) >= batch_size or i == len(data) - 1:
  • encoded = tokenizer.batch_encode_plus(batch_prompt, truncation=False, padding=True)
  • with torch.no_grad():
  • input_ids = (encoded.input_ids).to(device)
  • attention_mask = (encoded.attention_mask).to(device)
  • outputs = (
  • input_ids=input_ids,
  • attention_mask=attention_mask,
  • # max_seq_len=512,
  • # search_type='greedy',
  • generation_config=gen_config
  • )
  • outputs = tokenizer.batch_decode(().numpy(), clean_up_tokenization_spaces=True,
  • skip_special_tokens=True)
  • model_outs.extend(outputs)
  • batch_prompt = []
  • if len(model_outs) % 2000 == 0:
  • for i in range(len(model_outs)):
  • process_item[i]['reject'] = model_outs[i]
  • try:
  • with open(PROJECT_ROOT + '/', 'w', encoding='utf-8') as f:
  • (process_item, f, indent=4, ensure_ascii=False)
  • except Exception as e:
  • print(e)
  • for i in range(len(model_outs)):
  • process_item[i]['reject'] = model_outs[i]
  • with open(save_rw_json_file, 'w', encoding='utf-8') as f:
  • (process_item, f, indent=4, ensure_ascii=False)
  • # df = (data)
  • # write_single_parquet_file(save_rw_parquet_file, df)
  • def replace_line(s: str) -> str:
  • '''将双斜杠替换为单斜杠,既是 \\n 替换为 \n
  • '''
  • return ('\\\\n', '\n', s)
  • def merge_rlhf_data(max_len: int = 512) -> None:
  • ''''
  • 处理RM高质量回答部分
  • 数据集:/datasets/Skepsun/huozi_rlhf_data_json
  • /datasets/beyond/rlhf-reward-single-round-trans_chinese
  • '''
  • my_data = []
  • read_files = [
  • PROJECT_ROOT + '/huozi_rlhf_data.json',
  • PROJECT_ROOT + '/alpaca_gpt4_data_dpo_zh.json',
  • ]
  • save_file = PROJECT_ROOT + '/dpo_data.json'
  • if (save_file):
  • assert delete_file(save_file)
  • max_len += 8 # for eos token
  • for read_file in read_files:
  • items = []
  • with open(read_file, 'r', encoding='utf-8') as f:
  • items = (f)
  • for item in (items):
  • prompt, chosen, reject = item['prompt'], item['chosen'], item['reject']
  • if len(prompt) > max_len or len(chosen) > max_len or len(reject) > max_len:
  • continue
  • # () == (),这两个相同的也不要
  • if len(prompt) == 0 or len(chosen) == 0 or len(reject) == 0 or () == ():
  • continue
  • my_data.append({
  • 'prompt': replace_line(prompt),
  • 'chosen': replace_line(chosen),
  • 'rejected': replace_line(reject),
  • })
  • read_files = [
  • PROJECT_ROOT + '/',
  • PROJECT_ROOT + '/',
  • ]
  • for read_file in read_files:
  • pf = pq.read_table(read_file)
  • for prompt, chosen, rejected in (zip(pf['prompt'], pf['chosen'], pf['rejected']),
  • total=pf.num_rows):
  • prompt, chosen, rejected = prompt.as_py(), chosen.as_py(), rejected.as_py()
  • if len(prompt) > max_len or len(chosen) > max_len or len(rejected) > max_len:
  • continue
  • if len(prompt) == 0 or len(chosen) == 0 or len(rejected) == 0 or () == ():
  • continue
  • my_data.append({
  • 'prompt': replace_line(prompt),
  • 'chosen': replace_line(chosen),
  • 'rejected': replace_line(rejected),
  • })
  • print('length of {} is {}'.format(save_file, len(my_data)))
  • with open(save_file, 'w', encoding='utf-8') as f:
  • (my_data, f, indent=4, ensure_ascii=False)
  • def split_train_eval_dataset() -> None:
  • '''划分数据集
  • '''
  • rw_json_file = PROJECT_ROOT + '/dpo_data.json'
  • train_file = PROJECT_ROOT + '/dpo_train.json'
  • eval_file = PROJECT_ROOT + '/dpo_eval.json'
  • data = []
  • with open(rw_json_file, 'r', encoding='utf-8') as f:
  • data = (f)
  • (data)
  • split_idx = int(len(data) * 0.99)
  • train_data = data[0: split_idx]
  • eval_data = data[split_idx:]
  • print('train size: {}, eval size:{}'.format(len(train_data), len(eval_data)))
  • with open(train_file, 'w', encoding='utf-8') as f:
  • (train_data, f, indent=4, ensure_ascii=False)
  • with open(eval_file, 'w', encoding='utf-8') as f:
  • (eval_data, f, indent=4, ensure_ascii=False)
  • if __name__ == '__main__':
  • PROJECT_ROOT = "/home/image_team/image_team_docker_home/lgd/e_commerce_llm/minillama3/data/"
  • # 1. 处理chosen文本
  • # process_alpaca_gpt4_data()
  • # 2. 生成rejected文本
  • generate_alpaca_gpt4_reject_response()
  • # 合并数据集
  • # merge_rlhf_data()
  • # 3. split train and eval dataset
  • # split_train_eval_dataset()