OGOGOG's picture
Rename dataset.py to app.py
d52e547 verified
raw
history blame
7.57 kB
# ========================
# 1) Install deps (Colab/Space)
# ========================
!pip -q install datasets sentence-transformers gradio --quiet
# ========================
# 2) Imports
# ========================
import re, numpy as np, gradio as gr
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
# ========================
# 3) Load dataset
# ========================
DATASET_ID = "motimmom/cocktails_clean_nobrand" # <-- change if needed
ds = load_dataset(DATASET_ID, split="train")
# ========================
# 4) Tagging rules (base + flavors)
# ========================
BASE_SPIRITS = {
"vodka": [r"\bvodka\b"],
"gin": [r"\bgin\b"],
"rum": [r"\brum\b", r"\bwhite rum\b", r"\bdark rum\b"],
"tequila": [r"\btequila\b"],
"whiskey": [r"\bwhisk(?:e|)y\b", r"\bbourbon\b", r"\bscotch\b", r"\brye\b"],
"mezcal": [r"\bmezcal\b"],
"brandy": [r"\bbrandy\b", r"\bcognac\b"],
"vermouth": [r"\bvermouth\b"],
"other": [r"\btriple sec\b", r"\bliqueur\b", r"\bcointreau\b", r"\baperol\b", r"\bcampari\b"],
}
FLAVORS = {
"citrus": [r"lime", r"lemon", r"grapefruit", r"orange", r"citrus"],
"sweet": [r"simple syrup", r"sugar", r"honey", r"agave", r"maple", r"grenadine", r"vanilla", r"sweet"],
"sour": [r"\bsour\b", r"lemon juice", r"lime juice", r"acid"],
"bitter": [r"\bbitter", r"\bamaro\b", r"\bcampari\b", r"\baperol\b"],
"smoky": [r"\bsmoky\b", r"\bsmoked\b", r"\bmezcal\b", r"\bpeated\b"],
"spicy": [r"spicy", r"chili", r"ginger", r"jalapeño", r"cayenne"],
"herbal": [r"mint", r"basil", r"rosemary", r"thyme", r"herb", r"chartreuse"],
"fruity": [r"pineapple", r"cranberr", r"strawberr", r"mango", r"passion", r"peach", r"fruit"],
"creamy": [r"cream", r"coconut cream", r"egg white", r"creamy"],
"floral": [r"rose", r"violet", r"elderflower", r"lavender", r"floral"],
"refreshing": [r"soda water", r"tonic", r"highball", r"collins", r"fizz", r"refreshing"],
"boozy": [r"stirred", r"martini", r"old fashioned", r"\bboozy\b", r"\bstrong\b"],
}
def _clean(s):
return s.strip() if isinstance(s, str) else ""
def _extract_names_from_pairs(pairs):
out = []
if not pairs: return out
for p in pairs:
if isinstance(p, (list, tuple)) and len(p) >= 2 and p[1]:
name = str(p[1]).strip().lower()
if name and name not in out:
out.append(name)
return out
def _get_ingredients(row, cols):
# try a few common schemas
if "ingredient_tokens" in cols and row.get("ingredient_tokens"):
return [str(x).strip().lower() for x in row["ingredient_tokens"] if str(x).strip()]
if "ingredients" in cols and row.get("ingredients"):
vals = []
for x in row["ingredients"]:
if isinstance(x, str) and x.strip():
vals.append(x.strip().lower())
if vals: return vals
for k in ["ingredients_raw", "ingredient_list", "ingredients_list"]:
if k in cols and row.get(k):
names = _extract_names_from_pairs(row[k])
if names: return names
return []
def _get_instructions(row, cols):
for k in ["instructions", "howto", "preparation", "strInstructions", "recipe"]:
if k in cols and _clean(row.get(k)):
return _clean(row[k])
return ""
def tag_base(text):
t = text.lower()
for base, pats in BASE_SPIRITS.items():
if any(re.search(p, t) for p in pats):
return base
return "other"
def tag_flavors(text):
t = text.lower()
tags = []
for flv, pats in FLAVORS.items():
if any(re.search(p, t) for p in pats):
tags.append(flv)
return tags
# ========================
# 5) Build documents
# ========================
cols = ds.column_names
DOCS = []
for r in ds:
title = _clean(r.get("title") or r.get("name") or r.get("cocktail_name") or "Untitled")
ingredients = _get_ingredients(r, cols)
instructions = _get_instructions(r, cols)
fused = f"{title}\nIngredients: {', '.join(ingredients)}\nInstructions: {instructions}"
DOCS.append({
"title": title,
"ingredients": ingredients,
"instructions": instructions,
"text": fused,
"base": tag_base(fused),
"flavors": tag_flavors(fused),
})
# ========================
# 6) Embeddings (semantic search)
# ========================
encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
embs = encoder.encode([d["text"] for d in DOCS], normalize_embeddings=True, convert_to_numpy=True)
embs = embs.astype("float32") # cosine via dot since normalized
# ========================
# 7) Recommendation
# ========================
BASE_OPTIONS = list(BASE_SPIRITS.keys())
FLAVOR_OPTIONS = list(FLAVORS.keys())
BOOST = 0.20 # flavor bonus
def recommend(base_alcohol, flavor, top_k=3):
if base_alcohol not in BASE_OPTIONS:
return "Please pick a base alcohol."
if flavor not in FLAVOR_OPTIONS:
return "Please pick a flavor."
# 1) Hard filter by base
idxs = [i for i, d in enumerate(DOCS) if d["base"] == base_alcohol]
if not idxs:
# fallback if dataset has no exact base match
idxs = list(range(len(DOCS)))
# 2) Build a query text that steers embedding toward your constraints
q_text = f"Base spirit: {base_alcohol}. Flavor: {flavor}. Cocktail recipe."
q_emb = encoder.encode([q_text], normalize_embeddings=True, convert_to_numpy=True).astype("float32")[0]
# 3) Cosine similarity (dot, because normalized)
sims = embs[idxs].dot(q_emb)
# 4) Flavor boost if recipe carries the desired flavor tag
scores = []
for local_pos, i in enumerate(idxs):
base_score = float(sims[local_pos])
has_flavor = flavor in DOCS[i]["flavors"]
score = base_score + (BOOST if has_flavor else 0.0)
scores.append((score, i))
scores.sort(reverse=True)
# 5) Build output
k = max(1, int(top_k))
picks = scores[:k]
if not picks:
return "No matches found."
blocks = []
for sc, i in picks:
d = DOCS[i]
ing_txt = ", ".join(d["ingredients"]) if d["ingredients"] else "—"
instr = d["instructions"] or "—"
instr = instr[:400] + ("..." if len(instr) > 400 else "")
meta = f"**Base:** {d['base']} | **Flavors:** {', '.join(d['flavors']) or '—'} | **Score:** {sc:.3f}"
blocks.append(f"### {d['title']}\n{meta}\n\n**Ingredients:** {ing_txt}\n\n**Instructions:** {instr}")
return "\n\n---\n\n".join(blocks)
# ========================
# 8) Gradio UI
# ========================
with gr.Blocks() as demo:
gr.Markdown("# 🍹 Cocktail Recommender — Pick a Base & Flavor")
with gr.Row():
base = gr.Dropdown(choices=BASE_OPTIONS, value="gin", label="Base alcohol")
flavor = gr.Dropdown(choices=FLAVOR_OPTIONS, value="citrus", label="Flavor")
topk = gr.Slider(1, 10, value=3, step=1, label="Number of recommendations")
with gr.Row():
ex1 = gr.Button("Gin + Citrus")
ex2 = gr.Button("Rum + Fruity")
ex3 = gr.Button("Mezcal + Smoky")
out = gr.Markdown()
gr.Button("Recommend").click(recommend, [base, flavor, topk], out)
ex1.click(lambda: ("gin", "citrus", 3), outputs=[base, flavor, topk])
ex2.click(lambda: ("rum", "fruity", 3), outputs=[base, flavor, topk])
ex3.click(lambda: ("mezcal", "smoky", 3), outputs=[base, flavor, topk])
demo.launch(share=True)