栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

[半监督学习] MixMatch: A Holistic Approach to Semi-Supervised Learning

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

[半监督学习] MixMatch: A Holistic Approach to Semi-Supervised Learning

MixMatch 融合了半监督学习的主要方法, 对于数据增强后的未标记示例, MixMatch 预测其低熵标签, 并使用 MixUp 混合标记和未标记数据.

半监督学习中的一种方法为添加损失项, 该损失项是在未标记数据上计算的. 在许多工作中, 这个损失项包含以下三类:

最小化熵. 它鼓励模型对未标记的数据输出可信的预测.一致性正则化. 它鼓励模型在其输入受到扰动时产生相同的输出分布.通用正则化. 它鼓励模型很好地泛化并避免过度拟合训练数据.

论文中提出一种 SSL 方法: MixMatch, 其引入一个单一的损失, 将伪标签和一致性正则方法统一于其中. 与以前的方法不同, MixMatch 一次针对所有属性, 一些实验成果如下:

实验证明, MixMatch 在所有标准图像基准上都获得了最先进的结果, 并将 CIFAR-10 的错误率降低了 4 倍.在消融研究中进一步表明 MixMatch 大于其部分的总和.MixMatch 对于差分隐私学习很有用, 使 PATE 框架中的 Students 能够获得新的最先进的结果, 同时加强隐私保证和准确性.

简而言之, MixMatch 为未标记数据引入了一个统一的损失项, 可以无缝地降低熵, 同时保持一致性并与传统的正则化技术保持兼容.

1. 损失项 1.1 一致性正则化(Consistency Regularization)

监督学习中一种常见的正则化技术是数据增强. 例如, 在图像分类中, 通常对输入图像进行弹性变形或添加噪声, 这可以在不改变其标签的情况下显着改变图像的像素内容. 这可以通过生成近乎无限的新修改数据流来人为地扩大训练集的大小. 一致性正则化将数据增强应用于半监督学习, 其强制一个未标记的示例 x x x 的分类应该与 D a t a A u g m e n t ( x ) {rm DataAugment}(x) DataAugment(x) 的分类相同.

在我前面的几篇文章中介绍了利用一致性正则化的方法: Π Pi Π-model, Temporal Ensembling, Mean-Teacher, Dual-Student, VAT, ICT, UDA.

1.2 熵最小化(Entropy Minimization)

许多半监督学习方法中一个常见的基本假设是分类器的决策边界不应穿过边缘数据分布的高密度区域. 强制执行此操作的一种方法是要求分类器输出对未标记数据的低熵预测. 损失项使未标记数据 x x x 的 p m o d e l ( y ∣ x ; θ ) p_{model}(y | x; theta) pmodel​(y∣x;θ) 的熵最小化. 这种形式的熵最小化与 VAT 相结合以获得更强的结果. "伪标签"通过对未标记数据的高置信度预测构建 one-hot 标签并将其用作标准交叉熵损失中的训练目标, 从而隐式地进行熵最小化. MixMatch 还通过在未标记数据的目标分布上使用"锐化"函数来隐式实现熵最小化.

1.3 通正则化(Traditional Regularization)

正则化是对模型施加约束, 希望其能更好地泛化数据. 使用权重衰减来惩罚模型参数的 L2 范数, 还在 MixMatch 中使用 MixUp. 将 MixUp 用作正则化器(应用于标记数据点)和半监督学习方法(应用于未标记数据点). MixUp 方法在之前就已经应用于半监督学习, 例如 ICT.

2. MixMatch

MixMatch 是一种"整体"方法, 它结合了 SSL 主流范式中的思想. 给定一个 batch 的 X mathcal{X} X, 其标签为 one-hot 编码, 代表 L L L 个可能标签中的一个, 和一个相同大小 batch 的未标记示例 U mathcal{U} U, MixMatch 生成一批经过处理的增强标记示例 X ′ mathcal{X}' X′ 和一批具有"猜测"标签 U ′ mathcal{U}' U′ 的增强未标记示例. U ′ mathcal{U}' U′ 和 X ′ mathcal{X}' X′ 用于计算单独的标记和未标记损失项.

MixMatch 中使用的标签猜测过程如下图:

将随机数据增强(Stochastic data augmentation)应用于未标记的图像 K K K 次, 每个增强后的图像都通过分类器进行输入. 然后, 通过调整温度分布(distribution’s temperature)来"锐化"这 K K K 个预测的平均值.

MixMatch 整体算法如下:

MixMatch 具体步骤如下:

2.1 数据增强(Data Augmentation)

