1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
| class MyDataset(Dataset):
def __init__(self,data_txt,split,image_size,num_classes,random_data):
super(MyDataset, self).__init__()
self.data_txt = data_txt
self.len_dataset = len(data_txt)
self.split = split
self.image_size = image_size
self.num_classes = num_classes
if self.split == 'train':
self.path0 = '/content/drive/MyDrive/语义分割/dataset/jpg'
self.path1 = '/content/drive/MyDrive/语义分割/dataset/L_png'
else:
self.path0 = '/content/drive/MyDrive/语义分割/dataset/val_jpg'
self.path1 = '/content/drive/MyDrive/语义分割/dataset/L_val_png'
def __getitem__(self, index):
if index == 0:
shuffle(self.data_txt)
annotation_line = self.data_txt[index]
name0 = annotation_line.split(';')[0]
name1 = annotation_line.split(';')[1].replace("\n", "")
# 从文件中读取图像
jpg = Image.open(self.path0 + '/' + name0)
png = Image.open(self.path1 + '/' + name1)
# 从文件中读取图像
png = np.array(png)
png[png >= self.num_classes] = 0
# 转化成one_hot编码的形式
seg_labels = np.eye(self.num_classes)[png.reshape([-1])]
seg_labels = seg_labels.reshape((int(self.image_size[1]),int(self.image_size[0]),self.num_classes))
# 将jpg的格式从(h, w, c) => (c, h, w),torch要求图片通道在前
jpg = np.transpose(np.array(jpg),[2,0,1])/255
# 输出
# jpg: (3, 512, 512) 归一化到[0, 1]
# png: (512, 512) 每一个像素点存的是它的类别, 0,1,2,3,4
# seg_labels: (512, 512, Num_classes) one-hot编码格式
return jpg, png, seg_labels
def __len__(self):
return self.train_batches
|