當(dāng)需要從一個(gè)圖中抽取某一個(gè)或一系列節(jié)點(diǎn)(目標(biāo)節(jié)點(diǎn))周圍的k級連續(xù)節(jié)點(diǎn)時(shí),可以用PyG的 k_hop_subgraph() 方法。
k=1:與目標(biāo)節(jié)點(diǎn)直接相連的節(jié)點(diǎn);
k=2:與目標(biāo)節(jié)點(diǎn)隔1個(gè)節(jié)點(diǎn)相連的節(jié)點(diǎn);
……
一、構(gòu)建原始大圖
第一部分內(nèi)容純粹是為了準(zhǔn)備原始大圖(相關(guān)內(nèi)容見:PyG構(gòu)建圖對象并轉(zhuǎn)換成networkx圖對象),與本文的核心內(nèi)容無關(guān)。
1.1 原始數(shù)據(jù)準(zhǔn)備
import torch
from torch_geometric.data import Data
import networkx as nx
from torch_geometric.utils import to_networkx, k_hop_subgraph
import matplotlib.pyplot as plt
# 節(jié)點(diǎn)特征矩陣(一行對應(yīng)一個(gè)節(jié)點(diǎn)的特征,共7個(gè)節(jié)點(diǎn)(=節(jié)點(diǎn)特征矩陣的行數(shù)),每個(gè)節(jié)點(diǎn)有3個(gè)特征)
my_node_features = torch.tensor([[0, 0, 0],
[-1, -1, -1],
[-2, -2, -2],
[-3, -3, -3],
[-4, -4, -4],
[-5, -5, -5],
[-6, -6, -6]],dtype=torch.float)
# 邊的節(jié)點(diǎn)對,共有6條邊(7個(gè)節(jié)點(diǎn):0、1、2、3、4、5、6,與節(jié)點(diǎn)特征矩陣的行標(biāo)一一對應(yīng))
my_edge_index = torch.tensor([[0, 1, 2, 3, 4, 5],
[2, 2, 4, 4, 6, 6]])
# 邊特征矩陣(一行對應(yīng)一條邊的特征,每條邊有4個(gè)特征)
my_edge_attr = torch.tensor([[11, 11, 11, 11],
[22, 22, 22, 22],
[33, 33, 33, 33],
[44, 44, 44, 44],
[55, 55, 55, 55],
[66, 66, 66, 66]], dtype=torch.float)
# 邊權(quán)重,共有6個(gè)邊權(quán)重,一條邊一個(gè)
my_edge_weight = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.float)
1.2 根據(jù)原始數(shù)據(jù)構(gòu)建PyG對象
# 構(gòu)建Pyg對象
pyg_G = Data(x=my_node_features,
edge_index=my_edge_index,
edge_attr=my_edge_attr,
edge_weight=my_edge_weight)
print(pyg_G)
輸出:Data(x=[7, 3], edge_index=[2, 6], edge_attr=[6, 4], edge_weight=[6])
# 輸出信息,為to_networkx()的參數(shù)提供參考
print(pyg_G.node_attrs())
print(pyg_G.edge_attrs())
輸出:
['x']
['edge_attr', 'edge_index', 'edge_weight']
1.3 將PyG對象轉(zhuǎn)化成networkx對象,用于成圖
這里需要注意的是如果原始數(shù)據(jù)準(zhǔn)備不恰當(dāng),可能會導(dǎo)致to_networkx()將PyG對象轉(zhuǎn)化成networkx對象后,多出來一些節(jié)點(diǎn),詳見:PyG構(gòu)建圖對象并轉(zhuǎn)換成networkx圖對象
# PyG對象轉(zhuǎn)化成networkx對象
nx_G = to_networkx(data=pyg_G,
node_attrs=['x'],
edge_attrs=['edge_weight', 'edge_attr'],
to_undirected=False) # 將PyG的Data對象轉(zhuǎn)化成networkx的數(shù)據(jù)對象
print(f'節(jié)點(diǎn)名:{nx_G.nodes}')
print(f'邊的節(jié)點(diǎn)對:{nx_G.edges}')
print('每個(gè)節(jié)點(diǎn)的屬性:')
# print(nx_G.nodes(data=True))
for node in nx_G.nodes(data=True):
print(node)
print('每條邊的屬性:')
# print(nx_G.edges(data=True))
for edge in nx_G.edges(data=True):
print(edge)
# 畫圖
plt.figure(figsize=(4, 4))
pos = nx.spring_layout(nx_G) # 迭代計(jì)算‘可視化圖片’上每個(gè)節(jié)點(diǎn)的坐標(biāo)
nx.draw(nx_G, pos, node_size=400, with_labels=True) # 繪圖
plt.show()

