使用SAM(Segment Anything Model)模型做图像分割,包括单个、多个提示点进行目标分割

时间:2024-07-18 07:19:08

        最近需要做YOLO + SAM的工作,先是实现了YOLOV8 OBB  + SAM模型的级联,即在使用旋转目标检测模型YOLOV8 OBB对图像推理得到检测框之后分别计算检测框内部的单个和多个坐标点,在SAM模型中分别使用单个和多个提示点对目标进行分割(也可以对旋转检测框内的目标进行分割)

        SAM模型很强大,可以在SAM中使用单个提示点进行目标分割,使用多个提示点进行目标分割,使用方框(或者YOLO的检测框)对指定区域进行分割,使用提示点结合方框进行目标分割,使用多个同时输入的方框进行目标分割等等。

        下面先只介绍在SAM模型中分别使用单个和多个提示点对目标进行分割(也可以对旋转检测框内的目标进行分割)的方法(代码)。

1.使用单个提示点进行目标分割

import cv2
import numpy as np
import torch
from matplotlib import pyplot as plt
from segment_anything import sam_model_registry, SamPredictor


# 使用单个提示点进行目标分割
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',
               linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',
               linewidth=1.25)


DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"