import matplotlib.pyplot as plt
import numpy as np
data = np.array([1,1,-1,-1,1])
cmap = np.array([(1,0,0), (0,1,0)])
uniqdata, idx = np.unique(data, return_inverse=True)
N = len(data)
fig, ax = plt.subplots()
plt.scatter(np.zeros(N), np.arange(1, N+1), s=100, c=cmap[idx])
plt.grid()
plt.show()
収量
説明:
を出力するnp.unique(data, return_inverse=True)
と、配列のタプルが返されることがわかります。
In [71]: np.unique(data, return_inverse=True)
Out[71]: (array([-1, 1]), array([1, 1, 0, 0, 1]))
最初の配列は、 の一意の値data
が -1 と 1 であることを示しています。2 番目の配列は、-1 の場合は値 0 を割り当て、 1 の場合はdata
値 1 を割り当てdata
ます。基本的に、np.unique
に変換でき[1,1,-1,-1,1]
ます[1, 1, 0, 0, 1]
。これcmap[idx]
は RGB 値の配列です。
In [74]: cmap[idx]
Out[74]:
array([[0, 1, 0],
[0, 1, 0],
[1, 0, 0],
[1, 0, 0],
[0, 1, 0]])
これは、NumPy 配列に対するいわゆる「ファンシー インデックス」のアプリケーションです。cmap[0]
の最初の行ですcmap
。cmap[1]
の 2 行目ですcmap
。cmap[idx]
の i 番目の要素が であるような配列cmap[idx]
ですcmap[idx[i]]
。したがって、cmap[idx]
i 番目の行が である 2D 配列になることになりcmap[idx[i]]
ます。したがってcmap[idx]
、RGB カラー値のシーケンスと考えることができます。
複数のドットのセットがあり、それらを列にプロットしたい場合、私が考えることができる最も簡単な方法は、のax.scatter
リストごとに1回呼び出すことですdata
:
import matplotlib.pyplot as plt
import numpy as np
def plot_data(ax, data, xval):
N = len(data)
uniqdata, idx = np.unique(data, return_inverse=True)
ax.scatter(np.ones(N)*xval, np.arange(1, N+1), s=100, c=cmap[idx])
cmap = np.array([(1,0,0), (0,1,0)])
fig, ax = plt.subplots()
data = np.array([1,1,-1,-1,1])
data2 = np.array([1,-1,1,1,-1])
plot_data(ax, data, 0)
plot_data(ax, data2, 1)
plt.grid()
plt.show()
これの良いところは、比較的理解しやすいことです。これの悪い点は、ax.scatter
複数回呼び出すことです。大量のデータ セットがある場合は、データを照合して1 回呼び出すax.scatter
方が効率的です。これは Matplotlib の方が高速ですが、コーディングが少し複雑になります。
import matplotlib.pyplot as plt
import numpy as np
import itertools as IT
def plot_dots(ax, datasets):
N = sum(len(data) for data in datasets)
x = np.fromiter(
(i for i, data in enumerate(datasets) for j in np.arange(len(data))),
dtype='float', count=N)
y = np.fromiter(
(j for data in datasets for j in np.arange(1, len(data)+1)),
dtype='float', count=N)
c = np.fromiter(
(val for data in datasets
for rgb in cmap[np.unique(data, return_inverse=True)[-1]]
for val in rgb),
dtype='float', count=3*N).reshape(-1,3)
ax.scatter(x, y, s=100, c=c)
cmap = np.array([(1,0,0), (0,1,0)])
fig, ax = plt.subplots()
N = 100
datasets = [np.random.randint(2, size=5) for i in range(N)]
plot_dots(ax, datasets)
plt.grid()
plt.show()
参考文献: