LGM / app.py
multimodalart's picture
multimodalart HF Staff
[Admin maintenance] Support new ZeroGPU hardware (#7)
0c0924e
import os
import sys
import shlex
import subprocess
import tempfile
import ctypes
import spaces
import torch
# ---------------------------------------------------------------------------
# Blackwell (sm_120) shim: xformers' FA3/FA2/CutlassF kernels in the prebuilt
# torch-2.10/2.11 wheel all reject compute capability 12.0, so any call into
# `xformers.ops.memory_efficient_attention(...)` raises NotImplementedError.
# The LGM stack calls MEA directly in two places:
# - core/attention.py: 4D shape (B, M, H, K) (the dino-style Attention class)
# - mvdream/mv_unet.py: 3D shape (B*H, M, K) (the cross-attention block)
# Route both through torch SDPA. Must be installed BEFORE the imports that
# pull in core/attention.py and mvdream/mv_unet.py.
# ---------------------------------------------------------------------------
import xformers
import xformers.ops as _xops
def _xformers_mea_sdpa(query, key, value, attn_bias=None, p=0.0, scale=None,
op=None, **kwargs):
if query.dim() == 3:
# (B, M, K) -> single-head; add an H=1 axis.
q = query.unsqueeze(1)
k = key.unsqueeze(1)
v = value.unsqueeze(1)
squeeze_out = True
else:
# (B, M, H, K) -> (B, H, M, K)
q = query.transpose(1, 2)
k = key.transpose(1, 2)
v = value.transpose(1, 2)
squeeze_out = False
attn_mask = attn_bias
if hasattr(attn_mask, "materialize"):
try:
attn_mask = attn_mask.materialize(
shape=(q.shape[0], q.shape[1], q.shape[2], k.shape[2]),
dtype=q.dtype,
device=q.device,
)
except Exception:
attn_mask = None
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=p, scale=scale,
)
if squeeze_out:
return out.squeeze(1)
return out.transpose(1, 2)
_xops.memory_efficient_attention = _xformers_mea_sdpa
xformers.ops.memory_efficient_attention = _xformers_mea_sdpa
# ---------------------------------------------------------------------------
# Build `diff_gaussian_rasterization` from source against torch 2.10/2.11/cu128
# on the first GPU call. The vendored wheel/diff_gaussian_rasterization-...whl
# in this Space was built against torch 2.4 (`libcudart.so.11.0`) and won't
# load on Blackwell. We build the original graphdeco-inria fork (what
# core/gs.py imports).
# ---------------------------------------------------------------------------
CUDA_HOME = "/cuda-image/usr/local/cuda-13.0"
CUDA_LIBDIR = os.path.join(CUDA_HOME, "lib64")
@spaces.GPU(duration=600)
def _first_gpu_setup():
try:
import diff_gaussian_rasterization # noqa: F401
return
except ImportError:
pass
patch_dir = tempfile.mkdtemp(prefix="torch_cuda_patch_")
with open(os.path.join(patch_dir, "sitecustomize.py"), "w") as f:
f.write(
"try:\n"
" import torch.utils.cpp_extension as _c\n"
" _c._check_cuda_version = lambda *a, **k: None\n"
"except Exception:\n"
" pass\n"
)
env = os.environ.copy()
env["CUDA_HOME"] = CUDA_HOME
env["CUDA_PATH"] = CUDA_HOME
env["PATH"] = os.path.join(CUDA_HOME, "bin") + os.pathsep + env.get("PATH", "")
env["PYTHONPATH"] = patch_dir + os.pathsep + env.get("PYTHONPATH", "")
env["TORCH_CUDA_ARCH_LIST"] = "12.0"
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "--no-deps",
"setuptools", "wheel", "ninja", "packaging"],
)
# Build the vendored `diff-gaussian-rasterization/` source in this repo.
# This Space ships a custom fork that returns 4 outputs (image, radii,
# depth, alpha); the upstream graphdeco-inria release returns only 2 and
# would crash `core/gs.py` at unpack time. The vendored tree is what
# LGM's training/inference code was written against.
vendored_dgr = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"diff-gaussian-rasterization")
subprocess.check_call(
[sys.executable, "-m", "pip", "install",
"--no-build-isolation", "--no-deps",
vendored_dgr],
env=env,
)
_first_gpu_setup()
try:
ctypes.CDLL(os.path.join(CUDA_LIBDIR, "libcudart.so.13"), mode=ctypes.RTLD_GLOBAL)
os.environ["LD_LIBRARY_PATH"] = CUDA_LIBDIR + os.pathsep + os.environ.get("LD_LIBRARY_PATH", "")
except OSError:
pass
# ---------------------------------------------------------------------------
# Now the usual app imports.
# ---------------------------------------------------------------------------
import tyro
import imageio
import numpy as np
import tqdm
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from safetensors.torch import load_file
import rembg
import gradio as gr
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id="ashawkey/LGM", filename="model_fp16_fixrot.safetensors")
import kiui
from kiui.op import recenter
from kiui.cam import orbit_camera
from core.options import AllConfigs, Options
from core.models import LGM
from mvdream.pipeline_mvdream import MVDreamPipeline
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
GRADIO_VIDEO_PATH = 'gradio_output.mp4'
GRADIO_PLY_PATH = 'gradio_output.ply'
opt = Options(
input_size=256,
up_channels=(1024, 1024, 512, 256, 128), # one more decoder
up_attention=(True, True, True, False, False),
splat_size=128,
output_size=512, # render & supervise Gaussians at a higher resolution.
batch_size=8,
num_views=8,
gradient_accumulation_steps=1,
mixed_precision='bf16',
resume=ckpt_path,
)
# model
model = LGM(opt)
# resume pretrained checkpoint
if opt.resume is not None:
if opt.resume.endswith('safetensors'):
ckpt = load_file(opt.resume, device='cpu')
else:
ckpt = torch.load(opt.resume, map_location='cpu')
model.load_state_dict(ckpt, strict=False)
print(f'[INFO] Loaded checkpoint from {opt.resume}')
else:
print(f'[WARN] model randomly initialized, are you sure?')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.half().to(device)
model.eval()
tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
proj_matrix = torch.zeros(4, 4, dtype=torch.float32).to(device)
proj_matrix[0, 0] = 1 / tan_half_fov
proj_matrix[1, 1] = 1 / tan_half_fov
proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
proj_matrix[2, 3] = 1
# load dreams
pipe_text = MVDreamPipeline.from_pretrained(
'ashawkey/mvdream-sd2.1-diffusers', # remote weights
torch_dtype=torch.float16,
trust_remote_code=True,
)
pipe_text = pipe_text.to(device)
pipe_image = MVDreamPipeline.from_pretrained(
"ashawkey/imagedream-ipmv-diffusers", # remote weights
torch_dtype=torch.float16,
trust_remote_code=True,
)
pipe_image = pipe_image.to(device)
# load rembg
bg_remover = rembg.new_session()
# process function
@spaces.GPU
def process(input_image, prompt, prompt_neg='', input_elevation=0, input_num_steps=30, input_seed=42):
# seed
kiui.seed_everything(input_seed)
os.makedirs(opt.workspace, exist_ok=True)
output_video_path = os.path.join(opt.workspace, GRADIO_VIDEO_PATH)
output_ply_path = os.path.join(opt.workspace, GRADIO_PLY_PATH)
# text-conditioned
if input_image is None:
mv_image_uint8 = pipe_text(prompt, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=7.5, elevation=input_elevation)
mv_image_uint8 = (mv_image_uint8 * 255).astype(np.uint8)
# bg removal
mv_image = []
for i in range(4):
image = rembg.remove(mv_image_uint8[i], session=bg_remover) # [H, W, 4]
# to white bg
image = image.astype(np.float32) / 255
image = recenter(image, image[..., 0] > 0, border_ratio=0.2)
image = image[..., :3] * image[..., -1:] + (1 - image[..., -1:])
mv_image.append(image)
# image-conditioned (may also input text, but no text usually works too)
else:
input_image = np.array(input_image) # uint8
# bg removal
carved_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4]
mask = carved_image[..., -1] > 0
image = recenter(carved_image, mask, border_ratio=0.2)
image = image.astype(np.float32) / 255.0
image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
mv_image = pipe_image(prompt, image, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=5.0, elevation=input_elevation)
mv_image_grid = np.concatenate([
np.concatenate([mv_image[1], mv_image[2]], axis=1),
np.concatenate([mv_image[3], mv_image[0]], axis=1),
], axis=0)
# generate gaussians
input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) # [4, 256, 256, 3], float32
input_image = torch.from_numpy(input_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
rays_embeddings = model.prepare_default_rays(device, elevation=input_elevation)
input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
with torch.no_grad():
with torch.autocast(device_type='cuda', dtype=torch.float16):
# generate gaussians
gaussians = model.forward_gaussians(input_image)
# save gaussians
model.gs.save_ply(gaussians, output_ply_path)
# render 360 video
images = []
elevation = 0
if opt.fancy_video:
azimuth = np.arange(0, 720, 4, dtype=np.int32)
for azi in tqdm.tqdm(azimuth):
cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
# cameras needed by gaussian rasterizer
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
scale = min(azi / 360, 1)
image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)['image']
images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
else:
azimuth = np.arange(0, 360, 2, dtype=np.int32)
for azi in tqdm.tqdm(azimuth):
cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
# cameras needed by gaussian rasterizer
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image']
images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
images = np.concatenate(images, axis=0)
imageio.mimwrite(output_video_path, images, fps=30)
return mv_image_grid, output_video_path, output_ply_path
# gradio UI
_TITLE = '''LGM: Large Multi-View Gaussian Model for High-Resolution 3D Content Creation'''
_DESCRIPTION = '''
<div>
<a style="display:inline-block" href="https://me.kiui.moe/lgm/"><img src='https://img.shields.io/badge/public_website-8A2BE2'></a>
<a style="display:inline-block; margin-left: .5em" href="https://github.com/3DTopia/LGM"><img src='https://img.shields.io/github/stars/3DTopia/LGM?style=social'/></a>
</div>
* Input can be only text, only image, or both image and text.
* Output is a `ply` file containing the 3D Gaussians, please check our [repo](https://github.com/3DTopia/LGM/blob/main/readme.md) for visualization and mesh conversion.
* If you find the output unsatisfying, try using different seeds!
'''
block = gr.Blocks(title=_TITLE).queue()
with block:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown('# ' + _TITLE)
gr.Markdown(_DESCRIPTION)
with gr.Row(variant='panel'):
with gr.Column(scale=1):
# input image
input_image = gr.Image(label="image", type='pil')
# input prompt
input_text = gr.Textbox(label="prompt")
# negative prompt
input_neg_text = gr.Textbox(label="negative prompt", value='ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate')
# elevation
input_elevation = gr.Slider(label="elevation", minimum=-90, maximum=90, step=1, value=0)
# inference steps
input_num_steps = gr.Slider(label="inference steps", minimum=1, maximum=100, step=1, value=30)
# random seed
input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=0)
# gen button
button_gen = gr.Button("Generate")
with gr.Column(scale=1):
with gr.Tab("Video"):
# final video results
output_video = gr.Video(label="video")
# ply file
output_file = gr.File(label="3D Gaussians (ply format)")
with gr.Tab("Multi-view Image"):
# multi-view results
output_image = gr.Image(interactive=False, show_label=False)
button_gen.click(process, inputs=[input_image, input_text, input_neg_text, input_elevation, input_num_steps, input_seed], outputs=[output_image, output_video, output_file])
gr.Examples(
examples=[
"data_test/frog_sweater.jpg",
"data_test/bird.jpg",
"data_test/boy.jpg",
"data_test/cat_statue.jpg",
"data_test/dragontoy.jpg",
"data_test/gso_rabbit.jpg",
],
inputs=[input_image],
outputs=[output_image, output_video, output_file],
fn=lambda x: process(input_image=x, prompt=''),
cache_examples=False,
label='Image-to-3D Examples'
)
gr.Examples(
examples=[
"teddy bear",
"hamburger",
"oldman's head sculpture",
"headphone",
"motorbike",
"mech suit"
],
inputs=[input_text],
outputs=[output_image, output_video, output_file],
fn=lambda x: process(input_image=None, prompt=x),
cache_examples=False,
label='Text-to-3D Examples'
)
block.launch()