• 首页 首页 icon
  • 工具库 工具库 icon
    • IP查询 IP查询 icon
  • 内容库 内容库 icon
    • 快讯库 快讯库 icon
    • 精品库 精品库 icon
    • 问答库 问答库 icon
  • 更多 更多 icon
    • 服务条款 服务条款 icon

基础论文笔记一(2018 NIPS)Conditional Adversarial Domain Adaptation CDAN条件对抗域适应

武飞扬头像
羊驼不驼a
帮助1

        对抗性学习已被嵌入到深层网络中,用于学习解纠缠和可转移的领域适应表示。在分类问题中,现有的对抗性域自适应方法可能无法有效地对齐多模态分布的不同域。作者指出当前一些对抗域适应方法仍存在三个问题:1.只考虑了特征对齐,没有考虑标签对齐。2.当数据分布体现出复杂的多模态结构时,对抗性自适应方法可能无法捕获这种多模态结构,也就是说即使判别器完全被混淆,也无法保证此时源域和目标域足够相似。并且这种风险不能通过单独的域鉴别器将特征和类的分布对齐来解决。3.条件域判别器中使用最大最小优化方法也许存在一定的问题,最大最小的对抗网络结构给所有的例子施加了相同的权重,但是有些比较难预测的例子可能会对网络的学习产生影响。 

        因此本文提出条件对抗域适应网络(CDAN)在一定程度上解决以上三个问题:1.通过对齐特征-类别的联合分布解决 2.使用了Multilinear Conditioning多线性调整的方法来解决 3.提出了在目标函数中添加Entropy Conditioning熵调整来解决

一、CDAN结构

学新通

        网络结构主要就是源域和目标域的数据通过深度神经网络Alexnet/ResNet对源域和目标域提取特征f,然后通过源分类器G得出预测标签g,与《Simultaneous Deep Transfer Across Domains and Tasks》相似,不过文中将预测结果g和特征f联合分布,多线性映射输入到与域判别器D中。

二、多线性调整

        文中通过联合变量学新通将域鉴别器D条件设定在分类器预测g上。这种条件域鉴别器可以潜在地解决对抗性域自适应的上述两个挑战。D的一个简单条件是)学新通,然而,对于级联策略,f和g彼此独立,无法完全捕捉特征表示和分类器预测之间的乘法相互作用,这对领域自适应至关重要。

        因此,多线性映射被定义为多个随机向量的外积。给定两个随机向量x和y,联合分布学新通可以通过互协方差学新通[学新通学新通学新通]来建模,其中φ是由一些再生核引起的特征图。这样的内核嵌入使得能够操纵多个随机变量之间的乘法相互作用。

        假设线性映射φ(x)=x和具有C类别数的one-hot标签变量y,可以验证的是,均值映射学新通分别独立计算x和y的平均值。相反,均值映射学新通计算每个C类条件分布学新通的均值。与f⨁g相比,多线性映射f⨂g的优点是可以完全捕捉复杂数据分布背后的多模态结构。然而有个劣势则是梯度爆炸,假如 学新通学新通 分别表示f和g的维度,那么f⨂g的维度就是学新通学新通,二者的维度通常较大,因此很容易发生梯度爆炸。

        因此文中通过随机方法来解决梯度爆炸的问题,即随机抽取f和g上的某些维度做多线性映射:学新通

         其中⨀表示逐元素运算,学新通/学新通别表示随机矩阵,其只被采样一次并且在训练过程中固定。d则是需要采样的维度。经过作者论证,在学新通上进行内积近似等于学新通上进行内积,因此,可以直接采用学新通用于计算以方便效率。因此,我们将条件域鉴别器D使用的条件策略定义为:

学新通

        其中4096是典型深度网络(例如AlexNet)中的最大单元数,如果多线性映射学新通的维数大于4096,则采用随机策略进行多线性映射,反之使用正常的多线性映射。

三、熵调整     

        条件域鉴别器的极大极小问题对不同的例子具有同等的重要性,而具有不确定预测的难以转移的例子可能会恶化条件对抗性自适应过程。为了实现安全转移,我们通过熵标准学新通来量化分类器预测的不确定性,其中c是类的数量,学新通是将示例预测到类c的概率。我们通过将条件域鉴别器的每个训练示例重新加权熵感知权重学新通。用于提高可转移性的CDAN的熵调节变体(CDAN E)被公式化为:

学新通

四、总体优化目标

学新通

学新通

        即联合最小化式子(1)中源分类器G和特征提取器F,最小化式子(2)中鉴别器D,即最大化特征提取器和源分类器G。即:

学新通

        其中学新通 是域特定特征表示f和分类器预测g的联合变量。通过将条件域鉴别器的每个训练示例重新加权熵感知权重后,最终优化目标是:

学新通

