Image Enhancement🐽1——AODNet: All-in-One Dehazing Network

"图像去雾"

Posted by fuhao7i on March 29, 2021

paper: AOD-Net: All-in-One Dehazing Network

1. Physical Model:The atmospheric scattering model

\[\large I(x) = J(x)t(x)+A(1-t(x)) \tag {1}\]

其中,$I(x)$是得到的雾图,$J(x)$是场景光辉(清晰的图片),$A$是全局的光照强度,$t(x)$是传播矩阵,如下所示:

\[\large t(x) = e^{- \beta d(x)} \tag 2\]

其中,%\beta%是大气散射系数,$d(x)$是物体到相机的距离。

根据这个模型,我们进行一个简单的推导,就能得到如何由一个模糊图像得到清晰的图像,从而起到图像增强的效果。

\[\large J(x) = {\frac{1}{t(x)}}I(x) - A{\frac{1}{t(x)}} + A \tag 3\]

$I(x)$已经有了,就是我们的模糊图像,接下来我们只需要依靠神经网络求得$t(x)$和$A$就好了。以前的方法都是单独的估计$t(x)$和$A$的值,但这样并不能使在$J(x)$上重构建的误差最小,以致于模型也不是最优的。这里作者重新构造函数为:

\[\large J(x) = K(x)I(x) - K(x) + b, where \\ \large K(x) = {\frac{\frac{1}{t(x)}(I(x)-A)+(A-b)}{I(x)-1}} \tag 4\]

这样$\frac{1}{t(x)}$和$A$就被整合到一个新的变量$K(x)$中了,$b$是一个默认值为1 到常数.

2. Model

如图所示,模型用了5个输出维度全为3的卷积层,并做了3次规律的堆叠。

python实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

class AODnet(nn.Module):   
    def __init__(self):
        super(AODnet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=5, padding=2)
        self.conv4 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=7, padding=3)
        self.conv5 = nn.Conv2d(in_channels=12, out_channels=3, kernel_size=3, padding=1)
        self.b = 1

    def forward(self, x):  
        
        x1 = F.relu(self.conv1(x))
        x2 = F.relu(self.conv2(x1))
        cat1 = torch.cat((x1, x2), 1)
        x3 = F.relu(self.conv3(cat1))
        cat2 = torch.cat((x2, x3),1)
        x4 = F.relu(self.conv4(cat2))
        cat3 = torch.cat((x1, x2, x3, x4),1)
        k = F.relu(self.conv5(cat3))

        if k.size() != x.size():
            raise Exception("k, haze image are different size!")

        output = k * x - k + self.b
        return F.relu(output)

model = AODnet()

out = model(input)

__call__()

3. loss

1
2
3
4
5
6
7
8
#===== Loss function & optimizer =====
criterion = torch.nn.MSELoss()

if args.cuda:
    criterion = criterion.cuda()

optimizer = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=0.0001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=53760, gamma=0.5)

4. 数据集

输入的是模糊图像,标签为groundtruth清晰图像