如何使用 DeepSpeed-Chat 和自定义数据集训练类 ChatGPT 模型

时间:2025-02-17 21:07:58

如果你想使用自己的数据集进行训练,可以按照以下步骤操作:

1. 数据集格式要求

DeepSpeed-Chat 的数据集需要符合特定的格式。每个数据项应该是一个 JSON 对象,包含以下字段:

JSON复制

{
    "prompt": "Human: 你的问题", 
    "chosen": "好的回答", 
    "rejected": "不好的回答"
}
  • prompt 是问题或提示。

  • chosen 是被选择的、好的回答。

  • rejected 是被拒绝的、不好的回答。

2. 准备数据文件

将你的数据保存为 JSON 文件,例如 train.jsoneval.json,分别用于训练和评估。

3. 修改代码以使用自己的数据集

在 DeepSpeed-Chat 的代码中,需要修改数据加载部分以加载你的数据文件。具体步骤如下:

a. 修改 dschat/utils/data/raw_datasets.py

在该文件中添加一个新的类,定义你的数据集格式。例如:

Python复制

class MyDataset(PromptRawDataset):
    def __init__(self, path):
        super().__init__()
        self.data = self.load_data(path)

    def load_data(self, path):
        with open(path, 'r') as f:
            data = json.load(f)
        return data
b. 修改 dschat/utils/data/data_utils.py

get_raw_dataset 函数中添加一个条件,以便加载你的数据集。例如:

Python复制

if dataset_name == "my_dataset":
    return MyDataset(path)
c. 修改训练脚本

在训练脚本中,通过 --data_path 参数指定你的数据集路径。例如:

bash复制

python train.py --actor-model facebook/opt-1.3b --reward-model facebook/opt-350m --deployment-type single_gpu --data_path ./path/to/your/train.json

4. 注意事项

  • 如果你的数据集只包含单个回答(没有 rejected 字段),则只能在第一步(SFT)中使用。在这种情况下,需要将数据集名称添加到 --sft_only_data_path 参数中,而不是 --data_path

  • 如果你计划在第二步和第三步中使用数据集,建议使用包含两个回答(chosenrejected)的数据集,以确保训练的稳定性和模型质量。

通过以上步骤,你可以将自己准备的数据集用于 DeepSpeed-Chat 的训练。