Upload 9 files
Browse files- .gitignore +171 -0
- Dockerfile +19 -0
- LICENSE +21 -0
- README.md +153 -13
- __init__.py +4 -0
- inference.py +704 -0
- inference_vision.py +259 -0
- requirements.txt +20 -4
- server.py +96 -0
.gitignore
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
*pyc
|
3 |
+
*pth
|
4 |
+
checkpoint*/
|
5 |
+
output/
|
6 |
+
.DS_Store
|
7 |
+
*.wav
|
8 |
+
*.webp
|
9 |
+
*.jpg
|
10 |
+
|
11 |
+
__pycache__/
|
12 |
+
*.py[cod]
|
13 |
+
*$py.class
|
14 |
+
|
15 |
+
# C extensions
|
16 |
+
*.so
|
17 |
+
|
18 |
+
# Distribution / packaging
|
19 |
+
.Python
|
20 |
+
build/
|
21 |
+
develop-eggs/
|
22 |
+
dist/
|
23 |
+
downloads/
|
24 |
+
eggs/
|
25 |
+
.eggs/
|
26 |
+
lib/
|
27 |
+
lib64/
|
28 |
+
parts/
|
29 |
+
sdist/
|
30 |
+
var/
|
31 |
+
wheels/
|
32 |
+
share/python-wheels/
|
33 |
+
*.egg-info/
|
34 |
+
.installed.cfg
|
35 |
+
*.egg
|
36 |
+
MANIFEST
|
37 |
+
|
38 |
+
# PyInstaller
|
39 |
+
# Usually these files are written by a python script from a template
|
40 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
41 |
+
*.manifest
|
42 |
+
*.spec
|
43 |
+
|
44 |
+
# Installer logs
|
45 |
+
pip-log.txt
|
46 |
+
pip-delete-this-directory.txt
|
47 |
+
|
48 |
+
# Unit test / coverage reports
|
49 |
+
htmlcov/
|
50 |
+
.tox/
|
51 |
+
.nox/
|
52 |
+
.coverage
|
53 |
+
.coverage.*
|
54 |
+
.cache
|
55 |
+
nosetests.xml
|
56 |
+
coverage.xml
|
57 |
+
*.cover
|
58 |
+
*.py,cover
|
59 |
+
.hypothesis/
|
60 |
+
.pytest_cache/
|
61 |
+
cover/
|
62 |
+
|
63 |
+
# Translations
|
64 |
+
*.mo
|
65 |
+
*.pot
|
66 |
+
|
67 |
+
# Django stuff:
|
68 |
+
*.log
|
69 |
+
local_settings.py
|
70 |
+
db.sqlite3
|
71 |
+
db.sqlite3-journal
|
72 |
+
|
73 |
+
# Flask stuff:
|
74 |
+
instance/
|
75 |
+
.webassets-cache
|
76 |
+
|
77 |
+
# Scrapy stuff:
|
78 |
+
.scrapy
|
79 |
+
|
80 |
+
# Sphinx documentation
|
81 |
+
docs/_build/
|
82 |
+
|
83 |
+
# PyBuilder
|
84 |
+
.pybuilder/
|
85 |
+
target/
|
86 |
+
|
87 |
+
# Jupyter Notebook
|
88 |
+
.ipynb_checkpoints
|
89 |
+
|
90 |
+
# IPython
|
91 |
+
profile_default/
|
92 |
+
ipython_config.py
|
93 |
+
|
94 |
+
# pyenv
|
95 |
+
# For a library or package, you might want to ignore these files since the code is
|
96 |
+
# intended to run in multiple environments; otherwise, check them in:
|
97 |
+
# .python-version
|
98 |
+
|
99 |
+
# pipenv
|
100 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
101 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
102 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
103 |
+
# install all needed dependencies.
|
104 |
+
#Pipfile.lock
|
105 |
+
|
106 |
+
# poetry
|
107 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
108 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
109 |
+
# commonly ignored for libraries.
|
110 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
111 |
+
#poetry.lock
|
112 |
+
|
113 |
+
# pdm
|
114 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
115 |
+
#pdm.lock
|
116 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
117 |
+
# in version control.
|
118 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
119 |
+
.pdm.toml
|
120 |
+
.pdm-python
|
121 |
+
.pdm-build/
|
122 |
+
|
123 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
124 |
+
__pypackages__/
|
125 |
+
|
126 |
+
# Celery stuff
|
127 |
+
celerybeat-schedule
|
128 |
+
celerybeat.pid
|
129 |
+
|
130 |
+
# SageMath parsed files
|
131 |
+
*.sage.py
|
132 |
+
|
133 |
+
# Environments
|
134 |
+
.env
|
135 |
+
.venv
|
136 |
+
env/
|
137 |
+
venv/
|
138 |
+
ENV/
|
139 |
+
env.bak/
|
140 |
+
venv.bak/
|
141 |
+
|
142 |
+
# Spyder project settings
|
143 |
+
.spyderproject
|
144 |
+
.spyproject
|
145 |
+
|
146 |
+
# Rope project settings
|
147 |
+
.ropeproject
|
148 |
+
|
149 |
+
# mkdocs documentation
|
150 |
+
/site
|
151 |
+
|
152 |
+
# mypy
|
153 |
+
.mypy_cache/
|
154 |
+
.dmypy.json
|
155 |
+
dmypy.json
|
156 |
+
|
157 |
+
# Pyre type checker
|
158 |
+
.pyre/
|
159 |
+
|
160 |
+
# pytype static type analyzer
|
161 |
+
.pytype/
|
162 |
+
|
163 |
+
# Cython debug symbols
|
164 |
+
cython_debug/
|
165 |
+
|
166 |
+
# PyCharm
|
167 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
168 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
169 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
170 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
171 |
+
#.idea/
|
Dockerfile
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Base image set karein jo Python 3.10 par hai
|
2 |
+
FROM python:3.10
|
3 |
+
|
4 |
+
# Working directory set karein
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
# Local files ko container mein copy karein
|
8 |
+
COPY . .
|
9 |
+
|
10 |
+
# Linux packages install karein
|
11 |
+
RUN apt-get update && \
|
12 |
+
apt-get install -y ffmpeg
|
13 |
+
|
14 |
+
# Python dependencies install karein
|
15 |
+
RUN pip install --no-cache-dir -r requirements.txt && \
|
16 |
+
pip install PyAudio==0.2.14
|
17 |
+
|
18 |
+
# Server aur Streamlit app run karne ke liye commands
|
19 |
+
CMD ["bash", "-c", "python3 server.py --ip '0.0.0.0' --port 60808 & API_URL=http://0.0.0.0:60808/chat streamlit run webui/omni_streamlit.py --server.port=8501"]
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 gpt-omni
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,13 +1,153 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Mini-Omni2
|
3 |
+
|
4 |
+
<p align="center">
|
5 |
+
<img src="./data/figures/title_new.png" width="90%"/>
|
6 |
+
</p>
|
7 |
+
|
8 |
+
|
9 |
+
<p align="center">
|
10 |
+
🤗 <a href="https://huggingface.co/gpt-omni/mini-omni2">Hugging Face</a> | 📖 <a href="https://github.com/gpt-omni/mini-omni2">Github</a>
|
11 |
+
| 📑 <a href="https://arxiv.org/abs/2410.11190">Technical report</a>
|
12 |
+
</p>
|
13 |
+
|
14 |
+
<h5 align="center"> If you like little Omni2, please give us a star⭐ and cite our <a href="(https://arxiv.org/abs/2410.11190)">paper</a>!</h2>
|
15 |
+
|
16 |
+
## Introduction
|
17 |
+
Mini-Omni2 is an **omni-interactive** model. It can **understand image, audio and text inputs and has end-to-end voice conversations with users**. Featuring **real-time voice output**, **omni-capable multimodal understanding** and flexible interaction **ability with interruption mechanism while speaking**.
|
18 |
+
|
19 |
+
<p align="center">
|
20 |
+
<img src="./data/figures/framework.jpeg" width="100%"/>
|
21 |
+
</p>
|
22 |
+
|
23 |
+
|
24 |
+
## Updates
|
25 |
+
|
26 |
+
- **2024.10:** Release the model, technical report, inference and chat demo code.
|
27 |
+
|
28 |
+
## Features
|
29 |
+
✅ **Multimodal interaction**: with the ability to understand images, speech and text, just like GPT-4o.
|
30 |
+
|
31 |
+
✅ **Real-time speech-to-speech** conversational capabilities. No extra ASR or TTS models required, just like [Mini-Omni](https://github.com/gpt-omni/mini-omni).
|
32 |
+
|
33 |
+
<!-- ✅ **Streaming audio output**: with first-chunk latency of audio stream less than 0.3s. -->
|
34 |
+
|
35 |
+
<!-- ✅ **Duplex interaction**: hearing while speaking, it can be interrupted by key words like "stop omni". -->
|
36 |
+
|
37 |
+
|
38 |
+
## Demo
|
39 |
+
|
40 |
+
NOTE: need to unmute first.
|
41 |
+
|
42 |
+
https://github.com/user-attachments/assets/ad97ca7f-f8b4-40c3-a7e8-fa54b4edf155
|
43 |
+
|
44 |
+
|
45 |
+
## ToDo
|
46 |
+
- [ ] update interruption mechanism
|
47 |
+
|
48 |
+
|
49 |
+
## Install
|
50 |
+
|
51 |
+
Create a new conda environment and install the required packages:
|
52 |
+
|
53 |
+
```sh
|
54 |
+
conda create -n omni python=3.10
|
55 |
+
conda activate omni
|
56 |
+
|
57 |
+
git clone https://github.com/gpt-omni/mini-omni2.git
|
58 |
+
cd mini-omni2
|
59 |
+
pip install -r requirements.txt
|
60 |
+
```
|
61 |
+
|
62 |
+
## Quick start
|
63 |
+
|
64 |
+
**Interactive demo**
|
65 |
+
|
66 |
+
- start server
|
67 |
+
|
68 |
+
NOTE: you need to start the server before running the streamlit or gradio demo with API_URL set to the server address.
|
69 |
+
|
70 |
+
```sh
|
71 |
+
sudo apt-get install ffmpeg
|
72 |
+
conda activate omni
|
73 |
+
cd mini-omni2
|
74 |
+
python3 server.py --ip '0.0.0.0' --port 60808
|
75 |
+
```
|
76 |
+
|
77 |
+
|
78 |
+
- run streamlit demo
|
79 |
+
|
80 |
+
NOTE: you need to run streamlit **locally** with PyAudio installed.
|
81 |
+
|
82 |
+
```sh
|
83 |
+
pip install PyAudio==0.2.14
|
84 |
+
API_URL=http://0.0.0.0:60808/chat streamlit run webui/omni_streamlit.py
|
85 |
+
```
|
86 |
+
|
87 |
+
|
88 |
+
**Local test**
|
89 |
+
|
90 |
+
```sh
|
91 |
+
conda activate omni
|
92 |
+
cd mini-omni2
|
93 |
+
# test run the preset audio samples and questions
|
94 |
+
python inference_vision.py
|
95 |
+
```
|
96 |
+
|
97 |
+
## Mini-Omni2 Overview
|
98 |
+
|
99 |
+
**1. Multimodal Modeling**:
|
100 |
+
We use multiple sequences as the input and output of the model. In the input part, we will concatenate image, audio and text features to perform a series of comprehensive tasks, as shown in the following figures. In the output part, we use text-guided delayed parallel output to generate real-time speech responses.
|
101 |
+
<p align="center">
|
102 |
+
<img src="./data/figures/inputids.png" width="100%"/>
|
103 |
+
</p>
|
104 |
+
|
105 |
+
**2. Multi-stage Training**:
|
106 |
+
We propose an efficient alignment training method and conduct encoder adaptation, modal alignment, and multimodal fine-tuning respectively in the three-stage training.
|
107 |
+
<p align="center">
|
108 |
+
<img src="./data/figures/training.jpeg" width="100%"/>
|
109 |
+
</p>
|
110 |
+
|
111 |
+
<!-- **3. Cases**:
|
112 |
+
Here are more cases of Mini-Omni2:
|
113 |
+
<p align="center">
|
114 |
+
<img src="./data/figures/samples.png" width="100%"/>
|
115 |
+
</p> -->
|
116 |
+
|
117 |
+
## FAQ
|
118 |
+
|
119 |
+
**1. Does the model support other languages?**
|
120 |
+
|
121 |
+
No, the model is only trained on English. However, as we use whisper as the audio encoder, the model can understand other languages which is supported by whisper (like chinese), but the output is only in English.
|
122 |
+
|
123 |
+
**2. Error: can not run streamlit in local browser, with remote streamlit server**
|
124 |
+
|
125 |
+
You need start streamlit **locally** with PyAudio installed.
|
126 |
+
|
127 |
+
|
128 |
+
## Acknowledgements
|
129 |
+
|
130 |
+
- [Qwen2](https://github.com/QwenLM/Qwen2/) as the LLM backbone.
|
131 |
+
- [litGPT](https://github.com/Lightning-AI/litgpt/) for training and inference.
|
132 |
+
- [whisper](https://github.com/openai/whisper/) for audio encoding.
|
133 |
+
- [clip](https://github.com/openai/CLIP) for image encoding.
|
134 |
+
- [snac](https://github.com/hubertsiuzdak/snac/) for audio decoding.
|
135 |
+
- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) for generating synthetic speech.
|
136 |
+
- [OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca) and [MOSS](https://github.com/OpenMOSS/MOSS/tree/main) for alignment.
|
137 |
+
|
138 |
+
## Citation
|
139 |
+
|
140 |
+
```bibtex
|
141 |
+
@article{xie2024miniomni2opensourcegpt4ovision,
|
142 |
+
title={Mini-Omni2: Towards Open-source GPT-4o with Vision, Speech and Duplex Capabilities},
|
143 |
+
author={Zhifei Xie and Changqiao Wu},
|
144 |
+
year={2024},
|
145 |
+
eprint={2410.11190},
|
146 |
+
archivePrefix={arXiv},
|
147 |
+
primaryClass={eess.AS},
|
148 |
+
journal={ArXiv},
|
149 |
+
volume={abs/2410.11190},
|
150 |
+
}
|
151 |
+
```
|
152 |
+
## Star History
|
153 |
+
[![Star History Chart](https://api.star-history.com/svg?repos=gpt-omni/mini-omni2&type=Date)](https://star-history.com/#gpt-omni/mini-omni2&Date)
|
__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
inference.py
ADDED
@@ -0,0 +1,704 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import lightning as L
|
3 |
+
import torch
|
4 |
+
import glob
|
5 |
+
import time
|
6 |
+
from snac import SNAC
|
7 |
+
from litgpt import Tokenizer
|
8 |
+
from litgpt.utils import (
|
9 |
+
num_parameters,
|
10 |
+
)
|
11 |
+
from litgpt.generate.base import (
|
12 |
+
generate_AA,
|
13 |
+
generate_ASR,
|
14 |
+
generate_TA,
|
15 |
+
generate_TT,
|
16 |
+
generate_AT,
|
17 |
+
generate_TA_BATCH,
|
18 |
+
next_token_image_batch
|
19 |
+
)
|
20 |
+
import soundfile as sf
|
21 |
+
from litgpt.model import GPT, Config
|
22 |
+
from lightning.fabric.utilities.load import _lazy_load as lazy_load
|
23 |
+
from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
|
24 |
+
from utils.snac_utils import get_snac, generate_audio_data
|
25 |
+
import whisper
|
26 |
+
from tqdm import tqdm
|
27 |
+
from huggingface_hub import snapshot_download
|
28 |
+
|
29 |
+
|
30 |
+
torch.set_printoptions(sci_mode=False)
|
31 |
+
|
32 |
+
|
33 |
+
# TODO
|
34 |
+
text_vocabsize = 151936
|
35 |
+
text_specialtokens = 64
|
36 |
+
audio_vocabsize = 4096
|
37 |
+
audio_specialtokens = 64
|
38 |
+
|
39 |
+
padded_text_vocabsize = text_vocabsize + text_specialtokens
|
40 |
+
padded_audio_vocabsize = audio_vocabsize + audio_specialtokens
|
41 |
+
|
42 |
+
_eot = text_vocabsize
|
43 |
+
_pad_t = text_vocabsize + 1
|
44 |
+
_input_t = text_vocabsize + 2
|
45 |
+
_answer_t = text_vocabsize + 3
|
46 |
+
_asr = text_vocabsize + 4
|
47 |
+
|
48 |
+
_eoa = audio_vocabsize
|
49 |
+
_pad_a = audio_vocabsize + 1
|
50 |
+
_input_a = audio_vocabsize + 2
|
51 |
+
_answer_a = audio_vocabsize + 3
|
52 |
+
_split = audio_vocabsize + 4
|
53 |
+
_image = audio_vocabsize + 5
|
54 |
+
_eoimage = audio_vocabsize + 6
|
55 |
+
|
56 |
+
|
57 |
+
def get_input_ids_TA(text, text_tokenizer):
|
58 |
+
input_ids_item = [[] for _ in range(8)]
|
59 |
+
text_tokens = text_tokenizer.encode(text)
|
60 |
+
for i in range(7):
|
61 |
+
input_ids_item[i] = [layershift(_pad_a, i)] * (len(text_tokens) + 2) + [
|
62 |
+
layershift(_answer_a, i)
|
63 |
+
]
|
64 |
+
input_ids_item[i] = torch.tensor(input_ids_item[i]).unsqueeze(0)
|
65 |
+
input_ids_item[-1] = [_input_t] + text_tokens.tolist() + [_eot] + [_answer_t]
|
66 |
+
input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)
|
67 |
+
return input_ids_item
|
68 |
+
|
69 |
+
|
70 |
+
def get_input_ids_TT(text, text_tokenizer):
|
71 |
+
input_ids_item = [[] for i in range(8)]
|
72 |
+
text_tokens = text_tokenizer.encode(text).tolist()
|
73 |
+
|
74 |
+
for i in range(7):
|
75 |
+
input_ids_item[i] = torch.tensor(
|
76 |
+
[layershift(_pad_a, i)] * (len(text_tokens) + 3)
|
77 |
+
).unsqueeze(0)
|
78 |
+
input_ids_item[-1] = [_input_t] + text_tokens + [_eot] + [_answer_t]
|
79 |
+
input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)
|
80 |
+
|
81 |
+
return input_ids_item
|
82 |
+
|
83 |
+
|
84 |
+
def get_input_ids_whisper(
|
85 |
+
mel, leng, whispermodel, device,
|
86 |
+
special_token_a=_answer_a, special_token_t=_answer_t,
|
87 |
+
):
|
88 |
+
|
89 |
+
with torch.no_grad():
|
90 |
+
mel = mel.unsqueeze(0).to(device)
|
91 |
+
# audio_feature = whisper.decode(whispermodel,mel, options).audio_features
|
92 |
+
audio_feature = whispermodel.embed_audio(mel)[0][:leng]
|
93 |
+
|
94 |
+
T = audio_feature.size(0)
|
95 |
+
input_ids = []
|
96 |
+
for i in range(7):
|
97 |
+
input_ids_item = []
|
98 |
+
input_ids_item.append(layershift(_input_a, i))
|
99 |
+
input_ids_item += [layershift(_pad_a, i)] * T
|
100 |
+
input_ids_item += [(layershift(_eoa, i)), layershift(special_token_a, i)]
|
101 |
+
input_ids.append(torch.tensor(input_ids_item).unsqueeze(0))
|
102 |
+
input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, special_token_t])
|
103 |
+
input_ids.append(input_id_T.unsqueeze(0))
|
104 |
+
return audio_feature.unsqueeze(0), input_ids
|
105 |
+
|
106 |
+
|
107 |
+
def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
|
108 |
+
with torch.no_grad():
|
109 |
+
mel = mel.unsqueeze(0).to(device)
|
110 |
+
# audio_feature = whisper.decode(whispermodel,mel, options).audio_features
|
111 |
+
audio_feature = whispermodel.embed_audio(mel)[0][:leng]
|
112 |
+
T = audio_feature.size(0)
|
113 |
+
input_ids_AA = []
|
114 |
+
for i in range(7):
|
115 |
+
input_ids_item = []
|
116 |
+
input_ids_item.append(layershift(_input_a, i))
|
117 |
+
input_ids_item += [layershift(_pad_a, i)] * T
|
118 |
+
input_ids_item += [(layershift(_eoa, i)), layershift(_answer_a, i)]
|
119 |
+
input_ids_AA.append(torch.tensor(input_ids_item))
|
120 |
+
input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
|
121 |
+
input_ids_AA.append(input_id_T)
|
122 |
+
|
123 |
+
input_ids_AT = []
|
124 |
+
for i in range(7):
|
125 |
+
input_ids_item = []
|
126 |
+
input_ids_item.append(layershift(_input_a, i))
|
127 |
+
input_ids_item += [layershift(_pad_a, i)] * T
|
128 |
+
input_ids_item += [(layershift(_eoa, i)), layershift(_pad_a, i)]
|
129 |
+
input_ids_AT.append(torch.tensor(input_ids_item))
|
130 |
+
input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
|
131 |
+
input_ids_AT.append(input_id_T)
|
132 |
+
|
133 |
+
input_ids = [input_ids_AA, input_ids_AT]
|
134 |
+
stacked_inputids = [[] for _ in range(8)]
|
135 |
+
for i in range(2):
|
136 |
+
for j in range(8):
|
137 |
+
stacked_inputids[j].append(input_ids[i][j])
|
138 |
+
stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
|
139 |
+
return torch.stack([audio_feature, audio_feature]), stacked_inputids
|
140 |
+
|
141 |
+
|
142 |
+
def load_audio(path):
|
143 |
+
audio = whisper.load_audio(path)
|
144 |
+
duration_ms = (len(audio) / 16000) * 1000
|
145 |
+
audio = whisper.pad_or_trim(audio)
|
146 |
+
mel = whisper.log_mel_spectrogram(audio)
|
147 |
+
return mel, int(duration_ms / 20) + 1
|
148 |
+
|
149 |
+
|
150 |
+
def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
151 |
+
snacmodel, out_dir=None):
|
152 |
+
with fabric.init_tensor():
|
153 |
+
model.set_kv_cache(batch_size=2)
|
154 |
+
tokenlist = generate_TA_BATCH(
|
155 |
+
model,
|
156 |
+
audio_feature,
|
157 |
+
input_ids,
|
158 |
+
[leng, leng],
|
159 |
+
["A1A2", "A1T2"],
|
160 |
+
max_returned_tokens=2048,
|
161 |
+
temperature=0.9,
|
162 |
+
top_k=1,
|
163 |
+
eos_id_a=_eoa,
|
164 |
+
eos_id_t=_eot,
|
165 |
+
pad_id_t=_pad_t,
|
166 |
+
shift=padded_text_vocabsize,
|
167 |
+
include_prompt=True,
|
168 |
+
generate_text=True,
|
169 |
+
)
|
170 |
+
text_tokenlist = tokenlist[-1]
|
171 |
+
if text_vocabsize in text_tokenlist:
|
172 |
+
text_tokenlist = text_tokenlist[: text_tokenlist.index(text_vocabsize)]
|
173 |
+
text = text_tokenizer.decode(torch.tensor(text_tokenlist)).strip()
|
174 |
+
|
175 |
+
audio_tokenlist = tokenlist[:-1]
|
176 |
+
audiolist = reconscruct_snac(audio_tokenlist)
|
177 |
+
audio = reconstruct_tensors(audiolist)
|
178 |
+
if out_dir is None:
|
179 |
+
out_dir = "./output/default/A1-A2-batch"
|
180 |
+
else:
|
181 |
+
out_dir = out_dir + "/A1-A2-batch"
|
182 |
+
if not os.path.exists(out_dir):
|
183 |
+
os.makedirs(out_dir)
|
184 |
+
with torch.inference_mode():
|
185 |
+
audio_hat = snacmodel.decode(audio)
|
186 |
+
sf.write(
|
187 |
+
f"{out_dir}/{step:02d}.wav",
|
188 |
+
audio_hat.squeeze().cpu().numpy(),
|
189 |
+
24000,
|
190 |
+
)
|
191 |
+
model.clear_kv_cache()
|
192 |
+
return text
|
193 |
+
|
194 |
+
|
195 |
+
def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
196 |
+
with fabric.init_tensor():
|
197 |
+
model.set_kv_cache(batch_size=1)
|
198 |
+
tokenlist = generate_AT(
|
199 |
+
model,
|
200 |
+
audio_feature,
|
201 |
+
input_ids,
|
202 |
+
[leng],
|
203 |
+
["AT"],
|
204 |
+
max_returned_tokens=2048,
|
205 |
+
temperature=0.9,
|
206 |
+
top_k=1,
|
207 |
+
eos_id_a=_eoa,
|
208 |
+
eos_id_t=_eot,
|
209 |
+
pad_id_t=_pad_t,
|
210 |
+
shift=padded_text_vocabsize,
|
211 |
+
include_prompt=True,
|
212 |
+
generate_text=True,
|
213 |
+
)
|
214 |
+
return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
|
215 |
+
|
216 |
+
|
217 |
+
def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
218 |
+
snacmodel, out_dir=None):
|
219 |
+
with fabric.init_tensor():
|
220 |
+
model.set_kv_cache(batch_size=1)
|
221 |
+
tokenlist = generate_AA(
|
222 |
+
model,
|
223 |
+
audio_feature,
|
224 |
+
input_ids,
|
225 |
+
[leng],
|
226 |
+
["A1T2"],
|
227 |
+
max_returned_tokens=2048,
|
228 |
+
temperature=0.9,
|
229 |
+
top_k=1,
|
230 |
+
eos_id_a=_eoa,
|
231 |
+
eos_id_t=_eot,
|
232 |
+
pad_id_t=_pad_t,
|
233 |
+
shift=padded_text_vocabsize,
|
234 |
+
include_prompt=True,
|
235 |
+
generate_text=True,
|
236 |
+
)
|
237 |
+
audiolist = reconscruct_snac(tokenlist)
|
238 |
+
tokenlist = tokenlist[-1]
|
239 |
+
if text_vocabsize in tokenlist:
|
240 |
+
tokenlist = tokenlist[: tokenlist.index(text_vocabsize)]
|
241 |
+
if out_dir is None:
|
242 |
+
out_dir = "./output/default/A1-A2"
|
243 |
+
else:
|
244 |
+
out_dir = out_dir + "/A1-A2"
|
245 |
+
if not os.path.exists(out_dir):
|
246 |
+
os.makedirs(out_dir)
|
247 |
+
|
248 |
+
audio = reconstruct_tensors(audiolist)
|
249 |
+
with torch.inference_mode():
|
250 |
+
audio_hat = snacmodel.decode(audio)
|
251 |
+
sf.write(
|
252 |
+
f"{out_dir}/{step:02d}.wav",
|
253 |
+
audio_hat.squeeze().cpu().numpy(),
|
254 |
+
24000,
|
255 |
+
)
|
256 |
+
model.clear_kv_cache()
|
257 |
+
return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
|
258 |
+
|
259 |
+
|
260 |
+
def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
261 |
+
with fabric.init_tensor():
|
262 |
+
model.set_kv_cache(batch_size=1)
|
263 |
+
tokenlist = generate_ASR(
|
264 |
+
model,
|
265 |
+
audio_feature,
|
266 |
+
input_ids,
|
267 |
+
[leng],
|
268 |
+
["A1T1"],
|
269 |
+
max_returned_tokens=2048,
|
270 |
+
temperature=0.9,
|
271 |
+
top_k=1,
|
272 |
+
eos_id_a=_eoa,
|
273 |
+
eos_id_t=_eot,
|
274 |
+
pad_id_t=_pad_t,
|
275 |
+
shift=padded_text_vocabsize,
|
276 |
+
include_prompt=True,
|
277 |
+
generate_text=True,
|
278 |
+
)
|
279 |
+
model.clear_kv_cache()
|
280 |
+
return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
|
281 |
+
|
282 |
+
|
283 |
+
def T1_A2(fabric, input_ids, model, text_tokenizer, step,
|
284 |
+
snacmodel, out_dir=None):
|
285 |
+
with fabric.init_tensor():
|
286 |
+
model.set_kv_cache(batch_size=1)
|
287 |
+
tokenlist = generate_TA(
|
288 |
+
model,
|
289 |
+
None,
|
290 |
+
input_ids,
|
291 |
+
None,
|
292 |
+
["T1A2"],
|
293 |
+
max_returned_tokens=2048,
|
294 |
+
temperature=0.9,
|
295 |
+
top_k=1,
|
296 |
+
eos_id_a=_eoa,
|
297 |
+
eos_id_t=_eot,
|
298 |
+
pad_id_t=_pad_t,
|
299 |
+
shift=padded_text_vocabsize,
|
300 |
+
include_prompt=True,
|
301 |
+
generate_text=True,
|
302 |
+
)
|
303 |
+
|
304 |
+
audiolist = reconscruct_snac(tokenlist)
|
305 |
+
tokenlist = tokenlist[-1]
|
306 |
+
|
307 |
+
if text_vocabsize in tokenlist:
|
308 |
+
tokenlist = tokenlist[: tokenlist.index(text_vocabsize)]
|
309 |
+
audio = reconstruct_tensors(audiolist)
|
310 |
+
if out_dir is None:
|
311 |
+
out_dir = "./output/default/T1-A2"
|
312 |
+
else:
|
313 |
+
out_dir = out_dir + "/T1-A2"
|
314 |
+
if not os.path.exists(out_dir):
|
315 |
+
os.makedirs(out_dir)
|
316 |
+
|
317 |
+
with torch.inference_mode():
|
318 |
+
audio_hat = snacmodel.decode(audio)
|
319 |
+
sf.write(
|
320 |
+
f"{out_dir}/{step:02d}.wav",
|
321 |
+
audio_hat.squeeze().cpu().numpy(),
|
322 |
+
24000,
|
323 |
+
)
|
324 |
+
model.clear_kv_cache()
|
325 |
+
return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
|
326 |
+
|
327 |
+
|
328 |
+
def T1_T2(fabric, input_ids, model, text_tokenizer, step):
|
329 |
+
|
330 |
+
with fabric.init_tensor():
|
331 |
+
model.set_kv_cache(batch_size=1)
|
332 |
+
tokenlist = generate_TT(
|
333 |
+
model,
|
334 |
+
None,
|
335 |
+
input_ids,
|
336 |
+
None,
|
337 |
+
["T1T2"],
|
338 |
+
max_returned_tokens=2048,
|
339 |
+
temperature=0.9,
|
340 |
+
top_k=1,
|
341 |
+
eos_id_a=_eoa,
|
342 |
+
eos_id_t=_eot,
|
343 |
+
pad_id_t=_pad_t,
|
344 |
+
shift=padded_text_vocabsize,
|
345 |
+
include_prompt=True,
|
346 |
+
generate_text=True,
|
347 |
+
)
|
348 |
+
model.clear_kv_cache()
|
349 |
+
return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
|
350 |
+
|
351 |
+
|
352 |
+
def load_model(ckpt_dir, device):
|
353 |
+
snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
|
354 |
+
whisper_model_path = ckpt_dir + "/small.pt"
|
355 |
+
if not os.path.exists(whisper_model_path):
|
356 |
+
whisper_model_path = "small"
|
357 |
+
whispermodel = whisper.load_model(whisper_model_path).to(device)
|
358 |
+
text_tokenizer = Tokenizer(ckpt_dir)
|
359 |
+
fabric = L.Fabric(devices=1, strategy="auto")
|
360 |
+
config = Config.from_file(ckpt_dir + "/model_config.yaml")
|
361 |
+
config.post_adapter = False
|
362 |
+
|
363 |
+
with fabric.init_module(empty_init=False):
|
364 |
+
model = GPT(config)
|
365 |
+
|
366 |
+
model = fabric.setup(model)
|
367 |
+
state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
|
368 |
+
model.load_state_dict(state_dict, strict=True)
|
369 |
+
model.to(device).eval()
|
370 |
+
|
371 |
+
return fabric, model, text_tokenizer, snacmodel, whispermodel
|
372 |
+
|
373 |
+
|
374 |
+
def download_model(ckpt_dir):
|
375 |
+
repo_id = "gpt-omni/mini-omni2"
|
376 |
+
snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
|
377 |
+
|
378 |
+
|
379 |
+
def get_text_stream(list_output, index, text_tokenizer):
|
380 |
+
text_tokens = list_output[-1][index:]
|
381 |
+
index += len(text_tokens)
|
382 |
+
is_text_end = False
|
383 |
+
if text_vocabsize in text_tokens:
|
384 |
+
text_tokens = text_tokens[:text_tokens.index(text_vocabsize)]
|
385 |
+
is_text_end = True
|
386 |
+
if len(text_tokens) == 0:
|
387 |
+
return "", index, is_text_end
|
388 |
+
res_text = text_tokenizer.decode(torch.tensor(text_tokens))
|
389 |
+
return res_text, index, is_text_end
|
390 |
+
|
391 |
+
|
392 |
+
class OmniInference:
|
393 |
+
|
394 |
+
def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
|
395 |
+
self.device = device
|
396 |
+
if not os.path.exists(ckpt_dir):
|
397 |
+
print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
|
398 |
+
download_model(ckpt_dir)
|
399 |
+
self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device)
|
400 |
+
|
401 |
+
def warm_up(self, sample='./data/samples/output1.wav'):
|
402 |
+
for _ in self.run_AT_batch_stream(sample):
|
403 |
+
pass
|
404 |
+
|
405 |
+
@torch.inference_mode()
|
406 |
+
def run_AT_batch_stream(self,
|
407 |
+
audio_path,
|
408 |
+
stream_stride=4,
|
409 |
+
max_returned_tokens=2048,
|
410 |
+
temperature=0.9,
|
411 |
+
top_k=1,
|
412 |
+
top_p=1.0,
|
413 |
+
eos_id_a=_eoa,
|
414 |
+
eos_id_t=_eot,
|
415 |
+
save_path=None
|
416 |
+
):
|
417 |
+
|
418 |
+
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
|
419 |
+
model = self.model
|
420 |
+
|
421 |
+
with self.fabric.init_tensor():
|
422 |
+
model.set_kv_cache(batch_size=2,device=self.device)
|
423 |
+
|
424 |
+
mel, leng = load_audio(audio_path)
|
425 |
+
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
|
426 |
+
T = input_ids[0].size(1)
|
427 |
+
device = input_ids[0].device
|
428 |
+
|
429 |
+
assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}"
|
430 |
+
|
431 |
+
if model.max_seq_length < max_returned_tokens - 1:
|
432 |
+
raise NotImplementedError(
|
433 |
+
f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
|
434 |
+
)
|
435 |
+
|
436 |
+
input_pos = torch.tensor([T], device=device)
|
437 |
+
list_output = [[] for i in range(8)]
|
438 |
+
tokens_A, token_T = next_token_image_batch(
|
439 |
+
model,
|
440 |
+
audio_feature.to(torch.float32).to(model.device),
|
441 |
+
None,
|
442 |
+
input_ids,
|
443 |
+
[T - 3, T - 3],
|
444 |
+
["A1T2", "A1T2"],
|
445 |
+
input_pos=torch.arange(0, T, device=device),
|
446 |
+
temperature=temperature,
|
447 |
+
top_k=top_k,
|
448 |
+
top_p=top_p,
|
449 |
+
)
|
450 |
+
|
451 |
+
for i in range(7):
|
452 |
+
list_output[i].append(tokens_A[i].tolist()[0])
|
453 |
+
list_output[7].append(token_T.tolist()[0])
|
454 |
+
|
455 |
+
model_input_ids = [[] for i in range(8)]
|
456 |
+
for i in range(7):
|
457 |
+
tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize + i * padded_audio_vocabsize
|
458 |
+
model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
|
459 |
+
model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=device))
|
460 |
+
model_input_ids[i] = torch.stack(model_input_ids[i])
|
461 |
+
|
462 |
+
model_input_ids[-1].append(token_T.clone().to(torch.int32))
|
463 |
+
model_input_ids[-1].append(token_T.clone().to(torch.int32))
|
464 |
+
model_input_ids[-1] = torch.stack(model_input_ids[-1])
|
465 |
+
|
466 |
+
text_end = False
|
467 |
+
index = 1
|
468 |
+
nums_generate = stream_stride
|
469 |
+
begin_generate = False
|
470 |
+
current_index = 0
|
471 |
+
|
472 |
+
text_index = 0
|
473 |
+
is_text_end = False
|
474 |
+
|
475 |
+
for _ in tqdm(range(2, max_returned_tokens - T + 1)):
|
476 |
+
tokens_A, token_T = next_token_image_batch(
|
477 |
+
model,
|
478 |
+
None,
|
479 |
+
None,
|
480 |
+
model_input_ids,
|
481 |
+
None,
|
482 |
+
None,
|
483 |
+
input_pos=input_pos,
|
484 |
+
temperature=temperature,
|
485 |
+
top_k=top_k,
|
486 |
+
top_p=top_p,
|
487 |
+
)
|
488 |
+
|
489 |
+
if text_end:
|
490 |
+
token_T = torch.tensor([_pad_t], device=device)
|
491 |
+
|
492 |
+
if tokens_A[-1] == eos_id_a:
|
493 |
+
break
|
494 |
+
|
495 |
+
if token_T == eos_id_t:
|
496 |
+
text_end = True
|
497 |
+
|
498 |
+
for i in range(7):
|
499 |
+
list_output[i].append(tokens_A[i].tolist()[0])
|
500 |
+
list_output[7].append(token_T.tolist()[0])
|
501 |
+
|
502 |
+
model_input_ids = [[] for i in range(8)]
|
503 |
+
for i in range(7):
|
504 |
+
tokens_A[i] = tokens_A[i].clone() +padded_text_vocabsize + i * padded_audio_vocabsize
|
505 |
+
model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
|
506 |
+
model_input_ids[i].append(
|
507 |
+
torch.tensor([layershift(4097, i)], device=device)
|
508 |
+
)
|
509 |
+
model_input_ids[i] = torch.stack(model_input_ids[i])
|
510 |
+
|
511 |
+
model_input_ids[-1].append(token_T.clone().to(torch.int32))
|
512 |
+
model_input_ids[-1].append(token_T.clone().to(torch.int32))
|
513 |
+
model_input_ids[-1] = torch.stack(model_input_ids[-1])
|
514 |
+
|
515 |
+
if index == 7:
|
516 |
+
begin_generate = True
|
517 |
+
|
518 |
+
if begin_generate:
|
519 |
+
current_index += 1
|
520 |
+
if current_index == nums_generate:
|
521 |
+
current_index = 0
|
522 |
+
snac = get_snac(list_output, index, nums_generate)
|
523 |
+
audio_stream = generate_audio_data(snac, self.snacmodel, self.device)
|
524 |
+
if is_text_end:
|
525 |
+
text_stream = ""
|
526 |
+
else:
|
527 |
+
text_stream, text_index, is_text_end = get_text_stream(list_output, text_index, self.text_tokenizer)
|
528 |
+
|
529 |
+
yield (audio_stream, text_stream)
|
530 |
+
|
531 |
+
input_pos = input_pos.add_(1)
|
532 |
+
index += 1
|
533 |
+
text = self.text_tokenizer.decode(torch.tensor(list_output[-1]))
|
534 |
+
print(f"text output: {text}")
|
535 |
+
|
536 |
+
if save_path is not None:
|
537 |
+
audiolist = reconscruct_snac(list_output)
|
538 |
+
audio = reconstruct_tensors(audiolist)
|
539 |
+
with torch.inference_mode():
|
540 |
+
audio_hat = self.snacmodel.decode(audio)
|
541 |
+
sf.write(save_path, audio_hat.squeeze().cpu().numpy(), 24000)
|
542 |
+
|
543 |
+
model.clear_kv_cache()
|
544 |
+
return list_output
|
545 |
+
|
546 |
+
|
547 |
+
def test_infer():
|
548 |
+
device = "cuda:0"
|
549 |
+
out_dir = f"./output/{get_time_str()}"
|
550 |
+
ckpt_dir = f"./checkpoint"
|
551 |
+
if not os.path.exists(ckpt_dir):
|
552 |
+
print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
|
553 |
+
download_model(ckpt_dir)
|
554 |
+
|
555 |
+
fabric, model, text_tokenizer, snacmodel, whispermodel = load_model(ckpt_dir, device)
|
556 |
+
|
557 |
+
task = ['A1A2', 'asr', "T1A2", "AA-BATCH", 'T1T2', 'AT']
|
558 |
+
|
559 |
+
# prepare test data
|
560 |
+
# TODO
|
561 |
+
test_audio_list = sorted(glob.glob('./data/samples/output*.wav'))
|
562 |
+
test_audio_transcripts = [
|
563 |
+
"What is your name?",
|
564 |
+
"what are your hobbies?",
|
565 |
+
"Do you like beijing",
|
566 |
+
"How are you feeling today?",
|
567 |
+
"what is the weather like today?",
|
568 |
+
]
|
569 |
+
test_text_list = [
|
570 |
+
"What is your name?",
|
571 |
+
"How are you feeling today?",
|
572 |
+
"Can you describe your surroundings?",
|
573 |
+
"What did you do yesterday?",
|
574 |
+
"What is your favorite book and why?",
|
575 |
+
"How do you make a cup of tea?",
|
576 |
+
"What is the weather like today?",
|
577 |
+
"Can you explain the concept of time?",
|
578 |
+
"Can you tell me a joke?",
|
579 |
+
]
|
580 |
+
|
581 |
+
# LOAD MODEL
|
582 |
+
with torch.no_grad():
|
583 |
+
if "A1A2" in task:
|
584 |
+
print("===============================================================")
|
585 |
+
print(" testing A1A2")
|
586 |
+
print("===============================================================")
|
587 |
+
step = 0
|
588 |
+
for path in test_audio_list:
|
589 |
+
try:
|
590 |
+
mel, leng = load_audio(path)
|
591 |
+
audio_feature, input_ids = get_input_ids_whisper(mel, leng, whispermodel, device)
|
592 |
+
text = A1_A2(
|
593 |
+
fabric,
|
594 |
+
audio_feature,
|
595 |
+
input_ids,
|
596 |
+
leng,
|
597 |
+
model,
|
598 |
+
text_tokenizer,
|
599 |
+
step,
|
600 |
+
snacmodel,
|
601 |
+
out_dir=out_dir,
|
602 |
+
)
|
603 |
+
print(f"input: {test_audio_transcripts[step]}")
|
604 |
+
print(f"output: {text}")
|
605 |
+
step += 1
|
606 |
+
print(
|
607 |
+
"+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++"
|
608 |
+
)
|
609 |
+
except:
|
610 |
+
print(f"[error] failed to process {path}")
|
611 |
+
print("===============================================================")
|
612 |
+
|
613 |
+
if 'asr' in task:
|
614 |
+
print("===============================================================")
|
615 |
+
print(" testing asr")
|
616 |
+
print("===============================================================")
|
617 |
+
|
618 |
+
index = 0
|
619 |
+
step = 0
|
620 |
+
for path in test_audio_list:
|
621 |
+
mel, leng = load_audio(path)
|
622 |
+
audio_feature, input_ids = get_input_ids_whisper(mel, leng, whispermodel, device, special_token_a=_pad_a, special_token_t=_asr)
|
623 |
+
output = A1_T1(fabric, audio_feature, input_ids ,leng, model, text_tokenizer, index).lower().replace(',','').replace('.','').replace('?','')
|
624 |
+
print(f"audio_path: {path}")
|
625 |
+
print(f"audio transcript: {test_audio_transcripts[index]}")
|
626 |
+
print(f"asr output: {output}")
|
627 |
+
print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
|
628 |
+
index += 1
|
629 |
+
|
630 |
+
if "T1A2" in task:
|
631 |
+
step = 0
|
632 |
+
print("\n")
|
633 |
+
print("===============================================================")
|
634 |
+
print(" testing T1A2")
|
635 |
+
print("===============================================================")
|
636 |
+
for text in test_text_list:
|
637 |
+
input_ids = get_input_ids_TA(text, text_tokenizer)
|
638 |
+
text_output = T1_A2(fabric, input_ids, model, text_tokenizer, step,
|
639 |
+
snacmodel, out_dir=out_dir)
|
640 |
+
print(f"input: {text}")
|
641 |
+
print(f"output: {text_output}")
|
642 |
+
print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
|
643 |
+
step += 1
|
644 |
+
print("===============================================================")
|
645 |
+
|
646 |
+
if "T1T2" in task:
|
647 |
+
step = 0
|
648 |
+
print("\n")
|
649 |
+
print("===============================================================")
|
650 |
+
print(" testing T1T2")
|
651 |
+
print("===============================================================")
|
652 |
+
|
653 |
+
for text in test_text_list:
|
654 |
+
input_ids = get_input_ids_TT(text, text_tokenizer)
|
655 |
+
text_output = T1_T2(fabric, input_ids, model, text_tokenizer, step)
|
656 |
+
print(f" Input: {text}")
|
657 |
+
print(f"Output: {text_output}")
|
658 |
+
print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
|
659 |
+
print("===============================================================")
|
660 |
+
|
661 |
+
if "AT" in task:
|
662 |
+
print("===============================================================")
|
663 |
+
print(" testing A1T2")
|
664 |
+
print("===============================================================")
|
665 |
+
step = 0
|
666 |
+
for path in test_audio_list:
|
667 |
+
mel, leng = load_audio(path)
|
668 |
+
audio_feature, input_ids = get_input_ids_whisper(
|
669 |
+
mel, leng, whispermodel, device,
|
670 |
+
special_token_a=_pad_a, special_token_t=_answer_t
|
671 |
+
)
|
672 |
+
text = A1_T2(
|
673 |
+
fabric, audio_feature, input_ids, leng, model, text_tokenizer, step
|
674 |
+
)
|
675 |
+
print(f"input: {test_audio_transcripts[step]}")
|
676 |
+
print(f"output: {text}")
|
677 |
+
step += 1
|
678 |
+
print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
|
679 |
+
print("===============================================================")
|
680 |
+
|
681 |
+
if "AA-BATCH" in task:
|
682 |
+
print("===============================================================")
|
683 |
+
print(" testing A1A2-BATCH")
|
684 |
+
print("===============================================================")
|
685 |
+
step = 0
|
686 |
+
for path in test_audio_list:
|
687 |
+
mel, leng = load_audio(path)
|
688 |
+
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device)
|
689 |
+
text = A1_A2_batch(
|
690 |
+
fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
691 |
+
snacmodel, out_dir=out_dir
|
692 |
+
)
|
693 |
+
print(f"input: {test_audio_transcripts[step]}")
|
694 |
+
print(f"output: {text}")
|
695 |
+
step += 1
|
696 |
+
print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
|
697 |
+
print("===============================================================")
|
698 |
+
|
699 |
+
print("*********************** test end *****************************")
|
700 |
+
|
701 |
+
|
702 |
+
|
703 |
+
if __name__ == "__main__":
|
704 |
+
test_infer()
|
inference_vision.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from litgpt.generate.base import next_token_image_batch
|
4 |
+
import soundfile as sf
|
5 |
+
from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
|
6 |
+
from utils.snac_utils import get_snac, generate_audio_data
|
7 |
+
import clip
|
8 |
+
import inference
|
9 |
+
from tqdm import tqdm
|
10 |
+
from inference import OmniInference, load_model, load_audio, download_model
|
11 |
+
from inference import text_vocabsize, padded_text_vocabsize, get_text_stream
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
|
15 |
+
torch.set_printoptions(sci_mode=False)
|
16 |
+
|
17 |
+
_image = inference._image
|
18 |
+
_eoimage = inference._eoimage
|
19 |
+
_pad_t = inference._pad_t
|
20 |
+
_input_t = inference._input_t
|
21 |
+
_answer_t = inference._answer_t
|
22 |
+
_eot = inference._eot
|
23 |
+
_eoa = inference._eoa
|
24 |
+
_pad_a = inference._pad_a
|
25 |
+
_input_a = inference._input_a
|
26 |
+
_answer_a = inference._answer_a
|
27 |
+
|
28 |
+
|
29 |
+
def get_input_ids_ImageQA_ATBatch(mel, leng, whispermodel, device):
|
30 |
+
|
31 |
+
with torch.no_grad():
|
32 |
+
mel = mel.unsqueeze(0).to(device)
|
33 |
+
audio_feature = whispermodel.embed_audio(mel)[0][:leng]
|
34 |
+
|
35 |
+
audio_len = audio_feature.size(0)
|
36 |
+
|
37 |
+
input_ids = []
|
38 |
+
input_ids_item = [[] for i in range(8)]
|
39 |
+
for i in range(7):
|
40 |
+
input_ids_item[i] = [layershift(_image,i)] + [layershift(_pad_a,i)] * 50 + [layershift(_eoimage,i)]
|
41 |
+
input_ids_item[i] += [layershift(_input_a,i)]+[layershift(_pad_a,i)]*(audio_len)+[layershift(_eoa,i)]
|
42 |
+
input_ids_item[i] += [layershift(_answer_a,i)]
|
43 |
+
|
44 |
+
input_ids_item[-1] = [_pad_t]* (52 + 2 + audio_len) + [_answer_t]
|
45 |
+
input_ids_item = [torch.tensor(item) for item in input_ids_item]
|
46 |
+
|
47 |
+
input_ids.append(input_ids_item)
|
48 |
+
|
49 |
+
input_ids_item = [[] for i in range(8)]
|
50 |
+
for i in range(7):
|
51 |
+
input_ids_item[i] = [layershift(_image,i)] + [layershift(_pad_a,i)] * 50 + [layershift(_eoimage,i)]
|
52 |
+
input_ids_item[i] += [layershift(_input_a,i)]+[layershift(_pad_a,i)]*(audio_len)+[layershift(_eoa,i)] + [layershift(_pad_a,i)]
|
53 |
+
|
54 |
+
input_ids_item[-1] = [_pad_t]* (52 + 2 + audio_len) + [_answer_t]
|
55 |
+
|
56 |
+
input_ids_item = [torch.tensor(item) for item in input_ids_item]
|
57 |
+
input_ids.append(input_ids_item)
|
58 |
+
|
59 |
+
stacked_inputids = [[] for _ in range(8)]
|
60 |
+
for i in range(2):
|
61 |
+
for j in range(8):
|
62 |
+
stacked_inputids[j].append(input_ids[i][j])
|
63 |
+
stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
|
64 |
+
|
65 |
+
return torch.stack([audio_feature,audio_feature]), stacked_inputids
|
66 |
+
|
67 |
+
|
68 |
+
def load_clip_model(ckpt_dir, device):
|
69 |
+
clip_model_path = ckpt_dir + "/ViT-B-32.pt"
|
70 |
+
if not os.path.exists(clip_model_path):
|
71 |
+
clip_model_path = "ViT-B/32"
|
72 |
+
clipmodel, clippreprocess = clip.load(clip_model_path, device=device)
|
73 |
+
return clipmodel, clippreprocess
|
74 |
+
|
75 |
+
|
76 |
+
class OmniVisionInference(OmniInference):
|
77 |
+
|
78 |
+
def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
|
79 |
+
self.device = device
|
80 |
+
if not os.path.exists(ckpt_dir):
|
81 |
+
print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
|
82 |
+
download_model(ckpt_dir)
|
83 |
+
self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device)
|
84 |
+
self.clipmodel, self.clippreprocess = load_clip_model(ckpt_dir, device)
|
85 |
+
|
86 |
+
def warm_up(self,
|
87 |
+
audio_sample='./data/samples/vision_qa_audio.wav',
|
88 |
+
image_sample='./data/samples/vision_qa_image.jpg'
|
89 |
+
):
|
90 |
+
for _ in self.run_vision_AA_batch_stream(audio_sample, image_sample,
|
91 |
+
save_path="./data/samples/vision_qa_output.wav",
|
92 |
+
warm_up=True):
|
93 |
+
pass
|
94 |
+
|
95 |
+
@torch.inference_mode()
|
96 |
+
def run_vision_AA_batch_stream(self, audio_path, image_path,
|
97 |
+
stream_stride=4,
|
98 |
+
max_returned_tokens=2048,
|
99 |
+
temperature=0.9,
|
100 |
+
top_k=1,
|
101 |
+
top_p=1.0,
|
102 |
+
eos_id_a=_eoa,
|
103 |
+
eos_id_t=_eot,
|
104 |
+
pad_id=_pad_t,
|
105 |
+
save_path=None,
|
106 |
+
warm_up=False
|
107 |
+
):
|
108 |
+
with self.fabric.init_tensor():
|
109 |
+
self.model.set_kv_cache(batch_size=2)
|
110 |
+
|
111 |
+
model = self.model
|
112 |
+
|
113 |
+
mel, leng = load_audio(audio_path)
|
114 |
+
img = Image.open(image_path)
|
115 |
+
|
116 |
+
audio_feature, input_ids = get_input_ids_ImageQA_ATBatch(mel, leng, self.whispermodel, self.device)
|
117 |
+
ima = self.clippreprocess(img).unsqueeze(0).to(self.device)
|
118 |
+
ima_feature = self.clipmodel.encode_image(ima).squeeze(0).to(self.device)
|
119 |
+
|
120 |
+
ima_feature = torch.stack([ima_feature.clone(),ima_feature.clone()]).to(self.device)
|
121 |
+
leng = [leng,leng]
|
122 |
+
task = ['ImageQA_A','ImageQA_AT']
|
123 |
+
|
124 |
+
T = input_ids[0].size(1)
|
125 |
+
assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}"
|
126 |
+
|
127 |
+
if model.max_seq_length < max_returned_tokens - 1:
|
128 |
+
raise NotImplementedError(
|
129 |
+
f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
|
130 |
+
)
|
131 |
+
|
132 |
+
list_output = [[] for i in range(8)]
|
133 |
+
|
134 |
+
tokens_A , token_T = next_token_image_batch(
|
135 |
+
model,
|
136 |
+
audio_feature.to(torch.float32).to(self.device),
|
137 |
+
ima_feature.to(torch.float32).to(self.device) ,
|
138 |
+
input_ids ,
|
139 |
+
whisper_lens = leng ,
|
140 |
+
task = task,
|
141 |
+
input_pos = torch.arange(0, T, device=self.device),
|
142 |
+
temperature=temperature,
|
143 |
+
top_k=top_k,
|
144 |
+
top_p=top_p
|
145 |
+
)
|
146 |
+
for i in range(7): list_output[i].append(tokens_A[i].tolist()[0])
|
147 |
+
list_output[7].append(token_T.tolist()[0])
|
148 |
+
|
149 |
+
text_end = False
|
150 |
+
index = 1
|
151 |
+
nums_generate = stream_stride
|
152 |
+
begin_generate = False
|
153 |
+
current_index = 0
|
154 |
+
input_pos = torch.tensor([T], device=self.device)
|
155 |
+
|
156 |
+
model_input_ids = [[] for i in range(8)]
|
157 |
+
for i in range(7):
|
158 |
+
tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize+ i * 4160
|
159 |
+
model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32))
|
160 |
+
model_input_ids[i].append(torch.tensor([layershift(4097,i)],device=self.device))
|
161 |
+
model_input_ids[i] = torch.stack(model_input_ids[i])
|
162 |
+
|
163 |
+
model_input_ids[-1].append(token_T.clone().to(torch.int32))
|
164 |
+
model_input_ids[-1].append(token_T.clone().to(torch.int32))
|
165 |
+
model_input_ids[-1] = torch.stack(model_input_ids[-1])
|
166 |
+
|
167 |
+
text_index = 0
|
168 |
+
is_text_end = False
|
169 |
+
|
170 |
+
for _ in tqdm(range(2, max_returned_tokens - T + 1)):
|
171 |
+
|
172 |
+
tokens_A , token_T = next_token_image_batch(model, None , None ,
|
173 |
+
input_ids = model_input_ids,
|
174 |
+
whisper_lens= None,
|
175 |
+
task = None,
|
176 |
+
input_pos = input_pos,
|
177 |
+
temperature=temperature,
|
178 |
+
top_k=top_k,
|
179 |
+
top_p=top_p)
|
180 |
+
|
181 |
+
if text_end:
|
182 |
+
token_T = torch.tensor([_pad_t], device=self.device)
|
183 |
+
|
184 |
+
if tokens_A[-1] == eos_id_a:
|
185 |
+
break
|
186 |
+
if token_T == eos_id_t:
|
187 |
+
text_end = True
|
188 |
+
|
189 |
+
for i in range(7): list_output[i].append(tokens_A[i].tolist()[0])
|
190 |
+
list_output[7].append(token_T.tolist()[0])
|
191 |
+
|
192 |
+
|
193 |
+
if index == 7:
|
194 |
+
begin_generate = True
|
195 |
+
|
196 |
+
if begin_generate:
|
197 |
+
current_index += 1
|
198 |
+
if current_index == nums_generate:
|
199 |
+
current_index = 0
|
200 |
+
snac = get_snac(list_output,index,nums_generate)
|
201 |
+
audio_stream = generate_audio_data(snac, self.snacmodel, self.device)
|
202 |
+
if is_text_end:
|
203 |
+
text_stream = ""
|
204 |
+
else:
|
205 |
+
text_stream, text_index, is_text_end = get_text_stream(list_output, text_index, self.text_tokenizer)
|
206 |
+
|
207 |
+
yield (audio_stream, text_stream)
|
208 |
+
|
209 |
+
if warm_up:
|
210 |
+
break
|
211 |
+
|
212 |
+
input_pos = input_pos.add_(1)
|
213 |
+
model_input_ids = [[] for i in range(8)]
|
214 |
+
for i in range(7):
|
215 |
+
tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize+ i * 4160
|
216 |
+
model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32))
|
217 |
+
model_input_ids[i].append(torch.tensor([layershift(4097,i)],device=self.device))
|
218 |
+
model_input_ids[i] = torch.stack(model_input_ids[i])
|
219 |
+
|
220 |
+
model_input_ids[-1].append(token_T.clone().to(torch.int32))
|
221 |
+
model_input_ids[-1].append(token_T.clone().to(torch.int32))
|
222 |
+
model_input_ids[-1] = torch.stack(model_input_ids[-1])
|
223 |
+
|
224 |
+
index += 1
|
225 |
+
|
226 |
+
text_tokens = list_output[-1]
|
227 |
+
if text_vocabsize in text_tokens:
|
228 |
+
text_tokens = text_tokens[:text_tokens.index(text_vocabsize)]
|
229 |
+
res_text = self.text_tokenizer.decode(torch.tensor(text_tokens))
|
230 |
+
print(f"text output: {res_text}")
|
231 |
+
|
232 |
+
if save_path is not None:
|
233 |
+
audiolist = reconscruct_snac(list_output)
|
234 |
+
audio = reconstruct_tensors(audiolist)
|
235 |
+
with torch.inference_mode():
|
236 |
+
audio_hat = self.snacmodel.decode(audio)
|
237 |
+
sf.write(save_path, audio_hat.squeeze().cpu().numpy(), 24000)
|
238 |
+
|
239 |
+
model.clear_kv_cache()
|
240 |
+
|
241 |
+
|
242 |
+
def test_vision_infer():
|
243 |
+
client = OmniVisionInference()
|
244 |
+
client.warm_up()
|
245 |
+
input_audio_path = './data/samples/vision_qa_audio.wav'
|
246 |
+
input_image_path = './data/samples/vision_qa_image.jpg'
|
247 |
+
|
248 |
+
res_text = ""
|
249 |
+
for audio_stream, text_stream in client.run_vision_AA_batch_stream(
|
250 |
+
input_audio_path,
|
251 |
+
input_image_path,
|
252 |
+
save_path="./vision_qa_output.wav"
|
253 |
+
):
|
254 |
+
res_text += text_stream
|
255 |
+
print(f"text_output: {res_text}")
|
256 |
+
|
257 |
+
|
258 |
+
if __name__ == "__main__":
|
259 |
+
test_vision_infer()
|
requirements.txt
CHANGED
@@ -1,4 +1,20 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.3.1
|
2 |
+
torchvision==0.18.1
|
3 |
+
torchaudio==2.3.1
|
4 |
+
litgpt==0.4.3
|
5 |
+
snac==1.2.0
|
6 |
+
soundfile==0.12.1
|
7 |
+
openai-whisper
|
8 |
+
tokenizers==0.19.1
|
9 |
+
streamlit==1.37.1
|
10 |
+
streamlit-webrtc
|
11 |
+
# PyAudio==0.2.14
|
12 |
+
pydub==0.25.1
|
13 |
+
onnxruntime==1.19.0
|
14 |
+
# numpy==1.26.3
|
15 |
+
librosa==0.10.2.post1
|
16 |
+
flask==3.0.3
|
17 |
+
fire
|
18 |
+
git+https://github.com/mini-omni/CLIP.git
|
19 |
+
gradio_webrtc[vad]==0.0.11
|
20 |
+
twilio
|
server.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
5 |
+
|
6 |
+
|
7 |
+
import flask
|
8 |
+
import base64
|
9 |
+
import tempfile
|
10 |
+
import traceback
|
11 |
+
from flask import Flask, Response, stream_with_context
|
12 |
+
from inference_vision import OmniVisionInference
|
13 |
+
|
14 |
+
|
15 |
+
class OmniChatServer(object):
|
16 |
+
def __init__(self, ip='0.0.0.0', port=60808, run_app=True,
|
17 |
+
ckpt_dir='./checkpoint', device='cuda:0') -> None:
|
18 |
+
server = Flask(__name__)
|
19 |
+
# CORS(server, resources=r"/*")
|
20 |
+
# server.config["JSON_AS_ASCII"] = False
|
21 |
+
|
22 |
+
self.client = OmniVisionInference(ckpt_dir, device)
|
23 |
+
self.client.warm_up()
|
24 |
+
|
25 |
+
server.route("/chat", methods=["POST"])(self.chat)
|
26 |
+
|
27 |
+
if run_app:
|
28 |
+
server.run(host=ip, port=port, threaded=False)
|
29 |
+
else:
|
30 |
+
self.server = server
|
31 |
+
|
32 |
+
def chat(self) -> Response:
|
33 |
+
|
34 |
+
req_data = flask.request.get_json()
|
35 |
+
try:
|
36 |
+
audio_data_buf = req_data["audio"].encode("utf-8")
|
37 |
+
audio_data_buf = base64.b64decode(audio_data_buf)
|
38 |
+
stream_stride = req_data.get("stream_stride", 4)
|
39 |
+
max_tokens = req_data.get("max_tokens", 2048)
|
40 |
+
|
41 |
+
image_data_buf = req_data.get("image", None)
|
42 |
+
if image_data_buf:
|
43 |
+
image_data_buf = image_data_buf.encode("utf-8")
|
44 |
+
image_data_buf = base64.b64decode(image_data_buf)
|
45 |
+
|
46 |
+
audio_path, img_path = None, None
|
47 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_f, \
|
48 |
+
tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as img_f:
|
49 |
+
audio_f.write(audio_data_buf)
|
50 |
+
audio_path = audio_f.name
|
51 |
+
|
52 |
+
if image_data_buf:
|
53 |
+
img_f.write(image_data_buf)
|
54 |
+
img_path = img_f.name
|
55 |
+
else:
|
56 |
+
img_path = None
|
57 |
+
|
58 |
+
if img_path is not None:
|
59 |
+
resp_generator = self.client.run_vision_AA_batch_stream(audio_f.name, img_f.name,
|
60 |
+
stream_stride, max_tokens,
|
61 |
+
save_path='./vision_qa_out_cache.wav')
|
62 |
+
else:
|
63 |
+
resp_generator = self.client.run_AT_batch_stream(audio_f.name, stream_stride,
|
64 |
+
max_tokens,
|
65 |
+
save_path='./audio_qa_out_cache.wav')
|
66 |
+
return Response(stream_with_context(self.generator(resp_generator)),
|
67 |
+
mimetype='multipart/x-mixed-replace; boundary=frame')
|
68 |
+
except Exception as e:
|
69 |
+
print(traceback.format_exc())
|
70 |
+
return Response("An error occurred", status=500)
|
71 |
+
|
72 |
+
def generator(self, resp_generator):
|
73 |
+
for audio_stream, text_stream in resp_generator:
|
74 |
+
yield b'\r\n--frame\r\n'
|
75 |
+
yield b'Content-Type: audio/wav\r\n\r\n'
|
76 |
+
yield audio_stream
|
77 |
+
yield b'\r\n--frame\r\n'
|
78 |
+
yield b'Content-Type: text/plain\r\n\r\n'
|
79 |
+
yield text_stream.encode()
|
80 |
+
|
81 |
+
|
82 |
+
# CUDA_VISIBLE_DEVICES=1 gunicorn -w 2 -b 0.0.0.0:60808 'server:create_app()'
|
83 |
+
def create_app():
|
84 |
+
server = OmniChatServer(run_app=False)
|
85 |
+
return server.server
|
86 |
+
|
87 |
+
|
88 |
+
def serve(ip='0.0.0.0', port=60808, device='cuda:0'):
|
89 |
+
|
90 |
+
OmniChatServer(ip, port=port,run_app=True, device=device)
|
91 |
+
|
92 |
+
|
93 |
+
if __name__ == "__main__":
|
94 |
+
import fire
|
95 |
+
fire.Fire(serve)
|
96 |
+
|