OGOGOG's picture
Update app.py
d97596f verified
import os
import re
import numpy as np
import gradio as gr
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
# ========================
# Config
# ========================
DATASET_ID = "motimmom/cocktails_clean_nobrand"
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
FLAVOR_BOOST = 0.20
# Use the image you uploaded at the root of the Space repo:
BACKGROUND_IMAGE_URL = "file=bar.jpg" # <-- safest: served by Gradio from your Space files
# If you prefer the remote URL, make sure the space name uses the HY-PHEN:
# BACKGROUND_IMAGE_URL = "https://huggingface.co/spaces/OGOGOG/AI-Bartender/resolve/main/bar.jpg"
# If dataset is private, add Space secret HF_TOKEN (read scope)
HF_READ_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
load_kwargs = {}
if HF_READ_TOKEN:
load_kwargs["token"] = HF_READ_TOKEN
load_kwargs["use_auth_token"] = HF_READ_TOKEN
# ========================
# Base & Flavor tagging rules
# ========================
BASE_SPIRITS = {
"vodka": [r"\bvodka\b"],
"gin": [r"\bgin\b"],
"rum": [r"\brum\b", r"\bwhite rum\b", r"\bdark rum\b"],
"tequila": [r"\btequila\b"],
"whiskey": [r"\bwhisk(?:e|)y\b", r"\bbourbon\b", r"\bscotch\b", r"\brye\b"],
"mezcal": [r"\bmezcal\b"],
"brandy": [r"\bbrandy\b", r"\bcognac\b"],
"vermouth": [r"\bvermouth\b"],
"other": [r"\btriple sec\b", r"\bliqueur\b", r"\bcointreau\b", r"\baperol\b", r"\bcampari\b"],
}
FLAVORS = {
"citrus": [r"lime", r"lemon", r"grapefruit", r"orange", r"\bcitrus\b"],
"sweet": [r"simple syrup", r"\bsugar\b", r"\bhoney\b", r"\bagave\b", r"\bmaple\b", r"\bgrenadine\b", r"\bvanilla\b", r"\bsweet\b"],
"sour": [r"\bsour\b", r"lemon juice", r"lime juice", r"\bacid\b"],
"bitter": [r"\bbitter", r"\bamaro\b", r"\bcampari\b", r"\baperol\b"],
"smoky": [r"\bsmoky\b", r"\bsmoked\b", r"\bmezcal\b", r"\bpeated\b"],
"spicy": [r"\bspicy\b", r"\bchili\b", r"\bginger\b", r"\bjalapeño\b", r"\bcayenne\b"],
"herbal": [r"\bmint\b", r"\bbasil\b", r"\brosemary\b", r"\bthyme\b", r"\bherb", r"\bchartreuse\b"],
"fruity": [r"pineapple", r"cranberr", r"strawberr", r"mango", r"passion", r"peach", r"\bfruit"],
"creamy": [r"\bcream\b", r"coconut cream", r"\begg white\b", r"\bcreamy\b"],
"floral": [r"\brose\b", r"\bviolet\b", r"\belderflower\b", r"\blavender\b", r"\bfloral\b"],
"refreshing": [r"soda water", r"\btonic\b", r"\bhighball\b", r"\bcollins\b", r"\bfizz\b", r"\brefreshing\b"],
"boozy": [r"\bstirred\b", r"\bmartini\b", r"old fashioned", r"\bboozy\b", r"\bstrong\b"],
}
FLAVOR_OPTIONS = list(FLAVORS.keys())
# ========================
# Robust extraction helpers (with measures)
# ========================
def _clean(s): return s.strip() if isinstance(s, str) else ""
def _norm_measure(s: str) -> str:
if not isinstance(s, str): return ""
s = re.sub(r"\s+", " ", s.strip())
s = re.sub(r"\bml\b", "ml", s, flags=re.I)
s = re.sub(r"\boz\b", "oz", s, flags=re.I)
s = re.sub(r"\btsp\b", "tsp", s, flags=re.I)
s = re.sub(r"\btbsp\b", "tbsp", s, flags=re.I)
return s
def _join_measure_name(measure, name):
m = _norm_measure(measure)
n = name.strip() if isinstance(name, str) else ""
if m and n: return f"{m} {n}"
return n or m
def _split_ingredient_blob(s):
if not isinstance(s, str): return []
parts = re.split(r"[,\n;•\-–]+", s)
return [p.strip() for p in parts if p and p.strip()]
def _from_list_of_pairs(val):
out_disp, out_tokens = [], []
for x in val:
if not isinstance(x, (list, tuple)) or len(x) == 0: continue
if len(x) == 1:
name = str(x[0]).strip()
if name: out_disp.append(name); out_tokens.append(name.lower()); continue
a, b = str(x[0]).strip(), str(x[1]).strip()
if re.search(r"\d", a) and not re.search(r"\d", b):
disp = _join_measure_name(a, b); out_disp.append(disp); out_tokens.append(b.lower())
elif re.search(r"\d", b) and not re.search(r"\d", a):
disp = _join_measure_name(b, a); out_disp.append(disp); out_tokens.append(a.lower())
else:
disp = (a + " " + b).strip(); out_disp.append(disp); out_tokens.append((b if len(b) > len(a) else a).lower())
return out_disp, out_tokens
def _from_list_of_dicts(val):
out_disp, out_tokens = [], []
for x in val:
if not isinstance(x, dict): continue
name = next((x[k].strip() for k in ["name","ingredient","item","raw","text","strIngredient"] if isinstance(x.get(k), str) and x[k].strip()), None)
meas = next((x[k].strip() for k in ["measure","qty","quantity","amount","unit","Measure","strMeasure"] if isinstance(x.get(k), str) and x[k].strip()), None)
if name and meas:
out_disp.append(_join_measure_name(meas, name)); out_tokens.append(name.lower())
elif name:
out_disp.append(name); out_tokens.append(name.lower())
return out_disp, out_tokens
def _ingredients_from_any(val):
if isinstance(val, str):
lines = _split_ingredient_blob(val)
tokens = []
for line in lines:
parts = re.split(r"\s+", line); idx = 0
for i, p in enumerate(parts):
if re.search(r"[A-Za-z]", p): idx = i; break
tokens.append(" ".join(parts[idx:]).lower())
return lines, tokens
if isinstance(val, list) and all(isinstance(x, str) for x in val):
disp = [x.strip() for x in val if x and x.strip()]
return disp, [x.lower().strip() for x in disp]
if isinstance(val, list) and any(isinstance(x, (list, tuple)) for x in val):
return _from_list_of_pairs(val)
if isinstance(val, list) and any(isinstance(x, dict) for x in val):
return _from_list_of_dicts(val)
return [], []
def _get_title(row, cols):
for k in ["title","name","cocktail_name","drink","Drink","strDrink"]:
if k in cols and _clean(row.get(k)): return _clean(row[k])
return "Untitled"
def _get_ingredients_with_measures(row, cols):
if "ingredient_tokens" in cols and row.get("ingredient_tokens"):
toks = [str(x).strip().lower() for x in row["ingredient_tokens"] if str(x).strip()]
for mkey in ["measure_tokens","measures","measure_list"]:
if mkey in cols and row.get(mkey) and isinstance(row[mkey], list) and len(row[mkey]) == len(toks):
disp = []
for m, n in zip(row[mkey], row["ingredient_tokens"]):
m = _norm_measure(str(m)); n = str(n).strip()
disp.append(_join_measure_name(m, n) if m else n)
return disp, toks
return toks, toks
for key in ["ingredients","ingredients_raw","raw_ingredients","Raw_Ingredients","Raw Ingredients","ingredient_list","ingredients_list"]:
if key in cols and row.get(key) not in (None, "", [], {}): return _ingredients_from_any(row[key])
return [], []
def tag_base(text):
t = text.lower()
for base, pats in BASE_SPIRITS.items():
if any(re.search(p, t) for p in pats): return base
return "other"
def tag_flavors(text):
t = text.lower(); tags = []
for flv, pats in FLAVORS.items():
if any(re.search(p, t) for p in pats): tags.append(flv)
return tags
# ========================
# Load dataset & build docs
# ========================
ds = load_dataset(DATASET_ID, split="train", **load_kwargs)
cols = ds.column_names
DOCS = []
for r in ds:
title = _get_title(r, cols)
ing_disp, ing_tokens = _get_ingredients_with_measures(r, cols)
ing_disp = [x for x in ing_disp if x]; ing_tokens = [x for x in ing_tokens if x]
fused = f"{title}\nIngredients: {', '.join(ing_tokens)}"
DOCS.append({
"title": title,
"ingredients_display": ing_disp,
"ingredients_tokens": ing_tokens,
"text": fused,
"base": tag_base(fused),
"flavors": tag_flavors(fused),
})
# ========================
# Embeddings
# ========================
encoder = SentenceTransformer(EMBED_MODEL)
doc_embs = encoder.encode([d["text"] for d in DOCS], normalize_embeddings=True, convert_to_numpy=True).astype("float32")
# ========================
# Pretty ingredient formatting
# ========================
_MEASURE_RE = re.compile(r"^\s*(?P<meas>(?:\d+(\.\d+)?|\d+\s*/\s*\d+|\d+\s*\d*/\d+)\s*(?:ml|oz|tsp|tbsp)?|\d+\s*(?:ml|oz|tsp|tbsp)|(?:dash|dashes|drop|drops|barspoon)s?)\b[\s\-–:]*", flags=re.I)
def _split_measure_name_line(line: str):
if not isinstance(line, str): return None, line
m = _MEASURE_RE.match(line.strip())
if m:
meas = _norm_measure(m.group("meas")); name = line[m.end():].strip()
return meas, name or ""
return "", line.strip()
def _format_ingredients_markdown(lines):
"""Bullet points as 'Ingredient (amount)'. Also removes [ and ]."""
if not lines: return "—"
formatted = []
for ln in lines:
ln = ln.replace("[","").replace("]","")
meas, name = _split_measure_name_line(ln)
if name and meas: formatted.append(f"- {name} ({meas})")
elif name: formatted.append(f"- {name}")
else: formatted.append(f"- {ln}")
return "\n".join(formatted)
# ========================
# Recommendation
# ========================
def recommend(base_alcohol_text, flavor, top_k=3):
inferred_base = tag_base(base_alcohol_text or "")
if flavor not in FLAVOR_OPTIONS: return "Please choose a flavor."
idxs = [i for i, d in enumerate(DOCS) if d["base"] == inferred_base] or list(range(len(DOCS)))
q_text = f"Base spirit: {base_alcohol_text}. Flavor: {flavor}. Cocktail recipe."
q_emb = encoder.encode([q_text], normalize_embeddings=True, convert_to_numpy=True).astype("float32")[0]
sims = doc_embs[idxs].dot(q_emb)
scored = []
for pos, i in enumerate(idxs):
score = float(sims[pos]) + (FLAVOR_BOOST if flavor in DOCS[i]['flavors'] else 0.0)
scored.append((score, i))
scored.sort(reverse=True)
picks = scored[:max(1,int(top_k))]
if not picks: return "No matches found."
blocks = []
for sc, i in picks:
d = DOCS[i]
ing_lines = d["ingredients_display"] or d["ingredients_tokens"]
ing_md = _format_ingredients_markdown(ing_lines)
meta = f"**Base:** {d['base']} | **Flavor tags:** {', '.join(d['flavors']) or '—'} | **Score:** {sc:.3f}"
blocks.append(f"### {d['title']}\n{meta}\n\n**Ingredients:**\n{ing_md}")
return "\n\n---\n\n".join(blocks)
# ========================
# Background + UI (robust)
# ========================
CUSTOM_CSS = f"""
html, body, #root {{ height: 100%; }}
/* Background on BODY to avoid component stacking issues */
body {{
background-image: url('{BACKGROUND_IMAGE_URL}');
background-size: cover;
background-position: center;
background-attachment: fixed;
}}
/* Dark overlay for text contrast */
body::before {{
content: "";
position: fixed;
inset: 0;
background: rgba(0,0,0,0.30); /* slightly lighter so image shows */
z-index: 0;
}}
/* Make the app transparent and float above overlay */
.gradio-container {{ background: transparent !important; position: relative; z-index: 1; }}
.glass-card {{
background: rgba(255, 255, 255, 0.08);
backdrop-filter: blur(6px);
-webkit-backdrop-filter: blur(6px);
border-radius: 14px;
padding: 18px;
border: 1px solid rgba(255, 255, 255, 0.12);
}}
"""
with gr.Blocks(css=CUSTOM_CSS) as demo:
with gr.Column(elem_classes=["glass-card"]):
gr.Markdown("# 🍹 AI Bartender — Type a Base + Flavor")
with gr.Row():
base_text = gr.Textbox(value="gin", label="Base alcohol (type any spirit, e.g., 'gin', 'white rum', 'bourbon')")
flavor = gr.Dropdown(choices=FLAVOR_OPTIONS, value="citrus", label="Flavor")
topk = gr.Slider(1, 10, value=3, step=1, label="Number of recommendations")
with gr.Row():
ex1 = gr.Button("Example: Gin + Citrus")
ex2 = gr.Button("Example: Rum + Fruity")
ex3 = gr.Button("Example: Mezcal + Smoky")
out = gr.Markdown()
gr.Button("Recommend").click(recommend, [base_text, flavor, topk], out)
ex1.click(lambda: ("gin", "citrus", 3), outputs=[base_text, flavor, topk])
ex2.click(lambda: ("white rum", "fruity", 3), outputs=[base_text, flavor, topk])
ex3.click(lambda: ("mezcal", "smoky", 3), outputs=[base_text, flavor, topk])
if __name__ == "__main__":
demo.launch()