在过去的十年中,深度学习系统在许多人工智能任务中取得了显著成功,但它们的应用往往较为狭窄。例如,一个训练用于识别猫和狗的计算机视觉系统需要大量重新训练才能开始识别鲨鱼和海龟。
元学习是一种旨在将机器学习系统转变为通用学习者的范式。元学习模型在一系列相关任务上进行训练,但它不仅学习如何执行这些任务,还学习如何学习执行它们。这样,它就可以通过仅少量标记训练示例适应新任务,大幅减少劳动密集型数据标注的需求。
在国际学习表征会议(ICLR)上,我们提出了一种方法,在不增加数据标注要求的情况下提高元学习任务的性能。关键思想是调整元学习过程,使其能够利用少量未标记数据,以及传统的标记示例。
在元学习中,机器学习模型学习如何学习。在元训练期间,模型使用来自"支持集"的数据在一组相关任务上进行训练,并使用来自"查询集"的数据进行测试。但查询集是标记的,因此模型可以评估其学习效果。在元测试期间,模型再次在一组支持集上进行训练,但评估其分类未标记查询数据的能力。
直觉是即使没有标签,这些额外数据仍包含大量有用信息。例如,假设一个在陆地动物图像上训练的元学习系统正在适应识别水生动物。未标记的水生动物图像仍然告诉模型有关学习任务的信息,例如水下照片典型的光照条件和背景颜色。
在实验中,我们在目标识别元学习任务上将通过我们的方法训练的模型与16个不同基线进行比较。我们发现,根据底层神经网络架构的不同,我们的方法将单样本学习的性能提高了11%到16%。
元学习原理
在传统机器学习中,模型被输入一组标记数据,并学习将数据特征与标签关联起来。然后输入另一组测试数据,并评估其对数据标签的预测能力。为了评估目的,系统设计者可以访问测试数据标签,但模型本身不能。
元学习增加了另一层复杂性。在元训练期间,模型学习执行一系列相关任务。每个任务都有自己的训练数据集和测试数据集,模型都能看到。也就是说,其元训练的一部分是学习对训练数据的特定响应方式如何影响其在测试数据上的性能。
在元测试期间,它再次在一系列任务上进行训练。这些任务与元训练期间看到的任务相关但不完全相同。同样,对于每个任务,模型看到训练数据和测试数据。但在元训练期间,测试数据是标记的,而在元测试期间,标签未知且必须预测。
方法创新
我们的方法有两个关键创新。首先,在元训练期间,我们不学习单一的全局模型。相反,我们训练一个辅助神经网络,基于相应的支持集为每个任务生成局部模型。其次且更重要的是,在元训练期间,我们还训练第二个辅助网络以利用查询集的未标记数据。然后,在元测试期间,我们可以使用查询集微调局部模型,提高性能。
利用未标记数据
机器学习系统由一组参数控制,在元学习中,元训练为特定任务族优化这些参数。在元测试或操作部署期间,模型使用少量训练示例为新的任务优化这些参数。
特定的参数值集定义了多维空间中的一个点,适应新任务可以被视为在空间中搜索表示最优新设置的点。
在传统元学习中,搜索从全局模型定义的点开始;这是初始化步骤。然后,使用支持集的标记数据,它朝着对应于新任务的设置工作;这是适应步骤。
相比之下,在我们的方法中,初始化网络根据支持集中的数据选择起始搜索位置。然后它使用查询集的未标记数据朝着最优设置工作。更准确地说,第二个辅助神经网络估计查询集数据隐含的梯度。
性能与展望
尽管我们的系统在单样本学习任务上击败了所有16个基线,但在五样本学习中,有几个基线系统的性能优于我们的系统。这些基线使用的方法与我们的方法互补,我们相信结合方法可以产生更低的错误率。展望未来,这是我们将追求的这项工作的几个扩展方向之一。