Spaces:
Sleeping
Sleeping
zhzluke96
commited on
Commit
•
f34bda5
1
Parent(s):
b44532e
update
Browse files- modules/api/impl/openai_api.py +50 -4
- modules/normalization.py +21 -2
- modules/utils/zh_normalization/num.py +15 -6
- webui.py +3 -0
modules/api/impl/openai_api.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from fastapi import HTTPException, Body
|
2 |
from fastapi.responses import StreamingResponse
|
3 |
|
4 |
import io
|
@@ -14,7 +14,7 @@ from modules.normalization import text_normalize
|
|
14 |
from modules import generate_audio as generate
|
15 |
|
16 |
|
17 |
-
from typing import Literal
|
18 |
import pyrubberband as pyrb
|
19 |
|
20 |
from modules.api import utils as api_utils
|
@@ -106,8 +106,29 @@ async def openai_speech_api(
|
|
106 |
raise HTTPException(status_code=500, detail=str(e))
|
107 |
|
108 |
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
"/v1/audio/speech",
|
112 |
response_class=FileResponse,
|
113 |
description="""
|
@@ -122,3 +143,28 @@ openai api document:
|
|
122 |
> model 可填任意值
|
123 |
""",
|
124 |
)(openai_speech_api)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import File, Form, HTTPException, Body, UploadFile
|
2 |
from fastapi.responses import StreamingResponse
|
3 |
|
4 |
import io
|
|
|
14 |
from modules import generate_audio as generate
|
15 |
|
16 |
|
17 |
+
from typing import List, Literal, Optional, Union
|
18 |
import pyrubberband as pyrb
|
19 |
|
20 |
from modules.api import utils as api_utils
|
|
|
106 |
raise HTTPException(status_code=500, detail=str(e))
|
107 |
|
108 |
|
109 |
+
class TranscribeSegment(BaseModel):
|
110 |
+
id: int
|
111 |
+
seek: float
|
112 |
+
start: float
|
113 |
+
end: float
|
114 |
+
text: str
|
115 |
+
tokens: List[int]
|
116 |
+
temperature: float
|
117 |
+
avg_logprob: float
|
118 |
+
compression_ratio: float
|
119 |
+
no_speech_prob: float
|
120 |
+
|
121 |
+
|
122 |
+
class TranscriptionsVerboseResponse(BaseModel):
|
123 |
+
task: str
|
124 |
+
language: str
|
125 |
+
duration: float
|
126 |
+
text: str
|
127 |
+
segments: List[TranscribeSegment]
|
128 |
+
|
129 |
+
|
130 |
+
def setup(app: APIManager):
|
131 |
+
app.post(
|
132 |
"/v1/audio/speech",
|
133 |
response_class=FileResponse,
|
134 |
description="""
|
|
|
143 |
> model 可填任意值
|
144 |
""",
|
145 |
)(openai_speech_api)
|
146 |
+
|
147 |
+
@app.post(
|
148 |
+
"/v1/audio/transcriptions",
|
149 |
+
response_class=TranscriptionsVerboseResponse,
|
150 |
+
description="WIP",
|
151 |
+
)
|
152 |
+
async def transcribe(
|
153 |
+
file: UploadFile = File(...),
|
154 |
+
model: str = Form(...),
|
155 |
+
language: Optional[str] = Form(None),
|
156 |
+
prompt: Optional[str] = Form(None),
|
157 |
+
response_format: str = Form("json"),
|
158 |
+
temperature: float = Form(0),
|
159 |
+
timestamp_granularities: List[str] = Form(["segment"]),
|
160 |
+
):
|
161 |
+
# TODO: Implement transcribe
|
162 |
+
return {
|
163 |
+
"file": file.filename,
|
164 |
+
"model": model,
|
165 |
+
"language": language,
|
166 |
+
"prompt": prompt,
|
167 |
+
"response_format": response_format,
|
168 |
+
"temperature": temperature,
|
169 |
+
"timestamp_granularities": timestamp_granularities,
|
170 |
+
}
|
modules/normalization.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from modules.utils.zh_normalization.text_normlization import *
|
2 |
import emojiswitch
|
3 |
from modules.utils.markdown import markdown_to_text
|
@@ -5,12 +6,28 @@ from modules import models
|
|
5 |
import re
|
6 |
|
7 |
|
|
|
8 |
def is_chinese(text):
|
9 |
# 中文字符的 Unicode 范围是 \u4e00-\u9fff
|
10 |
chinese_pattern = re.compile(r"[\u4e00-\u9fff]")
|
11 |
return bool(chinese_pattern.search(text))
|
12 |
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
post_normalize_pipeline = []
|
15 |
pre_normalize_pipeline = []
|
16 |
|
@@ -123,7 +140,7 @@ def apply_character_map(text):
|
|
123 |
|
124 |
@post_normalize()
|
125 |
def apply_emoji_map(text):
|
126 |
-
lang =
|
127 |
return emojiswitch.demojize(text, delimiters=("", ""), lang=lang)
|
128 |
|
129 |
|
@@ -144,6 +161,8 @@ def replace_unk_tokens(text):
|
|
144 |
"""
|
145 |
chat_tts = models.load_chat_tts()
|
146 |
if "tokenizer" not in chat_tts.pretrain_models:
|
|
|
|
|
147 |
return text
|
148 |
tokenizer = chat_tts.pretrain_models["tokenizer"]
|
149 |
vocab = tokenizer.get_vocab()
|
@@ -223,7 +242,7 @@ def sentence_normalize(sentence_text: str):
|
|
223 |
pattern = re.compile(r"(\[.+?\])|([^[]+)")
|
224 |
|
225 |
def normalize_part(part):
|
226 |
-
sentences = tx.normalize(part) if
|
227 |
dest_text = ""
|
228 |
for sentence in sentences:
|
229 |
sentence = apply_post_normalize(sentence)
|
|
|
1 |
+
from functools import lru_cache
|
2 |
from modules.utils.zh_normalization.text_normlization import *
|
3 |
import emojiswitch
|
4 |
from modules.utils.markdown import markdown_to_text
|
|
|
6 |
import re
|
7 |
|
8 |
|
9 |
+
@lru_cache(maxsize=64)
|
10 |
def is_chinese(text):
|
11 |
# 中文字符的 Unicode 范围是 \u4e00-\u9fff
|
12 |
chinese_pattern = re.compile(r"[\u4e00-\u9fff]")
|
13 |
return bool(chinese_pattern.search(text))
|
14 |
|
15 |
|
16 |
+
@lru_cache(maxsize=64)
|
17 |
+
def is_eng(text):
|
18 |
+
eng_pattern = re.compile(r"[a-zA-Z]")
|
19 |
+
return bool(eng_pattern.search(text))
|
20 |
+
|
21 |
+
|
22 |
+
@lru_cache(maxsize=64)
|
23 |
+
def guess_lang(text):
|
24 |
+
if is_chinese(text):
|
25 |
+
return "zh"
|
26 |
+
if is_eng(text):
|
27 |
+
return "en"
|
28 |
+
return "zh"
|
29 |
+
|
30 |
+
|
31 |
post_normalize_pipeline = []
|
32 |
pre_normalize_pipeline = []
|
33 |
|
|
|
140 |
|
141 |
@post_normalize()
|
142 |
def apply_emoji_map(text):
|
143 |
+
lang = guess_lang(text)
|
144 |
return emojiswitch.demojize(text, delimiters=("", ""), lang=lang)
|
145 |
|
146 |
|
|
|
161 |
"""
|
162 |
chat_tts = models.load_chat_tts()
|
163 |
if "tokenizer" not in chat_tts.pretrain_models:
|
164 |
+
# 这个地方只有在 huggingface spaces 中才会触发
|
165 |
+
# 因为 hugggingface 自动处理模型卸载加载,所以如果拿不到就算了...
|
166 |
return text
|
167 |
tokenizer = chat_tts.pretrain_models["tokenizer"]
|
168 |
vocab = tokenizer.get_vocab()
|
|
|
242 |
pattern = re.compile(r"(\[.+?\])|([^[]+)")
|
243 |
|
244 |
def normalize_part(part):
|
245 |
+
sentences = tx.normalize(part) if guess_lang(part) == "zh" else [part]
|
246 |
dest_text = ""
|
247 |
for sentence in sentences:
|
248 |
sentence = apply_post_normalize(sentence)
|
modules/utils/zh_normalization/num.py
CHANGED
@@ -144,13 +144,22 @@ def replace_number(match) -> str:
|
|
144 |
sign = match.group(1)
|
145 |
number = match.group(2)
|
146 |
pure_decimal = match.group(5)
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
result =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
return result
|
|
|
|
|
154 |
|
155 |
|
156 |
# 范围表达式
|
|
|
144 |
sign = match.group(1)
|
145 |
number = match.group(2)
|
146 |
pure_decimal = match.group(5)
|
147 |
+
|
148 |
+
# TODO 也许可以把 num2str 完全替换成 cn2an
|
149 |
+
import cn2an
|
150 |
+
text = pure_decimal if pure_decimal else f"{sign}{number}"
|
151 |
+
try:
|
152 |
+
result = cn2an.an2cn(text, "low")
|
153 |
+
except ValueError:
|
154 |
+
if pure_decimal:
|
155 |
+
result = num2str(pure_decimal)
|
156 |
+
else:
|
157 |
+
sign: str = "负" if sign else ""
|
158 |
+
number: str = num2str(number)
|
159 |
+
result = f"{sign}{number}"
|
160 |
return result
|
161 |
+
|
162 |
+
|
163 |
|
164 |
|
165 |
# 范围表达式
|
webui.py
CHANGED
@@ -45,6 +45,9 @@ from modules import refiner, config
|
|
45 |
from modules.utils import env, audio
|
46 |
from modules.SentenceSplitter import SentenceSplitter
|
47 |
|
|
|
|
|
|
|
48 |
torch._dynamo.config.cache_size_limit = 64
|
49 |
torch._dynamo.config.suppress_errors = True
|
50 |
torch.set_float32_matmul_precision("high")
|
|
|
45 |
from modules.utils import env, audio
|
46 |
from modules.SentenceSplitter import SentenceSplitter
|
47 |
|
48 |
+
# fix: If the system proxy is enabled in the Windows system, you need to skip these
|
49 |
+
os.environ["NO_PROXY"] = "localhost,127.0.0.1,0.0.0.0"
|
50 |
+
|
51 |
torch._dynamo.config.cache_size_limit = 64
|
52 |
torch._dynamo.config.suppress_errors = True
|
53 |
torch.set_float32_matmul_precision("high")
|