栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

PyG从networkx导入数据的节点名称问题

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

PyG从networkx导入数据的节点名称问题

最近在研究使用PyG进行图数据处理时,发现PyG从networkx导入数据时,原有节点的名字都被转为了整数,而后面还要把训练的节点嵌入与原有节点名字标签一一对应起来。

经过一番探索,发现是PyG的from_networkx函数在实现时,用到了networkx.relabel.convert_node_labels_to_integers函数:

def from_networkx(G, group_node_attrs: Optional[Union[List[str], all]] = None,
                  group_edge_attrs: Optional[Union[List[str], all]] = None):
    r"""Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a
    :class:`torch_geometric.data.Data` instance.

    Args:
        G (networkx.Graph or networkx.DiGraph): A networkx graph.
        group_node_attrs (List[str] or all, optional): The node attributes to
            be concatenated and added to :obj:`data.x`. (default: :obj:`None`)
        group_edge_attrs (List[str] or all, optional): The edge attributes to
            be concatenated and added to :obj:`data.edge_attr`.
            (default: :obj:`None`)

    .. note::

        All :attr:`group_node_attrs` and :attr:`group_edge_attrs` values must
        be numeric.
    """
    import networkx as nx

    G = nx.convert_node_labels_to_integers(G)
    G = G.to_directed() if not nx.is_directed(G) else G
    edge_index = torch.LongTensor(list(G.edges)).t().contiguous()

    data = defaultdict(list)
    # 后面省略

以上源码来自:torch_geometric.utils.convert — pytorch_geometric 2.0.1 documentation

下面来看networkx中convert_node_labels_to_integers的用法。

convert_node_labels_to_integers(Gfirst_label=0ordering='default'label_attribute=None)[source]

Returns a copy of the graph G with the nodes relabeled using consecutive integers.

Parameters

G graph

A NetworkX graph

first_label int, optional (default=0)

An integer specifying the starting offset in numbering nodes. The new integer labels are numbered first_label, …, n-1+first_label.

ordering string

“default” : inherit node ordering from G.nodes() “sorted” : inherit node ordering from sorted(G.nodes()) “increasing degree” : nodes are sorted by increasing degree “decreasing degree” : nodes are sorted by decreasing degree

label_attribute string, optional (default=None)

Name of node attribute to store old label. If None no attribute is created.

只要是节点在转整数时,排序方法一致,出来的顺序就是一样的。因此可以把原节点标签暂存到nodename属性上,后面再直接用整数对应也比较省事了。

最终解决方案如下:

import torch
from torch_geometric.nn import SAGEConv
from torch_geometric.datasets import Planetoid
import torch.nn.functional as F
from torch_geometric.utils import from_networkx, to_networkx, add_self_loops, degree as pyg_degree
import networkx as nx

if __name__ == "__main__":
    #构建图
    G = nx.DiGraph()
    G.add_node("来", embedding=[1.0,2.0])
    G.add_node("去", embedding=[1.4,2.1])
    G.add_node("灰", embedding=[1.6,1.1])
    G.add_node("粉", embedding=[2.1,2.9])
    G.add_edge("来","灰")
    G.add_edge("粉","灰")
    G.add_edge("去","灰")
    G.add_edge("来","粉")
    G.add_edge("灰","粉")

    #将G转为整数后,返回一个副本G1,原节点名存到nodename属性中
    G1 = nx.relabel.convert_node_labels_to_integers(G, label_attribute="nodename")
    print("AFTER relabel:")
    print( G1.nodes(data=True))

    #转为PyG数据,将'embedding'属性作为data['x']
    data = from_networkx(G, ['embedding'])

    #使用pytorch处理,每个数字加10
    data['x'].add_( 10.0)

    #处理完成后,再转出为networkx的图对象,带上训练后的节点属性
    G = to_networkx(data, node_attrs=['x'])
    print(nx.info(G))
    print(G.nodes(data=True))
    
    for n,d in G.nodes(data=True):
        print(n,d)
        #更新G1中的属性
        G1.nodes[n]['embedding']=d['x']
    print(G1.nodes(data=True))
 

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/331190.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号