0

説明

A3C モデルで PTAN ライブラリを使用しており、wandb スイープを使用しようとしていますが、いくつかの奇妙な問題に遭遇しました。これがスイープに関するバグであるかどうかはわかりません (単純なモデルを使用する場合関与するスレッドがなくても適切に動作します) または私は何か間違ったことをしています。

再現方法

トレーニング機能:

def train(conf):
    batch = []
    step_idx = 0
    epoch = conf['epochs']
    try:
        with commune.RewardTracker(writer, stop_reward=conf['reward_bound']) as tracker:
            with ptan.common.utils.TBMeanTracker(writer, batch_size=100) as tb_tracker:
                while True:
                    if step_idx == epoch:
                        break
                    train_entry = train_queue.get()
                    if isinstance(train_entry, TotalReward):
                        if tracker.reward(train_entry.reward, step_idx):
                            break
                        continue
                    if isinstance(train_entry, TotalProfit):
                        tracker.profits(train_entry.total_profit, train_entry.curr_profit, step_idx)
                        continue
                    step_idx += 1
                    if step_idx % 100 == 0:
                        torch.save(net.state_dict(), os.path.join(SAVING_FOLDER, PROJECT_NAME))

                    batch.append(train_entry)
                    if len(batch) < conf['batch_size']:
                        continue

                    states_v, actions_t, vals_ref_v = commune.unpack_batch(batch, net,
                                                                           last_val_gamma=conf['gamma'] ** conf['reward_steps'],
                                                                           device=device)
                    batch.clear()

                    optimizer.zero_grad()
                    logits_v, value_v = net(states_v)

                    loss_value_v = F.mse_loss(value_v.squeeze(-1), vals_ref_v)

                    log_prob_v = F.log_softmax(logits_v, dim=1)
                    adv_v = vals_ref_v - value_v.detach()
                    log_prob_actions_v = adv_v * log_prob_v[range(conf['batch_size']), actions_t]
                    loss_policy_v = -log_prob_actions_v.mean()

                    prob_v = F.softmax(logits_v, dim=1)
                    entropy_loss_v = conf['entropy_beta'] * (prob_v * log_prob_v).sum(dim=1).mean()

                    loss_v = entropy_loss_v + loss_value_v + loss_policy_v
                    loss_v.backward()
                    nn_utils.clip_grad_norm_(net.parameters(), conf['clip_grad'])
                    optimizer.step()

                    tb_tracker.track("advantage", adv_v, step_idx)
                    tb_tracker.track("values", value_v, step_idx)
                    tb_tracker.track("batch_rewards", vals_ref_v, step_idx)
                    tb_tracker.track("loss_entropy", entropy_loss_v, step_idx)
                    tb_tracker.track("loss_policy", loss_policy_v, step_idx)
                    tb_tracker.track("loss_value", loss_value_v, step_idx)
                    tb_tracker.track("loss_total", loss_v, step_idx)
    finally:
        for p in data_proc_list:
            p.terminate()
            p.join()

主な機能:

if __name__ == "__main__":
    mp.set_start_method('fork')
    device = torch.device("cuda:0" if use_cuda else "cpu")

    with open(r'sweep_config.yaml') as file:
        sweep_config = yaml.load(file, Loader=yaml.FullLoader)

    logs_dir_name = "a3c_stock"
    wandb.tensorboard.patch(root_logdir=logs_dir_name)

    sweep_id = wandb.sweep(sweep_config, project="sweep_project", entity="vildnex")
    wandb.init(config=config_default)

    config = wandb.config

    writer = SummaryWriter(comment=logs_dir_name)

    env = make_env(config)
    net = commune.AtariA2C(env.observation_space.shape, env.action_space.n).to(device)
    net.share_memory()

    if not os.path.isdir(SAVING_FOLDER):
        os.mkdir(SAVING_FOLDER)

    if os.path.isfile(os.path.join(SAVING_FOLDER, PROJECT_NAME)):
        net.load_state_dict(torch.load(os.path.join(SAVING_FOLDER, PROJECT_NAME), map_location=device))

    optimizer = optim.RMSprop(net.parameters(), lr=config.learning_rate, eps=1e-3)

    train_queue = mp.Queue(maxsize=config.processes_count)
    data_proc_list = []
    dict_conf = dict(config)
    for _ in range(config.processes_count):
        data_proc = mp.Process(target=data_func, args=(net, device, train_queue, dict_conf))
        data_proc.start()
        data_proc_list.append(data_proc)

    wandb.agent(sweep_id, lambda: train(dict_conf))

