mmdetection 数据集准备,包括训练数据集和测试数据集
1. 训练数据集
在train.py文件中,构建训练数据集。
1
2
3
from mmdet.datasets import build_dataset
train_dataset = build_dataset(cfg.data.train)
1.1 参数cfg.data.train
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# dataset settings
dataset_type = 'VOCDataset'
data_root ='data/VOCdevkit/VOC2007/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='Resize',
# img_scale=(1333, 800),
img_scale=(800,600),
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
train=dict(
type=dataset_type,
ann_file=data_root + 'ImageSets/Main/train.txt',
img_prefix=data_root ,
pipeline=train_pipeline)
1.2 函数build_dataset()
其中,build_dataset函数在mmdet文件夹的datasets文件夹下的builder.py。
./mmdet/datasets/builder.py
1
2
3
4
5
6
7
8
9
10
11
12
def build_dataset(cfg, default_args=None):
if isinstance(cfg, (list, tuple)):
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
elif cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(
build_dataset(cfg['dataset'], default_args), cfg['times'])
elif isinstance(cfg['ann_file'], (list, tuple)):
dataset = _concat_dataset(cfg, default_args)
else:
dataset = build_from_cfg(cfg, DATASETS, default_args)
return dataset
这里我们会执行dataset = build_from_cfg(cfg, DATASETS, default_args)
. 我们来具体看一下。
1.2.1 全局变量(注册表)DATASETS的构建
mmdet/datasets/registry.py
这里将类Registry
实例化,注册到注册表中。
1
2
3
4
from mmdet.utils import Registry
DATASETS = Registry('dataset')
PIPELINES = Registry('pipeline')
然后我们具体看一下Registry
类:
mmdet/utils/registry.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import inspect
import mmcv
class Registry(object):
def __init__(self, name):
self._name = name #Registry类的名字
self._module_dict = dict() #创建一个模块字典
def __repr__(self):
format_str = self.__class__.__name__ + '(name={}, items={})'.format(
self._name, list(self._module_dict.keys()))
return format_str
@property #将方法修饰为类属性
def name(self):
return self._name
@property
def module_dict(self):
return self._module_dict
def get(self, key):
return self._module_dict.get(key, None)
def _register_module(self, module_class): #Registry类的主要方法,用来注册模块。
"""Register a module.
Args:
module (:obj:`nn.Module`): Module to be registered.
"""
if not inspect.isclass(module_class):
raise TypeError('module must be a class, but got {}'.format(
type(module_class)))
module_name = module_class.__name__
if module_name in self._module_dict:
raise KeyError('{} is already registered in {}'.format(
module_name, self.name))
self._module_dict[module_name] = module_class
def register_module(self, cls):
self._register_module(cls)
return cls
然后我们点开mmdet/datasets
下的数据集定义py文件。这里我们以voc.py
数据集为例。
1.2.2 数据集定义文件 voc.py
这里我们用修饰符@修饰: @DATASETS.register_module
,就是将类VOCDataset
作为参数传入DATASETS.register_module()
函数。
1
2
@DATASETS.register_module
class VOCDataset(XMLDataset):
1.2.2.1 父类XMLDataset
VocDataset 继承自 XMLDataset, XMLDataset 又继承自 CustomDataset, CustomDataset继承自 Dataset from torch.utils.data import Dataset
.
1
2
3
4
5
class XMLDataset(CustomDataset):
def load_annotations(self, ann_file): # 用于初始化VocDataset的,主要初始化算imgs的位置,imgs和annotations的关联关系等
...
def get_ann_info(self, idx): # 用于在train/val/test中调取annotations
...
然后看一下在训练中时如何调用的,XMLDataset继承CustomDatset,在train过程中调用的时CustomDataset的函数:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# ./mmdnet/datasets/custom.py
class CustomDataset(Dataset):
def __init__(...):
...
# load annotations (and proposals)
self.img_infos = self.load_annotations(self.ann_file)
...
# processing pipeline
self.pipeline = Compose(pipeline)
def __getitem__(self, idx): # 重载Dataset的,就是在train过程中加载data和target的,其中用了prepare_train_img
if self.test_mode:
return self.prepare_test_img(idx)
while True:
data = self.prepare_train_img(idx)
if data is None:
idx = self._rand_another(idx)
continue
return data
def prepare_train_img(self, idx):
img_info = self.img_infos[idx]
ann_info = self.get_ann_info(idx)
results = dict(img_info=img_info, ann_info=ann_info)
if self.proposals is not None:
results['proposals'] = self.proposals[idx]
self.pre_pipeline(results)
return self.pipeline(results)
根据上述代码,可以看到CustomDataset中使用load_annotations初始化dataset,使用get_ann_info加载target,所以继承CustomDataset需要定义这两个函数,就可以完成自己Dataset的定义。
1.3 函数 build_from_cfg()
我们继续来看dataset = build_from_cfg(cfg, DATASETS, default_args)
.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
obj: The constructed object.
"""
assert isinstance(cfg, dict) and 'type' in cfg
assert isinstance(default_args, dict) or default_args is None
args = cfg.copy()
obj_type = args.pop('type')
if mmcv.is_str(obj_type):
obj_type = registry.get(obj_type) #将cfg中cfg中的'VOCDataset'变成VOCDataset这个类赋值给obj_cls
if obj_type is None:
raise KeyError('{} is not in the {} registry'.format(
obj_type, registry.name))
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_type(**args) #实例化VOCDataset并返回,一路返回给test.py中的dataset
pytorch 的数据加载到模型的操作顺序是这样的:
① 创建一个 Dataset 对象
② 创建一个 DataLoader 对象
③ 循环这个 DataLoader 对象,将img, label加载到模型中进行训练
1
2
3
4
5
6
7
# https://blog.csdn.net/g11d111/article/details/81504637 非常好的Dataloader详解
dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
for img, label in dataloader:
...
2. 测试数据集
和训练数据集的构建相同。
tips1: isinstance(object, classinfo)函数详解:
* object -- 实例对象。
* classinfo -- 可以是直接或间接类名、基本类型或者由它们组成的元组。
classinfo可以是:int,float,bool,complex,str(字符串),list,dict(字典),set,tuple,具体的类
判断对象object的类型是否和classinfo的类型相同。相同则返回True,否则返回False。
isinstance() 与 type() 区别: type() 不会认为子类是一种父类类型,不考虑继承关系。 isinstance() 会认为子类是一种父类类型,考虑继承关系。 如果要判断两个类型是否相同推荐使用 isinstance()。
e.g
1
2
3
>>> a = 2
>>> isinstance (a,(str,int,list)) # 是元组中的一个返回 True
True