# ======================== # 1) Install deps (Colab/Space) # ======================== !pip -q install datasets sentence-transformers gradio --quiet # ======================== # 2) Imports # ======================== import re, numpy as np, gradio as gr from datasets import load_dataset from sentence_transformers import SentenceTransformer # ======================== # 3) Load dataset # ======================== DATASET_ID = "motimmom/cocktails_clean_nobrand" # <-- change if needed ds = load_dataset(DATASET_ID, split="train") # ======================== # 4) Tagging rules (base + flavors) # ======================== BASE_SPIRITS = { "vodka": [r"\bvodka\b"], "gin": [r"\bgin\b"], "rum": [r"\brum\b", r"\bwhite rum\b", r"\bdark rum\b"], "tequila": [r"\btequila\b"], "whiskey": [r"\bwhisk(?:e|)y\b", r"\bbourbon\b", r"\bscotch\b", r"\brye\b"], "mezcal": [r"\bmezcal\b"], "brandy": [r"\bbrandy\b", r"\bcognac\b"], "vermouth": [r"\bvermouth\b"], "other": [r"\btriple sec\b", r"\bliqueur\b", r"\bcointreau\b", r"\baperol\b", r"\bcampari\b"], } FLAVORS = { "citrus": [r"lime", r"lemon", r"grapefruit", r"orange", r"citrus"], "sweet": [r"simple syrup", r"sugar", r"honey", r"agave", r"maple", r"grenadine", r"vanilla", r"sweet"], "sour": [r"\bsour\b", r"lemon juice", r"lime juice", r"acid"], "bitter": [r"\bbitter", r"\bamaro\b", r"\bcampari\b", r"\baperol\b"], "smoky": [r"\bsmoky\b", r"\bsmoked\b", r"\bmezcal\b", r"\bpeated\b"], "spicy": [r"spicy", r"chili", r"ginger", r"jalapeño", r"cayenne"], "herbal": [r"mint", r"basil", r"rosemary", r"thyme", r"herb", r"chartreuse"], "fruity": [r"pineapple", r"cranberr", r"strawberr", r"mango", r"passion", r"peach", r"fruit"], "creamy": [r"cream", r"coconut cream", r"egg white", r"creamy"], "floral": [r"rose", r"violet", r"elderflower", r"lavender", r"floral"], "refreshing": [r"soda water", r"tonic", r"highball", r"collins", r"fizz", r"refreshing"], "boozy": [r"stirred", r"martini", r"old fashioned", r"\bboozy\b", r"\bstrong\b"], } def _clean(s): return s.strip() if isinstance(s, str) else "" def _extract_names_from_pairs(pairs): out = [] if not pairs: return out for p in pairs: if isinstance(p, (list, tuple)) and len(p) >= 2 and p[1]: name = str(p[1]).strip().lower() if name and name not in out: out.append(name) return out def _get_ingredients(row, cols): # try a few common schemas if "ingredient_tokens" in cols and row.get("ingredient_tokens"): return [str(x).strip().lower() for x in row["ingredient_tokens"] if str(x).strip()] if "ingredients" in cols and row.get("ingredients"): vals = [] for x in row["ingredients"]: if isinstance(x, str) and x.strip(): vals.append(x.strip().lower()) if vals: return vals for k in ["ingredients_raw", "ingredient_list", "ingredients_list"]: if k in cols and row.get(k): names = _extract_names_from_pairs(row[k]) if names: return names return [] def _get_instructions(row, cols): for k in ["instructions", "howto", "preparation", "strInstructions", "recipe"]: if k in cols and _clean(row.get(k)): return _clean(row[k]) return "" def tag_base(text): t = text.lower() for base, pats in BASE_SPIRITS.items(): if any(re.search(p, t) for p in pats): return base return "other" def tag_flavors(text): t = text.lower() tags = [] for flv, pats in FLAVORS.items(): if any(re.search(p, t) for p in pats): tags.append(flv) return tags # ======================== # 5) Build documents # ======================== cols = ds.column_names DOCS = [] for r in ds: title = _clean(r.get("title") or r.get("name") or r.get("cocktail_name") or "Untitled") ingredients = _get_ingredients(r, cols) instructions = _get_instructions(r, cols) fused = f"{title}\nIngredients: {', '.join(ingredients)}\nInstructions: {instructions}" DOCS.append({ "title": title, "ingredients": ingredients, "instructions": instructions, "text": fused, "base": tag_base(fused), "flavors": tag_flavors(fused), }) # ======================== # 6) Embeddings (semantic search) # ======================== encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") embs = encoder.encode([d["text"] for d in DOCS], normalize_embeddings=True, convert_to_numpy=True) embs = embs.astype("float32") # cosine via dot since normalized # ======================== # 7) Recommendation # ======================== BASE_OPTIONS = list(BASE_SPIRITS.keys()) FLAVOR_OPTIONS = list(FLAVORS.keys()) BOOST = 0.20 # flavor bonus def recommend(base_alcohol, flavor, top_k=3): if base_alcohol not in BASE_OPTIONS: return "Please pick a base alcohol." if flavor not in FLAVOR_OPTIONS: return "Please pick a flavor." # 1) Hard filter by base idxs = [i for i, d in enumerate(DOCS) if d["base"] == base_alcohol] if not idxs: # fallback if dataset has no exact base match idxs = list(range(len(DOCS))) # 2) Build a query text that steers embedding toward your constraints q_text = f"Base spirit: {base_alcohol}. Flavor: {flavor}. Cocktail recipe." q_emb = encoder.encode([q_text], normalize_embeddings=True, convert_to_numpy=True).astype("float32")[0] # 3) Cosine similarity (dot, because normalized) sims = embs[idxs].dot(q_emb) # 4) Flavor boost if recipe carries the desired flavor tag scores = [] for local_pos, i in enumerate(idxs): base_score = float(sims[local_pos]) has_flavor = flavor in DOCS[i]["flavors"] score = base_score + (BOOST if has_flavor else 0.0) scores.append((score, i)) scores.sort(reverse=True) # 5) Build output k = max(1, int(top_k)) picks = scores[:k] if not picks: return "No matches found." blocks = [] for sc, i in picks: d = DOCS[i] ing_txt = ", ".join(d["ingredients"]) if d["ingredients"] else "—" instr = d["instructions"] or "—" instr = instr[:400] + ("..." if len(instr) > 400 else "") meta = f"**Base:** {d['base']} | **Flavors:** {', '.join(d['flavors']) or '—'} | **Score:** {sc:.3f}" blocks.append(f"### {d['title']}\n{meta}\n\n**Ingredients:** {ing_txt}\n\n**Instructions:** {instr}") return "\n\n---\n\n".join(blocks) # ======================== # 8) Gradio UI # ======================== with gr.Blocks() as demo: gr.Markdown("# 🍹 Cocktail Recommender — Pick a Base & Flavor") with gr.Row(): base = gr.Dropdown(choices=BASE_OPTIONS, value="gin", label="Base alcohol") flavor = gr.Dropdown(choices=FLAVOR_OPTIONS, value="citrus", label="Flavor") topk = gr.Slider(1, 10, value=3, step=1, label="Number of recommendations") with gr.Row(): ex1 = gr.Button("Gin + Citrus") ex2 = gr.Button("Rum + Fruity") ex3 = gr.Button("Mezcal + Smoky") out = gr.Markdown() gr.Button("Recommend").click(recommend, [base, flavor, topk], out) ex1.click(lambda: ("gin", "citrus", 3), outputs=[base, flavor, topk]) ex2.click(lambda: ("rum", "fruity", 3), outputs=[base, flavor, topk]) ex3.click(lambda: ("mezcal", "smoky", 3), outputs=[base, flavor, topk]) demo.launch(share=True)