taesiri commited on
Commit
6415f35
1 Parent(s): 8ace0c2

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/macos,python
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=macos,python
3
+
4
+ ### macOS ###
5
+ # General
6
+ .DS_Store
7
+ .AppleDouble
8
+ .LSOverride
9
+
10
+ # Icon must end with two \r
11
+ Icon
12
+
13
+
14
+ # Thumbnails
15
+ ._*
16
+
17
+ # Files that might appear in the root of a volume
18
+ .DocumentRevisions-V100
19
+ .fseventsd
20
+ .Spotlight-V100
21
+ .TemporaryItems
22
+ .Trashes
23
+ .VolumeIcon.icns
24
+ .com.apple.timemachine.donotpresent
25
+
26
+ # Directories potentially created on remote AFP share
27
+ .AppleDB
28
+ .AppleDesktop
29
+ Network Trash Folder
30
+ Temporary Items
31
+ .apdisk
32
+
33
+ ### macOS Patch ###
34
+ # iCloud generated files
35
+ *.icloud
36
+
37
+ ### Python ###
38
+ # Byte-compiled / optimized / DLL files
39
+ __pycache__/
40
+ *.py[cod]
41
+ *$py.class
42
+
43
+ # C extensions
44
+ *.so
45
+
46
+ # Distribution / packaging
47
+ .Python
48
+ build/
49
+ develop-eggs/
50
+ dist/
51
+ downloads/
52
+ eggs/
53
+ .eggs/
54
+ lib/
55
+ lib64/
56
+ parts/
57
+ sdist/
58
+ var/
59
+ wheels/
60
+ share/python-wheels/
61
+ *.egg-info/
62
+ .installed.cfg
63
+ *.egg
64
+ MANIFEST
65
+
66
+ # PyInstaller
67
+ # Usually these files are written by a python script from a template
68
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
69
+ *.manifest
70
+ *.spec
71
+
72
+ # Installer logs
73
+ pip-log.txt
74
+ pip-delete-this-directory.txt
75
+
76
+ # Unit test / coverage reports
77
+ htmlcov/
78
+ .tox/
79
+ .nox/
80
+ .coverage
81
+ .coverage.*
82
+ .cache
83
+ nosetests.xml
84
+ coverage.xml
85
+ *.cover
86
+ *.py,cover
87
+ .hypothesis/
88
+ .pytest_cache/
89
+ cover/
90
+
91
+ # Translations
92
+ *.mo
93
+ *.pot
94
+
95
+ # Django stuff:
96
+ *.log
97
+ local_settings.py
98
+ db.sqlite3
99
+ db.sqlite3-journal
100
+
101
+ # Flask stuff:
102
+ instance/
103
+ .webassets-cache
104
+
105
+ # Scrapy stuff:
106
+ .scrapy
107
+
108
+ # Sphinx documentation
109
+ docs/_build/
110
+
111
+ # PyBuilder
112
+ .pybuilder/
113
+ target/
114
+
115
+ # Jupyter Notebook
116
+ .ipynb_checkpoints
117
+
118
+ # IPython
119
+ profile_default/
120
+ ipython_config.py
121
+
122
+ # pyenv
123
+ # For a library or package, you might want to ignore these files since the code is
124
+ # intended to run in multiple environments; otherwise, check them in:
125
+ # .python-version
126
+
127
+ # pipenv
128
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
129
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
130
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
131
+ # install all needed dependencies.
132
+ #Pipfile.lock
133
+
134
+ # poetry
135
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
136
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
137
+ # commonly ignored for libraries.
138
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
139
+ #poetry.lock
140
+
141
+ # pdm
142
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
143
+ #pdm.lock
144
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
145
+ # in version control.
146
+ # https://pdm.fming.dev/#use-with-ide
147
+ .pdm.toml
148
+
149
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
150
+ __pypackages__/
151
+
152
+ # Celery stuff
153
+ celerybeat-schedule
154
+ celerybeat.pid
155
+
156
+ # SageMath parsed files
157
+ *.sage.py
158
+
159
+ # Environments
160
+ .env
161
+ .venv
162
+ env/
163
+ venv/
164
+ ENV/
165
+ env.bak/
166
+ venv.bak/
167
+
168
+ # Spyder project settings
169
+ .spyderproject
170
+ .spyproject
171
+
172
+ # Rope project settings
173
+ .ropeproject
174
+
175
+ # mkdocs documentation
176
+ /site
177
+
178
+ # mypy
179
+ .mypy_cache/
180
+ .dmypy.json
181
+ dmypy.json
182
+
183
+ # Pyre type checker
184
+ .pyre/
185
+
186
+ # pytype static type analyzer
187
+ .pytype/
188
+
189
+ # Cython debug symbols
190
+ cython_debug/
191
+
192
+ # PyCharm
193
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
194
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
195
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
196
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
197
+ #.idea/
198
+
199
+ ### Python Patch ###
200
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
201
+ poetry.toml
202
+
203
+ # ruff
204
+ .ruff_cache/
205
+
206
+ # LSP config files
207
+ pyrightconfig.json
208
+
209
+ # End of https://www.toptal.com/developers/gitignore/api/macos,python
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: HuggingGPT Lite
3
- emoji: 🔥
4
  colorFrom: red
5
  colorTo: gray
6
  sdk: gradio
 
1
  ---
2
+ title: HuggingGPT - Lite
3
+ emoji: 🎐
4
  colorFrom: red
5
  colorTo: gray
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import gradio as gr
3
+ import re
4
+ from diffusers.utils import load_image
5
+ import requests
6
+ from awesome_chat import chat_huggingface
7
+ import os
8
+
9
+ os.makedirs("public/images", exist_ok=True)
10
+ os.makedirs("public/audios", exist_ok=True)
11
+ os.makedirs("public/videos", exist_ok=True)
12
+
13
+ HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
14
+ OPENAI_KEY = os.environ.get("OPENAI_KEY")
15
+
16
+
17
+ class Client:
18
+ def __init__(self) -> None:
19
+ self.OPENAI_KEY = OPENAI_KEY
20
+ self.HUGGINGFACE_TOKEN = HUGGINGFACE_TOKEN
21
+ self.all_messages = []
22
+
23
+ def set_key(self, openai_key):
24
+ self.OPENAI_KEY = openai_key
25
+ return self.OPENAI_KEY
26
+
27
+ def set_token(self, huggingface_token):
28
+ self.HUGGINGFACE_TOKEN = huggingface_token
29
+ return self.HUGGINGFACE_TOKEN
30
+
31
+ def add_message(self, content, role):
32
+ message = {"role": role, "content": content}
33
+ self.all_messages.append(message)
34
+
35
+ def extract_medias(self, message):
36
+ # url_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?")
37
+ urls = []
38
+ # for match in url_pattern.finditer(message):
39
+ # if match.group(0) not in urls:
40
+ # urls.append(match.group(0))
41
+
42
+ image_pattern = re.compile(
43
+ r"(http(s?):|\/)?([\.\/_\w:-])*?\.(jpg|jpeg|tiff|gif|png)"
44
+ )
45
+ image_urls = []
46
+ for match in image_pattern.finditer(message):
47
+ if match.group(0) not in image_urls:
48
+ image_urls.append(match.group(0))
49
+
50
+ audio_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(flac|wav)")
51
+ audio_urls = []
52
+ for match in audio_pattern.finditer(message):
53
+ if match.group(0) not in audio_urls:
54
+ audio_urls.append(match.group(0))
55
+
56
+ video_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(mp4)")
57
+ video_urls = []
58
+ for match in video_pattern.finditer(message):
59
+ if match.group(0) not in video_urls:
60
+ video_urls.append(match.group(0))
61
+
62
+ return urls, image_urls, audio_urls, video_urls
63
+
64
+ def add_text(self, messages, message):
65
+ if (
66
+ not self.OPENAI_KEY
67
+ or not self.OPENAI_KEY.startswith("sk-")
68
+ or not self.HUGGINGFACE_TOKEN
69
+ or not self.HUGGINGFACE_TOKEN.startswith("hf_")
70
+ ):
71
+ return (
72
+ messages,
73
+ "Please set your OpenAI API key and Hugging Face token first!!!",
74
+ )
75
+ self.add_message(message, "user")
76
+ messages = messages + [(message, None)]
77
+ urls, image_urls, audio_urls, video_urls = self.extract_medias(message)
78
+
79
+ for image_url in image_urls:
80
+ if not image_url.startswith("http") and not image_url.startswith("public"):
81
+ image_url = "public/" + image_url
82
+ image = load_image(image_url)
83
+ name = f"public/images/{str(uuid.uuid4())[:4]}.jpg"
84
+ image.save(name)
85
+ messages = messages + [((f"{name}",), None)]
86
+ for audio_url in audio_urls and not audio_url.startswith("public"):
87
+ if not audio_url.startswith("http"):
88
+ audio_url = "public/" + audio_url
89
+ ext = audio_url.split(".")[-1]
90
+ name = f"public/audios/{str(uuid.uuid4()[:4])}.{ext}"
91
+ response = requests.get(audio_url)
92
+ with open(name, "wb") as f:
93
+ f.write(response.content)
94
+ messages = messages + [((f"{name}",), None)]
95
+ for video_url in video_urls and not video_url.startswith("public"):
96
+ if not video_url.startswith("http"):
97
+ video_url = "public/" + video_url
98
+ ext = video_url.split(".")[-1]
99
+ name = f"public/audios/{str(uuid.uuid4()[:4])}.{ext}"
100
+ response = requests.get(video_url)
101
+ with open(name, "wb") as f:
102
+ f.write(response.content)
103
+ messages = messages + [((f"{name}",), None)]
104
+ return messages, ""
105
+
106
+ def bot(self, messages):
107
+ if (
108
+ not self.OPENAI_KEY
109
+ or not self.OPENAI_KEY.startswith("sk-")
110
+ or not self.HUGGINGFACE_TOKEN
111
+ or not self.HUGGINGFACE_TOKEN.startswith("hf_")
112
+ ):
113
+ return messages, {}
114
+ message, results = chat_huggingface(
115
+ self.all_messages, self.OPENAI_KEY, self.HUGGINGFACE_TOKEN
116
+ )
117
+ urls, image_urls, audio_urls, video_urls = self.extract_medias(message)
118
+ self.add_message(message, "assistant")
119
+ messages[-1][1] = message
120
+ for image_url in image_urls:
121
+ if not image_url.startswith("http"):
122
+ image_url = image_url.replace("public/", "")
123
+ messages = messages + [((None, (f"public/{image_url}",)))]
124
+ # else:
125
+ # messages = messages + [((None, (f"{image_url}",)))]
126
+ for audio_url in audio_urls:
127
+ if not audio_url.startswith("http"):
128
+ audio_url = audio_url.replace("public/", "")
129
+ messages = messages + [((None, (f"public/{audio_url}",)))]
130
+ # else:
131
+ # messages = messages + [((None, (f"{audio_url}",)))]
132
+ for video_url in video_urls:
133
+ if not video_url.startswith("http"):
134
+ video_url = video_url.replace("public/", "")
135
+ messages = messages + [((None, (f"public/{video_url}",)))]
136
+ # else:
137
+ # messages = messages + [((None, (f"{video_url}",)))]
138
+ # replace int key to string key
139
+ results = {str(k): v for k, v in results.items()}
140
+ return messages, results
141
+
142
+
143
+ css = ".json {height: 527px; overflow: scroll;} .json-holder {height: 527px; overflow: scroll;}"
144
+ with gr.Blocks(css=css) as demo:
145
+ state = gr.State(value={"client": Client()})
146
+ gr.Markdown("<h1><center>HuggingGPT - Lite 🎐 </center></h1>")
147
+ gr.Markdown(
148
+ "<p align='center'><img src='https://i.ibb.co/qNH3Jym/logo.png' height='25' width='95'></p>"
149
+ )
150
+ gr.Markdown(
151
+ "<p align='center' style='font-size: 20px;'>A system to connect LLMs with ML community. See our <a href='https://github.com/microsoft/JARVIS'>Project</a> and <a href='http://arxiv.org/abs/2303.17580'>Paper</a>.</p>"
152
+ )
153
+ gr.HTML(
154
+ """<center><a href="https://huggingface.co/spaces/taesiri/HuggingGPT-Lite?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space and run securely with your OpenAI API Key and Hugging Face Token</center>"""
155
+ )
156
+ gr.Markdown(
157
+ """>**Note**: This is a further lite version of the original HuggingGPT designed to run on CPU-only spaces. This model by default uses `gpt-3.5-turbo` which is much much cheaper than `text-davinci-003`. """
158
+ )
159
+ if not OPENAI_KEY:
160
+ with gr.Row().style():
161
+ with gr.Column(scale=0.85):
162
+ openai_api_key = gr.Textbox(
163
+ show_label=False,
164
+ placeholder="Set your OpenAI API key here and press Enter",
165
+ lines=1,
166
+ type="password",
167
+ ).style(container=False)
168
+ with gr.Column(scale=0.15, min_width=0):
169
+ btn1 = gr.Button("Submit").style(full_height=True)
170
+
171
+ if not HUGGINGFACE_TOKEN:
172
+ with gr.Row().style():
173
+ with gr.Column(scale=0.85):
174
+ hugging_face_token = gr.Textbox(
175
+ show_label=False,
176
+ placeholder="Set your Hugging Face Token here and press Enter",
177
+ lines=1,
178
+ type="password",
179
+ ).style(container=False)
180
+ with gr.Column(scale=0.15, min_width=0):
181
+ btn3 = gr.Button("Submit").style(full_height=True)
182
+
183
+ with gr.Row().style():
184
+ with gr.Column(scale=0.6):
185
+ chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500)
186
+ with gr.Column(scale=0.4):
187
+ results = gr.JSON(elem_classes="json")
188
+
189
+ with gr.Row().style():
190
+ with gr.Column(scale=0.85):
191
+ txt = gr.Textbox(
192
+ show_label=False,
193
+ placeholder="Enter text and press enter. The url must contain the media type. e.g, https://example.com/example.jpg",
194
+ lines=1,
195
+ ).style(container=False)
196
+ with gr.Column(scale=0.15, min_width=0):
197
+ btn2 = gr.Button("Send").style(full_height=True)
198
+
199
+ def set_key(state, openai_api_key):
200
+ return state["client"].set_key(openai_api_key)
201
+
202
+ def add_text(state, chatbot, txt):
203
+ return state["client"].add_text(chatbot, txt)
204
+
205
+ def set_token(state, hugging_face_token):
206
+ return state["client"].set_token(hugging_face_token)
207
+
208
+ def bot(state, chatbot):
209
+ return state["client"].bot(chatbot)
210
+
211
+ if not OPENAI_KEY:
212
+ openai_api_key.submit(set_key, [state, openai_api_key], [openai_api_key])
213
+ btn1.click(set_key, [state, openai_api_key], [openai_api_key])
214
+
215
+ if not HUGGINGFACE_TOKEN:
216
+ hugging_face_token.submit(
217
+ set_token, [state, hugging_face_token], [hugging_face_token]
218
+ )
219
+ btn3.click(set_token, [state, hugging_face_token], [hugging_face_token])
220
+
221
+ txt.submit(add_text, [state, chatbot, txt], [chatbot, txt]).then(
222
+ bot, [state, chatbot], [chatbot, results]
223
+ )
224
+ btn2.click(add_text, [state, chatbot, txt], [chatbot, txt]).then(
225
+ bot, [state, chatbot], [chatbot, results]
226
+ )
227
+
228
+ gr.Examples(
229
+ examples=[
230
+ "Given a collection of image A: /examples/a.jpg, B: /examples/b.jpg, C: /examples/c.jpg, please tell me how many zebras in these picture?",
231
+ "show me a joke and an image of cat",
232
+ "what is in the examples/a.jpg",
233
+ ],
234
+ inputs=txt,
235
+ )
236
+
237
+ demo.launch()
awesome_chat.py ADDED
@@ -0,0 +1,933 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import copy
3
+ import datetime
4
+ from io import BytesIO
5
+ import io
6
+ import os
7
+ import random
8
+ import time
9
+ import traceback
10
+ import uuid
11
+ import requests
12
+ import re
13
+ import json
14
+ import logging
15
+ import argparse
16
+ import yaml
17
+ from PIL import Image, ImageDraw
18
+ from diffusers.utils import load_image
19
+ from pydub import AudioSegment
20
+ import threading
21
+ from queue import Queue
22
+ from get_token_ids import get_token_ids_for_task_parsing, get_token_ids_for_choose_model, count_tokens, get_max_context_length
23
+ from huggingface_hub.inference_api import InferenceApi
24
+ from huggingface_hub.inference_api import ALL_TASKS
25
+ from models_server import models, status
26
+ from functools import partial
27
+ from huggingface_hub import Repository
28
+
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument("--config", type=str, default="config.yaml.dev")
31
+ parser.add_argument("--mode", type=str, default="cli")
32
+ args = parser.parse_args()
33
+
34
+ if __name__ != "__main__":
35
+ args.config = "config.gradio.yaml"
36
+
37
+ config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader)
38
+
39
+ if not os.path.exists("logs"):
40
+ os.mkdir("logs")
41
+
42
+ now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
43
+
44
+ DATASET_REPO_URL = "https://huggingface.co/datasets/tricktreat/HuggingGPT_logs"
45
+ LOG_HF_TOKEN = os.environ.get("LOG_HF_TOKEN")
46
+ if LOG_HF_TOKEN:
47
+ repo = Repository(
48
+ local_dir="logs", clone_from=DATASET_REPO_URL, use_auth_token=LOG_HF_TOKEN
49
+ )
50
+
51
+ logger = logging.getLogger(__name__)
52
+ logger.setLevel(logging.CRITICAL)
53
+
54
+ handler = logging.StreamHandler()
55
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
56
+ handler.setFormatter(formatter)
57
+ if not config["debug"]:
58
+ handler.setLevel(logging.INFO)
59
+ logger.addHandler(handler)
60
+
61
+ log_file = config["log_file"]
62
+ if log_file:
63
+ log_file = log_file.replace("TIMESTAMP", now)
64
+ filehandler = logging.FileHandler(log_file)
65
+ filehandler.setLevel(logging.DEBUG)
66
+ filehandler.setFormatter(formatter)
67
+ logger.addHandler(filehandler)
68
+
69
+ LLM = config["model"]
70
+ use_completion = config["use_completion"]
71
+
72
+ # consistent: wrong msra model name
73
+ LLM_encoding = LLM
74
+ if LLM == "gpt-3.5-turbo":
75
+ LLM_encoding = "text-davinci-003"
76
+ task_parsing_highlight_ids = get_token_ids_for_task_parsing(LLM_encoding)
77
+ choose_model_highlight_ids = get_token_ids_for_choose_model(LLM_encoding)
78
+
79
+ # ENDPOINT MODEL NAME
80
+ # /v1/chat/completions gpt-4, gpt-4-0314, gpt-4-32k, gpt-4-32k-0314, gpt-3.5-turbo, gpt-3.5-turbo-0301
81
+ # /v1/completions text-davinci-003, text-davinci-002, text-curie-001, text-babbage-001, text-ada-001, davinci, curie, babbage, ada
82
+
83
+ if use_completion:
84
+ api_name = "completions"
85
+ else:
86
+ api_name = "chat/completions"
87
+
88
+ if not config["dev"]:
89
+ if not config["openai"]["key"].startswith("sk-") and not config["openai"]["key"]=="gradio":
90
+ raise ValueError("Incrorrect OpenAI key. Please check your config.yaml file.")
91
+ OPENAI_KEY = config["openai"]["key"]
92
+ endpoint = f"https://api.openai.com/v1/{api_name}"
93
+ if OPENAI_KEY.startswith("sk-"):
94
+ HEADER = {
95
+ "Authorization": f"Bearer {OPENAI_KEY}"
96
+ }
97
+ else:
98
+ HEADER = None
99
+ else:
100
+ endpoint = f"{config['local']['endpoint']}/v1/{api_name}"
101
+ HEADER = None
102
+
103
+ PROXY = None
104
+ if config["proxy"]:
105
+ PROXY = {
106
+ "https": config["proxy"],
107
+ }
108
+
109
+ inference_mode = config["inference_mode"]
110
+
111
+ parse_task_demos_or_presteps = open(config["demos_or_presteps"]["parse_task"], "r").read()
112
+ choose_model_demos_or_presteps = open(config["demos_or_presteps"]["choose_model"], "r").read()
113
+ response_results_demos_or_presteps = open(config["demos_or_presteps"]["response_results"], "r").read()
114
+
115
+ parse_task_prompt = config["prompt"]["parse_task"]
116
+ choose_model_prompt = config["prompt"]["choose_model"]
117
+ response_results_prompt = config["prompt"]["response_results"]
118
+
119
+ parse_task_tprompt = config["tprompt"]["parse_task"]
120
+ choose_model_tprompt = config["tprompt"]["choose_model"]
121
+ response_results_tprompt = config["tprompt"]["response_results"]
122
+
123
+ MODELS = [json.loads(line) for line in open("data/p0_models.jsonl", "r").readlines()]
124
+ MODELS_MAP = {}
125
+ for model in MODELS:
126
+ tag = model["task"]
127
+ if tag not in MODELS_MAP:
128
+ MODELS_MAP[tag] = []
129
+ MODELS_MAP[tag].append(model)
130
+ METADATAS = {}
131
+ for model in MODELS:
132
+ METADATAS[model["id"]] = model
133
+
134
+ def convert_chat_to_completion(data):
135
+ messages = data.pop('messages', [])
136
+ tprompt = ""
137
+ if messages[0]['role'] == "system":
138
+ tprompt = messages[0]['content']
139
+ messages = messages[1:]
140
+ final_prompt = ""
141
+ for message in messages:
142
+ if message['role'] == "user":
143
+ final_prompt += ("<im_start>"+ "user" + "\n" + message['content'] + "<im_end>\n")
144
+ elif message['role'] == "assistant":
145
+ final_prompt += ("<im_start>"+ "assistant" + "\n" + message['content'] + "<im_end>\n")
146
+ else:
147
+ final_prompt += ("<im_start>"+ "system" + "\n" + message['content'] + "<im_end>\n")
148
+ final_prompt = tprompt + final_prompt
149
+ final_prompt = final_prompt + "<im_start>assistant"
150
+ data["prompt"] = final_prompt
151
+ data['stop'] = data.get('stop', ["<im_end>"])
152
+ data['max_tokens'] = data.get('max_tokens', max(get_max_context_length(LLM) - count_tokens(LLM_encoding, final_prompt), 1))
153
+ return data
154
+
155
+ def send_request(data):
156
+ global HEADER
157
+ openaikey = data.pop("openaikey")
158
+ if use_completion:
159
+ data = convert_chat_to_completion(data)
160
+ if openaikey and openaikey.startswith("sk-"):
161
+ HEADER = {
162
+ "Authorization": f"Bearer {openaikey}"
163
+ }
164
+
165
+ response = requests.post(endpoint, json=data, headers=HEADER, proxies=PROXY)
166
+ logger.debug(response.text.strip())
167
+ if "choices" not in response.json():
168
+ return response.json()
169
+ if use_completion:
170
+ return response.json()["choices"][0]["text"].strip()
171
+ else:
172
+ return response.json()["choices"][0]["message"]["content"].strip()
173
+
174
+ def replace_slot(text, entries):
175
+ for key, value in entries.items():
176
+ if not isinstance(value, str):
177
+ value = str(value)
178
+ text = text.replace("{{" + key +"}}", value.replace('"', "'").replace('\n', ""))
179
+ return text
180
+
181
+ def find_json(s):
182
+ s = s.replace("\'", "\"")
183
+ start = s.find("{")
184
+ end = s.rfind("}")
185
+ res = s[start:end+1]
186
+ res = res.replace("\n", "")
187
+ return res
188
+
189
+ def field_extract(s, field):
190
+ try:
191
+ field_rep = re.compile(f'{field}.*?:.*?"(.*?)"', re.IGNORECASE)
192
+ extracted = field_rep.search(s).group(1).replace("\"", "\'")
193
+ except:
194
+ field_rep = re.compile(f'{field}:\ *"(.*?)"', re.IGNORECASE)
195
+ extracted = field_rep.search(s).group(1).replace("\"", "\'")
196
+ return extracted
197
+
198
+ def get_id_reason(choose_str):
199
+ reason = field_extract(choose_str, "reason")
200
+ id = field_extract(choose_str, "id")
201
+ choose = {"id": id, "reason": reason}
202
+ return id.strip(), reason.strip(), choose
203
+
204
+ def record_case(success, **args):
205
+ if not success:
206
+ return
207
+ f = open(f"logs/log_success_{now}.jsonl", "a")
208
+ log = args
209
+ f.write(json.dumps(log) + "\n")
210
+ f.close()
211
+ if LOG_HF_TOKEN:
212
+ commit_url = repo.push_to_hub(blocking=False)
213
+
214
+ def image_to_bytes(img_url):
215
+ img_byte = io.BytesIO()
216
+ type = img_url.split(".")[-1]
217
+ load_image(img_url).save(img_byte, format="png")
218
+ img_data = img_byte.getvalue()
219
+ return img_data
220
+
221
+ def resource_has_dep(command):
222
+ args = command["args"]
223
+ for _, v in args.items():
224
+ if "<GENERATED>" in v:
225
+ return True
226
+ return False
227
+
228
+ def fix_dep(tasks):
229
+ for task in tasks:
230
+ args = task["args"]
231
+ task["dep"] = []
232
+ for k, v in args.items():
233
+ if "<GENERATED>" in v:
234
+ dep_task_id = int(v.split("-")[1])
235
+ if dep_task_id not in task["dep"]:
236
+ task["dep"].append(dep_task_id)
237
+ if len(task["dep"]) == 0:
238
+ task["dep"] = [-1]
239
+ return tasks
240
+
241
+ def unfold(tasks):
242
+ flag_unfold_task = False
243
+ try:
244
+ for task in tasks:
245
+ for key, value in task["args"].items():
246
+ if "<GENERATED>" in value:
247
+ generated_items = value.split(",")
248
+ if len(generated_items) > 1:
249
+ flag_unfold_task = True
250
+ for item in generated_items:
251
+ new_task = copy.deepcopy(task)
252
+ dep_task_id = int(item.split("-")[1])
253
+ new_task["dep"] = [dep_task_id]
254
+ new_task["args"][key] = item
255
+ tasks.append(new_task)
256
+ tasks.remove(task)
257
+ except Exception as e:
258
+ print(e)
259
+ traceback.print_exc()
260
+ logger.debug("unfold task failed.")
261
+
262
+ if flag_unfold_task:
263
+ logger.debug(f"unfold tasks: {tasks}")
264
+
265
+ return tasks
266
+
267
+ def chitchat(messages, openaikey=None):
268
+ data = {
269
+ "model": LLM,
270
+ "messages": messages,
271
+ "openaikey": openaikey
272
+ }
273
+ return send_request(data)
274
+
275
+ def parse_task(context, input, openaikey=None):
276
+ demos_or_presteps = parse_task_demos_or_presteps
277
+ messages = json.loads(demos_or_presteps)
278
+ messages.insert(0, {"role": "system", "content": parse_task_tprompt})
279
+
280
+ # cut chat logs
281
+ start = 0
282
+ while start <= len(context):
283
+ history = context[start:]
284
+ prompt = replace_slot(parse_task_prompt, {
285
+ "input": input,
286
+ "context": history
287
+ })
288
+ messages.append({"role": "user", "content": prompt})
289
+ history_text = "<im_end>\nuser<im_start>".join([m["content"] for m in messages])
290
+ num = count_tokens(LLM_encoding, history_text)
291
+ if get_max_context_length(LLM) - num > 800:
292
+ break
293
+ messages.pop()
294
+ start += 2
295
+
296
+ logger.debug(messages)
297
+ data = {
298
+ "model": LLM,
299
+ "messages": messages,
300
+ "temperature": 0,
301
+ "logit_bias": {item: config["logit_bias"]["parse_task"] for item in task_parsing_highlight_ids},
302
+ "openaikey": openaikey
303
+ }
304
+ return send_request(data)
305
+
306
+ def choose_model(input, task, metas, openaikey = None):
307
+ prompt = replace_slot(choose_model_prompt, {
308
+ "input": input,
309
+ "task": task,
310
+ "metas": metas,
311
+ })
312
+ demos_or_presteps = replace_slot(choose_model_demos_or_presteps, {
313
+ "input": input,
314
+ "task": task,
315
+ "metas": metas
316
+ })
317
+ messages = json.loads(demos_or_presteps)
318
+ messages.insert(0, {"role": "system", "content": choose_model_tprompt})
319
+ messages.append({"role": "user", "content": prompt})
320
+ logger.debug(messages)
321
+ data = {
322
+ "model": LLM,
323
+ "messages": messages,
324
+ "temperature": 0,
325
+ "logit_bias": {item: config["logit_bias"]["choose_model"] for item in choose_model_highlight_ids}, # 5
326
+ "openaikey": openaikey
327
+ }
328
+ return send_request(data)
329
+
330
+
331
+ def response_results(input, results, openaikey=None):
332
+ results = [v for k, v in sorted(results.items(), key=lambda item: item[0])]
333
+ prompt = replace_slot(response_results_prompt, {
334
+ "input": input,
335
+ })
336
+ demos_or_presteps = replace_slot(response_results_demos_or_presteps, {
337
+ "input": input,
338
+ "processes": results
339
+ })
340
+ messages = json.loads(demos_or_presteps)
341
+ messages.insert(0, {"role": "system", "content": response_results_tprompt})
342
+ messages.append({"role": "user", "content": prompt})
343
+ logger.debug(messages)
344
+ data = {
345
+ "model": LLM,
346
+ "messages": messages,
347
+ "temperature": 0,
348
+ "openaikey": openaikey
349
+ }
350
+ return send_request(data)
351
+
352
+ def huggingface_model_inference(model_id, data, task, huggingfacetoken=None):
353
+ if huggingfacetoken is None:
354
+ HUGGINGFACE_HEADERS = {}
355
+ else:
356
+ HUGGINGFACE_HEADERS = {
357
+ "Authorization": f"Bearer {huggingfacetoken}",
358
+ }
359
+ task_url = f"https://api-inference.huggingface.co/models/{model_id}" # InferenceApi does not yet support some tasks
360
+ inference = InferenceApi(repo_id=model_id, token=huggingfacetoken)
361
+
362
+ # NLP tasks
363
+ if task == "question-answering":
364
+ inputs = {"question": data["text"], "context": (data["context"] if "context" in data else "" )}
365
+ result = inference(inputs)
366
+ if task == "sentence-similarity":
367
+ inputs = {"source_sentence": data["text1"], "target_sentence": data["text2"]}
368
+ result = inference(inputs)
369
+ if task in ["text-classification", "token-classification", "text2text-generation", "summarization", "translation", "conversational", "text-generation"]:
370
+ inputs = data["text"]
371
+ result = inference(inputs)
372
+
373
+ # CV tasks
374
+ if task == "visual-question-answering" or task == "document-question-answering":
375
+ img_url = data["image"]
376
+ text = data["text"]
377
+ img_data = image_to_bytes(img_url)
378
+ img_base64 = base64.b64encode(img_data).decode("utf-8")
379
+ json_data = {}
380
+ json_data["inputs"] = {}
381
+ json_data["inputs"]["question"] = text
382
+ json_data["inputs"]["image"] = img_base64
383
+ result = requests.post(task_url, headers=HUGGINGFACE_HEADERS, json=json_data).json()
384
+ # result = inference(inputs) # not support
385
+
386
+ if task == "image-to-image":
387
+ img_url = data["image"]
388
+ img_data = image_to_bytes(img_url)
389
+ # result = inference(data=img_data) # not support
390
+ HUGGINGFACE_HEADERS["Content-Length"] = str(len(img_data))
391
+ r = requests.post(task_url, headers=HUGGINGFACE_HEADERS, data=img_data)
392
+ result = r.json()
393
+ if "path" in result:
394
+ result["generated image"] = result.pop("path")
395
+
396
+ if task == "text-to-image":
397
+ inputs = data["text"]
398
+ img = inference(inputs)
399
+ name = str(uuid.uuid4())[:4]
400
+ img.save(f"public/images/{name}.png")
401
+ result = {}
402
+ result["generated image"] = f"/images/{name}.png"
403
+
404
+ if task == "image-segmentation":
405
+ img_url = data["image"]
406
+ img_data = image_to_bytes(img_url)
407
+ image = Image.open(BytesIO(img_data))
408
+ predicted = inference(data=img_data)
409
+ colors = []
410
+ for i in range(len(predicted)):
411
+ colors.append((random.randint(100, 255), random.randint(100, 255), random.randint(100, 255), 155))
412
+ for i, pred in enumerate(predicted):
413
+ label = pred["label"]
414
+ mask = pred.pop("mask").encode("utf-8")
415
+ mask = base64.b64decode(mask)
416
+ mask = Image.open(BytesIO(mask), mode='r')
417
+ mask = mask.convert('L')
418
+
419
+ layer = Image.new('RGBA', mask.size, colors[i])
420
+ image.paste(layer, (0, 0), mask)
421
+ name = str(uuid.uuid4())[:4]
422
+ image.save(f"public/images/{name}.jpg")
423
+ result = {}
424
+ result["generated image with segmentation mask"] = f"/images/{name}.jpg"
425
+ result["predicted"] = predicted
426
+
427
+ if task == "object-detection":
428
+ img_url = data["image"]
429
+ img_data = image_to_bytes(img_url)
430
+ predicted = inference(data=img_data)
431
+ image = Image.open(BytesIO(img_data))
432
+ draw = ImageDraw.Draw(image)
433
+ labels = list(item['label'] for item in predicted)
434
+ color_map = {}
435
+ for label in labels:
436
+ if label not in color_map:
437
+ color_map[label] = (random.randint(0, 255), random.randint(0, 100), random.randint(0, 255))
438
+ for label in predicted:
439
+ box = label["box"]
440
+ draw.rectangle(((box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])), outline=color_map[label["label"]], width=2)
441
+ draw.text((box["xmin"]+5, box["ymin"]-15), label["label"], fill=color_map[label["label"]])
442
+ name = str(uuid.uuid4())[:4]
443
+ image.save(f"public/images/{name}.jpg")
444
+ result = {}
445
+ result["generated image with predicted box"] = f"/images/{name}.jpg"
446
+ result["predicted"] = predicted
447
+
448
+ if task in ["image-classification"]:
449
+ img_url = data["image"]
450
+ img_data = image_to_bytes(img_url)
451
+ result = inference(data=img_data)
452
+
453
+ if task == "image-to-text":
454
+ img_url = data["image"]
455
+ img_data = image_to_bytes(img_url)
456
+ HUGGINGFACE_HEADERS["Content-Length"] = str(len(img_data))
457
+ r = requests.post(task_url, headers=HUGGINGFACE_HEADERS, data=img_data)
458
+ result = {}
459
+ if "generated_text" in r.json()[0]:
460
+ result["generated text"] = r.json()[0].pop("generated_text")
461
+
462
+ # AUDIO tasks
463
+ if task == "text-to-speech":
464
+ inputs = data["text"]
465
+ response = inference(inputs, raw_response=True)
466
+ # response = requests.post(task_url, headers=HUGGINGFACE_HEADERS, json={"inputs": text})
467
+ name = str(uuid.uuid4())[:4]
468
+ with open(f"public/audios/{name}.flac", "wb") as f:
469
+ f.write(response.content)
470
+ result = {"generated audio": f"/audios/{name}.flac"}
471
+ if task in ["automatic-speech-recognition", "audio-to-audio", "audio-classification"]:
472
+ audio_url = data["audio"]
473
+ audio_data = requests.get(audio_url, timeout=10).content
474
+ response = inference(data=audio_data, raw_response=True)
475
+ result = response.json()
476
+ if task == "audio-to-audio":
477
+ content = None
478
+ type = None
479
+ for k, v in result[0].items():
480
+ if k == "blob":
481
+ content = base64.b64decode(v.encode("utf-8"))
482
+ if k == "content-type":
483
+ type = "audio/flac".split("/")[-1]
484
+ audio = AudioSegment.from_file(BytesIO(content))
485
+ name = str(uuid.uuid4())[:4]
486
+ audio.export(f"public/audios/{name}.{type}", format=type)
487
+ result = {"generated audio": f"/audios/{name}.{type}"}
488
+ return result
489
+
490
+ def local_model_inference(model_id, data, task):
491
+ inference = partial(models, model_id)
492
+ # contronlet
493
+ if model_id.startswith("lllyasviel/sd-controlnet-"):
494
+ img_url = data["image"]
495
+ text = data["text"]
496
+ results = inference({"img_url": img_url, "text": text})
497
+ if "path" in results:
498
+ results["generated image"] = results.pop("path")
499
+ return results
500
+ if model_id.endswith("-control"):
501
+ img_url = data["image"]
502
+ results = inference({"img_url": img_url})
503
+ if "path" in results:
504
+ results["generated image"] = results.pop("path")
505
+ return results
506
+
507
+ if task == "text-to-video":
508
+ results = inference(data)
509
+ if "path" in results:
510
+ results["generated video"] = results.pop("path")
511
+ return results
512
+
513
+ # NLP tasks
514
+ if task == "question-answering" or task == "sentence-similarity":
515
+ results = inference(json=data)
516
+ return results
517
+ if task in ["text-classification", "token-classification", "text2text-generation", "summarization", "translation", "conversational", "text-generation"]:
518
+ results = inference(json=data)
519
+ return results
520
+
521
+ # CV tasks
522
+ if task == "depth-estimation":
523
+ img_url = data["image"]
524
+ results = inference({"img_url": img_url})
525
+ if "path" in results:
526
+ results["generated depth image"] = results.pop("path")
527
+ return results
528
+ if task == "image-segmentation":
529
+ img_url = data["image"]
530
+ results = inference({"img_url": img_url})
531
+ results["generated image with segmentation mask"] = results.pop("path")
532
+ return results
533
+ if task == "image-to-image":
534
+ img_url = data["image"]
535
+ results = inference({"img_url": img_url})
536
+ if "path" in results:
537
+ results["generated image"] = results.pop("path")
538
+ return results
539
+ if task == "text-to-image":
540
+ results = inference(data)
541
+ if "path" in results:
542
+ results["generated image"] = results.pop("path")
543
+ return results
544
+ if task == "object-detection":
545
+ img_url = data["image"]
546
+ predicted = inference({"img_url": img_url})
547
+ if "error" in predicted:
548
+ return predicted
549
+ image = load_image(img_url)
550
+ draw = ImageDraw.Draw(image)
551
+ labels = list(item['label'] for item in predicted)
552
+ color_map = {}
553
+ for label in labels:
554
+ if label not in color_map:
555
+ color_map[label] = (random.randint(0, 255), random.randint(0, 100), random.randint(0, 255))
556
+ for label in predicted:
557
+ box = label["box"]
558
+ draw.rectangle(((box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])), outline=color_map[label["label"]], width=2)
559
+ draw.text((box["xmin"]+5, box["ymin"]-15), label["label"], fill=color_map[label["label"]])
560
+ name = str(uuid.uuid4())[:4]
561
+ image.save(f"public/images/{name}.jpg")
562
+ results = {}
563
+ results["generated image with predicted box"] = f"/images/{name}.jpg"
564
+ results["predicted"] = predicted
565
+ return results
566
+ if task in ["image-classification", "image-to-text", "document-question-answering", "visual-question-answering"]:
567
+ img_url = data["image"]
568
+ text = None
569
+ if "text" in data:
570
+ text = data["text"]
571
+ results = inference({"img_url": img_url, "text": text})
572
+ return results
573
+ # AUDIO tasks
574
+ if task == "text-to-speech":
575
+ results = inference(data)
576
+ if "path" in results:
577
+ results["generated audio"] = results.pop("path")
578
+ return results
579
+ if task in ["automatic-speech-recognition", "audio-to-audio", "audio-classification"]:
580
+ audio_url = data["audio"]
581
+ results = inference({"audio_url": audio_url})
582
+ return results
583
+
584
+
585
+ def model_inference(model_id, data, hosted_on, task, huggingfacetoken=None):
586
+ if huggingfacetoken:
587
+ HUGGINGFACE_HEADERS = {
588
+ "Authorization": f"Bearer {huggingfacetoken}",
589
+ }
590
+ else:
591
+ HUGGINGFACE_HEADERS = None
592
+ if hosted_on == "unknown":
593
+ r = status(model_id)
594
+ logger.debug("Local Server Status: " + str(r))
595
+ if "loaded" in r and r["loaded"]:
596
+ hosted_on = "local"
597
+ else:
598
+ huggingfaceStatusUrl = f"https://api-inference.huggingface.co/status/{model_id}"
599
+ r = requests.get(huggingfaceStatusUrl, headers=HUGGINGFACE_HEADERS, proxies=PROXY)
600
+ logger.debug("Huggingface Status: " + str(r.json()))
601
+ if "loaded" in r and r["loaded"]:
602
+ hosted_on = "huggingface"
603
+ try:
604
+ if hosted_on == "local":
605
+ inference_result = local_model_inference(model_id, data, task)
606
+ elif hosted_on == "huggingface":
607
+ inference_result = huggingface_model_inference(model_id, data, task, huggingfacetoken)
608
+ except Exception as e:
609
+ print(e)
610
+ traceback.print_exc()
611
+ inference_result = {"error":{"message": str(e)}}
612
+ return inference_result
613
+
614
+
615
+ def get_model_status(model_id, url, headers, queue = None):
616
+ endpoint_type = "huggingface" if "huggingface" in url else "local"
617
+ if "huggingface" in url:
618
+ r = requests.get(url, headers=headers, proxies=PROXY)
619
+ else:
620
+ r = status(model_id)
621
+ if "loaded" in r and r["loaded"]:
622
+ if queue:
623
+ queue.put((model_id, True, endpoint_type))
624
+ return True
625
+ else:
626
+ if queue:
627
+ queue.put((model_id, False, None))
628
+ return False
629
+
630
+ def get_avaliable_models(candidates, topk=10, huggingfacetoken = None):
631
+ all_available_models = {"local": [], "huggingface": []}
632
+ threads = []
633
+ result_queue = Queue()
634
+ HUGGINGFACE_HEADERS = {
635
+ "Authorization": f"Bearer {huggingfacetoken}",
636
+ }
637
+ for candidate in candidates:
638
+ model_id = candidate["id"]
639
+
640
+ if inference_mode != "local":
641
+ huggingfaceStatusUrl = f"https://api-inference.huggingface.co/status/{model_id}"
642
+ thread = threading.Thread(target=get_model_status, args=(model_id, huggingfaceStatusUrl, HUGGINGFACE_HEADERS, result_queue))
643
+ threads.append(thread)
644
+ thread.start()
645
+
646
+ if inference_mode != "huggingface" and config["local_deployment"] != "minimal":
647
+ thread = threading.Thread(target=get_model_status, args=(model_id, "", {}, result_queue))
648
+ threads.append(thread)
649
+ thread.start()
650
+
651
+ result_count = len(threads)
652
+ while result_count:
653
+ model_id, status, endpoint_type = result_queue.get()
654
+ if status and model_id not in all_available_models:
655
+ all_available_models[endpoint_type].append(model_id)
656
+ if len(all_available_models["local"] + all_available_models["huggingface"]) >= topk:
657
+ break
658
+ result_count -= 1
659
+
660
+ for thread in threads:
661
+ thread.join()
662
+
663
+ return all_available_models
664
+
665
+ def collect_result(command, choose, inference_result):
666
+ result = {"task": command}
667
+ result["inference result"] = inference_result
668
+ result["choose model result"] = choose
669
+ logger.debug(f"inference result: {inference_result}")
670
+ return result
671
+
672
+
673
+ def run_task(input, command, results, openaikey = None, huggingfacetoken = None):
674
+ id = command["id"]
675
+ args = command["args"]
676
+ task = command["task"]
677
+ deps = command["dep"]
678
+ if deps[0] != -1:
679
+ dep_tasks = [results[dep] for dep in deps]
680
+ else:
681
+ dep_tasks = []
682
+
683
+ logger.debug(f"Run task: {id} - {task}")
684
+ logger.debug("Deps: " + json.dumps(dep_tasks))
685
+
686
+ if deps[0] != -1:
687
+ if "image" in args and "<GENERATED>-" in args["image"]:
688
+ resource_id = int(args["image"].split("-")[1])
689
+ if "generated image" in results[resource_id]["inference result"]:
690
+ args["image"] = results[resource_id]["inference result"]["generated image"]
691
+ if "audio" in args and "<GENERATED>-" in args["audio"]:
692
+ resource_id = int(args["audio"].split("-")[1])
693
+ if "generated audio" in results[resource_id]["inference result"]:
694
+ args["audio"] = results[resource_id]["inference result"]["generated audio"]
695
+ if "text" in args and "<GENERATED>-" in args["text"]:
696
+ resource_id = int(args["text"].split("-")[1])
697
+ if "generated text" in results[resource_id]["inference result"]:
698
+ args["text"] = results[resource_id]["inference result"]["generated text"]
699
+
700
+ text = image = audio = None
701
+ for dep_task in dep_tasks:
702
+ if "generated text" in dep_task["inference result"]:
703
+ text = dep_task["inference result"]["generated text"]
704
+ logger.debug("Detect the generated text of dependency task (from results):" + text)
705
+ elif "text" in dep_task["task"]["args"]:
706
+ text = dep_task["task"]["args"]["text"]
707
+ logger.debug("Detect the text of dependency task (from args): " + text)
708
+ if "generated image" in dep_task["inference result"]:
709
+ image = dep_task["inference result"]["generated image"]
710
+ logger.debug("Detect the generated image of dependency task (from results): " + image)
711
+ elif "image" in dep_task["task"]["args"]:
712
+ image = dep_task["task"]["args"]["image"]
713
+ logger.debug("Detect the image of dependency task (from args): " + image)
714
+ if "generated audio" in dep_task["inference result"]:
715
+ audio = dep_task["inference result"]["generated audio"]
716
+ logger.debug("Detect the generated audio of dependency task (from results): " + audio)
717
+ elif "audio" in dep_task["task"]["args"]:
718
+ audio = dep_task["task"]["args"]["audio"]
719
+ logger.debug("Detect the audio of dependency task (from args): " + audio)
720
+
721
+ if "image" in args and "<GENERATED>" in args["image"]:
722
+ if image:
723
+ args["image"] = image
724
+ if "audio" in args and "<GENERATED>" in args["audio"]:
725
+ if audio:
726
+ args["audio"] = audio
727
+ if "text" in args and "<GENERATED>" in args["text"]:
728
+ if text:
729
+ args["text"] = text
730
+
731
+ for resource in ["image", "audio"]:
732
+ if resource in args and not args[resource].startswith("public/") and len(args[resource]) > 0 and not args[resource].startswith("http"):
733
+ args[resource] = f"public/{args[resource]}"
734
+
735
+ if "-text-to-image" in command['task'] and "text" not in args:
736
+ logger.debug("control-text-to-image task, but text is empty, so we use control-generation instead.")
737
+ control = task.split("-")[0]
738
+
739
+ if control == "seg":
740
+ task = "image-segmentation"
741
+ command['task'] = task
742
+ elif control == "depth":
743
+ task = "depth-estimation"
744
+ command['task'] = task
745
+ else:
746
+ task = f"{control}-control"
747
+
748
+ command["args"] = args
749
+ logger.debug(f"parsed task: {command}")
750
+
751
+ if task.endswith("-text-to-image") or task.endswith("-control"):
752
+ if inference_mode != "huggingface":
753
+ if task.endswith("-text-to-image"):
754
+ control = task.split("-")[0]
755
+ best_model_id = f"lllyasviel/sd-controlnet-{control}"
756
+ else:
757
+ best_model_id = task
758
+ hosted_on = "local"
759
+ reason = "ControlNet is the best model for this task."
760
+ choose = {"id": best_model_id, "reason": reason}
761
+ logger.debug(f"chosen model: {choose}")
762
+ else:
763
+ logger.warning(f"Task {command['task']} is not available. ControlNet need to be deployed locally.")
764
+ record_case(success=False, **{"input": input, "task": command, "reason": f"Task {command['task']} is not available. ControlNet need to be deployed locally.", "op":"message"})
765
+ inference_result = {"error": f"service related to ControlNet is not available."}
766
+ results[id] = collect_result(command, "", inference_result)
767
+ return False
768
+ elif task in ["summarization", "translation", "conversational", "text-generation", "text2text-generation"]: # ChatGPT Can do
769
+ best_model_id = "ChatGPT"
770
+ reason = "ChatGPT performs well on some NLP tasks as well."
771
+ choose = {"id": best_model_id, "reason": reason}
772
+ messages = [{
773
+ "role": "user",
774
+ "content": f"[ {input} ] contains a task in JSON format {command}, 'task' indicates the task type and 'args' indicates the arguments required for the task. Don't explain the task to me, just help me do it and give me the result. The result must be in text form without any urls."
775
+ }]
776
+ response = chitchat(messages, openaikey)
777
+ results[id] = collect_result(command, choose, {"response": response})
778
+ return True
779
+ else:
780
+ if task not in MODELS_MAP:
781
+ logger.warning(f"no available models on {task} task.")
782
+ record_case(success=False, **{"input": input, "task": command, "reason": f"task not support: {command['task']}", "op":"message"})
783
+ inference_result = {"error": f"{command['task']} not found in available tasks."}
784
+ results[id] = collect_result(command, "", inference_result)
785
+ return False
786
+
787
+ candidates = MODELS_MAP[task][:20]
788
+ all_avaliable_models = get_avaliable_models(candidates, config["num_candidate_models"], huggingfacetoken)
789
+ all_avaliable_model_ids = all_avaliable_models["local"] + all_avaliable_models["huggingface"]
790
+ logger.debug(f"avaliable models on {command['task']}: {all_avaliable_models}")
791
+
792
+ if len(all_avaliable_model_ids) == 0:
793
+ logger.warning(f"no available models on {command['task']}")
794
+ record_case(success=False, **{"input": input, "task": command, "reason": f"no available models: {command['task']}", "op":"message"})
795
+ inference_result = {"error": f"no available models on {command['task']} task."}
796
+ results[id] = collect_result(command, "", inference_result)
797
+ return False
798
+
799
+ if len(all_avaliable_model_ids) == 1:
800
+ best_model_id = all_avaliable_model_ids[0]
801
+ hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
802
+ reason = "Only one model available."
803
+ choose = {"id": best_model_id, "reason": reason}
804
+ logger.debug(f"chosen model: {choose}")
805
+ else:
806
+ cand_models_info = [
807
+ {
808
+ "id": model["id"],
809
+ "inference endpoint": all_avaliable_models.get(
810
+ "local" if model["id"] in all_avaliable_models["local"] else "huggingface"
811
+ ),
812
+ "likes": model.get("likes"),
813
+ "description": model.get("description", "")[:config["max_description_length"]],
814
+ "language": model.get("language"),
815
+ "tags": model.get("tags"),
816
+ }
817
+ for model in candidates
818
+ if model["id"] in all_avaliable_model_ids
819
+ ]
820
+
821
+ choose_str = choose_model(input, command, cand_models_info, openaikey)
822
+ logger.debug(f"chosen model: {choose_str}")
823
+ try:
824
+ choose = json.loads(choose_str)
825
+ reason = choose["reason"]
826
+ best_model_id = choose["id"]
827
+ hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
828
+ except Exception as e:
829
+ logger.warning(f"the response [ {choose_str} ] is not a valid JSON, try to find the model id and reason in the response.")
830
+ choose_str = find_json(choose_str)
831
+ best_model_id, reason, choose = get_id_reason(choose_str)
832
+ hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
833
+ inference_result = model_inference(best_model_id, args, hosted_on, command['task'], huggingfacetoken)
834
+
835
+ if "error" in inference_result:
836
+ logger.warning(f"Inference error: {inference_result['error']}")
837
+ record_case(success=False, **{"input": input, "task": command, "reason": f"inference error: {inference_result['error']}", "op":"message"})
838
+ results[id] = collect_result(command, choose, inference_result)
839
+ return False
840
+
841
+ results[id] = collect_result(command, choose, inference_result)
842
+ return True
843
+
844
+ def chat_huggingface(messages, openaikey = None, huggingfacetoken = None, return_planning = False, return_results = False):
845
+ start = time.time()
846
+ context = messages[:-1]
847
+ input = messages[-1]["content"]
848
+ logger.info("*"*80)
849
+ logger.info(f"input: {input}")
850
+
851
+ task_str = parse_task(context, input, openaikey)
852
+ logger.info(task_str)
853
+
854
+ if "error" in task_str:
855
+ return str(task_str), {}
856
+ else:
857
+ task_str = task_str.strip()
858
+
859
+ try:
860
+ tasks = json.loads(task_str)
861
+ except Exception as e:
862
+ logger.debug(e)
863
+ response = chitchat(messages, openaikey)
864
+ record_case(success=False, **{"input": input, "task": task_str, "reason": "task parsing fail", "op":"chitchat"})
865
+ return response, {}
866
+
867
+ if task_str == "[]": # using LLM response for empty task
868
+ record_case(success=False, **{"input": input, "task": [], "reason": "task parsing fail: empty", "op": "chitchat"})
869
+ response = chitchat(messages, openaikey)
870
+ return response, {}
871
+
872
+ if len(tasks)==1 and tasks[0]["task"] in ["summarization", "translation", "conversational", "text-generation", "text2text-generation"]:
873
+ record_case(success=True, **{"input": input, "task": tasks, "reason": "task parsing fail: empty", "op": "chitchat"})
874
+ response = chitchat(messages, openaikey)
875
+ best_model_id = "ChatGPT"
876
+ reason = "ChatGPT performs well on some NLP tasks as well."
877
+ choose = {"id": best_model_id, "reason": reason}
878
+ return response, collect_result(tasks[0], choose, {"response": response})
879
+
880
+
881
+ tasks = unfold(tasks)
882
+ tasks = fix_dep(tasks)
883
+ logger.debug(tasks)
884
+
885
+ if return_planning:
886
+ return tasks
887
+
888
+ results = {}
889
+ threads = []
890
+ tasks = tasks[:]
891
+ d = dict()
892
+ retry = 0
893
+ while True:
894
+ num_threads = len(threads)
895
+ for task in tasks:
896
+ dep = task["dep"]
897
+ # logger.debug(f"d.keys(): {d.keys()}, dep: {dep}")
898
+ for dep_id in dep:
899
+ if dep_id >= task["id"]:
900
+ task["dep"] = [-1]
901
+ dep = [-1]
902
+ break
903
+ if len(list(set(dep).intersection(d.keys()))) == len(dep) or dep[0] == -1:
904
+ tasks.remove(task)
905
+ thread = threading.Thread(target=run_task, args=(input, task, d, openaikey, huggingfacetoken))
906
+ thread.start()
907
+ threads.append(thread)
908
+ if num_threads == len(threads):
909
+ time.sleep(0.5)
910
+ retry += 1
911
+ if retry > 80:
912
+ logger.debug("User has waited too long, Loop break.")
913
+ break
914
+ if len(tasks) == 0:
915
+ break
916
+ for thread in threads:
917
+ thread.join()
918
+
919
+ results = d.copy()
920
+
921
+ logger.debug(results)
922
+ if return_results:
923
+ return results
924
+
925
+ response = response_results(input, results, openaikey).strip()
926
+
927
+ end = time.time()
928
+ during = end - start
929
+
930
+ answer = {"message": response}
931
+ record_case(success=True, **{"input": input, "task": task_str, "results": results, "response": response, "during": during, "op":"response"})
932
+ logger.info(f"response: {response}")
933
+ return response, results
config.gradio.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openai:
2
+ key: gradio # "gradio" (set when request) or your_personal_key
3
+ huggingface:
4
+ token: # required: huggingface token @ https://huggingface.co/settings/tokens
5
+ dev: false
6
+ debug: true
7
+ log_file: logs/debug_TIMESTAMP.log
8
+ model: gpt-3.5-turbo
9
+ use_completion: false
10
+ inference_mode: hybrid # local, huggingface or hybrid
11
+ local_deployment: standard # minimal, standard or full
12
+ num_candidate_models: 5
13
+ max_description_length: 100
14
+ proxy:
15
+ logit_bias:
16
+ parse_task: 0.5
17
+ choose_model: 5
18
+ tprompt:
19
+ parse_task: >-
20
+ #1 Task Planning Stage: The AI assistant can parse user input to several tasks: [{"task": task, "id": task_id, "dep": dependency_task_id, "args": {"text": text or <GENERATED>-dep_id, "image": image_url or <GENERATED>-dep_id, "audio": audio_url or <GENERATED>-dep_id}}]. The special tag "<GENERATED>-dep_id" refer to the one genereted text/image/audio in the dependency task (Please consider whether the dependency task generates resources of this type.) and "dep_id" must be in "dep" list. The "dep" field denotes the ids of the previous prerequisite tasks which generate a new resource that the current task relies on. The "args" field must in ["text", "image", "audio"], nothing else. The task MUST be selected from the following options: "token-classification", "text2text-generation", "summarization", "translation", "question-answering", "conversational", "text-generation", "sentence-similarity", "tabular-classification", "object-detection", "image-classification", "image-to-image", "image-to-text", "text-to-image", "text-to-video", "visual-question-answering", "document-question-answering", "image-segmentation", "depth-estimation", "text-to-speech", "automatic-speech-recognition", "audio-to-audio", "audio-classification", "canny-control", "hed-control", "mlsd-control", "normal-control", "openpose-control", "canny-text-to-image", "depth-text-to-image", "hed-text-to-image", "mlsd-text-to-image", "normal-text-to-image", "openpose-text-to-image", "seg-text-to-image". There may be multiple tasks of the same type. Think step by step about all the tasks needed to resolve the user's request. Parse out as few tasks as possible while ensuring that the user request can be resolved. Pay attention to the dependencies and order among tasks. If the user input can't be parsed, you need to reply empty JSON [].
21
+ choose_model: >-
22
+ #2 Model Selection Stage: Given the user request and the parsed tasks, the AI assistant helps the user to select a suitable model from a list of models to process the user request. The assistant should focus more on the description of the model and find the model that has the most potential to solve requests and tasks. Also, prefer models with local inference endpoints for speed and stability.
23
+ response_results: >-
24
+ #4 Response Generation Stage: With the task execution logs, the AI assistant needs to describe the process and inference results.
25
+ demos_or_presteps:
26
+ parse_task: demos/demo_parse_task.json
27
+ choose_model: demos/demo_choose_model.json
28
+ response_results: demos/demo_response_results.json
29
+ prompt:
30
+ parse_task: The chat log [ {{context}} ] may contain the resources I mentioned. Now I input { {{input}} }. Pay attention to the input and output types of tasks and the dependencies between tasks.
31
+ choose_model: >-
32
+ Please choose the most suitable model from {{metas}} for the task {{task}}. The output must be in a strict JSON format: {"id": "id", "reason": "your detail reasons for the choice"}.
33
+ response_results: >-
34
+ Yes. Please first think carefully and directly answer my request based on the inference results. Some of the inferences may not always turn out to be correct and require you to make careful consideration in making decisions. Then please detail your workflow including the used models and inference results for my request in your friendly tone. Please filter out information that is not relevant to my request. Tell me the complete path or urls of files in inference results. If there is nothing in the results, please tell me you can't make it. }
data/p0_models.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
demos/demo_choose_model.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "role": "user",
4
+ "content": "{{input}}"
5
+ },
6
+ {
7
+ "role": "assistant",
8
+ "content": "{{task}}"
9
+ }
10
+ ]
demos/demo_parse_task.json ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "role": "user",
4
+ "content": "Give you some pictures e1.jpg, e2.png, e3.jpg, help me count the number of sheep?"
5
+ },
6
+ {
7
+ "role": "assistant",
8
+ "content": "[{\"task\": \"image-to-text\", \"id\": 0, \"dep\": [-1], \"args\": {\"image\": \"e1.jpg\" }}, {\"task\": \"object-detection\", \"id\": 1, \"dep\": [-1], \"args\": {\"image\": \"e1.jpg\" }}, {\"task\": \"visual-question-answering\", \"id\": 2, \"dep\": [1], \"args\": {\"image\": \"<GENERATED>-1\", \"text\": \"How many sheep in the picture\"}} }}, {\"task\": \"image-to-text\", \"id\": 3, \"dep\": [-1], \"args\": {\"image\": \"e2.png\" }}, {\"task\": \"object-detection\", \"id\": 4, \"dep\": [-1], \"args\": {\"image\": \"e2.png\" }}, {\"task\": \"visual-question-answering\", \"id\": 5, \"dep\": [4], \"args\": {\"image\": \"<GENERATED>-4\", \"text\": \"How many sheep in the picture\"}} }}, {\"task\": \"image-to-text\", \"id\": 6, \"dep\": [-1], \"args\": {\"image\": \"e3.jpg\" }}, {\"task\": \"object-detection\", \"id\": 7, \"dep\": [-1], \"args\": {\"image\": \"e3.jpg\" }}, {\"task\": \"visual-question-answering\", \"id\": 8, \"dep\": [7], \"args\": {\"image\": \"<GENERATED>-7\", \"text\": \"How many sheep in the picture\"}}]"
9
+ },
10
+
11
+ {
12
+ "role":"user",
13
+ "content":"Look at /e.jpg, can you tell me how many objects in the picture? Give me a picture and video similar to this one."
14
+ },
15
+ {
16
+ "role":"assistant",
17
+ "content":"[{\"task\": \"image-to-text\", \"id\": 0, \"dep\": [-1], \"args\": {\"image\": \"/e.jpg\" }}, {\"task\": \"object-detection\", \"id\": 1, \"dep\": [-1], \"args\": {\"image\": \"/e.jpg\" }}, {\"task\": \"visual-question-answering\", \"id\": 2, \"dep\": [1], \"args\": {\"image\": \"<GENERATED>-1\", \"text\": \"how many objects in the picture?\" }}, {\"task\": \"text-to-image\", \"id\": 3, \"dep\": [0], \"args\": {\"text\": \"<GENERATED-0>\" }}, {\"task\": \"image-to-image\", \"id\": 4, \"dep\": [-1], \"args\": {\"image\": \"/e.jpg\" }}, {\"task\": \"text-to-video\", \"id\": 5, \"dep\": [0], \"args\": {\"text\": \"<GENERATED-0>\" }}]"
18
+ },
19
+
20
+ {
21
+ "role":"user",
22
+ "content":"given a document /images/e.jpeg, answer me what is the student amount? And describe the image with your voice"
23
+ },
24
+ {
25
+ "role":"assistant",
26
+ "content":"{\"task\": \"document-question-answering\", \"id\": 0, \"dep\": [-1], \"args\": {\"image\": \"/images/e.jpeg\", \"text\": \"what is the student amount?\" }}, {\"task\": \"visual-question-answering\", \"id\": 1, \"dep\": [-1], \"args\": {\"image\": \"/images/e.jpeg\", \"text\": \"what is the student amount?\" }}, {\"task\": \"image-to-text\", \"id\": 2, \"dep\": [-1], \"args\": {\"image\": \"/images/e.jpg\" }}, {\"task\": \"text-to-speech\", \"id\": 3, \"dep\": [2], \"args\": {\"text\": \"<GENERATED>-2\" }}]"
27
+ },
28
+
29
+ {
30
+ "role": "user",
31
+ "content": "Given an image /example.jpg, first generate a hed image, then based on the hed image generate a new image where a girl is reading a book"
32
+ },
33
+ {
34
+ "role": "assistant",
35
+ "content": "[{\"task\": \"openpose-control\", \"id\": 0, \"dep\": [-1], \"args\": {\"image\": \"/example.jpg\" }}, {\"task\": \"openpose-text-to-image\", \"id\": 1, \"dep\": [0], \"args\": {\"text\": \"a girl is reading a book\", \"image\": \"<GENERATED>-0\" }}]"
36
+ },
37
+
38
+ {
39
+ "role": "user",
40
+ "content": "please show me a video and an image of (based on the text) 'a boy is running' and dub it"
41
+ },
42
+ {
43
+ "role": "assistant",
44
+ "content": "[{\"task\": \"text-to-video\", \"id\": 0, \"dep\": [-1], \"args\": {\"text\": \"a boy is running\" }}, {\"task\": \"text-to-speech\", \"id\": 1, \"dep\": [-1], \"args\": {\"text\": \"a boy is running\" }}, {\"task\": \"text-to-image\", \"id\": 2, \"dep\": [-1], \"args\": {\"text\": \"a boy is running\" }}]"
45
+ },
46
+
47
+
48
+ {
49
+ "role": "user",
50
+ "content": "please show me a joke and an image of cat"
51
+ },
52
+ {
53
+ "role": "assistant",
54
+ "content": "[{\"task\": \"conversational\", \"id\": 0, \"dep\": [-1], \"args\": {\"text\": \"please show me a joke of cat\" }}, {\"task\": \"text-to-image\", \"id\": 1, \"dep\": [-1], \"args\": {\"text\": \"a photo of cat\" }}]"
55
+ },
56
+
57
+ {
58
+ "role": "user",
59
+ "content": "give me a picture about a cut dog, then describe the image to me and tell a story about it"
60
+ },
61
+ {
62
+ "role": "assistant",
63
+ "content": "[{\"task\": \"text-to-image\", \"id\": 0, \"dep\": [-1], \"args\": {\"text\": \"a picture of a cut dog\" }}, {\"task\": \"image-to-text\", \"id\": 1, \"dep\": [0], \"args\": {\"image\": \"<GENERATED>-0\" }}, {\"task\": \"text-generation\", \"id\": 2, \"dep\": [1], \"args\": {\"text\": \"<GENERATED>-1\" }}, {\"task\": \"text-to-speech\", \"id\": 3, \"dep\": [2], \"args\": {\"text\": \"<GENERATED>-2\" }}]"
64
+ },
65
+
66
+ {
67
+ "role": "user",
68
+ "content": "give you a picture /example.jpg, what's in it and tell me a joke about it"
69
+ },
70
+ {
71
+ "role": "assistant",
72
+ "content": "[{\"task\": \"image-to-text\", \"id\": 0, \"dep\": [-1], \"args\": {\"image\": \"/example.jpg\" }}, {\"task\": \"object-detection\", \"id\": 1, \"dep\": [-1], \"args\": {\"image\": \"/example.jpg\" }}, {\"task\": \"conversational\", \"id\": 2, \"dep\": [0], \"args\": {\"text\": \"<GENERATED>-0\" }}, {\"task\": \"text-to-speech\", \"id\": 3, \"dep\": [2], \"args\": {\"text\": \"<GENERATED>-1\" }}]"
73
+ }
74
+ ]
demos/demo_response_results.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "role": "user",
4
+ "content": "{{input}}"
5
+ },
6
+ {
7
+ "role": "assistant",
8
+ "content": "Before give you a response, I want to introduce my workflow for your request, which is shown in the following JSON data: {{processes}}. Do you have any demands regarding my response?"
9
+ }
10
+ ]
get_token_ids.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tiktoken
2
+
3
+ encodings = {
4
+ "gpt-3.5-turbo": tiktoken.get_encoding("cl100k_base"),
5
+ "gpt-3.5-turbo-0301": tiktoken.get_encoding("cl100k_base"),
6
+ "text-davinci-003": tiktoken.get_encoding("p50k_base"),
7
+ "text-davinci-002": tiktoken.get_encoding("p50k_base"),
8
+ "text-davinci-001": tiktoken.get_encoding("r50k_base"),
9
+ "text-curie-001": tiktoken.get_encoding("r50k_base"),
10
+ "text-babbage-001": tiktoken.get_encoding("r50k_base"),
11
+ "text-ada-001": tiktoken.get_encoding("r50k_base"),
12
+ "davinci": tiktoken.get_encoding("r50k_base"),
13
+ "curie": tiktoken.get_encoding("r50k_base"),
14
+ "babbage": tiktoken.get_encoding("r50k_base"),
15
+ "ada": tiktoken.get_encoding("r50k_base"),
16
+ }
17
+
18
+ max_length = {
19
+ "gpt-3.5-turbo": 4096,
20
+ "gpt-3.5-turbo-0301": 4096,
21
+ "text-davinci-003": 4096,
22
+ "text-davinci-002": 4096,
23
+ "text-davinci-001": 2049,
24
+ "text-curie-001": 2049,
25
+ "text-babbage-001": 2049,
26
+ "text-ada-001": 2049,
27
+ "davinci": 2049,
28
+ "curie": 2049,
29
+ "babbage": 2049,
30
+ "ada": 2049,
31
+ }
32
+
33
+
34
+ def count_tokens(model_name, text):
35
+ return len(encodings[model_name].encode(text))
36
+
37
+
38
+ def get_max_context_length(model_name):
39
+ return max_length[model_name]
40
+
41
+
42
+ def get_token_ids_for_task_parsing(model_name):
43
+ text = """{"task": "text-classification", "token-classification", "text2text-generation", "summarization", "translation", "question-answering", "conversational", "text-generation", "sentence-similarity", "tabular-classification", "object-detection", "image-classification", "image-to-image", "image-to-text", "text-to-image", "visual-question-answering", "document-question-answering", "image-segmentation", "text-to-speech", "text-to-video", "automatic-speech-recognition", "audio-to-audio", "audio-classification", "canny-control", "hed-control", "mlsd-control", "normal-control", "openpose-control", "canny-text-to-image", "depth-text-to-image", "hed-text-to-image", "mlsd-text-to-image", "normal-text-to-image", "openpose-text-to-image", "seg-text-to-image", "args", "text", "path", "dep", "id", "<GENERATED>-"}"""
44
+ res = encodings[model_name].encode(text)
45
+ res = list(set(res))
46
+ return res
47
+
48
+
49
+ def get_token_ids_for_choose_model(model_name):
50
+ text = """{"id": "reason"}"""
51
+ res = encodings[model_name].encode(text)
52
+ res = list(set(res))
53
+ return res
models_server.py ADDED
@@ -0,0 +1,779 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import random
4
+ import uuid
5
+ import numpy as np
6
+ from transformers import pipeline
7
+ from diffusers import (
8
+ DiffusionPipeline,
9
+ StableDiffusionControlNetPipeline,
10
+ ControlNetModel,
11
+ UniPCMultistepScheduler,
12
+ )
13
+ from diffusers.utils import load_image
14
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
15
+ from diffusers.utils import export_to_video
16
+ from transformers import BlipProcessor, BlipForConditionalGeneration
17
+ from transformers import (
18
+ TrOCRProcessor,
19
+ VisionEncoderDecoderModel,
20
+ ViTImageProcessor,
21
+ AutoTokenizer,
22
+ )
23
+ from datasets import load_dataset
24
+ from PIL import Image
25
+ import io
26
+ from torchvision import transforms
27
+ import torch
28
+ import torchaudio
29
+ from speechbrain.pretrained import WaveformEnhancement
30
+ import joblib
31
+ from huggingface_hub import hf_hub_url, cached_download
32
+ from transformers import AutoImageProcessor, TimesformerForVideoClassification
33
+ from transformers import (
34
+ MaskFormerFeatureExtractor,
35
+ MaskFormerForInstanceSegmentation,
36
+ AutoFeatureExtractor,
37
+ )
38
+ from controlnet_aux import (
39
+ OpenposeDetector,
40
+ MLSDdetector,
41
+ HEDdetector,
42
+ CannyDetector,
43
+ MidasDetector,
44
+ )
45
+ from controlnet_aux.open_pose.body import Body
46
+ from controlnet_aux.mlsd.models.mbv2_mlsd_large import MobileV2_MLSD_Large
47
+ from controlnet_aux.hed import Network
48
+ from transformers import DPTForDepthEstimation, DPTFeatureExtractor
49
+ import warnings
50
+ import time
51
+ from espnet2.bin.tts_inference import Text2Speech
52
+ import soundfile as sf
53
+ from asteroid.models import BaseModel
54
+ import traceback
55
+ import os
56
+ import yaml
57
+
58
+ warnings.filterwarnings("ignore")
59
+
60
+ parser = argparse.ArgumentParser()
61
+ parser.add_argument("--config", type=str, default="config.yaml")
62
+ args = parser.parse_args()
63
+
64
+ if __name__ != "__main__":
65
+ args.config = "config.gradio.yaml"
66
+
67
+ logger = logging.getLogger(__name__)
68
+ logger.setLevel(logging.INFO)
69
+ handler = logging.StreamHandler()
70
+ handler.setLevel(logging.INFO)
71
+ formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
72
+ handler.setFormatter(formatter)
73
+ logger.addHandler(handler)
74
+
75
+ config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader)
76
+
77
+ local_deployment = config["local_deployment"]
78
+ if config["inference_mode"] == "huggingface":
79
+ local_deployment = "none"
80
+
81
+ PROXY = None
82
+ if config["proxy"]:
83
+ PROXY = {
84
+ "https": config["proxy"],
85
+ }
86
+
87
+ start = time.time()
88
+
89
+ # local_models = "models/"
90
+ local_models = ""
91
+
92
+
93
+ def load_pipes(local_deployment):
94
+ other_pipes = {}
95
+ standard_pipes = {}
96
+ controlnet_sd_pipes = {}
97
+ if local_deployment in ["full"]:
98
+ other_pipes = {
99
+ # "Salesforce/blip-image-captioning-large": {
100
+ # "model": BlipForConditionalGeneration.from_pretrained(f"Salesforce/blip-image-captioning-large"),
101
+ # "processor": BlipProcessor.from_pretrained(f"Salesforce/blip-image-captioning-large"),
102
+ # "device": "cpu"
103
+ # },
104
+ # "damo-vilab/text-to-video-ms-1.7b": {
105
+ # "model": DiffusionPipeline.from_pretrained(
106
+ # f"{local_models}damo-vilab/text-to-video-ms-1.7b",
107
+ # torch_dtype=torch.float16,
108
+ # variant="fp16",
109
+ # ),
110
+ # "device": "cpu",
111
+ # },
112
+ # "facebook/maskformer-swin-large-ade": {
113
+ # "model": MaskFormerForInstanceSegmentation.from_pretrained(f"facebook/maskformer-swin-large-ade"),
114
+ # "feature_extractor" : AutoFeatureExtractor.from_pretrained("facebook/maskformer-swin-large-ade"),
115
+ # "device": "cpu"
116
+ # },
117
+ # "microsoft/trocr-base-printed": {
118
+ # "processor": TrOCRProcessor.from_pretrained(f"microsoft/trocr-base-printed"),
119
+ # "model": VisionEncoderDecoderModel.from_pretrained(f"microsoft/trocr-base-printed"),
120
+ # "device": "cpu"
121
+ # },
122
+ # "microsoft/trocr-base-handwritten": {
123
+ # "processor": TrOCRProcessor.from_pretrained(f"microsoft/trocr-base-handwritten"),
124
+ # "model": VisionEncoderDecoderModel.from_pretrained(f"microsoft/trocr-base-handwritten"),
125
+ # "device": "cpu"
126
+ # },
127
+ # "JorisCos/DCCRNet_Libri1Mix_enhsingle_16k": {
128
+ # "model": BaseModel.from_pretrained(
129
+ # "JorisCos/DCCRNet_Libri1Mix_enhsingle_16k"
130
+ # ),
131
+ # "device": "cpu",
132
+ # },
133
+ # "CompVis/stable-diffusion-v1-4": {
134
+ # "model": DiffusionPipeline.from_pretrained(f"CompVis/stable-diffusion-v1-4"),
135
+ # "device": "cpu"
136
+ # },
137
+ # "stabilityai/stable-diffusion-2-1": {
138
+ # "model": DiffusionPipeline.from_pretrained(f"stabilityai/stable-diffusion-2-1"),
139
+ # "device": "cpu"
140
+ # },
141
+ # "microsoft/speecht5_tts":{
142
+ # "processor": SpeechT5Processor.from_pretrained(f"microsoft/speecht5_tts"),
143
+ # "model": SpeechT5ForTextToSpeech.from_pretrained(f"microsoft/speecht5_tts"),
144
+ # "vocoder": SpeechT5HifiGan.from_pretrained(f"microsoft/speecht5_hifigan"),
145
+ # "embeddings_dataset": load_dataset(f"Matthijs/cmu-arctic-xvectors", split="validation"),
146
+ # "device": "cpu"
147
+ # },
148
+ # "speechbrain/mtl-mimic-voicebank": {
149
+ # "model": WaveformEnhancement.from_hparams(source="speechbrain/mtl-mimic-voicebank", savedir="models/mtl-mimic-voicebank"),
150
+ # "device": "cpu"
151
+ # },
152
+ # "microsoft/speecht5_vc": {
153
+ # "processor": SpeechT5Processor.from_pretrained(
154
+ # f"{local_models}microsoft/speecht5_vc"
155
+ # ),
156
+ # "model": SpeechT5ForSpeechToSpeech.from_pretrained(
157
+ # f"{local_models}microsoft/speecht5_vc"
158
+ # ),
159
+ # "vocoder": SpeechT5HifiGan.from_pretrained(
160
+ # f"{local_models}microsoft/speecht5_hifigan"
161
+ # ),
162
+ # "embeddings_dataset": load_dataset(
163
+ # f"{local_models}Matthijs/cmu-arctic-xvectors", split="validation"
164
+ # ),
165
+ # "device": "cpu",
166
+ # },
167
+ # "julien-c/wine-quality": {
168
+ # "model": joblib.load(cached_download(hf_hub_url("julien-c/wine-quality", "sklearn_model.joblib")))
169
+ # },
170
+ # "facebook/timesformer-base-finetuned-k400": {
171
+ # "processor": AutoImageProcessor.from_pretrained(f"facebook/timesformer-base-finetuned-k400"),
172
+ # "model": TimesformerForVideoClassification.from_pretrained(f"facebook/timesformer-base-finetuned-k400"),
173
+ # "device": "cpu"
174
+ # },
175
+ "facebook/maskformer-swin-base-coco": {
176
+ "feature_extractor": MaskFormerFeatureExtractor.from_pretrained(
177
+ f"{local_models}facebook/maskformer-swin-base-coco"
178
+ ),
179
+ "model": MaskFormerForInstanceSegmentation.from_pretrained(
180
+ f"{local_models}facebook/maskformer-swin-base-coco"
181
+ ),
182
+ "device": "cpu",
183
+ },
184
+ # "Intel/dpt-hybrid-midas": {
185
+ # "model": DPTForDepthEstimation.from_pretrained(
186
+ # f"{local_models}Intel/dpt-hybrid-midas", low_cpu_mem_usage=True
187
+ # ),
188
+ # "feature_extractor": DPTFeatureExtractor.from_pretrained(
189
+ # f"{local_models}Intel/dpt-hybrid-midas"
190
+ # ),
191
+ # "device": "cpu",
192
+ # },
193
+ }
194
+
195
+ if local_deployment in ["full", "standard"]:
196
+ standard_pipes = {
197
+ # "nlpconnect/vit-gpt2-image-captioning":{
198
+ # "model": VisionEncoderDecoderModel.from_pretrained(f"{local_models}nlpconnect/vit-gpt2-image-captioning"),
199
+ # "feature_extractor": ViTImageProcessor.from_pretrained(f"{local_models}nlpconnect/vit-gpt2-image-captioning"),
200
+ # "tokenizer": AutoTokenizer.from_pretrained(f"{local_models}nlpconnect/vit-gpt2-image-captioning"),
201
+ # "device": "cpu"
202
+ # },
203
+ # "espnet/kan-bayashi_ljspeech_vits": {
204
+ # "model": Text2Speech.from_pretrained(
205
+ # "espnet/kan-bayashi_ljspeech_vits"
206
+ # ),
207
+ # "device": "cpu",
208
+ # },
209
+ # "lambdalabs/sd-image-variations-diffusers": {
210
+ # "model": DiffusionPipeline.from_pretrained(f"{local_models}lambdalabs/sd-image-variations-diffusers"), #torch_dtype=torch.float16
211
+ # "device": "cpu"
212
+ # },
213
+ # "runwayml/stable-diffusion-v1-5": {
214
+ # "model": DiffusionPipeline.from_pretrained(
215
+ # f"{local_models}runwayml/stable-diffusion-v1-5"
216
+ # ),
217
+ # "device": "cpu",
218
+ # },
219
+ # "superb/wav2vec2-base-superb-ks": {
220
+ # "model": pipeline(task="audio-classification", model=f"superb/wav2vec2-base-superb-ks"),
221
+ # "device": "cpu"
222
+ # },
223
+ # "openai/whisper-base": {
224
+ # "model": pipeline(
225
+ # task="automatic-speech-recognition",
226
+ # model=f"{local_models}openai/whisper-base",
227
+ # ),
228
+ # "device": "cpu",
229
+ # },
230
+ # "microsoft/speecht5_asr": {
231
+ # "model": pipeline(task="automatic-speech-recognition", model=f"{local_models}microsoft/speecht5_asr"),
232
+ # "device": "cpu"
233
+ # },
234
+ "Intel/dpt-large": {
235
+ "model": pipeline(
236
+ task="depth-estimation", model=f"{local_models}Intel/dpt-large"
237
+ ),
238
+ "device": "cpu",
239
+ },
240
+ # "microsoft/beit-base-patch16-224-pt22k-ft22k": {
241
+ # "model": pipeline(task="image-classification", model=f"microsoft/beit-base-patch16-224-pt22k-ft22k"),
242
+ # "device": "cpu"
243
+ # },
244
+ "facebook/detr-resnet-50-panoptic": {
245
+ "model": pipeline(
246
+ task="image-segmentation",
247
+ model=f"{local_models}facebook/detr-resnet-50-panoptic",
248
+ ),
249
+ "device": "cpu",
250
+ },
251
+ "facebook/detr-resnet-101": {
252
+ "model": pipeline(
253
+ task="object-detection",
254
+ model=f"{local_models}facebook/detr-resnet-101",
255
+ ),
256
+ "device": "cpu",
257
+ },
258
+ # "openai/clip-vit-large-patch14": {
259
+ # "model": pipeline(task="zero-shot-image-classification", model=f"openai/clip-vit-large-patch14"),
260
+ # "device": "cpu"
261
+ # },
262
+ # "google/owlvit-base-patch32": {
263
+ # "model": pipeline(task="zero-shot-object-detection", model=f"{local_models}google/owlvit-base-patch32"),
264
+ # "device": "cpu"
265
+ # },
266
+ # "microsoft/DialoGPT-medium": {
267
+ # "model": pipeline(task="conversational", model=f"microsoft/DialoGPT-medium"),
268
+ # "device": "cpu"
269
+ # },
270
+ # "bert-base-uncased": {
271
+ # "model": pipeline(task="fill-mask", model=f"bert-base-uncased"),
272
+ # "device": "cpu"
273
+ # },
274
+ # "deepset/roberta-base-squad2": {
275
+ # "model": pipeline(task = "question-answering", model=f"deepset/roberta-base-squad2"),
276
+ # "device": "cpu"
277
+ # },
278
+ # "facebook/bart-large-cnn": {
279
+ # "model": pipeline(task="summarization", model=f"facebook/bart-large-cnn"),
280
+ # "device": "cpu"
281
+ # },
282
+ # "google/tapas-base-finetuned-wtq": {
283
+ # "model": pipeline(task="table-question-answering", model=f"google/tapas-base-finetuned-wtq"),
284
+ # "device": "cpu"
285
+ # },
286
+ # "distilbert-base-uncased-finetuned-sst-2-english": {
287
+ # "model": pipeline(task="text-classification", model=f"distilbert-base-uncased-finetuned-sst-2-english"),
288
+ # "device": "cpu"
289
+ # },
290
+ # "gpt2": {
291
+ # "model": pipeline(task="text-generation", model="gpt2"),
292
+ # "device": "cpu"
293
+ # },
294
+ # "mrm8488/t5-base-finetuned-question-generation-ap": {
295
+ # "model": pipeline(task="text2text-generation", model=f"mrm8488/t5-base-finetuned-question-generation-ap"),
296
+ # "device": "cpu"
297
+ # },
298
+ # "Jean-Baptiste/camembert-ner": {
299
+ # "model": pipeline(task="token-classification", model=f"Jean-Baptiste/camembert-ner", aggregation_strategy="simple"),
300
+ # "device": "cpu"
301
+ # },
302
+ # "t5-base": {
303
+ # "model": pipeline(task="translation", model=f"t5-base"),
304
+ # "device": "cpu"
305
+ # },
306
+ # "impira/layoutlm-document-qa": {
307
+ # "model": pipeline(task="document-question-answering", model=f"{local_models}impira/layoutlm-document-qa"),
308
+ # "device": "cpu"
309
+ # },
310
+ "ydshieh/vit-gpt2-coco-en": {
311
+ "model": pipeline(
312
+ task="image-to-text",
313
+ model=f"{local_models}ydshieh/vit-gpt2-coco-en",
314
+ ),
315
+ "device": "cpu",
316
+ },
317
+ # "dandelin/vilt-b32-finetuned-vqa": {
318
+ # "model": pipeline(
319
+ # task="visual-question-answering",
320
+ # model=f"{local_models}dandelin/vilt-b32-finetuned-vqa",
321
+ # ),
322
+ # "device": "cpu",
323
+ # },
324
+ }
325
+
326
+ if local_deployment in ["full", "standard", "minimal"]:
327
+ controlnet = ControlNetModel.from_pretrained(
328
+ f"{local_models}lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16
329
+ )
330
+ controlnetpipe = StableDiffusionControlNetPipeline.from_pretrained(
331
+ f"{local_models}runwayml/stable-diffusion-v1-5",
332
+ controlnet=controlnet,
333
+ torch_dtype=torch.float16,
334
+ )
335
+
336
+ hed_network = HEDdetector.from_pretrained("lllyasviel/ControlNet")
337
+
338
+ pipes = {**standard_pipes, **other_pipes}
339
+ return pipes
340
+
341
+
342
+ pipes = load_pipes(local_deployment)
343
+
344
+ end = time.time()
345
+ during = end - start
346
+
347
+ print(f"[ ready ] {during}s")
348
+
349
+
350
+ def running():
351
+ return {"running": True}
352
+
353
+
354
+ def status(model_id):
355
+ disabled_models = [
356
+ "microsoft/trocr-base-printed",
357
+ "microsoft/trocr-base-handwritten",
358
+ ]
359
+ if model_id in pipes.keys() and model_id not in disabled_models:
360
+ print(f"[ check {model_id} ] success")
361
+ return {"loaded": True}
362
+ else:
363
+ print(f"[ check {model_id} ] failed")
364
+ return {"loaded": False}
365
+
366
+
367
+ def models(model_id, data):
368
+ while "using" in pipes[model_id] and pipes[model_id]["using"]:
369
+ print(f"[ inference {model_id} ] waiting")
370
+ time.sleep(0.1)
371
+ pipes[model_id]["using"] = True
372
+ print(f"[ inference {model_id} ] start")
373
+
374
+ start = time.time()
375
+
376
+ pipe = pipes[model_id]["model"]
377
+
378
+ if "device" in pipes[model_id]:
379
+ try:
380
+ pipe.to(pipes[model_id]["device"])
381
+ except:
382
+ pipe.device = torch.device(pipes[model_id]["device"])
383
+ pipe.model.to(pipes[model_id]["device"])
384
+
385
+ result = None
386
+ try:
387
+ # text to video
388
+ if model_id == "damo-vilab/text-to-video-ms-1.7b":
389
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(
390
+ pipe.scheduler.config
391
+ )
392
+ # pipe.enable_model_cpu_offload()
393
+ prompt = data["text"]
394
+ video_frames = pipe(prompt, num_inference_steps=50, num_frames=40).frames
395
+ file_name = str(uuid.uuid4())[:4]
396
+ video_path = export_to_video(video_frames, f"public/videos/{file_name}.mp4")
397
+
398
+ new_file_name = str(uuid.uuid4())[:4]
399
+ os.system(
400
+ f"ffmpeg -i {video_path} -vcodec libx264 public/videos/{new_file_name}.mp4"
401
+ )
402
+
403
+ if os.path.exists(f"public/videos/{new_file_name}.mp4"):
404
+ result = {"path": f"/videos/{new_file_name}.mp4"}
405
+ else:
406
+ result = {"path": f"/videos/{file_name}.mp4"}
407
+
408
+ # controlnet
409
+ if model_id.startswith("lllyasviel/sd-controlnet-"):
410
+ pipe.controlnet.to("cpu")
411
+ pipe.controlnet = pipes[model_id]["control"].to(pipes[model_id]["device"])
412
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
413
+ control_image = load_image(data["img_url"])
414
+ # generator = torch.manual_seed(66)
415
+ out_image: Image = pipe(
416
+ data["text"], num_inference_steps=20, image=control_image
417
+ ).images[0]
418
+ file_name = str(uuid.uuid4())[:4]
419
+ out_image.save(f"public/images/{file_name}.png")
420
+ result = {"path": f"/images/{file_name}.png"}
421
+
422
+ if model_id.endswith("-control"):
423
+ image = load_image(data["img_url"])
424
+ if "scribble" in model_id:
425
+ control = pipe(image, scribble=True)
426
+ elif "canny" in model_id:
427
+ control = pipe(image, low_threshold=100, high_threshold=200)
428
+ else:
429
+ control = pipe(image)
430
+ file_name = str(uuid.uuid4())[:4]
431
+ control.save(f"public/images/{file_name}.png")
432
+ result = {"path": f"/images/{file_name}.png"}
433
+
434
+ # image to image
435
+ if model_id == "lambdalabs/sd-image-variations-diffusers":
436
+ im = load_image(data["img_url"])
437
+ file_name = str(uuid.uuid4())[:4]
438
+ with open(f"public/images/{file_name}.png", "wb") as f:
439
+ f.write(data)
440
+ tform = transforms.Compose(
441
+ [
442
+ transforms.ToTensor(),
443
+ transforms.Resize(
444
+ (224, 224),
445
+ interpolation=transforms.InterpolationMode.BICUBIC,
446
+ antialias=False,
447
+ ),
448
+ transforms.Normalize(
449
+ [0.48145466, 0.4578275, 0.40821073],
450
+ [0.26862954, 0.26130258, 0.27577711],
451
+ ),
452
+ ]
453
+ )
454
+ inp = tform(im).to(pipes[model_id]["device"]).unsqueeze(0)
455
+ out = pipe(inp, guidance_scale=3)
456
+ out["images"][0].save(f"public/images/{file_name}.jpg")
457
+ result = {"path": f"/images/{file_name}.jpg"}
458
+
459
+ # image to text
460
+ if model_id == "Salesforce/blip-image-captioning-large":
461
+ raw_image = load_image(data["img_url"]).convert("RGB")
462
+ text = data["text"]
463
+ inputs = pipes[model_id]["processor"](raw_image, return_tensors="pt").to(
464
+ pipes[model_id]["device"]
465
+ )
466
+ out = pipe.generate(**inputs)
467
+ caption = pipes[model_id]["processor"].decode(
468
+ out[0], skip_special_tokens=True
469
+ )
470
+ result = {"generated text": caption}
471
+ if model_id == "ydshieh/vit-gpt2-coco-en":
472
+ img_url = data["img_url"]
473
+ generated_text = pipe(img_url)[0]["generated_text"]
474
+ result = {"generated text": generated_text}
475
+ if model_id == "nlpconnect/vit-gpt2-image-captioning":
476
+ image = load_image(data["img_url"]).convert("RGB")
477
+ pixel_values = pipes[model_id]["feature_extractor"](
478
+ images=image, return_tensors="pt"
479
+ ).pixel_values
480
+ pixel_values = pixel_values.to(pipes[model_id]["device"])
481
+ generated_ids = pipe.generate(
482
+ pixel_values, **{"max_length": 200, "num_beams": 1}
483
+ )
484
+ generated_text = pipes[model_id]["tokenizer"].batch_decode(
485
+ generated_ids, skip_special_tokens=True
486
+ )[0]
487
+ result = {"generated text": generated_text}
488
+ # image to text: OCR
489
+ if (
490
+ model_id == "microsoft/trocr-base-printed"
491
+ or model_id == "microsoft/trocr-base-handwritten"
492
+ ):
493
+ image = load_image(data["img_url"]).convert("RGB")
494
+ pixel_values = pipes[model_id]["processor"](
495
+ image, return_tensors="pt"
496
+ ).pixel_values
497
+ pixel_values = pixel_values.to(pipes[model_id]["device"])
498
+ generated_ids = pipe.generate(pixel_values)
499
+ generated_text = pipes[model_id]["processor"].batch_decode(
500
+ generated_ids, skip_special_tokens=True
501
+ )[0]
502
+ result = {"generated text": generated_text}
503
+
504
+ # text to image
505
+ if model_id == "runwayml/stable-diffusion-v1-5":
506
+ file_name = str(uuid.uuid4())[:4]
507
+ text = data["text"]
508
+ out = pipe(prompt=text)
509
+ out["images"][0].save(f"public/images/{file_name}.jpg")
510
+ result = {"path": f"/images/{file_name}.jpg"}
511
+
512
+ # object detection
513
+ if (
514
+ model_id == "google/owlvit-base-patch32"
515
+ or model_id == "facebook/detr-resnet-101"
516
+ ):
517
+ img_url = data["img_url"]
518
+ open_types = [
519
+ "cat",
520
+ "couch",
521
+ "person",
522
+ "car",
523
+ "dog",
524
+ "horse",
525
+ "sheep",
526
+ "cow",
527
+ "elephant",
528
+ "bear",
529
+ "zebra",
530
+ "giraffe",
531
+ "backpack",
532
+ "umbrella",
533
+ "handbag",
534
+ "tie",
535
+ "suitcase",
536
+ "frisbee",
537
+ "skis",
538
+ "snowboard",
539
+ "sports ball",
540
+ "kite",
541
+ "baseball bat",
542
+ "baseball glove",
543
+ "skateboard",
544
+ "surfboard",
545
+ "tennis racket",
546
+ "bottle",
547
+ "wine glass",
548
+ "cup",
549
+ "fork",
550
+ "knife",
551
+ "spoon",
552
+ "bowl",
553
+ "banana",
554
+ "apple",
555
+ "sandwich",
556
+ "orange",
557
+ "broccoli",
558
+ "carrot",
559
+ "hot dog",
560
+ "pizza",
561
+ "donut",
562
+ "cake",
563
+ "chair",
564
+ "couch",
565
+ "potted plant",
566
+ "bed",
567
+ "dining table",
568
+ "toilet",
569
+ "tv",
570
+ "laptop",
571
+ "mouse",
572
+ "remote",
573
+ "keyboard",
574
+ "cell phone",
575
+ "microwave",
576
+ "oven",
577
+ "toaster",
578
+ "sink",
579
+ "refrigerator",
580
+ "book",
581
+ "clock",
582
+ "vase",
583
+ "scissors",
584
+ "teddy bear",
585
+ "hair drier",
586
+ "toothbrush",
587
+ "traffic light",
588
+ "fire hydrant",
589
+ "stop sign",
590
+ "parking meter",
591
+ "bench",
592
+ "bird",
593
+ ]
594
+ result = pipe(img_url, candidate_labels=open_types)
595
+
596
+ # VQA
597
+ if model_id == "dandelin/vilt-b32-finetuned-vqa":
598
+ question = data["text"]
599
+ img_url = data["img_url"]
600
+ result = pipe(question=question, image=img_url)
601
+
602
+ # DQA
603
+ if model_id == "impira/layoutlm-document-qa":
604
+ question = data["text"]
605
+ img_url = data["img_url"]
606
+ result = pipe(img_url, question)
607
+
608
+ # depth-estimation
609
+ if model_id == "Intel/dpt-large":
610
+ output = pipe(data["img_url"])
611
+ image = output["depth"]
612
+ name = str(uuid.uuid4())[:4]
613
+ image.save(f"public/images/{name}.jpg")
614
+ result = {"path": f"/images/{name}.jpg"}
615
+
616
+ if model_id == "Intel/dpt-hybrid-midas" and model_id == "Intel/dpt-large":
617
+ image = load_image(data["img_url"])
618
+ inputs = pipes[model_id]["feature_extractor"](
619
+ images=image, return_tensors="pt"
620
+ )
621
+ with torch.no_grad():
622
+ outputs = pipe(**inputs)
623
+ predicted_depth = outputs.predicted_depth
624
+ prediction = torch.nn.functional.interpolate(
625
+ predicted_depth.unsqueeze(1),
626
+ size=image.size[::-1],
627
+ mode="bicubic",
628
+ align_corners=False,
629
+ )
630
+ output = prediction.squeeze().cpu().numpy()
631
+ formatted = (output * 255 / np.max(output)).astype("uint8")
632
+ image = Image.fromarray(formatted)
633
+ name = str(uuid.uuid4())[:4]
634
+ image.save(f"public/images/{name}.jpg")
635
+ result = {"path": f"/images/{name}.jpg"}
636
+
637
+ # TTS
638
+ if model_id == "espnet/kan-bayashi_ljspeech_vits":
639
+ text = data["text"]
640
+ wav = pipe(text)["wav"]
641
+ name = str(uuid.uuid4())[:4]
642
+ sf.write(f"public/audios/{name}.wav", wav.cpu().numpy(), pipe.fs, "PCM_16")
643
+ result = {"path": f"/audios/{name}.wav"}
644
+
645
+ if model_id == "microsoft/speecht5_tts":
646
+ text = data["text"]
647
+ inputs = pipes[model_id]["processor"](text=text, return_tensors="pt")
648
+ embeddings_dataset = pipes[model_id]["embeddings_dataset"]
649
+ speaker_embeddings = (
650
+ torch.tensor(embeddings_dataset[7306]["xvector"])
651
+ .unsqueeze(0)
652
+ .to(pipes[model_id]["device"])
653
+ )
654
+ pipes[model_id]["vocoder"].to(pipes[model_id]["device"])
655
+ speech = pipe.generate_speech(
656
+ inputs["input_ids"].to(pipes[model_id]["device"]),
657
+ speaker_embeddings,
658
+ vocoder=pipes[model_id]["vocoder"],
659
+ )
660
+ name = str(uuid.uuid4())[:4]
661
+ sf.write(
662
+ f"public/audios/{name}.wav", speech.cpu().numpy(), samplerate=16000
663
+ )
664
+ result = {"path": f"/audios/{name}.wav"}
665
+
666
+ # ASR
667
+ if model_id == "openai/whisper-base" or model_id == "microsoft/speecht5_asr":
668
+ audio_url = data["audio_url"]
669
+ result = {"text": pipe(audio_url)["text"]}
670
+
671
+ # audio to audio
672
+ if model_id == "JorisCos/DCCRNet_Libri1Mix_enhsingle_16k":
673
+ audio_url = data["audio_url"]
674
+ wav, sr = torchaudio.load(audio_url)
675
+ with torch.no_grad():
676
+ result_wav = pipe(wav.to(pipes[model_id]["device"]))
677
+ name = str(uuid.uuid4())[:4]
678
+ sf.write(
679
+ f"public/audios/{name}.wav", result_wav.cpu().squeeze().numpy(), sr
680
+ )
681
+ result = {"path": f"/audios/{name}.wav"}
682
+
683
+ if model_id == "microsoft/speecht5_vc":
684
+ audio_url = data["audio_url"]
685
+ wav, sr = torchaudio.load(audio_url)
686
+ inputs = pipes[model_id]["processor"](
687
+ audio=wav, sampling_rate=sr, return_tensors="pt"
688
+ )
689
+ embeddings_dataset = pipes[model_id]["embeddings_dataset"]
690
+ speaker_embeddings = torch.tensor(
691
+ embeddings_dataset[7306]["xvector"]
692
+ ).unsqueeze(0)
693
+ pipes[model_id]["vocoder"].to(pipes[model_id]["device"])
694
+ speech = pipe.generate_speech(
695
+ inputs["input_ids"].to(pipes[model_id]["device"]),
696
+ speaker_embeddings,
697
+ vocoder=pipes[model_id]["vocoder"],
698
+ )
699
+ name = str(uuid.uuid4())[:4]
700
+ sf.write(
701
+ f"public/audios/{name}.wav", speech.cpu().numpy(), samplerate=16000
702
+ )
703
+ result = {"path": f"/audios/{name}.wav"}
704
+
705
+ # segmentation
706
+ if model_id == "facebook/detr-resnet-50-panoptic":
707
+ result = []
708
+ segments = pipe(data["img_url"])
709
+ image = load_image(data["img_url"])
710
+
711
+ colors = []
712
+ for i in range(len(segments)):
713
+ colors.append(
714
+ (
715
+ random.randint(100, 255),
716
+ random.randint(100, 255),
717
+ random.randint(100, 255),
718
+ 50,
719
+ )
720
+ )
721
+
722
+ for segment in segments:
723
+ mask = segment["mask"]
724
+ mask = mask.convert("L")
725
+ layer = Image.new("RGBA", mask.size, colors[i])
726
+ image.paste(layer, (0, 0), mask)
727
+ name = str(uuid.uuid4())[:4]
728
+ image.save(f"public/images/{name}.jpg")
729
+ result = {"path": f"/images/{name}.jpg"}
730
+
731
+ if (
732
+ model_id == "facebook/maskformer-swin-base-coco"
733
+ or model_id == "facebook/maskformer-swin-large-ade"
734
+ ):
735
+ image = load_image(data["img_url"])
736
+ inputs = pipes[model_id]["feature_extractor"](
737
+ images=image, return_tensors="pt"
738
+ ).to(pipes[model_id]["device"])
739
+ outputs = pipe(**inputs)
740
+ result = pipes[model_id][
741
+ "feature_extractor"
742
+ ].post_process_panoptic_segmentation(
743
+ outputs, target_sizes=[image.size[::-1]]
744
+ )[
745
+ 0
746
+ ]
747
+ predicted_panoptic_map = result["segmentation"].cpu().numpy()
748
+ predicted_panoptic_map = Image.fromarray(
749
+ predicted_panoptic_map.astype(np.uint8)
750
+ )
751
+ name = str(uuid.uuid4())[:4]
752
+ predicted_panoptic_map.save(f"public/images/{name}.jpg")
753
+ result = {"path": f"/images/{name}.jpg"}
754
+
755
+ except Exception as e:
756
+ print(e)
757
+ traceback.print_exc()
758
+ result = {"error": {"message": "Error when running the model inference."}}
759
+
760
+ if "device" in pipes[model_id]:
761
+ try:
762
+ pipe.to("cpu")
763
+ # torch.cuda.empty_cache()
764
+ except:
765
+ pipe.device = torch.device("cpu")
766
+ pipe.model.to("cpu")
767
+ # torch.cuda.empty_cache()
768
+
769
+ pipes[model_id]["using"] = False
770
+
771
+ if result is None:
772
+ result = {"error": {"message": "model not found"}}
773
+
774
+ end = time.time()
775
+ during = end - start
776
+ print(f"[ complete {model_id} ] {during}s")
777
+ print(f"[ result {model_id} ] {result}")
778
+
779
+ return result
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ tesseract-ocr
public/examples/a.jpg ADDED
public/examples/b.jpg ADDED
public/examples/c.jpg ADDED
public/examples/d.jpg ADDED
public/examples/e.jpg ADDED
public/examples/f.jpg ADDED
public/examples/g.jpg ADDED
public/images/001d.jpg ADDED
public/images/12d7.jpg ADDED
public/images/1df7.jpg ADDED
public/images/267d.jpg ADDED
public/images/2bc2.jpg ADDED
public/images/2cc3.jpg ADDED
public/images/372f.jpg ADDED
public/images/a5e9.jpg ADDED
public/images/bc40.jpg ADDED
public/images/bda5.jpg ADDED
public/images/bfc4.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/diffusers.git@8c530fc2f6a76a2aefb6b285dce6df1675092ac6#egg=diffusers
2
+ git+https://github.com/huggingface/transformers@c612628045822f909020f7eb6784c79700813eda#egg=transformers
3
+ git+https://github.com/patrickvonplaten/controlnet_aux@78efc716868a7f5669c288233d65b471f542ce40#egg=controlnet_aux
4
+ tiktoken==0.3.3
5
+ pydub==0.25.1
6
+ espnet==202301
7
+ espnet_model_zoo==0.1.7
8
+ flask==2.2.3
9
+ flask_cors==3.0.10
10
+ waitress==2.1.2
11
+ datasets==2.11.0
12
+ asteroid==0.6.0
13
+ speechbrain==0.5.14
14
+ timm==0.6.13
15
+ typeguard==2.13.3
16
+ accelerate==0.18.0
17
+ pytesseract==0.3.10
18
+ basicsr==1.4.2