2

Scipy を使用して Pandas DataFrame 列に適合させる最良の方法を知りたいです。AZ が A、B、C、および D に依存する列 ( 、BCDおよび)を持つデータ テーブル (Pandas DataFrame) がある場合Z_real、Z ( Z_pred)。

適合する各関数のシグネチャは

func(series, param_1, param_2...)

series は、DataFrame の各行に対応する Pandas シリーズです。さまざまな関数がさまざまな列の組み合わせを使用できるように、Pandas シリーズを使用します。

DataFrame をscipy.optimize.curve_fit使用して渡そうとしました

curve_fit(func, table, table.loc[:, 'Z_real'])

しかし、何らかの理由で、各 func インスタンスには、各行の Series ではなく、データテーブル全体が最初の引数として渡されます。また、DataFrame を Series オブジェクトのリストに変換しようとしましたが、これにより関数に Numpy 配列が渡されます (Scipy が Series のリストから Pandas を保持しない Numpy 配列への変換を実行するためだと思います)。シリーズ オブジェクト)。

4

1 に答える 1

5

Your call to curve_fit is incorrect. From the documentation:

xdata : An M-length sequence or an (k,M)-shaped array for functions with k predictors.

The independent variable where the data is measured.

ydata : M-length sequence

The dependent data — nominally f(xdata, ...)

In this case your independent variables xdata are the columns A to D, i.e. table[['A', 'B', 'C', 'D']], and your dependent variable ydata is table['Z_real'].

Also note that xdata should be a (k, M) array, where k is the number of predictor variables (i.e. columns) and M is the number of observations (i.e. rows). You should therefore transpose your input dataframe so that it is (4, M) rather than (M, 4), i.e. table[['A', 'B', 'C', 'D']].T.

The whole call to curve_fit might look something like this:

curve_fit(func, table[['A', 'B', 'C', 'D']].T, table['Z_real'])

Here's a complete example showing multiple linear regression:

import numpy as np
import pandas as pd
from scipy.optimize import curve_fit

X = np.random.randn(100, 4)     # independent variables
m = np.random.randn(4)          # known coefficients
y = X.dot(m)                    # dependent variable

df = pd.DataFrame(np.hstack((X, y[:, None])),
                  columns=['A', 'B', 'C', 'D', 'Z_real'])

def func(X, *params):
    return np.hstack(params).dot(X)

popt, pcov = curve_fit(func, df[['A', 'B', 'C', 'D']].T, df['Z_real'],
                       p0=np.random.randn(4))

print(np.allclose(popt, m))
# True
于 2016-02-05T22:00:17.350 に答える