基于Qwen2-VL模型针对LaTeX OCR任务进行微调训练 - 单图推理

时间:2024-12-13 17:46:37

基于Qwen2-VL模型针对LaTeX OCR任务进行微调训练 - 单图推理

flyfish

基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_LoRA配置如何写
基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_单图推理
基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_原模型_单图推理
基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_原模型_多图推理
基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_多图推理
基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_数据处理
基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_训练
基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_训练过程
举两个例子

示例1

1 输入是一个不大清楚的图像

请添加图片描述

2 推理结果是

E _ { m \bar { m } } = \frac { 2 ^ { 7 } \sqrt { Q _ { c } } \pi ^ { 1 / 2 } } { \Gamma ( 1 / 4 ) ^ { 2 } } \frac { \log \left( L _ { 0 } / L \right) } { L } \int _ { 1 } ^ { \infty } d y \frac { y ^ { 2 } } { \sqrt { y ^ { 4 } - 1 } } .

3 可视化之后

E m m ˉ = 2 7 Q c π 1 / 2 Γ ( 1 / 4 ) 2 log ⁡ ( L 0 / L ) L ∫ 1 ∞ d y y 2 y 4 − 1 . E _ { m \bar { m } } = \frac { 2 ^ { 7 } \sqrt { Q _ { c } } \pi ^ { 1 / 2 } } { \Gamma ( 1 / 4 ) ^ { 2 } } \frac { \log \left( L _ { 0 } / L \right) } { L } \int _ { 1 } ^ { \infty } d y \frac { y ^ { 2 } } { \sqrt { y ^ { 4 } - 1 } } . Emmˉ=Γ(1/4)227Qc π1/2Llog(L0/L)1dyy41 y2.

示例2

1 输入是一个复杂的多个公式

请添加图片描述

2 推理结果是

