a = np.array([1, 2, 2, 3])
print(np.searchsorted(a, 0)) # 0
print(np.searchsorted(a, 1)) # 0
print(np.searchsorted(a, 2)) # 1
print(np.searchsorted(a, 2, 'left')) # 1
print(np.searchsorted(a, 2, 'right')) # 3
print(np.searchsorted(a, 2.5, 'right')) # 3
print(np.searchsorted(a, 2.5, 'left')) # 3
print(np.searchsorted(a, 3, 'left')) # 3
print(np.searchsorted(a, 3, 'right')) # 4
print(np.searchsorted(a, 4)) # 4
print(np.searchsorted(a, [0, 1, 2, 3, 4, 5, ])) # [0 0 1 3 4 4]
searchsorted有三個重要參數:
- a:待查找的有序數組
- v:待查找的值
- side:字符串,取值為left或者right,表示取下界(閉區間)還是取上界(開區間),默認參數為下界閉區間
利用searchsorted可以非常炫酷地實現輪盤賭隨機選取:
t = np.cumsum(weights)
sample = np.searchsorted(t, np.random.random() * t[-1])
cumsum保證了遞增,searchsorted二分查找,其中t[-1]表示全部元素之和,整個過程一氣呵成、美不勝收。
雖然如此,這種方式依然不是最好的方法。因為numpy提供了輪盤賭算法。
from collections import Counter
import numpy as np
a = []
for i in range(10000):
x = np.random.choice([1, 2, 3], 2, p=[0.1, 0.3, 0.6])
a.extend(x)
a = Counter(a)
a = np.array([np.array(i) for i in a.items()], dtype=np.float32)
a[:, 1] /= np.sum(a[:, 1])
print(a)
輸出為
[[1. 0.0993 ]
[2. 0.30325]
[3. 0.59745]]
因為searchsorted的下閉區間、上開區間效果有些奇特,所以可以wrap一下使它的行為更加明確
二分查找實際上可以寫成四種:
- 左閉區間
- 右閉區間
- 左開區間
- 右開區間
如果自己寫,一定要十分小心地考慮好邊界條件才能夠避免出錯。
import numpy as np
def bisearch(a, v, can_eq=True, side='left'):
x = np.searchsorted(a, v, side=side)
if x >= a.shape[0]:
return x
if can_eq:
if side == 'left':
if a[x] == v:
return x
else:
return x - 1
else:
if a[x] > v:
if x > 0 and a[x - 1] == v:
return x - 1
else:
return x
else:
return x
else:
if side == 'left':
if a[x] == v:
return x - 1
else:
return x
else:
return x
a = np.array([1, 2, 2, 4])
print(bisearch(a, 2, True, 'left'))#1
print(bisearch(a, 2, True, 'right'))#2
print(bisearch(a, 2, False, 'left'))#0
print(bisearch(a, 2, False, 'right'))#3
print(bisearch(a, -1, True, 'left'))#-1
print(bisearch(a, 5, True, 'right'))#4
