【零废话系列】python pytorch 读取 Dogs In The Wild 数据集(附数据集网址)
零废话,直接上数据集链接和代码
·
数据集下载网址
图像分类系列【Dogs-in-the-wild 犬类图像数据集】
pytorch读取数据集
import os
import json
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
class DogsInTheWildDataset(Dataset):
def __init__(self, annotations_file, categories_file, img_dir, transform):
self.img_labels = self.parse_json(annotations_file)
self.categories = self.parse_json(categories_file)
self.img_dir = img_dir
self.transform = transform
def parse_json(self, json_file):
with open(json_file, 'r') as f:
data = json.load(f)
return data
def __len__(self):
return len(self.img_labels['annotations'])
def __getitem__(self, idx):
annotation = self.img_labels['annotations'][idx]
category_id = annotation['category id']
img_name = annotation['name']
img_path = os.path.join(self.img_dir, img_name)
image = Image.open(img_path).convert('RGB')
image = self.transform(image)
label = self.get_label(category_id)
return image, label
def get_label(self, category_id):
for category in self.categories['categories']:
if category['category id'] == category_id:
label = torch.from_numpy(np.array(category['category id'])).long()
return label
raise ValueError(f'Category ID {category_id} not found.')
读取的时候和平常使用pytorch
读取数据集的操作一致
train_dataset = DogsInTheWildDataset(
annotations_file='train.json',
categories_file='category.json',
img_dir='image/',
transform=transform
)
train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=8)

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐
所有评论(0)