数据集处理中的segmentor和vos_raw_dataset见上一篇【SAM2代码解析】数据集处理1

在这里插入图片描述

3. vos_sampler.py

3.1 基础模块类

1)SampledFramesAndObjects数据类

定义了视频帧数索引和对象id索引
在这里插入图片描述

2)VOSSampler 抽象基类

在这里插入图片描述

3.2 子类RandomUniformSampler

用于在视频对象分割任务中随机均匀地采样帧和对象

  • 初始化构造
    在这里插入图片描述
  • sample方法
    • 输入参数:video–包含帧信息的VOSVideo对象,segment_loader–用于加载视频帧掩码的加载器

    • 检查帧数量:确保视频帧数量足够采样num_frames帧,如果不够则抛出异常
      在这里插入图片描述

    • 随机选择起始帧进行采样:随机选择一个起始帧索引start,确保从该索引开始的连续num_frames帧都在视频范围内
      在这里插入图片描述

    • 可选的反转帧顺序:以概率反转采样帧的顺序
      在这里插入图片描述

    • 加载第一帧的对象掩码,并检查掩码中是否包含被检测对象
      在这里插入图片描述
      使用segment_loader的load方法,得到对象mask字典,字典的key是调色盘掩码png图像中,不同对象自身对应的像素值,字典的value是将不同对象分离后得到的单对象mask掩码,掩码的值是True和False。
      在这里插入图片描述
      这里的逻辑是,我们使用segment_load方法得到的mask是true.false填充的,此时直接计算sum,若和>0,则说明存在obj。将mask对应的key添加进变量visible中,若visible中的长度>0,则说明是有效采样,退出循环。

    • 随机采样对象ID:从可见对象ID列表中随机采样max_num_objects个对象ID,如果可见对象ID少于最大值,则全部采样
      在这里插入图片描述

    • 返回成基础数据类型
      在这里插入图片描述
      在这里插入图片描述

3.3 子类EvalSampler

# VOSSampler的子类
class EvalSampler(VOSSampler):
    """
    VOS Sampler for evaluation: sampling all the frames and all the objects in a video
    """

    def __init__(
        self,
    ):
        super().__init__()

    def sample(self, video, segment_loader, epoch=None):
        """
        Sampling all the frames and all the objects
        """
        if self.sort_frames:
            # ordered by frame id,按帧号排序
            frames = sorted(video.frames, key=lambda x: x.frame_idx)
        else:
            # use the original order
            frames = video.frames
        # 加载首帧的所有对象ID
        object_ids = segment_loader.load(frames[0].frame_idx).keys()
        if len(object_ids) == 0:
            raise Exception("First frame of the video has no objects")

        # 返回所有帧和对象ID
        return SampledFramesAndObjects(frames=frames, object_ids=object_ids)

4. vos_dataset.py

1)VOSDataset方法

  • 初始化
    在这里插入图片描述
  • __get_datapoint
    • 调用sampler.sample从视频中采样帧和对象
      具体见前面
      在这里插入图片描述
    • 调用construct方法构建数据集
    • 调用transform方法增强数据集
datapoint = self.construct(video, sampled_frms_and_objs, segment_loader)
        for transform in self._transforms:
            datapoint = transform(datapoint, epoch=self.curr_epoch)
        return datapoint
  • construct 数据集构建----构建一个videodatapoint 样例去进行transforms
    • 输入参数:
      在这里插入图片描述
      在这里插入图片描述在这里插入图片描述
    • 加载图像,通过load_images 高效读取RGB图像----调用load_images方法
      在这里插入图片描述
    • 构建VideoDatapoint数据
      • 遍历采样的样本 sampled_frames
      • 处理图像数据,实例化Frame数据类型,并添加进 images列表中
      • 调用segment_loader的load方法,得到单个obj的mask张量
        在这里插入图片描述
      • 检查得到的segments,确保segments中每一个张量都不为空,即都有掩膜
      • 检查segments是否包含全部obj,若不是则表明并不是全程都可监控到对象,若一开始的config设置里,设置的是一直都有检测对象,则创建一个虚假的全0掩码,若没有设置,则不做任何操作。
      • 将segments和先前得到的读取图像等放入Frame数据类型中,并构建成一个采样images列表
        在这里插入图片描述
      • 将上面得到的images组装成videopoint类型
        在这里插入图片描述
  • load_images,高效加载图像,避免重复读取相同路径
    • 可能存在的重复读取场景:多次采样同一视频的不同片段;数据增强时需要多次访问同一原始图像
    • 1、首先创建两个参数all_images 存储最终加载的PIL图像,cache记录已加载的文件路径和索引,避免重复id
    • 2、遍历所有帧,若存在已有张量数据,则直接将该张量数据转换成pil数据
    • 3、若不存在已有张量,则查看cache中是否有记录,若有则直接从all_images中复制
    • 4、若cache中没有记录,则读取文件并在cache中更新缓存的图像信息
