LDA模型數據的可視化


 1 """
 2     執行lda2vec.ipnb中的代碼
 3     模型LDA
 4     功能:訓練好后模型數據的可視化
 5 """
 6 
 7 from lda2vec import preprocess, Corpus
 8 import matplotlib.pyplot as plt
 9 import numpy as np
10 # %matplotlib inline
11 import pyLDAvis
12 try:
13     import seaborn
14 except:
15     pass
16 # 加載訓練好的主題-文檔模型,這里是查看數據使用。這里需要搞清楚數據的形式,還要去回看這個文件是怎么構成的
17 npz = np.load(open('D:/my_AI/lda2vec-master/examples/twenty_newsgroups/lda2vec/topics.pyldavis.npz', 'rb'))
18 # 數據
19 dat = {k: v for (k, v) in npz.iteritems()}
20 # 詞匯表變成list
21 dat['vocab'] = dat['vocab'].tolist()
22 
23 #####################################
24 ##  主題-詞匯
25 #####################################
26 # 主題個數為10
27 top_n = 10
28 # 主題對應10個最相關的詞
29 topic_to_topwords = {}
30 for j, topic_to_word in enumerate(dat['topic_term_dists']):
31     top = np.argsort(topic_to_word)[::-1][:top_n]               # 概率從大到小的下標索引值
32     msg = 'Topic %i '  % j
33     # 通過list的下標獲取關鍵詞
34     top_words = [dat['vocab'][i].strip()[:35] for i in top]
35     # 數據拼接
36     msg += ' '.join(top_words)
37     print(msg)
38     # 將數據保存到字典里面
39     topic_to_topwords[j] = top_words
40 
41 import warnings
42 warnings.filterwarnings('ignore')
43 prepared_data = pyLDAvis.prepare(dat['topic_term_dists'], dat['doc_topic_dists'],
44                                  dat['doc_lengths'] * 1.0, dat['vocab'], dat['term_frequency'] * 1.0, mds='tsne')
45 
46 from sklearn.datasets import fetch_20newsgroups
47 remove=('headers', 'footers', 'quotes')
48 texts = fetch_20newsgroups(subset='train', remove=remove).data
49 
50 
51 ##############################################
52 ##  選取一篇文章,確定該文章有哪些主題
53 ##############################################
54 
55 print(texts[1])
56 tt = dat['doc_topic_dists'][1]
57 msg = "{weight:02d}% in topic {topic_id:02d} which has top words {text:s}"
58 # 遍歷這20個主題,觀察一下它的權重,權重符合的跳出來
59 for topic_id, weight in enumerate(dat['doc_topic_dists'][1]):
60     if weight > 0.01:
61         # 權重符合要求,那么輸出該主題下的關聯詞匯
62         text = ', '.join(topic_to_topwords[topic_id])
63         print (msg.format(topic_id=topic_id, weight=int(weight * 100.0), text=text))
64 
65 # plt.bar(np.arange(20), dat['doc_topic_dists'][1])
66 
67 print(texts[51])
68 tt = texts[51]
69 msg = "{weight:02d}% in topic {topic_id:02d} which has top words {text:s}"
70 for topic_id, weight in enumerate(dat['doc_topic_dists'][51]):
71     if weight > 0.01:
72         text = ', '.join(topic_to_topwords[topic_id])
73         print(msg.format(topic_id=topic_id, weight=int(weight * 100.0), text=text))
74 
75 
76 # plt.bar(np.arange(20), dat['doc_topic_dists'][51])

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM