Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import pickle | |
| from keras.models import load_model | |
| from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input | |
| from tensorflow.keras.preprocessing import image | |
| from tensorflow.keras.preprocessing.sequence import pad_sequences | |
| # Load the trained model | |
| model = load_model('image_captioning_model.h5') | |
| # Load ResNet50 for feature extraction | |
| resnet_model = ResNet50(include_top=False, weights='imagenet', pooling='avg', input_shape=(224, 224, 3)) | |
| # Load word mappings and config | |
| with open('words_to_indices.pkl', 'rb') as f: | |
| words_to_indices = pickle.load(f) | |
| with open('indices_to_words.pkl', 'rb') as f: | |
| indices_to_words = pickle.load(f) | |
| with open('config.pkl', 'rb') as f: | |
| config = pickle.load(f) | |
| vocab_size = config['vocab_size'] | |
| max_length = config['max_length'] | |
| def extract_features(img): | |
| """Extract features from image using ResNet50""" | |
| img = img.resize((224, 224)) | |
| x = image.img_to_array(img) | |
| x = np.expand_dims(x, axis=0) | |
| x = preprocess_input(x) | |
| features = resnet_model.predict(x, verbose=0) | |
| return features.squeeze() | |
| def greedy_search(photo): | |
| """Generate caption using greedy search""" | |
| photo = photo.reshape(1, 2048) | |
| in_text = '<start>' | |
| for i in range(max_length): | |
| sequence = [words_to_indices.get(s, 0) for s in in_text.split(" ")] | |
| sequence = pad_sequences([sequence], maxlen=max_length, padding='post') | |
| y_pred = model.predict([photo, sequence], verbose=0) | |
| y_pred = np.argmax(y_pred[0]) | |
| word = indices_to_words.get(y_pred, 'Unk') | |
| in_text += ' ' + word | |
| if word == '<end>': | |
| break | |
| final = in_text.split() | |
| final = final[1:-1] | |
| return " ".join(final) | |
| def beam_search(photo, k=3): | |
| """Generate caption using beam search""" | |
| photo = photo.reshape(1, 2048) | |
| in_text = '<start>' | |
| sequence = [words_to_indices.get(s, 0) for s in in_text.split(" ")] | |
| sequence = pad_sequences([sequence], maxlen=max_length, padding='post') | |
| y_pred = model.predict([photo, sequence], verbose=0) | |
| predicted = [] | |
| y_pred = y_pred.reshape(-1) | |
| for i in range(y_pred.shape[0]): | |
| predicted.append((i, y_pred[i])) | |
| predicted = sorted(predicted, key=lambda x: x[1], reverse=True) | |
| b_search = [] | |
| for i in range(k): | |
| word = indices_to_words.get(predicted[i][0], 'Unk') | |
| b_search.append((in_text + ' ' + word, predicted[i][1])) | |
| for idx in range(max_length): | |
| b_search_square = [] | |
| for text in b_search: | |
| if text[0].split(" ")[-1] == "<end>": | |
| break | |
| sequence = [words_to_indices.get(s, 0) for s in text[0].split(" ")] | |
| sequence = pad_sequences([sequence], maxlen=max_length, padding='post') | |
| y_pred = model.predict([photo, sequence], verbose=0) | |
| predicted = [] | |
| y_pred = y_pred.reshape(-1) | |
| for i in range(y_pred.shape[0]): | |
| predicted.append((i, y_pred[i])) | |
| predicted = sorted(predicted, key=lambda x: x[1], reverse=True) | |
| for i in range(k): | |
| word = indices_to_words.get(predicted[i][0], 'Unk') | |
| b_search_square.append((text[0] + ' ' + word, predicted[i][1] * text[1])) | |
| if len(b_search_square) > 0: | |
| b_search = sorted(b_search_square, key=lambda x: x[1], reverse=True)[:k] | |
| final = b_search[0][0].split() | |
| final = final[1:-1] | |
| return " ".join(final) | |
| def generate_caption(img, search_method="Greedy Search"): | |
| """Main function to generate caption from image""" | |
| try: | |
| # Extract features | |
| features = extract_features(img) | |
| # Generate caption based on selected method | |
| if search_method == "Greedy Search": | |
| caption = greedy_search(features) | |
| else: # Beam Search | |
| caption = beam_search(features, k=3) | |
| return caption | |
| except Exception as e: | |
| return f"Error generating caption: {str(e)}" | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=generate_caption, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Image"), | |
| gr.Radio( | |
| choices=["Greedy Search", "Beam Search (k=3)"], | |
| value="Greedy Search", | |
| label="Caption Generation Method" | |
| ) | |
| ], | |
| outputs=gr.Textbox(label="Generated Caption"), | |
| title="Image Captioning with LSTM", | |
| description="Upload an image to generate a descriptive caption. Choose between Greedy Search or Beam Search for caption generation.", | |
| examples=[ | |
| # Add example images if you have them | |
| ], | |
| theme=gr.themes.Soft() | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |