LLAVA代码阅读:数据处理make_supervised_data_module
property@property定义两种获取文本长度的方式,第一种是直接给出文本长度如果有图像则预留128的空间,第二种则是不考虑图像的空间,并以返回数的正负性来标识是否存在图像。这两个方法貌似没有用到。
make_supervised_data_moudle也是train.py文件的一部分
train函数中调用了make_supervised_data_module函数来构造数据模块,可见这是数据集构造部分的最高接口,通过拆解这个函数大概就能看清数据处理部分的全貌
Make_supervised_data_moudle
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
data_path=data_args.data_path,
data_args=data_args)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset,
eval_dataset=None,
data_collator=data_collator)
此函数一共做了两件事:
1.调用LazySupervisedDataset来构建数据集,这里所谓Lazy,即当数据被需要时才加载到内存中
2.调用DataCollatorForSupervisedDataset来构建数据处理器,主要功能是将数据批次化并做一些额外的处理
LazySupervisedDataset类
初始化数据
class LazySupervisedDataset(Dataset)
def __init__(self, data_path: str,
tokenizer: transformers.PreTrainedTokenizer,
data_args: DataArguments):
super(LazySupervisedDataset, self).__init__()
list_data_dict = json.load(open(data_path, "r"))
rank0_print("Formatting inputs...Skip in lazy mode")
self.tokenizer = tokenizer
self.list_data_dict = list_data_dict
self.data_args = data_args
def __len__(self):
return len(self.list_data_dict)
初始化:调用父类(DataSet)初始化方法,并加载json文件内容,获得一个字典列表(一条数据是一个字典)
__len__:获取数据集大小
单条数据格式如下
{
"id": 0,
"image": "llava_image/00453/004539375.jpg",
"conversations": [
{
"from": "human",
"value": "<image>\nRender a clear and concise summary of the photo."
},
{
"from": "gpt",
"value": "select luxury furniture 3 - inch gel memory foam mattress topper"
}
]
}
定义工具方法
@property
def lengths(self):
length_list = []
for sample in self.list_data_dict:
img_tokens = 128 if 'image' in sample else 0
length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
return length_list
@property
def modality_lengths(self):
length_list = []
for sample in self.list_data_dict:
cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
cur_len = cur_len if 'image' in sample else -cur_len
length_list.append(cur_len)
return length_list
定义两种获取文本长度的方式,第一种是直接给出文本长度如果有图像则预留128的空间,第二种则是不考虑图像的空间,并以返回数的正负性来标识是否存在图像。
这两个方法貌似没有用到
核心方法:getitem
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
sources = self.list_data_dict[i]
if isinstance(i, int):
sources = [sources]
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
if 'image' in sources[0]:
image_file = self.list_data_dict[i]['image']
image_folder = self.data_args.image_folder
processor = self.data_args.image_processor
image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
if self.data_args.image_aspect_ratio == 'pad':
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
else:
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args)
else:
sources = copy.deepcopy([e["conversations"] for e in sources])
data_dict = preprocess(
sources,
self.tokenizer,
has_image=('image' in self.list_data_dict[i]))
if isinstance(i, int):
data_dict = dict(input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0])
# image exist in the data
if 'image' in self.list_data_dict[i]:
data_dict['image'] = image
elif self.data_args.is_multimodal:
# image does not exist in the data, but the model is multimodal
crop_size = self.data_args.image_processor.crop_size
data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
return data_dict
分析:
前面的if和assert我没看懂他在干什么,我认为是一个无效操作,猜测是为未来的开发做准备?
Part1:
image_file和Image_folder拼接起来就是图像系统路径,初始化图像处理器,打开图像获得Image
Part2:
对图像进行填充后调用图像处理器processer进行处理获得新的image,并调用自定义的preprocess_multimodel预处理获得sources(如果数据中无image则直接复制对话文本)
Part3:
进一步调用自定义preprocess进行预处理,获得data_dict,并提取其中的input_ids和labels获得新的data_dict
inputs_ids是对对话内容进行分词
labels是一种“有效内容”的标记,大小和input_ids是一样的,只是将模型不需要关注的部分用一个特殊符号注释掉了
具体可以去关注preprocess函数,都是一些非常细节的东西,这里就先不看了
Part4:
为data_dict添加图像
总而言之,调用getitem方法获得的是一个字典,里面有input_ids,labels,image三个字段
DataCollatorForSupervisedDataset类
先放总体代码
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances]
for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(labels,
batch_first=True,
padding_value=IGNORE_INDEX)
input_ids = input_ids[:, :self.tokenizer.model_max_length]
labels = labels[:, :self.tokenizer.model_max_length]
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
if 'image' in instances[0]:
images = [instance['image'] for instance in instances]
if all(x is not None and x.shape == images[0].shape for x in images):
batch['images'] = torch.stack(images)
else:
batch['images'] = images
return batch
总体来看,这个类干的活就是提取input_ids和labels,并对其进行填充,同时把mask返回出来。并当所有图像大小都一样时将它们合成一个大张量。
总之,这个类为将上述data_dict批量化(就是把一堆dict拼在一起),并加了一个attention_mask,并返回了一个batch

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐
所有评论(0)