|
|
import gradio as gr |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import os |
|
|
import tempfile |
|
|
import time |
|
|
import axengine as axe |
|
|
import common |
|
|
import imgproc |
|
|
|
|
|
rgb_range=255 |
|
|
scale=2 |
|
|
def from_numpy(x): |
|
|
return x if isinstance(x, np.ndarray) else np.array(x) |
|
|
|
|
|
def quantize(img, rgb_range): |
|
|
pixel_range = 255 / rgb_range |
|
|
return np.round(np.clip(img * pixel_range, 0, 255)) / pixel_range |
|
|
|
|
|
|
|
|
def init_SRmodel(EDSR_path="../model_convert/axmodel/edsr_baseline_x2_1.axmodel", |
|
|
ESPCN_path="../model_convert/axmodel/espcn_x2_T9.axmodel"): |
|
|
|
|
|
EDSR_session = axe.InferenceSession(EDSR_path) |
|
|
ESPCN_session = axe.InferenceSession(ESPCN_path) |
|
|
|
|
|
return [EDSR_session, ESPCN_session] |
|
|
|
|
|
SR_sessions=init_SRmodel() |
|
|
|
|
|
def EDSR_infer(frame, EDSR_session=SR_sessions[0]): |
|
|
output_names = [x.name for x in EDSR_session.get_outputs()] |
|
|
input_name = EDSR_session.get_inputs()[0].name |
|
|
|
|
|
lr_y_image, = common.set_channel(frame, n_channels=3) |
|
|
lr_y_image, = common.np_prepare(lr_y_image, rgb_range=rgb_range) |
|
|
|
|
|
sr = EDSR_session.run(output_names, {input_name: lr_y_image}) |
|
|
|
|
|
if isinstance(sr, (list, tuple)): |
|
|
sr = from_numpy(sr[0]) if len(sr) == 1 else [from_numpy(x) for x in sr] |
|
|
else: |
|
|
sr = from_numpy(sr) |
|
|
|
|
|
sr = quantize(sr, rgb_range).squeeze(0) |
|
|
normalized = sr * 255 / rgb_range |
|
|
ndarr = normalized.transpose(1, 2, 0).astype(np.uint8) |
|
|
|
|
|
return ndarr |
|
|
|
|
|
def ESPCN_infer(frame, ESPCN_session=SR_sessions[1]): |
|
|
|
|
|
output_names = [x.name for x in ESPCN_session.get_outputs()] |
|
|
input_name = ESPCN_session.get_inputs()[0].name |
|
|
|
|
|
lr_y_image, lr_cb_image, lr_cr_image = imgproc.preprocess_one_frame(frame) |
|
|
bic_cb_image = cv2.resize(lr_cb_image, |
|
|
(int(lr_cb_image.shape[1] * scale), |
|
|
int(lr_cb_image.shape[0] * scale)), |
|
|
interpolation=cv2.INTER_CUBIC) |
|
|
bic_cr_image = cv2.resize(lr_cr_image, |
|
|
(int(lr_cr_image.shape[1] * scale), |
|
|
int(lr_cr_image.shape[0] * scale)), |
|
|
interpolation=cv2.INTER_CUBIC) |
|
|
|
|
|
sr = ESPCN_session.run(output_names, {input_name: lr_y_image}) |
|
|
|
|
|
if isinstance(sr, (list, tuple)): |
|
|
sr = from_numpy(sr[0]) if len(sr) == 1 else [from_numpy(x) for x in sr] |
|
|
else: |
|
|
sr = from_numpy(sr) |
|
|
|
|
|
ndarr = imgproc.array_to_image(sr) |
|
|
sr_y_image = ndarr.astype(np.float32) / 255.0 |
|
|
sr_ycbcr_image = cv2.merge([sr_y_image[:, :, 0], bic_cb_image, bic_cr_image]) |
|
|
sr_image = imgproc.ycbcr_to_bgr(sr_ycbcr_image) |
|
|
sr_image = np.clip(sr_image* 255.0, 0 , 255).astype(np.uint8) |
|
|
|
|
|
return sr_image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def EDSR_MODEL(input_data, is_video=False): |
|
|
|
|
|
if is_video: |
|
|
output_frames = [] |
|
|
for frame in input_data: |
|
|
|
|
|
out = EDSR_infer(frame=frame) |
|
|
output_frames.append(out) |
|
|
return output_frames |
|
|
else: |
|
|
out = EDSR_infer(frame=input_data) |
|
|
return out |
|
|
|
|
|
def ESPCN_MODEL(input_data, is_video=False): |
|
|
if is_video: |
|
|
output_frames = [] |
|
|
for frame in input_data: |
|
|
out = ESPCN_infer(frame=frame) |
|
|
output_frames.append(out) |
|
|
return output_frames |
|
|
else: |
|
|
out = ESPCN_infer(frame=input_data) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AppState: |
|
|
def __init__(self): |
|
|
self.original_img = None |
|
|
self.sr_img = None |
|
|
self.is_video = False |
|
|
|
|
|
app_state = AppState() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_super_resolution(input_file, model_choice): |
|
|
global app_state |
|
|
if input_file is None: |
|
|
raise gr.Error("请先上传图片或视频!") |
|
|
|
|
|
file_path = input_file |
|
|
app_state = AppState() |
|
|
info_text = "" |
|
|
|
|
|
is_video = any(ext in file_path.lower() for ext in ['.mp4', '.avi', '.mov', '.mkv']) |
|
|
|
|
|
if is_video: |
|
|
|
|
|
cap = cv2.VideoCapture(file_path) |
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
info_text += f"🎬 视频信息:\n- 总帧数: {total_frames}\n- 帧率: {fps:.2f} FPS\n" |
|
|
frames = [] |
|
|
while True: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
frames.append(frame) |
|
|
cap.release() |
|
|
|
|
|
model_func = EDSR_MODEL if model_choice == "EDSR_MODEL" else ESPCN_MODEL |
|
|
start_time = time.time() |
|
|
output_data = model_func(frames, is_video=True) |
|
|
infer_time = time.time() - start_time |
|
|
info_text += f"\n⏱️ 推理时间: {infer_time:.2f} 秒\n" |
|
|
|
|
|
full_video_path = os.path.join(tempfile.gettempdir(), f"sr_video_x2.mp4") |
|
|
h_out, w_out = output_data[0].shape[:2] |
|
|
info_text += f"- 超分后尺寸: {w_out} x {h_out}\n" |
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
out_video = cv2.VideoWriter(full_video_path, fourcc, fps, (w_out, h_out)) |
|
|
for frame in output_data: |
|
|
out_video.write(frame) |
|
|
out_video.release() |
|
|
|
|
|
app_state.is_video = True |
|
|
|
|
|
return ( |
|
|
gr.update(value=None, visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(value="当前: 无", visible=False), |
|
|
gr.update(value=full_video_path, visible=True), |
|
|
gr.update(value=full_video_path, visible=True), |
|
|
gr.update(visible=False), |
|
|
info_text |
|
|
) |
|
|
|
|
|
else: |
|
|
|
|
|
img = cv2.imread(file_path) |
|
|
if img is None: |
|
|
raise gr.Error("无法读取图片!") |
|
|
h, w = img.shape[:2] |
|
|
info_text += f"🖼️ 图片信息:\n- 原始尺寸: {w} x {h}\n" |
|
|
|
|
|
app_state.original_img = img.copy() |
|
|
model_func = EDSR_MODEL if model_choice == "EDSR_MODEL" else ESPCN_MODEL |
|
|
start_time = time.time() |
|
|
sr_img = model_func(img, is_video=False) |
|
|
infer_time = time.time() - start_time |
|
|
info_text += f"\n⏱️ 推理时间: {infer_time:.2f} 秒\n" |
|
|
|
|
|
h_out, w_out = sr_img.shape[:2] |
|
|
info_text += f"- 超分后尺寸: {w_out} x {h_out}\n" |
|
|
|
|
|
sr_img_path = os.path.join(tempfile.gettempdir(), f"sr_image_x2.png") |
|
|
cv2.imwrite(sr_img_path, sr_img) |
|
|
app_state.sr_img = sr_img |
|
|
|
|
|
app_state.is_video = False |
|
|
|
|
|
|
|
|
return ( |
|
|
gr.update(value=app_state.original_img[:, :, ::-1], visible=True), |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=True), |
|
|
gr.update(value="当前: 原图", visible=True), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(value=sr_img_path, visible=True), |
|
|
info_text |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def show_original(): |
|
|
if app_state.original_img is None: |
|
|
return gr.update(), gr.update() |
|
|
|
|
|
rgb_img = app_state.original_img[:, :, ::-1] |
|
|
return gr.update(value=rgb_img), gr.update(value="当前: 原图") |
|
|
|
|
|
def show_sr(): |
|
|
if app_state.sr_img is None: |
|
|
return gr.update(), gr.update() |
|
|
rgb_img = app_state.sr_img[:, :, ::-1] |
|
|
return gr.update(value=rgb_img), gr.update(value="当前: 超分图") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="超分辨率可视化工具", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("## 🚀 超分辨率模型效果可视化") |
|
|
gr.Markdown("上传图片或视频,选择模型,点击箭头切换原图/超分图!") |
|
|
|
|
|
input_file = gr.File( |
|
|
label="📂 上传图片或视频", |
|
|
file_types=["image", "video"], |
|
|
file_count="single" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
model_choice = gr.Radio( |
|
|
choices=["EDSR_MODEL", "ESPCN_MODEL"], |
|
|
value="EDSR_MODEL", |
|
|
label="🔍 选择超分辨率模型" |
|
|
) |
|
|
run_btn = gr.Button("🚀 开始超分", variant="primary") |
|
|
|
|
|
|
|
|
with gr.Column(visible=False) as image_section: |
|
|
image_label = gr.Textbox(value="当前: 原图", interactive=False, lines=1) |
|
|
image_display = gr.Image( |
|
|
label="🖼️ 图像显示", |
|
|
width=800, |
|
|
height=600 |
|
|
) |
|
|
with gr.Row(): |
|
|
btn_original = gr.Button("◀ 原图") |
|
|
btn_sr = gr.Button("超分图 ▶") |
|
|
|
|
|
|
|
|
output_video_player = gr.Video( |
|
|
label="▶️ 超分视频(高分辨率)", |
|
|
visible=False, |
|
|
height=450 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
download_image = gr.File(label="📥 下载超分图片(原图)", visible=False) |
|
|
download_video = gr.File(label="📥 下载超分视频(完整分辨率)", visible=False) |
|
|
|
|
|
info_box = gr.Textbox(label="📊 处理信息", lines=6, interactive=False) |
|
|
|
|
|
run_btn.click( |
|
|
fn=process_super_resolution, |
|
|
inputs=[input_file, model_choice], |
|
|
outputs=[ |
|
|
image_display, |
|
|
btn_original, |
|
|
btn_sr, |
|
|
image_label, |
|
|
output_video_player, |
|
|
download_video, |
|
|
download_image, |
|
|
info_box |
|
|
] |
|
|
) |
|
|
|
|
|
btn_original.click(show_original, outputs=[image_display, image_label]) |
|
|
btn_sr.click(show_sr, outputs=[image_display, image_label]) |
|
|
|
|
|
def toggle_ui(file): |
|
|
if file is None: |
|
|
return ( |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False) |
|
|
) |
|
|
if any(ext in file.lower() for ext in ['.mp4', '.avi', '.mov', '.mkv']): |
|
|
return ( |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=True) |
|
|
) |
|
|
else: |
|
|
return ( |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False) |
|
|
) |
|
|
|
|
|
input_file.change( |
|
|
fn=toggle_ui, |
|
|
inputs=input_file, |
|
|
outputs=[ |
|
|
image_section, |
|
|
download_image, |
|
|
output_video_player, |
|
|
download_video |
|
|
] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |