MobileSAM实战:轻量级图像分割模型的快速上手指南
MobileSAM作为SAM的轻量级版本,通过模型压缩技术实现了体积和速度的大幅优化,同时保持了良好的分割精度。本文介绍了MobileSAM的特点、安装配置方法,并通过一个完整的代码示例展示了如何使用MobileSAM进行图像分割。轻量级:参数量仅为6.8M,适合在资源受限环境中运行高效:推理速度快,可以实现实时处理易用:提供了简单直观的API接口,方便集成到现有项目中通用:可以处理各种类型的图像
MobileSAM实战:轻量级图像分割模型的快速上手指南

一、引言
图像分割是计算机视觉领域的重要任务,它将图像分割成不同的区域,以便进一步分析和处理。随着深度学习技术的发展,图像分割模型的性能不断提升,但同时也面临着模型体积大、推理速度慢等问题,限制了在移动设备和边缘计算场景下的应用。
Meta公司开源的Segment Anything Model (SAM) 是一个强大的图像分割模型,但由于其参数量大(约9.1B),在资源受限的环境中难以高效运行。为此,研究人员开发了MobileSAM,这是SAM的轻量级版本,通过模型压缩技术将参数量减少到只有6.8M,同时保持了相似的分割质量。
本文将带你深入了解MobileSAM的特点、安装配置方法,并通过实战代码示例展示如何在实际项目中使用MobileSAM进行图像分割。
二、MobileSAM的特点
1. 超轻量级模型
MobileSAM的最大特点是其超轻量级的模型体积。与原始SAM相比,MobileSAM通过以下技术实现了模型压缩:
- 将ViT-H骨干网络替换为更轻量级的ViT-Tiny
- 对模型结构进行优化,减少冗余计算
- 在保持分割精度的前提下,降低模型复杂度
这些优化使得MobileSAM的参数量仅为6.8M,不到原始SAM的1/1000,同时推理速度提升了数十倍。
2. 保持良好的分割精度
尽管模型体积大幅减小,MobileSAM仍然保持了良好的分割精度。实验结果表明,MobileSAM在多个标准数据集上的表现与原始SAM相当接近,能够满足大多数应用场景的需求。
3. 支持多种分割方式
MobileSAM支持与原始SAM相同的三种分割方式:
- 点提示分割:通过指定图像中的点(前景或背景)进行分割
- 边界框分割:通过指定边界框进行分割
- 掩码提示分割:通过已有掩码进行分割
4. 适用于移动端和边缘设备
由于其轻量级特性,MobileSAM非常适合在移动端和边缘设备上运行,为实时图像分割应用提供了可能。
三、环境配置与安装
要使用MobileSAM,我们需要先配置相应的环境并安装必要的依赖库。以下是详细的安装步骤:
1. 安装Python环境
MobileSAM需要Python 3.8或更高版本。建议使用虚拟环境来隔离项目依赖:
# 创建虚拟环境
python -m venv mobile_sam_env
# 激活虚拟环境
# Windows系统
source mobile_sam_env/Scripts/activate
# Linux/Mac系统
source mobile_sam_env/bin/activate
2. 安装必要的依赖库
MobileSAM依赖以下库:
pip install torch torchvision opencv-python matplotlib pillow
pip install mobile-sam
3. 下载预训练模型
MobileSAM需要预训练模型文件才能正常工作。我们可以从GitHub仓库下载预训练模型:
# 下载MobileSAM预训练模型
git clone https://github.com/ChaoningZhang/MobileSAM.git
cd MobileSAM
wget https://github.com/ChaoningZhang/MobileSAM/releases/download/v0.0.1/mobile_sam.pt
或者直接从发布页面下载:https://github.com/ChaoningZhang/MobileSAM/releases/download/v0.0.1/mobile_sam.pt
四、MobileSAM实战代码示例
下面我们将通过一个完整的代码示例来展示如何使用MobileSAM进行图像分割。这个示例将包含模型加载、图像预处理、不同分割方式的实现以及结果可视化等功能。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
MobileSAM 使用示例
这是一个简化的 MobileSAM 模型使用示例,演示如何加载模型、处理图像并进行目标分割,
不依赖大模型生成边界框,而是通过交互式点击或预设坐标进行分割。
MobileSAM 是 SAM (Segment Anything Model) 的轻量级版本,拥有更快的推理速度,
同时保持了相似的分割质量,非常适合在资源有限的环境中使用。
"""
import os
import io
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
from PIL import Image
import warnings
# 设置中文字体显示
plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC", "Arial Unicode MS", "Microsoft YaHei", "sans-serif"]
# 忽略MobileSAM模型注册警告
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting tiny_vit_.* in registry")
# 忽略timm库弃用警告
warnings.filterwarnings("ignore", category=FutureWarning, message="Importing from timm\.models\.layers is deprecated")
# 尝试导入MobileSAM库
try:
from mobile_sam import sam_model_registry, SamPredictor
MOBILE_SAM_AVAILABLE = True
except ImportError:
print("未找到MobileSAM库。请使用 'pip install mobile-sam' 安装。")
MOBILE_SAM_AVAILABLE = False
class MobileSAMSegmenter:
"""MobileSAM分割器类,封装MobileSAM模型的加载和使用"""
def __init__(self, model_path="mobile_sam.pt", model_type="vit_t", device=None):
"""
初始化MobileSAM分割器
参数:
model_path: MobileSAM模型权重文件路径
model_type: 模型类型,默认为"vit_t"(tiny vision transformer)
device: 运行设备,如"cuda"或"cpu",默认为自动选择
"""
self.model_path = model_path
self.model_type = model_type
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
self.predictor = None
# 检查模型文件是否存在
if not os.path.exists(self.model_path):
# 尝试在当前目录查找
current_dir = os.path.dirname(os.path.abspath(__file__))
self.model_path = os.path.join(current_dir, self.model_path)
if not os.path.exists(self.model_path):
print(f"警告: 未找到MobileSAM模型文件 '{self.model_path}'")
print("请下载模型文件并放置在正确的路径下")
print("下载链接: https://github.com/ChaoningZhang/MobileSAM/releases/download/v0.0.1/mobile_sam.pt")
# 初始化预测器
if MOBILE_SAM_AVAILABLE:
self._init_predictor()
def _init_predictor(self):
"""初始化MobileSAM预测器"""
if not os.path.exists(self.model_path):
return
try:
# 加载模型权重
mobile_sam = sam_model_registry[self.model_type](checkpoint=self.model_path)
# 将模型移至指定设备
mobile_sam.to(device=self.device)
# 设置为评估模式
mobile_sam.eval()
# 创建预测器
self.predictor = SamPredictor(mobile_sam)
print(f"MobileSAM模型已成功加载到{self.device}设备")
except Exception as e:
print(f"加载MobileSAM模型时出错: {e}")
self.predictor = None
def set_image(self, image):
"""
设置要分割的图像
参数:
image: 图像数据,可以是PIL Image、numpy数组或文件路径
返回:
bool: 设置是否成功
"""
if not MOBILE_SAM_AVAILABLE or self.predictor is None:
return False
try:
# 处理不同类型的图像输入
if isinstance(image, str):
# 如果是文件路径
if not os.path.exists(image):
print(f"图像文件不存在: {image}")
return False
image = Image.open(image)
image_np = np.array(image)
elif isinstance(image, Image.Image):
# 如果是PIL Image
image_np = np.array(image)
elif isinstance(image, np.ndarray):
# 如果是numpy数组
image_np = image
else:
print("不支持的图像类型")
return False
# 如果是RGBA格式,转换为RGB
if image_np.shape[-1] == 4:
image_np = image_np[..., :3]
# 设置图像到预测器
self.predictor.set_image(image_np)
return True
except Exception as e:
print(f"设置图像时出错: {e}")
return False
def segment_by_box(self, box):
"""
通过边界框进行分割
参数:
box: 边界框坐标,格式为 [x1, y1, x2, y2]
返回:
tuple: (mask, bool) - 分割掩码和是否成功
"""
if not MOBILE_SAM_AVAILABLE or self.predictor is None:
return None, False
try:
# 将边界框转换为张量
input_box = torch.tensor([box], dtype=torch.float32, device=self.predictor.device)
# 应用坐标变换
transformed_box = self.predictor.transform.apply_boxes_torch(input_box,
self.predictor.original_size)
# 预测掩码
masks, _, _ = self.predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_box,
multimask_output=False,
)
# 返回掩码(从GPU移至CPU并转换为numpy数组)
return masks[0].cpu().numpy(), True
except Exception as e:
print(f"通过边界框分割时出错: {e}")
return None, False
def segment_by_points(self, point_coords, point_labels):
"""
通过点进行分割
参数:
point_coords: 点坐标列表,格式为 [[x1, y1], [x2, y2], ...]
point_labels: 点标签列表,1表示前景,0表示背景
返回:
tuple: (mask, bool) - 分割掩码和是否成功
"""
if not MOBILE_SAM_AVAILABLE or self.predictor is None:
return None, False
try:
# 将点坐标转换为张量
input_points = torch.tensor(point_coords, dtype=torch.float32, device=self.predictor.device)
input_labels = torch.tensor(point_labels, dtype=torch.int32, device=self.predictor.device)
# 应用坐标变换
transformed_points = self.predictor.transform.apply_coords_torch(
input_points, self.predictor.original_size
)
# 预测掩码
masks, _, _ = self.predictor.predict_torch(
point_coords=transformed_points,
point_labels=input_labels,
boxes=None,
multimask_output=True,
)
# 返回掩码(从GPU移至CPU并转换为numpy数组)
# 选择第一个掩码(通常是最佳结果)
return masks[0].cpu().numpy(), True
except Exception as e:
print(f"通过点分割时出错: {e}")
return None, False
def show_mask(self, mask, ax, color=None, alpha=0.5):
"""在图像上显示分割掩码"""
if color is None:
color = np.array([30/255, 144/255, 255/255, alpha])
else:
color = np.array(color + [alpha])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_mask_border(self, mask, ax, color=None, linewidth=2):
"""只显示分割掩码的边框"""
if color is None:
color = [0, 1, 0] # 默认为绿色
h, w = mask.shape[-2:]
mask_2d = mask.reshape(h, w)
mask_uint8 = (mask_2d * 255).astype(np.uint8)
# 查找轮廓
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 绘制轮廓
for contour in contours:
contour = contour.squeeze().astype(float)
if len(contour.shape) == 1:
contour = contour.reshape(-1, 2)
if contour.shape[0] > 2:
ax.fill(contour[:, 0], contour[:, 1], edgecolor=color, facecolor='none', linewidth=linewidth)
def main():
"""主函数,演示MobileSAM的基本使用方法"""
if not MOBILE_SAM_AVAILABLE:
print("请先安装MobileSAM库: pip install mobile-sam")
return
# 创建MobileSAM分割器实例
segmenter = MobileSAMSegmenter()
# 如果预测器初始化失败,尝试下载模型的提示
if segmenter.predictor is None:
print("\n模型初始化失败。请手动下载模型文件并放置在正确位置:")
print("1. 访问 https://github.com/ChaoningZhang/MobileSAM/releases/download/v0.0.1/mobile_sam.pt")
print("2. 下载模型文件并保存到当前目录")
print("3. 重新运行此脚本")
return
# 使用示例图像(可以替换为您自己的图像路径)
# 这里使用一个示例图像路径,实际使用时请替换
image_path = "sample_image.jpg"
# 如果示例图像不存在,创建一个简单的测试图像
if not os.path.exists(image_path):
print(f"未找到示例图像 {image_path},创建测试图像...")
create_test_image(image_path)
# 设置图像
if not segmenter.set_image(image_path):
print("设置图像失败")
return
# 示例1: 通过边界框进行分割
print("\n示例1: 通过边界框进行分割")
# 假设我们有一个感兴趣区域的边界框 [x1, y1, x2, y2]
# 注意:这里的坐标是示例,实际使用时需要根据您的图像调整
box = [100, 100, 300, 300] # 假设的边界框坐标
# 执行分割
mask, success = segmenter.segment_by_box(box)
if success and mask is not None:
# 显示结果
show_result(image_path, mask, box=box, title="通过边界框分割结果")
# 示例2: 通过点进行分割
print("\n示例2: 通过点进行分割")
# 假设我们有一些点坐标和标签
# 注意:这里的坐标是示例,实际使用时需要根据您的图像调整
point_coords = [[200, 200]] # 点坐标
point_labels = [1] # 1表示前景,0表示背景
# 执行分割
mask, success = segmenter.segment_by_points(point_coords, point_labels)
if success and mask is not None:
# 显示结果
show_result(image_path, mask, points=(point_coords, point_labels), title="通过点分割结果")
print("\n演示完成!您可以按任意键关闭图像窗口。")
plt.show(block=True)
def create_test_image(image_path):
"""创建一个简单的测试图像用于演示"""
try:
# 创建一个白色背景的图像
image = np.ones((500, 500, 3), dtype=np.uint8) * 255
# 在图像中央绘制一个蓝色圆形
cv2.circle(image, (250, 250), 100, (255, 0, 0), -1)
# 在蓝色圆形周围绘制一些红色小圆点
for i in range(12):
angle = i * 30
rad = np.radians(angle)
x = int(250 + 150 * np.cos(rad))
y = int(250 + 150 * np.sin(rad))
cv2.circle(image, (x, y), 10, (0, 0, 255), -1)
# 保存图像
cv2.imwrite(image_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
print(f"已创建测试图像: {image_path}")
except Exception as e:
print(f"创建测试图像时出错: {e}")
def show_result(image_path, mask, box=None, points=None, title="分割结果"):
"""显示分割结果"""
try:
# 读取原始图像
image = Image.open(image_path)
image_np = np.array(image)
# 创建图形
plt.figure(figsize=(10, 10))
plt.imshow(image_np)
plt.title(title)
plt.axis('off')
# 创建MobileSAM分割器实例用于显示
segmenter = MobileSAMSegmenter()
# 显示掩码
segmenter.show_mask(mask, plt.gca())
segmenter.show_mask_border(mask, plt.gca())
# 显示边界框(如果提供)
if box is not None:
x1, y1, x2, y2 = box
plt.gca().add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1, edgecolor='red', fill=False, linewidth=2))
# 显示点(如果提供)
if points is not None:
coords, labels = points
for i, (x, y) in enumerate(coords):
color = 'green' if labels[i] == 1 else 'red'
plt.scatter(x, y, color=color, s=100, edgecolor='white', linewidth=1)
plt.tight_layout()
plt.draw() # 绘制但不阻塞
except Exception as e:
print(f"显示结果时出错: {e}")
if __name__ == "__main__":
main()
五、代码解析
上面的代码示例提供了一个完整的MobileSAM使用框架,下面我们对代码的关键部分进行解析:
1. MobileSAMSegmenter类
这个类封装了MobileSAM模型的加载和使用,主要包括以下方法:
__init__: 初始化分割器,设置模型路径、类型和运行设备_init_predictor: 加载模型权重并创建预测器实例set_image: 设置要分割的图像,可以接受文件路径、PIL Image或numpy数组segment_by_box: 通过边界框进行分割segment_by_points: 通过点进行分割show_mask和show_mask_border: 用于可视化分割结果
2. 主要功能演示
在main函数中,我们演示了两种主要的分割方式:
- 通过边界框进行分割:指定一个矩形区域,MobileSAM会分割出该区域内的对象
- 通过点进行分割:指定一个或多个点(前景或背景),MobileSAM会根据这些点进行分割
3. 测试图像生成
为了方便演示,代码还提供了create_test_image函数,用于创建一个简单的测试图像。如果指定的图像文件不存在,程序会自动创建这个测试图像。
4. 结果可视化
show_result函数用于显示分割结果,包括原始图像、分割掩码、边界框和点标记等。
六、MobileSAM的应用场景
MobileSAM的轻量级特性使其适用于多种应用场景,特别是在资源受限的环境中:
1. 移动端实时图像处理
由于MobileSAM体积小、推理速度快,可以在移动设备上实现实时图像分割,为移动应用提供强大的视觉功能。
2. 边缘计算设备
在边缘计算设备上,MobileSAM可以高效运行,实现本地化的图像分割,减少对云端服务器的依赖,降低延迟和带宽消耗。
3. 视频分析
MobileSAM的高效推理能力使其适合处理视频流,实现实时视频分割和分析,可应用于监控、自动驾驶等领域。
4. 增强现实
在增强现实应用中,MobileSAM可以快速分割出场景中的对象,为虚拟物体与现实场景的融合提供基础。
5. 医疗影像分析
在医疗影像分析中,MobileSAM可以帮助医生快速分割和识别病变区域,提高诊断效率。
七、注意事项与优化建议
在使用MobileSAM时,有一些注意事项和优化建议可以帮助你获得更好的效果:
1. 模型文件路径
确保正确设置MobileSAM模型文件的路径。如果模型文件不存在,程序会给出下载提示。
2. 图像预处理
对于不同分辨率和格式的图像,MobileSAM可能需要不同的预处理方法。代码中提供了基本的预处理功能,但在实际应用中可能需要根据具体情况进行调整。
3. 分割参数调整
在segment_by_box和segment_by_points方法中,有一些参数可以调整以获得更好的分割效果,例如multimask_output参数可以控制是否生成多个掩码。
4. 性能优化
- 对于需要处理大量图像的场景,可以考虑将模型加载和图像预处理部分进行优化,减少重复计算
- 对于实时应用,可以考虑降低输入图像的分辨率,以提高处理速度
- 在GPU设备上运行可以显著提高推理速度
5. 错误处理
代码中包含了基本的错误处理机制,但在实际应用中可能需要根据具体情况进行扩展,以提高程序的稳定性和可靠性。
八、总结
MobileSAM作为SAM的轻量级版本,通过模型压缩技术实现了体积和速度的大幅优化,同时保持了良好的分割精度。本文介绍了MobileSAM的特点、安装配置方法,并通过一个完整的代码示例展示了如何使用MobileSAM进行图像分割。
与传统的图像分割方法相比,MobileSAM具有以下优势:
- 轻量级:参数量仅为6.8M,适合在资源受限环境中运行
- 高效:推理速度快,可以实现实时处理
- 易用:提供了简单直观的API接口,方便集成到现有项目中
- 通用:可以处理各种类型的图像和对象,无需针对特定任务进行训练
随着移动计算和边缘计算的发展,MobileSAM有望在更多领域得到应用,为计算机视觉技术的普及和推广做出贡献。
如果你对MobileSAM感兴趣,不妨尝试在自己的项目中使用它,相信它会给你带来惊喜!
代码文件链接:https://pan.quark.cn/s/75f131a0a304
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐


所有评论(0)