エラーメッセージ:

Exception in thread Thread-6:
Traceback (most recent call last):
  File "<PATH>/venv/lib/python3.9/site-packages/wandb/agents/pyagent.py", line 303, in _run_job
    self._function()
  File "<PATH>/RL_TraningBot/EXPERIMENTS/A3C_TEST.py", line 191, in <lambda>
    wandb.agent(sweep_id, lambda: train(dict_conf))
  File "<PATH>/RL_TraningBot/EXPERIMENTS/A3C_TEST.py", line 105, in train
    tracker.profits(train_entry.total_profit, train_entry.curr_profit, step_idx)
  File "<PATH>/RL_TraningBot/EXPERIMENTS/commune.py", line 118, in profits
    self.writer.add_scalar("total_profit", total_profit, frame)
  File "<PATH>/venv/lib/python3.9/site-packages/torch/utils/tensorboard/writer.py", line 344, in add_scalar
    self._get_file_writer().add_summary(
  File "<PATH>/venv/lib/python3.9/site-packages/torch/utils/tensorboard/writer.py", line 250, in _get_file_writer
    self.file_writer = FileWriter(self.log_dir, self.max_queue,
  File "<PATH>/venv/lib/python3.9/site-packages/torch/utils/tensorboard/writer.py", line 60, in __init__
    self.event_writer = EventFileWriter(
  File "<PATH>/venv/lib/python3.9/site-packages/wandb/integration/tensorboard/monkeypatch.py", line 157, in __init__
    _notify_tensorboard_logdir(logdir, save=save, root_logdir=root_logdir_arg)
  File "<PATH>/venv/lib/python3.9/site-packages/wandb/integration/tensorboard/monkeypatch.py", line 167, in _notify_tensorboard_logdir
    wandb.run._tensorboard_callback(logdir, save=save, root_logdir=root_logdir)
  File "<PATH>/venv/lib/python3.9/site-packages/wandb/sdk/wandb_run.py", line 804, in _tensorboard_callback
    self._backend.interface.publish_tbdata(logdir, save, root_logdir)
  File "<PATH>/venv/lib/python3.9/site-packages/wandb/sdk/interface/interface.py", line 202, in publish_tbdata
    self._publish(rec)
  File "<PATH>/venv/lib/python3.9/site-packages/wandb/sdk/interface/interface.py", line 518, in _publish
    raise Exception("The wandb backend process has shutdown")
Exception: The wandb backend process has shutdown

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/lib/python3.9/threading.py", line 954, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.9/threading.py", line 892, in run
    self._target(*self._args, **self._kwargs)
  File "<PATH>/venv/lib/python3.9/site-packages/wandb/agents/pyagent.py", line 308, in _run_job
    wandb.finish(exit_code=1)
  File "<PATH>/venv/lib/python3.9/site-packages/wandb/sdk/wandb_run.py", line 2374, in finish
    wandb.run.finish(exit_code=exit_code)
  File "<PATH>/venv/lib/python3.9/site-packages/wandb/sdk/wandb_run.py", line 1144, in finish
    if self._wl and len(self._wl._global_run_stack) > 0:
  File "<PATH>/venv/lib/python3.9/site-packages/wandb/sdk/wandb_setup.py", line 234, in __getattr__
    return getattr(self._instance, name)
AttributeError: 'NoneType' object has no attribute '_global_run_stack'

環境

  • OS: マンジャロ 5.21.5
  • 環境: PyCharm ローカル
  • Python バージョン: 3.9
4

0 に答える 0