栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

2021-09-23如何把torch

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

2021-09-23如何把torch

问题描述:

如何用PyG 表示多张图(torch_geometric.data.Batch)?把 torch_geometric.data.Data 多张Data类图对象拼接成一个batch,其目的是批量化处理多张图,如图所示。

代码实例:
import torch
from torch_geometric.data import Data
from torch_geometric.data.batch import Batch


edge_index_s = torch.tensor([
    [0, 0, 0, 0],
    [1, 2, 3, 4],
])
x_s = torch.randn(5, 16)  # 5 nodes.
edge_index_t = torch.tensor([
    [0, 0, 0],
    [1, 2, 3],
])
x_t = torch.randn(4, 16)  # 4 nodes.

edge_index_3 = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
x_3 = torch.randn(4, 16)

data1= Data(x=x_s,edge_index=edge_index_s)
data2= Data(x=x_t,edge_index=edge_index_t)
data3= Data(x=x_3,edge_index=edge_index_3)
#上面是构建3张Data图对象
# * `Batch(Data)` in case `Data` objects are batched together
#* `Batch(HeteroData)` in case `HeteroData` objects are batched together

data_list = [data1, data2,data3]


loader = Batch.from_data_list(data_list)#调用该函数data_list里的data1、data2、data3 三张图形成一张大图,也就是batch
print('data_list:n',data_list)
#data_list: [Data(edge_index=[2, 4], x=[5, 16]), Data(edge_index=[2, 3], x=[4, 16]), Data(edge_index=[2, 4], x=[4, 16])]
print('batch:',loader.batch)
#batch: tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2])
print('loader:',loader)
#loader: Batch(batch=[13], edge_index=[2, 11], x=[13, 16])
print('loader.edge_index:n',loader.edge_index) #batch的边的元组
#loader.edge_index:
#tensor([[ 0,  0,  0,  0,  5,  5,  5,  9, 10, 10, 11],
#        [ 1,  2,  3,  4,  6,  7,  8, 10,  9, 11, 10]])

print('loader.num_graphs:',loader.num_graphs)#该batch的图的个数,这里是3个
#loader.num_graphs: 3

Batch=Batch.to_data_list(loader)#大图Batch变回成3张小图
print(Batch)
#[Data(edge_index=[2, 4], x=[5, 16]), Data(edge_index=[2, 3], x=[4, 16]), Data(edge_index=[2, 4], x=[4, 16])]
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/269261.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号