Spaces:
Sleeping
Sleeping
| # ======================== | |
| # 1) Install deps (Colab/Space) | |
| # ======================== | |
| !pip -q install datasets sentence-transformers gradio --quiet | |
| # ======================== | |
| # 2) Imports | |
| # ======================== | |
| import re, numpy as np, gradio as gr | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| # ======================== | |
| # 3) Load dataset | |
| # ======================== | |
| DATASET_ID = "motimmom/cocktails_clean_nobrand" # <-- change if needed | |
| ds = load_dataset(DATASET_ID, split="train") | |
| # ======================== | |
| # 4) Tagging rules (base + flavors) | |
| # ======================== | |
| BASE_SPIRITS = { | |
| "vodka": [r"\bvodka\b"], | |
| "gin": [r"\bgin\b"], | |
| "rum": [r"\brum\b", r"\bwhite rum\b", r"\bdark rum\b"], | |
| "tequila": [r"\btequila\b"], | |
| "whiskey": [r"\bwhisk(?:e|)y\b", r"\bbourbon\b", r"\bscotch\b", r"\brye\b"], | |
| "mezcal": [r"\bmezcal\b"], | |
| "brandy": [r"\bbrandy\b", r"\bcognac\b"], | |
| "vermouth": [r"\bvermouth\b"], | |
| "other": [r"\btriple sec\b", r"\bliqueur\b", r"\bcointreau\b", r"\baperol\b", r"\bcampari\b"], | |
| } | |
| FLAVORS = { | |
| "citrus": [r"lime", r"lemon", r"grapefruit", r"orange", r"citrus"], | |
| "sweet": [r"simple syrup", r"sugar", r"honey", r"agave", r"maple", r"grenadine", r"vanilla", r"sweet"], | |
| "sour": [r"\bsour\b", r"lemon juice", r"lime juice", r"acid"], | |
| "bitter": [r"\bbitter", r"\bamaro\b", r"\bcampari\b", r"\baperol\b"], | |
| "smoky": [r"\bsmoky\b", r"\bsmoked\b", r"\bmezcal\b", r"\bpeated\b"], | |
| "spicy": [r"spicy", r"chili", r"ginger", r"jalapeño", r"cayenne"], | |
| "herbal": [r"mint", r"basil", r"rosemary", r"thyme", r"herb", r"chartreuse"], | |
| "fruity": [r"pineapple", r"cranberr", r"strawberr", r"mango", r"passion", r"peach", r"fruit"], | |
| "creamy": [r"cream", r"coconut cream", r"egg white", r"creamy"], | |
| "floral": [r"rose", r"violet", r"elderflower", r"lavender", r"floral"], | |
| "refreshing": [r"soda water", r"tonic", r"highball", r"collins", r"fizz", r"refreshing"], | |
| "boozy": [r"stirred", r"martini", r"old fashioned", r"\bboozy\b", r"\bstrong\b"], | |
| } | |
| def _clean(s): | |
| return s.strip() if isinstance(s, str) else "" | |
| def _extract_names_from_pairs(pairs): | |
| out = [] | |
| if not pairs: return out | |
| for p in pairs: | |
| if isinstance(p, (list, tuple)) and len(p) >= 2 and p[1]: | |
| name = str(p[1]).strip().lower() | |
| if name and name not in out: | |
| out.append(name) | |
| return out | |
| def _get_ingredients(row, cols): | |
| # try a few common schemas | |
| if "ingredient_tokens" in cols and row.get("ingredient_tokens"): | |
| return [str(x).strip().lower() for x in row["ingredient_tokens"] if str(x).strip()] | |
| if "ingredients" in cols and row.get("ingredients"): | |
| vals = [] | |
| for x in row["ingredients"]: | |
| if isinstance(x, str) and x.strip(): | |
| vals.append(x.strip().lower()) | |
| if vals: return vals | |
| for k in ["ingredients_raw", "ingredient_list", "ingredients_list"]: | |
| if k in cols and row.get(k): | |
| names = _extract_names_from_pairs(row[k]) | |
| if names: return names | |
| return [] | |
| def _get_instructions(row, cols): | |
| for k in ["instructions", "howto", "preparation", "strInstructions", "recipe"]: | |
| if k in cols and _clean(row.get(k)): | |
| return _clean(row[k]) | |
| return "" | |
| def tag_base(text): | |
| t = text.lower() | |
| for base, pats in BASE_SPIRITS.items(): | |
| if any(re.search(p, t) for p in pats): | |
| return base | |
| return "other" | |
| def tag_flavors(text): | |
| t = text.lower() | |
| tags = [] | |
| for flv, pats in FLAVORS.items(): | |
| if any(re.search(p, t) for p in pats): | |
| tags.append(flv) | |
| return tags | |
| # ======================== | |
| # 5) Build documents | |
| # ======================== | |
| cols = ds.column_names | |
| DOCS = [] | |
| for r in ds: | |
| title = _clean(r.get("title") or r.get("name") or r.get("cocktail_name") or "Untitled") | |
| ingredients = _get_ingredients(r, cols) | |
| instructions = _get_instructions(r, cols) | |
| fused = f"{title}\nIngredients: {', '.join(ingredients)}\nInstructions: {instructions}" | |
| DOCS.append({ | |
| "title": title, | |
| "ingredients": ingredients, | |
| "instructions": instructions, | |
| "text": fused, | |
| "base": tag_base(fused), | |
| "flavors": tag_flavors(fused), | |
| }) | |
| # ======================== | |
| # 6) Embeddings (semantic search) | |
| # ======================== | |
| encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
| embs = encoder.encode([d["text"] for d in DOCS], normalize_embeddings=True, convert_to_numpy=True) | |
| embs = embs.astype("float32") # cosine via dot since normalized | |
| # ======================== | |
| # 7) Recommendation | |
| # ======================== | |
| BASE_OPTIONS = list(BASE_SPIRITS.keys()) | |
| FLAVOR_OPTIONS = list(FLAVORS.keys()) | |
| BOOST = 0.20 # flavor bonus | |
| def recommend(base_alcohol, flavor, top_k=3): | |
| if base_alcohol not in BASE_OPTIONS: | |
| return "Please pick a base alcohol." | |
| if flavor not in FLAVOR_OPTIONS: | |
| return "Please pick a flavor." | |
| # 1) Hard filter by base | |
| idxs = [i for i, d in enumerate(DOCS) if d["base"] == base_alcohol] | |
| if not idxs: | |
| # fallback if dataset has no exact base match | |
| idxs = list(range(len(DOCS))) | |
| # 2) Build a query text that steers embedding toward your constraints | |
| q_text = f"Base spirit: {base_alcohol}. Flavor: {flavor}. Cocktail recipe." | |
| q_emb = encoder.encode([q_text], normalize_embeddings=True, convert_to_numpy=True).astype("float32")[0] | |
| # 3) Cosine similarity (dot, because normalized) | |
| sims = embs[idxs].dot(q_emb) | |
| # 4) Flavor boost if recipe carries the desired flavor tag | |
| scores = [] | |
| for local_pos, i in enumerate(idxs): | |
| base_score = float(sims[local_pos]) | |
| has_flavor = flavor in DOCS[i]["flavors"] | |
| score = base_score + (BOOST if has_flavor else 0.0) | |
| scores.append((score, i)) | |
| scores.sort(reverse=True) | |
| # 5) Build output | |
| k = max(1, int(top_k)) | |
| picks = scores[:k] | |
| if not picks: | |
| return "No matches found." | |
| blocks = [] | |
| for sc, i in picks: | |
| d = DOCS[i] | |
| ing_txt = ", ".join(d["ingredients"]) if d["ingredients"] else "—" | |
| instr = d["instructions"] or "—" | |
| instr = instr[:400] + ("..." if len(instr) > 400 else "") | |
| meta = f"**Base:** {d['base']} | **Flavors:** {', '.join(d['flavors']) or '—'} | **Score:** {sc:.3f}" | |
| blocks.append(f"### {d['title']}\n{meta}\n\n**Ingredients:** {ing_txt}\n\n**Instructions:** {instr}") | |
| return "\n\n---\n\n".join(blocks) | |
| # ======================== | |
| # 8) Gradio UI | |
| # ======================== | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🍹 Cocktail Recommender — Pick a Base & Flavor") | |
| with gr.Row(): | |
| base = gr.Dropdown(choices=BASE_OPTIONS, value="gin", label="Base alcohol") | |
| flavor = gr.Dropdown(choices=FLAVOR_OPTIONS, value="citrus", label="Flavor") | |
| topk = gr.Slider(1, 10, value=3, step=1, label="Number of recommendations") | |
| with gr.Row(): | |
| ex1 = gr.Button("Gin + Citrus") | |
| ex2 = gr.Button("Rum + Fruity") | |
| ex3 = gr.Button("Mezcal + Smoky") | |
| out = gr.Markdown() | |
| gr.Button("Recommend").click(recommend, [base, flavor, topk], out) | |
| ex1.click(lambda: ("gin", "citrus", 3), outputs=[base, flavor, topk]) | |
| ex2.click(lambda: ("rum", "fruity", 3), outputs=[base, flavor, topk]) | |
| ex3.click(lambda: ("mezcal", "smoky", 3), outputs=[base, flavor, topk]) | |
| demo.launch(share=True) | |