# 高效加载图像,利用缓存避免重复读取相同路径
def load_images(frames):
    all_images = []
    cache = {}
    for frame in frames:
        if frame.data is None:
            # Load the frame rgb data from file
            path = frame.image_path
            if path in cache:
                all_images.append(deepcopy(all_images[cache[path]]))
                continue
            with g_pathmgr.open(path, "rb") as fopen:
                all_images.append(PILImage.open(fopen).convert("RGB"))
            cache[path] = len(all_images) - 1
        else:
            # The frame rgb data has already been loaded
            # Convert it to a PILImage
            all_images.append(tensor_2_PIL(frame.data))

    return all_images

6. transforms.py

这里进行简单的讲解+伪代码叙述逻辑

  • 水平翻转 hflip—对指定帧的图像和所有对象掩膜进行水平翻转
def hflip(datapoint, index):
    # 翻转图像
    datapoint.frames[index].data = F.hflip(datapoint.frames[index].data)
    # 翻转每个对象的掩膜
    for obj in datapoint.frames[index].objects:
        if obj.segment is not None:
            obj.segment = F.hflip(obj.segment)
    return datapoint
  • 尺寸计算 get_size_with_aspect_ratio — 根据目标尺寸和最大限制计算保持宽高比的图像尺寸
def get_size_with_aspect_ratio(image_size, size, max_size):
    w, h = image_size
    # 处理最大尺寸限制
    if max_size and (max(w,h)/min(w,h)*size > max_size):
        size = max_size * min(w,h)/max(w,h)
    # 计算新尺寸
    if w < h:
        return (size, int(size * h/w))
    else:
        return (int(size * w/h), size)
  • 调整大小 resize — 调整图像和掩膜尺寸
def resize(datapoint, index, size, max_size, square, v2):
    # 计算目标尺寸
    if square:
        size = (size, size)
    else:
        size = get_size_with_aspect_ratio(cur_size, size, max_size)
    
    # 调整图像
    if v2:
        datapoint.frames[index].data = Fv2.resize(data, size, antialias=True)
    else:
        datapoint.frames[index].data = F.resize(data, size)
    
    # 调整掩膜
    for obj in datapoint.frames[index].objects:
        obj.segment = F.resize(obj.segment[None,None], size).squeeze()
  • 填充 pad --对图像和掩膜进行填充
def pad(datapoint, index, padding, v2):
    # 图像填充
    if len(padding) == 2:
        datapoint.frames[index].data = F.pad(data, (0,0,padding[0],padding[1]))
    else:
        datapoint.frames[index].data = F.pad(data, padding)
    
    # 掩膜填充
    for obj in datapoint.frames[index].objects:
        if v2:
            obj.segment = Fv2.pad(obj.segment, padding)
        else:
            obj.segment = F.pad(obj.segment, padding)
  • RandomHorizontalFlip​ – 随机水平翻转,支持帧间一致性
class RandomHorizontalFlip:
    def __call__(self, datapoint):
        if self.consistent_transform:
            if random.random() < self.p:
                for i in range(len(datapoint.frames)):
                    datapoint = hflip(datapoint, i)
        else:
            for i in range(len(datapoint.frames)):
                if random.random() < self.p:
                    datapoint = hflip(datapoint, i)
        return datapoint
  • RandomResizeAPI—随机调整尺寸
class RandomResizeAPI:
    def __call__(self, datapoint):
        size = random.choice(self.sizes)
        for i in range(len(datapoint.frames)):
            datapoint = resize(datapoint, i, size)
        return datapoint
  • ColorJitter — 随机调整颜色
  • RandomAffine — 随机仿射变换(旋转、平移、缩放、剪切)
  • RandomMosaicVideoAPI​ — 马赛克增强,将图像分割为网格并随机排列。
Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