代码如下:

  1.  
    class ConditionalDomainAdversarialLoss(nn.Module):
  2.  
    def __init__(self, domain_discriminator: nn.Module, entropy_conditioning: Optional[bool] = False,
  3.  
    randomized: Optional[bool] = False, num_classes: Optional[int] = -1,
  4.  
    features_dim: Optional[int] = -1, randomized_dim: Optional[int] = 1024,
  5.  
    reduction: Optional[str] = 'mean', sigmoid=True):
  6.  
    super(ConditionalDomainAdversarialLoss, self).__init__()
  7.  
    self.domain_discriminator = domain_discriminator
  8.  
    self.grl = WarmStartGradientReverseLayer(alpha=1., lo=0., hi=1., max_iters=1000, auto_step=True)
  9.  
    self.entropy_conditioning = entropy_conditioning
  10.  
    self.sigmoid = sigmoid
  11.  
    self.reduction = reduction
  12.  
     
  13.  
    #是否采用随机策略进行多线性映射
  14.  
    if randomized:
  15.  
    assert num_classes > 0 and features_dim > 0 and randomized_dim > 0
  16.  
    self.map = RandomizedMultiLinearMap(features_dim, num_classes, randomized_dim)
  17.  
    else:
  18.  
    self.map = MultiLinearMap()
  19.  
    self.bce = lambda input, target, weight: F.binary_cross_entropy(input, target, weight,
  20.  
    reduction=reduction) if self.entropy_conditioning \
  21.  
    else F.binary_cross_entropy(input, target, reduction=reduction)
  22.  
    self.domain_discriminator_accuracy = Nonedef forward(self, g_s: torch.Tensor, f_s: torch.Tensor, g_t: torch.Tensor, f_t: torch.Tensor) -> torch.Tensor:
  23.  
     
  24.  
    #将f和g进行联合分布,多线性映射后传给域分类器d计算域分类损失
  25.  
    f = torch.cat((f_s, f_t), dim=0)
  26.  
    g = torch.cat((g_s, g_t), dim=0)
  27.  
    g = F.softmax(g, dim=1).detach()
  28.  
    h = self.grl(self.map(f, g))
  29.  
    d = self.domain_discriminator(h)
  30.  
     
  31.  
    #熵调整
  32.  
    weight = 1.0 torch.exp(-entropy(g))
  33.  
    batch_size = f.size(0)
  34.  
    weight = weight / torch.sum(weight) * batch_size
  35.  
     
  36.  
    #是否采用二分类交叉熵损失函数
  37.  
    if self.sigmoid:
  38.  
    d_label = torch.cat((
  39.  
    torch.ones((g_s.size(0), 1)).to(g_s.device),
  40.  
    torch.zeros((g_t.size(0), 1)).to(g_t.device),
  41.  
    ))
  42.  
    self.domain_discriminator_accuracy = binary_accuracy(d, d_label)
  43.  
    if self.entropy_conditioning:
  44.  
    return F.binary_cross_entropy(d, d_label, weight.view_as(d), reduction=self.reduction)
  45.  
    else:
  46.  
    return F.binary_cross_entropy(d, d_label, reduction=self.reduction)
  47.  
    else:
  48.  
    d_label = torch.cat((
  49.  
    torch.ones((g_s.size(0), )).to(g_s.device),
  50.  
    torch.zeros((g_t.size(0), )).to(g_t.device),
  51.  
    )).long()
  52.  
    self.domain_discriminator_accuracy = accuracy(d, d_label)
  53.  
    if self.entropy_conditioning:
  54.  
    raise NotImplementedError("entropy_conditioning")
  55.  
    return F.cross_entropy(d, d_label, reduction=self.reduction)
  56.  
     
  57.  
    #随机策略多线性映射
  58.  
    class RandomizedMultiLinearMap(nn.Module):
  59.  
    def __init__(self, features_dim: int, num_classes: int, output_dim: Optional[int] = 1024):
  60.  
    super(RandomizedMultiLinearMap, self).__init__()
  61.  
    self.Rf = torch.randn(features_dim, output_dim)
  62.  
    self.Rg = torch.randn(num_classes, output_dim)
  63.  
    self.output_dim = output_dim
  64.  
     
  65.  
    def forward(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
  66.  
    f = torch.mm(f, self.Rf.to(f.device))
  67.  
    g = torch.mm(g, self.Rg.to(g.device))
  68.  
    output = torch.mul(f, g) / np.sqrt(float(self.output_dim))
  69.  
    return
  70.  
     
  71.  
    #正常多线性映射
  72.  
    class MultiLinearMap(nn.Module):
  73.  
    def __init__(self):
  74.  
    super(MultiLinearMap, self).__init__()
  75.  
     
  76.  
    def forward(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
  77.  
    batch_size = f.size(0)
  78.  
    output = torch.bmm(g.unsqueeze(2), f.unsqueeze(1))
  79.  
    return output.view(batch_size, -1)
学新通

这篇好文章是转载于:学新通技术网

  • 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
  • 本站站名: 学新通技术网
  • 本文地址: /boutique/detail/tanhhifeeh
系列文章
更多 icon
同类精品
更多 icon
继续加载