OGOGOG commited on
Commit
d52e547
Β·
verified Β·
1 Parent(s): 9e9bfcf

Rename dataset.py to app.py

Browse files
Files changed (2) hide show
  1. app.py +195 -0
  2. dataset.py +0 -29
app.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========================
2
+ # 1) Install deps (Colab/Space)
3
+ # ========================
4
+ !pip -q install datasets sentence-transformers gradio --quiet
5
+
6
+ # ========================
7
+ # 2) Imports
8
+ # ========================
9
+ import re, numpy as np, gradio as gr
10
+ from datasets import load_dataset
11
+ from sentence_transformers import SentenceTransformer
12
+
13
+ # ========================
14
+ # 3) Load dataset
15
+ # ========================
16
+ DATASET_ID = "motimmom/cocktails_clean_nobrand" # <-- change if needed
17
+ ds = load_dataset(DATASET_ID, split="train")
18
+
19
+ # ========================
20
+ # 4) Tagging rules (base + flavors)
21
+ # ========================
22
+ BASE_SPIRITS = {
23
+ "vodka": [r"\bvodka\b"],
24
+ "gin": [r"\bgin\b"],
25
+ "rum": [r"\brum\b", r"\bwhite rum\b", r"\bdark rum\b"],
26
+ "tequila": [r"\btequila\b"],
27
+ "whiskey": [r"\bwhisk(?:e|)y\b", r"\bbourbon\b", r"\bscotch\b", r"\brye\b"],
28
+ "mezcal": [r"\bmezcal\b"],
29
+ "brandy": [r"\bbrandy\b", r"\bcognac\b"],
30
+ "vermouth": [r"\bvermouth\b"],
31
+ "other": [r"\btriple sec\b", r"\bliqueur\b", r"\bcointreau\b", r"\baperol\b", r"\bcampari\b"],
32
+ }
33
+
34
+ FLAVORS = {
35
+ "citrus": [r"lime", r"lemon", r"grapefruit", r"orange", r"citrus"],
36
+ "sweet": [r"simple syrup", r"sugar", r"honey", r"agave", r"maple", r"grenadine", r"vanilla", r"sweet"],
37
+ "sour": [r"\bsour\b", r"lemon juice", r"lime juice", r"acid"],
38
+ "bitter": [r"\bbitter", r"\bamaro\b", r"\bcampari\b", r"\baperol\b"],
39
+ "smoky": [r"\bsmoky\b", r"\bsmoked\b", r"\bmezcal\b", r"\bpeated\b"],
40
+ "spicy": [r"spicy", r"chili", r"ginger", r"jalapeΓ±o", r"cayenne"],
41
+ "herbal": [r"mint", r"basil", r"rosemary", r"thyme", r"herb", r"chartreuse"],
42
+ "fruity": [r"pineapple", r"cranberr", r"strawberr", r"mango", r"passion", r"peach", r"fruit"],
43
+ "creamy": [r"cream", r"coconut cream", r"egg white", r"creamy"],
44
+ "floral": [r"rose", r"violet", r"elderflower", r"lavender", r"floral"],
45
+ "refreshing": [r"soda water", r"tonic", r"highball", r"collins", r"fizz", r"refreshing"],
46
+ "boozy": [r"stirred", r"martini", r"old fashioned", r"\bboozy\b", r"\bstrong\b"],
47
+ }
48
+
49
+ def _clean(s):
50
+ return s.strip() if isinstance(s, str) else ""
51
+
52
+ def _extract_names_from_pairs(pairs):
53
+ out = []
54
+ if not pairs: return out
55
+ for p in pairs:
56
+ if isinstance(p, (list, tuple)) and len(p) >= 2 and p[1]:
57
+ name = str(p[1]).strip().lower()
58
+ if name and name not in out:
59
+ out.append(name)
60
+ return out
61
+
62
+ def _get_ingredients(row, cols):
63
+ # try a few common schemas
64
+ if "ingredient_tokens" in cols and row.get("ingredient_tokens"):
65
+ return [str(x).strip().lower() for x in row["ingredient_tokens"] if str(x).strip()]
66
+ if "ingredients" in cols and row.get("ingredients"):
67
+ vals = []
68
+ for x in row["ingredients"]:
69
+ if isinstance(x, str) and x.strip():
70
+ vals.append(x.strip().lower())
71
+ if vals: return vals
72
+ for k in ["ingredients_raw", "ingredient_list", "ingredients_list"]:
73
+ if k in cols and row.get(k):
74
+ names = _extract_names_from_pairs(row[k])
75
+ if names: return names
76
+ return []
77
+
78
+ def _get_instructions(row, cols):
79
+ for k in ["instructions", "howto", "preparation", "strInstructions", "recipe"]:
80
+ if k in cols and _clean(row.get(k)):
81
+ return _clean(row[k])
82
+ return ""
83
+
84
+ def tag_base(text):
85
+ t = text.lower()
86
+ for base, pats in BASE_SPIRITS.items():
87
+ if any(re.search(p, t) for p in pats):
88
+ return base
89
+ return "other"
90
+
91
+ def tag_flavors(text):
92
+ t = text.lower()
93
+ tags = []
94
+ for flv, pats in FLAVORS.items():
95
+ if any(re.search(p, t) for p in pats):
96
+ tags.append(flv)
97
+ return tags
98
+
99
+ # ========================
100
+ # 5) Build documents
101
+ # ========================
102
+ cols = ds.column_names
103
+ DOCS = []
104
+ for r in ds:
105
+ title = _clean(r.get("title") or r.get("name") or r.get("cocktail_name") or "Untitled")
106
+ ingredients = _get_ingredients(r, cols)
107
+ instructions = _get_instructions(r, cols)
108
+ fused = f"{title}\nIngredients: {', '.join(ingredients)}\nInstructions: {instructions}"
109
+ DOCS.append({
110
+ "title": title,
111
+ "ingredients": ingredients,
112
+ "instructions": instructions,
113
+ "text": fused,
114
+ "base": tag_base(fused),
115
+ "flavors": tag_flavors(fused),
116
+ })
117
+
118
+ # ========================
119
+ # 6) Embeddings (semantic search)
120
+ # ========================
121
+ encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
122
+ embs = encoder.encode([d["text"] for d in DOCS], normalize_embeddings=True, convert_to_numpy=True)
123
+ embs = embs.astype("float32") # cosine via dot since normalized
124
+
125
+ # ========================
126
+ # 7) Recommendation
127
+ # ========================
128
+ BASE_OPTIONS = list(BASE_SPIRITS.keys())
129
+ FLAVOR_OPTIONS = list(FLAVORS.keys())
130
+ BOOST = 0.20 # flavor bonus
131
+
132
+ def recommend(base_alcohol, flavor, top_k=3):
133
+ if base_alcohol not in BASE_OPTIONS:
134
+ return "Please pick a base alcohol."
135
+ if flavor not in FLAVOR_OPTIONS:
136
+ return "Please pick a flavor."
137
+
138
+ # 1) Hard filter by base
139
+ idxs = [i for i, d in enumerate(DOCS) if d["base"] == base_alcohol]
140
+ if not idxs:
141
+ # fallback if dataset has no exact base match
142
+ idxs = list(range(len(DOCS)))
143
+
144
+ # 2) Build a query text that steers embedding toward your constraints
145
+ q_text = f"Base spirit: {base_alcohol}. Flavor: {flavor}. Cocktail recipe."
146
+ q_emb = encoder.encode([q_text], normalize_embeddings=True, convert_to_numpy=True).astype("float32")[0]
147
+
148
+ # 3) Cosine similarity (dot, because normalized)
149
+ sims = embs[idxs].dot(q_emb)
150
+
151
+ # 4) Flavor boost if recipe carries the desired flavor tag
152
+ scores = []
153
+ for local_pos, i in enumerate(idxs):
154
+ base_score = float(sims[local_pos])
155
+ has_flavor = flavor in DOCS[i]["flavors"]
156
+ score = base_score + (BOOST if has_flavor else 0.0)
157
+ scores.append((score, i))
158
+ scores.sort(reverse=True)
159
+
160
+ # 5) Build output
161
+ k = max(1, int(top_k))
162
+ picks = scores[:k]
163
+ if not picks:
164
+ return "No matches found."
165
+
166
+ blocks = []
167
+ for sc, i in picks:
168
+ d = DOCS[i]
169
+ ing_txt = ", ".join(d["ingredients"]) if d["ingredients"] else "β€”"
170
+ instr = d["instructions"] or "β€”"
171
+ instr = instr[:400] + ("..." if len(instr) > 400 else "")
172
+ meta = f"**Base:** {d['base']} | **Flavors:** {', '.join(d['flavors']) or 'β€”'} | **Score:** {sc:.3f}"
173
+ blocks.append(f"### {d['title']}\n{meta}\n\n**Ingredients:** {ing_txt}\n\n**Instructions:** {instr}")
174
+ return "\n\n---\n\n".join(blocks)
175
+
176
+ # ========================
177
+ # 8) Gradio UI
178
+ # ========================
179
+ with gr.Blocks() as demo:
180
+ gr.Markdown("# 🍹 Cocktail Recommender β€” Pick a Base & Flavor")
181
+ with gr.Row():
182
+ base = gr.Dropdown(choices=BASE_OPTIONS, value="gin", label="Base alcohol")
183
+ flavor = gr.Dropdown(choices=FLAVOR_OPTIONS, value="citrus", label="Flavor")
184
+ topk = gr.Slider(1, 10, value=3, step=1, label="Number of recommendations")
185
+ with gr.Row():
186
+ ex1 = gr.Button("Gin + Citrus")
187
+ ex2 = gr.Button("Rum + Fruity")
188
+ ex3 = gr.Button("Mezcal + Smoky")
189
+ out = gr.Markdown()
190
+ gr.Button("Recommend").click(recommend, [base, flavor, topk], out)
191
+ ex1.click(lambda: ("gin", "citrus", 3), outputs=[base, flavor, topk])
192
+ ex2.click(lambda: ("rum", "fruity", 3), outputs=[base, flavor, topk])
193
+ ex3.click(lambda: ("mezcal", "smoky", 3), outputs=[base, flavor, topk])
194
+
195
+ demo.launch(share=True)
dataset.py DELETED
@@ -1,29 +0,0 @@
1
- # make_dataset.py
2
- import random, json, torch
3
- from datasets import Dataset
4
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
-
6
- MODEL = "erwanlc/t5-cocktails_recipe-base"
7
- tokenizer = AutoTokenizer.from_pretrained(MODEL)
8
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL, device_map="auto")
9
-
10
- ING_POOL = ["vodka","gin","rum","tequila","whiskey","triple sec","vermouth","lime juice",
11
- "lemon juice","cranberry juice","pineapple juice","simple syrup","agave syrup",
12
- "bitters","ginger beer","soda water","tonic water","mint","basil","cucumber"]
13
-
14
- def rand_ings():
15
- return ", ".join(random.sample(ING_POOL, k=random.randint(3,6)))
16
-
17
- rows = []
18
- for i in range(1000):
19
- ings = rand_ings()
20
- prompt = f"ingredients: {ings}"
21
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
22
- ids = model.generate(**inputs, max_new_tokens=180, do_sample=True, temperature=0.9, top_p=0.95)
23
- text = tokenizer.decode(ids[0], skip_special_tokens=True)
24
- rows.append({"id": i, "ingredients_text": ings, "generated_text": text})
25
-
26
- ds = Dataset.from_list(rows)
27
- ds.to_parquet("cocktail_synth.parquet")
28
- print("Wrote cocktail_synth.parquet")
29
-