\begin{array} { l } { \mathrm { H e n c e , f o r } \ ( x , y ) \in D , } \ { f _ { x } ( x ) = \int _ { - \infty } ^ { \infty } f ( x , y ) : d y = \int _ { 0 } ^ { 1 } \frac { x + 2 y } { 4 } : d y } \ { = \frac { x y + y ^ { 2 } } { 4 } \Big | _ { 0 } ^ { 1 } = \frac { x + 1 } { 4 } } \ { \mathrm { O t h e r w i s e , } \ f _ { x } ( x ) = 0 . \mathrm { T h a t i s , } } \ { f _ { x } ( x ) = \left{ \begin{array} { l l } { \frac { x + 1 } { 4 } , } & { \mathrm { i f } \ 0 < y < 1 \ \mathrm { a n d } \ 0 < x < 2 ; } \ { 0 , } & { \mathrm { o t h e r w i s e . } } \ \end{array} \right. } \ \end{array}

3 可视化之后

H e n c e , f o r   ( x , y ) ∈ D , f x ( x ) = ∫ − ∞ ∞ f ( x , y )   d y = ∫ 0 1 x + 2 y 4   d y = x y + y 2 4 ∣ 0 1 = x + 1 4 O t h e r w i s e ,   f x ( x ) = 0. T h a t i s , f x ( x ) = { x + 1 4 , i f   0 < y < 1   a n d   0 < x < 2 ; 0 , o t h e r w i s e . \begin{array} { l } { \mathrm { H e n c e , f o r } \ ( x , y ) \in D , } \\ { f _ { x } ( x ) = \int _ { - \infty } ^ { \infty } f ( x , y ) \: d y = \int _ { 0 } ^ { 1 } \frac { x + 2 y } { 4 } \: d y } \\ { = \frac { x y + y ^ { 2 } } { 4 } \Big | _ { 0 } ^ { 1 } = \frac { x + 1 } { 4 } } \\ { \mathrm { O t h e r w i s e , } \ f _ { x } ( x ) = 0 . \mathrm { T h a t i s , } } \\ { f _ { x } ( x ) = \left\{ \begin{array} { l l } { \frac { x + 1 } { 4 } , } & { \mathrm { i f } \ 0 < y < 1 \ \mathrm { a n d } \ 0 < x < 2 ; } \\ { 0 , } & { \mathrm { o t h e r w i s e . } } \\ \end{array} \right. } \\ \end{array} Hence,for (x,y)D,fx(x)=f(x,y)dy=014x+2ydy=4xy+y2 01=4x+1Otherwise, fx(x)=0.Thatis,fx(x)={4x+1,0,if 0<y<1 and 0<x<2;otherwise.

import argparse

# 从transformers库中导入Qwen2VLForConditionalGeneration类,用于条件生成任务,以及AutoProcessor类,用于数据处理等操作
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor

# 从qwen_vl_utils模块中导入process_vision_info函数,可能用于处理图像相关的输入信息
from qwen_vl_utils import process_vision_info

# 从peft库中导入PeftModel类(用于加载微调后的模型)和LoraConfig类(用于配置LoRA相关参数)
from peft import PeftModel, LoraConfig, TaskType  # 确保TaskType被正确导入
import torch


class LaTeXOCR:
    def __init__(self, local_model_path, lora_model_path):
        """
        初始化LaTeXOCR实例。

        参数:
            local_model_path (str): 本地基础模型的路径。
            lora_model_path (str): LoRA微调模型检查点的路径。
        """
        self.local_model_path = local_model_path
        self.lora_model_path = lora_model_path
        self._load_model_and_processor()

    def _load_model_and_processor(self):
        """
        加载模型和处理器。

        配置LoRA参数并加载预训练的基础模型、微调后的LoRA模型及相应的处理器。
        """
        config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,  # 设置任务类型为因果语言模型
            target_modules=[
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                "gate_proj",
                "up_proj",
                "down_proj",
            ],  # 指定要应用LoRA的目标模块列表
            inference_mode=True,  # 推理模式
            r=64,  # LoRA秩参数
            lora_alpha=16,  # LoRA alpha参数
            lora_dropout=0.05,  # LoRA dropout概率
            bias="none",  # 不处理偏置项
        )

        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
            self.local_model_path, torch_dtype=torch.float16, device_map="auto"
        )
        self.model = PeftModel.from_pretrained(
            self.model, self.lora_model_path, config=config
        )
        self.processor = AutoProcessor.from_pretrained(self.local_model_path)

    def generate_latex_from_image(self, test_image_path, prompt):
        """
        根据给定的测试图像路径和提示信息,生成对应的LaTeX格式文本。

        参数:
            test_image_path (str): 包含数学公式的测试图像路径。
            prompt (str): 提供给模型的提示信息。

        返回:
            str: 转换后的LaTeX格式文本。
        """
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": test_image_path,
                        "resized_height": 100,  # 图像调整后的高度
                        "resized_width": 500,  # 图像调整后的宽度
                    },
                    {"type": "text", "text": prompt},
                ],
            }
        ]

        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to(
            "cuda" if torch.cuda.is_available() else "cpu"
        )  # 检查CUDA可用性

        with torch.no_grad():
            generated_ids = self.model.generate(**inputs, max_new_tokens=8192)

        generated_ids_trimmed = [
            out_ids[len(in_ids) :]
            for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = self.processor.batch_decode(
            generated_ids_trimmed,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )

        return output_text #[0]


def parse_arguments():
    """
    解析命令行参数,并设置默认值。

    返回:
        argparse.Namespace: 包含解析后的命令行参数的对象。
    """
    parser = argparse.ArgumentParser(description="LaTeX OCR using Qwen2-VL")

    parser.add_argument(
        "--local_model_path",
        type=str,
        default="./Qwen/Qwen2-VL-7B-Instruct",
        help='Path to the local model. ',
    )
    parser.add_argument(
        "--lora_model_path",
        type=str,
        default="./output/Qwen2-VL-7B-LatexOCR/checkpoint-1500",
        help='Path to the LoRA model checkpoint. ',
    )
    parser.add_argument(
        "--test_image_path",
        type=str,
        default="./LaTeX_OCR/987.jpg",
        help='Path to the test image. ',
    )

    return parser.parse_args()


if __name__ == "__main__":
    args = parse_arguments()

    prompt = (
        "尊敬的Qwen2VL大模型,我需要你帮助我将一张包含数学公式的图片转换成LaTeX格式的文本。\n"
        "请按照以下说明进行操作:\n"
        "1. **图像中的内容**: 图像中包含的是一个或多个数学公式,请确保准确地识别并转换为LaTeX代码。\n"
        "2. **公式识别**: 请专注于识别和转换数学符号、希腊字母、积分、求和、分数、指数等数学元素。\n"
        "3. **LaTeX语法**: 输出时使用标准的LaTeX语法。确保所有的命令都是正确的,并且可以被LaTeX编译器正确解析。\n"
        "4. **结构保持**: 如果图像中的公式有特定的结构(例如多行公式、矩阵、方程组),请在输出的LaTeX代码中保留这些结构。\n"
        "5. **上下文无关**: 不要尝试解释公式的含义或者添加额外的信息,只需严格按照图像内容转换。\n"
        "6. **格式化**: 如果可能的话,使输出的LaTeX代码易于阅读,比如适当添加空格和换行。"
    )

    latex_ocr = LaTeXOCR(args.local_model_path, args.lora_model_path)
    result = latex_ocr.generate_latex_from_image(args.test_image_path, prompt)
    print(result)

加载的local_model_pathlora_model_path两个模型路径代表了不同类型的模型

1. local_model_path(基础模型)

这是指本地的基础预训练模型的路径。这个模型通常是在大规模数据集上预先训练好的,旨在捕捉语言或视觉特征的一般模式。对于Qwen2-VL这样的多模态模型来说,它是已经在大量文本和图像对上进行过预训练的模型。Qwen/Qwen2-VL-7B-Instruct是一个已经经过广泛预训练的大规模多模态模型。

2. lora_model_path(LoRA微调模型)

这指的是使用低秩自适应(Low-Rank Adaptation, LoRA)技术微调后的模型检查点路径。
基于基础模型,针对特定任务(如LaTeX OCR)进行了优化。在这个过程中,只有选定的层或模块中的参数被调整,其余大部分参数保持不变。
output/Qwen2-VL-7B-LatexOCR/checkpoint-1500是指在特定任务,这里是将数学公式图片转换为LaTeX代码)上进行了微调后的模型版本。该版本继承了基础模型的知识,并且针对特定任务进行了优化。