連續(xù)子圖抽?。焊鶕?jù)指定目標(biāo)節(jié)點(diǎn)抽取連續(xù)子圖k_hop_subgraph()

當(dāng)需要從一個(gè)圖中抽取某一個(gè)或一系列節(jié)點(diǎn)(目標(biāo)節(jié)點(diǎn))周圍的k級連續(xù)節(jié)點(diǎn)時(shí),可以用PyG的 k_hop_subgraph() 方法。
k=1:與目標(biāo)節(jié)點(diǎn)直接相連的節(jié)點(diǎn);
k=2:與目標(biāo)節(jié)點(diǎn)隔1個(gè)節(jié)點(diǎn)相連的節(jié)點(diǎn);
……

一、構(gòu)建原始大圖

第一部分內(nèi)容純粹是為了準(zhǔn)備原始大圖(相關(guān)內(nèi)容見:PyG構(gòu)建圖對象并轉(zhuǎn)換成networkx圖對象),與本文的核心內(nèi)容無關(guān)。

1.1 原始數(shù)據(jù)準(zhǔn)備

import torch
from torch_geometric.data import Data
import networkx as nx
from torch_geometric.utils import to_networkx,  k_hop_subgraph

import matplotlib.pyplot as plt
# 節(jié)點(diǎn)特征矩陣(一行對應(yīng)一個(gè)節(jié)點(diǎn)的特征,共7個(gè)節(jié)點(diǎn)(=節(jié)點(diǎn)特征矩陣的行數(shù)),每個(gè)節(jié)點(diǎn)有3個(gè)特征)
my_node_features = torch.tensor([[0, 0, 0], 
                                 [-1, -1, -1], 
                                 [-2, -2, -2],
                                 [-3, -3, -3],
                                 [-4, -4, -4],
                                 [-5, -5, -5],
                                 [-6, -6, -6]],dtype=torch.float)

# 邊的節(jié)點(diǎn)對,共有6條邊(7個(gè)節(jié)點(diǎn):0、1、2、3、4、5、6,與節(jié)點(diǎn)特征矩陣的行標(biāo)一一對應(yīng))
my_edge_index = torch.tensor([[0, 1, 2, 3, 4, 5],
                              [2, 2, 4, 4, 6, 6]])

# 邊特征矩陣(一行對應(yīng)一條邊的特征,每條邊有4個(gè)特征)
my_edge_attr = torch.tensor([[11, 11, 11, 11],
                             [22, 22, 22, 22],
                             [33, 33, 33, 33],
                             [44, 44, 44, 44],
                             [55, 55, 55, 55],
                             [66, 66, 66, 66]], dtype=torch.float)

# 邊權(quán)重,共有6個(gè)邊權(quán)重,一條邊一個(gè)
my_edge_weight = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.float)

1.2 根據(jù)原始數(shù)據(jù)構(gòu)建PyG對象

# 構(gòu)建Pyg對象
pyg_G = Data(x=my_node_features,
             edge_index=my_edge_index,
             edge_attr=my_edge_attr,
             edge_weight=my_edge_weight)
print(pyg_G)
輸出:Data(x=[7, 3], edge_index=[2, 6], edge_attr=[6, 4], edge_weight=[6])
# 輸出信息,為to_networkx()的參數(shù)提供參考
print(pyg_G.node_attrs())
print(pyg_G.edge_attrs())
輸出:
['x']
['edge_attr', 'edge_index', 'edge_weight']

1.3 將PyG對象轉(zhuǎn)化成networkx對象,用于成圖

這里需要注意的是如果原始數(shù)據(jù)準(zhǔn)備不恰當(dāng),可能會導(dǎo)致to_networkx()將PyG對象轉(zhuǎn)化成networkx對象后,多出來一些節(jié)點(diǎn),詳見:PyG構(gòu)建圖對象并轉(zhuǎn)換成networkx圖對象

# PyG對象轉(zhuǎn)化成networkx對象
nx_G = to_networkx(data=pyg_G, 
                   node_attrs=['x'],
                   edge_attrs=['edge_weight', 'edge_attr'],
                   to_undirected=False)  # 將PyG的Data對象轉(zhuǎn)化成networkx的數(shù)據(jù)對象

print(f'節(jié)點(diǎn)名:{nx_G.nodes}')
print(f'邊的節(jié)點(diǎn)對:{nx_G.edges}')
print('每個(gè)節(jié)點(diǎn)的屬性:')
# print(nx_G.nodes(data=True))
for node in nx_G.nodes(data=True):
    print(node)
print('每條邊的屬性:')
# print(nx_G.edges(data=True))
for edge in nx_G.edges(data=True):
    print(edge)

# 畫圖
plt.figure(figsize=(4, 4))
pos = nx.spring_layout(nx_G)  # 迭代計(jì)算‘可視化圖片’上每個(gè)節(jié)點(diǎn)的坐標(biāo)
nx.draw(nx_G, pos, node_size=400, with_labels=True)  # 繪圖
plt.show()
原大圖.png

二、k_hop_subgraph()抽取子圖

2.1 方法解釋

result = k_hop_subgraph(
node_idx=, 目標(biāo)節(jié)點(diǎn)(int,或list int);
num_hops=, 待獲取的目標(biāo)節(jié)點(diǎn)的幾級周圍節(jié)點(diǎn)(int);
edge_index=, 原圖的邊節(jié)點(diǎn)對矩陣(tensor),其shape=[2, 邊數(shù)];
relabel_nodes=, 是否對獲取的節(jié)點(diǎn)從0開始重新順序編號(True/False)(要注意★★);
flow=, 根據(jù)邊的方向選擇節(jié)點(diǎn)(str = target_to_source(目標(biāo)節(jié)點(diǎn)到其他節(jié)點(diǎn))或source_to_target(其他節(jié)點(diǎn)到目標(biāo)節(jié)點(diǎn))(要注意★★);
directed=, 如果=False,將包括所有所有采樣節(jié)點(diǎn)之間的邊。(默認(rèn)=True)
)

k_hop_subgraph()的返回值result是一個(gè)包含四個(gè)元素的tuple:
result[0]:抽取出來的節(jié)點(diǎn)(包括目標(biāo)節(jié)點(diǎn))list,已經(jīng)按照從小到大順序排列好了;
result[1]:抽取的節(jié)點(diǎn)的邊對,是個(gè)shape=[2, 抽取的邊的條數(shù)]的tensor;
result[2]:每個(gè)目標(biāo)節(jié)點(diǎn)在result[0]中的位置,是一個(gè)長度與目標(biāo)節(jié)點(diǎn)個(gè)數(shù)相同的一維tensor;
result[3]:抽取的每條邊在原圖邊對矩陣中的位置,一個(gè)由True和False組成的list,長度等原圖邊對矩陣的列數(shù)。

2.2 抽取子圖并繪圖

下面看具體例子(緊接著前面代碼):

2.2.1 抽取子圖信息

我們希望找到6號節(jié)點(diǎn)周圍的k=2的節(jié)點(diǎn)(即找到6號節(jié)點(diǎn)的1、2級節(jié)點(diǎn))。
設(shè)置:relabel_nodes=False,不對找到的節(jié)點(diǎn)重新命名;
設(shè)置:flow='source_to_target'),只要求邊是指向目標(biāo)節(jié)點(diǎn)的節(jié)點(diǎn)
(上面這兩個(gè)參數(shù)的設(shè)置對結(jié)果影響很大,特別注意。)

target_node_idx = [6]  # 確定目標(biāo)節(jié)點(diǎn)序列
k = 2  # 目標(biāo)節(jié)點(diǎn)往周圍跳躍的次數(shù)(即幾級節(jié)點(diǎn))

# 設(shè)置重要參數(shù)
relabel_nodes=False
flow='source_to_target'

# 抽取節(jié)點(diǎn)
result = k_hop_subgraph(node_idx=target_node_idx,
                        num_hops=k,
                        edge_index=pyg_G.edge_index,
                        relabel_nodes=relabel_nodes,
                        flow=flow,
                        directed=False)

sub_nodes_names = result[0]
sub_edge_index  = result[1]
target_node_map = result[2]
sub_edge_mask   = result[3]

print(f'抽取的節(jié)點(diǎn)序列:{sub_nodes_names}')
print(f'抽取的邊節(jié)點(diǎn)對:{sub_edge_index}')
print(f'目標(biāo)節(jié)點(diǎn)在抽取節(jié)點(diǎn)序列中的位置:{target_node_map}')
print(f'選中的邊在原圖的邊序列中的位置:{sub_edge_mask}')
print(f'抽取目標(biāo)節(jié)點(diǎn):{sub_nodes_names[target_node_map]}')
抽取結(jié)果.png

從上述‘輸出結(jié)果’看relabel_nodes=False時(shí),和‘我們的目標(biāo)’是完全一致的。
relabel_nodes=True時(shí),只有邊的節(jié)點(diǎn)對的序號被從0開始重新命名了,其他沒變。

2.2.2 計(jì)算抽取邊在原圖邊中的序號

這一步是為了從原圖的一些其他數(shù)據(jù)中抽取跟子圖相關(guān)的數(shù)據(jù)。

# 計(jì)算抽取的邊對矩陣sub_edge_index的每一條邊對在原邊對矩陣my_edge_index中的位置序號
match_indices_list = []
for row in sub_edge_index.t():
    res = torch.where(torch.all(torch.isin(my_edge_index.t(), row), dim=1))
    if res[0].numel()!=0:
        print(res)
        match_indices_list.append(res[0].item())
        
#match_indices_list

2.2.3 構(gòu)建子圖PyG對象

# 創(chuàng)建子圖的 Data 對象
sub_pyg_G = Data(x=my_node_features[sub_nodes_names,:],  # 在原節(jié)點(diǎn)特征矩陣中抽取被選中的節(jié)點(diǎn)的特征
                 edge_index=sub_edge_index,  # 在原邊節(jié)點(diǎn)對矩陣中抽取被選中的邊節(jié)點(diǎn)對
                 edge_attr=my_edge_attr[match_indices_list,:],  # 在原邊特征矩陣中抽取被選中的邊特征
                 edge_weight=my_edge_weight[match_indices_list])   # 在原邊權(quán)重矩陣中抽取被選中的邊權(quán)重

sub_pyg_G
輸出:
Data(x=[3, 3], edge_index=[2, 2], edge_attr=[2, 4], edge_weight=[2])
# 輸出信息作為to_networkx()設(shè)置參數(shù)時(shí)的參考。
print(sub_pyg_G.node_attrs())
print(sub_pyg_G.edge_attrs())
輸出:
['x']
['edge_attr', 'edge_index', 'edge_weight']

2.2.4 將子圖的PyG對象轉(zhuǎn)換為networkx對象

# 將PyG對象轉(zhuǎn)換為networkx對象
sub_nx_G = to_networkx(data=sub_pyg_G, 
                       node_attrs=['x'],
                       edge_attrs=['edge_attr', 'edge_weight'],
                       to_undirected=False)

2.2.5 將冗余節(jié)點(diǎn)從子圖的networkx圖對象中刪除

這一步是需要的,這里我們抽取的子圖節(jié)點(diǎn)為‘抽取的節(jié)點(diǎn)序列:tensor([2, 3, 4, 5, 6])’,顯然是不包括0和1號節(jié)點(diǎn),但因?yàn)閠o_networkx()方法本身的一些原因,會在轉(zhuǎn)化時(shí)把0和1號節(jié)點(diǎn)也加上(注意轉(zhuǎn)化時(shí)加上的0和1號節(jié)點(diǎn)與原圖中的0和1號節(jié)點(diǎn)是完全不同的,只是名稱相同),稱其為冗余節(jié)點(diǎn)。這樣一來,就會導(dǎo)致networkx對象的節(jié)點(diǎn)數(shù)比PyG對象的節(jié)點(diǎn)數(shù)多。因此冗余節(jié)點(diǎn)需要?jiǎng)h除。
如果k_hop_subgraph()函數(shù)的 relabel_nodes=Ture,則無論被選取的節(jié)點(diǎn)是什么,都會被從0開始重新命名,這時(shí)就不會出現(xiàn)冗余節(jié)點(diǎn),如果執(zhí)行下述代碼,反而會刪除有效節(jié)點(diǎn)。
(注意:這一步不是必須的,只有存在冗余節(jié)點(diǎn)時(shí)需要執(zhí)行。當(dāng)。

# 將沒有被選中的節(jié)點(diǎn)從networkx圖對象中刪除
if not relabel_nodes:
    nodes_to_remove = list(set(list(sub_nx_G.nodes)) - set(sub_nodes_names.tolist()))
    sub_nx_G.remove_nodes_from(nodes_to_remove)

2.2.6 子圖繪圖

plt.figure(figsize=(4, 4))

pos = nx.spring_layout(sub_nx_G)  # 定義節(jié)點(diǎn)的布局
nx.draw(sub_nx_G, pos, with_labels=True, node_color='red', edge_color="green", node_size=100, font_size=10)
最終結(jié)果.png
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容