1

私は最近、https: //arxiv.org/abs/2102.07501 で説明されているように、Annealed Flow Transport Method の実装に取り​​組んでいます。ある時点でのタスクは、SGD を使用して正規化フローを学習することにより、特定の損失関数を最小化することです。この問題がもたらすいくつかのトピックについて多くの論文を研究しましたが、アイデアを結び付ける方法がわかりません。だから、ここに問題があります:

分布 p のサンプル (x_1,...,x_N) が与えられたとします。ここで、(T(x_1),...,T(x_N)) がターゲット分布 q の適切なサンプルになるように、各粒子を輸送する正規化フロー T を学習したいと考えています。前述のソースで説明されているように、これは T(p) と q のカルバック ライブラー ダイバージェンスを最小化することによって行われます。結果として得られる損失関数 (最小化したい関数) には、L または L(T) というラベルが付けられます。

著者はアルゴリズムを非常に詳細に説明していますが、この時点では「L を最小化するために SGD を使用して T を学習する」とだけ述べています。

私の意図は、TensorFlow と Keras を使用し、L をカスタム損失関数として使用し、著者が示唆するように Adam オプティマイザを使用することでしたが、現状では、ここに私のコードがあります:

def LearnFlow_Test(train_iters, x_train, W_train, x_val, W_val):
    
    # Initialize
    
    identity = lambda x: x # Initialize flow
    flows = np.array(identity)
    
    y_true = np.array([f_target(identity(x)) for x in x_val])
    y_pred = np.array([f_initial(x)/jacobian_det(identity,x) for x in x_val])
        
    val_losses = loss_function(y_true, y_pred)
    
    # Learn
    
    for j in range(train_iters):
        
        # Compute training loss
        
        y_true = np.array([f_target(identity(x)) for x in x_train])
        y_pred = np.array([f_initial(x)/jacobian_det(identity,x) for x in x_train])
        
        train_loss = loss_function(y_true, y_pred)
        
        """        
        Update flow using SGD to minimize train_loss
        minimizing_flow =
        
        """         
        
        # Update list of flows & list of validation losses
        
        flows = np.append(flows, minimizing_flow)
        
        # Compute new validation loss and update the list
        
        y_true = np.array([f_target(minimizing_flow(x)) for x in x_val])
        y_pred = np.array([f_initial(x)/jacobian_det(minimizing_flow,x) for x in x_val])
        
        val_losses = np.append(val_losses,[loss_function(y_true, y_pred)])a
        
        
        
    return flows[np.argmin(val_losses)] # Return flow with the smallest validation error 

既存のコードの検索がうまくいかなかったので、アドバイスをいただければ幸いです。

どうもありがとう、クリスチャン

4

0 に答える 0