transE知識圖譜補全,FB15K-237數據集(python實現)


0 源代碼倉庫

https://github.com/Cpaulyz/BigDataAnalysis/tree/master/Assignment8

  • transE.py訓練程序
  • graph.py繪制損失函數折線
  • test.py驗證測試集

1 目的

知識圖譜補全是從已知的知識圖譜中提取出三元組(h,r,t),為實體和關系進行建模,通過訓練出的模型進行鏈接預測,以達成知識圖譜補全的目標。

本文實驗采用了FB15K-237數據集,分為訓練集和測試集。利用訓練集進行transE建模,通過訓練為每個實體和關系建立起向量映射,並在測試集中計算MeanRank和Hit10指標進行結果檢驗。

2 數據集

使用FB15K-237數據集

分為以下四個文件

  • entity2id.txt

    實體和id對

    image-20201027183657753

  • relation2id.txt

    關系和id對

    image-20201027183639557

  • train.txt

    訓練集三元組(實體,實體,關系)

    image-20201027183739267

  • test.txt

    測試集三元組(實體,實體,關系)

3 方法

3.1 TransE

3.1.1 原理

TransE將起始實體,關系,指向實體映射成同一空間的向量,如果(head,relation,tail)存在,那么h+r≈t

image-20201027172514794

目標函數為:

image-20201027172327488

3.1.2 算法

image-20201027172221770

(1)初始化

根據維度,為每個實體和關系初始化向量,並歸一化

    def emb_initialize(self):
        relation_dict = {}
        entity_dict = {}

        for relation in self.relation:
            r_emb_temp = np.random.uniform(-6 / math.sqrt(self.embedding_dim),
                                           6 / math.sqrt(self.embedding_dim),
                                           self.embedding_dim)
            relation_dict[relation] = r_emb_temp / np.linalg.norm(r_emb_temp, ord=2)

        for entity in self.entity:
            e_emb_temp = np.random.uniform(-6 / math.sqrt(self.embedding_dim),
                                           6 / math.sqrt(self.embedding_dim),
                                           self.embedding_dim)
            entity_dict[entity] = e_emb_temp / np.linalg.norm(e_emb_temp, ord=2)

(2)選取batch

設置nbatches為batch數目,batch_size = len(self.triple_list) // nbatches

從訓練集中隨機選擇batch_size個三元組,並隨機構成一個錯誤的三元組S',進行更新

    def train(self, epochs):
        nbatches = 400
        batch_size = len(self.triple_list) // nbatches
        print("batch size: ", batch_size)
        for epoch in range(epochs):
            start = time.time()
            self.loss = 0

            # Sbatch:list
            Sbatch = random.sample(self.triple_list, batch_size)
            Tbatch = []

            for triple in Sbatch:
                corrupted_triple = self.Corrupt(triple)
                if (triple, corrupted_triple) not in Tbatch:
                    Tbatch.append((triple, corrupted_triple))
            self.update_embeddings(Tbatch)

(3)梯度下降

定義距離d(x,y)來表示兩個向量之間的距離,一般情況下,我們會取L1,或者L2 normal。

在這里,我們需要定義一個距離,對於正確的三元組(h,r,t),距離d(h+r,t)越小越好;對於錯誤的三元組(h',r,t'),距離d(h'+r,t')越小越好。

image-20201027174637723

之后,使用梯度下降進行更新

3.1.3 結果

選擇迭代次數2000次,向量維度50,學習率0.01進行訓練

損失函數變化如下

image-20201027201057941

結果存儲在entity_50dimrelation_50dim

image-20201027201141073

3.2 鏈接預測

通過transE建模后,我們得到了每個實體關系的嵌入向量,利用嵌入向量,我們可以進行知識圖譜的鏈接預測

將三元組(head,relation,tail)記為(h,r,t)

鏈接預測分為三類

  1. 頭實體預測:(?,r,t)
  2. 關系預測:(h,?,t)
  3. 尾實體預測:(h,r,?)

但原理很簡單,利用向量的可加性即可實現。以(h,r,?)的預測為例:

假設t'=h+r,則在所有的實體中選擇與t'距離最近的向量,即為t的的預測值

4 指標

4.1 Mean rank

對於測試集的每個三元組,以預測tail實體為例,我們將(h,r,t)中的t用知識圖譜中的每個實體來代替,然后通過distance(h, r, t)函數來計算距離,這樣我們可以得到一系列的距離,之后按照升序將這些分數排列。

distance(h, r, t)函數值是越小越好,那么在上個排列中,排的越前越好。

現在重點來了,我們去看每個三元組中正確答案也就是真實的t到底能在上述序列中排多少位,比如說t1排100,t2排200,t3排60.......,之后對這些排名求平均,mean rank就得到了。

4.2 Hit@10

還是按照上述進行函數值排列,然后去看每個三元組正確答案是否排在序列的前十,如果在的話就計數+1

最終 排在前十的個數/總個數 就是Hit@10

4.3 代碼實現

def distance(h, r, t):
    h = np.array(h)
    r = np.array(r)
    t = np.array(t)
    s = h + r - t
    return np.linalg.norm(s)


def mean_rank(entity_set, triple_list):
    triple_batch = random.sample(triple_list, 100)
    mean = 0
    hit10 = 0
    hit3 = 0
    for triple in triple_batch:
        dlist = []
        h = triple[0]
        t = triple[1]
        r = triple[2]
        dlist.append((t, distance(entityId2vec[h], relationId2vec[r], entityId2vec[t])))
        for t_ in entity_set:
            if t_ != t:
                dlist.append((t_, distance(entityId2vec[h], relationId2vec[r], entityId2vec[t_])))
        dlist = sorted(dlist, key=lambda val: val[1])
        for index in range(len(dlist)):
            if dlist[index][0] == t:
                mean += index + 1
                if index < 3:
                    hit3 += 1
                if index <10:
                    hit10 += 1
                print(index)
                break
    print("mean rank:", mean / len(triple_batch))
    print("hit@3:", hit3 / len(triple_batch))
    print("hit@10:", hit10 / len(triple_batch))

image-20201028011618776

5 結論

經過transE建模后,在測試集的13584個實體,961個關系的 59071個三元組中,測試結果如下:

mean rank: 353.06935721419984
hit@3: 0.12181950534103028
hit@10: 0.2754989758087725

一方面可以看出訓練后的結果是有效的,但不是十分優秀,可能與transE模型的局限性有關,transE只能處理一對一的關系,不適合一對多/多對一關系。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM