Social LSTM 實現代碼分析


----- 2019.8.5更新 實現代碼思維導圖 -----

----- 初始原文 -----
Social LSTM最早提出於文獻 “Social LSTM: Human Trajectory Prediction in Crowded Spaces”,但經過資料查閱目前暫未找到原文獻作者所提供的程序代碼和數據,而在github上有許多針對該文獻的實現版本代碼。

本文接下來的實現代碼來自https://github.com/xuerenlv/social-lstm-tf,代碼語言為Python3,代碼大體實現了原論文中核心原創部分的模型,包括Vanilla LSTM(沒有考慮行人軌跡之間關聯性的LSTM)和Social LSTM(使用池化層考慮了行人軌跡之間關聯性的LSTM模型)的模型構建、訓練和小樣本測試的代碼,但對橫向對比的其他模型、模型量化評估方法等暫未實現。

本文下面將從代碼中矩陣數據和列表(list)數據的維度細說實現過程和模型的特點。

Vanilla LSTM 模型

訓練數據

主要功能代碼文件:util.py

數據格式:

input_data, target_data = dataLoader.next_batch()
# input_data : [batch_size, seq_length, 2]
# target_data : [batch_size, seq_length, 2]

批量處理數據大小 x 序列長度大小 x 二維地址數據(已經過標准化處理,介於\(0 - 1\)

數據解釋:

  1. 模型在實際使用時,對於每個輸入的位置數據(源於已知數據/上一步預測數據)LSTM Cell將該運行后得到的輸出就可用於下一時刻位置的預測,因此從dataLoader獲得的input_datatarget_data從數據維度上只在seq_length維度上有1個大小的錯位,對於行人已知的\(t_0 - t_{obs}\)的軌跡,訓練時參與損失函數計算的是網絡預測的\(t_1 - t_{obs+1}\)軌跡。
  2. 同時,其在由於訓練采用Minibatch,因此輸入和目標數據的有大小為batch_size的第一維度。

模型中間變量

LSTM序列網絡是模型的核心部分,輸入數據需要修改結構以滿足數據要求,同時序列網絡的輸出結果也需要經過處理才能夠使用,為此,模型主要有以下中間變量:

inputs, embedding_inputs

inputsinput_data的拆分版,將其拆解為序列模型每步運行時的輸入數據。

embedding_inputs是將inputs使用embedding層后得到的輸入數據,默認滿足embedding_size = rnn_size = 128,因此數據可直接用於lstm的輸入數據了。

# inputs : [N_0, N_1, N_2, ....], N_i = [batch_size, 2]
# embedding_inputs = [M_0, M_1, M_2, ....], M_i = [batch_size, embedding_size]

# embedding
embedding_w = tf.get_variables("embedding_w", [2, embedding_size])
embedding_b = tf.get_variables("embedding_b", [embedding_size])
for input in embedding_inputs:
	x = tf.nn.relu(tf.nn.xw_plus_b(input, embedding_w, embedding_b))
	embedding_inputs.append(x)

seq2seq.rnn_decoder

由於該源碼相比tensorflow的版本更迭還是有一定的年代感,其在運行LSTM模型時使用了不常用的方法:

outputs, last_state = tf.contrib.legacy_seq2seq.rnn_decoder(embedded_inputs, self.initial_state, cell, loop_function=None, scope="rnnlm")

此LSTM模型嚴格來說並不是seq2seq模型,其只是借用了seq2seq中decoder相同的操作步驟用在這里(手動實現也不復雜),具體來說,就是在for循環迭代embedded_inputs列表中的元素,使LSTM的cell運行對應的次數,而后將序列模型的每步運行輸出生成outputs列表,並返回最后一步運行的finial_state

output_w, output_b

LSTM模型輸出的原始outputs數據需經線性變換為合適結構才被進一步使用,在此是對於每個大小為rnn_size的輸出向量,線性變為為大小為\(5\)的結果向量,有關使用目的請參見下一節。

output_size = 5 # 具體賦值目的請參見下文與原文獻
# output : [batch_size * seq_length, rnn_size]
output = tf.reshape(tf.concat(outputs, 1), [-1, rnn_size])

output_w = tf.get_variable("output_w", [rnn_size, output_size])
output_b = tf.get_variable("output_b", [output_size])
# output : [batch_size * seq_length, 5]
output = tf.nn.xw_plus_b(output, output_w, output_b)

*output數據中最終含有\(batch\_size * seq\_length\)個預測的位置(每個位置由5個參數表述),相同的reshape策略可確保output中預測位置與target中實際位置的排列順序是相同的。

模型輸出

將序列模型每步輸出結果合並、線性變換和變形后得到output,傳入的target_data經過變形后得到flat_target_data

# model.py
# output : [batch_size * seq_length, 5]
# flat_target_data : [batch_size * seq_length, 2]

outputflag_target_data就是最終用於(訓練時)計算損失/(采樣時,不依賴於target)計算下一時刻位置的數據。

兩個變量的第一維度大小均為batch_size * seq_length(在reshape策略相同情況下,第二維度數據在數據批次和時間點上一一對應),而兩個變量在第二維度數據量的差異是:原文獻中假設了LSTM Cell輸出的rnn_size大小(默認為128)的結果滿足二維高斯分布(bivariate Gaussian distribution),因此使用線性變換矩陣后得到的恰是刻畫二維高斯分布的5個參數$\mu_x, \mu_y, \sigma_x, \sigma_y, \rho $(有關如何基於二維高斯分布求出預測點和損失值請原文獻的引用)。

Social LSTM模型

此部分暫時未完全整理出來,根據初步的代碼閱讀,Social LSTM與Vanilla LSTM整體的代碼框架和模型構建方法是相似的,具體有下述幾方面的差異:

  1. batch_sizemax pedestrian number,批量訓練數據的差異:在Vanilla LSTM訓練時,采用了Mini Batch的數據方式使每次模型迭代時具備一定的數量規模;而Social LSTM中由於池化層的加入使得同一時刻需要有MPN個LSTM序列迭代,而縱使存在多個LSTM序列,其實共享的是同一個Cell,因此同一場景的多位行人的軌跡(在代碼中稱作frame)其實就可以等價於一個batch,從而使訓練Cell時有一定的數據規模。

    # input_data format in vanilla lstm
    input_data = tf.placeholder(tf.float32, [None, seq_length, 2])
    ----
    # input_data format in social lstm
    input_data = tf.placeholder(tf.float32, [seq_length, maxNumPeds, 3])
    
  2. social tensor池化層:

    social LSTM結構從本質上就是vanilla lstm添加了池化層,在源代碼的grid.py包含主要的social tensor的支持方法。social tensor在原文中用\(H_i^t\)表示,每個行人\(i\)在不同時間點\(t\)中都有不同的social tensor

    對於每個張量中的值,實際是由上一時刻其他行人的Hidden State加和得到,Hidden State只有LSTM Cell真正跑起來才能得到,因此最終的social tensor是在模型運算中所得到的(這也是為什么運算量較大的原因以至原文獻中又提出了一種能夠在運算前得到張量的O-LSTM模型),不過在模型運行前,Hidden State的加和方式就可以通過輸入數據推算得出,grid.py做得主要就是這部分工作,其生成了數據為01真值的Grid Mask矩陣,在模型迭代時作為參數傳入,從而簡化生成social tensor的過程。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM