dataloader.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import os
  2. import random
  3. from PIL import Image
  4. from torch.utils.data import Dataset
  5. random.seed(1)
  6. class CustomDataset(Dataset):
  7. # 自定义Dataset类,必须继承Dataset并重写__init__和__getitem__函数
  8. def __init__(self, data_dir, transform=None):
  9. """
  10. 花朵分类任务的Dataset
  11. :param data_dir: str, 数据集所在路径
  12. :param transform: torch.transform,数据预处理,默认不进行预处理
  13. """
  14. # data_info存储所有图片路径和标签(元组的列表),在DataLoader中通过index读取样本
  15. self.data_info = self.get_img_info(data_dir)
  16. self.transform = transform
  17. def __getitem__(self, index):
  18. path_img, label = self.data_info[index]
  19. # 打开图片,默认为PIL,需要转成RGB
  20. img = Image.open(path_img).convert('RGB')
  21. # 如果预处理的条件不为空,应该进行预处理操作
  22. if self.transform is not None:
  23. img = self.transform(img)
  24. return img, label
  25. def __len__(self):
  26. return len(self.data_info)
  27. # 自定义方法,用于返回所有图片的路径以及标签
  28. @staticmethod
  29. def get_img_info(data_dir):
  30. data_info = list()
  31. for root, dirs, _ in os.walk(data_dir):
  32. # 遍历类别
  33. for sub_dir in dirs:
  34. # listdir为列出文件夹下所有文件和文件夹名
  35. img_names = os.listdir(os.path.join(root, sub_dir))
  36. # 过滤出所有后缀名为jpg的文件名(那当然也就把文件夹过滤掉了)
  37. img_names = list(filter(lambda x: x.endswith('.png'), img_names))
  38. # 遍历图片
  39. for i in range(len(img_names)):
  40. img_name = img_names[i]
  41. path_img = os.path.join(root, sub_dir, img_name)
  42. # 在该任务中,文件夹名等于标签名
  43. label = sub_dir
  44. data_info.append((path_img, int(label)))
  45. return data_info