本文从Cora的例子来展示PYG如何加载图数据集。
Cora 是一个小型的有标注的图数据集,包含以下内容:

  • data.x:2708 个节点(即 2708 篇论文),每个节点有 1433 个特征,形状为 (2708, 1433)。
  • data.edge_index:5429 条边(即 5429 个引用关系),形状为 (2, 5429)。
  • data.y:节点标签,共 7 类,形状为 (2708,)。(共有 7 个类别,表示论文的研究领域)
  • data.train_mask:训练集掩码,布尔向量,表示哪些节点用于训练。
  • data.val_mask:验证集掩码,布尔向量,表示哪些节点用于验证。
  • data.test_mask:测试集掩码,布尔向量,表示哪些节点用于测试。

数据主要描述了论文之间的引用关系以及每篇论文的主题。可用于进行训练节点分类问题(即判断每篇论文属于哪个类别)

1.自动加载

1.1 数据加载操作详解

PYG库提供了自动加载数据集的方法:

from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='data/Planetoid', name='Cora')
dataset[0]
print(len(dataset))  # 输出: 1
print(data)

1
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

对于 Planetoid 类来说:

  • 它是一个专门为 Planetoid 系列数据集(Cora、CiteSeer、PubMed) 设计的类。
  • 这些数据集的主要特点是:它们实际上是单图数据集,即整个数据集中只包含一个图。

dataset 是一个包含 单个 Data 对象(图) 的数据集对象。


由于 Planetoid 类的数据集中只有一个图,因此:

  • dataset[0] 返回了这个唯一的图,类型是 Data 对象,表示整个 Cora 数据集的图。
  • Dataset 是一个可索引的对象,dataset[0] 的作用就是提取第一(也是唯一)个图。

  • dataset = Planetoid(root='data/Planetoid', name='Cora') 加载了 Cora 数据集,它是一个 单图数据集,包含一张图的节点特征、边索引、节点标签和数据集划分信息。
  • dataset[0] 提取了该图的数据,返回了一个 Data 对象,表示整个图。
  • dataset 本身是一个数据集管理器,帮助加载和存储数据,同时提供一些元信息和操作方法。

1. 2 数据加载的过程

  1. 下载数据:

    • 如果指定路径 'data/Planetoid' 下没有数据集文件,Planetoid 类会从 指定的远程服务器(由 PyG 维护)下载 Cora 数据集文件,并存储在 'data/Planetoid/Cora' 文件夹下。
    • 数据集下载地址为:
  2. 解压文件:

    • 下载的数据集是 .zip.tar 格式,会被自动解压为一系列文件,主要包括:
      • ind.cora.x:训练节点的特征矩阵;
      • ind.cora.tx:测试节点的特征矩阵;
      • ind.cora.allx:包含训练节点和一些验证节点的特征矩阵;
      • ind.cora.y:训练节点的标签;
      • ind.cora.ty:测试节点的标签;
      • ind.cora.ally:训练和验证节点的标签;
      • ind.cora.graph:节点的邻接表(图结构信息);
      • ind.cora.test.index:测试节点的索引。
        如图所示:
        请添加图片描述
  3. 解析数据:

    • PyG 将原始文件的内容解析为图数据格式(Data 对象),将以下内容整合起来:
      • 节点特征矩阵 x
      • 图的边信息 edge_index
      • 节点标签 y
      • 训练、验证和测试集的掩码(train_maskval_masktest_mask)。
  4. 数据存储:

    • 如果数据加载成功,解析后的数据将被缓存到指定路径(data/Planetoid/Cora)中,后续运行时会直接加载解析后的缓存文件,而不会重复下载和解析。

2. 数据集原始文件的形式

原始文件(以 ind.cora.* 为前缀)是以下几种内容的存储形式:

文件名 内容描述
ind.cora.x 稀疏矩阵,训练集中节点的特征矩阵,大小为 (num_train_nodes, num_features)
ind.cora.tx 稀疏矩阵,测试集中节点的特征矩阵,大小为 (num_test_nodes, num_features)
ind.cora.allx 稀疏矩阵,包含训练集和部分验证集中节点的特征矩阵,大小为 (num_allx_nodes, num_features)
ind.cora.y 训练集的标签,大小为 (num_train_nodes, num_classes) 的独热编码矩阵。
ind.cora.ty 测试集的标签,大小为 (num_test_nodes, num_classes) 的独热编码矩阵。
ind.cora.ally 训练和验证集的标签,大小为 (num_allx_nodes, num_classes) 的独热编码矩阵。
ind.cora.graph 字典格式,存储图的邻接表,键为节点 ID,值为该节点的邻居节点列表。
ind.cora.test.index 列表形式,包含测试节点的索引。

3. 加载后的数据形式

加载后,数据以 torch_geometric.data.Data 对象的形式存储,主要包含以下内容:

属性 描述 形状
data.x 节点的特征矩阵,每一行表示一个节点的特征向量。 (num_nodes, num_features)
data.edge_index 图的边信息,存储为 COO 格式的索引矩阵(两个一维数组,分别表示边的起始节点和结束节点)。 (2, num_edges)
data.y 节点的标签,每个节点对应一个整数,表示其所属类别的索引值。 (num_nodes,)
data.train_mask 训练节点的布尔掩码,值为 True 的位置表示该节点属于训练集。 (num_nodes,)
data.val_mask 验证节点的布尔掩码,值为 True 的位置表示该节点属于验证集。 (num_nodes,)
data.test_mask 测试节点的布尔掩码,值为 True 的位置表示该节点属于测试集。 (num_nodes,)

4. 加载后的具体内容

Cora 数据集为例,加载后的数据具有以下具体特性:

  • 节点数num_nodes = 2708(共 2708 篇论文)。
  • 特征数num_features = 1433(每篇论文的特征是一个 1433 维向量,表示词袋模型中的单词出现情况)。
  • 边数num_edges = 10556(论文之间的引用关系,构成无向图)。
  • 类别数num_classes = 7(每篇论文属于 7 个主题之一)。
  • 掩码分布
    • 训练集:140 个节点;
    • 验证集:500 个节点;
    • 测试集:1000 个节点。

手动读取数据集

下面手动实现的 CoraData 类代码,经过修改后与 PyTorch Geometric (PyG) 的 Planetoid 类功能一致,可以直接生成标准的 Data 对象,用于图神经网络训练。


完整代码:CoraData

import os
import os.path as osp
import pickle
import numpy as np
import torch
from torch_geometric.data import Data
import scipy.sparse as sp
import urllib.request


class CoraData(object):
    download_url = "https://github.com/kimiyoung/planetoid/raw/master/data"
    filenames = ["ind.cora.{}".format(name) for name in
                 ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']]

    def __init__(self, data_root="cora", rebuild=False):
        """
        Cora 数据加载器,包括下载、处理和缓存功能。
        处理后的数据可以通过属性 .data 获取,返回 PyG 标准的 Data 对象。

        Args:
            data_root: str, 数据存储的根目录
            rebuild: bool, 是否强制重新构建数据
        """
        self.data_root = data_root
        save_file = osp.join(self.data_root, "processed_cora.pkl")
        if osp.exists(save_file) and not rebuild:
            print("Using Cached file: {}".format(save_file))
            self._data = pickle.load(open(save_file, "rb"))
        else:
            self.maybe_download()
            self._data = self.process_data()
            with open(save_file, "wb") as f:
                pickle.dump(self.data, f)
            print("Cached file: {}".format(save_file))

    @property
    def data(self):
        """返回 PyG 标准的 Data 对象"""
        return self._data

    def maybe_download(self):
        save_path = osp.join(self.data_root, "raw")
        for name in self.filenames:
            if not osp.exists(osp.join(save_path, name)):
                self.download_data("{}/{}".format(self.download_url, name), save_path)

    def process_data(self):
        """
        处理数据并生成 PyG 标准的 Data 对象,包括以下属性:
        - x: 节点特征,(2708, 1433)
        - y: 节点标签,共 7 类,(2708,)
        - edge_index: 图边索引,(2, num_edges)
        - train_mask: 训练集掩码,(2708,)
        - val_mask: 验证集掩码,(2708,)
        - test_mask: 测试集掩码,(2708,)
        """
        print("Processing data ...")
        # 读取原始数据
        x, tx, allx, y, ty, ally, graph, test_index = [
            self.read_data(osp.join(self.data_root, "raw", name)) for name in self.filenames
        ]

        train_index = np.arange(y.shape[0])  # 训练集索引 [0, 1, ..., 139]
        val_index = np.arange(y.shape[0], y.shape[0] + 500)  # 验证集索引 [140, ..., 639]
        sorted_test_index = sorted(test_index)  # 排序后的测试集索引

        # 特征和标签拼接
        x = np.concatenate((allx, tx), axis=0)  # (2708, 1433)
        y = np.concatenate((ally, ty), axis=0).argmax(axis=1)  # (2708,)

        # 重新排序测试集数据
        x[test_index] = x[sorted_test_index]
        y[test_index] = y[sorted_test_index]

        # 创建训练、验证、测试掩码
        num_nodes = x.shape[0]
        train_mask = np.zeros(num_nodes, dtype=np.bool_)
        val_mask = np.zeros(num_nodes, dtype=np.bool_)
        test_mask = np.zeros(num_nodes, dtype=np.bool_)
        train_mask[train_index] = True
        val_mask[val_index] = True
        test_mask[test_index] = True

        # 构造 edge_index
        edge_index = self.build_edge_index(graph)

        # 转换为 PyTorch 格式
        x = torch.tensor(x, dtype=torch.float32)
        y = torch.tensor(y, dtype=torch.long)
        edge_index = torch.tensor(edge_index, dtype=torch.long)
        train_mask = torch.tensor(train_mask, dtype=torch.bool)
        val_mask = torch.tensor(val_mask, dtype=torch.bool)
        test_mask = torch.tensor(test_mask, dtype=torch.bool)

        # 打印基本信息
        print("Node feature shape: ", x.shape)
        print("Node label shape: ", y.shape)
        print("Edge index shape: ", edge_index.shape)
        print("Number of training nodes: ", train_mask.sum().item())
        print("Number of validation nodes: ", val_mask.sum().item())
        print("Number of test nodes: ", test_mask.sum().item())

        # 返回 PyG 的 Data 对象
        return Data(x=x, y=y, edge_index=edge_index,
                    train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)

    @staticmethod
    def build_edge_index(graph):
        """
        根据邻接表生成 edge_index 格式 (2, num_edges)。
        """
        edge_index = []
        for src, dst in graph.items():
            edge_index.extend([[src, v] for v in dst])  # 正向边
            edge_index.extend([[v, src] for v in dst])  # 反向边
        edge_index = np.array(edge_index).T  # 转置为 (2, num_edges)
        return edge_index

    @staticmethod
    def read_data(path):
        """
        读取数据文件,根据文件名选择加载方式。
        """
        name = osp.basename(path)
        if name == "ind.cora.test.index":
            out = np.genfromtxt(path, dtype="int64")
            return out
        else:
            out = pickle.load(open(path, "rb"), encoding="latin1")
            out = out.toarray() if hasattr(out, "toarray") else out
            return out

    @staticmethod
    def download_data(url, save_path):
        """
        从指定 URL 下载数据,并保存到本地路径。
        """
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        data = urllib.request.urlopen(url)
        filename = os.path.split(url)[-1]

        with open(os.path.join(save_path, filename), 'wb') as f:
            f.write(data.read())

        return True

代码解析

  1. 下载和缓存功能

    • 如果处理后的数据已缓存 (processed_cora.pkl),直接加载缓存。
    • 如果未缓存,则从 GitHub 下载原始数据,处理后存储为缓存文件。
  2. 数据处理:process_data

    • 加载原始数据,并将训练、验证、测试节点特征拼接成完整矩阵。
    • 生成 PyG 格式的 edge_index(用于图神经网络的邻接表表示)。
    • 生成训练、验证和测试集掩码。
  3. 邻接表转换为边索引

    • build_edge_index 将邻接表 (graph) 转换为 edge_index 格式。
    • edge_index 是一个形状为 (2, num_edges) 的数组,列表示一条边的起点和终点。
  4. 返回 PyG 数据对象

    • 数据对象包括 xyedge_indextrain_maskval_masktest_mask

运行代码测试

要测试 CoraData 类,可以直接运行以下代码:

cora_data = CoraData(data_root="cora", rebuild=True)
data = cora_data.data  # 获取 PyG 的 Data 对象
print(data)

输出示例:

Processing data ...
Node feature shape:  torch.Size([2708, 1433])
Node label shape:  torch.Size([2708])
Edge index shape:  torch.Size([2, 10556])
Number of training nodes:  140
Number of validation nodes:  500
Number of test nodes:  1000
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

该类的功能与 PyTorch Geometric 的 Planetoid 类一致,支持加载 Cora 数据集,并生成标准的 PyG Data 对象,适用于图神经网络模型训练。

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