Spaces:
Runtime error
Runtime error
| """ | |
| RealCanvas-MJ4K | |
| A 16-GB-friendly Gradio Space that | |
| 1. streams the prompt dataset MohamedRashad/midjourney-detailed-prompts | |
| 2. generates realistic images using SDXL-Lightning | |
| 3. optionally displays random images from opendiffusionai/cc12m-4mp-realistic | |
| """ | |
| import gradio as gr | |
| import torch, os, random, json, requests | |
| from io import BytesIO | |
| from PIL import Image | |
| from datasets import load_dataset | |
| from huggingface_hub import hf_hub_download | |
| from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler | |
| # ------------------------------------------------- | |
| # 1. Load the prompt dataset (lazy streaming) | |
| # ------------------------------------------------- | |
| print("π Streaming prompt dataset β¦") | |
| ds_prompts = load_dataset( | |
| "MohamedRashad/midjourney-detailed-prompts", | |
| split="train", | |
| streaming=True | |
| ) | |
| prompt_pool = list(ds_prompts.shuffle(seed=42).take(500_000)) # β 5 MB RAM | |
| # ------------------------------------------------- | |
| # 2. Load SDXL-Lightning (fp16, 4-step, 4 GB VRAM) | |
| # ------------------------------------------------- | |
| MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0" | |
| print("π€ Loading SDXL-Lightning β¦") | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16, | |
| variant="fp16", | |
| use_safetensors=True | |
| ) | |
| pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
| # lightning LoRA | |
| lora_path = hf_hub_download( | |
| repo_id="ByteDance/SDXL-Lightning", | |
| filename="sdxl_lightning_4step_lora.safetensors" | |
| ) | |
| pipe.load_lora_weights(lora_path) | |
| pipe = pipe.to("cuda") if torch.cuda.is_available() else pipe.to("cpu") | |
| pipe.enable_attention_slicing() | |
| # ------------------------------------------------- | |
| # 3. Random CC12M-4MP image helper (optional demo) | |
| # ------------------------------------------------- | |
| print("πΈ Streaming CC12M-4MP-realistic β¦") | |
| ds_images = load_dataset( | |
| "opendiffusionai/cc12m-4mp-realistic", | |
| split="train", | |
| streaming=True | |
| ) | |
| img_pool = list(ds_images.shuffle(seed=42).take(1_000)) # β 10 MB RAM | |
| def random_cc12m_image(): | |
| sample = random.choice(img_pool) | |
| return sample["image"].resize((512, 512)) | |
| # ------------------------------------------------- | |
| # 4. Gradio UI | |
| # ------------------------------------------------- | |
| def generate(prompt: str, steps: int = 4, guidance: float = 0.0): | |
| if not prompt.strip(): | |
| prompt = random.choice(prompt_pool)["prompt"] | |
| image = pipe( | |
| prompt, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance | |
| ).images[0] | |
| return image.resize((768, 768)) | |
| with gr.Blocks(title="RealCanvas-MJ4K") as demo: | |
| gr.Markdown("# π¨ RealCanvas-MJ4K | Midjourney-level realism under 16 GB") | |
| with gr.Row(): | |
| prompt_in = gr.Textbox( | |
| label="Prompt (leave empty for random Midjourney-style prompt)", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| steps = gr.Slider(1, 8, value=4, step=1, label="Inference steps (SDXL-Lightning)") | |
| guidance = gr.Slider(0.0, 2.0, value=0.0, step=0.1, label="Guidance scale") | |
| btn = gr.Button("Generate", variant="primary") | |
| gallery = gr.Image(type="pil", label="Generated image") | |
| with gr.Accordion("πΈ Random CC12M-4MP sample", open=False): | |
| cc_btn = gr.Button("Show random CC12M-4MP image") | |
| cc_out = gr.Image(type="pil", label="Real photo from dataset") | |
| btn.click(generate, [prompt_in, steps, guidance], gallery) | |
| cc_btn.click(random_cc12m_image, outputs=cc_out) | |
| demo.queue(max_size=8).launch() |