numpy中的argmax、argmin、argwhere、argsort、argpartition函數


楔子

numpy中有幾個以arg開頭的函數,非常的神奇,因為它們返回的不是元素、而是元素的索引,我們來看一下用法,這里只以一維數組為例。

np.argmax

首先np.max是獲取最大元素,那么np.argmax是做什么的呢?

import numpy as np

arr = np.array([3, 22, 4, 11, 2, 44, 9])
print(np.max(arr))  # 44
print(np.argmax(arr))  # 5

我們看到np.max是獲取數組中最大的元素,np.argmax是獲取數組中最大元素對應的索引。

同理還有np.argmin,np.min是獲取數組中最小的元素,顯然是2;np.argmin是獲取數組中最小元素對應的索引,顯然是4。

import numpy as np

arr = np.array([3, 22, 4, 11, 2, 44, 9])
print(np.min(arr))  # 2
print(np.argmin(arr))  # 4

np.argwhere

np.where我們算是經常使用了,先來復習一下它的用法吧。

import numpy as np

arr = np.array([1, 2, 3, 4, 5, 6, 7, 8])

# 如果元素大於4, 那么減去10; 否則擴大十倍
print(np.where(arr > 4, arr - 10, arr * 10))  # [10 20 30 40 -5 -4 -3 -2]

# 如果元素大於4, 那么保持不變, 否則變成4
print(np.where(arr > 4, arr, 4))  # [4 4 4 4 5 6 7 8]

和np.where作用類似的還有一個np.clip,來看一下。

import numpy as np

arr = np.array([1, 2, 3, 4, 5, 6, 7, 8])

# 小於2的換成2, 大於6的換成6, 一般在設置上下限的時候非常有用
print(np.clip(arr, 2, 6))  # [2 2 3 4 5 6 6 6]

那么np.where是做啥的呢?首先這個函數只接受一個參數,找出滿足條件的元素對應的索引。

import numpy as np

arr = np.array([3, 4, 5, 6, 7])
print(np.argwhere(arr % 2 != 0))
"""
[[0]
 [2]
 [4]]
"""
print(np.argwhere(arr % 2 != 0).flatten())  # [0 2 4]

顯然元素3、5、7在%2之后不為0,所以會篩選出它們的索引,因此是[0 2 4]。只不過默認不是一個一維數組,我們需要再調用一下flatten,將其扁平化。

np.argsort

np.sort是用來排序的,類似於Python的內置函數sorted。

import numpy as np

arr = np.array([4, 2, 3, 6, 5, 1])
print(np.sort(arr))  # [1 2 3 4 5 6] 

sort很容易,再來看看argsort。

import numpy as np

arr = np.array([4, 2, 3, 6, 5, 1])
print(np.sort(arr))  # [1 2 3 4 5 6]
print(np.argsort(arr))  # [5 1 2 0 4 3]

sort是將從小到大排序之后返回,argsort是返回從小到大排序之后元素對應的索引。比如:第一個元素是5,表示原來數組中索引為5的元素在排序之后應該排在第一個位置上。

因此,通過argsort我們可以選出topN的元素。

import numpy as np

arr = np.array([4, 2, 3, 6, 5, 1, 8, 9, 7])
print(arr[np.argsort(arr)[-3:]])  # [7 8 9]

# 當然sort本身也是可以的
print(np.sort(arr)[-3:])  # [7 8 9]

下面看一個問題,如果我想查看數組中每一個元素在排完序之后對應的索引該怎么辦呢?

以數組:[3 2 1 4]為例,在排完序之后結果顯然是[1 2 3 4],那么原來的元素3應該在索引為2的位置上、元素2在索引為1的位置上、元素1在索引為0的位置上、元素4在索引為3的位置上,所以我們希望得到一個數組[2 1 0 3],那么要怎么做?

import numpy as np

arr = np.array([88, 79, 86, 97, 89, 95, 84])

# 調用一次argsort顯然是不夠的, 它表示排完序之后原來的元素對應的索引
print(np.argsort(arr))  # [1 6 2 0 4 5 3]

# 如果我們連續調用兩次argsort的話, 另外np.argsort(arr) <==> arr.argsort()
print(arr.argsort().argsort())  # [3 0 2 6 4 5 1]

# 此時就大功告成了
# 數組[3 0 2 6 4 5 1]表示:
#   arr中第一個元素88在排完序之后應該處於索引為3的位置
#   79在排完序之后應該處於索引為0的位置
#   ...

以88為例,顯然它在排序之后索引為3,所以對arr.argsort()得到的數組再進行一次argsort即可得到我們想要的結果。這個可能有點繞,使用言語表達起來實在是不太容易,可以自己看着圖嘗試一下。

np.argpartition

argpartition類似於argwhere,但它只是局部排序,舉例說明:

import numpy as np

arr = np.array([66, 15, 27, 33, 19, 13, 10])

"""
np.partition(arr, n)
找出arr中第n + 1小的元素(將arr排序之后索引n的元素), 然后返回一個新數組
並將原來數組中第n + 1小的元素放在新數組索引為n的地方, 保證左邊的元素比它小, 右邊的元素比它大
"""
print(np.partition(arr, 3))  # [15 13 10 19 27 33 66]
# 第4小的元素(排完序之后索引為3)顯然是19, 那么將19放在索引為3的位置, 然后左邊的元素比它小, 右邊的元素比它大
# 至於兩邊的順序則沒有要求

# 雖然我們可以使用sort, 但是sort是全局排序
# 如果數組非常大, 我們只希望選擇最小的10個元素, 直接通過np.partition(arr, 9)即可
# 然后如果排序的話, 只對這選出來的10個元素排序即可, 而無需對整個大數組進行排序

# 同理還可以從后往前找, 比如:
# np.partition(arr, -2)表示找到第2大的元素(將arr排序之后索引-2的元素), 放在數組索引為-2的地方
# 然后左邊元素比它小, 右邊元素比它大
print(np.partition(arr, -2))  # [13 10 27 15 19 33 66]
# 第2大的元素顯然是33, 那么排在索引為-2的位置, 左邊元素比它小, 右邊元素比它大


# 然后argpartition不用想, 肯定是獲取排序之后的索引
print(np.argpartition(arr, 3))  # [1 5 6 4 2 3 0]
print(np.argpartition(arr, -2))  # [5 6 2 1 4 3 0]


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM