OGOGOG commited on
Commit
9e9bfcf
·
verified ·
1 Parent(s): 9ce231d

Rename dataset generate to dataset.py

Browse files
Files changed (2) hide show
  1. dataset generate +0 -5
  2. dataset.py +29 -0
dataset generate DELETED
@@ -1,5 +0,0 @@
1
- from transformers import pipeline
2
- gen = pipeline("text2text-generation", model=MODEL, torch_dtype="auto", device_map="auto")
3
-
4
- def gen_one(ings):
5
- return gen(f"ingredients: {ings}", max_new_tokens=180, do_sample=True, temperature=0.9, top_p=0.95)[0]["generated_text"]
 
 
 
 
 
 
dataset.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # make_dataset.py
2
+ import random, json, torch
3
+ from datasets import Dataset
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+
6
+ MODEL = "erwanlc/t5-cocktails_recipe-base"
7
+ tokenizer = AutoTokenizer.from_pretrained(MODEL)
8
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL, device_map="auto")
9
+
10
+ ING_POOL = ["vodka","gin","rum","tequila","whiskey","triple sec","vermouth","lime juice",
11
+ "lemon juice","cranberry juice","pineapple juice","simple syrup","agave syrup",
12
+ "bitters","ginger beer","soda water","tonic water","mint","basil","cucumber"]
13
+
14
+ def rand_ings():
15
+ return ", ".join(random.sample(ING_POOL, k=random.randint(3,6)))
16
+
17
+ rows = []
18
+ for i in range(1000):
19
+ ings = rand_ings()
20
+ prompt = f"ingredients: {ings}"
21
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
22
+ ids = model.generate(**inputs, max_new_tokens=180, do_sample=True, temperature=0.9, top_p=0.95)
23
+ text = tokenizer.decode(ids[0], skip_special_tokens=True)
24
+ rows.append({"id": i, "ingredients_text": ings, "generated_text": text})
25
+
26
+ ds = Dataset.from_list(rows)
27
+ ds.to_parquet("cocktail_synth.parquet")
28
+ print("Wrote cocktail_synth.parquet")
29
+