SAN
文献阅读(3)SAN
标题:Partial Transfer Learning with Selective Adversarial Networks
清华大学龙明盛团队 2018 年发表在计算机视觉顶级会议 CVPR 上的文章提出了一个选择性迁移网络 (Partial Transfer Learning)。作者认为,在大数据时代,通常我们会有大量的源域数据。这些源域数据比目标域数据,在类别上通常都是丰富的。比如基于 ImageNet训练的图像分类器,必然是针对几千个类别进行的分类。我们实际用的时候,目标域往往只是其中的一部分类别。这样就会带来一个问题:那些只存在于源域中的类别在迁移时,会对迁移结果产生负迁移影响。
这种情况通常来说是非常普遍的。因此,就要求相应的迁移学习方法能够对目标域,选择相似的源域样本 (类别),同时也要避免负迁移。但是目标域通常是没有标签的,不知道和源域中哪个类别更相似。作者指出这个问题叫做 partial transfer learning。这个 partial,就是只迁移源域中那部分和目标域相关的样本。下图展示了部分迁移学习的思想。
image-20230418192557331 作者提出了一个叫做 Selective Adversarial Networks (SAN) [Cao et al., 2017] 的方法来处理 partial transfer 问题。在 partial 问题中,传统的对抗网络不再适用。所以就需要对进行一些修改,使得它能够适用于 partial 问题。因为不知道目标域的标签,也就没办法知道到底是源域中哪些类是目标域的。为了达到这个目的,作者对目标域按照类别分组,把原来的一整个判别器分成了个:,每一个子判别器都对它所在的第 k 个类进行判别。作者观察到了这样的事实:对于每个数据点来说,分类器的预测结果其实是对于整个类别空间的一个概率分布。因此,在进行对抗时,需要考虑每个样本属于每个类别的影响。这个影响就是由概率来刻画。所以作者提出了一个概率权重的判别器:
1 亮点
2 模型示意图
3 实验内容
4 方法
5 代码
5.1 train_san_w_t.py
if __name__ == "__main__":
# args中包括[gpu_id, net, dset, s_dset_path, t_dset_path, test_interval, snapshot_interval, output_dir]
# gpu_id:
# net: 应该是指选择提取特征的网络层;
# dset: 使用的dataset
# s_dset_path: source_dataset路径
# t_dset_path: target_dataset路径
# test_interval: 两次测试之间的间隔
# snapshot_interval: 模型输出的间隔
# output_dir: 输出路径
# train config
# config = {}
# config["num_iterations"] = 50004 # 迭代次数
# config["test_interval"] = args.test_interval
# config["snapshot_interval"] = args.snapshot_interval
# config["output_path"] = "../snapshot/" + args.output_dir