关键要点
- 图神经网络(GNN)非常适合对互联数据中的关系进行建模,是处理结构和依赖关系至关重要的任务的理想选择。
- GNN将传统的深度学习推广到图格式,允许从节点、边及其连接性中学习。
- 流行的GNN架构,如图卷积网络(GCN)、图注意力网络(GAT)和GraphSAGE,提供了跨节点聚合和传播信息的不同方式。
- GNN的实际应用涵盖社交网络分析、分子属性预测、推荐系统、欺诈检测和知识图谱。
- 使用PyTorch Geometric等现代框架实现GNN非常方便,它简化了数据处理、消息传递和模型构建。
- 实践演示了GNN的工作流程,包括加载图数据、定义模型、训练和评估结果。
- GNN发展迅速,其进步提高了处理日益庞大和复杂图的可扩展性、效率和性能。
什么是图神经网络?
图神经网络,简称GNN,是一种相对较新的神经网络类型,可以处理图结构的数据。图基本上是一组对象(表示为节点)以及这些对象之间的关系(表示为连接节点的边),GNN可以处理有向图(边有方向)和无向图(边没有特定方向)。这些图的尺寸和形状也可能差异很大。 GNN的架构由多层组成,每一层都从前一层获取信息。我们将一个图(表示为一组节点和边及其相关特征)输入到GNN中。我们得到的是输入图中每个节点的一组节点嵌入。这些嵌入代表了网络为每个节点学习到的特征。 与普通神经网络只处理向量、矩阵或张量不同,GNN可以处理结构化图形的数据。这使得它们在处理网络化数据(如社交网络、分子结构或交通系统)时非常灵活。所涉及的数学很复杂,但其核心思想是它们在图中进行迭代,在节点之间传递消息以学习有用的表示。
图神经网络如何工作?
图神经网络,简称GNN,其核心是学习网络中节点之间的模式。主要思想是每个节点向其相邻节点传递消息,共享自身的信息。
- 在GNN中,每个节点与其连接的节点交换信息,帮助模型逐步理解完整的图结构。
- 每个节点使用自身特征和相邻节点的特征创建一个“消息”。
- 同时,相邻节点生成并发送它们自己的消息回来。
- 当节点收到这些消息时,它通过聚合所有传入信息来更新其内部状态。
- 这种重复的消息传递允许知识在整个图中流动,让节点了解其直接邻居之外的图部分。
- 通过堆叠多个这样的处理层,GNN可以学习更复杂、更深层的关系。
- 每增加一层,模型就能构建出更丰富、更有意义的图特征表示。
在PyTorch中实现图神经网络
Cora数据集 Cora数据集是从事图表示学习的研究人员常用的一个流行基准。该数据集包含许多科学出版物,分为七类,如“基于案例”、“遗传算法”、“神经网络”、“概率方法”、“强化学习”和“规则学习”。 Cora数据集已经存在了一段时间,并且仍然是该领域许多项目的首选资源。它提供了一种方法来评估模型在分析文档文本内容和文档之间相互引用的网络方面的有效性。许多著名的图神经网络论文都使用Cora来评估在这双重任务上的性能。 它被构建为一个图,其中出版物是节点,它们之间的引用是连接节点的边。每个文档都与一个表示其内容的特征向量相关联。这里的挑战是开发一个模型,能够查看引用图、内容向量以及它们之间的关系,以预测任何给定出版物属于哪七个类别。
数据预处理
我们使用以下命令安装PyTorch Geometric库:
|
|
然后,我们可以使用PyTorch Geometric库来加载和预处理数据集。
|
|
Planetoid类加载Cora数据集并归一化特征向量。我们可以使用dataset获取预处理后的数据,这会给我们一个具有以下属性的Data对象:
x:形状为(num_nodes, num_features)的节点特征矩阵edge_index:形状为(2, num_edges)的边连接矩阵y:形状为(num_nodes)的节点标签向量train_mask,val_mask,test_mask:指示哪些节点用于训练、验证和测试的布尔掩码。
模型架构
在构建图神经网络时,选择合适的模型架构非常重要。我们将使用PyTorch的torch_geometric库来讲解一个基本实现。我们将使用图卷积网络,它是各种图学习任务的一个良好起点。
|
|
在上面的代码中,我们导入了torch和torch.nn.functional以访问一些有用的神经网络模块和函数。然后,我们定义了一个继承自torch.nn.Module的GNN类。
在__init__方法中,我们使用PyTorch Geometric的GCNConv模块定义了两个卷积层。这便于实现图卷积。我们还添加了一个简单的线性层。
前向传递首先通过两个卷积层传递输入,每次应用ReLU激活。然后它通过线性层,最后通过log softmax来压缩输出。
只需几行代码,我们就可以构建一个不错的图神经网络!
显然,这是一个简单的例子,但我们可以看到PyTorch和PyTorch Geometric如何让我们快速原型化和迭代图神经网络架构。GCNConv层使得将图结构融入我们的模型变得非常容易。
训练
对于训练,我们将使用交叉熵损失和Adam优化器。我们可以使用Data对象上的掩码属性将数据分成训练集、验证集和测试集。
|
|
train函数进行一轮训练并返回损失。test函数检查模型在训练集、验证集和测试集上的表现并返回准确率。我们训练了模型500个epoch,并在每个epoch打印训练和测试准确率。
计算GNN模型的准确率
下面的代码定义了一个函数来计算模型在整个数据集上的准确率。compute_accuracy()函数将模型切换到评估模式,执行前向传递,并预测每个节点的标签。它将预测的标签与真实标签进行比较,并计算正确预测的数量。然后,它将正确预测的数量除以数据集中的节点总数,得到准确率百分比。
|
|
在这种情况下,模型在Cora数据集上的准确率为0.8006。这意味着模型能够正确预测类别标签的概率大约为80%。这很好,但还不完美。准确率为我们提供了模型整体表现如何的快速、高层视图。但你必须更深入地挖掘才能真正理解它在哪些方面成功,哪些方面存在困难。 为了更深入地了解模型的有效性,建议考虑其他评估指标,如精确率、召回率、F1分数和混淆矩阵。这些指标提供了模型在不同方面表现的洞察,例如正确识别正例和负例以及处理不平衡数据集。 因此,虽然80%的准确率是可靠的,但我们希望在宣布这个模型是巨大成功之前获得更多背景信息。仅凭准确率指标并不能提供完整的情况。但它是衡量性能的一个良好起点。
评估
我们可以使用准确率、精确率、召回率和F1分数等指标来评估GNN的表现。但是,我们也可以使用t-SNE可视化模型学习的节点嵌入。它将高维嵌入投影到2D中,允许我们可视化它们。
|
|
注意:读者可以运行上面的代码。它将显示一个可以解释的散点图。 该代码使用t-SNE在2D散点图中显示学习到的节点嵌入,这是可视化高维数据的一种巧妙方法。让我们来看看发生了什么:
- 图中的每个点代表数据集中的一个节点。x轴和y轴是t-SNE将嵌入压缩成的两个维度。
- 每个点的颜色代表数据集中相应节点的真实标签。
- 具有相似嵌入的节点应该具有相似的标签,因此它们将在图上聚集在一起。反之,嵌入差异很大的节点可能具有不同的标签,因此它们之间的距离会更远。
- 总体而言,该图基于学习到的嵌入清晰地表示了节点之间的关系。你可以看到形成的一些分组,它们必须共享一些潜在的相似性。这是一种窥探模型内部并理解其如何组织概念的便捷方法。
潜在挑战与考虑因素
- 拥有2708个节点和5429条边,Cora数据集被认为规模较小。这可能会影响GNN的效率,因此需要采用更先进的方法,如数据增强和迁移学习。
- Cora数据集由一种节点类型和一种边类型组成,是一个同构网络。当应用于涉及不同节点和边类型的更复杂网络时,这可能会限制GNN的适用性。
- 选择合适的超参数值,如隐藏层数、隐藏单元数和学习率,可以显著影响GNN的性能,需要仔细调整。
常见问题解答
-
什么是图神经网络(GNN)? GNN是一种设计用于处理图结构数据的神经网络。它通过聚合来自相邻节点的信息来学习节点、边或图级别的表示。
-
GNN与传统神经网络有何不同? 传统网络处理网格状数据(图像、序列),而GNN可以建模不规则、互连的结构,如社交网络或分子图。
-
GNN用于解决哪些类型的问题? 它们通常用于节点分类、链接预测、图分类、推荐、欺诈检测和分子属性预测。
-
GNN中的消息传递是什么? 消息传递是核心机制,节点与其邻居交换信息、更新嵌入并学习上下文关系。
-
GNN能很好地扩展到非常大的图吗? 由于内存和邻域扩展问题,扩展可能具有挑战性。采样(GraphSAGE)、小批量处理和分布式训练等技术可以提供帮助。
-
实现GNN的最佳编程库是什么? 流行的框架包括PyTorch Geometric(PyG)和Deep Graph Library(DGL),两者都提供用于构建GNN模型的现成层和实用程序。
-
GNN适合实时应用吗? 是的,取决于模型的复杂性。轻量级架构和基于采样的方法有助于实现接近实时的性能。
-
GNN能处理动态或演变的图吗? 是的,动态GNN变体可以处理随时间变化的图,这对于交通预测、时间推荐和异常检测很有用。
-
GNN需要什么样的数据预处理? 通常需要准备邻接信息、节点/边特征,并确保图的格式适合您选择的库。
-
GNN可解释吗? 通过注意力机制和GNNExplainer等工具,可解释性正在提高,这些工具可以突出显示有影响的节点和边。
结论
在本文中,我们深入探讨了图神经网络(GNN)背后的核心概念,并探索了它们如何应用于不同领域。GNN特别适合处理图结构数据,使模型能够推理社交网络、分子图、交通系统等中发现的复杂关系。 为了在实践中演示这些想法,我们使用了流行的Cora数据集——图学习的基准。在这个数据集中,每篇出版物表示为一个节点,它们之间的引用形成了边。我们的目标是利用每篇论文的文本特征和引用链接来预测其类别。 我们使用PyTorch Geometric库准备了数据集,对特征向量进行了归一化,并将数据分成训练集、验证集和测试集。然后,我们构建了一个简单的GNN,使用图卷积层,接着是一个线性分类器,并使用带有Adam优化器的交叉熵损失进行训练。最后,我们通过测量其准确率来评估模型的性能。 虽然还有许多其他技术和改进可以探索,但这个项目很好地介绍了GNN如何运作以及它们在从复杂关系数据中学习方面的有效性。
资源
- Constructing Neural Networks From Scratch: Part 1
- What is Deep Learning? A Beginner’s Guide to Neural Networks
- Deep Learning Architectures Explained: ResNet, InceptionV3, SqueezeNet
- PyTorch 101: Learn Deep Learning with PyTorch