GraphSAGE 代碼解析 - minibatch.py


class EdgeMinibatchIterator

    """ This minibatch iterator iterates over batches of sampled edges or
    random pairs of co-occuring edges.

    G -- networkx graph
    id2idx -- dict mapping node ids to index in feature tensor
    placeholders -- tensorflow placeholders object
    context_pairs -- if not none, then a list of co-occuring node pairs (from random walks)
    batch_size -- size of the minibatches
    max_degree -- maximum size of the downsampled adjacency lists
    n2v_retrain -- signals that the iterator is being used to add new embeddings to a n2v model
    fixed_n2v -- signals that the iterator is being used to retrain n2v with only existing nodes as context
    """

def __init__(self, G, id2idx, placeholders, context_pairs=None, batch_size=100, max_degree=25,

n2v_retrain=False, fixed_n2v=False, **kwargs) 中具體介紹以下:

1 self.nodes = np.random.permutation(G.nodes())
2 # 函數shuffle與permutation都是對原來的數組進行重新洗牌,即隨機打亂原來的元素順序
3 # shuffle直接在原來的數組上進行操作,改變原來數組的順序,無返回值
4 # permutation不直接在原來的數組上進行操作,而是返回一個新的打亂順序的數組,並不改變原來的數組。
1 self.adj, self.deg = self.construct_adj()

這里重點看construct_adj()函數。

 1 def construct_adj(self):
 2         adj = len(self.id2idx) * \
 3             np.ones((len(self.id2idx) + 1, self.max_degree))
 4         # 該矩陣記錄訓練數據中各節點的鄰居節點的編號
 5         # 采樣只取max_degree個鄰居節點,采樣方法見下
 6         # 同樣進行了行數加一操作
 7 
 8         deg = np.zeros((len(self.id2idx),))
 9         # 該矩陣記錄了每個節點的度數
10 
11         for nodeid in self.G.nodes():
12             if self.G.node[nodeid]['test'] or self.G.node[nodeid]['val']:
13                 continue
14             neighbors = np.array([self.id2idx[neighbor]
15                                   for neighbor in self.G.neighbors(nodeid)                   
16                                   if (not self.G[nodeid][neighbor]['train_removed'])])
17             # Graph.neighbors() Return a list of the nodes connected to the node n.
18             # 在選取鄰居節點時進行了篩選,對於G.neighbors(nodeid) 點node的鄰居,
19             # 只取該node與neighbor相連的邊的train_removed = False的neighbor
20             # 也就是只取不是val, test的節點。
21             # neighbors得到了鄰居節點編號數列。
22 
23             deg[self.id2idx[nodeid]] = len(neighbors)
24             # deg各位取值為該位對應nodeid的節點的度數,
25             # 也即經過上面篩選后得到的鄰居數
26 
27             if len(neighbors) == 0:
28                 continue
29             if len(neighbors) > self.max_degree:
30                 neighbors = np.random.choice(
31                     neighbors, self.max_degree, replace=False)
32             # range: neighbors; size = max_degree; replace: replace the origin matrix or not
33             # np.random.choice為選取size大小的數列
34 
35             elif len(neighbors) < self.max_degree:
36                 neighbors = np.random.choice(
37                     neighbors, self.max_degree, replace=True)
38             # 經過choice隨機選取,得到了固定大小max_degree = 25的直接相連的鄰居數列
39 
40             adj[self.id2idx[nodeid], :] = neighbors
41            # 把該node的鄰居數列,賦值給adj矩陣中對應nodeid位的向量。
42         return adj, deg

 

construct_test_adj()  函數中,與上不同之處在於,可以直接得到鄰居而無需根據val/test/train_removed篩選.

1 neighbors = np.array([self.id2idx[neighbor]
2                           for neighbor in self.G.neighbors(nodeid)])

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM