1

wandb のドキュメントには、これを行う方法が説明されていないようですが、かなり一般的な使用例ではないでしょうか?

私はこのように望んでいたことをほとんど(完全ではありませんが)達成しましたが、少し不格好に見えますか?インスタンスself.aliasesにプロパティがあると予想していましたか?ArtifactCollection

ENTITY = os.environ.get("WANDB_ENTITY")
API_KEY = os.environ.get("WANDB_API_KEY")

def get_model_artifacts(key=None):
    wandb.login(key=key if key is not None else API_KEY)
    api = wandb.Api(overrides={"entity": ENTITY})
    model_names = [
        i
        for i in api.artifact_type(
            type_name="models", project="train"
        ).collections()
    ]
    for model in model_names:
        artifact = api.artifact("train/" + model.name + ":latest")
        model._attrs.update(artifact._attrs)
        model._attrs["metadata"] = json.loads(model._attrs["metadata"])
        model.aliases = [x["alias"] for x in model._attrs["aliases"]]
    return model_names

必要に応じてカスタムのgraph-qlクエリを作成するか、この不格好な方法を使用することを検討できると思います。

何か不足していますか?これを行うためのよりクリーンな方法はありますか?

この不格好な方法に欠けていることの1つは、古いエイリアスです-最新のモデルとそのエイリアスのみを表示します(「最新」と「v4」などとしましょう)-これがどのように表示されるか/すべきかはわかりませんが、古いエイリアス (つまり、アーティファクトの古いバージョンを指すエイリアス) も取得できることを望んでいました。ただし、これはそれほど重要ではありません。

編集- 彼らのSDKコードを数時間調べた後、私はこれを持っています(まだそれがいかに不格好であるかに満足していません):

ENTITY = os.environ.get("WANDB_ENTITY")
API_KEY = os.environ.get("WANDB_API_KEY")

def get_model_artifacts(key=None):
    wandb.login(key=key if key is not None else API_KEY)
    api = wandb.Api(overrides={"entity": ENTITY})
    model_artifacts = [
        a
        for a in api.artifact_type(
            type_name="models", project="train"
        ).collections()
    ]

    def get_alias_tuple(artifact_version):
        version = None
        aliases = []
        for a in artifact_version._attrs["aliases"]:
            if re.match(r"^v\d+$", a["alias"]):
                version = a["alias"]
            else:
                aliases.append(a["alias"])
        return version, aliases

    for model in model_artifacts:
        # artifact = api.artifact("train/" + model.name + ":latest")
        # model._attrs.update(artifact._attrs)
        # model._attrs["metadata"] = json.loads(model._attrs["metadata"])
        versions = model.versions()
        version_dict = dict(get_alias_tuple(version) for version in versions)
        model.version_dict = version_dict
        model.aliases = [
            x for key, val in model.version_dict.items() for x in [key] + val
        ]
    return model_artifacts
4

1 に答える 1