Numpy 的廣播機制高效計算矩陣之間兩兩距離


利用numpy可以很方便的計算兩個二維數組之間的距離。二維數組之間的距離定義為:X的維度為(a,c),Y的維度為(b,c),Z為X到Y的距離數組,維度為(a,b)。且Z[0,0]是X[0]到Y[0]的距離。Z(m,n)為X[m]到Y[n]的距離。

例如: 計算 m*2 的矩陣 與  n * 2 的矩陣中,m*2 的每一行到  n*2 的兩兩之間歐氏距離。 

#computer the distance between text point x and train point x_train
import numpy as np
X = np.random.random((3,2))
X_train = np.random.random((5,2))
print('X:')
print(X)
print('X_train:')
print(X_train)

dist = np.zeros((X.shape[0],X_train.shape[0]))
print('--------------------')
#way 1:use two loops ,使用兩層循環
for i in range(X.shape[0]):
    for j in range(X_train.shape[0]):
        dist[i,j] = np.sum((X[i,:]-X_train[j,:])**2)
print('way 1 result:')
print(dist)

#way 2:use one loops ,使用一層循環
for i in range(X.shape[0]):
    dist[i,:] = np.sum((X_train-X[i,:])**2,axis=1)
print('--------------------')
print('way 2 result:')
print(dist)

#way 3:use no loops,不使用循環
dist = np.reshape(np.sum(X**2,axis=1),(X.shape[0],1))+ np.sum(X_train**2,axis=1)-2*X.dot(X_train.T)
print('--------------------')
print('way 3 result:')
print(dist)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM