【YOLOE: Real-Time Seeing Anything】predict_visual_prompt.py视觉推理代码分析(检测版本)

时间:2025-03-30 17:25:05

一、YOLODataset中build_transforms加载LoadVisualPrompt

transforms.append(LoadVisualPrompt(nc=nc, augment=self.augment))

LoadVisualPrompt位于augment.py用于读取视觉提示特征。将标注gt的box或者mask转成特征图大小的mask,存储在labels[“visuals”]中

class LoadVisualPrompt:
    def __init__(self, nc, augment):
        self.nc = nc
        self.min_interval = 5
        self.augment = augment
        self.scale_factor = 1/8
    
    def make_mask(self, boxes, h, w):
        x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1)  # x1 shape(n,1,1)
        r = torch.arange(w)[None, None, :]  # rows shape(1,1,w)
        c = torch.arange(h)[None, :, None]  # cols shape(1,h,1)

        return ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
    
    def __call__(self, labels):
        imgsz = labels["img"].shape[1:]
        masksz = (int(imgsz[0] * self.scale_factor), int(imgsz[1] * self.scale_factor))
        if "bboxes" in labels:
            bboxes = labels["bboxes"]
            bboxes = xywh2xyxy(bboxes) * torch.tensor(masksz)[[1, 0, 1, 0]]  # target boxes
            masks = self.make_mask(bboxes, *masksz).float()
        elif "masks" in labels:
            assert(not self.augment)
            masks = F.interpolate(torch.from_numpy(labels["masks"]).unsqueeze(1), 
                                  masksz, mode="nearest").squeeze(1).float()
        else:
            raise ValueError("LoadVisualPrompt must have bboxes or masks in the label")
        # import pdb
        # pdb.set_trace()
        cls = labels["cls"].squeeze(-1).to(torch.int)
        """
        #类别排序cls_unique [0, 1, 2, 3], 类别标签inverse_indices:[3, 2, 1, 0])

        类别排序cls_unique:[0, 1, 2, 3, 4, 5, 6] ,   类别标签inverse_indices: [4, 0, 1, 2, 1, 0, 4, 5, 3, 6, 0])
        """
        cls_unique, inverse_indices = torch.unique(cls, sorted=True, return_inverse=True)
        if len(cls_unique) != 0 and self.augment:
            assert(len(cls_unique) == cls_unique[-1] + 1)
        elif not self.augment:
            # assert(len(cls_unique) == 1)
            pass
        """
        #[4,80,80] 
        #[7,80,80]
        torch.logical_or(visuals[idx], mask):制作对应类别的mask,这张mask上包含所有该类别的物体,将2个mask中只要有true位置合并,都是false的位置还是false
        """
        visuals = torch.zeros(len(cls_unique), *masksz)  
        for idx, mask in zip(inverse_indices, masks):
            visuals[idx] = torch.logical_or(visuals[idx], mask)
        # visuals[0] = masks[random.choice(range(len(masks)))]
        labels["visuals"] = visuals
        return labels

二、预测视觉特征predict_visual_prompt.py

from ultralytics import YOLOE
import numpy as np
from ultralytics.models.yolo.yoloe.predict_vp import YOLOEVPSegPredictor,YOLOEVPDetectPredictor
model = YOLOE("/runs/detect/train/weights/best.pt")
# Handcrafted shape can also be passed, please refer to app.py
# Multiple boxes or handcrafted shapes can also be passed as visual prompt in an image
"""
如果多个框类别是同类,则读取后的vp特征是[1, 1, 80, 80])
如果多个框类别是不同类,如1,2,则读取后vp的特征是[1, 2, 80, 80]
"""
visuals = dict(
    bboxes=np.array(
        [
            [500.0, 230.0, 521.0, 260.0],  # For person
            [503, 284,553,341],  # For person
        ],
    ),
    cls=np.array(
        [
            2,  # For 
            1,  # For glasses
        ]
    )
)
source_image = '/path/to/your_test_img.jpg'
# model.predict(source_image, save=True, prompts=visuals, predictor=YOLOEVPSegPredictor)
#只调用训练的检测模型,非分割训练模型
model.predict(source_image, save=True, prompts=visuals, predictor=YOLOEVPDetectPredictor)

①调用model.predict()进行预测.
②运行engine中的model.py中predict函数,对一些参数进行设置后调用self.predictor。
③运行predictor.py中的__call__函数。
④调用stream_inference()进行推理,主要过程如下:

 def stream_inference(self, source=None, model=None, *args, **kwargs):
        """Streams real-time inference on camera feed and saves results to file."""
        if self.args.verbose:
            LOGGER.info("")
        # Setup model
        if not self.model:
            self.setup_model(model)
        with self._lock:  # for thread-safe inference
            # Setup source every time predict is called
            self.setup_source(source if source is not None else self.args.source)
            # Check if save_dir/ label file exists
            if self.args.save or self.args.save_txt:
                (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)

            # Warmup model
            if not self.done_warmup:
                self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
                self.done_warmup = True

            self.seen, self.windows, self.batch = 0, [], None
            profilers = (
                ops.Profile(device=self.device),
                ops.Profile(device=self.device),
                ops.Profile(device=self.device),
            )
            self.run_callbacks("on_predict_start")
            for self.batch in self.dataset:
                self.run_callbacks("on_predict_batch_start")
                paths, im0s, s = self.batch
                # Preprocess
                with profilers[0]:
                    im = self.preprocess(im0s)

                # Inference
                with profilers[1]:
                    preds = self.inference(im, *args, **kwargs)
                    if self.args.embed:
                        yield from [preds] if isinstance(preds, torch.Tensor) else preds  # yield embedding tensors
                        continue

                # Postprocess
                with profilers[2]:
                    self.results = self.postprocess(preds, im, im0s)
                self.run_callbacks("on_predict_postprocess_end")

                # Visualize, save, write results
                n = len(im0s)
                for i in range(n):
                    self.seen += 1
                    self.results[i].speed = {
                        "preprocess": profilers[0].dt * 1e3 / n,
                        "inference": profilers[1].dt * 1e3 / n,
                        "postprocess": profilers[2].dt * 1e3 / n,
                    }
                    if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
                        s[i] += self.write_results(i, Path(paths[i]), im, s)

                # Print batch results
                if self.args.verbose:
                    LOGGER.info("\n".join(s))

                self.run_callbacks("on_predict_batch_end")
                yield from self.results

        # Release assets
        for v in self.vid_writer.values():
            if isinstance(v, cv2.VideoWriter):
                v.release()

        # Print final results
        if self.args.verbose and self.seen:
            t = tuple(x.t / self.seen * 1e3 for x in profilers)  # speeds per image
            LOGGER.info(
                f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape "
                f"{(min(self.args.batch, self.seen), 3, *im.shape[2:])}" % t
            )
        if self.args.save or self.args.save_txt or self.args.save_crop:
            nl = len(list(self.save_dir.glob("labels/*.txt")))  # number of labels
            s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ""
            LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
        self.run_callbacks("on_predict_end")

4.1调用self.setup_source()读取传入的图像路径( ‘/path/to/your_test_img.jpg’ )
4.2调用im = self.preprocess(im0s)对图像进行处理。

    def preprocess(self, im):
        """
        Prepares input image before inference.

        Args:
            im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
        """
        not_tensor = not isinstance(im, torch.Tensor)
        if not_tensor:
            im = np.stack(self.pre_transform(im))
            im = im[..., ::-1].transpose((0, 3, 1, 2))  # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
            im = np.ascontiguousarray(im)  # contiguous
            im = torch.from_numpy(im)

        im = im.to(self.device)
        im = im.half() if self.model.fp16 else im.float()  # uint8 to fp16/32
        if not_tensor:
            im /= 255  # 0 - 255 to 0.0 - 1.0
        return im

此时会调用self.pre_transform(im)函数,重写在predict_vp.py中,其中最主要是LoadVisualPrompt提取图像上的视觉提示box进行提示编码。

   def pre_transform(self, im):
        letterbox = LetterBox(
            self.imgsz,
            auto=False,
            stride=int(self.model.stride[-1].item()),
        )
        assert(len(im) == 1)

        if "bboxes" in self.prompts and len(self.prompts["bboxes"]) > 0:
            labels = dict(
                img=im[0],
                instances=Instances(bboxes=self.prompts["bboxes"], 
                                    segments=np.zeros((0, 1000, 2), dtype=np.float32), 
                                    bbox_format="xyxy", normalized=False),
                cls=torch.tensor(self.prompts["cls"]).unsqueeze(-1)
            )
            labels = letterbox(labels)
            instances = labels.pop("instances")
            h, w = labels["img"].shape[:2]
            instances.normalize(w, h)
            instances.convert_bbox(format="xywh")
            labels["bboxes"] = torch.from_numpy(instances.bboxes)
        elif "masks" in self.prompts:
            masks = self.prompts["masks"]

            img = letterbox(image=im[0])
            resized_masks = []
            for i in range(len(masks)):
                resized_masks.append(letterbox(image=masks[i]))
            masks = np.stack(resized_masks)
            masks[masks == 114] = 0
            labels = dict(
                img=img,
                masks=masks,
                cls=torch.tensor(self.prompts["cls"]).unsqueeze(-1)
            )
        else:
            raise ValueError("Please provide valid bboxes or masks")

        labels["img"] = labels["img"].transpose(2, 0, 1)
        #吧输入的bbox转成视觉提示
        load_vp = LoadVisualPrompt(nc=len(self.prompts["cls"]), augment=False)
        labels = load_vp(labels)
        
        cls = np.unique(self.prompts["cls"])
        self.prompts = labels["visuals"].unsqueeze(0).to(self.device)
        self.model.model[-1].nc = self.prompts.shape[1]
        self.model.names = [f"object{cls[i]}" for i in range(self.prompts.shape[1])]
        return [labels["img"].transpose(1, 2, 0)]

4.3接着进行stream_inference推理, preds = self.inference(im, *args, **kwargs)。此处self.inference先调用重写在predict_vp.py中的inference()函数,保存下视觉提示vpe。

    def inference(self, im, *args, **kwargs):
        if self.return_vpe:
            self.vpe = self.model.get_visual_pe(im, visual=self.prompts)
        return super().inference(im, vpe=self.prompts, *args, **kwargs)

再返回父类的predictor.py中的inference()函数

   def inference(self, im, *args, **kwargs):
        """Runs inference on a given image using the specified model and arguments."""
        visualize = (
            increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
            if self.args.visualize and (not self.source_type.tensor)
            else False
        )
        return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)

将上述处理完的图像和视觉特征传入到模型中进行推理。self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)获取推理预测结果。
4.4推理时调用task.py中的YOLOEModel中predict函数,将图像和视觉提示经过模型处理获取结果.
4.5self.postprocess(preds, im, im0s)对预测结果进行nms后处理获得最终结果。再进行后续的保存。