Bridge MLP β€” Wan2.2 latents β†’ V-JEPA-2.1 features

Lightweight MLP head that maps Wan2.2-T2V-A14B VAE latents into the V-JEPA-2.1 ViT-G/384 encoder feature space. Used by the gradient-guidance refinement mode (ext3_grad / vjepa_grad) in the self-refine pipeline.

  • Base model: Wan2.2-T2V-A14B (Wan-AI/Wan2.2-T2V-A14B-Diffusers)
  • VAE latent: 16 channels, resolution 81Γ—480Γ—832
  • Backbone teacher: V-JEPA-2.1 ViT-Gigantic-384 (hierarchical, 8-layer concat)
  • Training data: OpenVid (~64k clips, Wan2.2 VAE latents cached at cache/openvid_wan22_64k)
  • Loss: L1 + relational (alpha=0.5, beta=0.1, rel_subset=256)
  • Output: 6656-d features on the V-JEPA token grid (24, 24, 24)
  • Selected by: EMA-of-loss minimum (step 2490, EMA decay 0.98)

Usage

import torch
from bridge.models.bridge_net import BridgeMLP
from huggingface_hub import hf_hub_download

ckpt = hf_hub_download(repo_id="itruonghai/wan22-vjepa-bridge",
                       filename="bridge_mlp_best.pt")
state = torch.load(ckpt, map_location="cpu")
bridge = BridgeMLP(**state["config"])
bridge.load_state_dict(state["model"])

In the self-refine pipeline this is wired up automatically β€” pass --bridge_ckpt (or rely on the default path).

Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support