Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import numpy as np | |
| import gradio as gr | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| # ======================== | |
| # Config | |
| # ======================== | |
| DATASET_ID = "motimmom/cocktails_clean_nobrand" | |
| EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
| FLAVOR_BOOST = 0.20 | |
| # Use the image you uploaded at the root of the Space repo: | |
| BACKGROUND_IMAGE_URL = "file=bar.jpg" # <-- safest: served by Gradio from your Space files | |
| # If you prefer the remote URL, make sure the space name uses the HY-PHEN: | |
| # BACKGROUND_IMAGE_URL = "https://huggingface.co/spaces/OGOGOG/AI-Bartender/resolve/main/bar.jpg" | |
| # If dataset is private, add Space secret HF_TOKEN (read scope) | |
| HF_READ_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") | |
| load_kwargs = {} | |
| if HF_READ_TOKEN: | |
| load_kwargs["token"] = HF_READ_TOKEN | |
| load_kwargs["use_auth_token"] = HF_READ_TOKEN | |
| # ======================== | |
| # Base & Flavor tagging rules | |
| # ======================== | |
| BASE_SPIRITS = { | |
| "vodka": [r"\bvodka\b"], | |
| "gin": [r"\bgin\b"], | |
| "rum": [r"\brum\b", r"\bwhite rum\b", r"\bdark rum\b"], | |
| "tequila": [r"\btequila\b"], | |
| "whiskey": [r"\bwhisk(?:e|)y\b", r"\bbourbon\b", r"\bscotch\b", r"\brye\b"], | |
| "mezcal": [r"\bmezcal\b"], | |
| "brandy": [r"\bbrandy\b", r"\bcognac\b"], | |
| "vermouth": [r"\bvermouth\b"], | |
| "other": [r"\btriple sec\b", r"\bliqueur\b", r"\bcointreau\b", r"\baperol\b", r"\bcampari\b"], | |
| } | |
| FLAVORS = { | |
| "citrus": [r"lime", r"lemon", r"grapefruit", r"orange", r"\bcitrus\b"], | |
| "sweet": [r"simple syrup", r"\bsugar\b", r"\bhoney\b", r"\bagave\b", r"\bmaple\b", r"\bgrenadine\b", r"\bvanilla\b", r"\bsweet\b"], | |
| "sour": [r"\bsour\b", r"lemon juice", r"lime juice", r"\bacid\b"], | |
| "bitter": [r"\bbitter", r"\bamaro\b", r"\bcampari\b", r"\baperol\b"], | |
| "smoky": [r"\bsmoky\b", r"\bsmoked\b", r"\bmezcal\b", r"\bpeated\b"], | |
| "spicy": [r"\bspicy\b", r"\bchili\b", r"\bginger\b", r"\bjalapeño\b", r"\bcayenne\b"], | |
| "herbal": [r"\bmint\b", r"\bbasil\b", r"\brosemary\b", r"\bthyme\b", r"\bherb", r"\bchartreuse\b"], | |
| "fruity": [r"pineapple", r"cranberr", r"strawberr", r"mango", r"passion", r"peach", r"\bfruit"], | |
| "creamy": [r"\bcream\b", r"coconut cream", r"\begg white\b", r"\bcreamy\b"], | |
| "floral": [r"\brose\b", r"\bviolet\b", r"\belderflower\b", r"\blavender\b", r"\bfloral\b"], | |
| "refreshing": [r"soda water", r"\btonic\b", r"\bhighball\b", r"\bcollins\b", r"\bfizz\b", r"\brefreshing\b"], | |
| "boozy": [r"\bstirred\b", r"\bmartini\b", r"old fashioned", r"\bboozy\b", r"\bstrong\b"], | |
| } | |
| FLAVOR_OPTIONS = list(FLAVORS.keys()) | |
| # ======================== | |
| # Robust extraction helpers (with measures) | |
| # ======================== | |
| def _clean(s): return s.strip() if isinstance(s, str) else "" | |
| def _norm_measure(s: str) -> str: | |
| if not isinstance(s, str): return "" | |
| s = re.sub(r"\s+", " ", s.strip()) | |
| s = re.sub(r"\bml\b", "ml", s, flags=re.I) | |
| s = re.sub(r"\boz\b", "oz", s, flags=re.I) | |
| s = re.sub(r"\btsp\b", "tsp", s, flags=re.I) | |
| s = re.sub(r"\btbsp\b", "tbsp", s, flags=re.I) | |
| return s | |
| def _join_measure_name(measure, name): | |
| m = _norm_measure(measure) | |
| n = name.strip() if isinstance(name, str) else "" | |
| if m and n: return f"{m} {n}" | |
| return n or m | |
| def _split_ingredient_blob(s): | |
| if not isinstance(s, str): return [] | |
| parts = re.split(r"[,\n;•\-–]+", s) | |
| return [p.strip() for p in parts if p and p.strip()] | |
| def _from_list_of_pairs(val): | |
| out_disp, out_tokens = [], [] | |
| for x in val: | |
| if not isinstance(x, (list, tuple)) or len(x) == 0: continue | |
| if len(x) == 1: | |
| name = str(x[0]).strip() | |
| if name: out_disp.append(name); out_tokens.append(name.lower()); continue | |
| a, b = str(x[0]).strip(), str(x[1]).strip() | |
| if re.search(r"\d", a) and not re.search(r"\d", b): | |
| disp = _join_measure_name(a, b); out_disp.append(disp); out_tokens.append(b.lower()) | |
| elif re.search(r"\d", b) and not re.search(r"\d", a): | |
| disp = _join_measure_name(b, a); out_disp.append(disp); out_tokens.append(a.lower()) | |
| else: | |
| disp = (a + " " + b).strip(); out_disp.append(disp); out_tokens.append((b if len(b) > len(a) else a).lower()) | |
| return out_disp, out_tokens | |
| def _from_list_of_dicts(val): | |
| out_disp, out_tokens = [], [] | |
| for x in val: | |
| if not isinstance(x, dict): continue | |
| name = next((x[k].strip() for k in ["name","ingredient","item","raw","text","strIngredient"] if isinstance(x.get(k), str) and x[k].strip()), None) | |
| meas = next((x[k].strip() for k in ["measure","qty","quantity","amount","unit","Measure","strMeasure"] if isinstance(x.get(k), str) and x[k].strip()), None) | |
| if name and meas: | |
| out_disp.append(_join_measure_name(meas, name)); out_tokens.append(name.lower()) | |
| elif name: | |
| out_disp.append(name); out_tokens.append(name.lower()) | |
| return out_disp, out_tokens | |
| def _ingredients_from_any(val): | |
| if isinstance(val, str): | |
| lines = _split_ingredient_blob(val) | |
| tokens = [] | |
| for line in lines: | |
| parts = re.split(r"\s+", line); idx = 0 | |
| for i, p in enumerate(parts): | |
| if re.search(r"[A-Za-z]", p): idx = i; break | |
| tokens.append(" ".join(parts[idx:]).lower()) | |
| return lines, tokens | |
| if isinstance(val, list) and all(isinstance(x, str) for x in val): | |
| disp = [x.strip() for x in val if x and x.strip()] | |
| return disp, [x.lower().strip() for x in disp] | |
| if isinstance(val, list) and any(isinstance(x, (list, tuple)) for x in val): | |
| return _from_list_of_pairs(val) | |
| if isinstance(val, list) and any(isinstance(x, dict) for x in val): | |
| return _from_list_of_dicts(val) | |
| return [], [] | |
| def _get_title(row, cols): | |
| for k in ["title","name","cocktail_name","drink","Drink","strDrink"]: | |
| if k in cols and _clean(row.get(k)): return _clean(row[k]) | |
| return "Untitled" | |
| def _get_ingredients_with_measures(row, cols): | |
| if "ingredient_tokens" in cols and row.get("ingredient_tokens"): | |
| toks = [str(x).strip().lower() for x in row["ingredient_tokens"] if str(x).strip()] | |
| for mkey in ["measure_tokens","measures","measure_list"]: | |
| if mkey in cols and row.get(mkey) and isinstance(row[mkey], list) and len(row[mkey]) == len(toks): | |
| disp = [] | |
| for m, n in zip(row[mkey], row["ingredient_tokens"]): | |
| m = _norm_measure(str(m)); n = str(n).strip() | |
| disp.append(_join_measure_name(m, n) if m else n) | |
| return disp, toks | |
| return toks, toks | |
| for key in ["ingredients","ingredients_raw","raw_ingredients","Raw_Ingredients","Raw Ingredients","ingredient_list","ingredients_list"]: | |
| if key in cols and row.get(key) not in (None, "", [], {}): return _ingredients_from_any(row[key]) | |
| return [], [] | |
| def tag_base(text): | |
| t = text.lower() | |
| for base, pats in BASE_SPIRITS.items(): | |
| if any(re.search(p, t) for p in pats): return base | |
| return "other" | |
| def tag_flavors(text): | |
| t = text.lower(); tags = [] | |
| for flv, pats in FLAVORS.items(): | |
| if any(re.search(p, t) for p in pats): tags.append(flv) | |
| return tags | |
| # ======================== | |
| # Load dataset & build docs | |
| # ======================== | |
| ds = load_dataset(DATASET_ID, split="train", **load_kwargs) | |
| cols = ds.column_names | |
| DOCS = [] | |
| for r in ds: | |
| title = _get_title(r, cols) | |
| ing_disp, ing_tokens = _get_ingredients_with_measures(r, cols) | |
| ing_disp = [x for x in ing_disp if x]; ing_tokens = [x for x in ing_tokens if x] | |
| fused = f"{title}\nIngredients: {', '.join(ing_tokens)}" | |
| DOCS.append({ | |
| "title": title, | |
| "ingredients_display": ing_disp, | |
| "ingredients_tokens": ing_tokens, | |
| "text": fused, | |
| "base": tag_base(fused), | |
| "flavors": tag_flavors(fused), | |
| }) | |
| # ======================== | |
| # Embeddings | |
| # ======================== | |
| encoder = SentenceTransformer(EMBED_MODEL) | |
| doc_embs = encoder.encode([d["text"] for d in DOCS], normalize_embeddings=True, convert_to_numpy=True).astype("float32") | |
| # ======================== | |
| # Pretty ingredient formatting | |
| # ======================== | |
| _MEASURE_RE = re.compile(r"^\s*(?P<meas>(?:\d+(\.\d+)?|\d+\s*/\s*\d+|\d+\s*\d*/\d+)\s*(?:ml|oz|tsp|tbsp)?|\d+\s*(?:ml|oz|tsp|tbsp)|(?:dash|dashes|drop|drops|barspoon)s?)\b[\s\-–:]*", flags=re.I) | |
| def _split_measure_name_line(line: str): | |
| if not isinstance(line, str): return None, line | |
| m = _MEASURE_RE.match(line.strip()) | |
| if m: | |
| meas = _norm_measure(m.group("meas")); name = line[m.end():].strip() | |
| return meas, name or "" | |
| return "", line.strip() | |
| def _format_ingredients_markdown(lines): | |
| """Bullet points as 'Ingredient (amount)'. Also removes [ and ].""" | |
| if not lines: return "—" | |
| formatted = [] | |
| for ln in lines: | |
| ln = ln.replace("[","").replace("]","") | |
| meas, name = _split_measure_name_line(ln) | |
| if name and meas: formatted.append(f"- {name} ({meas})") | |
| elif name: formatted.append(f"- {name}") | |
| else: formatted.append(f"- {ln}") | |
| return "\n".join(formatted) | |
| # ======================== | |
| # Recommendation | |
| # ======================== | |
| def recommend(base_alcohol_text, flavor, top_k=3): | |
| inferred_base = tag_base(base_alcohol_text or "") | |
| if flavor not in FLAVOR_OPTIONS: return "Please choose a flavor." | |
| idxs = [i for i, d in enumerate(DOCS) if d["base"] == inferred_base] or list(range(len(DOCS))) | |
| q_text = f"Base spirit: {base_alcohol_text}. Flavor: {flavor}. Cocktail recipe." | |
| q_emb = encoder.encode([q_text], normalize_embeddings=True, convert_to_numpy=True).astype("float32")[0] | |
| sims = doc_embs[idxs].dot(q_emb) | |
| scored = [] | |
| for pos, i in enumerate(idxs): | |
| score = float(sims[pos]) + (FLAVOR_BOOST if flavor in DOCS[i]['flavors'] else 0.0) | |
| scored.append((score, i)) | |
| scored.sort(reverse=True) | |
| picks = scored[:max(1,int(top_k))] | |
| if not picks: return "No matches found." | |
| blocks = [] | |
| for sc, i in picks: | |
| d = DOCS[i] | |
| ing_lines = d["ingredients_display"] or d["ingredients_tokens"] | |
| ing_md = _format_ingredients_markdown(ing_lines) | |
| meta = f"**Base:** {d['base']} | **Flavor tags:** {', '.join(d['flavors']) or '—'} | **Score:** {sc:.3f}" | |
| blocks.append(f"### {d['title']}\n{meta}\n\n**Ingredients:**\n{ing_md}") | |
| return "\n\n---\n\n".join(blocks) | |
| # ======================== | |
| # Background + UI (robust) | |
| # ======================== | |
| CUSTOM_CSS = f""" | |
| html, body, #root {{ height: 100%; }} | |
| /* Background on BODY to avoid component stacking issues */ | |
| body {{ | |
| background-image: url('{BACKGROUND_IMAGE_URL}'); | |
| background-size: cover; | |
| background-position: center; | |
| background-attachment: fixed; | |
| }} | |
| /* Dark overlay for text contrast */ | |
| body::before {{ | |
| content: ""; | |
| position: fixed; | |
| inset: 0; | |
| background: rgba(0,0,0,0.30); /* slightly lighter so image shows */ | |
| z-index: 0; | |
| }} | |
| /* Make the app transparent and float above overlay */ | |
| .gradio-container {{ background: transparent !important; position: relative; z-index: 1; }} | |
| .glass-card {{ | |
| background: rgba(255, 255, 255, 0.08); | |
| backdrop-filter: blur(6px); | |
| -webkit-backdrop-filter: blur(6px); | |
| border-radius: 14px; | |
| padding: 18px; | |
| border: 1px solid rgba(255, 255, 255, 0.12); | |
| }} | |
| """ | |
| with gr.Blocks(css=CUSTOM_CSS) as demo: | |
| with gr.Column(elem_classes=["glass-card"]): | |
| gr.Markdown("# 🍹 AI Bartender — Type a Base + Flavor") | |
| with gr.Row(): | |
| base_text = gr.Textbox(value="gin", label="Base alcohol (type any spirit, e.g., 'gin', 'white rum', 'bourbon')") | |
| flavor = gr.Dropdown(choices=FLAVOR_OPTIONS, value="citrus", label="Flavor") | |
| topk = gr.Slider(1, 10, value=3, step=1, label="Number of recommendations") | |
| with gr.Row(): | |
| ex1 = gr.Button("Example: Gin + Citrus") | |
| ex2 = gr.Button("Example: Rum + Fruity") | |
| ex3 = gr.Button("Example: Mezcal + Smoky") | |
| out = gr.Markdown() | |
| gr.Button("Recommend").click(recommend, [base_text, flavor, topk], out) | |
| ex1.click(lambda: ("gin", "citrus", 3), outputs=[base_text, flavor, topk]) | |
| ex2.click(lambda: ("white rum", "fruity", 3), outputs=[base_text, flavor, topk]) | |
| ex3.click(lambda: ("mezcal", "smoky", 3), outputs=[base_text, flavor, topk]) | |
| if __name__ == "__main__": | |
| demo.launch() |