2

私は Jax を学んでいますが、奇妙な質問に遭遇しました。次のようにコードを使用すると、

import numpy as np
import jax.numpy as jnp
from jax import grad, value_and_grad
from jax import vmap # for auto-vectorizing functions
from functools import partial # for use with vmap
from jax import jit # for compiling functions for speedup
from jax import random # stax initialization uses jax.random
from jax.experimental import stax # neural network library
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax # neural network layers
import matplotlib.pyplot as plt # visualization

net_init, net_apply = stax.serial(
    Dense(40), Relu,
    Dense(40), Relu,
    Dense(40), Relu,
    Dense(1)
)
rng = random.PRNGKey(0)
in_shape = (-1, 1,)
out_shape, params = net_init(rng, in_shape)

def loss(params, X, Y):
    predictions = net_apply(params, X)
    return jnp.mean((Y - predictions)**2)

@jit
def step(i, opt_state, x1, y1):
    p = get_params(opt_state)
    val, g = value_and_grad(loss)(p, x1, y1)
    return val, opt_update(i, g, opt_state)

opt_init, opt_update, get_params = optimizers.adam(step_size=1e-3)
opt_state = opt_init(params)

val_his = []
for i in range(1000):
    val, opt_state = step(i, opt_state, xrange_inputs, targets)
    val_his.append(val)
params = get_params(opt_state)
val_his = jnp.array(val_his)

xrange_inputs = jnp.linspace(-5,5,100).reshape((100, 1)) # (k, 1)
targets = jnp.cos(xrange_inputs)
predictions = vmap(partial(net_apply, params))(xrange_inputs)
losses = vmap(partial(loss, params))(xrange_inputs, targets) # per-input loss

plt.plot(xrange_inputs, predictions, label='prediction')
plt.plot(xrange_inputs, losses, label='loss')
plt.plot(xrange_inputs, targets, label='target')
plt.legend()

ニューラル ネットワークは関数を適切に近似できcos(x)ます。

しかし、ニューラルネットワークの部分を自分で次のように書き直すと

import numpy as np
import jax.numpy as jnp
from jax import grad, value_and_grad
from jax import vmap # for auto-vectorizing functions
from functools import partial # for use with vmap
from jax import jit # for compiling functions for speedup
from jax import random # stax initialization uses jax.random
from jax.experimental import stax # neural network library
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax # neural network layers
import matplotlib.pyplot as plt # visualization
import numpy as np
from jax.experimental import optimizers
from jax.tree_util import tree_multimap

def initialize_NN(layers, key):        
    params = []
    num_layers = len(layers)
    keys = random.split(key, len(layers))
    a = jnp.sqrt(0.1)
    #params.append(a)
    for l in range(0, num_layers-1):
        W = xavier_init((layers[l], layers[l+1]), keys[l])
        b = jnp.zeros((layers[l+1],), dtype=np.float32)
        params.append((W,b))
    return params

def xavier_init(size, key):
    in_dim = size[0]
    out_dim = size[1]      
    xavier_stddev = jnp.sqrt(2/(in_dim + out_dim))
    return random.truncated_normal(key, -2, 2, shape=(out_dim, in_dim), dtype=np.float32)*xavier_stddev
    
def net_apply(params, X):
    num_layers = len(params)
    #a = params[0]
    for l in range(0, num_layers-1):
        W, b = params[l]
        X = jnp.maximum(0, jnp.add(jnp.dot(X, W.T), b))
    W, b = params[-1]
    Y = jnp.dot(X, W.T)+ b
    Y = jnp.squeeze(Y)
    return Y
    
def loss(params, X, Y):
    predictions = net_apply(params, X)
    return jnp.mean((Y - predictions)**2)

key = random.PRNGKey(1)
layers = [1,40,40,40,1]
params = initialize_NN(layers, key)

@jit
def step(i, opt_state, x1, y1):
    p = get_params(opt_state)
    val, g = value_and_grad(loss)(p, x1, y1)
    return val, opt_update(i, g, opt_state)

opt_init, opt_update, get_params = optimizers.adam(step_size=1e-3)
opt_state = opt_init(params)

xrange_inputs = jnp.linspace(-5,5,100).reshape((100, 1)) # (k, 1)
targets = jnp.cos(xrange_inputs)

val_his = []
for i in range(1000):
    val, opt_state = step(i, opt_state, xrange_inputs, targets)
    val_his.append(val)
params = get_params(opt_state)
val_his = jnp.array(val_his)

predictions = vmap(partial(net_apply, params))(xrange_inputs)
losses = vmap(partial(loss, params))(xrange_inputs, targets) # per-input loss

plt.plot(xrange_inputs, predictions, label='prediction')
plt.plot(xrange_inputs, losses, label='loss')
plt.plot(xrange_inputs, targets, label='target')
plt.legend()

私のニューラル ネットワークは常に定数に収束しますが、これは極小値によってトラップされているようです。しかし、最初の部分と同じニューラル ネットワークがうまく機能します。私はそれについて本当に混乱しています。

唯一の違いは、初期化、ニューラル ネットワーク部分、およびパラメーターの設定ですparams。私は別の初期化を試みましたが、違いはありません。最適化の設定paramsが間違っているのか、収束に至りません。

4

1 に答える 1

1

質問者はすでにこの問題を自分で解決しているようです。ただし、この質問者が直面したまったく同じ問題に直面したため、実際に何が起こったのかを説明したいと思います。

実際、質問者が を削除する前のニューラル ネットワークのぎこちない動作はY = jnp.squeeze(Y)、 の形状Ypredictionsの関数定義では、loss(params, X, Y)実際には異なる形状を持っていました:predictionsは列ベクトル (サイズ(N, 1)) であり、「スクイーズ」操作の後は、次のように修正されました。質問者自身Yは、行ベクトル (サイズ(1, N)) になります。

NumPy と JAX の NumPy (実際には MATLAB にもあります) には、配列操作のブロードキャストと呼ばれる機能があります。この機能により、インタープリターは次のような計算を行います。

\begin{pmatrix}
a_{1}\\
a_{2}\\
\vdots \\
a_{m}
\end{pmatrix} -\begin{pmatrix}
b_{1} & b_{2} & \cdots  & b_{n}
\end{pmatrix} =\begin{pmatrix}
a_{1} -b_{1} & a_{1} -b_{2} & \cdots  & a_{1} -b_{n}\\
a_{2} -b_{1} & a_{2} -b_{2} & \cdots  & a_{2} -b_{n}\\
\vdots  & \vdots  & \ddots  & \vdots \\
a_{m} -b_{1} & a_{m} -b_{2} & \cdots  & a_{m} -b_{n}
\end{pmatrix}

(この式は LaTeX によって解釈される必要があります)

したがって、質問者自身の修正の前に、Y - 予測は実際には shape の行列であり、この N*N 行列のすべてのエントリ(N, N)np.means()平均しましたが、これはもちろん、計算したい目的の MSELoss ではなく、奇妙な収束動作を引き起こしました。アスカーが示した。

于 2022-01-22T09:42:13.143 に答える