二、k_hop_subgraph()抽取子圖
2.1 方法解釋
result = k_hop_subgraph(
node_idx=, 目標(biāo)節(jié)點(diǎn)(int,或list int);
num_hops=, 待獲取的目標(biāo)節(jié)點(diǎn)的幾級周圍節(jié)點(diǎn)(int);
edge_index=, 原圖的邊節(jié)點(diǎn)對矩陣(tensor),其shape=[2, 邊數(shù)];
relabel_nodes=, 是否對獲取的節(jié)點(diǎn)從0開始重新順序編號(True/False)(要注意★★);
flow=, 根據(jù)邊的方向選擇節(jié)點(diǎn)(str = target_to_source(目標(biāo)節(jié)點(diǎn)到其他節(jié)點(diǎn))或source_to_target(其他節(jié)點(diǎn)到目標(biāo)節(jié)點(diǎn))(要注意★★);
directed=, 如果=False,將包括所有所有采樣節(jié)點(diǎn)之間的邊。(默認(rèn)=True)
)
k_hop_subgraph()的返回值result是一個(gè)包含四個(gè)元素的tuple:
result[0]:抽取出來的節(jié)點(diǎn)(包括目標(biāo)節(jié)點(diǎn))list,已經(jīng)按照從小到大順序排列好了;
result[1]:抽取的節(jié)點(diǎn)的邊對,是個(gè)shape=[2, 抽取的邊的條數(shù)]的tensor;
result[2]:每個(gè)目標(biāo)節(jié)點(diǎn)在result[0]中的位置,是一個(gè)長度與目標(biāo)節(jié)點(diǎn)個(gè)數(shù)相同的一維tensor;
result[3]:抽取的每條邊在原圖邊對矩陣中的位置,一個(gè)由True和False組成的list,長度等原圖邊對矩陣的列數(shù)。
2.2 抽取子圖并繪圖
下面看具體例子(緊接著前面代碼):
2.2.1 抽取子圖信息
我們希望找到6號節(jié)點(diǎn)周圍的k=2的節(jié)點(diǎn)(即找到6號節(jié)點(diǎn)的1、2級節(jié)點(diǎn))。
設(shè)置:relabel_nodes=False,不對找到的節(jié)點(diǎn)重新命名;
設(shè)置:flow='source_to_target'),只要求邊是指向目標(biāo)節(jié)點(diǎn)的節(jié)點(diǎn);
(上面這兩個(gè)參數(shù)的設(shè)置對結(jié)果影響很大,特別注意。)
target_node_idx = [6] # 確定目標(biāo)節(jié)點(diǎn)序列
k = 2 # 目標(biāo)節(jié)點(diǎn)往周圍跳躍的次數(shù)(即幾級節(jié)點(diǎn))
# 設(shè)置重要參數(shù)
relabel_nodes=False
flow='source_to_target'
# 抽取節(jié)點(diǎn)
result = k_hop_subgraph(node_idx=target_node_idx,
num_hops=k,
edge_index=pyg_G.edge_index,
relabel_nodes=relabel_nodes,
flow=flow,
directed=False)
sub_nodes_names = result[0]
sub_edge_index = result[1]
target_node_map = result[2]
sub_edge_mask = result[3]
print(f'抽取的節(jié)點(diǎn)序列:{sub_nodes_names}')
print(f'抽取的邊節(jié)點(diǎn)對:{sub_edge_index}')
print(f'目標(biāo)節(jié)點(diǎn)在抽取節(jié)點(diǎn)序列中的位置:{target_node_map}')
print(f'選中的邊在原圖的邊序列中的位置:{sub_edge_mask}')
print(f'抽取目標(biāo)節(jié)點(diǎn):{sub_nodes_names[target_node_map]}')

從上述‘輸出結(jié)果’看relabel_nodes=False時(shí),和‘我們的目標(biāo)’是完全一致的。
relabel_nodes=True時(shí),只有邊的節(jié)點(diǎn)對的序號被從0開始重新命名了,其他沒變。
2.2.2 計(jì)算抽取邊在原圖邊中的序號
這一步是為了從原圖的一些其他數(shù)據(jù)中抽取跟子圖相關(guān)的數(shù)據(jù)。
# 計(jì)算抽取的邊對矩陣sub_edge_index的每一條邊對在原邊對矩陣my_edge_index中的位置序號
match_indices_list = []
for row in sub_edge_index.t():
res = torch.where(torch.all(torch.isin(my_edge_index.t(), row), dim=1))
if res[0].numel()!=0:
print(res)
match_indices_list.append(res[0].item())
#match_indices_list
2.2.3 構(gòu)建子圖PyG對象
# 創(chuàng)建子圖的 Data 對象
sub_pyg_G = Data(x=my_node_features[sub_nodes_names,:], # 在原節(jié)點(diǎn)特征矩陣中抽取被選中的節(jié)點(diǎn)的特征
edge_index=sub_edge_index, # 在原邊節(jié)點(diǎn)對矩陣中抽取被選中的邊節(jié)點(diǎn)對
edge_attr=my_edge_attr[match_indices_list,:], # 在原邊特征矩陣中抽取被選中的邊特征
edge_weight=my_edge_weight[match_indices_list]) # 在原邊權(quán)重矩陣中抽取被選中的邊權(quán)重
sub_pyg_G
輸出:
Data(x=[3, 3], edge_index=[2, 2], edge_attr=[2, 4], edge_weight=[2])
# 輸出信息作為to_networkx()設(shè)置參數(shù)時(shí)的參考。
print(sub_pyg_G.node_attrs())
print(sub_pyg_G.edge_attrs())
輸出:
['x']
['edge_attr', 'edge_index', 'edge_weight']
2.2.4 將子圖的PyG對象轉(zhuǎn)換為networkx對象
# 將PyG對象轉(zhuǎn)換為networkx對象
sub_nx_G = to_networkx(data=sub_pyg_G,
node_attrs=['x'],
edge_attrs=['edge_attr', 'edge_weight'],
to_undirected=False)
2.2.5 將冗余節(jié)點(diǎn)從子圖的networkx圖對象中刪除
這一步是需要的,這里我們抽取的子圖節(jié)點(diǎn)為‘抽取的節(jié)點(diǎn)序列:tensor([2, 3, 4, 5, 6])’,顯然是不包括0和1號節(jié)點(diǎn),但因?yàn)閠o_networkx()方法本身的一些原因,會在轉(zhuǎn)化時(shí)把0和1號節(jié)點(diǎn)也加上(注意轉(zhuǎn)化時(shí)加上的0和1號節(jié)點(diǎn)與原圖中的0和1號節(jié)點(diǎn)是完全不同的,只是名稱相同),稱其為冗余節(jié)點(diǎn)。這樣一來,就會導(dǎo)致networkx對象的節(jié)點(diǎn)數(shù)比PyG對象的節(jié)點(diǎn)數(shù)多。因此冗余節(jié)點(diǎn)需要?jiǎng)h除。
如果k_hop_subgraph()函數(shù)的 relabel_nodes=Ture,則無論被選取的節(jié)點(diǎn)是什么,都會被從0開始重新命名,這時(shí)就不會出現(xiàn)冗余節(jié)點(diǎn),如果執(zhí)行下述代碼,反而會刪除有效節(jié)點(diǎn)。
(注意:這一步不是必須的,只有存在冗余節(jié)點(diǎn)時(shí)需要執(zhí)行。當(dāng)。)
# 將沒有被選中的節(jié)點(diǎn)從networkx圖對象中刪除
if not relabel_nodes:
nodes_to_remove = list(set(list(sub_nx_G.nodes)) - set(sub_nodes_names.tolist()))
sub_nx_G.remove_nodes_from(nodes_to_remove)
2.2.6 子圖繪圖
plt.figure(figsize=(4, 4))
pos = nx.spring_layout(sub_nx_G) # 定義節(jié)點(diǎn)的布局
nx.draw(sub_nx_G, pos, with_labels=True, node_color='red', edge_color="green", node_size=100, font_size=10)
