深度对抗神经网络(DANN)笔记
一 总体介绍
DANN是一种迁移学习方法,是对抗迁移学习方法的代表方法。基本结构由特征提取层f,分类器部分c和对抗部分d组成,其中f和c其实就是一个标准的分类模型,通过GAN(生成对抗网络)得到迁移对抗模型的灵感。但此时生成的不是假样本,而是假特征,一个足以让目标域和源域区分不开的假特征。
而领域判别器D其实是个标准的二分类分类器,0是源域,1是目标域。它本身的目标是区分源域和目标域,而我们想要的结果是使判别器越来越分不出数据特征来自源域还是目标域,感觉起来这很矛盾。但其实我们引入一个梯度反转层就可以完美避免这个问题。
引入梯度反转层(GRL),分类器c和判别器d朝着优化分类器效果的方向反向传播优化梯度。有了梯度反转层,简单的说就是判别器d反向传播时,梯度更新前引入了一个“ - ”。这样就可以同时满足判别器和我们需求的一致性。
下面是DANN的基本网络图。
二 UDTL代码库中的DANN网络
对抗网络部分代码。
-
from torch import nn
-
import numpy as np
-
-
-
def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0):
-
return np.float(2.0 * (high - low) / (1.0 np.exp(-alpha * iter_num / max_iter)) - (high - low) low)
-
### 如果self.trade_off_adversarial == 'Step',则调用此函数得到coeff的值,不然self.trade_off_adversarial == 'Cons',则coeff是个固定的值
-
###coeff——————coeff = self.lam_adversarial 其中trade_off_adverial充当域分类器部分的学习率随着迭代过程会逐渐递减——————学习率
-
-
def grl_hook(coeff):#补充连接,因为是引入块,需要连接到model层的梯度: grad.clone()
-
def fun1(grad):
-
return -coeff * grad.clone()
-
return fun1
-
-
class AdversarialNet(nn.Module):
-
def __init__(self, in_feature, hidden_size,max_iter=10000.0, trade_off_adversarial='Step', lam_adversarial=1.0):
-
super(AdversarialNet, self).__init__()
-
self.ad_layer1 = nn.Sequential(
-
nn.Linear(in_feature, hidden_size),
-
nn.ReLU(inplace=True),
-
nn.Dropout(),
-
)
-
self.ad_layer2 = nn.Sequential(
-
nn.Linear(hidden_size, hidden_size),
-
nn.ReLU(inplace=True),
-
nn.Dropout(),
-
)
-
self.ad_layer3 = nn.Linear(hidden_size, 1)
-
self.sigmoid = nn.Sigmoid()
-
# parameters
-
self.iter_num = 0
-
self.alpha = 10
-
self.low = 0.0
-
self.high = 1.0
-
self.max_iter = max_iter
-
self.trade_off_adversarial = trade_off_adversarial
-
self.lam_adversarial = lam_adversarial
-
self.__in_features = 1
-
-
def forward(self, x):
-
if self.training:
-
self.iter_num = 1
-
if self.trade_off_adversarial == 'Cons':
-
coeff = self.lam_adversarial
-
elif self.trade_off_adversarial == 'Step':
-
coeff = calc_coeff(self.iter_num, self.high, self.low, self.alpha, self.max_iter) #学习率
-
else:
-
raise Exception("loss not implement")
-
x = x * 1.0
-
x.register_hook(grl_hook(coeff))#register_hook的作用:即对x求导时,对x的导数进行操作,并且register_hook的参数只能以函数的形式传过去,
-
#grl_hook(coeff)则返回的是梯度 * “——”梯度反转层作用
-
##register_hook的作用:对x求导,并将梯度保存下来,这样可以作为参数通过优化器通过反向传播过程进行更新优化,实现DANN所需效果
-
x = self.ad_layer1(x)
-
x = self.ad_layer2(x)
-
y = self.ad_layer3(x)
-
y = self.sigmoid(y)
-
return y
-
-
def output_num(self):
-
return self.__in_features#输出通道是1 代表域判别值0:源域 1:目标域
以上是赵志斌老师UDTL代码中的对抗网络部分的介绍,用于故障诊断数据。
这里要声明的是对于AdversarialNet网络而言,一维的数据和二维数据都可以拿来直接使用,实质上它仅仅是多出了一个二分类判别器和一个梯度反转层而已。
ZhaoZhibin/UDTL: Source codes for the paper "Applications of Unsupervised Deep Transfer Learning to Intelligent Fault Diagnosis: A Survey and Comparative Study" published in TIM (github.com)https://github.com/ZhaoZhibin/UDTL
三 网络的其他写法
这里博主还找到了另外一种的对抗网络写法。
-
import torch.nn as nn
-
from functions import ReverseLayerF#从functions中导入梯度反转层这一类
-
-
-
class CNNModel(nn.Module):
-
-
def __init__(self):
-
super(CNNModel, self).__init__()
-
self.feature = nn.Sequential()
-
self.feature.add_module('f_conv1', nn.Conv2d(3, 64, kernel_size=5))#这里是因为数据是mnist数据所以输入通道为3
-
self.feature.add_module('f_bn1', nn.BatchNorm2d(64))
-
self.feature.add_module('f_pool1', nn.MaxPool2d(2))
-
self.feature.add_module('f_relu1', nn.ReLU(True))
-
self.feature.add_module('f_conv2', nn.Conv2d(64, 50, kernel_size=5))
-
self.feature.add_module('f_bn2', nn.BatchNorm2d(50))
-
self.feature.add_module('f_drop1', nn.Dropout2d())
-
self.feature.add_module('f_pool2', nn.MaxPool2d(2))
-
self.feature.add_module('f_relu2', nn.ReLU(True))
-
#上面是backbone部分也是网络的特征提取部分
-
self.class_classifier = nn.Sequential()
-
self.class_classifier.add_module('c_fc1', nn.Linear(50 * 4 * 4, 100))
-
self.class_classifier.add_module('c_bn1', nn.BatchNorm1d(100))
-
self.class_classifier.add_module('c_relu1', nn.ReLU(True))
-
self.class_classifier.add_module('c_drop1', nn.Dropout())
-
self.class_classifier.add_module('c_fc2', nn.Linear(100, 100))
-
self.class_classifier.add_module('c_bn2', nn.BatchNorm1d(100))
-
self.class_classifier.add_module('c_relu2', nn.ReLU(True))
-
self.class_classifier.add_module('c_fc3', nn.Linear(100, 10))
-
self.class_classifier.add_module('c_softmax', nn.LogSoftmax(dim=1))
-
#上面是源域的分类器部分,只要是要对源域数据进行有效的分类
-
self.domain_classifier = nn.Sequential()
-
self.domain_classifier.add_module('d_fc1', nn.Linear(50 * 4 * 4, 100))
-
self.domain_classifier.add_module('d_bn1', nn.BatchNorm1d(100))
-
self.domain_classifier.add_module('d_relu1', nn.ReLU(True))
-
self.domain_classifier.add_module('d_fc2', nn.Linear(100, 2))
-
self.domain_classifier.add_module('d_softmax', nn.LogSoftmax(dim=1))
-
#上面是领域判别器部分,主要任务是要区分出源域和目标域
-
def forward(self, input_data, alpha):
-
input_data = input_data.expand(input_data.data.shape[0], 3, 28, 28)
-
feature = self.feature(input_data)
-
feature = feature.view(-1, 50 * 4 * 4)
-
reverse_feature = ReverseLayerF.apply(feature, alpha)
-
#前向网络中注意到,reverse_feature是通过ReverseLayerF.apply将feature进行反向的梯度计算。
-
class_output = self.class_classifier(feature)
-
domain_output = self.domain_classifier(reverse_feature)
-
#并将处理过的reverse_feature特征给domain_classifer进行域判别。
-
return class_output, domain_output
对应的ReverseLayerF部分代码:
-
from torch.autograd import Function
-
-
-
class ReverseLayerF(Function):
-
-
-
def forward(ctx, x, alpha):
-
ctx.alpha = alpha
-
-
return x.view_as(x)
-
-
-
def backward(ctx, grad_output):
-
output = grad_output.neg() * ctx.alpha
-
####grad_output.neg()梯度取负操作,反向内容的核心。
-
-
return output, None
代码地址:https://github.com/fungtion/DANN_py3
通过介绍以上两种不同写法的对抗网络模型,相信你也可以看到对抗网络的核心其实很简单。
仅仅是多出了一个领域判别器和一个梯度反转层。
但采用对抗网络作为迁移网络方法又能很好的解决很多域迁移领域的问题,特别是在域之间的差异较大的情况时,往往要比以MMD(最大均值差异)为代表的度量学习方法效果要好。
以上是我学习过程中对DANN进行的一些总结工作,欢迎评论区讨论交流。
这篇好文章是转载于:学新通技术网
- 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
- 本站站名: 学新通技术网
- 本文地址: /boutique/detail/tanhiagecj
-
photoshop保存的图片太大微信发不了怎么办
PHP中文网 06-15 -
word里面弄一个表格后上面的标题会跑到下面怎么办
PHP中文网 06-20 -
photoshop扩展功能面板显示灰色怎么办
PHP中文网 06-14 -
《学习通》视频自动暂停处理方法
HelloWorld317 07-05 -
TikTok加速器哪个好免费的TK加速器推荐
TK小达人 10-01 -
Android 11 保存文件到外部存储,并分享文件
Luke 10-12 -
微信公众号没有声音提示怎么办
PHP中文网 03-31 -
excel下划线不显示怎么办
PHP中文网 06-23 -
微信运动停用后别人还能看到步数吗
PHP中文网 07-22 -
excel打印预览压线压字怎么办
PHP中文网 06-22