import gradio as gr import cv2 import numpy as np import os import tempfile import time import axengine as axe import common import imgproc rgb_range=255 scale=2 def from_numpy(x): return x if isinstance(x, np.ndarray) else np.array(x) def quantize(img, rgb_range): pixel_range = 255 / rgb_range return np.round(np.clip(img * pixel_range, 0, 255)) / pixel_range # 初始化EDSR和ESPCN模型 def init_SRmodel(EDSR_path="../model_convert/axmodel/edsr_baseline_x2_1.axmodel", ESPCN_path="../model_convert/axmodel/espcn_x2_T9.axmodel"): EDSR_session = axe.InferenceSession(EDSR_path) ESPCN_session = axe.InferenceSession(ESPCN_path) return [EDSR_session, ESPCN_session] SR_sessions=init_SRmodel() def EDSR_infer(frame, EDSR_session=SR_sessions[0]): output_names = [x.name for x in EDSR_session.get_outputs()] input_name = EDSR_session.get_inputs()[0].name lr_y_image, = common.set_channel(frame, n_channels=3) lr_y_image, = common.np_prepare(lr_y_image, rgb_range=rgb_range) sr = EDSR_session.run(output_names, {input_name: lr_y_image}) if isinstance(sr, (list, tuple)): sr = from_numpy(sr[0]) if len(sr) == 1 else [from_numpy(x) for x in sr] else: sr = from_numpy(sr) sr = quantize(sr, rgb_range).squeeze(0) normalized = sr * 255 / rgb_range ndarr = normalized.transpose(1, 2, 0).astype(np.uint8) return ndarr def ESPCN_infer(frame, ESPCN_session=SR_sessions[1]): output_names = [x.name for x in ESPCN_session.get_outputs()] input_name = ESPCN_session.get_inputs()[0].name lr_y_image, lr_cb_image, lr_cr_image = imgproc.preprocess_one_frame(frame) bic_cb_image = cv2.resize(lr_cb_image, (int(lr_cb_image.shape[1] * scale), int(lr_cb_image.shape[0] * scale)), interpolation=cv2.INTER_CUBIC) bic_cr_image = cv2.resize(lr_cr_image, (int(lr_cr_image.shape[1] * scale), int(lr_cr_image.shape[0] * scale)), interpolation=cv2.INTER_CUBIC) sr = ESPCN_session.run(output_names, {input_name: lr_y_image}) if isinstance(sr, (list, tuple)): sr = from_numpy(sr[0]) if len(sr) == 1 else [from_numpy(x) for x in sr] else: sr = from_numpy(sr) ndarr = imgproc.array_to_image(sr) sr_y_image = ndarr.astype(np.float32) / 255.0 sr_ycbcr_image = cv2.merge([sr_y_image[:, :, 0], bic_cb_image, bic_cr_image]) sr_image = imgproc.ycbcr_to_bgr(sr_ycbcr_image) sr_image = np.clip(sr_image* 255.0, 0 , 255).astype(np.uint8) return sr_image # ====================== # 模拟超分辨率模型 # ====================== def EDSR_MODEL(input_data, is_video=False): if is_video: output_frames = [] for frame in input_data: out = EDSR_infer(frame=frame) output_frames.append(out) return output_frames else: out = EDSR_infer(frame=input_data) return out def ESPCN_MODEL(input_data, is_video=False): if is_video: output_frames = [] for frame in input_data: out = ESPCN_infer(frame=frame) output_frames.append(out) return output_frames else: out = ESPCN_infer(frame=input_data) return out # ====================== # 全局状态(单用户) # ====================== class AppState: def __init__(self): self.original_img = None # 原始图(BGR, 高分辨率) self.sr_img = None # 超分图(BGR, 高分辨率) self.is_video = False app_state = AppState() # ====================== # 核心处理函数 # ====================== def process_super_resolution(input_file, model_choice): global app_state if input_file is None: raise gr.Error("请先上传图片或视频!") file_path = input_file app_state = AppState() info_text = "" is_video = any(ext in file_path.lower() for ext in ['.mp4', '.avi', '.mov', '.mkv']) if is_video: # --- 视频处理(直接保存高分辨率)--- cap = cv2.VideoCapture(file_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) info_text += f"🎬 视频信息:\n- 总帧数: {total_frames}\n- 帧率: {fps:.2f} FPS\n" frames = [] while True: ret, frame = cap.read() if not ret: break frames.append(frame) cap.release() model_func = EDSR_MODEL if model_choice == "EDSR_MODEL" else ESPCN_MODEL start_time = time.time() output_data = model_func(frames, is_video=True) infer_time = time.time() - start_time info_text += f"\n⏱️ 推理时间: {infer_time:.2f} 秒\n" full_video_path = os.path.join(tempfile.gettempdir(), f"sr_video_x2.mp4") h_out, w_out = output_data[0].shape[:2] info_text += f"- 超分后尺寸: {w_out} x {h_out}\n" fourcc = cv2.VideoWriter_fourcc(*'mp4v') out_video = cv2.VideoWriter(full_video_path, fourcc, fps, (w_out, h_out)) for frame in output_data: out_video.write(frame) out_video.release() app_state.is_video = True return ( gr.update(value=None, visible=False), # image_display gr.update(visible=False), # btn_original gr.update(visible=False), # btn_sr gr.update(value="当前: 无", visible=False), gr.update(value=full_video_path, visible=True), gr.update(value=full_video_path, visible=True), gr.update(visible=False), info_text ) else: # --- 图片处理(保存原始高分辨率)--- img = cv2.imread(file_path) if img is None: raise gr.Error("无法读取图片!") h, w = img.shape[:2] info_text += f"🖼️ 图片信息:\n- 原始尺寸: {w} x {h}\n" app_state.original_img = img.copy() model_func = EDSR_MODEL if model_choice == "EDSR_MODEL" else ESPCN_MODEL start_time = time.time() sr_img = model_func(img, is_video=False) infer_time = time.time() - start_time info_text += f"\n⏱️ 推理时间: {infer_time:.2f} 秒\n" h_out, w_out = sr_img.shape[:2] info_text += f"- 超分后尺寸: {w_out} x {h_out}\n" sr_img_path = os.path.join(tempfile.gettempdir(), f"sr_image_x2.png") cv2.imwrite(sr_img_path, sr_img) app_state.sr_img = sr_img app_state.is_video = False # 默认显示原图(高分辨率,但 UI 会限制尺寸) return ( gr.update(value=app_state.original_img[:, :, ::-1], visible=True), # BGR→RGB gr.update(visible=True), gr.update(visible=True), gr.update(value="当前: 原图", visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(value=sr_img_path, visible=True), info_text ) # ====================== # 切换显示函数(直接使用原始高分辨率图) # ====================== def show_original(): if app_state.original_img is None: return gr.update(), gr.update() # OpenCV BGR → RGB rgb_img = app_state.original_img[:, :, ::-1] return gr.update(value=rgb_img), gr.update(value="当前: 原图") def show_sr(): if app_state.sr_img is None: return gr.update(), gr.update() rgb_img = app_state.sr_img[:, :, ::-1] return gr.update(value=rgb_img), gr.update(value="当前: 超分图") # ====================== # Gradio UI # ====================== with gr.Blocks(title="超分辨率可视化工具", theme=gr.themes.Soft()) as demo: gr.Markdown("## 🚀 超分辨率模型效果可视化") gr.Markdown("上传图片或视频,选择模型,点击箭头切换原图/超分图!") input_file = gr.File( label="📂 上传图片或视频", file_types=["image", "video"], file_count="single" ) with gr.Row(): model_choice = gr.Radio( choices=["EDSR_MODEL", "ESPCN_MODEL"], value="EDSR_MODEL", label="🔍 选择超分辨率模型" ) run_btn = gr.Button("🚀 开始超分", variant="primary") # 图片区:硬性限定尺寸,直接显示原始高分辨率图 with gr.Column(visible=False) as image_section: image_label = gr.Textbox(value="当前: 原图", interactive=False, lines=1) image_display = gr.Image( label="🖼️ 图像显示", width=800, # 👈 固定宽度 height=600 # 👈 固定高度 ) with gr.Row(): btn_original = gr.Button("◀ 原图") btn_sr = gr.Button("超分图 ▶") # 视频区:硬性限定高度 output_video_player = gr.Video( label="▶️ 超分视频(高分辨率)", visible=False, height=450 # 宽度自适应,高度固定 ) with gr.Row(): download_image = gr.File(label="📥 下载超分图片(原图)", visible=False) download_video = gr.File(label="📥 下载超分视频(完整分辨率)", visible=False) info_box = gr.Textbox(label="📊 处理信息", lines=6, interactive=False) run_btn.click( fn=process_super_resolution, inputs=[input_file, model_choice], outputs=[ image_display, btn_original, btn_sr, image_label, output_video_player, download_video, download_image, info_box ] ) btn_original.click(show_original, outputs=[image_display, image_label]) btn_sr.click(show_sr, outputs=[image_display, image_label]) def toggle_ui(file): if file is None: return ( gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) ) if any(ext in file.lower() for ext in ['.mp4', '.avi', '.mov', '.mkv']): return ( gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) ) else: return ( gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) ) input_file.change( fn=toggle_ui, inputs=input_file, outputs=[ image_section, download_image, output_video_player, download_video ] ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)