2

Cython から clapack ルーチン dgelsy_ を呼び出す次のコードを作成しましたが、最小二乗問題の正しい解が得られません。

cimport numpy as np
import numpy as np
ctypedef np.float64_t NP_FLOAT_t
ctypedef np.int_t NP_INT_t
ctypedef np.uint8_t NP_BOOL_t    
ctypedef int integer

cdef extern from "clapack.h":
    integer dgelsy_(integer *m, integer *n, integer *nrhs, 
    double *a, integer *lda, double *b, integer *ldb, integer *
    jpvt, double *rcond, integer *rank, double *work, integer *
    lwork, integer *info)

cpdef dgelsy(np.ndarray[NP_FLOAT_t,ndim=2] A, np.ndarray[NP_FLOAT_t,ndim=1] b, np.ndarray[NP_INT_t,ndim=1] jpvt):
    cdef integer m = A.shape[0]
    cdef integer n = A.shape[1]
    cdef integer nrhs = 1
    cdef integer lda = m
    cdef integer ldb = m
    cdef integer rank
    cdef NP_FLOAT_t rcond = 1e-16
    cdef integer lwork = -1
    cdef integer info

    #First call as a workspace query
    cdef np.ndarray[NP_FLOAT_t, ndim=1] work1 = np.empty(shape=1,dtype=np.float)
    dgelsy_(&m, &n, &nrhs, <double*>A.data, &lda, <double*>b.data, &ldb, 
            <integer*>jpvt.data, &rcond, &rank, <double*>work1.data, &lwork, &info)

    #Now the actual call to solve the problem
    lwork = <integer>work1[0]
    cdef np.ndarray[NP_FLOAT_t, ndim=1] work2 = np.empty(shape=lwork,dtype=np.float)
    dgelsy_(&m, &n, &nrhs, <double*>A.data, &lda, <double*>b.data, &ldb, 
            <integer*>jpvt.data, &rcond, &rank, <double*>work2.data, &lwork, &info)
    return rank, info

私の setup.py ファイルは正しいと思います。私のコードはコンパイル、リンク、および実行されますが、コンパイル時の警告が表示され、得られたソリューションは正しくありません。ここに私のPythonテストコードがあります:

import numpy
import cylapack #cylapack is my cython module with the code above
numpy.random.seed(1)
A = numpy.random.normal(size=(100,10))
A_ = A.copy()
x = numpy.random.normal(size=10)
b = numpy.dot(A,x) + numpy.random.normal(size=100)
b_ = b.copy()
pivots = numpy.zeros(shape=10,dtype=numpy.int)

print cylapack.dgelsy(A,b,pivots)
print pivots
x_ = numpy.linalg.lstsq(A_,b_,1e-16)[0]
print numpy.sum((numpy.dot(A_,x_) - b_)**2)
print numpy.sum((numpy.dot(A_,b[0:10]) - b_)**2)

以下を出力します。

(10, 0)
[25769803780 12884901896 30064771077 38654705666  4294967306           0
           0           0           0           0]
99.8269537854
1087.62032064

最後の 2 つの数値は、それぞれ numpy および lapack ソリューションの残差二乗和です。どちらも同じはずですが、明らかに lapack ソリューションは実際には正しくありません。コンパイラの警告は次のとおりです。

cylapack.c:1424: warning: passing argument 1 of 'dgelsy_' from incompatible pointer type
cylapack.c:1424: warning: passing argument 2 of 'dgelsy_' from incompatible pointer type
cylapack.c:1424: warning: passing argument 3 of 'dgelsy_' from incompatible pointer type
cylapack.c:1424: warning: passing argument 5 of 'dgelsy_' from incompatible pointer type
cylapack.c:1424: warning: passing argument 7 of 'dgelsy_' from incompatible pointer type
cylapack.c:1424: warning: passing argument 8 of 'dgelsy_' from incompatible pointer type
cylapack.c:1424: warning: passing argument 10 of 'dgelsy_' from incompatible pointer type
cylapack.c:1424: warning: passing argument 12 of 'dgelsy_' from incompatible pointer type
cylapack.c:1424: warning: passing argument 13 of 'dgelsy_' from incompatible pointer type
cylapack.c:1495: warning: passing argument 1 of 'dgelsy_' from incompatible pointer type
cylapack.c:1495: warning: passing argument 2 of 'dgelsy_' from incompatible pointer type
cylapack.c:1495: warning: passing argument 3 of 'dgelsy_' from incompatible pointer type
cylapack.c:1495: warning: passing argument 5 of 'dgelsy_' from incompatible pointer type
cylapack.c:1495: warning: passing argument 7 of 'dgelsy_' from incompatible pointer type
cylapack.c:1495: warning: passing argument 8 of 'dgelsy_' from incompatible pointer type
cylapack.c:1495: warning: passing argument 10 of 'dgelsy_' from incompatible pointer type
cylapack.c:1495: warning: passing argument 12 of 'dgelsy_' from incompatible pointer type
cylapack.c:1495: warning: passing argument 13 of 'dgelsy_' from incompatible pointer type

明らかに、コンパイラはすべての整数ポインターについて不平を言っています (変更なしで代わりに long を使用してみました)。私が理解していない基本的な何かがあると思います。誰かが私が間違っているかもしれないことを教えてもらえますか?

4

1 に答える 1

3

私自身の質問に答えるつもりはありませんでしたが、今では理解できました。問題は、lapack が行列を Fortran スタイルの列優先順で期待しているのに対し、numpy はデフォルトで C スタイルの行優先順を使用することです。テストコードで次の行を変更した場合:

A = numpy.random.normal(size=(100,10))

これに:

A = numpy.random.normal(size=(10,100)).transpose()

その後、正常に動作します。ただし、コンパイラの警告やピボットの値はまだわかりませんが、問題の正しい解決策とは無関係のようです。

于 2013-02-14T19:09:00.290 に答える