import os import re import numpy as np import gradio as gr from datasets import load_dataset from sentence_transformers import SentenceTransformer # ======================== # Config # ======================== DATASET_ID = "motimmom/cocktails_clean_nobrand" EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" FLAVOR_BOOST = 0.20 # Use the image you uploaded at the root of the Space repo: BACKGROUND_IMAGE_URL = "file=bar.jpg" # <-- safest: served by Gradio from your Space files # If you prefer the remote URL, make sure the space name uses the HY-PHEN: # BACKGROUND_IMAGE_URL = "https://huggingface.co/spaces/OGOGOG/AI-Bartender/resolve/main/bar.jpg" # If dataset is private, add Space secret HF_TOKEN (read scope) HF_READ_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") load_kwargs = {} if HF_READ_TOKEN: load_kwargs["token"] = HF_READ_TOKEN load_kwargs["use_auth_token"] = HF_READ_TOKEN # ======================== # Base & Flavor tagging rules # ======================== BASE_SPIRITS = { "vodka": [r"\bvodka\b"], "gin": [r"\bgin\b"], "rum": [r"\brum\b", r"\bwhite rum\b", r"\bdark rum\b"], "tequila": [r"\btequila\b"], "whiskey": [r"\bwhisk(?:e|)y\b", r"\bbourbon\b", r"\bscotch\b", r"\brye\b"], "mezcal": [r"\bmezcal\b"], "brandy": [r"\bbrandy\b", r"\bcognac\b"], "vermouth": [r"\bvermouth\b"], "other": [r"\btriple sec\b", r"\bliqueur\b", r"\bcointreau\b", r"\baperol\b", r"\bcampari\b"], } FLAVORS = { "citrus": [r"lime", r"lemon", r"grapefruit", r"orange", r"\bcitrus\b"], "sweet": [r"simple syrup", r"\bsugar\b", r"\bhoney\b", r"\bagave\b", r"\bmaple\b", r"\bgrenadine\b", r"\bvanilla\b", r"\bsweet\b"], "sour": [r"\bsour\b", r"lemon juice", r"lime juice", r"\bacid\b"], "bitter": [r"\bbitter", r"\bamaro\b", r"\bcampari\b", r"\baperol\b"], "smoky": [r"\bsmoky\b", r"\bsmoked\b", r"\bmezcal\b", r"\bpeated\b"], "spicy": [r"\bspicy\b", r"\bchili\b", r"\bginger\b", r"\bjalapeño\b", r"\bcayenne\b"], "herbal": [r"\bmint\b", r"\bbasil\b", r"\brosemary\b", r"\bthyme\b", r"\bherb", r"\bchartreuse\b"], "fruity": [r"pineapple", r"cranberr", r"strawberr", r"mango", r"passion", r"peach", r"\bfruit"], "creamy": [r"\bcream\b", r"coconut cream", r"\begg white\b", r"\bcreamy\b"], "floral": [r"\brose\b", r"\bviolet\b", r"\belderflower\b", r"\blavender\b", r"\bfloral\b"], "refreshing": [r"soda water", r"\btonic\b", r"\bhighball\b", r"\bcollins\b", r"\bfizz\b", r"\brefreshing\b"], "boozy": [r"\bstirred\b", r"\bmartini\b", r"old fashioned", r"\bboozy\b", r"\bstrong\b"], } FLAVOR_OPTIONS = list(FLAVORS.keys()) # ======================== # Robust extraction helpers (with measures) # ======================== def _clean(s): return s.strip() if isinstance(s, str) else "" def _norm_measure(s: str) -> str: if not isinstance(s, str): return "" s = re.sub(r"\s+", " ", s.strip()) s = re.sub(r"\bml\b", "ml", s, flags=re.I) s = re.sub(r"\boz\b", "oz", s, flags=re.I) s = re.sub(r"\btsp\b", "tsp", s, flags=re.I) s = re.sub(r"\btbsp\b", "tbsp", s, flags=re.I) return s def _join_measure_name(measure, name): m = _norm_measure(measure) n = name.strip() if isinstance(name, str) else "" if m and n: return f"{m} {n}" return n or m def _split_ingredient_blob(s): if not isinstance(s, str): return [] parts = re.split(r"[,\n;•\-–]+", s) return [p.strip() for p in parts if p and p.strip()] def _from_list_of_pairs(val): out_disp, out_tokens = [], [] for x in val: if not isinstance(x, (list, tuple)) or len(x) == 0: continue if len(x) == 1: name = str(x[0]).strip() if name: out_disp.append(name); out_tokens.append(name.lower()); continue a, b = str(x[0]).strip(), str(x[1]).strip() if re.search(r"\d", a) and not re.search(r"\d", b): disp = _join_measure_name(a, b); out_disp.append(disp); out_tokens.append(b.lower()) elif re.search(r"\d", b) and not re.search(r"\d", a): disp = _join_measure_name(b, a); out_disp.append(disp); out_tokens.append(a.lower()) else: disp = (a + " " + b).strip(); out_disp.append(disp); out_tokens.append((b if len(b) > len(a) else a).lower()) return out_disp, out_tokens def _from_list_of_dicts(val): out_disp, out_tokens = [], [] for x in val: if not isinstance(x, dict): continue name = next((x[k].strip() for k in ["name","ingredient","item","raw","text","strIngredient"] if isinstance(x.get(k), str) and x[k].strip()), None) meas = next((x[k].strip() for k in ["measure","qty","quantity","amount","unit","Measure","strMeasure"] if isinstance(x.get(k), str) and x[k].strip()), None) if name and meas: out_disp.append(_join_measure_name(meas, name)); out_tokens.append(name.lower()) elif name: out_disp.append(name); out_tokens.append(name.lower()) return out_disp, out_tokens def _ingredients_from_any(val): if isinstance(val, str): lines = _split_ingredient_blob(val) tokens = [] for line in lines: parts = re.split(r"\s+", line); idx = 0 for i, p in enumerate(parts): if re.search(r"[A-Za-z]", p): idx = i; break tokens.append(" ".join(parts[idx:]).lower()) return lines, tokens if isinstance(val, list) and all(isinstance(x, str) for x in val): disp = [x.strip() for x in val if x and x.strip()] return disp, [x.lower().strip() for x in disp] if isinstance(val, list) and any(isinstance(x, (list, tuple)) for x in val): return _from_list_of_pairs(val) if isinstance(val, list) and any(isinstance(x, dict) for x in val): return _from_list_of_dicts(val) return [], [] def _get_title(row, cols): for k in ["title","name","cocktail_name","drink","Drink","strDrink"]: if k in cols and _clean(row.get(k)): return _clean(row[k]) return "Untitled" def _get_ingredients_with_measures(row, cols): if "ingredient_tokens" in cols and row.get("ingredient_tokens"): toks = [str(x).strip().lower() for x in row["ingredient_tokens"] if str(x).strip()] for mkey in ["measure_tokens","measures","measure_list"]: if mkey in cols and row.get(mkey) and isinstance(row[mkey], list) and len(row[mkey]) == len(toks): disp = [] for m, n in zip(row[mkey], row["ingredient_tokens"]): m = _norm_measure(str(m)); n = str(n).strip() disp.append(_join_measure_name(m, n) if m else n) return disp, toks return toks, toks for key in ["ingredients","ingredients_raw","raw_ingredients","Raw_Ingredients","Raw Ingredients","ingredient_list","ingredients_list"]: if key in cols and row.get(key) not in (None, "", [], {}): return _ingredients_from_any(row[key]) return [], [] def tag_base(text): t = text.lower() for base, pats in BASE_SPIRITS.items(): if any(re.search(p, t) for p in pats): return base return "other" def tag_flavors(text): t = text.lower(); tags = [] for flv, pats in FLAVORS.items(): if any(re.search(p, t) for p in pats): tags.append(flv) return tags # ======================== # Load dataset & build docs # ======================== ds = load_dataset(DATASET_ID, split="train", **load_kwargs) cols = ds.column_names DOCS = [] for r in ds: title = _get_title(r, cols) ing_disp, ing_tokens = _get_ingredients_with_measures(r, cols) ing_disp = [x for x in ing_disp if x]; ing_tokens = [x for x in ing_tokens if x] fused = f"{title}\nIngredients: {', '.join(ing_tokens)}" DOCS.append({ "title": title, "ingredients_display": ing_disp, "ingredients_tokens": ing_tokens, "text": fused, "base": tag_base(fused), "flavors": tag_flavors(fused), }) # ======================== # Embeddings # ======================== encoder = SentenceTransformer(EMBED_MODEL) doc_embs = encoder.encode([d["text"] for d in DOCS], normalize_embeddings=True, convert_to_numpy=True).astype("float32") # ======================== # Pretty ingredient formatting # ======================== _MEASURE_RE = re.compile(r"^\s*(?P(?:\d+(\.\d+)?|\d+\s*/\s*\d+|\d+\s*\d*/\d+)\s*(?:ml|oz|tsp|tbsp)?|\d+\s*(?:ml|oz|tsp|tbsp)|(?:dash|dashes|drop|drops|barspoon)s?)\b[\s\-–:]*", flags=re.I) def _split_measure_name_line(line: str): if not isinstance(line, str): return None, line m = _MEASURE_RE.match(line.strip()) if m: meas = _norm_measure(m.group("meas")); name = line[m.end():].strip() return meas, name or "" return "", line.strip() def _format_ingredients_markdown(lines): """Bullet points as 'Ingredient (amount)'. Also removes [ and ].""" if not lines: return "—" formatted = [] for ln in lines: ln = ln.replace("[","").replace("]","") meas, name = _split_measure_name_line(ln) if name and meas: formatted.append(f"- {name} ({meas})") elif name: formatted.append(f"- {name}") else: formatted.append(f"- {ln}") return "\n".join(formatted) # ======================== # Recommendation # ======================== def recommend(base_alcohol_text, flavor, top_k=3): inferred_base = tag_base(base_alcohol_text or "") if flavor not in FLAVOR_OPTIONS: return "Please choose a flavor." idxs = [i for i, d in enumerate(DOCS) if d["base"] == inferred_base] or list(range(len(DOCS))) q_text = f"Base spirit: {base_alcohol_text}. Flavor: {flavor}. Cocktail recipe." q_emb = encoder.encode([q_text], normalize_embeddings=True, convert_to_numpy=True).astype("float32")[0] sims = doc_embs[idxs].dot(q_emb) scored = [] for pos, i in enumerate(idxs): score = float(sims[pos]) + (FLAVOR_BOOST if flavor in DOCS[i]['flavors'] else 0.0) scored.append((score, i)) scored.sort(reverse=True) picks = scored[:max(1,int(top_k))] if not picks: return "No matches found." blocks = [] for sc, i in picks: d = DOCS[i] ing_lines = d["ingredients_display"] or d["ingredients_tokens"] ing_md = _format_ingredients_markdown(ing_lines) meta = f"**Base:** {d['base']} | **Flavor tags:** {', '.join(d['flavors']) or '—'} | **Score:** {sc:.3f}" blocks.append(f"### {d['title']}\n{meta}\n\n**Ingredients:**\n{ing_md}") return "\n\n---\n\n".join(blocks) # ======================== # Background + UI (robust) # ======================== CUSTOM_CSS = f""" html, body, #root {{ height: 100%; }} /* Background on BODY to avoid component stacking issues */ body {{ background-image: url('{BACKGROUND_IMAGE_URL}'); background-size: cover; background-position: center; background-attachment: fixed; }} /* Dark overlay for text contrast */ body::before {{ content: ""; position: fixed; inset: 0; background: rgba(0,0,0,0.30); /* slightly lighter so image shows */ z-index: 0; }} /* Make the app transparent and float above overlay */ .gradio-container {{ background: transparent !important; position: relative; z-index: 1; }} .glass-card {{ background: rgba(255, 255, 255, 0.08); backdrop-filter: blur(6px); -webkit-backdrop-filter: blur(6px); border-radius: 14px; padding: 18px; border: 1px solid rgba(255, 255, 255, 0.12); }} """ with gr.Blocks(css=CUSTOM_CSS) as demo: with gr.Column(elem_classes=["glass-card"]): gr.Markdown("# 🍹 AI Bartender — Type a Base + Flavor") with gr.Row(): base_text = gr.Textbox(value="gin", label="Base alcohol (type any spirit, e.g., 'gin', 'white rum', 'bourbon')") flavor = gr.Dropdown(choices=FLAVOR_OPTIONS, value="citrus", label="Flavor") topk = gr.Slider(1, 10, value=3, step=1, label="Number of recommendations") with gr.Row(): ex1 = gr.Button("Example: Gin + Citrus") ex2 = gr.Button("Example: Rum + Fruity") ex3 = gr.Button("Example: Mezcal + Smoky") out = gr.Markdown() gr.Button("Recommend").click(recommend, [base_text, flavor, topk], out) ex1.click(lambda: ("gin", "citrus", 3), outputs=[base_text, flavor, topk]) ex2.click(lambda: ("white rum", "fruity", 3), outputs=[base_text, flavor, topk]) ex3.click(lambda: ("mezcal", "smoky", 3), outputs=[base_text, flavor, topk]) if __name__ == "__main__": demo.launch()