import gradio as gr import tensorflow as tf import numpy as np import json import pickle from tensorflow.keras.models import Model from tensorflow.keras.layers import Input, Dense, LSTM, Embedding, Dropout, Add, BatchNormalization from tensorflow.keras.preprocessing.image import img_to_array from tensorflow.keras.preprocessing.sequence import pad_sequences from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input from PIL import Image # --- Configuration paths --- MODEL_WEIGHTS_PATH = "model_weights.h5" TOKENIZER_DATA_PATH = "tokenizer_data.json" CONFIG_PATH = "model_config.pkl" # --- Load configurations --- print("Loading model configuration...") with open(CONFIG_PATH, 'rb') as f: config = pickle.load(f) max_caption_length = config['max_caption_length'] cnn_output_dim = config['cnn_output_dim'] # --- Load tokenizer data from JSON --- print("Loading tokenizer data...") with open(TOKENIZER_DATA_PATH, 'r') as f: tokenizer_data = json.load(f) # Create a simple tokenizer class class SimpleTokenizer: def __init__(self, word_index, index_word): self.word_index = word_index self.index_word = {int(k): v for k, v in index_word.items()} def texts_to_sequences(self, texts): sequences = [] for text in texts: words = text.lower().split() sequence = [self.word_index.get(word, 0) for word in words] sequences.append(sequence) return sequences # Initialize tokenizer tokenizer = SimpleTokenizer(tokenizer_data['word_index'], tokenizer_data['index_word']) vocab_size = len(tokenizer.word_index) + 1 print(f"Tokenizer loaded! Vocabulary size: {vocab_size}") # --- Load InceptionV3 for feature extraction --- print("Loading InceptionV3 model...") inception_v3_model = InceptionV3(weights='imagenet', input_shape=(299, 299, 3)) inception_v3_model = tf.keras.Model( inputs=inception_v3_model.inputs, outputs=inception_v3_model.layers[-2].output ) # --- Rebuild the caption model architecture --- print("Building caption model architecture...") # Image feature input image_features_input = Input(shape=(cnn_output_dim,), name='Features_Input') image_features_bn = BatchNormalization()(image_features_input) image_features_dense = Dense(256, activation='relu')(image_features_bn) image_features_bn2 = BatchNormalization()(image_features_dense) # Sequence input sequence_input = Input(shape=(max_caption_length,), name='Sequence_Input') sequence_embedding = Embedding(vocab_size, 256, mask_zero=True)(sequence_input) sequence_lstm = LSTM(256)(sequence_embedding) # Merge features merged = Add()([image_features_bn2, sequence_lstm]) merged_dense = Dense(256, activation='relu')(merged) output = Dense(vocab_size, activation='softmax', name='Output_Layer')(merged_dense) # Create model caption_model = Model(inputs=[image_features_input, sequence_input], outputs=output) # Load weights print("Loading model weights...") caption_model.load_weights(MODEL_WEIGHTS_PATH) print("Model loaded successfully!") # --- Utility functions --- def preprocess_image(image_pil): """Preprocess PIL image for InceptionV3""" img = image_pil.resize((299, 299)) img = img_to_array(img) img = np.expand_dims(img, axis=0) img = preprocess_input(img) return img def extract_image_features(model, image_pil): """Extract features from image using InceptionV3""" img_processed = preprocess_image(image_pil) features = model.predict(img_processed, verbose=0) return features.flatten() def greedy_generator(image_features): """Generate caption using greedy search""" in_text = 'start' for _ in range(max_caption_length): sequence = tokenizer.texts_to_sequences([in_text])[0] sequence = pad_sequences([sequence], maxlen=max_caption_length).reshape((1, max_caption_length)) prediction = caption_model.predict([image_features.reshape(1, cnn_output_dim), sequence], verbose=0) idx = np.argmax(prediction) if idx == 0 or idx not in tokenizer.index_word: break word = tokenizer.index_word[idx] in_text += ' ' + word if word == 'end': break in_text = in_text.replace('start ', '').replace(' start', '') in_text = in_text.replace(' end', '').replace('end', '') return in_text.strip() def beam_search_generator(image_features, K_beams=3): """Generate caption using beam search""" start = [tokenizer.word_index.get('start', 1)] start_word = [[start, 0.0]] for _ in range(max_caption_length): temp = [] for s in start_word: sequence = pad_sequences([s[0]], maxlen=max_caption_length).reshape((1, max_caption_length)) preds = caption_model.predict([image_features.reshape(1, cnn_output_dim), sequence], verbose=0) word_preds = np.argsort(preds[0])[-K_beams:] for w in word_preds: if w == 0: continue next_cap, prob = s[0][:], s[1] next_cap.append(w) prob += np.log(preds[0][w] + 1e-10) temp.append([next_cap, prob]) start_word = temp start_word = sorted(start_word, reverse=True, key=lambda l: l[1]) start_word = start_word[:K_beams] best_caption_sequence = start_word[0][0] captions_ = [] for i in best_caption_sequence: if i in tokenizer.index_word: captions_.append(tokenizer.index_word[i]) final_caption = [] for i in captions_: if i != 'end' and i != 'start': final_caption.append(i) elif i == 'end': break return ' '.join(final_caption).strip() # --- Gradio Interface --- def generate_captions_gradio(image): """Main function for Gradio interface""" if image is None: return "Please upload an image.", "Please upload an image." try: if not isinstance(image, Image.Image): image = Image.fromarray(image) if image.mode != 'RGB': image = image.convert('RGB') image_features = extract_image_features(inception_v3_model, image) greedy_cap = greedy_generator(image_features) beam_cap = beam_search_generator(image_features, K_beams=3) greedy_output = f"**Greedy Search:**\n\n{greedy_cap.capitalize()}" beam_output = f"**Beam Search (K=3):**\n\n{beam_cap.capitalize()}" return greedy_output, beam_output except Exception as e: error_msg = f"Error: {str(e)}" return error_msg, error_msg # --- Gradio App --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🖼️ Image Captioning Model ### Generate natural language descriptions for your images Upload an image and get captions using two different algorithms: - **Greedy Search**: Fast caption generation - **Beam Search (K=3)**: Higher quality captions """ ) with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Image") generate_btn = gr.Button("Generate Captions", variant="primary") with gr.Column(): greedy_output = gr.Textbox(label="Greedy Search Result", lines=3) beam_output = gr.Textbox(label="Beam Search Result", lines=3) gr.Markdown( """ --- **Model**: CNN-RNN (InceptionV3 + LSTM) | **Dataset**: Flickr8k **Created by**: Prabhar Kumar Singh | [GitHub](https://github.com/Prabhat9801/Image_Captioning_Model) """ ) generate_btn.click( fn=generate_captions_gradio, inputs=[image_input], outputs=[greedy_output, beam_output] ) image_input.change( fn=generate_captions_gradio, inputs=[image_input], outputs=[greedy_output, beam_output] ) if __name__ == "__main__": demo.launch()