Spaces:
Sleeping
Sleeping
Rename dataset.py to app.py
Browse files- app.py +195 -0
- dataset.py +0 -29
app.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========================
|
| 2 |
+
# 1) Install deps (Colab/Space)
|
| 3 |
+
# ========================
|
| 4 |
+
!pip -q install datasets sentence-transformers gradio --quiet
|
| 5 |
+
|
| 6 |
+
# ========================
|
| 7 |
+
# 2) Imports
|
| 8 |
+
# ========================
|
| 9 |
+
import re, numpy as np, gradio as gr
|
| 10 |
+
from datasets import load_dataset
|
| 11 |
+
from sentence_transformers import SentenceTransformer
|
| 12 |
+
|
| 13 |
+
# ========================
|
| 14 |
+
# 3) Load dataset
|
| 15 |
+
# ========================
|
| 16 |
+
DATASET_ID = "motimmom/cocktails_clean_nobrand" # <-- change if needed
|
| 17 |
+
ds = load_dataset(DATASET_ID, split="train")
|
| 18 |
+
|
| 19 |
+
# ========================
|
| 20 |
+
# 4) Tagging rules (base + flavors)
|
| 21 |
+
# ========================
|
| 22 |
+
BASE_SPIRITS = {
|
| 23 |
+
"vodka": [r"\bvodka\b"],
|
| 24 |
+
"gin": [r"\bgin\b"],
|
| 25 |
+
"rum": [r"\brum\b", r"\bwhite rum\b", r"\bdark rum\b"],
|
| 26 |
+
"tequila": [r"\btequila\b"],
|
| 27 |
+
"whiskey": [r"\bwhisk(?:e|)y\b", r"\bbourbon\b", r"\bscotch\b", r"\brye\b"],
|
| 28 |
+
"mezcal": [r"\bmezcal\b"],
|
| 29 |
+
"brandy": [r"\bbrandy\b", r"\bcognac\b"],
|
| 30 |
+
"vermouth": [r"\bvermouth\b"],
|
| 31 |
+
"other": [r"\btriple sec\b", r"\bliqueur\b", r"\bcointreau\b", r"\baperol\b", r"\bcampari\b"],
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
FLAVORS = {
|
| 35 |
+
"citrus": [r"lime", r"lemon", r"grapefruit", r"orange", r"citrus"],
|
| 36 |
+
"sweet": [r"simple syrup", r"sugar", r"honey", r"agave", r"maple", r"grenadine", r"vanilla", r"sweet"],
|
| 37 |
+
"sour": [r"\bsour\b", r"lemon juice", r"lime juice", r"acid"],
|
| 38 |
+
"bitter": [r"\bbitter", r"\bamaro\b", r"\bcampari\b", r"\baperol\b"],
|
| 39 |
+
"smoky": [r"\bsmoky\b", r"\bsmoked\b", r"\bmezcal\b", r"\bpeated\b"],
|
| 40 |
+
"spicy": [r"spicy", r"chili", r"ginger", r"jalapeΓ±o", r"cayenne"],
|
| 41 |
+
"herbal": [r"mint", r"basil", r"rosemary", r"thyme", r"herb", r"chartreuse"],
|
| 42 |
+
"fruity": [r"pineapple", r"cranberr", r"strawberr", r"mango", r"passion", r"peach", r"fruit"],
|
| 43 |
+
"creamy": [r"cream", r"coconut cream", r"egg white", r"creamy"],
|
| 44 |
+
"floral": [r"rose", r"violet", r"elderflower", r"lavender", r"floral"],
|
| 45 |
+
"refreshing": [r"soda water", r"tonic", r"highball", r"collins", r"fizz", r"refreshing"],
|
| 46 |
+
"boozy": [r"stirred", r"martini", r"old fashioned", r"\bboozy\b", r"\bstrong\b"],
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
def _clean(s):
|
| 50 |
+
return s.strip() if isinstance(s, str) else ""
|
| 51 |
+
|
| 52 |
+
def _extract_names_from_pairs(pairs):
|
| 53 |
+
out = []
|
| 54 |
+
if not pairs: return out
|
| 55 |
+
for p in pairs:
|
| 56 |
+
if isinstance(p, (list, tuple)) and len(p) >= 2 and p[1]:
|
| 57 |
+
name = str(p[1]).strip().lower()
|
| 58 |
+
if name and name not in out:
|
| 59 |
+
out.append(name)
|
| 60 |
+
return out
|
| 61 |
+
|
| 62 |
+
def _get_ingredients(row, cols):
|
| 63 |
+
# try a few common schemas
|
| 64 |
+
if "ingredient_tokens" in cols and row.get("ingredient_tokens"):
|
| 65 |
+
return [str(x).strip().lower() for x in row["ingredient_tokens"] if str(x).strip()]
|
| 66 |
+
if "ingredients" in cols and row.get("ingredients"):
|
| 67 |
+
vals = []
|
| 68 |
+
for x in row["ingredients"]:
|
| 69 |
+
if isinstance(x, str) and x.strip():
|
| 70 |
+
vals.append(x.strip().lower())
|
| 71 |
+
if vals: return vals
|
| 72 |
+
for k in ["ingredients_raw", "ingredient_list", "ingredients_list"]:
|
| 73 |
+
if k in cols and row.get(k):
|
| 74 |
+
names = _extract_names_from_pairs(row[k])
|
| 75 |
+
if names: return names
|
| 76 |
+
return []
|
| 77 |
+
|
| 78 |
+
def _get_instructions(row, cols):
|
| 79 |
+
for k in ["instructions", "howto", "preparation", "strInstructions", "recipe"]:
|
| 80 |
+
if k in cols and _clean(row.get(k)):
|
| 81 |
+
return _clean(row[k])
|
| 82 |
+
return ""
|
| 83 |
+
|
| 84 |
+
def tag_base(text):
|
| 85 |
+
t = text.lower()
|
| 86 |
+
for base, pats in BASE_SPIRITS.items():
|
| 87 |
+
if any(re.search(p, t) for p in pats):
|
| 88 |
+
return base
|
| 89 |
+
return "other"
|
| 90 |
+
|
| 91 |
+
def tag_flavors(text):
|
| 92 |
+
t = text.lower()
|
| 93 |
+
tags = []
|
| 94 |
+
for flv, pats in FLAVORS.items():
|
| 95 |
+
if any(re.search(p, t) for p in pats):
|
| 96 |
+
tags.append(flv)
|
| 97 |
+
return tags
|
| 98 |
+
|
| 99 |
+
# ========================
|
| 100 |
+
# 5) Build documents
|
| 101 |
+
# ========================
|
| 102 |
+
cols = ds.column_names
|
| 103 |
+
DOCS = []
|
| 104 |
+
for r in ds:
|
| 105 |
+
title = _clean(r.get("title") or r.get("name") or r.get("cocktail_name") or "Untitled")
|
| 106 |
+
ingredients = _get_ingredients(r, cols)
|
| 107 |
+
instructions = _get_instructions(r, cols)
|
| 108 |
+
fused = f"{title}\nIngredients: {', '.join(ingredients)}\nInstructions: {instructions}"
|
| 109 |
+
DOCS.append({
|
| 110 |
+
"title": title,
|
| 111 |
+
"ingredients": ingredients,
|
| 112 |
+
"instructions": instructions,
|
| 113 |
+
"text": fused,
|
| 114 |
+
"base": tag_base(fused),
|
| 115 |
+
"flavors": tag_flavors(fused),
|
| 116 |
+
})
|
| 117 |
+
|
| 118 |
+
# ========================
|
| 119 |
+
# 6) Embeddings (semantic search)
|
| 120 |
+
# ========================
|
| 121 |
+
encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
| 122 |
+
embs = encoder.encode([d["text"] for d in DOCS], normalize_embeddings=True, convert_to_numpy=True)
|
| 123 |
+
embs = embs.astype("float32") # cosine via dot since normalized
|
| 124 |
+
|
| 125 |
+
# ========================
|
| 126 |
+
# 7) Recommendation
|
| 127 |
+
# ========================
|
| 128 |
+
BASE_OPTIONS = list(BASE_SPIRITS.keys())
|
| 129 |
+
FLAVOR_OPTIONS = list(FLAVORS.keys())
|
| 130 |
+
BOOST = 0.20 # flavor bonus
|
| 131 |
+
|
| 132 |
+
def recommend(base_alcohol, flavor, top_k=3):
|
| 133 |
+
if base_alcohol not in BASE_OPTIONS:
|
| 134 |
+
return "Please pick a base alcohol."
|
| 135 |
+
if flavor not in FLAVOR_OPTIONS:
|
| 136 |
+
return "Please pick a flavor."
|
| 137 |
+
|
| 138 |
+
# 1) Hard filter by base
|
| 139 |
+
idxs = [i for i, d in enumerate(DOCS) if d["base"] == base_alcohol]
|
| 140 |
+
if not idxs:
|
| 141 |
+
# fallback if dataset has no exact base match
|
| 142 |
+
idxs = list(range(len(DOCS)))
|
| 143 |
+
|
| 144 |
+
# 2) Build a query text that steers embedding toward your constraints
|
| 145 |
+
q_text = f"Base spirit: {base_alcohol}. Flavor: {flavor}. Cocktail recipe."
|
| 146 |
+
q_emb = encoder.encode([q_text], normalize_embeddings=True, convert_to_numpy=True).astype("float32")[0]
|
| 147 |
+
|
| 148 |
+
# 3) Cosine similarity (dot, because normalized)
|
| 149 |
+
sims = embs[idxs].dot(q_emb)
|
| 150 |
+
|
| 151 |
+
# 4) Flavor boost if recipe carries the desired flavor tag
|
| 152 |
+
scores = []
|
| 153 |
+
for local_pos, i in enumerate(idxs):
|
| 154 |
+
base_score = float(sims[local_pos])
|
| 155 |
+
has_flavor = flavor in DOCS[i]["flavors"]
|
| 156 |
+
score = base_score + (BOOST if has_flavor else 0.0)
|
| 157 |
+
scores.append((score, i))
|
| 158 |
+
scores.sort(reverse=True)
|
| 159 |
+
|
| 160 |
+
# 5) Build output
|
| 161 |
+
k = max(1, int(top_k))
|
| 162 |
+
picks = scores[:k]
|
| 163 |
+
if not picks:
|
| 164 |
+
return "No matches found."
|
| 165 |
+
|
| 166 |
+
blocks = []
|
| 167 |
+
for sc, i in picks:
|
| 168 |
+
d = DOCS[i]
|
| 169 |
+
ing_txt = ", ".join(d["ingredients"]) if d["ingredients"] else "β"
|
| 170 |
+
instr = d["instructions"] or "β"
|
| 171 |
+
instr = instr[:400] + ("..." if len(instr) > 400 else "")
|
| 172 |
+
meta = f"**Base:** {d['base']} | **Flavors:** {', '.join(d['flavors']) or 'β'} | **Score:** {sc:.3f}"
|
| 173 |
+
blocks.append(f"### {d['title']}\n{meta}\n\n**Ingredients:** {ing_txt}\n\n**Instructions:** {instr}")
|
| 174 |
+
return "\n\n---\n\n".join(blocks)
|
| 175 |
+
|
| 176 |
+
# ========================
|
| 177 |
+
# 8) Gradio UI
|
| 178 |
+
# ========================
|
| 179 |
+
with gr.Blocks() as demo:
|
| 180 |
+
gr.Markdown("# πΉ Cocktail Recommender β Pick a Base & Flavor")
|
| 181 |
+
with gr.Row():
|
| 182 |
+
base = gr.Dropdown(choices=BASE_OPTIONS, value="gin", label="Base alcohol")
|
| 183 |
+
flavor = gr.Dropdown(choices=FLAVOR_OPTIONS, value="citrus", label="Flavor")
|
| 184 |
+
topk = gr.Slider(1, 10, value=3, step=1, label="Number of recommendations")
|
| 185 |
+
with gr.Row():
|
| 186 |
+
ex1 = gr.Button("Gin + Citrus")
|
| 187 |
+
ex2 = gr.Button("Rum + Fruity")
|
| 188 |
+
ex3 = gr.Button("Mezcal + Smoky")
|
| 189 |
+
out = gr.Markdown()
|
| 190 |
+
gr.Button("Recommend").click(recommend, [base, flavor, topk], out)
|
| 191 |
+
ex1.click(lambda: ("gin", "citrus", 3), outputs=[base, flavor, topk])
|
| 192 |
+
ex2.click(lambda: ("rum", "fruity", 3), outputs=[base, flavor, topk])
|
| 193 |
+
ex3.click(lambda: ("mezcal", "smoky", 3), outputs=[base, flavor, topk])
|
| 194 |
+
|
| 195 |
+
demo.launch(share=True)
|
dataset.py
DELETED
|
@@ -1,29 +0,0 @@
|
|
| 1 |
-
# make_dataset.py
|
| 2 |
-
import random, json, torch
|
| 3 |
-
from datasets import Dataset
|
| 4 |
-
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 5 |
-
|
| 6 |
-
MODEL = "erwanlc/t5-cocktails_recipe-base"
|
| 7 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL)
|
| 8 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL, device_map="auto")
|
| 9 |
-
|
| 10 |
-
ING_POOL = ["vodka","gin","rum","tequila","whiskey","triple sec","vermouth","lime juice",
|
| 11 |
-
"lemon juice","cranberry juice","pineapple juice","simple syrup","agave syrup",
|
| 12 |
-
"bitters","ginger beer","soda water","tonic water","mint","basil","cucumber"]
|
| 13 |
-
|
| 14 |
-
def rand_ings():
|
| 15 |
-
return ", ".join(random.sample(ING_POOL, k=random.randint(3,6)))
|
| 16 |
-
|
| 17 |
-
rows = []
|
| 18 |
-
for i in range(1000):
|
| 19 |
-
ings = rand_ings()
|
| 20 |
-
prompt = f"ingredients: {ings}"
|
| 21 |
-
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 22 |
-
ids = model.generate(**inputs, max_new_tokens=180, do_sample=True, temperature=0.9, top_p=0.95)
|
| 23 |
-
text = tokenizer.decode(ids[0], skip_special_tokens=True)
|
| 24 |
-
rows.append({"id": i, "ingredients_text": ings, "generated_text": text})
|
| 25 |
-
|
| 26 |
-
ds = Dataset.from_list(rows)
|
| 27 |
-
ds.to_parquet("cocktail_synth.parquet")
|
| 28 |
-
print("Wrote cocktail_synth.parquet")
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|