一、YOLODataset中build_transforms加载LoadVisualPrompt
transforms.append(LoadVisualPrompt(nc=nc, augment=self.augment))
LoadVisualPrompt位于augment.py用于读取视觉提示特征。将标注gt的box或者mask转成特征图大小的mask,存储在labels[“visuals”]中
class LoadVisualPrompt:
def __init__(self, nc, augment):
self.nc = nc
self.min_interval = 5
self.augment = augment
self.scale_factor = 1/8
def make_mask(self, boxes, h, w):
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
r = torch.arange(w)[None, None, :] # rows shape(1,1,w)
c = torch.arange(h)[None, :, None] # cols shape(1,h,1)
return ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
def __call__(self, labels):
imgsz = labels["img"].shape[1:]
masksz = (int(imgsz[0] * self.scale_factor), int(imgsz[1] * self.scale_factor))
if "bboxes" in labels:
bboxes = labels["bboxes"]
bboxes = xywh2xyxy(bboxes) * torch.tensor(masksz)[[1, 0, 1, 0]] # target boxes
masks = self.make_mask(bboxes, *masksz).float()
elif "masks" in labels:
assert(not self.augment)
masks = F.interpolate(torch.from_numpy(labels["masks"]).unsqueeze(1),
masksz, mode="nearest").squeeze(1).float()
else:
raise ValueError("LoadVisualPrompt must have bboxes or masks in the label")
# import pdb
# pdb.set_trace()
cls = labels["cls"].squeeze(-1).to(torch.int)
"""
#类别排序cls_unique [0, 1, 2, 3], 类别标签inverse_indices:[3, 2, 1, 0])
类别排序cls_unique:[0, 1, 2, 3, 4, 5, 6] , 类别标签inverse_indices: [4, 0, 1, 2, 1, 0, 4, 5, 3, 6, 0])
"""
cls_unique, inverse_indices = torch.unique(cls, sorted=True, return_inverse=True)
if len(cls_unique) != 0 and self.augment:
assert(len(cls_unique) == cls_unique[-1] + 1)
elif not self.augment:
# assert(len(cls_unique) == 1)
pass
"""
#[4,80,80]
#[7,80,80]
torch.logical_or(visuals[idx], mask):制作对应类别的mask,这张mask上包含所有该类别的物体,将2个mask中只要有true位置合并,都是false的位置还是false
"""
visuals = torch.zeros(len(cls_unique), *masksz)
for idx, mask in zip(inverse_indices, masks):
visuals[idx] = torch.logical_or(visuals[idx], mask)
# visuals[0] = masks[random.choice(range(len(masks)))]
labels["visuals"] = visuals
return labels
二、预测视觉特征predict_visual_prompt.py
from ultralytics import YOLOE
import numpy as np
from ultralytics.models.yolo.yoloe.predict_vp import YOLOEVPSegPredictor,YOLOEVPDetectPredictor
model = YOLOE("/runs/detect/train/weights/best.pt")
# Handcrafted shape can also be passed, please refer to app.py
# Multiple boxes or handcrafted shapes can also be passed as visual prompt in an image
"""
如果多个框类别是同类,则读取后的vp特征是[1, 1, 80, 80])
如果多个框类别是不同类,如1,2,则读取后vp的特征是[1, 2, 80, 80]
"""
visuals = dict(
bboxes=np.array(
[
[500.0, 230.0, 521.0, 260.0], # For person
[503, 284,553,341], # For person
],
),
cls=np.array(
[
2, # For
1, # For glasses
]
)
)
source_image = '/path/to/your_test_img.jpg'
# model.predict(source_image, save=True, prompts=visuals, predictor=YOLOEVPSegPredictor)
#只调用训练的检测模型,非分割训练模型
model.predict(source_image, save=True, prompts=visuals, predictor=YOLOEVPDetectPredictor)
①调用model.predict()进行预测.
②运行engine中的model.py中predict函数,对一些参数进行设置后调用self.predictor。
③运行predictor.py中的__call__函数。
④调用stream_inference()进行推理,主要过程如下:
def stream_inference(self, source=None, model=None, *args, **kwargs):
"""Streams real-time inference on camera feed and saves results to file."""
if self.args.verbose:
LOGGER.info("")
# Setup model
if not self.model:
self.setup_model(model)
with self._lock: # for thread-safe inference
# Setup source every time predict is called
self.setup_source(source if source is not None else self.args.source)
# Check if save_dir/ label file exists
if self.args.save or self.args.save_txt:
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
# Warmup model
if not self.done_warmup:
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
self.done_warmup = True
self.seen, self.windows, self.batch = 0, [], None
profilers = (
ops.Profile(device=self.device),
ops.Profile(device=self.device),
ops.Profile(device=self.device),
)
self.run_callbacks("on_predict_start")
for self.batch in self.dataset:
self.run_callbacks("on_predict_batch_start")
paths, im0s, s = self.batch
# Preprocess
with profilers[0]:
im = self.preprocess(im0s)
# Inference
with profilers[1]:
preds = self.inference(im, *args, **kwargs)
if self.args.embed:
yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors
continue
# Postprocess
with profilers[2]:
self.results = self.postprocess(preds, im, im0s)
self.run_callbacks("on_predict_postprocess_end")
# Visualize, save, write results
n = len(im0s)
for i in range(n):
self.seen += 1
self.results[i].speed = {
"preprocess": profilers[0].dt * 1e3 / n,
"inference": profilers[1].dt * 1e3 / n,
"postprocess": profilers[2].dt * 1e3 / n,
}
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
s[i] += self.write_results(i, Path(paths[i]), im, s)
# Print batch results
if self.args.verbose:
LOGGER.info("\n".join(s))
self.run_callbacks("on_predict_batch_end")
yield from self.results
# Release assets
for v in self.vid_writer.values():
if isinstance(v, cv2.VideoWriter):
v.release()
# Print final results
if self.args.verbose and self.seen:
t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image
LOGGER.info(
f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape "
f"{(min(self.args.batch, self.seen), 3, *im.shape[2:])}" % t
)
if self.args.save or self.args.save_txt or self.args.save_crop:
nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ""
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
self.run_callbacks("on_predict_end")
4.1调用self.setup_source()读取传入的图像路径( ‘/path/to/your_test_img.jpg’ )
4.2调用im = self.preprocess(im0s)对图像进行处理。
def preprocess(self, im):
"""
Prepares input image before inference.
Args:
im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
"""
not_tensor = not isinstance(im, torch.Tensor)
if not_tensor:
im = np.stack(self.pre_transform(im))
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
im = np.ascontiguousarray(im) # contiguous
im = torch.from_numpy(im)
im = im.to(self.device)
im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
if not_tensor:
im /= 255 # 0 - 255 to 0.0 - 1.0
return im
此时会调用self.pre_transform(im)函数,重写在predict_vp.py中,其中最主要是LoadVisualPrompt提取图像上的视觉提示box进行提示编码。
def pre_transform(self, im):
letterbox = LetterBox(
self.imgsz,
auto=False,
stride=int(self.model.stride[-1].item()),
)
assert(len(im) == 1)
if "bboxes" in self.prompts and len(self.prompts["bboxes"]) > 0:
labels = dict(
img=im[0],
instances=Instances(bboxes=self.prompts["bboxes"],
segments=np.zeros((0, 1000, 2), dtype=np.float32),
bbox_format="xyxy", normalized=False),
cls=torch.tensor(self.prompts["cls"]).unsqueeze(-1)
)
labels = letterbox(labels)
instances = labels.pop("instances")
h, w = labels["img"].shape[:2]
instances.normalize(w, h)
instances.convert_bbox(format="xywh")
labels["bboxes"] = torch.from_numpy(instances.bboxes)
elif "masks" in self.prompts:
masks = self.prompts["masks"]
img = letterbox(image=im[0])
resized_masks = []
for i in range(len(masks)):
resized_masks.append(letterbox(image=masks[i]))
masks = np.stack(resized_masks)
masks[masks == 114] = 0
labels = dict(
img=img,
masks=masks,
cls=torch.tensor(self.prompts["cls"]).unsqueeze(-1)
)
else:
raise ValueError("Please provide valid bboxes or masks")
labels["img"] = labels["img"].transpose(2, 0, 1)
#吧输入的bbox转成视觉提示
load_vp = LoadVisualPrompt(nc=len(self.prompts["cls"]), augment=False)
labels = load_vp(labels)
cls = np.unique(self.prompts["cls"])
self.prompts = labels["visuals"].unsqueeze(0).to(self.device)
self.model.model[-1].nc = self.prompts.shape[1]
self.model.names = [f"object{cls[i]}" for i in range(self.prompts.shape[1])]
return [labels["img"].transpose(1, 2, 0)]
4.3接着进行stream_inference推理, preds = self.inference(im, *args, **kwargs)。此处self.inference先调用重写在predict_vp.py中的inference()函数,保存下视觉提示vpe。
def inference(self, im, *args, **kwargs):
if self.return_vpe:
self.vpe = self.model.get_visual_pe(im, visual=self.prompts)
return super().inference(im, vpe=self.prompts, *args, **kwargs)
再返回父类的predictor.py中的inference()函数
def inference(self, im, *args, **kwargs):
"""Runs inference on a given image using the specified model and arguments."""
visualize = (
increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
if self.args.visualize and (not self.source_type.tensor)
else False
)
return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
将上述处理完的图像和视觉特征传入到模型中进行推理。self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)获取推理预测结果。
4.4推理时调用task.py中的YOLOEModel中predict函数,将图像和视觉提示经过模型处理获取结果.
4.5self.postprocess(preds, im, im0s)对预测结果进行nms后处理获得最终结果。再进行后续的保存。