6

スパース行列の特定の条件を満たさない行をゼロに置き換える最良の方法は何だろうか。例(説明のために単純な配列を使用しています):

合計が 10 より大きいすべての行をゼロの行に置き換えたい

a = np.array([[0,0,0,1,1],
              [1,2,0,0,0],
              [6,7,4,1,0],  # sum > 10
              [0,1,1,0,1],
              [7,3,2,2,8],  # sum > 10 
              [0,1,0,1,2]])

a[2] と a[4] をゼロに置き換えたいので、出力は次のようになります。

array([[0, 0, 0, 1, 1],
       [1, 2, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 1, 1, 0, 1],
       [0, 0, 0, 0, 0],
       [0, 1, 0, 1, 2]])

これは、密な行列の場合はかなり簡単です。

row_sum = a.sum(axis=1)
to_keep = row_sum >= 10   
a[to_keep] = np.zeros(a.shape[1]) 

ただし、試してみると:

s = sparse.csr_matrix(a) 
s[to_keep, :] = np.zeros(a.shape[1])

次のエラーが表示されます。

    raise NotImplementedError("Fancy indexing in assignment not "
NotImplementedError: Fancy indexing in assignment not supported for csr matrices.

したがって、疎行列には別のソリューションが必要です。私はこれを思いついた:

def zero_out_unfit_rows(s_mat, limit_row_sum):
    row_sum = s_mat.sum(axis=1).T.A[0]
    to_keep = row_sum <= limit_row_sum
    to_keep = to_keep.astype('int8')
    temp_diag = get_sparse_diag_mat(to_keep)
    return temp_diag * s_mat

def get_sparse_diag_mat(my_diag):
    N = len(my_diag)
    my_diags = my_diag[np.newaxis, :]
    return sparse.dia_matrix((my_diags, [0]), shape=(N,N))

これは、単位行列の対角要素の 2 番目と 4 番目の要素をゼロに設定すると、事前に乗算された行列の行がゼロに設定されるという事実に依存しています。

しかし、より優れた、より科学的な解決策があると思います。より良い解決策はありますか?

4

1 に答える 1

4

それが非常にscithonicであるかどうかはわかりませんが、スパース行列に対する多くの操作は、guts に直接アクセスすることでより適切に実行できます。あなたの場合、私は個人的に次のようにします:

a = np.array([[0,0,0,1,1],
              [1,2,0,0,0],
              [6,7,4,1,0],  # sum > 10
              [0,1,1,0,1],
              [7,3,2,2,8],  # sum > 10 
              [0,1,0,1,2]])
sps_a = sps.csr_matrix(a)

# get sum of each row:
row_sum = np.add.reduceat(sps_a.data, sps_a.indptr[:-1])

# set values to zero
row_mask = row_sum > 10
nnz_per_row = np.diff(sps_a.indptr)
sps_a.data[np.repeat(row_mask, nnz_per_row)] = 0
# ask scipy.sparse to remove the zeroed entries
sps_a.eliminate_zeros()

>>> sps_a.toarray()
array([[0, 0, 0, 1, 1],
       [1, 2, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 1, 1, 0, 1],
       [0, 0, 0, 0, 0],
       [0, 1, 0, 1, 2]])
>>> sps_a.nnz # it does remove the entries, not simply set them to zero
10
于 2013-09-26T18:51:09.497 に答える