123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 |
- import os
- import random
- from PIL import Image
- from torch.utils.data import Dataset
- random.seed(1)
- class CustomDataset(Dataset):
- # 自定义Dataset类,必须继承Dataset并重写__init__和__getitem__函数
- def __init__(self, data_dir, transform=None):
- """
- 花朵分类任务的Dataset
- :param data_dir: str, 数据集所在路径
- :param transform: torch.transform,数据预处理,默认不进行预处理
- """
- # data_info存储所有图片路径和标签(元组的列表),在DataLoader中通过index读取样本
- self.data_info = self.get_img_info(data_dir)
- self.transform = transform
- def __getitem__(self, index):
- path_img, label = self.data_info[index]
- # 打开图片,默认为PIL,需要转成RGB
- img = Image.open(path_img).convert('RGB')
- # 如果预处理的条件不为空,应该进行预处理操作
- if self.transform is not None:
- img = self.transform(img)
- return img, label
- def __len__(self):
- return len(self.data_info)
- # 自定义方法,用于返回所有图片的路径以及标签
- @staticmethod
- def get_img_info(data_dir):
- data_info = list()
- for root, dirs, _ in os.walk(data_dir):
- # 遍历类别
- for sub_dir in dirs:
- # listdir为列出文件夹下所有文件和文件夹名
- img_names = os.listdir(os.path.join(root, sub_dir))
- # 过滤出所有后缀名为jpg的文件名(那当然也就把文件夹过滤掉了)
- img_names = list(filter(lambda x: x.endswith('.png'), img_names))
- # 遍历图片
- for i in range(len(img_names)):
- img_name = img_names[i]
- path_img = os.path.join(root, sub_dir, img_name)
- # 在该任务中,文件夹名等于标签名
- label = sub_dir
- data_info.append((path_img, int(label)))
- return data_info
|