Prabhat9801 commited on
Commit
b9eae0a
·
verified ·
1 Parent(s): 814d167

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -0
app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import pickle
4
+ from keras.models import load_model
5
+ from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
6
+ from tensorflow.keras.preprocessing import image
7
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
8
+
9
+ # Load the trained model
10
+ model = load_model('image_captioning_model.h5')
11
+
12
+ # Load ResNet50 for feature extraction
13
+ resnet_model = ResNet50(include_top=False, weights='imagenet', pooling='avg', input_shape=(224, 224, 3))
14
+
15
+ # Load word mappings and config
16
+ with open('words_to_indices.pkl', 'rb') as f:
17
+ words_to_indices = pickle.load(f)
18
+
19
+ with open('indices_to_words.pkl', 'rb') as f:
20
+ indices_to_words = pickle.load(f)
21
+
22
+ with open('config.pkl', 'rb') as f:
23
+ config = pickle.load(f)
24
+ vocab_size = config['vocab_size']
25
+ max_length = config['max_length']
26
+
27
+ def extract_features(img):
28
+ """Extract features from image using ResNet50"""
29
+ img = img.resize((224, 224))
30
+ x = image.img_to_array(img)
31
+ x = np.expand_dims(x, axis=0)
32
+ x = preprocess_input(x)
33
+ features = resnet_model.predict(x, verbose=0)
34
+ return features.squeeze()
35
+
36
+ def greedy_search(photo):
37
+ """Generate caption using greedy search"""
38
+ photo = photo.reshape(1, 2048)
39
+ in_text = '<start>'
40
+
41
+ for i in range(max_length):
42
+ sequence = [words_to_indices.get(s, 0) for s in in_text.split(" ")]
43
+ sequence = pad_sequences([sequence], maxlen=max_length, padding='post')
44
+ y_pred = model.predict([photo, sequence], verbose=0)
45
+ y_pred = np.argmax(y_pred[0])
46
+ word = indices_to_words.get(y_pred, 'Unk')
47
+ in_text += ' ' + word
48
+ if word == '<end>':
49
+ break
50
+
51
+ final = in_text.split()
52
+ final = final[1:-1]
53
+ return " ".join(final)
54
+
55
+ def beam_search(photo, k=3):
56
+ """Generate caption using beam search"""
57
+ photo = photo.reshape(1, 2048)
58
+ in_text = '<start>'
59
+ sequence = [words_to_indices.get(s, 0) for s in in_text.split(" ")]
60
+ sequence = pad_sequences([sequence], maxlen=max_length, padding='post')
61
+ y_pred = model.predict([photo, sequence], verbose=0)
62
+
63
+ predicted = []
64
+ y_pred = y_pred.reshape(-1)
65
+ for i in range(y_pred.shape[0]):
66
+ predicted.append((i, y_pred[i]))
67
+ predicted = sorted(predicted, key=lambda x: x[1], reverse=True)
68
+
69
+ b_search = []
70
+ for i in range(k):
71
+ word = indices_to_words.get(predicted[i][0], 'Unk')
72
+ b_search.append((in_text + ' ' + word, predicted[i][1]))
73
+
74
+ for idx in range(max_length):
75
+ b_search_square = []
76
+ for text in b_search:
77
+ if text[0].split(" ")[-1] == "<end>":
78
+ break
79
+ sequence = [words_to_indices.get(s, 0) for s in text[0].split(" ")]
80
+ sequence = pad_sequences([sequence], maxlen=max_length, padding='post')
81
+ y_pred = model.predict([photo, sequence], verbose=0)
82
+ predicted = []
83
+ y_pred = y_pred.reshape(-1)
84
+ for i in range(y_pred.shape[0]):
85
+ predicted.append((i, y_pred[i]))
86
+ predicted = sorted(predicted, key=lambda x: x[1], reverse=True)
87
+ for i in range(k):
88
+ word = indices_to_words.get(predicted[i][0], 'Unk')
89
+ b_search_square.append((text[0] + ' ' + word, predicted[i][1] * text[1]))
90
+ if len(b_search_square) > 0:
91
+ b_search = sorted(b_search_square, key=lambda x: x[1], reverse=True)[:k]
92
+
93
+ final = b_search[0][0].split()
94
+ final = final[1:-1]
95
+ return " ".join(final)
96
+
97
+ def generate_caption(img, search_method="Greedy Search"):
98
+ """Main function to generate caption from image"""
99
+ try:
100
+ # Extract features
101
+ features = extract_features(img)
102
+
103
+ # Generate caption based on selected method
104
+ if search_method == "Greedy Search":
105
+ caption = greedy_search(features)
106
+ else: # Beam Search
107
+ caption = beam_search(features, k=3)
108
+
109
+ return caption
110
+ except Exception as e:
111
+ return f"Error generating caption: {str(e)}"
112
+
113
+ # Create Gradio interface
114
+ demo = gr.Interface(
115
+ fn=generate_caption,
116
+ inputs=[
117
+ gr.Image(type="pil", label="Upload Image"),
118
+ gr.Radio(
119
+ choices=["Greedy Search", "Beam Search (k=3)"],
120
+ value="Greedy Search",
121
+ label="Caption Generation Method"
122
+ )
123
+ ],
124
+ outputs=gr.Textbox(label="Generated Caption"),
125
+ title="Image Captioning with LSTM",
126
+ description="Upload an image to generate a descriptive caption. Choose between Greedy Search or Beam Search for caption generation.",
127
+ examples=[
128
+ # Add example images if you have them
129
+ ],
130
+ theme=gr.themes.Soft()
131
+ )
132
+
133
+ if __name__ == "__main__":
134
+ demo.launch()