あいどる💖たいむ

あいどるやってます。

numpyの配列から上位n個を取得する

概要

numpy.darrayから値を比較し、上位n個を取得したい。
argsortを使えば実現できた。

環境

$ python --version
Python 3.5.1

argsortを使ってみる

argsortでソート後のindexのリストを取得できる
numpy.argsort — NumPy v1.12 Manual

In [1]: import numpy as np

In [3]: x = np.array([[0, 3, 1, 2],
   ...:        [1, 0, 2, 2],
   ...:        [4, 4, 4, 4]])
   ...:

In [4]: x
Out[4]:
array([[0, 3, 1, 2],
       [1, 0, 2, 2],
       [4, 4, 4, 4]])

# axis=0の場合(列ごとに比較するイメージ)
In [5]: np.argsort(x, axis=0)
Out[5]:
array([[0, 1, 0, 0],
       [1, 0, 1, 1],
       [2, 2, 2, 2]])

# axis=0の場合(行ごとに比較するイメージ)
In [6]: np.argsort(x, axis=1)
Out[6]:
array([[0, 2, 3, 1],
       [1, 0, 2, 3],
       [0, 1, 2, 3]])

# axisのデフォルトは-1(最後の次元)
In [7]: np.argsort(x)
Out[7]:
array([[0, 2, 3, 1],
       [1, 0, 2, 3],
       [0, 1, 2, 3]])

argsortを使って、上位2個を取得する

In [3]: np.argsort(x)[0][::-1][:2]
Out[3]: array([1, 3])

# 1つのスライスで処理するなら・・・
In [4]: np.argsort(x)[0][:-3:-1]
Out[4]: array([1, 3])

こんな感じでしょうか