Spaces:
Runtime error
Runtime error
dk-crazydiv
commited on
Commit
•
39f82cb
1
Parent(s):
4dcbec3
Adding v0 streamlit app for flax-community/t5-recipe-generation
Browse files- .gitignore +1 -0
- Build Ingredients Vocab.ipynb +213 -0
- README.md +7 -30
- beam_search.py +63 -0
- config.json +1 -0
- images/chef-transformer.png +0 -0
- images/logo.png +0 -0
- server.py +80 -0
- top_sampling.py +63 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__/
|
Build Ingredients Vocab.ipynb
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 3,
|
6 |
+
"metadata": {
|
7 |
+
"ExecuteTime": {
|
8 |
+
"end_time": "2021-07-14T12:54:01.369853Z",
|
9 |
+
"start_time": "2021-07-14T12:49:27.961404Z"
|
10 |
+
}
|
11 |
+
},
|
12 |
+
"outputs": [
|
13 |
+
{
|
14 |
+
"name": "stderr",
|
15 |
+
"output_type": "stream",
|
16 |
+
"text": [
|
17 |
+
"Using custom data configuration default-fdc6acb780b42528\n"
|
18 |
+
]
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"name": "stdout",
|
22 |
+
"output_type": "stream",
|
23 |
+
"text": [
|
24 |
+
"Downloading and preparing dataset recipe_nlg/default (download: Unknown size, generated: 2.04 GiB, post-processed: Unknown size, total: 2.04 GiB) to /home/rtx/.cache/huggingface/datasets/recipe_nlg/default-fdc6acb780b42528/1.0.0/20c969e1192265af03a7186457bdb4a9109d5d68b92cad04c3ec894d6e5aee61...\n"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"data": {
|
29 |
+
"application/vnd.jupyter.widget-view+json": {
|
30 |
+
"model_id": "",
|
31 |
+
"version_major": 2,
|
32 |
+
"version_minor": 0
|
33 |
+
},
|
34 |
+
"text/plain": [
|
35 |
+
"HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))"
|
36 |
+
]
|
37 |
+
},
|
38 |
+
"metadata": {},
|
39 |
+
"output_type": "display_data"
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"name": "stdout",
|
43 |
+
"output_type": "stream",
|
44 |
+
"text": [
|
45 |
+
"\r",
|
46 |
+
"Dataset recipe_nlg downloaded and prepared to /home/rtx/.cache/huggingface/datasets/recipe_nlg/default-fdc6acb780b42528/1.0.0/20c969e1192265af03a7186457bdb4a9109d5d68b92cad04c3ec894d6e5aee61. Subsequent calls will reuse this data.\n"
|
47 |
+
]
|
48 |
+
}
|
49 |
+
],
|
50 |
+
"source": [
|
51 |
+
"from datasets import load_dataset\n",
|
52 |
+
"DATA_DIR = \"~/Downloads/dataset/\"\n",
|
53 |
+
"dataset = load_dataset(\"recipe_nlg\", data_dir=DATA_DIR)"
|
54 |
+
]
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"cell_type": "code",
|
58 |
+
"execution_count": 10,
|
59 |
+
"metadata": {
|
60 |
+
"ExecuteTime": {
|
61 |
+
"end_time": "2021-07-14T12:58:25.150105Z",
|
62 |
+
"start_time": "2021-07-14T12:55:27.486385Z"
|
63 |
+
}
|
64 |
+
},
|
65 |
+
"outputs": [
|
66 |
+
{
|
67 |
+
"name": "stderr",
|
68 |
+
"output_type": "stream",
|
69 |
+
"text": [
|
70 |
+
"100%|██████████| 2231142/2231142 [02:57<00:00, 12558.59it/s]\n"
|
71 |
+
]
|
72 |
+
}
|
73 |
+
],
|
74 |
+
"source": [
|
75 |
+
"from collections import Counter\n",
|
76 |
+
"from tqdm import tqdm\n",
|
77 |
+
"ctr = Counter()\n",
|
78 |
+
"\n",
|
79 |
+
"for row in tqdm(dataset[\"train\"]):\n",
|
80 |
+
" for item in row[\"ner\"]:\n",
|
81 |
+
" ctr[item] += 1"
|
82 |
+
]
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"cell_type": "code",
|
86 |
+
"execution_count": 23,
|
87 |
+
"metadata": {
|
88 |
+
"ExecuteTime": {
|
89 |
+
"end_time": "2021-07-14T13:02:09.315817Z",
|
90 |
+
"start_time": "2021-07-14T13:02:09.259046Z"
|
91 |
+
}
|
92 |
+
},
|
93 |
+
"outputs": [],
|
94 |
+
"source": [
|
95 |
+
"first_500 = list(set([x[0].lower() for x in ctr.most_common()[0:500]]))"
|
96 |
+
]
|
97 |
+
},
|
98 |
+
{
|
99 |
+
"cell_type": "code",
|
100 |
+
"execution_count": 25,
|
101 |
+
"metadata": {
|
102 |
+
"ExecuteTime": {
|
103 |
+
"end_time": "2021-07-14T13:02:28.864546Z",
|
104 |
+
"start_time": "2021-07-14T13:02:28.856279Z"
|
105 |
+
}
|
106 |
+
},
|
107 |
+
"outputs": [
|
108 |
+
{
|
109 |
+
"data": {
|
110 |
+
"text/plain": [
|
111 |
+
"443"
|
112 |
+
]
|
113 |
+
},
|
114 |
+
"execution_count": 25,
|
115 |
+
"metadata": {},
|
116 |
+
"output_type": "execute_result"
|
117 |
+
}
|
118 |
+
],
|
119 |
+
"source": [
|
120 |
+
"len(first_500)"
|
121 |
+
]
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"cell_type": "code",
|
125 |
+
"execution_count": 26,
|
126 |
+
"metadata": {
|
127 |
+
"ExecuteTime": {
|
128 |
+
"end_time": "2021-07-14T13:02:53.656711Z",
|
129 |
+
"start_time": "2021-07-14T13:02:53.653868Z"
|
130 |
+
}
|
131 |
+
},
|
132 |
+
"outputs": [],
|
133 |
+
"source": [
|
134 |
+
"first_100 = sorted(first_500[:100])\n",
|
135 |
+
"next_100 = sorted(first_500[100:200])"
|
136 |
+
]
|
137 |
+
},
|
138 |
+
{
|
139 |
+
"cell_type": "code",
|
140 |
+
"execution_count": 29,
|
141 |
+
"metadata": {
|
142 |
+
"ExecuteTime": {
|
143 |
+
"end_time": "2021-07-14T13:03:35.640538Z",
|
144 |
+
"start_time": "2021-07-14T13:03:35.634368Z"
|
145 |
+
}
|
146 |
+
},
|
147 |
+
"outputs": [],
|
148 |
+
"source": [
|
149 |
+
"d = {\n",
|
150 |
+
" \"first_100\": first_100,\n",
|
151 |
+
" \"next_100\": next_100\n",
|
152 |
+
"}"
|
153 |
+
]
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"cell_type": "code",
|
157 |
+
"execution_count": 31,
|
158 |
+
"metadata": {
|
159 |
+
"ExecuteTime": {
|
160 |
+
"end_time": "2021-07-14T13:03:52.682190Z",
|
161 |
+
"start_time": "2021-07-14T13:03:52.679624Z"
|
162 |
+
}
|
163 |
+
},
|
164 |
+
"outputs": [],
|
165 |
+
"source": [
|
166 |
+
"import json\n",
|
167 |
+
"with open(\"config.json\", \"w\") as f:\n",
|
168 |
+
" f.write(json.dumps(d))"
|
169 |
+
]
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"cell_type": "code",
|
173 |
+
"execution_count": null,
|
174 |
+
"metadata": {},
|
175 |
+
"outputs": [],
|
176 |
+
"source": []
|
177 |
+
}
|
178 |
+
],
|
179 |
+
"metadata": {
|
180 |
+
"kernelspec": {
|
181 |
+
"display_name": "Python 3",
|
182 |
+
"language": "python",
|
183 |
+
"name": "python3"
|
184 |
+
},
|
185 |
+
"language_info": {
|
186 |
+
"codemirror_mode": {
|
187 |
+
"name": "ipython",
|
188 |
+
"version": 3
|
189 |
+
},
|
190 |
+
"file_extension": ".py",
|
191 |
+
"mimetype": "text/x-python",
|
192 |
+
"name": "python",
|
193 |
+
"nbconvert_exporter": "python",
|
194 |
+
"pygments_lexer": "ipython3",
|
195 |
+
"version": "3.7.1"
|
196 |
+
},
|
197 |
+
"toc": {
|
198 |
+
"base_numbering": 1,
|
199 |
+
"nav_menu": {},
|
200 |
+
"number_sections": true,
|
201 |
+
"sideBar": true,
|
202 |
+
"skip_h1_title": false,
|
203 |
+
"title_cell": "Table of Contents",
|
204 |
+
"title_sidebar": "Contents",
|
205 |
+
"toc_cell": false,
|
206 |
+
"toc_position": {},
|
207 |
+
"toc_section_display": true,
|
208 |
+
"toc_window_display": false
|
209 |
+
}
|
210 |
+
},
|
211 |
+
"nbformat": 4,
|
212 |
+
"nbformat_minor": 2
|
213 |
+
}
|
README.md
CHANGED
@@ -1,33 +1,10 @@
|
|
1 |
-
|
2 |
-
title: Chef Transformer
|
3 |
-
emoji: 👀
|
4 |
-
colorFrom: indigo
|
5 |
-
colorTo: blue
|
6 |
-
sdk: streamlit
|
7 |
-
app_file: app.py
|
8 |
-
pinned: false
|
9 |
-
---
|
10 |
|
11 |
-
# Configuration
|
12 |
|
13 |
-
|
14 |
-
|
|
|
|
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
`colorFrom`: _string_
|
20 |
-
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
21 |
-
|
22 |
-
`colorTo`: _string_
|
23 |
-
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
24 |
-
|
25 |
-
`sdk`: _string_
|
26 |
-
Can be either `gradio` or `streamlit`
|
27 |
-
|
28 |
-
`app_file`: _string_
|
29 |
-
Path to your main application file (which contains either `gradio` or `streamlit` Python code).
|
30 |
-
Path is relative to the root of the repository.
|
31 |
-
|
32 |
-
`pinned`: _boolean_
|
33 |
-
Whether the Space stays on top of your list.
|
|
|
1 |
+
# Streamlit demo for Chef Transformers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
|
|
3 |
|
4 |
+
### Launch demo:
|
5 |
+
```
|
6 |
+
streamlit run server.py
|
7 |
+
```
|
8 |
|
9 |
+
### Modify config
|
10 |
+
Add any custom ingredient to display in `config.json` with key `first_100` to be displayed in multi-select. `next_100` are for custom ingredient adding section (to provide autocomplete assist as we type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
beam_search.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForSeq2SeqLM
|
3 |
+
from transformers import AutoTokenizer
|
4 |
+
from transformers import pipeline
|
5 |
+
|
6 |
+
from pprint import pprint
|
7 |
+
import re
|
8 |
+
|
9 |
+
|
10 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
# MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
|
12 |
+
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
|
13 |
+
# model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
|
14 |
+
|
15 |
+
|
16 |
+
def skip_special_tokens_and_prettify(text, tokenizer):
|
17 |
+
recipe_maps = {"<sep>": "--", "<section>": "\n"}
|
18 |
+
recipe_map_pattern = "|".join(map(re.escape, recipe_maps.keys()))
|
19 |
+
|
20 |
+
text = re.sub(
|
21 |
+
recipe_map_pattern,
|
22 |
+
lambda m: recipe_maps[m.group()],
|
23 |
+
re.sub("|".join(tokenizer.all_special_tokens), "", text)
|
24 |
+
)
|
25 |
+
|
26 |
+
data = {"title": "", "ingredients": [], "directions": []}
|
27 |
+
for section in text.split("\n"):
|
28 |
+
section = section.strip()
|
29 |
+
section = section.strip()
|
30 |
+
if section.startswith("title:"):
|
31 |
+
data["title"] = section.replace("title:", "").strip()
|
32 |
+
elif section.startswith("ingredients:"):
|
33 |
+
data["ingredients"] = [s.strip() for s in section.replace("ingredients:", "").split('--')]
|
34 |
+
elif section.startswith("directions:"):
|
35 |
+
data["directions"] = [s.strip() for s in section.replace("directions:", "").split('--')]
|
36 |
+
else:
|
37 |
+
pass
|
38 |
+
|
39 |
+
return data
|
40 |
+
|
41 |
+
|
42 |
+
def post_generator(output_tensors, tokenizer):
|
43 |
+
output_tensors = [output_tensors[i]["generated_token_ids"] for i in range(len(output_tensors))]
|
44 |
+
texts = tokenizer.batch_decode(output_tensors, skip_special_tokens=False)
|
45 |
+
texts = [skip_special_tokens_and_prettify(text, tokenizer) for text in texts]
|
46 |
+
return texts
|
47 |
+
|
48 |
+
|
49 |
+
# Example
|
50 |
+
generate_kwargs = {
|
51 |
+
"max_length": 512,
|
52 |
+
"min_length": 64,
|
53 |
+
"no_repeat_ngram_size": 3,
|
54 |
+
"early_stopping": True,
|
55 |
+
"num_beams": 5,
|
56 |
+
"length_penalty": 1.5,
|
57 |
+
"num_return_sequences": 2
|
58 |
+
}
|
59 |
+
items = "potato, cheese"
|
60 |
+
# generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
|
61 |
+
# generated = generator(items, return_tensors=True, return_text=False, **generate_kwargs)
|
62 |
+
# outputs = post_generator(generated, tokenizer)
|
63 |
+
# pprint(outputs)
|
config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"first_100": ["allspice", "almond extract", "applesauce", "avocado", "balsamic vinegar", "basil", "bay leaf", "beets", "bread crumbs", "bread flour", "buns", "catsup", "cayenne", "cherry tomatoes", "chicken breasts", "chives", "chocolate cake", "coconut milk", "cold butter", "cold milk", "cooking oil", "cornstarch", "crab meat", "crackers", "cream of chicken soup", "cream of tartar", "cumin", "cumin seeds", "curry powder", "egg yolk", "extra-virgin olive oil", "feta cheese", "flaked coconut", "flat leaf parsley", "flour tortillas", "fresh chives", "fresh cilantro", "fresh mint", "fresh oregano", "fresh rosemary", "frozen strawberries", "gingerroot", "green olives", "ground allspice", "ground chuck", "ground coriander", "ground cumin", "ground pork", "ground red pepper", "hamburger", "hazelnuts", "heavy cream", "heavy whipping cream", "hot pepper", "italian dressing", "lean ground beef", "lemon juice", "lemon pepper", "marjoram", "miracle", "noodles", "nuts", "oatmeal", "oats", "oleo", "olive oil", "onion salt", "onions", "orange", "paprika", "parmesan cheese", "parsley", "pasta", "peaches", "pecans", "pork sausage", "pork tenderloin", "poultry seasoning", "powdered sugar", "pumpkin", "red potatoes", "red wine vinegar", "rosemary", "salmon", "scallion", "sesame oil", "shell", "stalks celery", "tabasco sauce", "tarragon", "tomatoes", "unsalted butter", "vanilla wafers", "vegetables", "warm water", "whipping cream", "white wine vinegar", "whole wheat flour", "yellow squash", "yogurt"], "next_100": ["active dry yeast", "almonds", "apple", "apple cider", "apple cider vinegar", "avocados", "baby spinach", "bay leaves", "bean sprouts", "beef", "beef broth", "broccoli", "cabbage", "capers", "cashews", "celery", "celery salt", "cherries", "cherry pie filling", "chicken", "chicken broth", "chicken stock", "chickpeas", "chili sauce", "chocolate", "cinnamon", "cloves", "corn", "cottage cheese", "cranberry sauce", "egg noodles", "egg yolks", "extra virgin olive oil", "freshly ground pepper", "garlic powder", "golden raisins", "graham cracker crust", "graham crackers", "green peppers", "ground black pepper", "ground nutmeg", "ground pepper", "ground turmeric", "kosher salt", "lemon rind", "mango", "mint", "mustard", "nutmeg", "orange juice", "orange zest", "oregano", "peanut butter", "peas", "pecan halves", "pepperoni", "pine nuts", "pinto beans", "pizza sauce", "plain yogurt", "potatoes", "raisins", "red", "red bell peppers", "red pepper", "red peppers", "rhubarb", "ricotta cheese", "salad oil", "sauce", "scallions", "sesame seeds", "sherry", "shredded cheese", "skinless", "soda", "soy sauce", "spinach", "strawberries", "sugar", "sweet onion", "sweet potatoes", "swiss cheese", "t", "tomato", "tomato paste", "tomato soup", "tuna", "vanilla bean", "vanilla ice cream", "vanilla pudding", "vegetable broth", "vegetable oil", "vegetable shortening", "whipped cream", "white onion", "white sugar", "yellow cake", "yellow cornmeal", "zucchini"]}
|
images/chef-transformer.png
ADDED
images/logo.png
ADDED
server.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
2 |
+
from datetime import datetime as dt
|
3 |
+
import streamlit as st
|
4 |
+
from streamlit_tags import st_tags
|
5 |
+
import beam_search
|
6 |
+
import top_sampling
|
7 |
+
from pprint import pprint
|
8 |
+
import json
|
9 |
+
|
10 |
+
with open("config.json") as f:
|
11 |
+
cfg = json.loads(f.read())
|
12 |
+
|
13 |
+
st.set_page_config(layout="wide")
|
14 |
+
|
15 |
+
@st.cache(allow_output_mutation=True)
|
16 |
+
def load_model():
|
17 |
+
tokenizer = AutoTokenizer.from_pretrained("flax-community/t5-recipe-generation")
|
18 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("flax-community/t5-recipe-generation")
|
19 |
+
generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
|
20 |
+
return generator, tokenizer
|
21 |
+
|
22 |
+
def sampling_changed(obj):
|
23 |
+
print(obj)
|
24 |
+
|
25 |
+
|
26 |
+
with st.spinner('Loading model...'):
|
27 |
+
generator, tokenizer = load_model()
|
28 |
+
# st.image("images/chef-transformer.png", width=400)
|
29 |
+
st.header("Chef transformers (flax-community)")
|
30 |
+
st.markdown("This demo uses [t5 trained on recipe-nlg](https://huggingface.co/flax-community/t5-recipe-generation) to generate recipe from a given set of ingredients")
|
31 |
+
img = st.sidebar.image("images/chef-transformer.png", width=200)
|
32 |
+
add_text_sidebar = st.sidebar.title("Popular recipes:")
|
33 |
+
add_text_sidebar = st.sidebar.text("Recipe preset(example#1)")
|
34 |
+
add_text_sidebar = st.sidebar.text("Recipe preset(example#2)")
|
35 |
+
|
36 |
+
add_text_sidebar = st.sidebar.title("Mode:")
|
37 |
+
sampling_mode = st.sidebar.selectbox("select a Mode", index=0, options=["Beam Search", "Top-k Sampling"])
|
38 |
+
|
39 |
+
|
40 |
+
original_keywords = st.multiselect("Choose ingredients",
|
41 |
+
cfg["first_100"],
|
42 |
+
["parmesan cheese", "fresh oregano", "basil", "whole wheat flour"]
|
43 |
+
)
|
44 |
+
|
45 |
+
st.write("Add custom ingredients here:")
|
46 |
+
custom_keywords = st_tags(
|
47 |
+
label="",
|
48 |
+
text='Press enter to add more',
|
49 |
+
value=['salt'],
|
50 |
+
suggestions=cfg["next_100"],
|
51 |
+
maxtags = 15,
|
52 |
+
key='1')
|
53 |
+
all_ingredients = []
|
54 |
+
all_ingredients.extend(original_keywords)
|
55 |
+
all_ingredients.extend(custom_keywords)
|
56 |
+
all_ingredients = ", ".join(all_ingredients)
|
57 |
+
st.markdown("**Generate recipe for:** "+all_ingredients)
|
58 |
+
|
59 |
+
|
60 |
+
submit = st.button('Get Recipe!')
|
61 |
+
if submit:
|
62 |
+
with st.spinner('Generating recipe...'):
|
63 |
+
if sampling_mode == "Beam Search":
|
64 |
+
generated = generator(all_ingredients, return_tensors=True, return_text=False, **beam_search.generate_kwargs)
|
65 |
+
outputs = beam_search.post_generator(generated, tokenizer)
|
66 |
+
elif sampling_mode == "Top-k Sampling":
|
67 |
+
generated = generator(all_ingredients, return_tensors=True, return_text=False, **top_sampling.generate_kwargs)
|
68 |
+
outputs = top_sampling.post_generator(generated, tokenizer)
|
69 |
+
output = outputs[0]
|
70 |
+
markdown_output = ""
|
71 |
+
markdown_output += f"## {output['title'].capitalize()}\n"
|
72 |
+
markdown_output += f"#### Ingredients:\n"
|
73 |
+
for o in output["ingredients"]:
|
74 |
+
markdown_output += f"- {o}\n"
|
75 |
+
markdown_output += f"#### Directions:\n"
|
76 |
+
for o in output["directions"]:
|
77 |
+
markdown_output += f"- {o}\n"
|
78 |
+
st.markdown(markdown_output)
|
79 |
+
st.balloons()
|
80 |
+
|
top_sampling.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForSeq2SeqLM
|
3 |
+
from transformers import AutoTokenizer
|
4 |
+
from transformers import pipeline
|
5 |
+
|
6 |
+
from pprint import pprint
|
7 |
+
import re
|
8 |
+
|
9 |
+
|
10 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
# MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
|
12 |
+
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
|
13 |
+
# model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
|
14 |
+
|
15 |
+
|
16 |
+
def skip_special_tokens_and_prettify(text, tokenizer):
|
17 |
+
recipe_maps = {"<sep>": "--", "<section>": "\n"}
|
18 |
+
recipe_map_pattern = "|".join(map(re.escape, recipe_maps.keys()))
|
19 |
+
|
20 |
+
text = re.sub(
|
21 |
+
recipe_map_pattern,
|
22 |
+
lambda m: recipe_maps[m.group()],
|
23 |
+
re.sub("|".join(tokenizer.all_special_tokens), "", text)
|
24 |
+
)
|
25 |
+
|
26 |
+
data = {"title": "", "ingredients": [], "directions": []}
|
27 |
+
for section in text.split("\n"):
|
28 |
+
section = section.strip()
|
29 |
+
section = section.strip()
|
30 |
+
if section.startswith("title:"):
|
31 |
+
data["title"] = section.replace("title:", "").strip()
|
32 |
+
elif section.startswith("ingredients:"):
|
33 |
+
data["ingredients"] = [s.strip() for s in section.replace("ingredients:", "").split('--')]
|
34 |
+
elif section.startswith("directions:"):
|
35 |
+
data["directions"] = [s.strip() for s in section.replace("directions:", "").split('--')]
|
36 |
+
else:
|
37 |
+
pass
|
38 |
+
|
39 |
+
return data
|
40 |
+
|
41 |
+
|
42 |
+
def post_generator(output_tensors, tokenizer):
|
43 |
+
output_tensors = [output_tensors[i]["generated_token_ids"] for i in range(len(output_tensors))]
|
44 |
+
texts = tokenizer.batch_decode(output_tensors, skip_special_tokens=False)
|
45 |
+
texts = [skip_special_tokens_and_prettify(text, tokenizer) for text in texts]
|
46 |
+
return texts
|
47 |
+
|
48 |
+
|
49 |
+
# Example
|
50 |
+
generate_kwargs = {
|
51 |
+
"max_length": 512,
|
52 |
+
"min_length": 64,
|
53 |
+
"no_repeat_ngram_size": 3,
|
54 |
+
"do_sample": True,
|
55 |
+
"top_k": 60,
|
56 |
+
"top_p": 0.95,
|
57 |
+
"num_return_sequences": 3
|
58 |
+
}
|
59 |
+
# items = "potato, cheese"
|
60 |
+
# generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
|
61 |
+
# generated = generator(items, return_tensors=True, return_text=False, **generate_kwargs)
|
62 |
+
# outputs = post_generator(generated, tokenizer)
|
63 |
+
# pprint(outputs)
|