Geraldine's picture
Update app.py
79ad491 verified
import gradio as gr
from dataclasses import dataclass
from typing import List, Tuple, Dict
import pandas as pd
import numpy as np
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
import json
with open("/app/data.json", "r", encoding="utf-8") as f:
DATA = json.load(f)
@dataclass
class Doc:
id: int
text: str
def normalize_text(s: str) -> str:
return s.lower()
def build_corpus() -> List[Doc]:
corpus = []
for row in DATA:
text = f"{row['year']}. {row['title']}. {row['abstract']}"
corpus.append(Doc(id=row['id'], text=normalize_text(text)))
return corpus
def bm25_search(corpus: List[Doc], query: str, k: int = 5) -> List[Tuple[int, float]]:
tokenized_corpus = [doc.text.split() for doc in corpus]
bm25 = BM25Okapi(tokenized_corpus)
scores = bm25.get_scores(query.split())
idxs = np.argsort(scores)[::-1][:k]
return [(corpus[i].id, float(scores[i])) for i in idxs]
def show(results, title="Résultats"):
import pandas as pd
rows = []
for rank, (doc_id, score) in enumerate(results, start=1):
row = next(item for item in DATA if item['id'] == doc_id)
rows.append({
'rank': rank,
'id': doc_id,
'title': row['title'],
'score': round(score, 4)
})
return pd.DataFrame(rows)
class DenseIndex:
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
self.model = SentenceTransformer(model_name)
self.index = None
self.embeddings = None
def encode(self, texts: List[str]):
vecs = self.model.encode(texts, normalize_embeddings=True, show_progress_bar=False)
return np.asarray(vecs, dtype='float32')
def build(self, docs: List[Doc]):
texts = [d.text for d in docs]
X = self.encode(texts)
try:
import faiss # type: ignore
dim = X.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(X)
self.index = index
self.embeddings = None
except Exception:
self.index = None
self.embeddings = X
def search(self, query: str, k: int = 5) -> List[Tuple[int, float]]:
q = self.encode([query])[0:1]
if self.index is not None:
import faiss # type: ignore
D, I = self.index.search(q, k)
scores = D[0].tolist()
idxs = I[0].tolist()
else:
X = self.embeddings
scores = (X @ q[0])
idxs = np.argsort(scores)[::-1][:k].tolist()
scores = scores[idxs].tolist()
return [(DATA[i]['id'], float(scores[j])) for j, i in enumerate(idxs)]
def rrf_fusion(results: Dict[str, List[Tuple[int, float]]], k: int = 5, K: int = 60) -> List[Tuple[int, float]]:
ranks_by_sys: Dict[str, Dict[int, int]] = {}
for name, lst in results.items():
ranks = {}
for rank, (doc_id, _score) in enumerate(lst, start=1):
ranks[doc_id] = rank
ranks_by_sys[name] = ranks
all_ids = set()
for ranks in ranks_by_sys.values():
all_ids |= set(ranks.keys())
fused = []
for doc_id in all_ids:
s = 0.0
for ranks in ranks_by_sys.values():
if doc_id in ranks:
r = ranks[doc_id]
s += 1.0 / (K + r)
fused.append((doc_id, s))
fused.sort(key=lambda x: x[1], reverse=True)
return fused[:k]
def rerank_cross_encoder(query: str, doc_ids: List[int]):
pairs = [(query, normalize_text(f"{d['title']}. {d['abstract']}")) for d in DATA if d['id'] in doc_ids]
try:
from FlagEmbedding import FlagReranker
reranker = FlagReranker('BAAI/bge-reranker-base', use_fp16=True)
scores = reranker.compute_score(pairs, normalize=True)
id_order = [d['id'] for d in DATA if d['id'] in doc_ids]
return list(zip(id_order, [float(s) for s in scores]))
except Exception as e:
import numpy as np
print("Reranker indisponible (fallback aléatoire pour la démo).")
rng = np.random.default_rng(123)
scores = rng.random(len(pairs))
id_order = [d['id'] for d in DATA if d['id'] in doc_ids]
return list(zip(id_order, [float(s) for s in scores]))
def orchestrate_search(query: str, k: int = 5, do_rerank: bool = False):
qn = normalize_text(query)
# 1. BM25 Search
bm25_res = bm25_search(corpus, qn, k)
sparse_df = show(bm25_res, title="BM25")
# 2. Dense Retrieval
dense_res = dense.search(qn, k)
dense_df = show(dense_res, title="Dense (Embeddings)")
# 3. RRF Fusion
hybrid_res = rrf_fusion({"sparse": bm25_res, "dense": dense_res}, k)
hybrid_df = show(hybrid_res, title="Fusion Hybride (RRF)")
# 4. Optional Reranking
rerank_df = pd.DataFrame()
if do_rerank:
doc_ids_for_rerank = [doc_id for doc_id, _ in hybrid_res]
rerank_res = rerank_cross_encoder(qn, doc_ids_for_rerank)
rerank_res.sort(key=lambda x: x[1], reverse=True)
rerank_df = show(rerank_res, title="Reranking (cross-encoder)")
return sparse_df, dense_df, hybrid_df, rerank_df
def gradio_interface(query: str, k: int, do_rerank: bool):
sparse_df, dense_df, hybrid_df, rerank_df = orchestrate_search(query, k, do_rerank)
return sparse_df, dense_df, hybrid_df, rerank_df
with gr.Blocks() as demo:
corpus = build_corpus()
dense = DenseIndex()
dense.build(corpus)
gr.Markdown("# Hybrid Search Pipeline Demo")
with gr.Row():
query_input = gr.Textbox(label="Query", placeholder="Enter your search query here...")
k_input = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of results (k)")
rerank_checkbox = gr.Checkbox(label="Enable Reranking (cross-encoder)", value=False)
search_button = gr.Button("Run Search")
with gr.Tabs():
with gr.TabItem("BM25 (Sparse)"):
bm25_output = gr.DataFrame(headers=["rank", "id", "title", "score"], datatype=["number", "number", "str", "number"])
with gr.TabItem("Dense (Embeddings)"):
dense_output = gr.DataFrame(headers=["rank", "id", "title", "score"], datatype=["number", "number", "str", "number"])
with gr.TabItem("Hybrid (RRF Fusion)"):
hybrid_output = gr.DataFrame(headers=["rank", "id", "title", "score"], datatype=["number", "number", "str", "number"])
with gr.TabItem("Reranked Results"):
rerank_output = gr.DataFrame(headers=["rank", "id", "title", "score"], datatype=["number", "number", "str", "number"])
with gr.TabItem("Corpus Data"):
gr.DataFrame(pd.DataFrame(DATA), label="Original Corpus Data")
search_button.click(
gradio_interface,
inputs=[query_input, k_input, rerank_checkbox],
outputs=[bm25_output, dense_output, hybrid_output, rerank_output]
)
demo.launch()