dk-crazydiv commited on
Commit
39f82cb
1 Parent(s): 4dcbec3

Adding v0 streamlit app for flax-community/t5-recipe-generation

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
Build Ingredients Vocab.ipynb ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 3,
6
+ "metadata": {
7
+ "ExecuteTime": {
8
+ "end_time": "2021-07-14T12:54:01.369853Z",
9
+ "start_time": "2021-07-14T12:49:27.961404Z"
10
+ }
11
+ },
12
+ "outputs": [
13
+ {
14
+ "name": "stderr",
15
+ "output_type": "stream",
16
+ "text": [
17
+ "Using custom data configuration default-fdc6acb780b42528\n"
18
+ ]
19
+ },
20
+ {
21
+ "name": "stdout",
22
+ "output_type": "stream",
23
+ "text": [
24
+ "Downloading and preparing dataset recipe_nlg/default (download: Unknown size, generated: 2.04 GiB, post-processed: Unknown size, total: 2.04 GiB) to /home/rtx/.cache/huggingface/datasets/recipe_nlg/default-fdc6acb780b42528/1.0.0/20c969e1192265af03a7186457bdb4a9109d5d68b92cad04c3ec894d6e5aee61...\n"
25
+ ]
26
+ },
27
+ {
28
+ "data": {
29
+ "application/vnd.jupyter.widget-view+json": {
30
+ "model_id": "",
31
+ "version_major": 2,
32
+ "version_minor": 0
33
+ },
34
+ "text/plain": [
35
+ "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))"
36
+ ]
37
+ },
38
+ "metadata": {},
39
+ "output_type": "display_data"
40
+ },
41
+ {
42
+ "name": "stdout",
43
+ "output_type": "stream",
44
+ "text": [
45
+ "\r",
46
+ "Dataset recipe_nlg downloaded and prepared to /home/rtx/.cache/huggingface/datasets/recipe_nlg/default-fdc6acb780b42528/1.0.0/20c969e1192265af03a7186457bdb4a9109d5d68b92cad04c3ec894d6e5aee61. Subsequent calls will reuse this data.\n"
47
+ ]
48
+ }
49
+ ],
50
+ "source": [
51
+ "from datasets import load_dataset\n",
52
+ "DATA_DIR = \"~/Downloads/dataset/\"\n",
53
+ "dataset = load_dataset(\"recipe_nlg\", data_dir=DATA_DIR)"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": 10,
59
+ "metadata": {
60
+ "ExecuteTime": {
61
+ "end_time": "2021-07-14T12:58:25.150105Z",
62
+ "start_time": "2021-07-14T12:55:27.486385Z"
63
+ }
64
+ },
65
+ "outputs": [
66
+ {
67
+ "name": "stderr",
68
+ "output_type": "stream",
69
+ "text": [
70
+ "100%|██████████| 2231142/2231142 [02:57<00:00, 12558.59it/s]\n"
71
+ ]
72
+ }
73
+ ],
74
+ "source": [
75
+ "from collections import Counter\n",
76
+ "from tqdm import tqdm\n",
77
+ "ctr = Counter()\n",
78
+ "\n",
79
+ "for row in tqdm(dataset[\"train\"]):\n",
80
+ " for item in row[\"ner\"]:\n",
81
+ " ctr[item] += 1"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": 23,
87
+ "metadata": {
88
+ "ExecuteTime": {
89
+ "end_time": "2021-07-14T13:02:09.315817Z",
90
+ "start_time": "2021-07-14T13:02:09.259046Z"
91
+ }
92
+ },
93
+ "outputs": [],
94
+ "source": [
95
+ "first_500 = list(set([x[0].lower() for x in ctr.most_common()[0:500]]))"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": 25,
101
+ "metadata": {
102
+ "ExecuteTime": {
103
+ "end_time": "2021-07-14T13:02:28.864546Z",
104
+ "start_time": "2021-07-14T13:02:28.856279Z"
105
+ }
106
+ },
107
+ "outputs": [
108
+ {
109
+ "data": {
110
+ "text/plain": [
111
+ "443"
112
+ ]
113
+ },
114
+ "execution_count": 25,
115
+ "metadata": {},
116
+ "output_type": "execute_result"
117
+ }
118
+ ],
119
+ "source": [
120
+ "len(first_500)"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": 26,
126
+ "metadata": {
127
+ "ExecuteTime": {
128
+ "end_time": "2021-07-14T13:02:53.656711Z",
129
+ "start_time": "2021-07-14T13:02:53.653868Z"
130
+ }
131
+ },
132
+ "outputs": [],
133
+ "source": [
134
+ "first_100 = sorted(first_500[:100])\n",
135
+ "next_100 = sorted(first_500[100:200])"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": 29,
141
+ "metadata": {
142
+ "ExecuteTime": {
143
+ "end_time": "2021-07-14T13:03:35.640538Z",
144
+ "start_time": "2021-07-14T13:03:35.634368Z"
145
+ }
146
+ },
147
+ "outputs": [],
148
+ "source": [
149
+ "d = {\n",
150
+ " \"first_100\": first_100,\n",
151
+ " \"next_100\": next_100\n",
152
+ "}"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": 31,
158
+ "metadata": {
159
+ "ExecuteTime": {
160
+ "end_time": "2021-07-14T13:03:52.682190Z",
161
+ "start_time": "2021-07-14T13:03:52.679624Z"
162
+ }
163
+ },
164
+ "outputs": [],
165
+ "source": [
166
+ "import json\n",
167
+ "with open(\"config.json\", \"w\") as f:\n",
168
+ " f.write(json.dumps(d))"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": null,
174
+ "metadata": {},
175
+ "outputs": [],
176
+ "source": []
177
+ }
178
+ ],
179
+ "metadata": {
180
+ "kernelspec": {
181
+ "display_name": "Python 3",
182
+ "language": "python",
183
+ "name": "python3"
184
+ },
185
+ "language_info": {
186
+ "codemirror_mode": {
187
+ "name": "ipython",
188
+ "version": 3
189
+ },
190
+ "file_extension": ".py",
191
+ "mimetype": "text/x-python",
192
+ "name": "python",
193
+ "nbconvert_exporter": "python",
194
+ "pygments_lexer": "ipython3",
195
+ "version": "3.7.1"
196
+ },
197
+ "toc": {
198
+ "base_numbering": 1,
199
+ "nav_menu": {},
200
+ "number_sections": true,
201
+ "sideBar": true,
202
+ "skip_h1_title": false,
203
+ "title_cell": "Table of Contents",
204
+ "title_sidebar": "Contents",
205
+ "toc_cell": false,
206
+ "toc_position": {},
207
+ "toc_section_display": true,
208
+ "toc_window_display": false
209
+ }
210
+ },
211
+ "nbformat": 4,
212
+ "nbformat_minor": 2
213
+ }
README.md CHANGED
@@ -1,33 +1,10 @@
1
- ---
2
- title: Chef Transformer
3
- emoji: 👀
4
- colorFrom: indigo
5
- colorTo: blue
6
- sdk: streamlit
7
- app_file: app.py
8
- pinned: false
9
- ---
10
 
11
- # Configuration
12
 
13
- `title`: _string_
14
- Display title for the Space
 
 
15
 
16
- `emoji`: _string_
17
- Space emoji (emoji-only character allowed)
18
-
19
- `colorFrom`: _string_
20
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
21
-
22
- `colorTo`: _string_
23
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
24
-
25
- `sdk`: _string_
26
- Can be either `gradio` or `streamlit`
27
-
28
- `app_file`: _string_
29
- Path to your main application file (which contains either `gradio` or `streamlit` Python code).
30
- Path is relative to the root of the repository.
31
-
32
- `pinned`: _boolean_
33
- Whether the Space stays on top of your list.
 
1
+ # Streamlit demo for Chef Transformers
 
 
 
 
 
 
 
 
2
 
 
3
 
4
+ ### Launch demo:
5
+ ```
6
+ streamlit run server.py
7
+ ```
8
 
9
+ ### Modify config
10
+ Add any custom ingredient to display in `config.json` with key `first_100` to be displayed in multi-select. `next_100` are for custom ingredient adding section (to provide autocomplete assist as we type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
beam_search.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForSeq2SeqLM
3
+ from transformers import AutoTokenizer
4
+ from transformers import pipeline
5
+
6
+ from pprint import pprint
7
+ import re
8
+
9
+
10
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ # MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
12
+ # tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
13
+ # model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
14
+
15
+
16
+ def skip_special_tokens_and_prettify(text, tokenizer):
17
+ recipe_maps = {"<sep>": "--", "<section>": "\n"}
18
+ recipe_map_pattern = "|".join(map(re.escape, recipe_maps.keys()))
19
+
20
+ text = re.sub(
21
+ recipe_map_pattern,
22
+ lambda m: recipe_maps[m.group()],
23
+ re.sub("|".join(tokenizer.all_special_tokens), "", text)
24
+ )
25
+
26
+ data = {"title": "", "ingredients": [], "directions": []}
27
+ for section in text.split("\n"):
28
+ section = section.strip()
29
+ section = section.strip()
30
+ if section.startswith("title:"):
31
+ data["title"] = section.replace("title:", "").strip()
32
+ elif section.startswith("ingredients:"):
33
+ data["ingredients"] = [s.strip() for s in section.replace("ingredients:", "").split('--')]
34
+ elif section.startswith("directions:"):
35
+ data["directions"] = [s.strip() for s in section.replace("directions:", "").split('--')]
36
+ else:
37
+ pass
38
+
39
+ return data
40
+
41
+
42
+ def post_generator(output_tensors, tokenizer):
43
+ output_tensors = [output_tensors[i]["generated_token_ids"] for i in range(len(output_tensors))]
44
+ texts = tokenizer.batch_decode(output_tensors, skip_special_tokens=False)
45
+ texts = [skip_special_tokens_and_prettify(text, tokenizer) for text in texts]
46
+ return texts
47
+
48
+
49
+ # Example
50
+ generate_kwargs = {
51
+ "max_length": 512,
52
+ "min_length": 64,
53
+ "no_repeat_ngram_size": 3,
54
+ "early_stopping": True,
55
+ "num_beams": 5,
56
+ "length_penalty": 1.5,
57
+ "num_return_sequences": 2
58
+ }
59
+ items = "potato, cheese"
60
+ # generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
61
+ # generated = generator(items, return_tensors=True, return_text=False, **generate_kwargs)
62
+ # outputs = post_generator(generated, tokenizer)
63
+ # pprint(outputs)
config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"first_100": ["allspice", "almond extract", "applesauce", "avocado", "balsamic vinegar", "basil", "bay leaf", "beets", "bread crumbs", "bread flour", "buns", "catsup", "cayenne", "cherry tomatoes", "chicken breasts", "chives", "chocolate cake", "coconut milk", "cold butter", "cold milk", "cooking oil", "cornstarch", "crab meat", "crackers", "cream of chicken soup", "cream of tartar", "cumin", "cumin seeds", "curry powder", "egg yolk", "extra-virgin olive oil", "feta cheese", "flaked coconut", "flat leaf parsley", "flour tortillas", "fresh chives", "fresh cilantro", "fresh mint", "fresh oregano", "fresh rosemary", "frozen strawberries", "gingerroot", "green olives", "ground allspice", "ground chuck", "ground coriander", "ground cumin", "ground pork", "ground red pepper", "hamburger", "hazelnuts", "heavy cream", "heavy whipping cream", "hot pepper", "italian dressing", "lean ground beef", "lemon juice", "lemon pepper", "marjoram", "miracle", "noodles", "nuts", "oatmeal", "oats", "oleo", "olive oil", "onion salt", "onions", "orange", "paprika", "parmesan cheese", "parsley", "pasta", "peaches", "pecans", "pork sausage", "pork tenderloin", "poultry seasoning", "powdered sugar", "pumpkin", "red potatoes", "red wine vinegar", "rosemary", "salmon", "scallion", "sesame oil", "shell", "stalks celery", "tabasco sauce", "tarragon", "tomatoes", "unsalted butter", "vanilla wafers", "vegetables", "warm water", "whipping cream", "white wine vinegar", "whole wheat flour", "yellow squash", "yogurt"], "next_100": ["active dry yeast", "almonds", "apple", "apple cider", "apple cider vinegar", "avocados", "baby spinach", "bay leaves", "bean sprouts", "beef", "beef broth", "broccoli", "cabbage", "capers", "cashews", "celery", "celery salt", "cherries", "cherry pie filling", "chicken", "chicken broth", "chicken stock", "chickpeas", "chili sauce", "chocolate", "cinnamon", "cloves", "corn", "cottage cheese", "cranberry sauce", "egg noodles", "egg yolks", "extra virgin olive oil", "freshly ground pepper", "garlic powder", "golden raisins", "graham cracker crust", "graham crackers", "green peppers", "ground black pepper", "ground nutmeg", "ground pepper", "ground turmeric", "kosher salt", "lemon rind", "mango", "mint", "mustard", "nutmeg", "orange juice", "orange zest", "oregano", "peanut butter", "peas", "pecan halves", "pepperoni", "pine nuts", "pinto beans", "pizza sauce", "plain yogurt", "potatoes", "raisins", "red", "red bell peppers", "red pepper", "red peppers", "rhubarb", "ricotta cheese", "salad oil", "sauce", "scallions", "sesame seeds", "sherry", "shredded cheese", "skinless", "soda", "soy sauce", "spinach", "strawberries", "sugar", "sweet onion", "sweet potatoes", "swiss cheese", "t", "tomato", "tomato paste", "tomato soup", "tuna", "vanilla bean", "vanilla ice cream", "vanilla pudding", "vegetable broth", "vegetable oil", "vegetable shortening", "whipped cream", "white onion", "white sugar", "yellow cake", "yellow cornmeal", "zucchini"]}
images/chef-transformer.png ADDED
images/logo.png ADDED
server.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
2
+ from datetime import datetime as dt
3
+ import streamlit as st
4
+ from streamlit_tags import st_tags
5
+ import beam_search
6
+ import top_sampling
7
+ from pprint import pprint
8
+ import json
9
+
10
+ with open("config.json") as f:
11
+ cfg = json.loads(f.read())
12
+
13
+ st.set_page_config(layout="wide")
14
+
15
+ @st.cache(allow_output_mutation=True)
16
+ def load_model():
17
+ tokenizer = AutoTokenizer.from_pretrained("flax-community/t5-recipe-generation")
18
+ model = AutoModelForSeq2SeqLM.from_pretrained("flax-community/t5-recipe-generation")
19
+ generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
20
+ return generator, tokenizer
21
+
22
+ def sampling_changed(obj):
23
+ print(obj)
24
+
25
+
26
+ with st.spinner('Loading model...'):
27
+ generator, tokenizer = load_model()
28
+ # st.image("images/chef-transformer.png", width=400)
29
+ st.header("Chef transformers (flax-community)")
30
+ st.markdown("This demo uses [t5 trained on recipe-nlg](https://huggingface.co/flax-community/t5-recipe-generation) to generate recipe from a given set of ingredients")
31
+ img = st.sidebar.image("images/chef-transformer.png", width=200)
32
+ add_text_sidebar = st.sidebar.title("Popular recipes:")
33
+ add_text_sidebar = st.sidebar.text("Recipe preset(example#1)")
34
+ add_text_sidebar = st.sidebar.text("Recipe preset(example#2)")
35
+
36
+ add_text_sidebar = st.sidebar.title("Mode:")
37
+ sampling_mode = st.sidebar.selectbox("select a Mode", index=0, options=["Beam Search", "Top-k Sampling"])
38
+
39
+
40
+ original_keywords = st.multiselect("Choose ingredients",
41
+ cfg["first_100"],
42
+ ["parmesan cheese", "fresh oregano", "basil", "whole wheat flour"]
43
+ )
44
+
45
+ st.write("Add custom ingredients here:")
46
+ custom_keywords = st_tags(
47
+ label="",
48
+ text='Press enter to add more',
49
+ value=['salt'],
50
+ suggestions=cfg["next_100"],
51
+ maxtags = 15,
52
+ key='1')
53
+ all_ingredients = []
54
+ all_ingredients.extend(original_keywords)
55
+ all_ingredients.extend(custom_keywords)
56
+ all_ingredients = ", ".join(all_ingredients)
57
+ st.markdown("**Generate recipe for:** "+all_ingredients)
58
+
59
+
60
+ submit = st.button('Get Recipe!')
61
+ if submit:
62
+ with st.spinner('Generating recipe...'):
63
+ if sampling_mode == "Beam Search":
64
+ generated = generator(all_ingredients, return_tensors=True, return_text=False, **beam_search.generate_kwargs)
65
+ outputs = beam_search.post_generator(generated, tokenizer)
66
+ elif sampling_mode == "Top-k Sampling":
67
+ generated = generator(all_ingredients, return_tensors=True, return_text=False, **top_sampling.generate_kwargs)
68
+ outputs = top_sampling.post_generator(generated, tokenizer)
69
+ output = outputs[0]
70
+ markdown_output = ""
71
+ markdown_output += f"## {output['title'].capitalize()}\n"
72
+ markdown_output += f"#### Ingredients:\n"
73
+ for o in output["ingredients"]:
74
+ markdown_output += f"- {o}\n"
75
+ markdown_output += f"#### Directions:\n"
76
+ for o in output["directions"]:
77
+ markdown_output += f"- {o}\n"
78
+ st.markdown(markdown_output)
79
+ st.balloons()
80
+
top_sampling.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForSeq2SeqLM
3
+ from transformers import AutoTokenizer
4
+ from transformers import pipeline
5
+
6
+ from pprint import pprint
7
+ import re
8
+
9
+
10
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ # MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
12
+ # tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
13
+ # model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
14
+
15
+
16
+ def skip_special_tokens_and_prettify(text, tokenizer):
17
+ recipe_maps = {"<sep>": "--", "<section>": "\n"}
18
+ recipe_map_pattern = "|".join(map(re.escape, recipe_maps.keys()))
19
+
20
+ text = re.sub(
21
+ recipe_map_pattern,
22
+ lambda m: recipe_maps[m.group()],
23
+ re.sub("|".join(tokenizer.all_special_tokens), "", text)
24
+ )
25
+
26
+ data = {"title": "", "ingredients": [], "directions": []}
27
+ for section in text.split("\n"):
28
+ section = section.strip()
29
+ section = section.strip()
30
+ if section.startswith("title:"):
31
+ data["title"] = section.replace("title:", "").strip()
32
+ elif section.startswith("ingredients:"):
33
+ data["ingredients"] = [s.strip() for s in section.replace("ingredients:", "").split('--')]
34
+ elif section.startswith("directions:"):
35
+ data["directions"] = [s.strip() for s in section.replace("directions:", "").split('--')]
36
+ else:
37
+ pass
38
+
39
+ return data
40
+
41
+
42
+ def post_generator(output_tensors, tokenizer):
43
+ output_tensors = [output_tensors[i]["generated_token_ids"] for i in range(len(output_tensors))]
44
+ texts = tokenizer.batch_decode(output_tensors, skip_special_tokens=False)
45
+ texts = [skip_special_tokens_and_prettify(text, tokenizer) for text in texts]
46
+ return texts
47
+
48
+
49
+ # Example
50
+ generate_kwargs = {
51
+ "max_length": 512,
52
+ "min_length": 64,
53
+ "no_repeat_ngram_size": 3,
54
+ "do_sample": True,
55
+ "top_k": 60,
56
+ "top_p": 0.95,
57
+ "num_return_sequences": 3
58
+ }
59
+ # items = "potato, cheese"
60
+ # generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
61
+ # generated = generator(items, return_tensors=True, return_text=False, **generate_kwargs)
62
+ # outputs = post_generator(generated, tokenizer)
63
+ # pprint(outputs)