正如许多 SSL 方法中的典型情况一样, 对标记和未标记的数据都使用数据增强. 对 X mathcal{X} X 中的每个 x b x_b xb​, 生成一个转换后的版本 x ^ b = A u g m e n t ( x b ) hat{x}_b = mathrm{Augment}(x_b) x^b​=Augment(xb​), 在上面 algorithm 1的第3行. 对未标记数据 U mathcal{U} U 中的每个 u b u_b ub​, 生成 K K K 个增强 u ^ b , k = A u g m e n t ( u b ) , k ∈ ( 1 , … , K ) hat{u}_{b,k} = mathrm{Augment}(u_b), k in (1, dots ,K) u^b,k​=Augment(ub​),k∈(1,…,K), 在上面 algorithm 1的第5行. 使用这些单独的增强来为每个 u b u_b ub​ 生成一个"猜测标签" q b q_b qb​. 其中 batch_size= B B B.

2.2 标签猜测(Label Guessing)

对于 U mathcal{U} U 中的每个未标记示例, 增强后使用 MixMatch 模型的预测为示例的标签生成一个"猜测"值. 这个猜测在后来被用在无监督损失项中. 为此, 计算其平均值, 在上面 algorithm 1的第7行:
q ‾ b = 1 K ∑ K = 1 K P m o d e l ( y ∣ u ^ b , k ; θ ) (1) overline{q}_b=frac{1}{K} sum_{K=1}^K P_{model}(y vert hat{u}_{b,k}; theta) tag{1} q​b​=K1​K=1∑K​Pmodel​(y∣u^b,k​;θ)(1)
锐化操作(Sharpening): 在生成标签猜测时, 执行一个额外的步骤, 即对于给定增强的平均预测 q ‾ b overline{q}_b q​b​, 应用锐化函数来减少标签分布的熵. 在实践中, 调整分类的"温度"分布是常用方法, 在上面 algorithm 1的第8行, 它被定义为:
S h a r p e n ( p , T ) i = p i 1 T / ∑ j = 1 L p j 1 T (2) mathrm{Sharpen}(p,T)_i=p_i^{frac{1}{T}} bigg/ sum_{j=1}^L p_j^{frac{1}{T}} tag{2} Sharpen(p,T)i​=piT1​​/j=1∑L​pjT1​​(2)
其中 p p p 是一些输入分类的分布(特别是在 MixMatch 中, p p p 是对增强 q ‾ b overline{q}_b q​b​ 的平均类别预测, T T T 是超参数. 随着 T → 0 T rightarrow 0 T→0, S h a r p e n ( p , T ) mathrm{Sharpen}(p, T) Sharpen(p,T) 的输出将接近 Dirac(“one-hot”)分布.

2.3 MixUp

使用 MixUp 进行半监督学习, 与过去的 SSL 工作不同, 将标记示例和未标记示例与猜测标签混合在一起. 为了与损失兼容, 这里定义了一个修改过的 MixUp 版本. 对于具有相应标签概率的两个示例 ( x 1 , p 1 ) (x_1, p_1) (x1​,p1​), ( x 2 , p 2 ) (x_2, p_2) (x2​,p2​), 通过以下式计算 ( x ′ , p ′ ) (x', p') (x′,p′):
λ ∼ B e t a ( α , α ) (3) lambda sim mathrm{Beta}(alpha,alpha) tag{3} λ∼Beta(α,α)(3)
λ ′ = max ⁡ ( λ , 1 − λ ) (4) lambda'=max(lambda,1-lambda) tag{4} λ′=max(λ,1−λ)(4)
x ′ = λ ′ x 1 + ( 1 − λ ′ ) x 2 (5) x'=lambda'x_1+(1-lambda')x_2 tag{5} x′=λ′x1​+(1−λ′)x2​(5)
p ′ = λ ′ p 1 + ( 1 − λ ′ ) p 2 (6) p'=lambda'p_1+(1-lambda')p_2 tag{6} p′=λ′p1​+(1−λ′)p2​(6)
其中 α alpha α 是超参数. 将 ( x ′ , p ′ ) (x', p') (x′,p′) 作为增强数据或者虚拟训练数据. 为了应用 MixUp, 见上面 algorithm 1的第10-11行, 首先将所有带标签的增强标记示例和所有带猜测标签的未标记示例收集到:
X ^ = ( ( x ^ b , p b ) ; b ∈ ( 1 , … , B ) ) (7) hat{mathcal{X}}=((hat{x}_b,p_b);bin(1,dots,B)) tag{7} X^=((x^b​,pb​);b∈(1,…,B))(7)
U ^ = ( ( u ^ b , k , q b ) ; b ∈ ( 1 , … , B ) , k ∈ ( 1 , … , K ) ) (8) hat{mathcal{U}}=((hat{u}_{b,k},q_b);bin(1,dots,B),kin(1,dots,K)) tag{8} U^=((u^b,k​,qb​);b∈(1,…,B),k∈(1,…,K))(8)
然后, 组合这些集合并将结果打乱以形成 W mathcal{W} W, 它将作为 MixUp 的数据源, 见上面 algorithm 1的第12行. 对于 X ^ hat{mathcal{X}} X^ 中的第 i i i 个带标签的示例对, 计算 M i x U p ( X ^ i , W i ) {rm MixUp}(hat{mathcal{X}}_i, mathcal{W}_i) MixUp(X^i​,Wi​) 其中 i ∈ ( 1 , … , ∣ X ^ ∣ ) i in (1, dots, vert hat{mathcal{X}}vert) i∈(1,…,∣X^∣), 并将结果添加到 X ′ mathcal{X}' X′ 里, 见上面 algorithm 1的第13行. 计算 U i ′ = M i x U p ( U ^ i , W i + ∣ X ^ ∣ ) mathcal{U}'_i = mathrm{MixUp}( hat{mathcal{U}}_i, mathcal{W}_{i+vert hat{mathcal{X}}vert}) Ui′​=MixUp(U^i​,Wi+∣X^∣​) 其中 i ∈ ( 1 , … , ∣ U ^ ∣ ) i in (1, dots, vert hat{mathcal{U}}vert) i∈(1,…,∣U^∣), 见上面 algorithm 1的第14行.

总而言之, MixMatch 将 X mathcal{X} X 转换为 X ′ mathcal{X}' X′, 这是一组应用了数据增强和 MixUp 的标记示例. 类似地, U mathcal{U} U 被转换为 U ′ mathcal{U}' U′, 即每个未标记示例的多个增强的集合, 并带有相应的猜测标签. X ′ mathcal{X}' X′, U ′ mathcal{U}' U′ 是对有标签和无标签数据进行增强之后得到的新训练数据.

2.4 损失函数(Loss Function)

组合损失 L mathcal{L} L 定义如下:
X ′ , U ′ = M i x M a t c h ( X , U , T , K , α ) (9) mathcal{X}',mathcal{U}'=mathrm{MixMatch}(mathcal{X},mathcal{U},T,K,alpha) tag{9} X′,U′=MixMatch(X,U,T,K,α)(9)
L X = 1 ∣ X ′ ∣ ∑ x , p ∈ X ′ H ( p , P m o d e l ( y ∣ x ; θ ) ) (10) mathcal{L}_{mathcal{X}}=frac{1}{vert mathcal{X}' vert} sum_{x,p in mathcal{X}'} mathrm{H}(p,P_{model}(y vert x; theta)) tag{10} LX​=∣X′∣1​x,p∈X′∑​H(p,Pmodel​(y∣x;θ))(10)
L U = 1 L ∣ U ′ ∣ ∑ u , q ∈ U ′ ∥ q − P m o d e l ( y ∣ u ; θ ) ∥ 2 2 (11) mathcal{L}_{mathcal{U}}=frac{1}{Lvert mathcal{U}' vert} sum_{u,q in mathcal{U}'} lVert q-P_{model}(y vert u; theta)rVert_2^2 tag{11} LU​=L∣U′∣1​u,q∈U′∑​∥q−Pmodel​(y∣u;θ)∥22​(11)
L = L X + λ U L U (12) mathcal{L}=mathcal{L}_{mathcal{X}}+lambda_{mathcal{U}}mathcal{L}_{mathcal{U}} tag{12} L=LX​+λU​LU​(12)
其中 H ( p , q ) mathrm{H}(p, q) H(p,q) 是分布 p p p 和 q q q 之间的交叉熵, T T T、 K K K、 α alpha α 和 λ U lambda_{mathcal{U}} λU​ 是超参数. 式子(12)将来自 X ′ mathcal{X}' X′ 的有标签数据之间的交叉熵损失与来自 U ′ mathcal{U}' U′ 的预测和猜测标签的平方 L 2 L_2 L2​ 损失相结合.

2.5 一些超参数设置

在实践中发现, MixMatch 的大多数超参数都可以固定, 不需要在每个实验或每个数据集的基础上进行调整. 具体来说, 对于所有实验, 设置 T = 0.5 T = 0.5 T=0.5, K = 2 K = 2 K=2. 此外, 仅在每个数据集的基础上更改 α alpha α 和 λ U lambda_{mathcal{U}} λU​. 文中发现 α = 0.75 alpha = 0.75 α=0.75 和 λ U = 100 lambda_{mathcal{U}}=100 λU​=100 是调整的良好起点.

TensorFlow 版本代码: https://github.com/google-research/mixmatch
PyTorch 版本代码: https://github.com/YU1ut/MixMatch-pytorch

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/726060.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号