1 python使用networkx或者graphviz,pygraphviz可視化RNN(recursive)中的二叉樹


代碼地址https://github.com/vijayvee/Recursive-neural-networks-TensorFlow

代碼實現的是結構遞歸神經網絡(Recursive NN,注意,不是Recurrent),里面需要構建樹。代碼寫的有不少錯誤,一步步調試就能解決。主要是隨着tensorflow版本的變更,一些函數的使用方式發生了變化。

2 數據樣式

(3 (2 (2 The) (2 Rock)) (4 (3 (2 is) (4 (2 destined) (2 (2 (2 (2 (2 to) (2 (2 be) (2 (2 the) (2 (2 21st) (2 (2 (2 Century) (2 's)) (2 (3 new) (2 (2 ``) (2 Conan)))))))) (2 '')) (2 and)) (3 (2 that) (3 (2 he) (3 (2 's) (3 (2 going) (3 (2 to) (4 (3 (2 make) (3 (3 (2 a) (3 splash)) (2 (2 even) (3 greater)))) (2 (2 than) (2 (2 (2 (2 (1 (2 Arnold) (2 Schwarzenegger)) (2 ,)) (2 (2 Jean-Claud) (2 (2 Van) (2 Damme)))) (2 or)) (2 (2 Steven) (2 Segal))))))))))))) (2 .)))

(4 (4 (4 (2 The) (4 (3 gorgeously) (3 (2 elaborate) (2 continuation)))) (2 (2 (2 of) (2 ``)) (2 (2 The) (2 (2 (2 Lord) (2 (2 of) (2 (2 the) (2 Rings)))) (2 (2 '') (2 trilogy)))))) (2 (3 (2 (2 is) (2 (2 so) (2 huge))) (2 (2 that) (3 (2 (2 (2 a) (2 column)) (2 (2 of) (2 words))) (2 (2 (2 (2 can) (1 not)) (3 adequately)) (2 (2 describe) (2 (3 (2 (2 co-writer\/director) (2 (2 Peter) (3 (2 Jackson) (2 's)))) (3 (2 expanded) (2 vision))) (2 (2 of) (2 (2 (2 J.R.R.) (2 (2 Tolkien) (2 's))) (2 Middle-earth))))))))) (2 .)))

這是兩行數據,可以構建兩棵樹。

首先,以第一棵樹為例,3是root節點,是label,只有葉子節點有word。word就是記錄的單詞。

3 依據文件構建樹的主要處理過程:

    with open(file, 'r') as fid:

        trees = [Tree(l) for l in fid.readlines()]

 

Tree構建的時候:
    def __init__(self, treeString, openChar='(', closeChar=')'):
        tokens = []
        self.open = '('
        self.close = ')'
        for toks in treeString.strip().split():
            tokens += list(toks)
        self.root = self.parse(tokens)
        # get list of labels as obtained through a post-order traversal
        self.labels = get_labels(self.root)
        self.num_words = len(self.labels)

 其中,程序得到的tokens,是如下形式:

tokens輸出的是字符的列表,即[‘(’,’3’,’(’,’2’,‘(’,’2’,’(‘,’T’,’h’,’e’………………]

Parse函數處理:(遞歸構建樹的過程),注意,其中的int('3')得到的是3,而不是字符'3'的ASCII碼值。

    Parse函數處理:(遞歸構建樹的過程)
    def parse(self, tokens, parent=None):
        assert tokens[0] == self.open, "Malformed tree"
        assert tokens[-1] == self.close, "Malformed tree"

        split = 2  # position after open and label
        countOpen = countClose = 0

        if tokens[split] == self.open: #假如是父節點,還有子節點的話,一定是(3(,即[2]對應的字符是一個open
            countOpen += 1
            split += 1
        # Find where left child and right child split
#下面的while循環就是處理,可以看到,能夠找到(2 (2 The) (2 Rock))字符序列是其左子樹。
#
        while countOpen != countClose: 
            if tokens[split] == self.open:
                countOpen += 1
            if tokens[split] == self.close:
                countClose += 1
            split += 1

        # New node
        
        print (tokens[1],int(tokens[1]))
        node = Node(int(tokens[1]))  # zero index labels
        node.parent = parent

        # leaf Node
        if countOpen == 0: #也就是葉子節點
            node.word = ''.join(tokens[2:-1]).lower()  # lower case?
            node.isLeaf = True
            return node

        node.left = self.parse(tokens[2:split], parent=node)
        node.right = self.parse(tokens[split:-1], parent=node)
        return node

4 networkx構建可視化二叉樹

代碼如下:

def plotTree_xiaojie(tree):
    
    positions,edges = _get_pos_edge_list(tree)
    nodes = [x for x in positions.keys()]
    labels = _get_label_list(tree)
    colors = []
    try:
        colors = _get_color_list(tree)
    except AttributeError:
        pass
    #使用networkx畫圖
    G=nx.Graph()
    G.add_edges_from(edges)
    G.add_nodes_from(nodes)
    
    if len(colors) > 0:
        nx.draw_networkx_nodes(G,positions,node_size=100,node_color=colors)
        nx.draw_networkx_edges(G,positions)
        nx.draw_networkx_labels(G,positions,labels,font_color='w')
    else:
        nx.draw_networkx_nodes(G,positions,node_size=100,node_color='r')
        nx.draw_networkx_edges(G,positions)
        nx.draw_networkx_labels(G,positions,labels)
    nx.draw(G)
    plt.axis('off')
    
    plt.savefig('./可視化二叉樹__曾傑.jpg')
    plt.show()
    #官網提供的下面的兩個方法,已經缺失了。
#    nx.draw_graphviz(G)
#    nx.write_dot(G,'xiaojie.dot')
    return None

其中,_get_pos_edge_list的主要作用是對樹進行遍歷,決定每個樹節點在畫布中的位置,比如root節點就在(0,0)坐標處,然后edge就是遍歷樹得到邊。

def _get_pos_edge_list(tree):
    """
    _get_pos_list(tree) -> Mapping. Produces a mapping
    of nodes as keys, and their coordinates for plotting
    as values. Since pyplot or networkx don't have built in
    methods for plotting binary search trees, this somewhat
    choppy method has to be used.
    """
    return _get_pos_edge_list_from(tree,tree.root,{},[],0,(0,0),1.0)

dot = None
def _get_pos_edge_list_from(tree,node,poslst,edgelist,index,coords,gap):
    #利用先序遍歷,遍歷一顆樹,將邊和節點生成networkx可以識別的內容。
    """
    _get_pos_list_from(tree,node,poslst,index,coords,gap) -> Mapping.
    Produces a mapping of nodes as keys, and their coordinates for
    plotting as values.

    Non-straightforward arguments:
    index: represents the index of node in
    a list of all Nodes in tree in preorder.
    coords: represents coordinates of node's parent. Used to
    determine coordinates of node for plotting.
    gap: represents horizontal distance from node and node's parent.
    To achieve plotting consistency each time we move down the tree
    we half this value.
    """
    global dot
    positions = poslst
    edges=edgelist
    if node and node == tree.root:
        dot.node(str(index),str(node.label))
        positions[index] = coords
        new_index = 1 +index+tree.get_element_count(node.left)
        if node.left:
            edges.append((0,1))
            dot.edge(str(index),str(index+1),constraint='false')
            positions,edges = _get_pos_edge_list_from(tree,node.left,positions,edges,1,coords,gap)
        if node.right:
            edges.append((0,new_index))
            dot.edge(str(index),str(new_index),constraint='false')
            positions,edges = _get_pos_edge_list_from(tree,node.right,positions,edges,new_index,coords,gap)
     
        return positions,edges
    elif node:
        dot.node(str(index),str(node.label))
        if node.parent.right and node.parent.right == node:
            #new_coords = (coords[0]+gap,coords[1]-1) #這樣的話,當節點過多的時候,很容易出現重合節點的情形。
            new_coords = (coords[0]+(tree.get_element_count(node.left)+1)*3,coords[1]-3)
            positions[index] = new_coords
        else:
            #new_coords = (coords[0]-gap,coords[1]-1)
            new_coords = (coords[0]-(tree.get_element_count(node.right)+1)*3,coords[1]-3)
            positions[index] = new_coords
        
        new_index = 1 + index + tree.get_element_count(node.left)
        if node.left:
            edges.append((index,index+1))
            dot.edge(str(index),str(index+1),constraint='false')
            positions,edges = _get_pos_edge_list_from(tree,node.left,positions,edges,index+1,new_coords,gap)    
        if node.right:
            edges.append((index,new_index))
            dot.edge(str(index),str(new_index),constraint='false')
            positions,edges = _get_pos_edge_list_from(tree,node.right,positions,edges,new_index,new_coords,gap)    
        
        return positions,edges
    else:
        return positions,edges

5 遇到的問題(畫的樹太丑了,不忍心看)

 

樹畫的特別的丑,而且能夠對樹進行描述的信息不多。這是我參考網上繪制二叉樹的開源項目:

見博客地址:http://www.studyai.com/article/9bf95027,其中引用的兩個庫是BSTree

  1. from pybst.bstree import BSTree
  2. from pybst.draw import plot_tree

由於BSTree有它自己的樹結構,而我下載的RNN網絡的樹又是另外一種結構。於是,我只能修改BSTree的代碼,產生了前述的代碼,即plotTree_xiaojie,加入到RNN項目的源碼當中去。

樹是什么樣子呢?

可以看到,在x軸中有重疊現象。

於是代碼中有如下改動:

        if node.parent.right and node.parent.right == node: #new_coords = (coords[0]+gap,coords[1]-1) #這樣的話,當節點過多的時候,很容易出現重合節點的情形。 new_coords = (coords[0]+(tree.get_element_count(node.left)+1)*1,coords[1]-1) positions[index] = new_coords else: #new_coords = (coords[0]-gap,coords[1]-1) new_coords = (coords[0]-(tree.get_element_count(node.right)+1)*1,coords[1]-1) positions[index] = new_coords

即在x軸方向上從單純的加減去一個1,而變成了加上和減去節點數確定的距離,如此一來,能夠保證二叉樹上的所有節點在x軸上不會出現重合。因為我畫樹的過程是先序遍歷的方式,所以y軸上所有節點從根本上是不可能重合的。而子節點的位置必然要依據父節點的位置來斷定,就會導致整顆樹的節點,在x軸上出現重合。

我畫了一個手稿示意圖如下:即依據子節點的左右子樹的節點數,確立子節點與父節點的位置關系(父節點當前的位置是知道的,要確立子節點的位置)

  

優化后的二叉樹長這個樣子:

通過之前的樹對比一下,可以發現沒有節點重合了。但是為什么在根節點處出現一大片紅色。這個原因不明確。但是通過對比前后兩個圖,是可以發現,3節點和其左子節點2之間,並沒有其它的節點。

但是,圖依舊很丑。

此外,networkx能夠記錄的信息有限。一個label是不夠的。我希望能夠展現出RNN的節點的當前的向量是多少,所以需要更豐富的展現形式。於是求助Graphviz

6 借助Graphviz展現二叉樹

參考:

https://blog.csdn.net/a1368783069/article/details/52067404

使用Graphviz繪圖(一)

https://www.cnblogs.com/taceywong/p/5439574.html

修改前述繪制樹的plotTree_xiaojie程序如下:

def plotTree_xiaojie(tree):
    global dot
    dot=Digraph("G",format="pdf")

    positions,edges = _get_pos_edge_list(tree)
    nodes = [x for x in positions.keys()]
    labels = _get_label_list(tree)
    colors = []
    try:
        colors = _get_color_list(tree)
    except AttributeError:
        pass
    print(dot.source)
    f=open('可視化二叉樹.dot', 'w+')
    f.write(dot.source)  
    f.close()

    dot.view()

    #使用networkx畫圖
    G=nx.Graph()
    G.add_edges_from(edges)
    G.add_nodes_from(nodes)
    
    if len(colors) > 0:
        nx.draw_networkx_nodes(G,positions,node_size=40,node_color=colors)
        nx.draw_networkx_edges(G,positions)
        nx.draw_networkx_labels(G,positions,labels,font_color='w')
    else:
        nx.draw_networkx_nodes(G,positions,node_size=40,node_color='r')
        nx.draw_networkx_edges(G,positions)
        nx.draw_networkx_labels(G,positions,labels)
    nx.draw(G)
    plt.axis('off')
    
    plt.savefig('./可視化二叉樹__曾傑.jpg')
    plt.show()
    #官網提供的下面的兩個方法,已經缺失了。
#    nx.draw_graphviz(G)
#    nx.write_dot(G,'xiaojie.dot')
    return None
在對樹進行遍歷的_get_pos_edge_list函數中也添加了dot的相關添加節點和邊的操作,見前述代碼。前述代碼中已經包含使用graphviz的相關操作了。
結果得到的圖是這個死樣子:

雖然節點和邊的關系是對的。但是太丑了,這哪是一顆樹。

博客:https://blog.csdn.net/theonegis/article/details/71772334宣稱,能夠將二叉樹變得好看。使用如下代碼:

dot tree.dot | gvpr -c -f binarytree.gvpr | neato -n -Tpng -o tree.png

結果,更丑了。

7 拋出問題:如何更好的展現一顆二叉樹,我希望用pygraphviz。

正在研究和使用中,后續更新在下篇博文中。

見本博客,2 pygraphviz在windows10 64位下的安裝問題(反斜杠的血案)

更新博文 2018年8月23日17:21:45


 

8 使用pygraphviz繪制二叉樹

代碼修改如下:

def plotTree_xiaojie(tree):
    positions,edges = _get_pos_edge_list(tree)
    nodes = [x for x in positions.keys()]
    G=pgv.AGraph(name='xiaojie_draw_RtNN_Tree',directed=True,strict=True)
    G.add_nodes_from(nodes)
    G.add_edges_from(edges)
    G.layout('dot')
    G.draw('xiaojie_draw_RtNN_Tree.png')
    return None

結果是:

是不是相當的好看?

而且還可以局部區域放大,完全是graphviz的強大特性。

這相當於什么了,把graphviz比作原版的android系統,然后pygraphviz就像是小米,oppo,華為等進行的升級版本。

哇咔咔。

可以對邊的顏色,節點大小,還可以添加附加信息。比如我想添加節點當前的計算向量等等。

這樣,一顆結構遞歸計算的樹就出來了。留待后續更新。

下面是一顆樹的局部區域展示。

 

 

 

 

 

 

 

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM