GraphSAGE 代碼解析(一) - unsupervised_train.py


原創文章~轉載請注明出處哦。其他部分內容參見以下鏈接~

GraphSAGE 代碼解析(二) - layers.py

GraphSAGE 代碼解析(三) - aggregators.py

GraphSAGE 代碼解析(四) - models.py

GraphSAGE代碼詳解

example_data:

1. toy-ppi-G.json 圖的信息

{ 
  directed: false
  graph : {
              {name: disjoint_union(,) }
           nodes:  [
                        {  
                                test: false
                         id: 0
                         features: [ ... ]
                         val: false
                          lable: [ ... ]
                       }
                       {...}
                         ...
                  ]

            links: [
                       {  
                                test_removed: false
                        train_removed: false
                        target: 800 # 指向的節點id(默認從小節點指向大節點)
                        source: 0   # 從0節點按順序展示
                         }
                         {...}
                           ...
                    ]
      }
}
View Code

2. toy-ppi-class_map.json

3. toy-ppi-feats.npy 預訓練好得到的features

4. toy-ppi-id_map.json 節點編號與序號的一一對應;數據格式為:{"0": 0, "1": 1,..., "14754": 14754}

5. toy-ppi-walks.txt 

    從一點出發隨機游走到鄰居節點的情況,對於每個點取198次(即可能有重復情況)

    例如:0    708 表示從0點走到708點。

1. __init__.py

1 from __future__ import print_function  
2 #即使在python2.X,使用print就得像python3.X那樣加括號使用。
3 
4 from __future__ import division          
5 # 導入python未來支持的語言特征division(精確除法),
6 # 當我們沒有在程序中導入該特征時,"/"操作符執行的是截斷除法(Truncating Division);
7 # 當我們導入精確除法之后,"/"執行的是精確除法, "//"執行截斷除除法

2. unsupervised_train.py

1 if __name__ == '__main__':
2   tf.app.run()
3 # https://blog.csdn.net/fxjzzyo/article/details/80466321
4 # tf.app.run()的作用:通過處理flag解析,然后執行main函數
5 # 如果你的代碼中的入口函數不叫main(),而是一個其他名字的函數,如test(),則你應該這樣寫入口tf.app.run(test())
6 # 如果你的代碼中的入口函數叫main(),則你就可以把入口寫成tf.app.run()
1 def main(argv=None):
2   print("Loading training data..")
3   train_data = load_data(FLAGS.train_prefix, load_walks=True)
4   # load_data函數在graphsage.utils中定義
5 
6   print("Done loading training data..")
7   train(train_data)
8   # train函數在該文件中定義def train(train_data, test_data=None)

3. utils.py - func: load_data

(1) 讀入id_map, class_map

1 if isinstance(G.nodes()[0], int):
2         def conversion(n): return int(n)
3     else:
4         def conversion(n): return n

a. isinstance() 函數來判斷一個對象是否是一個已知的類型,類似 type()。 

isinstance(object, classinfo)

參數
object -- 實例對象。
classinfo -- 可以是直接或間接類名、基本類型或者由它們組成的元組。

返回值
如果對象的類型與參數二的類型(classinfo)相同則返回 True,否則返回 False。

>>>a = 2
>>> isinstance (a,int)
True
>>> isinstance (a,str)
False
>>> isinstance (a,(str,int,list))    # 是元組中的一個返回 True
True

type() 與 isinstance() 區別:

type() 不會認為子類是一種父類類型,不考慮繼承關系。
isinstance() 會認為子類是一種父類類型,考慮繼承關系。
如果要判斷兩個類型是否相同推薦使用 isinstance()。

 1 class A:
 2     pass
 3  
 4 class B(A):
 5     pass
 6  
 7 isinstance(A(), A)    # returns True
 8 type(A()) == A        # returns True
 9 isinstance(B(), A)    # returns True
10 type(B()) == A        # returns False
View Code

b. G.nodes()

返回的是圖中節點n與節點屬性nodedata。https://networkx.github.io/documentation/stable/reference/classes/generated/networkx.Graph.nodes.html

例子:

>>> G = nx.path_graph(3)
>>> list(G.nodes)
[0, 1, 2]
>>> list(G)
[0, 1, 2]
View Code

獲取nodedata:

>>> G.add_node(1, time='5pm')
>>> G.nodes[0]['foo'] = 'bar'
>>> list(G.nodes(data=True))
[(0, {'foo': 'bar'}), (1, {'time': '5pm'}), (2, {})]
>>> list(G.nodes.data())
[(0, {'foo': 'bar'}), (1, {'time': '5pm'}), (2, {})]

>>> list(G.nodes(data='foo'))
[(0, 'bar'), (1, None), (2, None)]

>>> list(G.nodes(data='time'))
[(0, None), (1, '5pm'), (2, None)]

>>> list(G.nodes(data='time', default='Not Available'))
[(0, 'Not Available'), (1, '5pm'), (2, 'Not Available')]
View Code

If some of your nodes have an attribute and the rest are assumed to have a default attribute value you can create a dictionary from node/attribute pairs using the default keyword argument to guarantee the value is never None:

>>> G = nx.Graph()
>>> G.add_node(0)
>>> G.add_node(1, weight=2)
>>> G.add_node(2, weight=3)
>>> dict(G.nodes(data='weight', default=1))
{0: 1, 1: 2, 2: 3}
View Code

----------------------------

在utils.py中,判斷G.nodes()[0] 是否為int型(即不帶nodedata)。

若為int型,則將n轉為int型;否則直接返回n.

b. conversion() 函數

1 id_map = json.load(open(prefix + "-id_map.json"))
2 id_map = {conversion(k): int(v) for k, v in id_map.items()}

前面定義的conversion()函數在id_map這里用到了,把外存中的文件內容讀到內存中,用dict類型的id_map存儲。

id_map.json文件中數據格式為:{"0": 0, "1": 1,..., "14754": 14754},也即id_map的迭代中k為str類型,v為int型。數據文件中G.nodes()[0] 顯然是帶nodedata的,也就算一般采用 def conversion(n): return n,返回的n為類型的(就是前面形參k的類型);

但是為什么當G.nodes()[0] 不帶nodedata時,要返回int(n)?

c. class_map:  {"0": [.0,1,..], "1": [.0,1,..]...} ?含義?

list(class_map.values()): [ [...], [...], ... ,[...] ]
list(class_map.values())[0]: 表示取第一個[...] =>含義? 
if isinstance(list(class_map.values())[0], list):
    def lab_conversion(n): return n else: def lab_conversion(n): return int(n)

(2) Remove node

1 # Remove all nodes that do not have val/test annotations
2     # (necessary because of networkx weirdness with the Reddit data)
3     broken_count = 0
4     for node in G.nodes():
5         if not 'val' in G.node[node] or not 'test' in G.node[node]:
6             G.remove_node(node)
7             broken_count += 1

這里刪除的節點是不具有'val','test'屬性 的節點,而不是'val','test' 屬性值為None的節點。

區分開 if not 'val' in G.node[node] 和 if not G.node[n]['val']的不同意義。

broken_count  記錄刪去的沒有val 或者 test的屬性的節點的數目。

e. G.edges()

1 for edge in G.edges():
2         if (G.node[edge[0]]['val'] or G.node[edge[1]]['val'] or
3                 G.node[edge[0]]['test'] or G.node[edge[1]]['test']):
4             G[edge[0]][edge[1]]['train_removed'] = True
5         else:
6             G[edge[0]][edge[1]]['train_removed'] = False

G.edges() 得到edge_list, [( , ), ( , ), ... ( , )].list中每一個元素是所表示邊的兩個節點信息。若設置data = True,則會顯示邊的權重等屬性信息。

>>> G = nx.Graph()   # or DiGraph, MultiGraph, MultiDiGraph, etc
>>> G.add_path([0,1,2])
>>> G.add_edge(2,3,weight=5)
>>> G.edges()
[(0, 1), (1, 2), (2, 3)]
>>> G.edges(data=True) # default edge data is {} (empty dictionary)
[(0, 1, {}), (1, 2, {}), (2, 3, {'weight': 5})]
>>> list(G.edges_iter(data='weight', default=1))
[(0, 1, 1), (1, 2, 1), (2, 3, 5)]
>>> G.edges([0,3])
[(0, 1), (3, 2)]
>>> G.edges(0)
[(0, 1)]
View Code

代碼中edge對edges迭代,每次去list中的一個元組,而edge[0], edge[1]則分別表示兩個頂點。

若兩個頂點中至少有一個的val/test不為空,則將該邊的'train_removed'設為True,否則為False.

該操作為保證'train_removed'不為空。

(3) 獲取訓練數據features並標准化

1 if normalize and not feats is None:
2         from sklearn.preprocessing import StandardScaler
3         train_ids = np.array([id_map[n] for n in G.nodes(
4         ) if not G.node[n]['val'] and not G.node[n]['test']])
5         train_feats = feats[train_ids]
6         scaler = StandardScaler()
7         scaler.fit(train_feats)
8         feats = scaler.transform(feats)

這里if not feats is None 等價於 if feats is not None.

將val,test均為None的node選為訓練數據,通過id_map獲取其在feature表中的索引值,添加到train_ids數組中。根據索引train_ids,train_fests獲取這些nodes的features.

StandardScaler的用法:

http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html

Methods:

fit(X[, y]) : Compute the mean and std to be used for later scaling.

transform(X[, y, copy]) : Perform standardization by centering and scaling

fit_transform(X[, y]) : Fit to data, then transform it.

例子:

>>> from sklearn.preprocessing import StandardScaler
>>> data = [[0, 0], [0, 0], [1, 1], [1, 1]]
>>> scaler = StandardScaler()
>>> print(scaler.fit(data))
StandardScaler(copy=True, with_mean=True, with_std=True)
>>> print(scaler.mean_)
[0.5 0.5]
>>> print(scaler.transform(data))
[[-1. -1.]
 [-1. -1.]
 [ 1.  1.]
 [ 1.  1.]]
>>> print(scaler.transform([[2, 2]]))
[[3. 3.]]

# 計算得
# 均值[0.5, 0.5], 
# 方差:1/4 * [(0 - 0.5)^2 * 2 + (1 - 0.5)^2 * 2] = 1/4 = 0.25
# 標准差:0.5
# 對於[2,2] transform 標准化之后: (2 - 0.5) / 0.5 = 3
View Code

(4) Load walks

在unsupervised_train.py的main函數中:

1 train_data = load_data(FLAGS.train_prefix, load_walks=True)

load_walks = True,需要執行utils.py中的load_walks操作。

1 if load_walks:  # false by default
2         with open(prefix + "-walks.txt") as fp:
3             for line in fp:
4                 walks.append(map(conversion, line.split()))

map() 的用法:http://www.runoob.com/python/python-func-map.html

map(function, iterable, ...)

map() 會根據提供的函數對指定序列做映射。

第一個參數 function 以參數序列中的每一個元素調用 function 函數,返回包含每次 function 函數返回值的新列表。

例子:

>>>def square(x) :            # 計算平方數
...     return x ** 2
... 
>>> map(square, [1,2,3,4,5])   # 計算列表各個元素的平方
[1, 4, 9, 16, 25]
>>> map(lambda x: x ** 2, [1, 2, 3, 4, 5])  # 使用 lambda 匿名函數
[1, 4, 9, 16, 25]
 
# 提供了兩個列表,對相同位置的列表數據進行相加
>>> map(lambda x, y: x + y, [1, 3, 5, 7, 9], [2, 4, 6, 8, 10])
[3, 7, 11, 15, 19]
View Code

walks初始化為[], 之后append的是游走的節點對的對象。

例子:walks.txt:

0    708
0    3163
0    276
1 def conversion(n): return n
2 walks = []
3 with open("walks.txt") as fp:
4     for line in fp:
5         print(line.split())
6         walks.append(map(conversion, line.split()))
7 print(walks) 
8 print(len(walks))
View Code

輸出:

['0', '708']
['0', '3163']
['0', '276']
[<map object at 0x7f5bc0d68da0>, <map object at 0x7f5bc0d68e48>, <map object at 0x7f5bc0d68f28>]
3

(5) 函數返回值

1 return G, feats, id_map, walks, class_map

------------------------------------------------------------------------------------

4. unsupervised_train.py - func: train(train_data)

1 def train(train_data, test_data=None):

這里的train_data是上文所述的load_data函數的返回值。

變量含義:

G = train_data[0]    #
features = train_data[1]    # 訓練數據的features
id_map = train_data[2]     # "n" : n
context_pairs = train_data[3] if FLAGS.random_context else None #random walk的點對
1 if not features is None:
2     # pad with dummy zero vector
3     features = np.vstack([features, np.zeros((features.shape[1],))])

這里vstack為features添加列一行0向量,用於WX + b中與b相加。

1 placeholders = construct_placeholders()
2 # def construct_placeholders()定義的placeholders包含:
3 # batch1, batch2, neg_samples, dropout, batch_size

minibatch是EdgeMinibatchIterator的一個實例,轉至minibatch.py看class EdgeMinibatchIterator(object)的定義。

5. minibatch.py - class EdgeMinibatchIterator

https://www.cnblogs.com/shiyublog/p/9902423.html

6. unsupervised_train.py - func train

繼續回來看unsupervised_trian.py 中的train函數

變量:

1 adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
2 adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

adj_info記錄鄰居信息,是一個矩陣,矩陣每一行對應每一個節點的鄰居節點編號數組。

(1)選擇模型

接下來根據輸入參數判斷選擇6種模型(graphsage_mean,gcn,graphsage_seq,graphsage_maxpool,graphsage_meanpool,n2v)中的哪一種。

以graphsage開頭的幾種是graphsage的幾種變體,由於aggregator不同而不同。可以通過設定SampleAndAggregate()中的aggregator_type進行選擇。默認為mean.

其中gcn與graphsage的參數不同在於:

gcn的aggregator中進行列concat的操作,因此其維數是graphsage的二倍。

a. graphsage_maxpool 

 1 sampler = UniformNeighborSampler(adj_info)

首先看UniformNeighborSampler,該類用於sample節點的鄰居,在neigh_samplers.py中。

neigh_samplers.py

 1 class UniformNeighborSampler(Layer):
 2     """
 3     Uniformly samples neighbors.
 4     Assumes that adj lists are padded with random re-sampling
 5     """
 6     def __init__(self, adj_info, **kwargs):
 7         super(UniformNeighborSampler, self).__init__(**kwargs)
 8         self.adj_info = adj_info
 9 
10     def _call(self, inputs):
11         ids, num_samples = inputs
12         adj_lists = tf.nn.embedding_lookup(self.adj_info, ids) 
13         adj_lists = tf.transpose(tf.random_shuffle(tf.transpose(adj_lists)))
14         adj_lists = tf.slice(adj_lists, [0,0], [-1, num_samples])
15         return adj_lists

1.  tf.nn.embedding_lookup 用於根據ids在adj_info中找到各個對應位的向量。

2. adj_lists = tf.transpose(tf.random_shuffle(tf.transpose(adj_lists)))

    adj_lists = tf.slice(adj_lists, [0,0], [-1, num_samples]) 的過程見下:

id0 id1 id2...   --transpose--> id0 [...]  --shuffle--> id1 [...]  --transpose--> id1 id2 id0 --slice--> id1 id2

[]    []    []                                id1 [...]                     id2 [...]                         []     []    []                  []    []

                                              id2 [...]                     id0 [...]

均勻:shuffle打亂0維的順序,即打亂行順序,以此使下面采樣可以“均勻”。為了使用shuffle函數,需要在shuffle前后transpose一下。

采樣:slice之后,相當於隨機挑選了num_samples個樣本,並保留了這些樣本的全部屬性特征。

3. 最后的adj_lists即為均勻采樣后的表示鄰居信息的矩陣。

---------------------------------------------------

回到unsupervised_train.py 的train()函數.

1 sampler = UniformNeighborSampler(adj_info)

sampler獲取均勻采樣后的鄰居節點信息。

---------------------------------------------------

1 layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
2                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

其中SAGEInfo在models.py中。

models.py 

https://www.cnblogs.com/shiyublog/p/9879875.html

1 # SAGEInfo is a namedtuple that specifies the parameters 
2 # of the recursive GraphSAGE layers
3 SAGEInfo = namedtuple("SAGEInfo",
4     ['layer_name', # name of the layer (to get feature embedding etc.)
5      'neigh_sampler', # callable neigh_sampler constructor
6      'num_samples',
7      'output_dim' # the output (i.e., hidden) dimension
8     ])

namedtuple 命名元組,可以給tuple命名,用法見下:

https://www.cnblogs.com/chenlin163/p/7259061.html

 1 import collections
 2 
 3 MyTupleClass = collections.namedtuple('MyTupleClass',['name', 'age', 'job'])
 4 obj = MyTupleClass("Tomsom",12,'Cooker')
 5 print(obj.name)
 6 print(obj.age)
 7 print(obj.job)
 8 
 9 # Output:
10 # Tomsom
11 # 12
12 # Cooker
13 #############################
14 
15 Person=collections.namedtuple('Person','name age gender') 
16 # 以空格分開,表示這個namedtuple有三個元素
17 
18 print( 'Type of Person:',type(Person))
19 Bob=Person(name='Bob',age=30,gender='male')
20 print( 'Representation:',Bob)
21 Jane=Person(name='Jane',age=29,gender='female')
22 print( 'Field by Name:',Jane.name)
23 for people in [Bob,Jane]:
24     print ("%s is %d years old %s" % people)
25 
26 # Output:
27 # Type of Person: <class 'type'>
28 # Representation: Person(name='Bob', age=30, gender='male')
29 # Field by Name: Jane
30 # Bob is 30 years old male
31 # Jane is 29 years old female
32 #############################
33 
34 # 在使用namedtyuple的時候要注意其中的名稱不能使用Python的關鍵字,如class def等
35 # 不能有重復的元素名稱,比如:不能有兩個’age age’。如果出現這些情況,程序會報錯。
36 # 但是,在實際使用的時候可能無法避免這種情況,
37 # 比如:可能我們的元素名稱是從數據庫里讀出來的記錄,這樣很難保證一定不會出現Python關鍵字。
38 # 這種情況下的解決辦法是將namedtuple的重命名模式打開,
39 # 這樣如果遇到Python關鍵字或者有重復元素名 時,自動進行重命名。
40 
41 with_class=collections.namedtuple('Person','name age class gender',rename=True)
42 print with_class._fields
43 two_ages=collections.namedtuple('Person','name age gender age',rename=True)
44 print two_ages._fields
45 
46 # Output:
47 # ('name', 'age', '_2', 'gender')
48 # ('name', 'age', 'gender', '_3')
49 
50 # 使用rename=True的方式打開重命名選項。
51 # 可以看到第一個集合中的class被重命名為 ‘_2' ; 
52 # 第二個集合中重復的age被重命名為 ‘_3'
53 # namedtuple在重命名的時候使用了下划線 _ 加元素所在索引數的方式進行重命名
54 ##############################
55 
56 # 附兩段官方文檔代碼實例:
57 # 1) namedtuple基本用法
58 >>> # Basic example
59 >>> Point = namedtuple('Point', ['x', 'y'])
60 >>> p = Point(11, y=22) # instantiate with positional or keyword arguments
61 >>> p[0] + p[1] # indexable like the plain tuple (11, 22)
62 33
63 >>> x, y = p # unpack like a regular tuple
64 >>> x, y
65 (11, 22)
66 >>> p.x + p.y # fields also accessible by name
67 33
68 >>> p # readable __repr__ with a name=value style
69 Point(x=11, y=22)
70 
71 # 2) namedtuple結合csv和sqlite用法
72 EmployeeRecord = namedtuple('EmployeeRecord', 'name, age, title, department, paygrade')
73 import csv
74 for emp in map(EmployeeRecord._make, csv.reader(open("employees.csv", "rb"))):
75 print(emp.name, emp.title)
76 
77 import sqlite3
78 conn = sqlite3.connect('/companydata')
79 cursor = conn.cursor()
80 cursor.execute('SELECT name, age, title, department, paygrade FROM employees')
81 for emp in map(EmployeeRecord._make, cursor.fetchall()):
82 print(emp.name, emp.title)
View Code

對於FLAGS.dim_1FLAGS.dim_2,定義為:

1 flags.DEFINE_integer(
2     'dim_1', 128, 'Size of output dim (final is 2x this, if using concat)')
3 flags.DEFINE_integer(
4     'dim_2', 128, 'Size of output dim (final is 2x this, if using concat)')

若GCN,因為有concat操作,故使用2x.

對於FLAGS.samples_1FLAGS.samples_2,定義為:

1 flags.DEFINE_integer('samples_1', 25, 'number of samples in layer 1')
2 flags.DEFINE_integer('samples_2', 10, 'number of users samples in layer 2')

對應論文中的K = 1 ,第一層S1 = 25; K = 2 ,第二層S2 = 10。

----------------------------------------------------------

1 model = SampleAndAggregate(placeholders,
2                            features,
3                            adj_info,
4                            minibatch.deg,
5                            layer_infos=layer_infos,
6                            aggregator_type="maxpool",
7                            model_size=FLAGS.model_size,
8                            identity_dim=FLAGS.identity_dim,
9                            logging=True)

SampleAndAggregate在models.py中。

class SampleAndAggregate(GeneralizedModel)主要包含的函數有:

1. def __init__(self, placeholders, features, adj, degrees, layer_infos, concat=True, aggregator_type="mean",  model_size="small", identity_dim=0, **kwargs)

2. def sample(self, inputs, layer_infos, batch_size=None)

3. def aggregate(self, samples, input_features, dims, num_samples, support_sizes, batch_size=None,
aggregators=None, name=None, concat=False, model_size="small")

4. def _build(self)

5. def build(self)

6. def _loss(self)

7. def _accuracy(self)

---------------------------------------------------------------

(2) Session

Config

 1 config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
 2 # 參數初始化為False:
 3 # tf.app.flags.DEFINE_boolean('log_device_placement', False,
 4 #                     """Whether to log device placement.""")
 5 
 6 config.gpu_options.allow_growth = True
 7 # 控制GPU資源使用率
 8 # 使用allow_growth option,剛一開始分配少量的GPU容量,然后按需慢慢的增加,
 9 # 由於不會釋放內存,所以會導致碎片
10 
11 #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
12 # 設置每個GPU應該拿出多少容量給進程使用,
13 # per_process_gpu_memory_fraction =0.4代表 40%
14 
15 config.allow_soft_placement = True
16 # 自動選擇運行設備
17 # 在tf中,通過命令 "with tf.device('/cpu:0'):",允許手動設置操作運行的設備。
18 # 如果手動設置的設備不存在或者不可用,就會導致tf程序等待或異常,
19 # 為了防止這種情況,可以設置tf.ConfigProto()中參數allow_soft_placement=True,
20 # 允許tf自動選擇一個存在並且可用的設備來運行操作。

 Initialize session

 1 # Initialize session
 2 sess = tf.Session(config=config)
 3 merged = tf.summary.merge_all()
 4 # tf.summary()能夠保存訓練過程以及參數分布圖並在tensorboard顯示。
 5 # merge_all 可以將所有summary全部保存到磁盤,以便tensorboard顯示。
 6 # 如果沒有特殊要求,一般用這一句就可一顯示訓練時的各種信息了
 7 
 8 summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)
 9 # 指定一個文件用來保存圖。
10 # 格式:tf.summary.FileWritter(path,sess.graph)
11 # 可以調用其add_summary()方法將訓練過程數據保存在filewriter指定的文件中

Init variables

1 sess.run(tf.global_variables_initializer(),
2      feed_dict={adj_info_ph: minibatch.adj})

---------------------------------------------------------

(4) Train model

1 feed_dict = minibatch.next_minibatch_feed_dict()

next_minibatch_feed_dict() 在minibatch.py的class EdgeMinibatchIterator(object)中定義。

1 def next_minibatch_feed_dict(self):
2     start_idx = self.batch_num * self.batch_size
3     self.batch_num += 1
4     end_idx = min(start_idx + self.batch_size, len(self.train_edges))
5     batch_edges = self.train_edges[start_idx: end_idx]
6     return self.batch_feed_dict(batch_edges)
View Code

函數中獲取下個edgeminibatch的起始與終止序號,將batch后的邊的信息傳給batch_feed_dict(self, batch_edges)函數,更新placeholders中的batch1, batch2, batch_size信息。

 1 def batch_feed_dict(self, batch_edges):
 2     batch1 = []
 3     batch2 = []
 4     for node1, node2 in batch_edges:
 5         batch1.append(self.id2idx[node1])
 6         batch2.append(self.id2idx[node2])
 7 
 8     feed_dict = dict()
 9     feed_dict.update({self.placeholders['batch_size']: len(batch_edges)})
10     feed_dict.update({self.placeholders['batch1']: batch1})
11     feed_dict.update({self.placeholders['batch2']: batch2})
12 
13     return feed_dict
View Code

也即next_minibatch_feed_dict()返回的是下一個edge minibatch的placeholders信息。

=======================================

     感謝您的支持!             感謝您的支持!

感謝您的打賞!

(夢想還是要有的,萬一您喜歡我的文章呢)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM