基于Qwen2-VL模型针对LaTeX OCR任务进行微调训练 - 单图推理
flyfish
基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_LoRA配置如何写
基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_单图推理
基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_原模型_单图推理
基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_原模型_多图推理
基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_多图推理
基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_数据处理
基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_训练
基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_训练过程
举两个例子
示例1
1 输入是一个不大清楚的图像
2 推理结果是
E _ { m \bar { m } } = \frac { 2 ^ { 7 } \sqrt { Q _ { c } } \pi ^ { 1 / 2 } } { \Gamma ( 1 / 4 ) ^ { 2 } } \frac { \log \left( L _ { 0 } / L \right) } { L } \int _ { 1 } ^ { \infty } d y \frac { y ^ { 2 } } { \sqrt { y ^ { 4 } - 1 } } .
3 可视化之后
E m m ˉ = 2 7 Q c π 1 / 2 Γ ( 1 / 4 ) 2 log ( L 0 / L ) L ∫ 1 ∞ d y y 2 y 4 − 1 . E _ { m \bar { m } } = \frac { 2 ^ { 7 } \sqrt { Q _ { c } } \pi ^ { 1 / 2 } } { \Gamma ( 1 / 4 ) ^ { 2 } } \frac { \log \left( L _ { 0 } / L \right) } { L } \int _ { 1 } ^ { \infty } d y \frac { y ^ { 2 } } { \sqrt { y ^ { 4 } - 1 } } . Emmˉ=Γ(1/4)227Qcπ1/2Llog(L0/L)∫1∞dyy4−1y2.
示例2
1 输入是一个复杂的多个公式
2 推理结果是
\begin{array} { l } { \mathrm { H e n c e , f o r } \ ( x , y ) \in D , } \ { f _ { x } ( x ) = \int _ { - \infty } ^ { \infty } f ( x , y ) : d y = \int _ { 0 } ^ { 1 } \frac { x + 2 y } { 4 } : d y } \ { = \frac { x y + y ^ { 2 } } { 4 } \Big | _ { 0 } ^ { 1 } = \frac { x + 1 } { 4 } } \ { \mathrm { O t h e r w i s e , } \ f _ { x } ( x ) = 0 . \mathrm { T h a t i s , } } \ { f _ { x } ( x ) = \left{ \begin{array} { l l } { \frac { x + 1 } { 4 } , } & { \mathrm { i f } \ 0 < y < 1 \ \mathrm { a n d } \ 0 < x < 2 ; } \ { 0 , } & { \mathrm { o t h e r w i s e . } } \ \end{array} \right. } \ \end{array}
3 可视化之后
H e n c e , f o r ( x , y ) ∈ D , f x ( x ) = ∫ − ∞ ∞ f ( x , y ) d y = ∫ 0 1 x + 2 y 4 d y = x y + y 2 4 ∣ 0 1 = x + 1 4 O t h e r w i s e , f x ( x ) = 0. T h a t i s , f x ( x ) = { x + 1 4 , i f 0 < y < 1 a n d 0 < x < 2 ; 0 , o t h e r w i s e . \begin{array} { l } { \mathrm { H e n c e , f o r } \ ( x , y ) \in D , } \\ { f _ { x } ( x ) = \int _ { - \infty } ^ { \infty } f ( x , y ) \: d y = \int _ { 0 } ^ { 1 } \frac { x + 2 y } { 4 } \: d y } \\ { = \frac { x y + y ^ { 2 } } { 4 } \Big | _ { 0 } ^ { 1 } = \frac { x + 1 } { 4 } } \\ { \mathrm { O t h e r w i s e , } \ f _ { x } ( x ) = 0 . \mathrm { T h a t i s , } } \\ { f _ { x } ( x ) = \left\{ \begin{array} { l l } { \frac { x + 1 } { 4 } , } & { \mathrm { i f } \ 0 < y < 1 \ \mathrm { a n d } \ 0 < x < 2 ; } \\ { 0 , } & { \mathrm { o t h e r w i s e . } } \\ \end{array} \right. } \\ \end{array} Hence,for (x,y)∈D,fx(x)=∫−∞∞f(x,y)dy=∫014x+2ydy=4xy+y2 01=4x+1Otherwise, fx(x)=0.Thatis,fx(x)={4x+1,0,if 0<y<1 and 0<x<2;otherwise.
import argparse
# 从transformers库中导入Qwen2VLForConditionalGeneration类,用于条件生成任务,以及AutoProcessor类,用于数据处理等操作
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
# 从qwen_vl_utils模块中导入process_vision_info函数,可能用于处理图像相关的输入信息
from qwen_vl_utils import process_vision_info
# 从peft库中导入PeftModel类(用于加载微调后的模型)和LoraConfig类(用于配置LoRA相关参数)
from peft import PeftModel, LoraConfig, TaskType # 确保TaskType被正确导入
import torch
class LaTeXOCR:
def __init__(self, local_model_path, lora_model_path):
"""
初始化LaTeXOCR实例。
参数:
local_model_path (str): 本地基础模型的路径。
lora_model_path (str): LoRA微调模型检查点的路径。
"""
self.local_model_path = local_model_path
self.lora_model_path = lora_model_path
self._load_model_and_processor()
def _load_model_and_processor(self):
"""
加载模型和处理器。
配置LoRA参数并加载预训练的基础模型、微调后的LoRA模型及相应的处理器。
"""
config = LoraConfig(
task_type=TaskType.CAUSAL_LM, # 设置任务类型为因果语言模型
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
], # 指定要应用LoRA的目标模块列表
inference_mode=True, # 推理模式
r=64, # LoRA秩参数
lora_alpha=16, # LoRA alpha参数
lora_dropout=0.05, # LoRA dropout概率
bias="none", # 不处理偏置项
)
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
self.local_model_path, torch_dtype=torch.float16, device_map="auto"
)
self.model = PeftModel.from_pretrained(
self.model, self.lora_model_path, config=config
)
self.processor = AutoProcessor.from_pretrained(self.local_model_path)
def generate_latex_from_image(self, test_image_path, prompt):
"""
根据给定的测试图像路径和提示信息,生成对应的LaTeX格式文本。
参数:
test_image_path (str): 包含数学公式的测试图像路径。
prompt (str): 提供给模型的提示信息。
返回:
str: 转换后的LaTeX格式文本。
"""
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": test_image_path,
"resized_height": 100, # 图像调整后的高度
"resized_width": 500, # 图像调整后的宽度
},
{"type": "text", "text": prompt},
],
}
]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(
"cuda" if torch.cuda.is_available() else "cpu"
) # 检查CUDA可用性
with torch.no_grad():
generated_ids = self.model.generate(**inputs, max_new_tokens=8192)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
return output_text #[0]
def parse_arguments():
"""
解析命令行参数,并设置默认值。
返回:
argparse.Namespace: 包含解析后的命令行参数的对象。
"""
parser = argparse.ArgumentParser(description="LaTeX OCR using Qwen2-VL")
parser.add_argument(
"--local_model_path",
type=str,
default="./Qwen/Qwen2-VL-7B-Instruct",
help='Path to the local model. ',
)
parser.add_argument(
"--lora_model_path",
type=str,
default="./output/Qwen2-VL-7B-LatexOCR/checkpoint-1500",
help='Path to the LoRA model checkpoint. ',
)
parser.add_argument(
"--test_image_path",
type=str,
default="./LaTeX_OCR/987.jpg",
help='Path to the test image. ',
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_arguments()
prompt = (
"尊敬的Qwen2VL大模型,我需要你帮助我将一张包含数学公式的图片转换成LaTeX格式的文本。\n"
"请按照以下说明进行操作:\n"
"1. **图像中的内容**: 图像中包含的是一个或多个数学公式,请确保准确地识别并转换为LaTeX代码。\n"
"2. **公式识别**: 请专注于识别和转换数学符号、希腊字母、积分、求和、分数、指数等数学元素。\n"
"3. **LaTeX语法**: 输出时使用标准的LaTeX语法。确保所有的命令都是正确的,并且可以被LaTeX编译器正确解析。\n"
"4. **结构保持**: 如果图像中的公式有特定的结构(例如多行公式、矩阵、方程组),请在输出的LaTeX代码中保留这些结构。\n"
"5. **上下文无关**: 不要尝试解释公式的含义或者添加额外的信息,只需严格按照图像内容转换。\n"
"6. **格式化**: 如果可能的话,使输出的LaTeX代码易于阅读,比如适当添加空格和换行。"
)
latex_ocr = LaTeXOCR(args.local_model_path, args.lora_model_path)
result = latex_ocr.generate_latex_from_image(args.test_image_path, prompt)
print(result)
加载的local_model_path
和lora_model_path
两个模型路径代表了不同类型的模型
1. local_model_path
(基础模型)
这是指本地的基础预训练模型的路径。这个模型通常是在大规模数据集上预先训练好的,旨在捕捉语言或视觉特征的一般模式。对于Qwen2-VL这样的多模态模型来说,它是已经在大量文本和图像对上进行过预训练的模型。Qwen/Qwen2-VL-7B-Instruct
是一个已经经过广泛预训练的大规模多模态模型。
2. lora_model_path
(LoRA微调模型)
这指的是使用低秩自适应(Low-Rank Adaptation, LoRA)技术微调后的模型检查点路径。
基于基础模型,针对特定任务(如LaTeX OCR)进行了优化。在这个过程中,只有选定的层或模块中的参数被调整,其余大部分参数保持不变。output/Qwen2-VL-7B-LatexOCR/checkpoint-1500
是指在特定任务,这里是将数学公式图片转换为LaTeX代码)上进行了微调后的模型版本。该版本继承了基础模型的知识,并且针对特定任务进行了优化。