akjedidtz commited on
Commit
8575608
1 Parent(s): 022aa5e

Upload 9 files

Browse files
Files changed (9) hide show
  1. .gitignore +171 -0
  2. Dockerfile +19 -0
  3. LICENSE +21 -0
  4. README.md +153 -13
  5. __init__.py +4 -0
  6. inference.py +704 -0
  7. inference_vision.py +259 -0
  8. requirements.txt +20 -4
  9. server.py +96 -0
.gitignore ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ *pyc
3
+ *pth
4
+ checkpoint*/
5
+ output/
6
+ .DS_Store
7
+ *.wav
8
+ *.webp
9
+ *.jpg
10
+
11
+ __pycache__/
12
+ *.py[cod]
13
+ *$py.class
14
+
15
+ # C extensions
16
+ *.so
17
+
18
+ # Distribution / packaging
19
+ .Python
20
+ build/
21
+ develop-eggs/
22
+ dist/
23
+ downloads/
24
+ eggs/
25
+ .eggs/
26
+ lib/
27
+ lib64/
28
+ parts/
29
+ sdist/
30
+ var/
31
+ wheels/
32
+ share/python-wheels/
33
+ *.egg-info/
34
+ .installed.cfg
35
+ *.egg
36
+ MANIFEST
37
+
38
+ # PyInstaller
39
+ # Usually these files are written by a python script from a template
40
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
41
+ *.manifest
42
+ *.spec
43
+
44
+ # Installer logs
45
+ pip-log.txt
46
+ pip-delete-this-directory.txt
47
+
48
+ # Unit test / coverage reports
49
+ htmlcov/
50
+ .tox/
51
+ .nox/
52
+ .coverage
53
+ .coverage.*
54
+ .cache
55
+ nosetests.xml
56
+ coverage.xml
57
+ *.cover
58
+ *.py,cover
59
+ .hypothesis/
60
+ .pytest_cache/
61
+ cover/
62
+
63
+ # Translations
64
+ *.mo
65
+ *.pot
66
+
67
+ # Django stuff:
68
+ *.log
69
+ local_settings.py
70
+ db.sqlite3
71
+ db.sqlite3-journal
72
+
73
+ # Flask stuff:
74
+ instance/
75
+ .webassets-cache
76
+
77
+ # Scrapy stuff:
78
+ .scrapy
79
+
80
+ # Sphinx documentation
81
+ docs/_build/
82
+
83
+ # PyBuilder
84
+ .pybuilder/
85
+ target/
86
+
87
+ # Jupyter Notebook
88
+ .ipynb_checkpoints
89
+
90
+ # IPython
91
+ profile_default/
92
+ ipython_config.py
93
+
94
+ # pyenv
95
+ # For a library or package, you might want to ignore these files since the code is
96
+ # intended to run in multiple environments; otherwise, check them in:
97
+ # .python-version
98
+
99
+ # pipenv
100
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
101
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
102
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
103
+ # install all needed dependencies.
104
+ #Pipfile.lock
105
+
106
+ # poetry
107
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
108
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
109
+ # commonly ignored for libraries.
110
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
111
+ #poetry.lock
112
+
113
+ # pdm
114
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
115
+ #pdm.lock
116
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
117
+ # in version control.
118
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
119
+ .pdm.toml
120
+ .pdm-python
121
+ .pdm-build/
122
+
123
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
124
+ __pypackages__/
125
+
126
+ # Celery stuff
127
+ celerybeat-schedule
128
+ celerybeat.pid
129
+
130
+ # SageMath parsed files
131
+ *.sage.py
132
+
133
+ # Environments
134
+ .env
135
+ .venv
136
+ env/
137
+ venv/
138
+ ENV/
139
+ env.bak/
140
+ venv.bak/
141
+
142
+ # Spyder project settings
143
+ .spyderproject
144
+ .spyproject
145
+
146
+ # Rope project settings
147
+ .ropeproject
148
+
149
+ # mkdocs documentation
150
+ /site
151
+
152
+ # mypy
153
+ .mypy_cache/
154
+ .dmypy.json
155
+ dmypy.json
156
+
157
+ # Pyre type checker
158
+ .pyre/
159
+
160
+ # pytype static type analyzer
161
+ .pytype/
162
+
163
+ # Cython debug symbols
164
+ cython_debug/
165
+
166
+ # PyCharm
167
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
168
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
169
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
170
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
171
+ #.idea/
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base image set karein jo Python 3.10 par hai
2
+ FROM python:3.10
3
+
4
+ # Working directory set karein
5
+ WORKDIR /app
6
+
7
+ # Local files ko container mein copy karein
8
+ COPY . .
9
+
10
+ # Linux packages install karein
11
+ RUN apt-get update && \
12
+ apt-get install -y ffmpeg
13
+
14
+ # Python dependencies install karein
15
+ RUN pip install --no-cache-dir -r requirements.txt && \
16
+ pip install PyAudio==0.2.14
17
+
18
+ # Server aur Streamlit app run karne ke liye commands
19
+ CMD ["bash", "-c", "python3 server.py --ip '0.0.0.0' --port 60808 & API_URL=http://0.0.0.0:60808/chat streamlit run webui/omni_streamlit.py --server.port=8501"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 gpt-omni
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,153 @@
1
- ---
2
- title: Bo2
3
- emoji: 💬
4
- colorFrom: yellow
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.0.1
8
- app_file: app.py
9
- pinned: false
10
- license: llama3
11
- ---
12
-
13
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Mini-Omni2
3
+
4
+ <p align="center">
5
+ <img src="./data/figures/title_new.png" width="90%"/>
6
+ </p>
7
+
8
+
9
+ <p align="center">
10
+ 🤗 <a href="https://huggingface.co/gpt-omni/mini-omni2">Hugging Face</a> | 📖 <a href="https://github.com/gpt-omni/mini-omni2">Github</a>
11
+ | 📑 <a href="https://arxiv.org/abs/2410.11190">Technical report</a>
12
+ </p>
13
+
14
+ <h5 align="center"> If you like little Omni2, please give us a star⭐ and cite our <a href="(https://arxiv.org/abs/2410.11190)">paper</a>!</h2>
15
+
16
+ ## Introduction
17
+ Mini-Omni2 is an **omni-interactive** model. It can **understand image, audio and text inputs and has end-to-end voice conversations with users**. Featuring **real-time voice output**, **omni-capable multimodal understanding** and flexible interaction **ability with interruption mechanism while speaking**.
18
+
19
+ <p align="center">
20
+ <img src="./data/figures/framework.jpeg" width="100%"/>
21
+ </p>
22
+
23
+
24
+ ## Updates
25
+
26
+ - **2024.10:** Release the model, technical report, inference and chat demo code.
27
+
28
+ ## Features
29
+ ✅ **Multimodal interaction**: with the ability to understand images, speech and text, just like GPT-4o.
30
+
31
+ ✅ **Real-time speech-to-speech** conversational capabilities. No extra ASR or TTS models required, just like [Mini-Omni](https://github.com/gpt-omni/mini-omni).
32
+
33
+ <!-- ✅ **Streaming audio output**: with first-chunk latency of audio stream less than 0.3s. -->
34
+
35
+ <!-- ✅ **Duplex interaction**: hearing while speaking, it can be interrupted by key words like "stop omni". -->
36
+
37
+
38
+ ## Demo
39
+
40
+ NOTE: need to unmute first.
41
+
42
+ https://github.com/user-attachments/assets/ad97ca7f-f8b4-40c3-a7e8-fa54b4edf155
43
+
44
+
45
+ ## ToDo
46
+ - [ ] update interruption mechanism
47
+
48
+
49
+ ## Install
50
+
51
+ Create a new conda environment and install the required packages:
52
+
53
+ ```sh
54
+ conda create -n omni python=3.10
55
+ conda activate omni
56
+
57
+ git clone https://github.com/gpt-omni/mini-omni2.git
58
+ cd mini-omni2
59
+ pip install -r requirements.txt
60
+ ```
61
+
62
+ ## Quick start
63
+
64
+ **Interactive demo**
65
+
66
+ - start server
67
+
68
+ NOTE: you need to start the server before running the streamlit or gradio demo with API_URL set to the server address.
69
+
70
+ ```sh
71
+ sudo apt-get install ffmpeg
72
+ conda activate omni
73
+ cd mini-omni2
74
+ python3 server.py --ip '0.0.0.0' --port 60808
75
+ ```
76
+
77
+
78
+ - run streamlit demo
79
+
80
+ NOTE: you need to run streamlit **locally** with PyAudio installed.
81
+
82
+ ```sh
83
+ pip install PyAudio==0.2.14
84
+ API_URL=http://0.0.0.0:60808/chat streamlit run webui/omni_streamlit.py
85
+ ```
86
+
87
+
88
+ **Local test**
89
+
90
+ ```sh
91
+ conda activate omni
92
+ cd mini-omni2
93
+ # test run the preset audio samples and questions
94
+ python inference_vision.py
95
+ ```
96
+
97
+ ## Mini-Omni2 Overview
98
+
99
+ **1. Multimodal Modeling**:
100
+ We use multiple sequences as the input and output of the model. In the input part, we will concatenate image, audio and text features to perform a series of comprehensive tasks, as shown in the following figures. In the output part, we use text-guided delayed parallel output to generate real-time speech responses.
101
+ <p align="center">
102
+ <img src="./data/figures/inputids.png" width="100%"/>
103
+ </p>
104
+
105
+ **2. Multi-stage Training**:
106
+ We propose an efficient alignment training method and conduct encoder adaptation, modal alignment, and multimodal fine-tuning respectively in the three-stage training.
107
+ <p align="center">
108
+ <img src="./data/figures/training.jpeg" width="100%"/>
109
+ </p>
110
+
111
+ <!-- **3. Cases**:
112
+ Here are more cases of Mini-Omni2:
113
+ <p align="center">
114
+ <img src="./data/figures/samples.png" width="100%"/>
115
+ </p> -->
116
+
117
+ ## FAQ
118
+
119
+ **1. Does the model support other languages?**
120
+
121
+ No, the model is only trained on English. However, as we use whisper as the audio encoder, the model can understand other languages which is supported by whisper (like chinese), but the output is only in English.
122
+
123
+ **2. Error: can not run streamlit in local browser, with remote streamlit server**
124
+
125
+ You need start streamlit **locally** with PyAudio installed.
126
+
127
+
128
+ ## Acknowledgements
129
+
130
+ - [Qwen2](https://github.com/QwenLM/Qwen2/) as the LLM backbone.
131
+ - [litGPT](https://github.com/Lightning-AI/litgpt/) for training and inference.
132
+ - [whisper](https://github.com/openai/whisper/) for audio encoding.
133
+ - [clip](https://github.com/openai/CLIP) for image encoding.
134
+ - [snac](https://github.com/hubertsiuzdak/snac/) for audio decoding.
135
+ - [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) for generating synthetic speech.
136
+ - [OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca) and [MOSS](https://github.com/OpenMOSS/MOSS/tree/main) for alignment.
137
+
138
+ ## Citation
139
+
140
+ ```bibtex
141
+ @article{xie2024miniomni2opensourcegpt4ovision,
142
+ title={Mini-Omni2: Towards Open-source GPT-4o with Vision, Speech and Duplex Capabilities},
143
+ author={Zhifei Xie and Changqiao Wu},
144
+ year={2024},
145
+ eprint={2410.11190},
146
+ archivePrefix={arXiv},
147
+ primaryClass={eess.AS},
148
+ journal={ArXiv},
149
+ volume={abs/2410.11190},
150
+ }
151
+ ```
152
+ ## Star History
153
+ [![Star History Chart](https://api.star-history.com/svg?repos=gpt-omni/mini-omni2&type=Date)](https://star-history.com/#gpt-omni/mini-omni2&Date)
__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
inference.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import lightning as L
3
+ import torch
4
+ import glob
5
+ import time
6
+ from snac import SNAC
7
+ from litgpt import Tokenizer
8
+ from litgpt.utils import (
9
+ num_parameters,
10
+ )
11
+ from litgpt.generate.base import (
12
+ generate_AA,
13
+ generate_ASR,
14
+ generate_TA,
15
+ generate_TT,
16
+ generate_AT,
17
+ generate_TA_BATCH,
18
+ next_token_image_batch
19
+ )
20
+ import soundfile as sf
21
+ from litgpt.model import GPT, Config
22
+ from lightning.fabric.utilities.load import _lazy_load as lazy_load
23
+ from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
24
+ from utils.snac_utils import get_snac, generate_audio_data
25
+ import whisper
26
+ from tqdm import tqdm
27
+ from huggingface_hub import snapshot_download
28
+
29
+
30
+ torch.set_printoptions(sci_mode=False)
31
+
32
+
33
+ # TODO
34
+ text_vocabsize = 151936
35
+ text_specialtokens = 64
36
+ audio_vocabsize = 4096
37
+ audio_specialtokens = 64
38
+
39
+ padded_text_vocabsize = text_vocabsize + text_specialtokens
40
+ padded_audio_vocabsize = audio_vocabsize + audio_specialtokens
41
+
42
+ _eot = text_vocabsize
43
+ _pad_t = text_vocabsize + 1
44
+ _input_t = text_vocabsize + 2
45
+ _answer_t = text_vocabsize + 3
46
+ _asr = text_vocabsize + 4
47
+
48
+ _eoa = audio_vocabsize
49
+ _pad_a = audio_vocabsize + 1
50
+ _input_a = audio_vocabsize + 2
51
+ _answer_a = audio_vocabsize + 3
52
+ _split = audio_vocabsize + 4
53
+ _image = audio_vocabsize + 5
54
+ _eoimage = audio_vocabsize + 6
55
+
56
+
57
+ def get_input_ids_TA(text, text_tokenizer):
58
+ input_ids_item = [[] for _ in range(8)]
59
+ text_tokens = text_tokenizer.encode(text)
60
+ for i in range(7):
61
+ input_ids_item[i] = [layershift(_pad_a, i)] * (len(text_tokens) + 2) + [
62
+ layershift(_answer_a, i)
63
+ ]
64
+ input_ids_item[i] = torch.tensor(input_ids_item[i]).unsqueeze(0)
65
+ input_ids_item[-1] = [_input_t] + text_tokens.tolist() + [_eot] + [_answer_t]
66
+ input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)
67
+ return input_ids_item
68
+
69
+
70
+ def get_input_ids_TT(text, text_tokenizer):
71
+ input_ids_item = [[] for i in range(8)]
72
+ text_tokens = text_tokenizer.encode(text).tolist()
73
+
74
+ for i in range(7):
75
+ input_ids_item[i] = torch.tensor(
76
+ [layershift(_pad_a, i)] * (len(text_tokens) + 3)
77
+ ).unsqueeze(0)
78
+ input_ids_item[-1] = [_input_t] + text_tokens + [_eot] + [_answer_t]
79
+ input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)
80
+
81
+ return input_ids_item
82
+
83
+
84
+ def get_input_ids_whisper(
85
+ mel, leng, whispermodel, device,
86
+ special_token_a=_answer_a, special_token_t=_answer_t,
87
+ ):
88
+
89
+ with torch.no_grad():
90
+ mel = mel.unsqueeze(0).to(device)
91
+ # audio_feature = whisper.decode(whispermodel,mel, options).audio_features
92
+ audio_feature = whispermodel.embed_audio(mel)[0][:leng]
93
+
94
+ T = audio_feature.size(0)
95
+ input_ids = []
96
+ for i in range(7):
97
+ input_ids_item = []
98
+ input_ids_item.append(layershift(_input_a, i))
99
+ input_ids_item += [layershift(_pad_a, i)] * T
100
+ input_ids_item += [(layershift(_eoa, i)), layershift(special_token_a, i)]
101
+ input_ids.append(torch.tensor(input_ids_item).unsqueeze(0))
102
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, special_token_t])
103
+ input_ids.append(input_id_T.unsqueeze(0))
104
+ return audio_feature.unsqueeze(0), input_ids
105
+
106
+
107
+ def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
108
+ with torch.no_grad():
109
+ mel = mel.unsqueeze(0).to(device)
110
+ # audio_feature = whisper.decode(whispermodel,mel, options).audio_features
111
+ audio_feature = whispermodel.embed_audio(mel)[0][:leng]
112
+ T = audio_feature.size(0)
113
+ input_ids_AA = []
114
+ for i in range(7):
115
+ input_ids_item = []
116
+ input_ids_item.append(layershift(_input_a, i))
117
+ input_ids_item += [layershift(_pad_a, i)] * T
118
+ input_ids_item += [(layershift(_eoa, i)), layershift(_answer_a, i)]
119
+ input_ids_AA.append(torch.tensor(input_ids_item))
120
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
121
+ input_ids_AA.append(input_id_T)
122
+
123
+ input_ids_AT = []
124
+ for i in range(7):
125
+ input_ids_item = []
126
+ input_ids_item.append(layershift(_input_a, i))
127
+ input_ids_item += [layershift(_pad_a, i)] * T
128
+ input_ids_item += [(layershift(_eoa, i)), layershift(_pad_a, i)]
129
+ input_ids_AT.append(torch.tensor(input_ids_item))
130
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
131
+ input_ids_AT.append(input_id_T)
132
+
133
+ input_ids = [input_ids_AA, input_ids_AT]
134
+ stacked_inputids = [[] for _ in range(8)]
135
+ for i in range(2):
136
+ for j in range(8):
137
+ stacked_inputids[j].append(input_ids[i][j])
138
+ stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
139
+ return torch.stack([audio_feature, audio_feature]), stacked_inputids
140
+
141
+
142
+ def load_audio(path):
143
+ audio = whisper.load_audio(path)
144
+ duration_ms = (len(audio) / 16000) * 1000
145
+ audio = whisper.pad_or_trim(audio)
146
+ mel = whisper.log_mel_spectrogram(audio)
147
+ return mel, int(duration_ms / 20) + 1
148
+
149
+
150
+ def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
151
+ snacmodel, out_dir=None):
152
+ with fabric.init_tensor():
153
+ model.set_kv_cache(batch_size=2)
154
+ tokenlist = generate_TA_BATCH(
155
+ model,
156
+ audio_feature,
157
+ input_ids,
158
+ [leng, leng],
159
+ ["A1A2", "A1T2"],
160
+ max_returned_tokens=2048,
161
+ temperature=0.9,
162
+ top_k=1,
163
+ eos_id_a=_eoa,
164
+ eos_id_t=_eot,
165
+ pad_id_t=_pad_t,
166
+ shift=padded_text_vocabsize,
167
+ include_prompt=True,
168
+ generate_text=True,
169
+ )
170
+ text_tokenlist = tokenlist[-1]
171
+ if text_vocabsize in text_tokenlist:
172
+ text_tokenlist = text_tokenlist[: text_tokenlist.index(text_vocabsize)]
173
+ text = text_tokenizer.decode(torch.tensor(text_tokenlist)).strip()
174
+
175
+ audio_tokenlist = tokenlist[:-1]
176
+ audiolist = reconscruct_snac(audio_tokenlist)
177
+ audio = reconstruct_tensors(audiolist)
178
+ if out_dir is None:
179
+ out_dir = "./output/default/A1-A2-batch"
180
+ else:
181
+ out_dir = out_dir + "/A1-A2-batch"
182
+ if not os.path.exists(out_dir):
183
+ os.makedirs(out_dir)
184
+ with torch.inference_mode():
185
+ audio_hat = snacmodel.decode(audio)
186
+ sf.write(
187
+ f"{out_dir}/{step:02d}.wav",
188
+ audio_hat.squeeze().cpu().numpy(),
189
+ 24000,
190
+ )
191
+ model.clear_kv_cache()
192
+ return text
193
+
194
+
195
+ def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
196
+ with fabric.init_tensor():
197
+ model.set_kv_cache(batch_size=1)
198
+ tokenlist = generate_AT(
199
+ model,
200
+ audio_feature,
201
+ input_ids,
202
+ [leng],
203
+ ["AT"],
204
+ max_returned_tokens=2048,
205
+ temperature=0.9,
206
+ top_k=1,
207
+ eos_id_a=_eoa,
208
+ eos_id_t=_eot,
209
+ pad_id_t=_pad_t,
210
+ shift=padded_text_vocabsize,
211
+ include_prompt=True,
212
+ generate_text=True,
213
+ )
214
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
215
+
216
+
217
+ def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
218
+ snacmodel, out_dir=None):
219
+ with fabric.init_tensor():
220
+ model.set_kv_cache(batch_size=1)
221
+ tokenlist = generate_AA(
222
+ model,
223
+ audio_feature,
224
+ input_ids,
225
+ [leng],
226
+ ["A1T2"],
227
+ max_returned_tokens=2048,
228
+ temperature=0.9,
229
+ top_k=1,
230
+ eos_id_a=_eoa,
231
+ eos_id_t=_eot,
232
+ pad_id_t=_pad_t,
233
+ shift=padded_text_vocabsize,
234
+ include_prompt=True,
235
+ generate_text=True,
236
+ )
237
+ audiolist = reconscruct_snac(tokenlist)
238
+ tokenlist = tokenlist[-1]
239
+ if text_vocabsize in tokenlist:
240
+ tokenlist = tokenlist[: tokenlist.index(text_vocabsize)]
241
+ if out_dir is None:
242
+ out_dir = "./output/default/A1-A2"
243
+ else:
244
+ out_dir = out_dir + "/A1-A2"
245
+ if not os.path.exists(out_dir):
246
+ os.makedirs(out_dir)
247
+
248
+ audio = reconstruct_tensors(audiolist)
249
+ with torch.inference_mode():
250
+ audio_hat = snacmodel.decode(audio)
251
+ sf.write(
252
+ f"{out_dir}/{step:02d}.wav",
253
+ audio_hat.squeeze().cpu().numpy(),
254
+ 24000,
255
+ )
256
+ model.clear_kv_cache()
257
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
258
+
259
+
260
+ def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
261
+ with fabric.init_tensor():
262
+ model.set_kv_cache(batch_size=1)
263
+ tokenlist = generate_ASR(
264
+ model,
265
+ audio_feature,
266
+ input_ids,
267
+ [leng],
268
+ ["A1T1"],
269
+ max_returned_tokens=2048,
270
+ temperature=0.9,
271
+ top_k=1,
272
+ eos_id_a=_eoa,
273
+ eos_id_t=_eot,
274
+ pad_id_t=_pad_t,
275
+ shift=padded_text_vocabsize,
276
+ include_prompt=True,
277
+ generate_text=True,
278
+ )
279
+ model.clear_kv_cache()
280
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
281
+
282
+
283
+ def T1_A2(fabric, input_ids, model, text_tokenizer, step,
284
+ snacmodel, out_dir=None):
285
+ with fabric.init_tensor():
286
+ model.set_kv_cache(batch_size=1)
287
+ tokenlist = generate_TA(
288
+ model,
289
+ None,
290
+ input_ids,
291
+ None,
292
+ ["T1A2"],
293
+ max_returned_tokens=2048,
294
+ temperature=0.9,
295
+ top_k=1,
296
+ eos_id_a=_eoa,
297
+ eos_id_t=_eot,
298
+ pad_id_t=_pad_t,
299
+ shift=padded_text_vocabsize,
300
+ include_prompt=True,
301
+ generate_text=True,
302
+ )
303
+
304
+ audiolist = reconscruct_snac(tokenlist)
305
+ tokenlist = tokenlist[-1]
306
+
307
+ if text_vocabsize in tokenlist:
308
+ tokenlist = tokenlist[: tokenlist.index(text_vocabsize)]
309
+ audio = reconstruct_tensors(audiolist)
310
+ if out_dir is None:
311
+ out_dir = "./output/default/T1-A2"
312
+ else:
313
+ out_dir = out_dir + "/T1-A2"
314
+ if not os.path.exists(out_dir):
315
+ os.makedirs(out_dir)
316
+
317
+ with torch.inference_mode():
318
+ audio_hat = snacmodel.decode(audio)
319
+ sf.write(
320
+ f"{out_dir}/{step:02d}.wav",
321
+ audio_hat.squeeze().cpu().numpy(),
322
+ 24000,
323
+ )
324
+ model.clear_kv_cache()
325
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
326
+
327
+
328
+ def T1_T2(fabric, input_ids, model, text_tokenizer, step):
329
+
330
+ with fabric.init_tensor():
331
+ model.set_kv_cache(batch_size=1)
332
+ tokenlist = generate_TT(
333
+ model,
334
+ None,
335
+ input_ids,
336
+ None,
337
+ ["T1T2"],
338
+ max_returned_tokens=2048,
339
+ temperature=0.9,
340
+ top_k=1,
341
+ eos_id_a=_eoa,
342
+ eos_id_t=_eot,
343
+ pad_id_t=_pad_t,
344
+ shift=padded_text_vocabsize,
345
+ include_prompt=True,
346
+ generate_text=True,
347
+ )
348
+ model.clear_kv_cache()
349
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
350
+
351
+
352
+ def load_model(ckpt_dir, device):
353
+ snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
354
+ whisper_model_path = ckpt_dir + "/small.pt"
355
+ if not os.path.exists(whisper_model_path):
356
+ whisper_model_path = "small"
357
+ whispermodel = whisper.load_model(whisper_model_path).to(device)
358
+ text_tokenizer = Tokenizer(ckpt_dir)
359
+ fabric = L.Fabric(devices=1, strategy="auto")
360
+ config = Config.from_file(ckpt_dir + "/model_config.yaml")
361
+ config.post_adapter = False
362
+
363
+ with fabric.init_module(empty_init=False):
364
+ model = GPT(config)
365
+
366
+ model = fabric.setup(model)
367
+ state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
368
+ model.load_state_dict(state_dict, strict=True)
369
+ model.to(device).eval()
370
+
371
+ return fabric, model, text_tokenizer, snacmodel, whispermodel
372
+
373
+
374
+ def download_model(ckpt_dir):
375
+ repo_id = "gpt-omni/mini-omni2"
376
+ snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
377
+
378
+
379
+ def get_text_stream(list_output, index, text_tokenizer):
380
+ text_tokens = list_output[-1][index:]
381
+ index += len(text_tokens)
382
+ is_text_end = False
383
+ if text_vocabsize in text_tokens:
384
+ text_tokens = text_tokens[:text_tokens.index(text_vocabsize)]
385
+ is_text_end = True
386
+ if len(text_tokens) == 0:
387
+ return "", index, is_text_end
388
+ res_text = text_tokenizer.decode(torch.tensor(text_tokens))
389
+ return res_text, index, is_text_end
390
+
391
+
392
+ class OmniInference:
393
+
394
+ def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
395
+ self.device = device
396
+ if not os.path.exists(ckpt_dir):
397
+ print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
398
+ download_model(ckpt_dir)
399
+ self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device)
400
+
401
+ def warm_up(self, sample='./data/samples/output1.wav'):
402
+ for _ in self.run_AT_batch_stream(sample):
403
+ pass
404
+
405
+ @torch.inference_mode()
406
+ def run_AT_batch_stream(self,
407
+ audio_path,
408
+ stream_stride=4,
409
+ max_returned_tokens=2048,
410
+ temperature=0.9,
411
+ top_k=1,
412
+ top_p=1.0,
413
+ eos_id_a=_eoa,
414
+ eos_id_t=_eot,
415
+ save_path=None
416
+ ):
417
+
418
+ assert os.path.exists(audio_path), f"audio file {audio_path} not found"
419
+ model = self.model
420
+
421
+ with self.fabric.init_tensor():
422
+ model.set_kv_cache(batch_size=2,device=self.device)
423
+
424
+ mel, leng = load_audio(audio_path)
425
+ audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
426
+ T = input_ids[0].size(1)
427
+ device = input_ids[0].device
428
+
429
+ assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}"
430
+
431
+ if model.max_seq_length < max_returned_tokens - 1:
432
+ raise NotImplementedError(
433
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
434
+ )
435
+
436
+ input_pos = torch.tensor([T], device=device)
437
+ list_output = [[] for i in range(8)]
438
+ tokens_A, token_T = next_token_image_batch(
439
+ model,
440
+ audio_feature.to(torch.float32).to(model.device),
441
+ None,
442
+ input_ids,
443
+ [T - 3, T - 3],
444
+ ["A1T2", "A1T2"],
445
+ input_pos=torch.arange(0, T, device=device),
446
+ temperature=temperature,
447
+ top_k=top_k,
448
+ top_p=top_p,
449
+ )
450
+
451
+ for i in range(7):
452
+ list_output[i].append(tokens_A[i].tolist()[0])
453
+ list_output[7].append(token_T.tolist()[0])
454
+
455
+ model_input_ids = [[] for i in range(8)]
456
+ for i in range(7):
457
+ tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize + i * padded_audio_vocabsize
458
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
459
+ model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=device))
460
+ model_input_ids[i] = torch.stack(model_input_ids[i])
461
+
462
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
463
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
464
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
465
+
466
+ text_end = False
467
+ index = 1
468
+ nums_generate = stream_stride
469
+ begin_generate = False
470
+ current_index = 0
471
+
472
+ text_index = 0
473
+ is_text_end = False
474
+
475
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
476
+ tokens_A, token_T = next_token_image_batch(
477
+ model,
478
+ None,
479
+ None,
480
+ model_input_ids,
481
+ None,
482
+ None,
483
+ input_pos=input_pos,
484
+ temperature=temperature,
485
+ top_k=top_k,
486
+ top_p=top_p,
487
+ )
488
+
489
+ if text_end:
490
+ token_T = torch.tensor([_pad_t], device=device)
491
+
492
+ if tokens_A[-1] == eos_id_a:
493
+ break
494
+
495
+ if token_T == eos_id_t:
496
+ text_end = True
497
+
498
+ for i in range(7):
499
+ list_output[i].append(tokens_A[i].tolist()[0])
500
+ list_output[7].append(token_T.tolist()[0])
501
+
502
+ model_input_ids = [[] for i in range(8)]
503
+ for i in range(7):
504
+ tokens_A[i] = tokens_A[i].clone() +padded_text_vocabsize + i * padded_audio_vocabsize
505
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
506
+ model_input_ids[i].append(
507
+ torch.tensor([layershift(4097, i)], device=device)
508
+ )
509
+ model_input_ids[i] = torch.stack(model_input_ids[i])
510
+
511
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
512
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
513
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
514
+
515
+ if index == 7:
516
+ begin_generate = True
517
+
518
+ if begin_generate:
519
+ current_index += 1
520
+ if current_index == nums_generate:
521
+ current_index = 0
522
+ snac = get_snac(list_output, index, nums_generate)
523
+ audio_stream = generate_audio_data(snac, self.snacmodel, self.device)
524
+ if is_text_end:
525
+ text_stream = ""
526
+ else:
527
+ text_stream, text_index, is_text_end = get_text_stream(list_output, text_index, self.text_tokenizer)
528
+
529
+ yield (audio_stream, text_stream)
530
+
531
+ input_pos = input_pos.add_(1)
532
+ index += 1
533
+ text = self.text_tokenizer.decode(torch.tensor(list_output[-1]))
534
+ print(f"text output: {text}")
535
+
536
+ if save_path is not None:
537
+ audiolist = reconscruct_snac(list_output)
538
+ audio = reconstruct_tensors(audiolist)
539
+ with torch.inference_mode():
540
+ audio_hat = self.snacmodel.decode(audio)
541
+ sf.write(save_path, audio_hat.squeeze().cpu().numpy(), 24000)
542
+
543
+ model.clear_kv_cache()
544
+ return list_output
545
+
546
+
547
+ def test_infer():
548
+ device = "cuda:0"
549
+ out_dir = f"./output/{get_time_str()}"
550
+ ckpt_dir = f"./checkpoint"
551
+ if not os.path.exists(ckpt_dir):
552
+ print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
553
+ download_model(ckpt_dir)
554
+
555
+ fabric, model, text_tokenizer, snacmodel, whispermodel = load_model(ckpt_dir, device)
556
+
557
+ task = ['A1A2', 'asr', "T1A2", "AA-BATCH", 'T1T2', 'AT']
558
+
559
+ # prepare test data
560
+ # TODO
561
+ test_audio_list = sorted(glob.glob('./data/samples/output*.wav'))
562
+ test_audio_transcripts = [
563
+ "What is your name?",
564
+ "what are your hobbies?",
565
+ "Do you like beijing",
566
+ "How are you feeling today?",
567
+ "what is the weather like today?",
568
+ ]
569
+ test_text_list = [
570
+ "What is your name?",
571
+ "How are you feeling today?",
572
+ "Can you describe your surroundings?",
573
+ "What did you do yesterday?",
574
+ "What is your favorite book and why?",
575
+ "How do you make a cup of tea?",
576
+ "What is the weather like today?",
577
+ "Can you explain the concept of time?",
578
+ "Can you tell me a joke?",
579
+ ]
580
+
581
+ # LOAD MODEL
582
+ with torch.no_grad():
583
+ if "A1A2" in task:
584
+ print("===============================================================")
585
+ print(" testing A1A2")
586
+ print("===============================================================")
587
+ step = 0
588
+ for path in test_audio_list:
589
+ try:
590
+ mel, leng = load_audio(path)
591
+ audio_feature, input_ids = get_input_ids_whisper(mel, leng, whispermodel, device)
592
+ text = A1_A2(
593
+ fabric,
594
+ audio_feature,
595
+ input_ids,
596
+ leng,
597
+ model,
598
+ text_tokenizer,
599
+ step,
600
+ snacmodel,
601
+ out_dir=out_dir,
602
+ )
603
+ print(f"input: {test_audio_transcripts[step]}")
604
+ print(f"output: {text}")
605
+ step += 1
606
+ print(
607
+ "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++"
608
+ )
609
+ except:
610
+ print(f"[error] failed to process {path}")
611
+ print("===============================================================")
612
+
613
+ if 'asr' in task:
614
+ print("===============================================================")
615
+ print(" testing asr")
616
+ print("===============================================================")
617
+
618
+ index = 0
619
+ step = 0
620
+ for path in test_audio_list:
621
+ mel, leng = load_audio(path)
622
+ audio_feature, input_ids = get_input_ids_whisper(mel, leng, whispermodel, device, special_token_a=_pad_a, special_token_t=_asr)
623
+ output = A1_T1(fabric, audio_feature, input_ids ,leng, model, text_tokenizer, index).lower().replace(',','').replace('.','').replace('?','')
624
+ print(f"audio_path: {path}")
625
+ print(f"audio transcript: {test_audio_transcripts[index]}")
626
+ print(f"asr output: {output}")
627
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
628
+ index += 1
629
+
630
+ if "T1A2" in task:
631
+ step = 0
632
+ print("\n")
633
+ print("===============================================================")
634
+ print(" testing T1A2")
635
+ print("===============================================================")
636
+ for text in test_text_list:
637
+ input_ids = get_input_ids_TA(text, text_tokenizer)
638
+ text_output = T1_A2(fabric, input_ids, model, text_tokenizer, step,
639
+ snacmodel, out_dir=out_dir)
640
+ print(f"input: {text}")
641
+ print(f"output: {text_output}")
642
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
643
+ step += 1
644
+ print("===============================================================")
645
+
646
+ if "T1T2" in task:
647
+ step = 0
648
+ print("\n")
649
+ print("===============================================================")
650
+ print(" testing T1T2")
651
+ print("===============================================================")
652
+
653
+ for text in test_text_list:
654
+ input_ids = get_input_ids_TT(text, text_tokenizer)
655
+ text_output = T1_T2(fabric, input_ids, model, text_tokenizer, step)
656
+ print(f" Input: {text}")
657
+ print(f"Output: {text_output}")
658
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
659
+ print("===============================================================")
660
+
661
+ if "AT" in task:
662
+ print("===============================================================")
663
+ print(" testing A1T2")
664
+ print("===============================================================")
665
+ step = 0
666
+ for path in test_audio_list:
667
+ mel, leng = load_audio(path)
668
+ audio_feature, input_ids = get_input_ids_whisper(
669
+ mel, leng, whispermodel, device,
670
+ special_token_a=_pad_a, special_token_t=_answer_t
671
+ )
672
+ text = A1_T2(
673
+ fabric, audio_feature, input_ids, leng, model, text_tokenizer, step
674
+ )
675
+ print(f"input: {test_audio_transcripts[step]}")
676
+ print(f"output: {text}")
677
+ step += 1
678
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
679
+ print("===============================================================")
680
+
681
+ if "AA-BATCH" in task:
682
+ print("===============================================================")
683
+ print(" testing A1A2-BATCH")
684
+ print("===============================================================")
685
+ step = 0
686
+ for path in test_audio_list:
687
+ mel, leng = load_audio(path)
688
+ audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device)
689
+ text = A1_A2_batch(
690
+ fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
691
+ snacmodel, out_dir=out_dir
692
+ )
693
+ print(f"input: {test_audio_transcripts[step]}")
694
+ print(f"output: {text}")
695
+ step += 1
696
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
697
+ print("===============================================================")
698
+
699
+ print("*********************** test end *****************************")
700
+
701
+
702
+
703
+ if __name__ == "__main__":
704
+ test_infer()
inference_vision.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from litgpt.generate.base import next_token_image_batch
4
+ import soundfile as sf
5
+ from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
6
+ from utils.snac_utils import get_snac, generate_audio_data
7
+ import clip
8
+ import inference
9
+ from tqdm import tqdm
10
+ from inference import OmniInference, load_model, load_audio, download_model
11
+ from inference import text_vocabsize, padded_text_vocabsize, get_text_stream
12
+ from PIL import Image
13
+
14
+
15
+ torch.set_printoptions(sci_mode=False)
16
+
17
+ _image = inference._image
18
+ _eoimage = inference._eoimage
19
+ _pad_t = inference._pad_t
20
+ _input_t = inference._input_t
21
+ _answer_t = inference._answer_t
22
+ _eot = inference._eot
23
+ _eoa = inference._eoa
24
+ _pad_a = inference._pad_a
25
+ _input_a = inference._input_a
26
+ _answer_a = inference._answer_a
27
+
28
+
29
+ def get_input_ids_ImageQA_ATBatch(mel, leng, whispermodel, device):
30
+
31
+ with torch.no_grad():
32
+ mel = mel.unsqueeze(0).to(device)
33
+ audio_feature = whispermodel.embed_audio(mel)[0][:leng]
34
+
35
+ audio_len = audio_feature.size(0)
36
+
37
+ input_ids = []
38
+ input_ids_item = [[] for i in range(8)]
39
+ for i in range(7):
40
+ input_ids_item[i] = [layershift(_image,i)] + [layershift(_pad_a,i)] * 50 + [layershift(_eoimage,i)]
41
+ input_ids_item[i] += [layershift(_input_a,i)]+[layershift(_pad_a,i)]*(audio_len)+[layershift(_eoa,i)]
42
+ input_ids_item[i] += [layershift(_answer_a,i)]
43
+
44
+ input_ids_item[-1] = [_pad_t]* (52 + 2 + audio_len) + [_answer_t]
45
+ input_ids_item = [torch.tensor(item) for item in input_ids_item]
46
+
47
+ input_ids.append(input_ids_item)
48
+
49
+ input_ids_item = [[] for i in range(8)]
50
+ for i in range(7):
51
+ input_ids_item[i] = [layershift(_image,i)] + [layershift(_pad_a,i)] * 50 + [layershift(_eoimage,i)]
52
+ input_ids_item[i] += [layershift(_input_a,i)]+[layershift(_pad_a,i)]*(audio_len)+[layershift(_eoa,i)] + [layershift(_pad_a,i)]
53
+
54
+ input_ids_item[-1] = [_pad_t]* (52 + 2 + audio_len) + [_answer_t]
55
+
56
+ input_ids_item = [torch.tensor(item) for item in input_ids_item]
57
+ input_ids.append(input_ids_item)
58
+
59
+ stacked_inputids = [[] for _ in range(8)]
60
+ for i in range(2):
61
+ for j in range(8):
62
+ stacked_inputids[j].append(input_ids[i][j])
63
+ stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
64
+
65
+ return torch.stack([audio_feature,audio_feature]), stacked_inputids
66
+
67
+
68
+ def load_clip_model(ckpt_dir, device):
69
+ clip_model_path = ckpt_dir + "/ViT-B-32.pt"
70
+ if not os.path.exists(clip_model_path):
71
+ clip_model_path = "ViT-B/32"
72
+ clipmodel, clippreprocess = clip.load(clip_model_path, device=device)
73
+ return clipmodel, clippreprocess
74
+
75
+
76
+ class OmniVisionInference(OmniInference):
77
+
78
+ def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
79
+ self.device = device
80
+ if not os.path.exists(ckpt_dir):
81
+ print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
82
+ download_model(ckpt_dir)
83
+ self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device)
84
+ self.clipmodel, self.clippreprocess = load_clip_model(ckpt_dir, device)
85
+
86
+ def warm_up(self,
87
+ audio_sample='./data/samples/vision_qa_audio.wav',
88
+ image_sample='./data/samples/vision_qa_image.jpg'
89
+ ):
90
+ for _ in self.run_vision_AA_batch_stream(audio_sample, image_sample,
91
+ save_path="./data/samples/vision_qa_output.wav",
92
+ warm_up=True):
93
+ pass
94
+
95
+ @torch.inference_mode()
96
+ def run_vision_AA_batch_stream(self, audio_path, image_path,
97
+ stream_stride=4,
98
+ max_returned_tokens=2048,
99
+ temperature=0.9,
100
+ top_k=1,
101
+ top_p=1.0,
102
+ eos_id_a=_eoa,
103
+ eos_id_t=_eot,
104
+ pad_id=_pad_t,
105
+ save_path=None,
106
+ warm_up=False
107
+ ):
108
+ with self.fabric.init_tensor():
109
+ self.model.set_kv_cache(batch_size=2)
110
+
111
+ model = self.model
112
+
113
+ mel, leng = load_audio(audio_path)
114
+ img = Image.open(image_path)
115
+
116
+ audio_feature, input_ids = get_input_ids_ImageQA_ATBatch(mel, leng, self.whispermodel, self.device)
117
+ ima = self.clippreprocess(img).unsqueeze(0).to(self.device)
118
+ ima_feature = self.clipmodel.encode_image(ima).squeeze(0).to(self.device)
119
+
120
+ ima_feature = torch.stack([ima_feature.clone(),ima_feature.clone()]).to(self.device)
121
+ leng = [leng,leng]
122
+ task = ['ImageQA_A','ImageQA_AT']
123
+
124
+ T = input_ids[0].size(1)
125
+ assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}"
126
+
127
+ if model.max_seq_length < max_returned_tokens - 1:
128
+ raise NotImplementedError(
129
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
130
+ )
131
+
132
+ list_output = [[] for i in range(8)]
133
+
134
+ tokens_A , token_T = next_token_image_batch(
135
+ model,
136
+ audio_feature.to(torch.float32).to(self.device),
137
+ ima_feature.to(torch.float32).to(self.device) ,
138
+ input_ids ,
139
+ whisper_lens = leng ,
140
+ task = task,
141
+ input_pos = torch.arange(0, T, device=self.device),
142
+ temperature=temperature,
143
+ top_k=top_k,
144
+ top_p=top_p
145
+ )
146
+ for i in range(7): list_output[i].append(tokens_A[i].tolist()[0])
147
+ list_output[7].append(token_T.tolist()[0])
148
+
149
+ text_end = False
150
+ index = 1
151
+ nums_generate = stream_stride
152
+ begin_generate = False
153
+ current_index = 0
154
+ input_pos = torch.tensor([T], device=self.device)
155
+
156
+ model_input_ids = [[] for i in range(8)]
157
+ for i in range(7):
158
+ tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize+ i * 4160
159
+ model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32))
160
+ model_input_ids[i].append(torch.tensor([layershift(4097,i)],device=self.device))
161
+ model_input_ids[i] = torch.stack(model_input_ids[i])
162
+
163
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
164
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
165
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
166
+
167
+ text_index = 0
168
+ is_text_end = False
169
+
170
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
171
+
172
+ tokens_A , token_T = next_token_image_batch(model, None , None ,
173
+ input_ids = model_input_ids,
174
+ whisper_lens= None,
175
+ task = None,
176
+ input_pos = input_pos,
177
+ temperature=temperature,
178
+ top_k=top_k,
179
+ top_p=top_p)
180
+
181
+ if text_end:
182
+ token_T = torch.tensor([_pad_t], device=self.device)
183
+
184
+ if tokens_A[-1] == eos_id_a:
185
+ break
186
+ if token_T == eos_id_t:
187
+ text_end = True
188
+
189
+ for i in range(7): list_output[i].append(tokens_A[i].tolist()[0])
190
+ list_output[7].append(token_T.tolist()[0])
191
+
192
+
193
+ if index == 7:
194
+ begin_generate = True
195
+
196
+ if begin_generate:
197
+ current_index += 1
198
+ if current_index == nums_generate:
199
+ current_index = 0
200
+ snac = get_snac(list_output,index,nums_generate)
201
+ audio_stream = generate_audio_data(snac, self.snacmodel, self.device)
202
+ if is_text_end:
203
+ text_stream = ""
204
+ else:
205
+ text_stream, text_index, is_text_end = get_text_stream(list_output, text_index, self.text_tokenizer)
206
+
207
+ yield (audio_stream, text_stream)
208
+
209
+ if warm_up:
210
+ break
211
+
212
+ input_pos = input_pos.add_(1)
213
+ model_input_ids = [[] for i in range(8)]
214
+ for i in range(7):
215
+ tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize+ i * 4160
216
+ model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32))
217
+ model_input_ids[i].append(torch.tensor([layershift(4097,i)],device=self.device))
218
+ model_input_ids[i] = torch.stack(model_input_ids[i])
219
+
220
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
221
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
222
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
223
+
224
+ index += 1
225
+
226
+ text_tokens = list_output[-1]
227
+ if text_vocabsize in text_tokens:
228
+ text_tokens = text_tokens[:text_tokens.index(text_vocabsize)]
229
+ res_text = self.text_tokenizer.decode(torch.tensor(text_tokens))
230
+ print(f"text output: {res_text}")
231
+
232
+ if save_path is not None:
233
+ audiolist = reconscruct_snac(list_output)
234
+ audio = reconstruct_tensors(audiolist)
235
+ with torch.inference_mode():
236
+ audio_hat = self.snacmodel.decode(audio)
237
+ sf.write(save_path, audio_hat.squeeze().cpu().numpy(), 24000)
238
+
239
+ model.clear_kv_cache()
240
+
241
+
242
+ def test_vision_infer():
243
+ client = OmniVisionInference()
244
+ client.warm_up()
245
+ input_audio_path = './data/samples/vision_qa_audio.wav'
246
+ input_image_path = './data/samples/vision_qa_image.jpg'
247
+
248
+ res_text = ""
249
+ for audio_stream, text_stream in client.run_vision_AA_batch_stream(
250
+ input_audio_path,
251
+ input_image_path,
252
+ save_path="./vision_qa_output.wav"
253
+ ):
254
+ res_text += text_stream
255
+ print(f"text_output: {res_text}")
256
+
257
+
258
+ if __name__ == "__main__":
259
+ test_vision_infer()
requirements.txt CHANGED
@@ -1,4 +1,20 @@
1
- huggingface_hub==0.25.2
2
- transformers
3
- scipy
4
- accelerate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.3.1
2
+ torchvision==0.18.1
3
+ torchaudio==2.3.1
4
+ litgpt==0.4.3
5
+ snac==1.2.0
6
+ soundfile==0.12.1
7
+ openai-whisper
8
+ tokenizers==0.19.1
9
+ streamlit==1.37.1
10
+ streamlit-webrtc
11
+ # PyAudio==0.2.14
12
+ pydub==0.25.1
13
+ onnxruntime==1.19.0
14
+ # numpy==1.26.3
15
+ librosa==0.10.2.post1
16
+ flask==3.0.3
17
+ fire
18
+ git+https://github.com/mini-omni/CLIP.git
19
+ gradio_webrtc[vad]==0.0.11
20
+ twilio
server.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
5
+
6
+
7
+ import flask
8
+ import base64
9
+ import tempfile
10
+ import traceback
11
+ from flask import Flask, Response, stream_with_context
12
+ from inference_vision import OmniVisionInference
13
+
14
+
15
+ class OmniChatServer(object):
16
+ def __init__(self, ip='0.0.0.0', port=60808, run_app=True,
17
+ ckpt_dir='./checkpoint', device='cuda:0') -> None:
18
+ server = Flask(__name__)
19
+ # CORS(server, resources=r"/*")
20
+ # server.config["JSON_AS_ASCII"] = False
21
+
22
+ self.client = OmniVisionInference(ckpt_dir, device)
23
+ self.client.warm_up()
24
+
25
+ server.route("/chat", methods=["POST"])(self.chat)
26
+
27
+ if run_app:
28
+ server.run(host=ip, port=port, threaded=False)
29
+ else:
30
+ self.server = server
31
+
32
+ def chat(self) -> Response:
33
+
34
+ req_data = flask.request.get_json()
35
+ try:
36
+ audio_data_buf = req_data["audio"].encode("utf-8")
37
+ audio_data_buf = base64.b64decode(audio_data_buf)
38
+ stream_stride = req_data.get("stream_stride", 4)
39
+ max_tokens = req_data.get("max_tokens", 2048)
40
+
41
+ image_data_buf = req_data.get("image", None)
42
+ if image_data_buf:
43
+ image_data_buf = image_data_buf.encode("utf-8")
44
+ image_data_buf = base64.b64decode(image_data_buf)
45
+
46
+ audio_path, img_path = None, None
47
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_f, \
48
+ tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as img_f:
49
+ audio_f.write(audio_data_buf)
50
+ audio_path = audio_f.name
51
+
52
+ if image_data_buf:
53
+ img_f.write(image_data_buf)
54
+ img_path = img_f.name
55
+ else:
56
+ img_path = None
57
+
58
+ if img_path is not None:
59
+ resp_generator = self.client.run_vision_AA_batch_stream(audio_f.name, img_f.name,
60
+ stream_stride, max_tokens,
61
+ save_path='./vision_qa_out_cache.wav')
62
+ else:
63
+ resp_generator = self.client.run_AT_batch_stream(audio_f.name, stream_stride,
64
+ max_tokens,
65
+ save_path='./audio_qa_out_cache.wav')
66
+ return Response(stream_with_context(self.generator(resp_generator)),
67
+ mimetype='multipart/x-mixed-replace; boundary=frame')
68
+ except Exception as e:
69
+ print(traceback.format_exc())
70
+ return Response("An error occurred", status=500)
71
+
72
+ def generator(self, resp_generator):
73
+ for audio_stream, text_stream in resp_generator:
74
+ yield b'\r\n--frame\r\n'
75
+ yield b'Content-Type: audio/wav\r\n\r\n'
76
+ yield audio_stream
77
+ yield b'\r\n--frame\r\n'
78
+ yield b'Content-Type: text/plain\r\n\r\n'
79
+ yield text_stream.encode()
80
+
81
+
82
+ # CUDA_VISIBLE_DEVICES=1 gunicorn -w 2 -b 0.0.0.0:60808 'server:create_app()'
83
+ def create_app():
84
+ server = OmniChatServer(run_app=False)
85
+ return server.server
86
+
87
+
88
+ def serve(ip='0.0.0.0', port=60808, device='cuda:0'):
89
+
90
+ OmniChatServer(ip, port=port,run_app=True, device=device)
91
+
92
+
93
+ if __name__ == "__main__":
94
+ import fire
95
+ fire.Fire(serve)
96
+