Spaces:
Runtime error
Runtime error
| import cv2 | |
| import numpy as np | |
| import streamlit as st | |
| import tensorflow as tf | |
| from models.cstylegan import cStyleGAN | |
| from models.gaugan import GauGAN | |
| from utils import fix_pred_label, onehot_to_rgb, rgb_to_onehot, color_dict | |
| from skimage import io | |
| def load_cstylegan(): | |
| conditional_style_gan = cStyleGAN(start_res=4, target_res=1024) | |
| conditional_style_gan.grow_model(256) | |
| conditional_style_gan.load_weights('checkpoints/cstylegan/cstylegan_256x256.ckpt').expect_partial() | |
| print('Conditional StyleGAN Model Loaded!') | |
| return conditional_style_gan | |
| def load_gaugan(batch_size): | |
| gaugan = GauGAN(image_size=1024, num_classes=7, batch_size=batch_size, latent_dim=512) | |
| gaugan.load_weights('checkpoints/gaugan/gaugan_1024x1024.ckpt').expect_partial() | |
| print('GauGAN Model Loaded!') | |
| return gaugan | |
| def set_seed(): | |
| tf.random.set_seed(seed=st.session_state.seed) | |
| def main(): | |
| st.title('RetinaGAN') | |
| st.sidebar.columns([1, 5, 1])[1].image(cv2.cvtColor(cv2.imread('assets/sample.jpeg'), cv2.COLOR_BGR2RGB)) | |
| st.sidebar.title('Menu') | |
| options = st.sidebar.selectbox('Select Option:', ('About', 'Random', 'Upload your own', 'Retina Template')) | |
| if options == 'About': | |
| st.write('Online Demo for **High-Fidelity Diabetic Retina Fundus Image Synthesis from Freestyle Lesion Maps**') | |
| st.write(''' | |
| Paper: https://opg.optica.org/abstract.cfm?uri=boe-14-2-533 | |
| Github: http://github.com/farrell236/RetinaGAN | |
| 👈 Select an Option From the drop down menu | |
| --- | |
| ''') | |
| st.write(''' | |
| RetinaGAN a two-step process for generating photo-realistic retinal | |
| Fundus images based on artificially generated or free-hand drawn semantic lesion maps. | |
| ''') | |
| st.columns([1, 5, 1])[1].image(cv2.cvtColor(cv2.imread('assets/RetinaGAN_pipeline.png'), cv2.COLOR_BGR2RGB), | |
| caption='RetinaGAN Pipeline') | |
| st.write(''' | |
| StyleGAN is modified to be conditional in to synthesize pathological lesion maps | |
| based on a specified DR grade (i.e., grades 0 to 4). The DR Grades are defined by the | |
| International Clinical Diabetic Retinopathy (ICDR) disease severity scale; | |
| no apparent retinopathy, {mild, moderate, severe} Non-Proliferative Diabetic Retinopathy (NPDR), | |
| and Proliferative Diabetic Retinopathy (PDR). The output of the network is a binary image with | |
| seven channels instead of class colors to avoid ambiguity. | |
| ''') | |
| st.columns([1, 5, 1])[1].image(cv2.cvtColor(cv2.imread('assets/cStyleGAN.png'), cv2.COLOR_BGR2RGB), | |
| caption='Conditional StyleGAN Model') | |
| st.write(''' | |
| The generated label maps are then passed through GauGAN, an image-to-image translation network, | |
| to turn them into photo-realistic retina fundus images. The input to the network are one-hot | |
| encoded labels. | |
| ''') | |
| st.columns([1, 5, 1])[1].image(cv2.cvtColor(cv2.imread('assets/GauGAN.png'), cv2.COLOR_BGR2RGB), | |
| caption='GauGAN Model') | |
| elif options == 'Random': | |
| st.session_state.seed = st.sidebar.number_input('Sampling Seed:', value=42, on_change=set_seed) | |
| ## Load Models | |
| conditional_style_gan = load_cstylegan() | |
| gaugan = load_gaugan(4) | |
| for idx, col in enumerate(st.columns(5)): | |
| z = tf.random.normal((1, conditional_style_gan.z_dim)) | |
| w = conditional_style_gan.mapping([z, conditional_style_gan.embedding(idx)]) | |
| noise = conditional_style_gan.generate_noise(batch_size=1) | |
| labels = conditional_style_gan.call({"style_code": w, "noise": noise, "alpha": 1.0, "class_label": idx}) | |
| labels = tf.keras.backend.softmax(labels) | |
| labels = tf.cast(labels > 0.5, dtype=tf.float32) | |
| labels = tf.image.resize(labels, (1024, 1024), method='nearest') | |
| fixed_labels = fix_pred_label(labels) | |
| fixed_labels = tf.tile(fixed_labels, (4, 1, 1, 1)) | |
| latent_vector = tf.random.normal(shape=(4, 512), mean=0.0, stddev=2.0) | |
| fake_image = gaugan.predict([latent_vector, fixed_labels]) | |
| with col: | |
| st.text(f'DR Grade {idx}') | |
| st.image(onehot_to_rgb(fixed_labels[0], color_dict), output_format='PNG') | |
| for im in fake_image: | |
| st.image(im) | |
| # Run again? | |
| st.button('Regenerate Images') | |
| elif options == 'Upload your own': | |
| st.session_state.seed = st.sidebar.number_input('Sampling Seed:', value=42, on_change=set_seed) | |
| ## Load Models | |
| gaugan = load_gaugan(1) | |
| uploaded_file = st.file_uploader('Choose an image...', type=('png')) | |
| if uploaded_file: | |
| col1, col2 = st.columns(2) | |
| # Read input image with size [H, W, 3] and range (0, 255) | |
| img_array = io.imread(uploaded_file)[..., 0:3] | |
| # Test for valid mask | |
| test_colours = np.unique(img_array.reshape(-1, img_array.shape[2]), axis=0) | |
| if not all([tuple(x) in color_dict.values() for x in test_colours]): | |
| st.info('Mask Contains invalid Class Colours') | |
| return | |
| # Resize image with padding to [1024, 1024, 3] | |
| img_array = tf.image.resize_with_pad(img_array, 1024, 1024, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) | |
| # Display input image | |
| with col1: | |
| st.image(img_array.numpy(), caption='Uploaded Image') | |
| img_label = rgb_to_onehot(img_array.numpy(), color_dict)[None, ...] | |
| latent_vector = tf.random.normal(shape=(1, 512), mean=0.0, stddev=2.0) | |
| fake_image = gaugan.predict([latent_vector, img_label])[0] | |
| with col2: | |
| st.image(fake_image, caption='Generated Image') | |
| # Run again? | |
| st.button('Regenerate Image') | |
| elif options == 'Retina Template': | |
| st.header('Template') | |
| st.write('Download the Retina Template image below. ' | |
| 'Using an image editor of your choice, paint lesions ' | |
| 'into the Vitreous Body and upload it to the model. ' | |
| 'NB: Images must be stored as lossless PNGs') | |
| template = np.uint8(cv2.circle(np.zeros((1024, 1024, 3)), [512, 512], 512, (255, 255, 255), -1)) | |
| st.columns([1, 5, 1])[1].image(template, use_column_width=True, output_format='PNG') | |
| st.header('Class Colours') | |
| cols = st.columns(7) | |
| for idx, cls in enumerate(color_dict): | |
| with cols[idx]: | |
| st.image(image=np.tile(color_dict[cls], (32, 32, 1)), | |
| caption=cls, | |
| output_format='PNG') | |
| # st.caption(color_dict[cls]) | |
| data = {'Class Name': [ | |
| 'Background', | |
| 'Hard Exudate', | |
| 'Hemohedge', | |
| 'Soft Exudate', | |
| 'Micro Aneurysms', | |
| 'Optical Disc', | |
| 'Vitreous Body'], | |
| 'RGB Colour': [ | |
| str(color_dict[0]), # BG | |
| str(color_dict[1]), # EX | |
| str(color_dict[2]), # HE | |
| str(color_dict[3]), # SE | |
| str(color_dict[4]), # MA | |
| str(color_dict[5]), # OD | |
| str(color_dict[6])] # VB | |
| } | |
| st.table(data) | |
| if __name__ == '__main__': | |
| # tf.config.set_visible_devices([], 'GPU') | |
| main() | |