sd_vid / loadmodel.py
waveydaveygravy's picture
Upload 6 files
78415c4
#@title Load Model
import sys
from omegaconf import OmegaConf
import torch
sys.path.append("generative-models")
from sgm.util import default, instantiate_from_config
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
def load_model(
config: str,
device: str,
num_frames: int,
num_steps: int,
):
config = OmegaConf.load(config)
config.model.params.conditioner_config.params.emb_models[
0
].params.open_clip_embedding_config.params.init_device = device
config.model.params.sampler_config.params.num_steps = num_steps
config.model.params.sampler_config.params.guider_config.params.num_frames = (
num_frames
)
with torch.device(device):
model = instantiate_from_config(config.model).to(device).eval().requires_grad_(False)
filter = DeepFloydDataFiltering(verbose=False, device=device)
return model, filter
if version == "svd":
num_frames = 14
num_steps = 25
# output_folder = default(output_folder, "outputs/simple_video_sample/svd/")
model_config = "generative-models/scripts/sampling/configs/svd.yaml"
elif version == "svd_xt":
num_frames = 25
num_steps = 30
# output_folder = default(output_folder, "outputs/simple_video_sample/svd_xt/")
model_config = "generative-models/scripts/sampling/configs/svd_xt.yaml"
else:
raise ValueError(f"Version {version} does not exist.")
device = "cuda" if torch.cuda.is_available() else "cpu"
model, filter = load_model(
model_config,
device,
num_frames,
num_steps,
)
# move models expect unet to cpu
model.conditioner.cpu()
model.first_stage_model.cpu()
# change the dtype of unet
model.model.to(dtype=torch.float16)
torch.cuda.empty_cache()
model = model.requires_grad_(False)