数据集下载网址

图像分类系列【Dogs-in-the-wild 犬类图像数据集】

pytorch读取数据集

import os
import json
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset

class DogsInTheWildDataset(Dataset):
    def __init__(self, annotations_file, categories_file, img_dir, transform):
        self.img_labels = self.parse_json(annotations_file)
        self.categories = self.parse_json(categories_file)
        self.img_dir = img_dir
        self.transform = transform

    def parse_json(self, json_file):
        with open(json_file, 'r') as f:
            data = json.load(f)
        return data

    def __len__(self):
        return len(self.img_labels['annotations'])

    def __getitem__(self, idx):
        annotation = self.img_labels['annotations'][idx]
        category_id = annotation['category id']
        img_name = annotation['name']
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        image = self.transform(image)
        label = self.get_label(category_id)
        return image, label

    def get_label(self, category_id):
        for category in self.categories['categories']:
            if category['category id'] == category_id:
                label = torch.from_numpy(np.array(category['category id'])).long()
                return label
    
        raise ValueError(f'Category ID {category_id} not found.')

读取的时候和平常使用pytorch读取数据集的操作一致

train_dataset = DogsInTheWildDataset(
	            	annotations_file='train.json',
	            	categories_file='category.json',
	            	img_dir='image/',
	            	transform=transform
	        	)
train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=8)
Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