书接上回LLAVA代码阅读:train.py-CSDN博客

make_supervised_data_moudle也是train.py文件的一部分

train函数中调用了make_supervised_data_module函数来构造数据模块,可见这是数据集构造部分的最高接口,通过拆解这个函数大概就能看清数据处理部分的全貌

Make_supervised_data_moudle

def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
                                data_args) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
                                data_path=data_args.data_path,
                                data_args=data_args)
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    return dict(train_dataset=train_dataset,
                eval_dataset=None,
                data_collator=data_collator)

此函数一共做了两件事:

1.调用LazySupervisedDataset来构建数据集,这里所谓Lazy,即当数据被需要时才加载到内存中

2.调用DataCollatorForSupervisedDataset来构建数据处理器,主要功能是将数据批次化并做一些额外的处理

LazySupervisedDataset类

初始化数据

class LazySupervisedDataset(Dataset)
    def __init__(self, data_path: str,
                 tokenizer: transformers.PreTrainedTokenizer,
                 data_args: DataArguments):
        super(LazySupervisedDataset, self).__init__()
        list_data_dict = json.load(open(data_path, "r"))

        rank0_print("Formatting inputs...Skip in lazy mode")
        self.tokenizer = tokenizer
        self.list_data_dict = list_data_dict
        self.data_args = data_args

    def __len__(self):
        return len(self.list_data_dict)

初始化:调用父类(DataSet)初始化方法,并加载json文件内容,获得一个字典列表(一条数据是一个字典)

__len__:获取数据集大小

单条数据格式如下

{
 "id": 0,
 "image": "llava_image/00453/004539375.jpg",
 "conversations": [
      {
 "from": "human",
 "value": "<image>\nRender a clear and concise summary of the photo."
      },
      {
 "from": "gpt",
 "value": "select luxury furniture 3 - inch gel memory foam mattress topper"
      }
    ]
  }

 定义工具方法

@property
    def lengths(self):
        length_list = []
        for sample in self.list_data_dict:
            img_tokens = 128 if 'image' in sample else 0
            length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
        return length_list

    @property
    def modality_lengths(self):
        length_list = []
        for sample in self.list_data_dict:
            cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
            cur_len = cur_len if 'image' in sample else -cur_len
            length_list.append(cur_len)
        return length_list

定义两种获取文本长度的方式,第一种是直接给出文本长度如果有图像则预留128的空间,第二种则是不考虑图像的空间,并以返回数的正负性来标识是否存在图像。

这两个方法貌似没有用到

核心方法:getitem

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        sources = self.list_data_dict[i]
        if isinstance(i, int):
            sources = [sources]
        assert len(sources) == 1, "Don't know why it is wrapped to a list"  # FIXME
        if 'image' in sources[0]:
            image_file = self.list_data_dict[i]['image']
            image_folder = self.data_args.image_folder
            processor = self.data_args.image_processor
            image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
            if self.data_args.image_aspect_ratio == 'pad':
                def expand2square(pil_img, background_color):
                    width, height = pil_img.size
                    if width == height:
                        return pil_img
                    elif width > height:
                        result = Image.new(pil_img.mode, (width, width), background_color)
                        result.paste(pil_img, (0, (width - height) // 2))
                        return result
                    else:
                        result = Image.new(pil_img.mode, (height, height), background_color)
                        result.paste(pil_img, ((height - width) // 2, 0))
                        return result
                image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
                image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
            else:
                image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
            sources = preprocess_multimodal(
                copy.deepcopy([e["conversations"] for e in sources]),
                self.data_args)
        else:
            sources = copy.deepcopy([e["conversations"] for e in sources])
        data_dict = preprocess(
            sources,
            self.tokenizer,
            has_image=('image' in self.list_data_dict[i]))
        if isinstance(i, int):
            data_dict = dict(input_ids=data_dict["input_ids"][0],
                             labels=data_dict["labels"][0])

        # image exist in the data
        if 'image' in self.list_data_dict[i]:
            data_dict['image'] = image
        elif self.data_args.is_multimodal:
            # image does not exist in the data, but the model is multimodal
            crop_size = self.data_args.image_processor.crop_size
            data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
        return data_dict

分析: 

前面的if和assert我没看懂他在干什么,我认为是一个无效操作,猜测是为未来的开发做准备?

Part1:

image_file和Image_folder拼接起来就是图像系统路径,初始化图像处理器,打开图像获得Image

Part2:

对图像进行填充后调用图像处理器processer进行处理获得新的image,并调用自定义的preprocess_multimodel预处理获得sources(如果数据中无image则直接复制对话文本)

Part3:

进一步调用自定义preprocess进行预处理,获得data_dict,并提取其中的input_ids和labels获得新的data_dict

inputs_ids是对对话内容进行分词

labels是一种“有效内容”的标记,大小和input_ids是一样的,只是将模型不需要关注的部分用一个特殊符号注释掉了  

具体可以去关注preprocess函数,都是一些非常细节的东西,这里就先不看了

Part4:

为data_dict添加图像

总而言之,调用getitem方法获得的是一个字典,里面有input_ids,labels,image三个字段

DataCollatorForSupervisedDataset类

先放总体代码

@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances]
                                  for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id)
        labels = torch.nn.utils.rnn.pad_sequence(labels,
                                                 batch_first=True,
                                                 padding_value=IGNORE_INDEX)
        input_ids = input_ids[:, :self.tokenizer.model_max_length]
        labels = labels[:, :self.tokenizer.model_max_length]
        batch = dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )

        if 'image' in instances[0]:
            images = [instance['image'] for instance in instances]
            if all(x is not None and x.shape == images[0].shape for x in images):
                batch['images'] = torch.stack(images)
            else:
                batch['images'] = images

        return batch

总体来看,这个类干的活就是提取input_ids和labels,并对其进行填充,同时把mask返回出来。并当所有图像大小都一样时将它们合成一个大张量。

总之,这个类为将上述data_dict批量化(就是把一堆dict拼在一起),并加了一个attention_mask,并返回了一个batch

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