Modular torch💌2——Define your Dataset

"Dataset for classification, semantic sementation, object detection or ..."

Posted by fuhao7i on April 17, 2021

1. Dataset for classification

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class MyDataset(Dataset):  # 创类:MyDataset,继承torch.utils.data.Dataset
    def __init__(self, datatxt, transform=None):
        super(MyDataset, self).__init__()
        fh = open(datatxt, 'r')  # 打开txt,读取内容
        data = []
        for line in fh:  # 按行循环txt文本中的内容
            words = line.split(';')  # 通过指定分隔符对字符串进行切片
            data.append((words[0], int(words[1])))  # 把txt里的内容读入data列表保存,words[0]是图片信息,words[1]是label

        self.data = data
        self.transform = transform

    def __getitem__(self, index):  # 按照索引读取每个元素的具体内容
        fn, label = self.data[index]  # fn是图片path
        img = Image.open('/content/drive/MyDrive/search/mmdetection/data/imagenet-underwater/train/' + fn).convert('RGB')  # from PIL import Image

        if self.transform is not None:  # 是否进行transform
            img = self.transform(img)
        return img, label  # return回哪些内容,在训练时循环读取每个batch,就能获得哪些内容

    def __len__(self):  # 它返回的是数据集的长度,必须有
        return len(self.data)

2. Dataset for semantic segmentation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class MyDataset(Dataset):
    def __init__(self,data_txt,split,image_size,num_classes,random_data):
        super(MyDataset, self).__init__()

        self.data_txt = data_txt
        self.len_dataset = len(data_txt)
        self.split = split
        self.image_size = image_size
        self.num_classes = num_classes

        if self.split == 'train':
            self.path0 = '/content/drive/MyDrive/语义分割/dataset/jpg'
            self.path1 = '/content/drive/MyDrive/语义分割/dataset/L_png'
        else:
            self.path0 = '/content/drive/MyDrive/语义分割/dataset/val_jpg'
            self.path1 = '/content/drive/MyDrive/语义分割/dataset/L_val_png'

    def __getitem__(self, index):
        if index == 0:
            shuffle(self.data_txt)
            
        annotation_line = self.data_txt[index]
        name0 = annotation_line.split(';')[0]
        name1 = annotation_line.split(';')[1].replace("\n", "")
        # 从文件中读取图像
        jpg = Image.open(self.path0 + '/' + name0)
        png = Image.open(self.path1 + '/' + name1)

        # 从文件中读取图像
        png = np.array(png)

        png[png >= self.num_classes] = 0
        
        # 转化成one_hot编码的形式
        seg_labels = np.eye(self.num_classes)[png.reshape([-1])]
        seg_labels = seg_labels.reshape((int(self.image_size[1]),int(self.image_size[0]),self.num_classes))

        # 将jpg的格式从(h, w, c) => (c, h, w),torch要求图片通道在前
        jpg = np.transpose(np.array(jpg),[2,0,1])/255
        
        # 输出
        # jpg: (3, 512, 512) 归一化到[0, 1]
        # png: (512, 512)    每一个像素点存的是它的类别, 0,1,2,3,4
        # seg_labels: (512, 512, Num_classes) one-hot编码格式
        return jpg, png, seg_labels

    def __len__(self):
        return self.train_batches

3. Dataset for object detection