MobileSAM实战:轻量级图像分割模型的快速上手指南

在这里插入图片描述

一、引言

图像分割是计算机视觉领域的重要任务,它将图像分割成不同的区域,以便进一步分析和处理。随着深度学习技术的发展,图像分割模型的性能不断提升,但同时也面临着模型体积大、推理速度慢等问题,限制了在移动设备和边缘计算场景下的应用。

Meta公司开源的Segment Anything Model (SAM) 是一个强大的图像分割模型,但由于其参数量大(约9.1B),在资源受限的环境中难以高效运行。为此,研究人员开发了MobileSAM,这是SAM的轻量级版本,通过模型压缩技术将参数量减少到只有6.8M,同时保持了相似的分割质量。

本文将带你深入了解MobileSAM的特点、安装配置方法,并通过实战代码示例展示如何在实际项目中使用MobileSAM进行图像分割。

二、MobileSAM的特点

1. 超轻量级模型

MobileSAM的最大特点是其超轻量级的模型体积。与原始SAM相比,MobileSAM通过以下技术实现了模型压缩:

  • 将ViT-H骨干网络替换为更轻量级的ViT-Tiny
  • 对模型结构进行优化,减少冗余计算
  • 在保持分割精度的前提下,降低模型复杂度

这些优化使得MobileSAM的参数量仅为6.8M,不到原始SAM的1/1000,同时推理速度提升了数十倍。

2. 保持良好的分割精度

尽管模型体积大幅减小,MobileSAM仍然保持了良好的分割精度。实验结果表明,MobileSAM在多个标准数据集上的表现与原始SAM相当接近,能够满足大多数应用场景的需求。

3. 支持多种分割方式

MobileSAM支持与原始SAM相同的三种分割方式:

  • 点提示分割:通过指定图像中的点(前景或背景)进行分割
  • 边界框分割:通过指定边界框进行分割
  • 掩码提示分割:通过已有掩码进行分割

4. 适用于移动端和边缘设备

由于其轻量级特性,MobileSAM非常适合在移动端和边缘设备上运行,为实时图像分割应用提供了可能。

三、环境配置与安装

要使用MobileSAM,我们需要先配置相应的环境并安装必要的依赖库。以下是详细的安装步骤:

1. 安装Python环境

MobileSAM需要Python 3.8或更高版本。建议使用虚拟环境来隔离项目依赖:

# 创建虚拟环境
python -m venv mobile_sam_env

# 激活虚拟环境
# Windows系统
source mobile_sam_env/Scripts/activate
# Linux/Mac系统
source mobile_sam_env/bin/activate

2. 安装必要的依赖库

MobileSAM依赖以下库:

pip install torch torchvision opencv-python matplotlib pillow
pip install mobile-sam

3. 下载预训练模型

MobileSAM需要预训练模型文件才能正常工作。我们可以从GitHub仓库下载预训练模型:

# 下载MobileSAM预训练模型
git clone https://github.com/ChaoningZhang/MobileSAM.git
cd MobileSAM
wget https://github.com/ChaoningZhang/MobileSAM/releases/download/v0.0.1/mobile_sam.pt

或者直接从发布页面下载:https://github.com/ChaoningZhang/MobileSAM/releases/download/v0.0.1/mobile_sam.pt

四、MobileSAM实战代码示例

下面我们将通过一个完整的代码示例来展示如何使用MobileSAM进行图像分割。这个示例将包含模型加载、图像预处理、不同分割方式的实现以及结果可视化等功能。

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
MobileSAM 使用示例

这是一个简化的 MobileSAM 模型使用示例,演示如何加载模型、处理图像并进行目标分割,
不依赖大模型生成边界框,而是通过交互式点击或预设坐标进行分割。

MobileSAM 是 SAM (Segment Anything Model) 的轻量级版本,拥有更快的推理速度,
同时保持了相似的分割质量,非常适合在资源有限的环境中使用。
"""

import os
import io
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
from PIL import Image
import warnings

# 设置中文字体显示
plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC", "Arial Unicode MS", "Microsoft YaHei", "sans-serif"]
# 忽略MobileSAM模型注册警告
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting tiny_vit_.* in registry")
# 忽略timm库弃用警告
warnings.filterwarnings("ignore", category=FutureWarning, message="Importing from timm\.models\.layers is deprecated")

# 尝试导入MobileSAM库
try:
    from mobile_sam import sam_model_registry, SamPredictor
    MOBILE_SAM_AVAILABLE = True
except ImportError:
    print("未找到MobileSAM库。请使用 'pip install mobile-sam' 安装。")
    MOBILE_SAM_AVAILABLE = False


class MobileSAMSegmenter:
    """MobileSAM分割器类,封装MobileSAM模型的加载和使用"""
    
    def __init__(self, model_path="mobile_sam.pt", model_type="vit_t", device=None):
        """
        初始化MobileSAM分割器
        
        参数:
            model_path: MobileSAM模型权重文件路径
            model_type: 模型类型,默认为"vit_t"(tiny vision transformer)
            device: 运行设备,如"cuda"或"cpu",默认为自动选择
        """
        self.model_path = model_path
        self.model_type = model_type
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        self.predictor = None
        
        # 检查模型文件是否存在
        if not os.path.exists(self.model_path):
            # 尝试在当前目录查找
            current_dir = os.path.dirname(os.path.abspath(__file__))
            self.model_path = os.path.join(current_dir, self.model_path)
            if not os.path.exists(self.model_path):
                print(f"警告: 未找到MobileSAM模型文件 '{self.model_path}'")
                print("请下载模型文件并放置在正确的路径下")
                print("下载链接: https://github.com/ChaoningZhang/MobileSAM/releases/download/v0.0.1/mobile_sam.pt")
        
        # 初始化预测器
        if MOBILE_SAM_AVAILABLE:
            self._init_predictor()
    
    def _init_predictor(self):
        """初始化MobileSAM预测器"""
        if not os.path.exists(self.model_path):
            return
        
        try:
            # 加载模型权重
            mobile_sam = sam_model_registry[self.model_type](checkpoint=self.model_path)
            # 将模型移至指定设备
            mobile_sam.to(device=self.device)
            # 设置为评估模式
            mobile_sam.eval()
            # 创建预测器
            self.predictor = SamPredictor(mobile_sam)
            print(f"MobileSAM模型已成功加载到{self.device}设备")
        except Exception as e:
            print(f"加载MobileSAM模型时出错: {e}")
            self.predictor = None
    
    def set_image(self, image):
        """
        设置要分割的图像
        
        参数:
            image: 图像数据,可以是PIL Image、numpy数组或文件路径
        
        返回:
            bool: 设置是否成功
        """
        if not MOBILE_SAM_AVAILABLE or self.predictor is None:
            return False
        
        try:
            # 处理不同类型的图像输入
            if isinstance(image, str):
                # 如果是文件路径
                if not os.path.exists(image):
                    print(f"图像文件不存在: {image}")
                    return False
                image = Image.open(image)
                image_np = np.array(image)
            elif isinstance(image, Image.Image):
                # 如果是PIL Image
                image_np = np.array(image)
            elif isinstance(image, np.ndarray):
                # 如果是numpy数组
                image_np = image
            else:
                print("不支持的图像类型")
                return False
            
            # 如果是RGBA格式,转换为RGB
            if image_np.shape[-1] == 4:
                image_np = image_np[..., :3]
            
            # 设置图像到预测器
            self.predictor.set_image(image_np)
            return True
        except Exception as e:
            print(f"设置图像时出错: {e}")
            return False
    
    def segment_by_box(self, box):
        """
        通过边界框进行分割
        
        参数:
            box: 边界框坐标,格式为 [x1, y1, x2, y2]
        
        返回:
            tuple: (mask, bool) - 分割掩码和是否成功
        """
        if not MOBILE_SAM_AVAILABLE or self.predictor is None:
            return None, False
        
        try:
            # 将边界框转换为张量
            input_box = torch.tensor([box], dtype=torch.float32, device=self.predictor.device)
            
            # 应用坐标变换
            transformed_box = self.predictor.transform.apply_boxes_torch(input_box, 
                                                                       self.predictor.original_size)
            
            # 预测掩码
            masks, _, _ = self.predictor.predict_torch(
                point_coords=None,
                point_labels=None,
                boxes=transformed_box,
                multimask_output=False,
            )
            
            # 返回掩码(从GPU移至CPU并转换为numpy数组)
            return masks[0].cpu().numpy(), True
        except Exception as e:
            print(f"通过边界框分割时出错: {e}")
            return None, False
    
    def segment_by_points(self, point_coords, point_labels):
        """
        通过点进行分割
        
        参数:
            point_coords: 点坐标列表,格式为 [[x1, y1], [x2, y2], ...]
            point_labels: 点标签列表,1表示前景,0表示背景
        
        返回:
            tuple: (mask, bool) - 分割掩码和是否成功
        """
        if not MOBILE_SAM_AVAILABLE or self.predictor is None:
            return None, False
        
        try:
            # 将点坐标转换为张量
            input_points = torch.tensor(point_coords, dtype=torch.float32, device=self.predictor.device)
            input_labels = torch.tensor(point_labels, dtype=torch.int32, device=self.predictor.device)
            
            # 应用坐标变换
            transformed_points = self.predictor.transform.apply_coords_torch(
                input_points, self.predictor.original_size
            )
            
            # 预测掩码
            masks, _, _ = self.predictor.predict_torch(
                point_coords=transformed_points,
                point_labels=input_labels,
                boxes=None,
                multimask_output=True,
            )
            
            # 返回掩码(从GPU移至CPU并转换为numpy数组)
            # 选择第一个掩码(通常是最佳结果)
            return masks[0].cpu().numpy(), True
        except Exception as e:
            print(f"通过点分割时出错: {e}")
            return None, False
    
    def show_mask(self, mask, ax, color=None, alpha=0.5):
        """在图像上显示分割掩码"""
        if color is None:
            color = np.array([30/255, 144/255, 255/255, alpha])
        else:
            color = np.array(color + [alpha])
            
        h, w = mask.shape[-2:]
        mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
        ax.imshow(mask_image)
    
    def show_mask_border(self, mask, ax, color=None, linewidth=2):
        """只显示分割掩码的边框"""
        if color is None:
            color = [0, 1, 0]  # 默认为绿色
            
        h, w = mask.shape[-2:]
        mask_2d = mask.reshape(h, w)
        mask_uint8 = (mask_2d * 255).astype(np.uint8)
        
        # 查找轮廓
        contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        # 绘制轮廓
        for contour in contours:
            contour = contour.squeeze().astype(float)
            if len(contour.shape) == 1:
                contour = contour.reshape(-1, 2)
            if contour.shape[0] > 2:
                ax.fill(contour[:, 0], contour[:, 1], edgecolor=color, facecolor='none', linewidth=linewidth)


def main():
    """主函数,演示MobileSAM的基本使用方法"""
    if not MOBILE_SAM_AVAILABLE:
        print("请先安装MobileSAM库: pip install mobile-sam")
        return
    
    # 创建MobileSAM分割器实例
    segmenter = MobileSAMSegmenter()
    
    # 如果预测器初始化失败,尝试下载模型的提示
    if segmenter.predictor is None:
        print("\n模型初始化失败。请手动下载模型文件并放置在正确位置:")
        print("1. 访问 https://github.com/ChaoningZhang/MobileSAM/releases/download/v0.0.1/mobile_sam.pt")
        print("2. 下载模型文件并保存到当前目录")
        print("3. 重新运行此脚本")
        return
    
    # 使用示例图像(可以替换为您自己的图像路径)
    # 这里使用一个示例图像路径,实际使用时请替换
    image_path = "sample_image.jpg"
    
    # 如果示例图像不存在,创建一个简单的测试图像
    if not os.path.exists(image_path):
        print(f"未找到示例图像 {image_path},创建测试图像...")
        create_test_image(image_path)
    
    # 设置图像
    if not segmenter.set_image(image_path):
        print("设置图像失败")
        return
    
    # 示例1: 通过边界框进行分割
    print("\n示例1: 通过边界框进行分割")
    # 假设我们有一个感兴趣区域的边界框 [x1, y1, x2, y2]
    # 注意:这里的坐标是示例,实际使用时需要根据您的图像调整
    box = [100, 100, 300, 300]  # 假设的边界框坐标
    
    # 执行分割
    mask, success = segmenter.segment_by_box(box)
    
    if success and mask is not None:
        # 显示结果
        show_result(image_path, mask, box=box, title="通过边界框分割结果")
    
    # 示例2: 通过点进行分割
    print("\n示例2: 通过点进行分割")
    # 假设我们有一些点坐标和标签
    # 注意:这里的坐标是示例,实际使用时需要根据您的图像调整
    point_coords = [[200, 200]]  # 点坐标
    point_labels = [1]  # 1表示前景,0表示背景
    
    # 执行分割
    mask, success = segmenter.segment_by_points(point_coords, point_labels)
    
    if success and mask is not None:
        # 显示结果
        show_result(image_path, mask, points=(point_coords, point_labels), title="通过点分割结果")
    
    print("\n演示完成!您可以按任意键关闭图像窗口。")
    plt.show(block=True)


def create_test_image(image_path):
    """创建一个简单的测试图像用于演示"""
    try:
        # 创建一个白色背景的图像
        image = np.ones((500, 500, 3), dtype=np.uint8) * 255
        
        # 在图像中央绘制一个蓝色圆形
        cv2.circle(image, (250, 250), 100, (255, 0, 0), -1)
        
        # 在蓝色圆形周围绘制一些红色小圆点
        for i in range(12):
            angle = i * 30
            rad = np.radians(angle)
            x = int(250 + 150 * np.cos(rad))
            y = int(250 + 150 * np.sin(rad))
            cv2.circle(image, (x, y), 10, (0, 0, 255), -1)
        
        # 保存图像
        cv2.imwrite(image_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
        print(f"已创建测试图像: {image_path}")
    except Exception as e:
        print(f"创建测试图像时出错: {e}")


def show_result(image_path, mask, box=None, points=None, title="分割结果"):
    """显示分割结果"""
    try:
        # 读取原始图像
        image = Image.open(image_path)
        image_np = np.array(image)
        
        # 创建图形
        plt.figure(figsize=(10, 10))
        plt.imshow(image_np)
        plt.title(title)
        plt.axis('off')
        
        # 创建MobileSAM分割器实例用于显示
        segmenter = MobileSAMSegmenter()
        
        # 显示掩码
        segmenter.show_mask(mask, plt.gca())
        segmenter.show_mask_border(mask, plt.gca())
        
        # 显示边界框(如果提供)
        if box is not None:
            x1, y1, x2, y2 = box
            plt.gca().add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1, edgecolor='red', fill=False, linewidth=2))
        
        # 显示点(如果提供)
        if points is not None:
            coords, labels = points
            for i, (x, y) in enumerate(coords):
                color = 'green' if labels[i] == 1 else 'red'
                plt.scatter(x, y, color=color, s=100, edgecolor='white', linewidth=1)
        
        plt.tight_layout()
        plt.draw()  # 绘制但不阻塞
    except Exception as e:
        print(f"显示结果时出错: {e}")


if __name__ == "__main__":
    main()

五、代码解析

上面的代码示例提供了一个完整的MobileSAM使用框架,下面我们对代码的关键部分进行解析:

1. MobileSAMSegmenter类

这个类封装了MobileSAM模型的加载和使用,主要包括以下方法:

  • __init__: 初始化分割器,设置模型路径、类型和运行设备
  • _init_predictor: 加载模型权重并创建预测器实例
  • set_image: 设置要分割的图像,可以接受文件路径、PIL Image或numpy数组
  • segment_by_box: 通过边界框进行分割
  • segment_by_points: 通过点进行分割
  • show_maskshow_mask_border: 用于可视化分割结果

2. 主要功能演示

main函数中,我们演示了两种主要的分割方式:

  • 通过边界框进行分割:指定一个矩形区域,MobileSAM会分割出该区域内的对象
  • 通过点进行分割:指定一个或多个点(前景或背景),MobileSAM会根据这些点进行分割

3. 测试图像生成

为了方便演示,代码还提供了create_test_image函数,用于创建一个简单的测试图像。如果指定的图像文件不存在,程序会自动创建这个测试图像。

4. 结果可视化

show_result函数用于显示分割结果,包括原始图像、分割掩码、边界框和点标记等。

六、MobileSAM的应用场景

MobileSAM的轻量级特性使其适用于多种应用场景,特别是在资源受限的环境中:

1. 移动端实时图像处理

由于MobileSAM体积小、推理速度快,可以在移动设备上实现实时图像分割,为移动应用提供强大的视觉功能。

2. 边缘计算设备

在边缘计算设备上,MobileSAM可以高效运行,实现本地化的图像分割,减少对云端服务器的依赖,降低延迟和带宽消耗。

3. 视频分析

MobileSAM的高效推理能力使其适合处理视频流,实现实时视频分割和分析,可应用于监控、自动驾驶等领域。

4. 增强现实

在增强现实应用中,MobileSAM可以快速分割出场景中的对象,为虚拟物体与现实场景的融合提供基础。

5. 医疗影像分析

在医疗影像分析中,MobileSAM可以帮助医生快速分割和识别病变区域,提高诊断效率。

七、注意事项与优化建议

在使用MobileSAM时,有一些注意事项和优化建议可以帮助你获得更好的效果:

1. 模型文件路径

确保正确设置MobileSAM模型文件的路径。如果模型文件不存在,程序会给出下载提示。

2. 图像预处理

对于不同分辨率和格式的图像,MobileSAM可能需要不同的预处理方法。代码中提供了基本的预处理功能,但在实际应用中可能需要根据具体情况进行调整。

3. 分割参数调整

segment_by_boxsegment_by_points方法中,有一些参数可以调整以获得更好的分割效果,例如multimask_output参数可以控制是否生成多个掩码。

4. 性能优化

  • 对于需要处理大量图像的场景,可以考虑将模型加载和图像预处理部分进行优化,减少重复计算
  • 对于实时应用,可以考虑降低输入图像的分辨率,以提高处理速度
  • 在GPU设备上运行可以显著提高推理速度

5. 错误处理

代码中包含了基本的错误处理机制,但在实际应用中可能需要根据具体情况进行扩展,以提高程序的稳定性和可靠性。

八、总结

MobileSAM作为SAM的轻量级版本,通过模型压缩技术实现了体积和速度的大幅优化,同时保持了良好的分割精度。本文介绍了MobileSAM的特点、安装配置方法,并通过一个完整的代码示例展示了如何使用MobileSAM进行图像分割。

与传统的图像分割方法相比,MobileSAM具有以下优势:

  1. 轻量级:参数量仅为6.8M,适合在资源受限环境中运行
  2. 高效:推理速度快,可以实现实时处理
  3. 易用:提供了简单直观的API接口,方便集成到现有项目中
  4. 通用:可以处理各种类型的图像和对象,无需针对特定任务进行训练

随着移动计算和边缘计算的发展,MobileSAM有望在更多领域得到应用,为计算机视觉技术的普及和推广做出贡献。

如果你对MobileSAM感兴趣,不妨尝试在自己的项目中使用它,相信它会给你带来惊喜!

代码文件链接:https://pan.quark.cn/s/75f131a0a304

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