init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- ChatTTS/__init__.py +1 -0
- ChatTTS/config/__init__.py +1 -0
- ChatTTS/config/config.py +134 -0
- ChatTTS/core.py +669 -0
- ChatTTS/model/__init__.py +6 -0
- ChatTTS/model/cuda/__init__.py +1 -0
- ChatTTS/model/cuda/patch.py +18 -0
- ChatTTS/model/cuda/te_llama.py +192 -0
- ChatTTS/model/dvae.py +296 -0
- ChatTTS/model/embed.py +81 -0
- ChatTTS/model/gpt.py +613 -0
- ChatTTS/model/processors.py +58 -0
- ChatTTS/model/speaker.py +154 -0
- ChatTTS/model/tokenizer.py +138 -0
- ChatTTS/model/velocity/__init__.py +2 -0
- ChatTTS/model/velocity/block_manager.py +296 -0
- ChatTTS/model/velocity/configs.py +865 -0
- ChatTTS/model/velocity/llama.py +393 -0
- ChatTTS/model/velocity/llm.py +213 -0
- ChatTTS/model/velocity/llm_engine.py +833 -0
- ChatTTS/model/velocity/model_loader.py +69 -0
- ChatTTS/model/velocity/model_runner.py +817 -0
- ChatTTS/model/velocity/output.py +144 -0
- ChatTTS/model/velocity/sampler.py +120 -0
- ChatTTS/model/velocity/sampling_params.py +296 -0
- ChatTTS/model/velocity/scheduler.py +426 -0
- ChatTTS/model/velocity/sequence.py +450 -0
- ChatTTS/model/velocity/worker.py +251 -0
- ChatTTS/norm.py +253 -0
- ChatTTS/res/__init__.py +0 -0
- ChatTTS/res/homophones_map.json +0 -0
- ChatTTS/res/sha256_map.json +13 -0
- ChatTTS/utils/__init__.py +4 -0
- ChatTTS/utils/dl.py +220 -0
- ChatTTS/utils/gpu.py +40 -0
- ChatTTS/utils/io.py +44 -0
- ChatTTS/utils/log.py +16 -0
- Dockerfile +13 -0
- LICENSE +661 -0
- docs/cn/README.md +314 -0
- docs/es/README.md +255 -0
- docs/fr/README.md +283 -0
- docs/jp/README.md +134 -0
- docs/ru/README.md +136 -0
- examples/__init__.py +0 -0
- examples/api/README.md +23 -0
- examples/api/client.py +76 -0
- examples/api/main.py +107 -0
- examples/api/requirements.txt +2 -0
- examples/cmd/run.py +151 -0
ChatTTS/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .core import Chat
|
ChatTTS/config/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .config import Config
|
ChatTTS/config/config.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass(repr=False, eq=False)
|
5 |
+
class Path:
|
6 |
+
vocos_ckpt_path: str = "asset/Vocos.pt"
|
7 |
+
dvae_ckpt_path: str = "asset/DVAE_full.pt"
|
8 |
+
gpt_ckpt_path: str = "asset/gpt"
|
9 |
+
decoder_ckpt_path: str = "asset/Decoder.pt"
|
10 |
+
tokenizer_path: str = "asset/tokenizer"
|
11 |
+
embed_path: str = "asset/Embed.safetensors"
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass(repr=False, eq=False)
|
15 |
+
class Decoder:
|
16 |
+
idim: int = 384
|
17 |
+
odim: int = 384
|
18 |
+
hidden: int = 512
|
19 |
+
n_layer: int = 12
|
20 |
+
bn_dim: int = 128
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass(repr=False, eq=False)
|
24 |
+
class VQ:
|
25 |
+
dim: int = 1024
|
26 |
+
levels: tuple = (5, 5, 5, 5)
|
27 |
+
G: int = 2
|
28 |
+
R: int = 2
|
29 |
+
|
30 |
+
|
31 |
+
@dataclass(repr=False, eq=False)
|
32 |
+
class DVAE:
|
33 |
+
encoder: Decoder = Decoder(
|
34 |
+
idim=512,
|
35 |
+
odim=1024,
|
36 |
+
hidden=256,
|
37 |
+
n_layer=12,
|
38 |
+
bn_dim=128,
|
39 |
+
)
|
40 |
+
decoder: Decoder = Decoder(
|
41 |
+
idim=512,
|
42 |
+
odim=512,
|
43 |
+
hidden=256,
|
44 |
+
n_layer=12,
|
45 |
+
bn_dim=128,
|
46 |
+
)
|
47 |
+
vq: VQ = VQ()
|
48 |
+
|
49 |
+
|
50 |
+
@dataclass(repr=False, eq=False)
|
51 |
+
class GPT:
|
52 |
+
hidden_size: int = 768
|
53 |
+
intermediate_size: int = 3072
|
54 |
+
num_attention_heads: int = 12
|
55 |
+
num_hidden_layers: int = 20
|
56 |
+
use_cache: bool = False
|
57 |
+
max_position_embeddings: int = 4096
|
58 |
+
|
59 |
+
spk_emb_dim: int = 192
|
60 |
+
spk_KL: bool = False
|
61 |
+
num_audio_tokens: int = 626
|
62 |
+
num_text_tokens: int = 21178
|
63 |
+
num_vq: int = 4
|
64 |
+
|
65 |
+
|
66 |
+
@dataclass(repr=False, eq=False)
|
67 |
+
class Embed:
|
68 |
+
hidden_size: int = 768
|
69 |
+
num_audio_tokens: int = 626
|
70 |
+
num_text_tokens: int = 21178
|
71 |
+
num_vq: int = 4
|
72 |
+
|
73 |
+
|
74 |
+
@dataclass(repr=False, eq=False)
|
75 |
+
class FeatureExtractorInitArgs:
|
76 |
+
sample_rate: int = 24000
|
77 |
+
n_fft: int = 1024
|
78 |
+
hop_length: int = 256
|
79 |
+
n_mels: int = 100
|
80 |
+
padding: str = "center"
|
81 |
+
|
82 |
+
|
83 |
+
@dataclass(repr=False, eq=False)
|
84 |
+
class FeatureExtractor:
|
85 |
+
class_path: str = "vocos.feature_extractors.MelSpectrogramFeatures"
|
86 |
+
init_args: FeatureExtractorInitArgs = FeatureExtractorInitArgs()
|
87 |
+
|
88 |
+
|
89 |
+
@dataclass(repr=False, eq=False)
|
90 |
+
class BackboneInitArgs:
|
91 |
+
input_channels: int = 100
|
92 |
+
dim: int = 512
|
93 |
+
intermediate_dim: int = 1536
|
94 |
+
num_layers: int = 8
|
95 |
+
|
96 |
+
|
97 |
+
@dataclass(repr=False, eq=False)
|
98 |
+
class Backbone:
|
99 |
+
class_path: str = "vocos.models.VocosBackbone"
|
100 |
+
init_args: BackboneInitArgs = BackboneInitArgs()
|
101 |
+
|
102 |
+
|
103 |
+
@dataclass(repr=False, eq=False)
|
104 |
+
class FourierHeadInitArgs:
|
105 |
+
dim: int = 512
|
106 |
+
n_fft: int = 1024
|
107 |
+
hop_length: int = 256
|
108 |
+
padding: str = "center"
|
109 |
+
|
110 |
+
|
111 |
+
@dataclass(repr=False, eq=False)
|
112 |
+
class FourierHead:
|
113 |
+
class_path: str = "vocos.heads.ISTFTHead"
|
114 |
+
init_args: FourierHeadInitArgs = FourierHeadInitArgs()
|
115 |
+
|
116 |
+
|
117 |
+
@dataclass(repr=False, eq=False)
|
118 |
+
class Vocos:
|
119 |
+
feature_extractor: FeatureExtractor = FeatureExtractor()
|
120 |
+
backbone: Backbone = Backbone()
|
121 |
+
head: FourierHead = FourierHead()
|
122 |
+
|
123 |
+
|
124 |
+
@dataclass(repr=False, eq=False)
|
125 |
+
class Config:
|
126 |
+
path: Path = Path()
|
127 |
+
decoder: Decoder = Decoder()
|
128 |
+
dvae: DVAE = DVAE()
|
129 |
+
gpt: GPT = GPT()
|
130 |
+
embed: Embed = Embed()
|
131 |
+
vocos: Vocos = Vocos()
|
132 |
+
spk_stat: str = (
|
133 |
+
"愐穤巩噅廷戇笉屈癐媄垹垧帶爲漈塀殐慄亅倴庲舴猂瑈圐狴夥圓帍戛挠腉耐劤坽喳幾战謇聀崒栄呥倸庭燡欈杁襐褄乭埗幺爃弔摁斐捔兕佖廐舏竾豃磐姓趡佄幒爚欄豄讐皳訵仩帆投謌荃蝐叄圝伆幦抂茁呄掑斃讹傮庞爣蜀橁偐祄亥兡常爂欍扉丐浔佱僈強払伅扂蛐徴憍傞巀戺欀艂琐嗴啥値彷刂權穈扒卤俔贲庛初笂卄贐枴仭亁庛剎猢扃缐趤刁偵幪舏伌煁婐潤晍位弾舙茥穁葏蠣訑企庤刊笍橁溑僔云偁庯戚伍潉膐脴僵噔廃艅匊祂唐憴壝嗙席爥欁虁谐牴帽势弿牳蜁兀蛐傄喩丿帔刔圆衁廐罤庁促帙劢伈汄樐檄勵伴弝舑欍罅虐昴劭勅帜刼朊蕁虐蓴樑伫幨扑謪剀堐稴丵伱弐舮諸赁習俔容厱幫牶謃孄糐答嗝僊帜燲笄終瀒判久僤帘爴茇千孑冄凕佳引扐蜁歁缏裄剽儺恘爋朏眿廐呄塍嘇幻爱茠詁訐剴唭俐幾戊欀硁菐贄楕偒巡爀弎屄莐睳賙凶彎刅漄區唐溴剑劋庽舽猄煃跐夔惥伾庮舎伈罁垑坄怅业怯刁朇獁嶏覔坩俳巶爜朐潁崐萄俹凛常爺笌穀聐此夡倛帡刀匉終窏舣販侽怿扉伥贿憐忓謩姆幌犊漂慆癒却甝兎帼戏欅詂浐朔仹壭帰臷弎恇菐獤帡偖帘爞伅腂皐纤囅充幓戠伥灂丐訤戱倱弋爮嬌癁恐孄侥劬忶刓國詀桒古偩嘄庬戚茝赂监燤嘑勌幦舽持呂諐棤姑再底舡笍艃瀐孴倉傔弋爔猠乁濑塄偽嘧恂舛缇襃厐窴仡刱忕別漇穁岏缴廽价庌爊謈硄讑惤倁儂庭爋伇蝂嶐莔摝傠库刞茄歃戏薤伍伯廮创笠塄熐兴勽俄帅剉最腀砐敤卝侍弆戺朒虃旐蚄梕亖幔牻朣扅贐玔堝噅帡剌圅摀崐彤流僳庙爖嬇啁渐悤堁丛幆刧挜彃悐幤刹嚟恕芁看聀摐焔向乁帖爭欁癃糒圄弙佱廜戤謍婀咐昴焍亩廦艏拼謿芐癤怹兽幸舳朇畁喐稔毝丼弈懲挀譂勑哴啁伎常舭笯晁堑俄叩剔廟爍欦絁夒伤休傑廳戌蜅潆癐彴摑勯床刽欅艁砐忄搉从廡舊猥潂唐委仱僜廼爤朄呃弐礔滵垓幩爄挂筁乐籤刕凟幵爠弉���乑吴勥伖帪舩茆婁碐幤叭乢巜艳猁桀桐啄唩俊幍舮猀艅焐螔琽亀帋爜缅噃咐斤喩予幩爛笆摀浐猴依侹幃刕園慄蛐栤澹仑座爼謉桃慐浔斕偻幛懰嬓衁愐氄悅仿应芔漄衃敐謤傁匩幹抃圉癄廐裄屵噉幍利謍聂搐蛔嚙坍怗舁圐畃膐栄刵东巆戤諾呃偑媤嗨跞忶爝眄祂朒嶔僭劉忾刐匋癄袐翴珅僷廲芄茈恈皐擄崑伄廉牍匃剃犏澤唑丄庺戃伃煀某杄偙亽帴切缌罄挐尴噙倰带舞漄橄塐糴俩僯帀般漀坂栐更両俇廱舌猁慂拐偤嶱卶应刪眉獁茐伔嘅偺帟舊漂恀栐暄喡乞庙舆匂敀潑恔劑侖延戦盽怶唯慳蝘蟃孫娎益袰玍屃痶翮笪儚裀倹椌玻翀詵筽舘惯堿某侰晈藏缮詗廦夸妎瑻瀒裔媀憞唃冶璭狻渠荑奬熹茅愺氰菣滠翦岓褌泣崲嚭欓湒聙宺爄蛅愸庍匃帆誔穮懌蓪玷澌氋抌訙屌臞廛玸听屺希疭孝凂紋新煎彃膲跱尪懁眆窴珏卓揨菸紭概囥显壌榄垫嘮嬭覤媸侵佮烒耸觌婀秋狃帹葯訤桜糨笾腢伀肶悍炂艤禖岅臺惘梷瞍友盁佨岧憳瓧嘴汬藊愌蘤嶠硴绤蜲襏括勾谂縨妥蓪澭竭萢藜纞糲煮愆瀯孯琓罂諺塿燗狟弙衯揻縷丱糅臄梱瀮杰巳猙亊符胠匃泀廏圃膂蒃籏礩岈簹缌劺燲褡孓膜拔蠿觮呋煣厌尷熜論弲牭紫寊誃紀橴賬傸箍弚窃侫簲慯烣渽祌壓媥噜夽夛諛玹疮禄冪謇媽衤盰缺繑薫兾萧嵱打滽箺嚯凣狢蠜崼覽烸簶盯籓摀苶峸懗泲涻凮愳緗剋笔懆廡瞿椏礤惐藥崍腈烄伹亯昣翬褍絋桫僨吨莌丛矄蜞娈憊苆塁蓏嚢嫼绻崱婋囱蠸篯晣芀繼索兓僖誹岯圪褰蠇唓妷胅巁渮砛傈蝷嵚冃購赁峍裋荂舾符熻岳墩寮粃凲袑彚太绲头摯繳狁俥籌冝諝註坎幫擤詒宒凕賐唶梎噔弼課屿覍囨焬櫱撪蝮蝬簸懰櫫涺嵍睻屪翔峞慘滟熲昱军烊舿尦舄糖奁溏凂彆蝲糴禍困皻灏牋睒诙嶱臀开蓈眎腼丢纻廏憤嫖暭袭崲肸螛妒榗紉谨窮袃瑠聍绊腆亿冲葐喋縔詖岑兾给堸赏旻桀蛨媆訂峦紷敯囬偐筨岸焸拭笵殒哜墒萍屓娓諙械臮望摰芑寭准僞谹氍旋憢菮屃划欣瘫谎蘻哐繁籥禦僿誵皯墓燀縿笞熦绗稹榎矻綞蓓帡戓沺区才畃洊詪糐裶盰窶耎偌劂誐庩惝滜沺哮呃煐譠崄槀猄肼蔐擋湌蠺篃恥諌瞦宍堫挪裕崑慩狲悠煋仛愞砈粵八棁害楐妋萔貨尵奂苰怫誎傫岆蕯屇脉夈仆茎刓繸芺壸碗曛汁戭炻獻凉媁兎狜爴怰賃纎袏娷禃蓥膹薪渻罸窿粫凾褄舺窮墫干苊繁冏僮訸夯绛蓪虛羽慲烏憷趎睊蠰莍塞成廎盁欏喓蜮譤崆楁囘矇薭伣艘虝帴奮苢渶虎暣翐蝃尾稈糶瀴罐嵚氮葯笫慐棌悶炯竻爅们媡姢嫺窷刮歫劈裩屬椕賑蜹薊刲義哯尗褦瓀稾礋揣窼舫尋姁椄侸嗫珺修纘媃腽蛛稹梭呛瀈蘟縀礉論夵售主梮蠉娅娭裀誼嶭観枳倊簈褃擞綿催瞃溶苊笛襹櫲盅六囫獩佃粨慯瓢眸旱荃婨蔞岋祗墼焻网牻琖詆峋秉胳媴袭澓賢経稟壩胫碯偏囫嶎纆窈槊賐撹璬莃缘誾宭愊眗喷监劋萘訯總槿棭戾墮犄恌縈簍樥蛔杁袭嫛憫倆篏墵賈羯茎觳蒜致娢慄勒覸蘍曲栂葭宆妋皽缽免盳猼蔂糥觧烳檸佯憓煶蔐筼种繷琲膌塄剰讎対腕棥渽忲俛浪譬秛惛壒嘸淫冻曄睻砃奫貯庴爅粓脮脡娎妖峵蘲討惋泊蠀㴆"
|
134 |
+
)
|
ChatTTS/core.py
ADDED
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import tempfile
|
4 |
+
from dataclasses import dataclass, asdict
|
5 |
+
from typing import Literal, Optional, List, Tuple, Dict, Union
|
6 |
+
from json import load
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from vocos import Vocos
|
12 |
+
from vocos.pretrained import instantiate_class
|
13 |
+
from huggingface_hub import snapshot_download
|
14 |
+
|
15 |
+
from .config import Config
|
16 |
+
from .model import DVAE, Embed, GPT, gen_logits, Tokenizer, Speaker
|
17 |
+
from .utils import (
|
18 |
+
check_all_assets,
|
19 |
+
download_all_assets,
|
20 |
+
select_device,
|
21 |
+
get_latest_modified_file,
|
22 |
+
del_all,
|
23 |
+
)
|
24 |
+
from .utils import logger as utils_logger
|
25 |
+
|
26 |
+
from .norm import Normalizer
|
27 |
+
|
28 |
+
|
29 |
+
class Chat:
|
30 |
+
def __init__(self, logger=logging.getLogger(__name__)):
|
31 |
+
self.logger = logger
|
32 |
+
utils_logger.set_logger(logger)
|
33 |
+
|
34 |
+
self.config = Config()
|
35 |
+
|
36 |
+
self.normalizer = Normalizer(
|
37 |
+
os.path.join(os.path.dirname(__file__), "res", "homophones_map.json"),
|
38 |
+
logger,
|
39 |
+
)
|
40 |
+
with open(
|
41 |
+
os.path.join(os.path.dirname(__file__), "res", "sha256_map.json")
|
42 |
+
) as f:
|
43 |
+
self.sha256_map: Dict[str, str] = load(f)
|
44 |
+
|
45 |
+
self.context = GPT.Context()
|
46 |
+
|
47 |
+
def has_loaded(self, use_decoder=False):
|
48 |
+
not_finish = False
|
49 |
+
check_list = ["vocos", "gpt", "tokenizer", "embed"]
|
50 |
+
|
51 |
+
if use_decoder:
|
52 |
+
check_list.append("decoder")
|
53 |
+
else:
|
54 |
+
check_list.append("dvae")
|
55 |
+
|
56 |
+
for module in check_list:
|
57 |
+
if not hasattr(self, module):
|
58 |
+
self.logger.warning(f"{module} not initialized.")
|
59 |
+
not_finish = True
|
60 |
+
|
61 |
+
return not not_finish
|
62 |
+
|
63 |
+
def download_models(
|
64 |
+
self,
|
65 |
+
source: Literal["huggingface", "local", "custom"] = "local",
|
66 |
+
force_redownload=False,
|
67 |
+
custom_path: Optional[torch.serialization.FILE_LIKE] = None,
|
68 |
+
) -> Optional[str]:
|
69 |
+
if source == "local":
|
70 |
+
download_path = os.getcwd()
|
71 |
+
if (
|
72 |
+
not check_all_assets(Path(download_path), self.sha256_map, update=True)
|
73 |
+
or force_redownload
|
74 |
+
):
|
75 |
+
with tempfile.TemporaryDirectory() as tmp:
|
76 |
+
download_all_assets(tmpdir=tmp)
|
77 |
+
if not check_all_assets(
|
78 |
+
Path(download_path), self.sha256_map, update=False
|
79 |
+
):
|
80 |
+
self.logger.error(
|
81 |
+
"download to local path %s failed.", download_path
|
82 |
+
)
|
83 |
+
return None
|
84 |
+
elif source == "huggingface":
|
85 |
+
hf_home = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
|
86 |
+
try:
|
87 |
+
download_path = get_latest_modified_file(
|
88 |
+
os.path.join(hf_home, "hub/models--2Noise--ChatTTS/snapshots")
|
89 |
+
)
|
90 |
+
except:
|
91 |
+
download_path = None
|
92 |
+
if download_path is None or force_redownload:
|
93 |
+
self.logger.log(
|
94 |
+
logging.INFO,
|
95 |
+
f"download from HF: https://huggingface.co/2Noise/ChatTTS",
|
96 |
+
)
|
97 |
+
try:
|
98 |
+
download_path = snapshot_download(
|
99 |
+
repo_id="2Noise/ChatTTS",
|
100 |
+
allow_patterns=["*.pt", "*.yaml", "*.json", "*.safetensors"],
|
101 |
+
)
|
102 |
+
except:
|
103 |
+
download_path = None
|
104 |
+
else:
|
105 |
+
self.logger.log(
|
106 |
+
logging.INFO, f"load latest snapshot from cache: {download_path}"
|
107 |
+
)
|
108 |
+
if download_path is None:
|
109 |
+
self.logger.error("download from huggingface failed.")
|
110 |
+
return None
|
111 |
+
elif source == "custom":
|
112 |
+
self.logger.log(logging.INFO, f"try to load from local: {custom_path}")
|
113 |
+
if not check_all_assets(Path(custom_path), self.sha256_map, update=False):
|
114 |
+
self.logger.error("check models in custom path %s failed.", custom_path)
|
115 |
+
return None
|
116 |
+
download_path = custom_path
|
117 |
+
|
118 |
+
return download_path
|
119 |
+
|
120 |
+
def load(
|
121 |
+
self,
|
122 |
+
source: Literal["huggingface", "local", "custom"] = "local",
|
123 |
+
force_redownload=False,
|
124 |
+
compile: bool = False,
|
125 |
+
custom_path: Optional[torch.serialization.FILE_LIKE] = None,
|
126 |
+
device: Optional[torch.device] = None,
|
127 |
+
coef: Optional[torch.Tensor] = None,
|
128 |
+
use_flash_attn=False,
|
129 |
+
use_vllm=False,
|
130 |
+
experimental: bool = False,
|
131 |
+
) -> bool:
|
132 |
+
download_path = self.download_models(source, force_redownload, custom_path)
|
133 |
+
if download_path is None:
|
134 |
+
return False
|
135 |
+
return self._load(
|
136 |
+
device=device,
|
137 |
+
compile=compile,
|
138 |
+
coef=coef,
|
139 |
+
use_flash_attn=use_flash_attn,
|
140 |
+
use_vllm=use_vllm,
|
141 |
+
experimental=experimental,
|
142 |
+
**{
|
143 |
+
k: os.path.join(download_path, v)
|
144 |
+
for k, v in asdict(self.config.path).items()
|
145 |
+
},
|
146 |
+
)
|
147 |
+
|
148 |
+
def unload(self):
|
149 |
+
logger = self.logger
|
150 |
+
self.normalizer.destroy()
|
151 |
+
del self.normalizer
|
152 |
+
del self.sha256_map
|
153 |
+
del_list = ["vocos", "gpt", "decoder", "dvae", "tokenizer", "embed"]
|
154 |
+
for module in del_list:
|
155 |
+
if hasattr(self, module):
|
156 |
+
delattr(self, module)
|
157 |
+
self.__init__(logger)
|
158 |
+
|
159 |
+
def sample_random_speaker(self) -> str:
|
160 |
+
return self.speaker.sample_random()
|
161 |
+
|
162 |
+
def sample_audio_speaker(self, wav: Union[np.ndarray, torch.Tensor]) -> str:
|
163 |
+
return self.speaker.encode_prompt(self.dvae.sample_audio(wav))
|
164 |
+
|
165 |
+
@dataclass(repr=False, eq=False)
|
166 |
+
class RefineTextParams:
|
167 |
+
prompt: str = ""
|
168 |
+
top_P: float = 0.7
|
169 |
+
top_K: int = 20
|
170 |
+
temperature: float = 0.7
|
171 |
+
repetition_penalty: float = 1.0
|
172 |
+
max_new_token: int = 384
|
173 |
+
min_new_token: int = 0
|
174 |
+
show_tqdm: bool = True
|
175 |
+
ensure_non_empty: bool = True
|
176 |
+
manual_seed: Optional[int] = None
|
177 |
+
|
178 |
+
@dataclass(repr=False, eq=False)
|
179 |
+
class InferCodeParams(RefineTextParams):
|
180 |
+
prompt: str = "[speed_5]"
|
181 |
+
spk_emb: Optional[str] = None
|
182 |
+
spk_smp: Optional[str] = None
|
183 |
+
txt_smp: Optional[str] = None
|
184 |
+
temperature: float = 0.3
|
185 |
+
repetition_penalty: float = 1.05
|
186 |
+
max_new_token: int = 2048
|
187 |
+
stream_batch: int = 24
|
188 |
+
stream_speed: int = 12000
|
189 |
+
pass_first_n_batches: int = 2
|
190 |
+
|
191 |
+
def infer(
|
192 |
+
self,
|
193 |
+
text,
|
194 |
+
stream=False,
|
195 |
+
lang=None,
|
196 |
+
skip_refine_text=False,
|
197 |
+
refine_text_only=False,
|
198 |
+
use_decoder=True,
|
199 |
+
do_text_normalization=True,
|
200 |
+
do_homophone_replacement=True,
|
201 |
+
params_refine_text=RefineTextParams(),
|
202 |
+
params_infer_code=InferCodeParams(),
|
203 |
+
):
|
204 |
+
self.context.set(False)
|
205 |
+
res_gen = self._infer(
|
206 |
+
text,
|
207 |
+
stream,
|
208 |
+
lang,
|
209 |
+
skip_refine_text,
|
210 |
+
refine_text_only,
|
211 |
+
use_decoder,
|
212 |
+
do_text_normalization,
|
213 |
+
do_homophone_replacement,
|
214 |
+
params_refine_text,
|
215 |
+
params_infer_code,
|
216 |
+
)
|
217 |
+
if stream:
|
218 |
+
return res_gen
|
219 |
+
else:
|
220 |
+
return next(res_gen)
|
221 |
+
|
222 |
+
def interrupt(self):
|
223 |
+
self.context.set(True)
|
224 |
+
|
225 |
+
@torch.no_grad()
|
226 |
+
def _load(
|
227 |
+
self,
|
228 |
+
vocos_ckpt_path: str = None,
|
229 |
+
dvae_ckpt_path: str = None,
|
230 |
+
gpt_ckpt_path: str = None,
|
231 |
+
embed_path: str = None,
|
232 |
+
decoder_ckpt_path: str = None,
|
233 |
+
tokenizer_path: str = None,
|
234 |
+
device: Optional[torch.device] = None,
|
235 |
+
compile: bool = False,
|
236 |
+
coef: Optional[str] = None,
|
237 |
+
use_flash_attn=False,
|
238 |
+
use_vllm=False,
|
239 |
+
experimental: bool = False,
|
240 |
+
):
|
241 |
+
if device is None:
|
242 |
+
device = select_device(experimental=experimental)
|
243 |
+
self.logger.info("use device %s", str(device))
|
244 |
+
self.device = device
|
245 |
+
self.device_gpt = device if "mps" not in str(device) else torch.device("cpu")
|
246 |
+
self.compile = compile
|
247 |
+
|
248 |
+
feature_extractor = instantiate_class(
|
249 |
+
args=(), init=asdict(self.config.vocos.feature_extractor)
|
250 |
+
)
|
251 |
+
backbone = instantiate_class(args=(), init=asdict(self.config.vocos.backbone))
|
252 |
+
head = instantiate_class(args=(), init=asdict(self.config.vocos.head))
|
253 |
+
vocos = (
|
254 |
+
Vocos(feature_extractor=feature_extractor, backbone=backbone, head=head)
|
255 |
+
.to(
|
256 |
+
# vocos on mps will crash, use cpu fallback
|
257 |
+
"cpu"
|
258 |
+
if "mps" in str(device)
|
259 |
+
else device
|
260 |
+
)
|
261 |
+
.eval()
|
262 |
+
)
|
263 |
+
assert vocos_ckpt_path, "vocos_ckpt_path should not be None"
|
264 |
+
vocos.load_state_dict(torch.load(vocos_ckpt_path, weights_only=True, mmap=True))
|
265 |
+
self.vocos = vocos
|
266 |
+
self.logger.log(logging.INFO, "vocos loaded.")
|
267 |
+
|
268 |
+
dvae = (
|
269 |
+
DVAE(
|
270 |
+
decoder_config=asdict(self.config.dvae.decoder),
|
271 |
+
encoder_config=asdict(self.config.dvae.encoder),
|
272 |
+
vq_config=asdict(self.config.dvae.vq),
|
273 |
+
dim=self.config.dvae.decoder.idim,
|
274 |
+
coef=coef,
|
275 |
+
device=device,
|
276 |
+
)
|
277 |
+
.to(device)
|
278 |
+
.eval()
|
279 |
+
)
|
280 |
+
coef = str(dvae)
|
281 |
+
assert dvae_ckpt_path, "dvae_ckpt_path should not be None"
|
282 |
+
dvae.load_state_dict(torch.load(dvae_ckpt_path, weights_only=True, mmap=True))
|
283 |
+
self.dvae = dvae
|
284 |
+
self.logger.log(logging.INFO, "dvae loaded.")
|
285 |
+
|
286 |
+
embed = Embed(
|
287 |
+
self.config.embed.hidden_size,
|
288 |
+
self.config.embed.num_audio_tokens,
|
289 |
+
self.config.embed.num_text_tokens,
|
290 |
+
self.config.embed.num_vq,
|
291 |
+
)
|
292 |
+
embed.from_pretrained(embed_path, device=device)
|
293 |
+
self.embed = embed.to(device)
|
294 |
+
self.logger.log(logging.INFO, "embed loaded.")
|
295 |
+
|
296 |
+
gpt = GPT(
|
297 |
+
gpt_config=asdict(self.config.gpt),
|
298 |
+
embed=self.embed,
|
299 |
+
use_flash_attn=use_flash_attn,
|
300 |
+
use_vllm=use_vllm,
|
301 |
+
device=device,
|
302 |
+
device_gpt=self.device_gpt,
|
303 |
+
logger=self.logger,
|
304 |
+
).eval()
|
305 |
+
assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
|
306 |
+
gpt.from_pretrained(gpt_ckpt_path, embed_path, experimental=experimental)
|
307 |
+
gpt.prepare(compile=compile and "cuda" in str(device))
|
308 |
+
self.gpt = gpt
|
309 |
+
self.logger.log(logging.INFO, "gpt loaded.")
|
310 |
+
|
311 |
+
self.speaker = Speaker(
|
312 |
+
self.config.gpt.hidden_size, self.config.spk_stat, device
|
313 |
+
)
|
314 |
+
self.logger.log(logging.INFO, "speaker loaded.")
|
315 |
+
|
316 |
+
decoder = (
|
317 |
+
DVAE(
|
318 |
+
decoder_config=asdict(self.config.decoder),
|
319 |
+
dim=self.config.decoder.idim,
|
320 |
+
coef=coef,
|
321 |
+
device=device,
|
322 |
+
)
|
323 |
+
.to(device)
|
324 |
+
.eval()
|
325 |
+
)
|
326 |
+
coef = str(decoder)
|
327 |
+
assert decoder_ckpt_path, "decoder_ckpt_path should not be None"
|
328 |
+
decoder.load_state_dict(
|
329 |
+
torch.load(decoder_ckpt_path, weights_only=True, mmap=True)
|
330 |
+
)
|
331 |
+
self.decoder = decoder
|
332 |
+
self.logger.log(logging.INFO, "decoder loaded.")
|
333 |
+
|
334 |
+
if tokenizer_path:
|
335 |
+
self.tokenizer = Tokenizer(tokenizer_path)
|
336 |
+
self.logger.log(logging.INFO, "tokenizer loaded.")
|
337 |
+
|
338 |
+
self.coef = coef
|
339 |
+
|
340 |
+
return self.has_loaded()
|
341 |
+
|
342 |
+
def _infer(
|
343 |
+
self,
|
344 |
+
text,
|
345 |
+
stream=False,
|
346 |
+
lang=None,
|
347 |
+
skip_refine_text=False,
|
348 |
+
refine_text_only=False,
|
349 |
+
use_decoder=True,
|
350 |
+
do_text_normalization=True,
|
351 |
+
do_homophone_replacement=True,
|
352 |
+
params_refine_text=RefineTextParams(),
|
353 |
+
params_infer_code=InferCodeParams(),
|
354 |
+
):
|
355 |
+
|
356 |
+
assert self.has_loaded(use_decoder=use_decoder)
|
357 |
+
|
358 |
+
if not isinstance(text, list):
|
359 |
+
text = [text]
|
360 |
+
|
361 |
+
text = [
|
362 |
+
self.normalizer(
|
363 |
+
t,
|
364 |
+
do_text_normalization,
|
365 |
+
do_homophone_replacement,
|
366 |
+
lang,
|
367 |
+
)
|
368 |
+
for t in text
|
369 |
+
]
|
370 |
+
|
371 |
+
self.logger.debug("normed texts %s", str(text))
|
372 |
+
|
373 |
+
if not skip_refine_text:
|
374 |
+
refined = self._refine_text(
|
375 |
+
text,
|
376 |
+
self.device,
|
377 |
+
params_refine_text,
|
378 |
+
)
|
379 |
+
text_tokens = refined.ids
|
380 |
+
text_tokens = [i[i.less(self.tokenizer.break_0_ids)] for i in text_tokens]
|
381 |
+
text = self.tokenizer.decode(text_tokens)
|
382 |
+
refined.destroy()
|
383 |
+
if refine_text_only:
|
384 |
+
yield text
|
385 |
+
return
|
386 |
+
|
387 |
+
if stream:
|
388 |
+
length = 0
|
389 |
+
pass_batch_count = 0
|
390 |
+
for result in self._infer_code(
|
391 |
+
text,
|
392 |
+
stream,
|
393 |
+
self.device,
|
394 |
+
use_decoder,
|
395 |
+
params_infer_code,
|
396 |
+
):
|
397 |
+
wavs = self._decode_to_wavs(
|
398 |
+
result.hiddens if use_decoder else result.ids,
|
399 |
+
use_decoder,
|
400 |
+
)
|
401 |
+
result.destroy()
|
402 |
+
if stream:
|
403 |
+
pass_batch_count += 1
|
404 |
+
if pass_batch_count <= params_infer_code.pass_first_n_batches:
|
405 |
+
continue
|
406 |
+
a = length
|
407 |
+
b = a + params_infer_code.stream_speed
|
408 |
+
if b > wavs.shape[1]:
|
409 |
+
b = wavs.shape[1]
|
410 |
+
new_wavs = wavs[:, a:b]
|
411 |
+
length = b
|
412 |
+
yield new_wavs
|
413 |
+
else:
|
414 |
+
yield wavs
|
415 |
+
if stream:
|
416 |
+
new_wavs = wavs[:, length:]
|
417 |
+
# Identify rows with non-zero elements using np.any
|
418 |
+
# keep_rows = np.any(array != 0, axis=1)
|
419 |
+
keep_cols = np.sum(new_wavs != 0, axis=0) > 0
|
420 |
+
# Filter both rows and columns using slicing
|
421 |
+
yield new_wavs[:][:, keep_cols]
|
422 |
+
|
423 |
+
@torch.inference_mode()
|
424 |
+
def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray:
|
425 |
+
if "mps" in str(self.device):
|
426 |
+
return self.vocos.decode(spec.cpu()).cpu().numpy()
|
427 |
+
else:
|
428 |
+
return self.vocos.decode(spec).cpu().numpy()
|
429 |
+
|
430 |
+
@torch.inference_mode()
|
431 |
+
def _decode_to_wavs(
|
432 |
+
self,
|
433 |
+
result_list: List[torch.Tensor],
|
434 |
+
use_decoder: bool,
|
435 |
+
):
|
436 |
+
decoder = self.decoder if use_decoder else self.dvae
|
437 |
+
max_x_len = -1
|
438 |
+
if len(result_list) == 0:
|
439 |
+
return np.array([], dtype=np.float32)
|
440 |
+
for result in result_list:
|
441 |
+
if result.size(0) > max_x_len:
|
442 |
+
max_x_len = result.size(0)
|
443 |
+
batch_result = torch.zeros(
|
444 |
+
(len(result_list), result_list[0].size(1), max_x_len),
|
445 |
+
dtype=result_list[0].dtype,
|
446 |
+
device=result_list[0].device,
|
447 |
+
)
|
448 |
+
for i in range(len(result_list)):
|
449 |
+
src = result_list[i]
|
450 |
+
batch_result[i].narrow(1, 0, src.size(0)).copy_(src.permute(1, 0))
|
451 |
+
del src
|
452 |
+
del_all(result_list)
|
453 |
+
mel_specs = decoder(batch_result)
|
454 |
+
del batch_result
|
455 |
+
wavs = self._vocos_decode(mel_specs)
|
456 |
+
del mel_specs
|
457 |
+
return wavs
|
458 |
+
|
459 |
+
@torch.no_grad()
|
460 |
+
def _infer_code(
|
461 |
+
self,
|
462 |
+
text: Tuple[List[str], str],
|
463 |
+
stream: bool,
|
464 |
+
device: torch.device,
|
465 |
+
return_hidden: bool,
|
466 |
+
params: InferCodeParams,
|
467 |
+
):
|
468 |
+
|
469 |
+
gpt = self.gpt
|
470 |
+
|
471 |
+
if not isinstance(text, list):
|
472 |
+
text = [text]
|
473 |
+
|
474 |
+
assert len(text), "text should not be empty"
|
475 |
+
|
476 |
+
if not isinstance(params.temperature, list):
|
477 |
+
temperature = [params.temperature] * self.config.gpt.num_vq
|
478 |
+
else:
|
479 |
+
temperature = params.temperature
|
480 |
+
|
481 |
+
input_ids, attention_mask, text_mask = self.tokenizer.encode(
|
482 |
+
self.speaker.decorate_code_prompts(
|
483 |
+
text,
|
484 |
+
params.prompt,
|
485 |
+
params.txt_smp,
|
486 |
+
params.spk_emb,
|
487 |
+
),
|
488 |
+
self.config.gpt.num_vq,
|
489 |
+
prompt=(
|
490 |
+
self.speaker.decode_prompt(params.spk_smp)
|
491 |
+
if params.spk_smp is not None
|
492 |
+
else None
|
493 |
+
),
|
494 |
+
device=self.device_gpt,
|
495 |
+
)
|
496 |
+
start_idx = input_ids.shape[-2]
|
497 |
+
|
498 |
+
num_code = self.config.gpt.num_audio_tokens - 1
|
499 |
+
|
500 |
+
logits_warpers, logits_processors = gen_logits(
|
501 |
+
num_code=num_code,
|
502 |
+
top_P=params.top_P,
|
503 |
+
top_K=params.top_K,
|
504 |
+
repetition_penalty=params.repetition_penalty,
|
505 |
+
)
|
506 |
+
|
507 |
+
if gpt.is_vllm:
|
508 |
+
from .model.velocity import SamplingParams
|
509 |
+
|
510 |
+
sample_params = SamplingParams(
|
511 |
+
temperature=temperature,
|
512 |
+
max_new_token=params.max_new_token,
|
513 |
+
max_tokens=8192,
|
514 |
+
min_new_token=params.min_new_token,
|
515 |
+
logits_processors=(logits_processors, logits_warpers),
|
516 |
+
eos_token=num_code,
|
517 |
+
infer_text=False,
|
518 |
+
start_idx=start_idx,
|
519 |
+
)
|
520 |
+
input_ids = [i.tolist() for i in input_ids]
|
521 |
+
|
522 |
+
result = gpt.llm.generate(
|
523 |
+
None,
|
524 |
+
sample_params,
|
525 |
+
input_ids,
|
526 |
+
)
|
527 |
+
|
528 |
+
token_ids = []
|
529 |
+
hidden_states = []
|
530 |
+
for i in result:
|
531 |
+
token_ids.append(torch.tensor(i.outputs[0].token_ids))
|
532 |
+
hidden_states.append(
|
533 |
+
i.outputs[0].hidden_states.to(torch.float32).to(self.device)
|
534 |
+
)
|
535 |
+
|
536 |
+
del text_mask, input_ids
|
537 |
+
|
538 |
+
return [
|
539 |
+
GPT.GenerationOutputs(
|
540 |
+
ids=token_ids,
|
541 |
+
hiddens=hidden_states,
|
542 |
+
attentions=[],
|
543 |
+
),
|
544 |
+
]
|
545 |
+
|
546 |
+
emb = self.embed(input_ids, text_mask)
|
547 |
+
|
548 |
+
del text_mask
|
549 |
+
|
550 |
+
if params.spk_emb is not None:
|
551 |
+
self.speaker.apply(
|
552 |
+
emb,
|
553 |
+
params.spk_emb,
|
554 |
+
input_ids,
|
555 |
+
self.tokenizer.spk_emb_ids,
|
556 |
+
self.gpt.device_gpt,
|
557 |
+
)
|
558 |
+
|
559 |
+
result = gpt.generate(
|
560 |
+
emb,
|
561 |
+
input_ids,
|
562 |
+
temperature=torch.tensor(temperature, device=device),
|
563 |
+
eos_token=num_code,
|
564 |
+
attention_mask=attention_mask,
|
565 |
+
max_new_token=params.max_new_token,
|
566 |
+
min_new_token=params.min_new_token,
|
567 |
+
logits_processors=(*logits_processors, *logits_warpers),
|
568 |
+
infer_text=False,
|
569 |
+
return_hidden=return_hidden,
|
570 |
+
stream=stream,
|
571 |
+
show_tqdm=params.show_tqdm,
|
572 |
+
ensure_non_empty=params.ensure_non_empty,
|
573 |
+
stream_batch=params.stream_batch,
|
574 |
+
manual_seed=params.manual_seed,
|
575 |
+
context=self.context,
|
576 |
+
)
|
577 |
+
|
578 |
+
del emb, input_ids
|
579 |
+
|
580 |
+
return result
|
581 |
+
|
582 |
+
@torch.no_grad()
|
583 |
+
def _refine_text(
|
584 |
+
self,
|
585 |
+
text: str,
|
586 |
+
device: torch.device,
|
587 |
+
params: RefineTextParams,
|
588 |
+
):
|
589 |
+
|
590 |
+
gpt = self.gpt
|
591 |
+
|
592 |
+
if not isinstance(text, list):
|
593 |
+
text = [text]
|
594 |
+
|
595 |
+
input_ids, attention_mask, text_mask = self.tokenizer.encode(
|
596 |
+
self.speaker.decorate_text_prompts(text, params.prompt),
|
597 |
+
self.config.gpt.num_vq,
|
598 |
+
device=self.device_gpt,
|
599 |
+
)
|
600 |
+
|
601 |
+
logits_warpers, logits_processors = gen_logits(
|
602 |
+
num_code=self.tokenizer.len,
|
603 |
+
top_P=params.top_P,
|
604 |
+
top_K=params.top_K,
|
605 |
+
repetition_penalty=params.repetition_penalty,
|
606 |
+
)
|
607 |
+
|
608 |
+
if gpt.is_vllm:
|
609 |
+
from .model.velocity import SamplingParams
|
610 |
+
|
611 |
+
sample_params = SamplingParams(
|
612 |
+
repetition_penalty=params.repetition_penalty,
|
613 |
+
temperature=params.temperature,
|
614 |
+
top_p=params.top_P,
|
615 |
+
top_k=params.top_K,
|
616 |
+
max_new_token=params.max_new_token,
|
617 |
+
max_tokens=8192,
|
618 |
+
min_new_token=params.min_new_token,
|
619 |
+
logits_processors=(logits_processors, logits_warpers),
|
620 |
+
eos_token=self.tokenizer.eos_token,
|
621 |
+
infer_text=True,
|
622 |
+
start_idx=input_ids.shape[-2],
|
623 |
+
)
|
624 |
+
input_ids_list = [i.tolist() for i in input_ids]
|
625 |
+
del input_ids
|
626 |
+
|
627 |
+
result = gpt.llm.generate(
|
628 |
+
None, sample_params, input_ids_list, params.show_tqdm
|
629 |
+
)
|
630 |
+
token_ids = []
|
631 |
+
hidden_states = []
|
632 |
+
for i in result:
|
633 |
+
token_ids.append(torch.tensor(i.outputs[0].token_ids))
|
634 |
+
hidden_states.append(i.outputs[0].hidden_states)
|
635 |
+
|
636 |
+
del text_mask, input_ids_list, result
|
637 |
+
|
638 |
+
return GPT.GenerationOutputs(
|
639 |
+
ids=token_ids,
|
640 |
+
hiddens=hidden_states,
|
641 |
+
attentions=[],
|
642 |
+
)
|
643 |
+
|
644 |
+
emb = self.embed(input_ids, text_mask)
|
645 |
+
|
646 |
+
del text_mask
|
647 |
+
|
648 |
+
result = next(
|
649 |
+
gpt.generate(
|
650 |
+
emb,
|
651 |
+
input_ids,
|
652 |
+
temperature=torch.tensor([params.temperature], device=device),
|
653 |
+
eos_token=self.tokenizer.eos_token,
|
654 |
+
attention_mask=attention_mask,
|
655 |
+
max_new_token=params.max_new_token,
|
656 |
+
min_new_token=params.min_new_token,
|
657 |
+
logits_processors=(*logits_processors, *logits_warpers),
|
658 |
+
infer_text=True,
|
659 |
+
stream=False,
|
660 |
+
show_tqdm=params.show_tqdm,
|
661 |
+
ensure_non_empty=params.ensure_non_empty,
|
662 |
+
manual_seed=params.manual_seed,
|
663 |
+
context=self.context,
|
664 |
+
)
|
665 |
+
)
|
666 |
+
|
667 |
+
del emb, input_ids
|
668 |
+
|
669 |
+
return result
|
ChatTTS/model/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .dvae import DVAE
|
2 |
+
from .embed import Embed
|
3 |
+
from .gpt import GPT
|
4 |
+
from .processors import gen_logits
|
5 |
+
from .speaker import Speaker
|
6 |
+
from .tokenizer import Tokenizer
|
ChatTTS/model/cuda/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .te_llama import TELlamaModel
|
ChatTTS/model/cuda/patch.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class LlamaRMSNorm(torch.nn.Module):
|
5 |
+
def __init__(self, hidden_size, eps=1e-6):
|
6 |
+
"""
|
7 |
+
LlamaRMSNorm is equivalent to T5LayerNorm
|
8 |
+
"""
|
9 |
+
super().__init__()
|
10 |
+
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
11 |
+
self.variance_epsilon = eps
|
12 |
+
|
13 |
+
def forward(self, hidden_states: torch.Tensor):
|
14 |
+
input_dtype = hidden_states.dtype
|
15 |
+
hidden_states = hidden_states.to(torch.float32)
|
16 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
17 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
18 |
+
return self.weight.to(hidden_states.device) * hidden_states.to(input_dtype)
|
ChatTTS/model/cuda/te_llama.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# See LICENSE for license information.
|
4 |
+
#
|
5 |
+
# From https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/te_llama/te_llama.py
|
6 |
+
#
|
7 |
+
# Edited by fumiama.
|
8 |
+
|
9 |
+
import re
|
10 |
+
from contextlib import contextmanager
|
11 |
+
from typing import Dict
|
12 |
+
|
13 |
+
import transformer_engine as te
|
14 |
+
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
|
15 |
+
|
16 |
+
import torch
|
17 |
+
|
18 |
+
import transformers
|
19 |
+
from transformers.models.llama.modeling_llama import (
|
20 |
+
LlamaModel,
|
21 |
+
LlamaConfig,
|
22 |
+
)
|
23 |
+
from transformers.modeling_utils import _load_state_dict_into_model
|
24 |
+
|
25 |
+
from .patch import LlamaRMSNorm
|
26 |
+
|
27 |
+
|
28 |
+
@contextmanager
|
29 |
+
def replace_decoder(te_decoder_cls, llama_rms_norm_cls):
|
30 |
+
"""
|
31 |
+
Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.
|
32 |
+
"""
|
33 |
+
original_llama_decoder_cls = (
|
34 |
+
transformers.models.llama.modeling_llama.LlamaDecoderLayer
|
35 |
+
)
|
36 |
+
transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls
|
37 |
+
original_llama_rms_norm_cls = transformers.models.llama.modeling_llama.LlamaRMSNorm
|
38 |
+
transformers.models.llama.modeling_llama.LlamaRMSNorm = llama_rms_norm_cls
|
39 |
+
try:
|
40 |
+
yield
|
41 |
+
finally:
|
42 |
+
transformers.models.llama.modeling_llama.LlamaDecoderLayer = (
|
43 |
+
original_llama_decoder_cls
|
44 |
+
)
|
45 |
+
transformers.models.llama.modeling_llama.LlamaRMSNorm = (
|
46 |
+
original_llama_rms_norm_cls
|
47 |
+
)
|
48 |
+
|
49 |
+
|
50 |
+
class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
|
51 |
+
"""
|
52 |
+
Wrapper class over TE's `TransformerLayer`. This makes the wrapper very
|
53 |
+
similar to HF's `LlamaDecoderLayer` and easier to replace it in the code.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
config: LlamaConfig
|
57 |
+
args: positional args (for compatibility with `LlamaDecoderLayer`)
|
58 |
+
kwargs: keyword args (for compatibility with `LlamaDecoderLayer`)
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(self, config, *args, **kwargs):
|
62 |
+
super().__init__(
|
63 |
+
hidden_size=config.hidden_size,
|
64 |
+
ffn_hidden_size=config.intermediate_size,
|
65 |
+
num_attention_heads=config.num_attention_heads,
|
66 |
+
bias=False,
|
67 |
+
layernorm_epsilon=config.rms_norm_eps,
|
68 |
+
hidden_dropout=0,
|
69 |
+
attention_dropout=0,
|
70 |
+
fuse_qkv_params=False,
|
71 |
+
normalization="RMSNorm",
|
72 |
+
activation="swiglu",
|
73 |
+
attn_input_format="bshd",
|
74 |
+
num_gqa_groups=config.num_key_value_heads,
|
75 |
+
)
|
76 |
+
te_rope = RotaryPositionEmbedding(
|
77 |
+
config.hidden_size // config.num_attention_heads
|
78 |
+
)
|
79 |
+
self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()
|
80 |
+
|
81 |
+
def forward(self, hidden_states, *args, attention_mask, **kwargs):
|
82 |
+
"""
|
83 |
+
Custom forward to make sure we only pass relevant arguments to the
|
84 |
+
forward pass of the `TransformerLayer`. Also, make sure the output
|
85 |
+
format matches the output of the HF's `LlamaDecoderLayer`.
|
86 |
+
"""
|
87 |
+
return (
|
88 |
+
super().forward(
|
89 |
+
hidden_states,
|
90 |
+
attention_mask=attention_mask,
|
91 |
+
rotary_pos_emb=self.te_rope_emb,
|
92 |
+
),
|
93 |
+
)
|
94 |
+
|
95 |
+
|
96 |
+
class TELlamaModel:
|
97 |
+
"""
|
98 |
+
LM created with `LlamaModel`. The underlying `LlamaDecoderLayer`
|
99 |
+
class is monkey-patched with `TELlamaDecoderLayer` class before
|
100 |
+
initializing the causal LM with `LlamaModel`.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
config: LlamaConfig
|
104 |
+
"""
|
105 |
+
|
106 |
+
def __new__(cls, config: LlamaConfig):
|
107 |
+
with replace_decoder(
|
108 |
+
te_decoder_cls=TELlamaDecoderLayer, llama_rms_norm_cls=LlamaRMSNorm
|
109 |
+
):
|
110 |
+
model = LlamaModel(config)
|
111 |
+
return model
|
112 |
+
|
113 |
+
@classmethod
|
114 |
+
def from_state_dict(
|
115 |
+
cls,
|
116 |
+
state_dict: Dict[str, torch.Tensor],
|
117 |
+
config: LlamaConfig,
|
118 |
+
):
|
119 |
+
"""
|
120 |
+
Custom method adapted from `from_pretrained` method in HuggingFace
|
121 |
+
Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
|
122 |
+
"""
|
123 |
+
|
124 |
+
vanilla_model = cls(config)
|
125 |
+
|
126 |
+
# replace_params copies parameters relevant only to TransformerEngine
|
127 |
+
_replace_params(state_dict, vanilla_model.state_dict(), config)
|
128 |
+
# _load_state_dict_into_model copies parameters other than those in TransformerEngine
|
129 |
+
_load_state_dict_into_model(vanilla_model, state_dict, start_prefix="")
|
130 |
+
|
131 |
+
return vanilla_model
|
132 |
+
|
133 |
+
|
134 |
+
def _replace_params(hf_state_dict, te_state_dict, config):
|
135 |
+
# collect all layer prefixes to update
|
136 |
+
all_layer_prefixes = set()
|
137 |
+
for param_key in hf_state_dict.keys():
|
138 |
+
layer_prefix_pat = "model.layers.\d+."
|
139 |
+
m = re.match(layer_prefix_pat, param_key)
|
140 |
+
if m is not None:
|
141 |
+
all_layer_prefixes.add(m.group())
|
142 |
+
|
143 |
+
for layer_prefix in all_layer_prefixes:
|
144 |
+
# When loading weights into models with less number of layers, skip the
|
145 |
+
# copy if the corresponding layer doesn't exist in HF model
|
146 |
+
if layer_prefix + "input_layernorm.weight" in hf_state_dict:
|
147 |
+
te_state_dict[
|
148 |
+
layer_prefix + "self_attention.layernorm_qkv.layer_norm_weight"
|
149 |
+
].data[:] = hf_state_dict[layer_prefix + "input_layernorm.weight"].data[:]
|
150 |
+
|
151 |
+
if layer_prefix + "self_attn.q_proj.weight" in hf_state_dict:
|
152 |
+
te_state_dict[
|
153 |
+
layer_prefix + "self_attention.layernorm_qkv.query_weight"
|
154 |
+
].data[:] = hf_state_dict[layer_prefix + "self_attn.q_proj.weight"].data[:]
|
155 |
+
|
156 |
+
if layer_prefix + "self_attn.k_proj.weight" in hf_state_dict:
|
157 |
+
te_state_dict[
|
158 |
+
layer_prefix + "self_attention.layernorm_qkv.key_weight"
|
159 |
+
].data[:] = hf_state_dict[layer_prefix + "self_attn.k_proj.weight"].data[:]
|
160 |
+
|
161 |
+
if layer_prefix + "self_attn.v_proj.weight" in hf_state_dict:
|
162 |
+
te_state_dict[
|
163 |
+
layer_prefix + "self_attention.layernorm_qkv.value_weight"
|
164 |
+
].data[:] = hf_state_dict[layer_prefix + "self_attn.v_proj.weight"].data[:]
|
165 |
+
|
166 |
+
if layer_prefix + "self_attn.o_proj.weight" in hf_state_dict:
|
167 |
+
te_state_dict[layer_prefix + "self_attention.proj.weight"].data[:] = (
|
168 |
+
hf_state_dict[layer_prefix + "self_attn.o_proj.weight"].data[:]
|
169 |
+
)
|
170 |
+
|
171 |
+
if layer_prefix + "post_attention_layernorm.weight" in hf_state_dict:
|
172 |
+
te_state_dict[layer_prefix + "layernorm_mlp.layer_norm_weight"].data[:] = (
|
173 |
+
hf_state_dict[layer_prefix + "post_attention_layernorm.weight"].data[:]
|
174 |
+
)
|
175 |
+
|
176 |
+
# It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to
|
177 |
+
# load them separately.
|
178 |
+
if layer_prefix + "mlp.gate_proj.weight" in hf_state_dict:
|
179 |
+
te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
|
180 |
+
: config.intermediate_size
|
181 |
+
] = hf_state_dict[layer_prefix + "mlp.gate_proj.weight"].data
|
182 |
+
|
183 |
+
if layer_prefix + "mlp.up_proj.weight" in hf_state_dict:
|
184 |
+
te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
|
185 |
+
config.intermediate_size :
|
186 |
+
] = hf_state_dict[layer_prefix + "mlp.up_proj.weight"].data
|
187 |
+
|
188 |
+
if layer_prefix + "mlp.down_proj.weight" in hf_state_dict:
|
189 |
+
te_state_dict[layer_prefix + "layernorm_mlp.fc2_weight"].data[:] = (
|
190 |
+
hf_state_dict[layer_prefix + "mlp.down_proj.weight"].data[:]
|
191 |
+
)
|
192 |
+
return all_layer_prefixes
|
ChatTTS/model/dvae.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import List, Optional, Literal, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pybase16384 as b14
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchaudio
|
10 |
+
from vector_quantize_pytorch import GroupedResidualFSQ
|
11 |
+
|
12 |
+
|
13 |
+
class ConvNeXtBlock(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
dim: int,
|
17 |
+
intermediate_dim: int,
|
18 |
+
kernel: int,
|
19 |
+
dilation: int,
|
20 |
+
layer_scale_init_value: float = 1e-6,
|
21 |
+
):
|
22 |
+
# ConvNeXt Block copied from Vocos.
|
23 |
+
super().__init__()
|
24 |
+
self.dwconv = nn.Conv1d(
|
25 |
+
dim,
|
26 |
+
dim,
|
27 |
+
kernel_size=kernel,
|
28 |
+
padding=dilation * (kernel // 2),
|
29 |
+
dilation=dilation,
|
30 |
+
groups=dim,
|
31 |
+
) # depthwise conv
|
32 |
+
|
33 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
34 |
+
self.pwconv1 = nn.Linear(
|
35 |
+
dim, intermediate_dim
|
36 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
37 |
+
self.act = nn.GELU()
|
38 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
39 |
+
self.gamma = (
|
40 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
41 |
+
if layer_scale_init_value > 0
|
42 |
+
else None
|
43 |
+
)
|
44 |
+
|
45 |
+
def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor:
|
46 |
+
residual = x
|
47 |
+
|
48 |
+
y = self.dwconv(x)
|
49 |
+
y.transpose_(1, 2) # (B, C, T) -> (B, T, C)
|
50 |
+
x = self.norm(y)
|
51 |
+
del y
|
52 |
+
y = self.pwconv1(x)
|
53 |
+
del x
|
54 |
+
x = self.act(y)
|
55 |
+
del y
|
56 |
+
y = self.pwconv2(x)
|
57 |
+
del x
|
58 |
+
if self.gamma is not None:
|
59 |
+
y *= self.gamma
|
60 |
+
y.transpose_(1, 2) # (B, T, C) -> (B, C, T)
|
61 |
+
|
62 |
+
x = y + residual
|
63 |
+
del y
|
64 |
+
|
65 |
+
return x
|
66 |
+
|
67 |
+
|
68 |
+
class GFSQ(nn.Module):
|
69 |
+
|
70 |
+
def __init__(
|
71 |
+
self, dim: int, levels: List[int], G: int, R: int, eps=1e-5, transpose=True
|
72 |
+
):
|
73 |
+
super(GFSQ, self).__init__()
|
74 |
+
self.quantizer = GroupedResidualFSQ(
|
75 |
+
dim=dim,
|
76 |
+
levels=list(levels),
|
77 |
+
num_quantizers=R,
|
78 |
+
groups=G,
|
79 |
+
)
|
80 |
+
self.n_ind = math.prod(levels)
|
81 |
+
self.eps = eps
|
82 |
+
self.transpose = transpose
|
83 |
+
self.G = G
|
84 |
+
self.R = R
|
85 |
+
|
86 |
+
def _embed(self, x: torch.Tensor):
|
87 |
+
if self.transpose:
|
88 |
+
x = x.transpose(1, 2)
|
89 |
+
"""
|
90 |
+
x = rearrange(
|
91 |
+
x, "b t (g r) -> g b t r", g = self.G, r = self.R,
|
92 |
+
)
|
93 |
+
"""
|
94 |
+
x = x.view(x.size(0), x.size(1), self.G, self.R).permute(2, 0, 1, 3)
|
95 |
+
feat = self.quantizer.get_output_from_indices(x)
|
96 |
+
return feat.transpose_(1, 2) if self.transpose else feat
|
97 |
+
|
98 |
+
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
99 |
+
return super().__call__(x)
|
100 |
+
|
101 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
102 |
+
if self.transpose:
|
103 |
+
x.transpose_(1, 2)
|
104 |
+
# feat, ind = self.quantizer(x)
|
105 |
+
_, ind = self.quantizer(x)
|
106 |
+
"""
|
107 |
+
ind = rearrange(
|
108 |
+
ind, "g b t r ->b t (g r)",
|
109 |
+
)
|
110 |
+
"""
|
111 |
+
ind = ind.permute(1, 2, 0, 3).contiguous()
|
112 |
+
ind = ind.view(ind.size(0), ind.size(1), -1)
|
113 |
+
"""
|
114 |
+
embed_onehot_tmp = F.one_hot(ind.long(), self.n_ind)
|
115 |
+
embed_onehot = embed_onehot_tmp.to(x.dtype)
|
116 |
+
del embed_onehot_tmp
|
117 |
+
e_mean = torch.mean(embed_onehot, dim=[0, 1])
|
118 |
+
# e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
|
119 |
+
torch.div(e_mean, (e_mean.sum(dim=1) + self.eps).unsqueeze(1), out=e_mean)
|
120 |
+
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
|
121 |
+
|
122 |
+
return
|
123 |
+
torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device),
|
124 |
+
feat.transpose_(1, 2) if self.transpose else feat,
|
125 |
+
perplexity,
|
126 |
+
"""
|
127 |
+
return ind.transpose_(1, 2) if self.transpose else ind
|
128 |
+
|
129 |
+
|
130 |
+
class DVAEDecoder(nn.Module):
|
131 |
+
def __init__(
|
132 |
+
self,
|
133 |
+
idim: int,
|
134 |
+
odim: int,
|
135 |
+
n_layer=12,
|
136 |
+
bn_dim=64,
|
137 |
+
hidden=256,
|
138 |
+
kernel=7,
|
139 |
+
dilation=2,
|
140 |
+
up=False,
|
141 |
+
):
|
142 |
+
super().__init__()
|
143 |
+
self.up = up
|
144 |
+
self.conv_in = nn.Sequential(
|
145 |
+
nn.Conv1d(idim, bn_dim, 3, 1, 1),
|
146 |
+
nn.GELU(),
|
147 |
+
nn.Conv1d(bn_dim, hidden, 3, 1, 1),
|
148 |
+
)
|
149 |
+
self.decoder_block = nn.ModuleList(
|
150 |
+
[
|
151 |
+
ConvNeXtBlock(
|
152 |
+
hidden,
|
153 |
+
hidden * 4,
|
154 |
+
kernel,
|
155 |
+
dilation,
|
156 |
+
)
|
157 |
+
for _ in range(n_layer)
|
158 |
+
]
|
159 |
+
)
|
160 |
+
self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
|
161 |
+
|
162 |
+
def forward(self, x: torch.Tensor, conditioning=None) -> torch.Tensor:
|
163 |
+
# B, C, T
|
164 |
+
y = self.conv_in(x)
|
165 |
+
del x
|
166 |
+
for f in self.decoder_block:
|
167 |
+
y = f(y, conditioning)
|
168 |
+
|
169 |
+
x = self.conv_out(y)
|
170 |
+
del y
|
171 |
+
return x
|
172 |
+
|
173 |
+
|
174 |
+
class MelSpectrogramFeatures(torch.nn.Module):
|
175 |
+
def __init__(
|
176 |
+
self,
|
177 |
+
sample_rate=24000,
|
178 |
+
n_fft=1024,
|
179 |
+
hop_length=256,
|
180 |
+
n_mels=100,
|
181 |
+
padding: Literal["center", "same"] = "center",
|
182 |
+
device: torch.device = torch.device("cpu"),
|
183 |
+
):
|
184 |
+
super().__init__()
|
185 |
+
self.device = device
|
186 |
+
if padding not in ["center", "same"]:
|
187 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
188 |
+
self.padding = padding
|
189 |
+
self.mel_spec = torchaudio.transforms.MelSpectrogram(
|
190 |
+
sample_rate=sample_rate,
|
191 |
+
n_fft=n_fft,
|
192 |
+
hop_length=hop_length,
|
193 |
+
n_mels=n_mels,
|
194 |
+
center=padding == "center",
|
195 |
+
power=1,
|
196 |
+
)
|
197 |
+
|
198 |
+
def __call__(self, audio: torch.Tensor) -> torch.Tensor:
|
199 |
+
return super().__call__(audio)
|
200 |
+
|
201 |
+
def forward(self, audio: torch.Tensor) -> torch.Tensor:
|
202 |
+
audio = audio.to(self.device)
|
203 |
+
mel: torch.Tensor = self.mel_spec(audio)
|
204 |
+
features = torch.log(torch.clip(mel, min=1e-5))
|
205 |
+
return features
|
206 |
+
|
207 |
+
|
208 |
+
class DVAE(nn.Module):
|
209 |
+
def __init__(
|
210 |
+
self,
|
211 |
+
decoder_config: dict,
|
212 |
+
encoder_config: Optional[dict] = None,
|
213 |
+
vq_config: Optional[dict] = None,
|
214 |
+
dim=512,
|
215 |
+
coef: Optional[str] = None,
|
216 |
+
device: torch.device = torch.device("cpu"),
|
217 |
+
):
|
218 |
+
super().__init__()
|
219 |
+
if coef is None:
|
220 |
+
coef = torch.rand(100)
|
221 |
+
else:
|
222 |
+
coef = torch.from_numpy(
|
223 |
+
np.frombuffer(b14.decode_from_string(coef), dtype=np.float32).copy()
|
224 |
+
)
|
225 |
+
self.register_buffer("coef", coef.unsqueeze(0).unsqueeze_(2))
|
226 |
+
|
227 |
+
if encoder_config is not None:
|
228 |
+
self.downsample_conv = nn.Sequential(
|
229 |
+
nn.Conv1d(100, dim, 3, 1, 1),
|
230 |
+
nn.GELU(),
|
231 |
+
nn.Conv1d(dim, dim, 4, 2, 1),
|
232 |
+
nn.GELU(),
|
233 |
+
)
|
234 |
+
self.preprocessor_mel = MelSpectrogramFeatures(device=device)
|
235 |
+
self.encoder: Optional[DVAEDecoder] = DVAEDecoder(**encoder_config)
|
236 |
+
|
237 |
+
self.decoder = DVAEDecoder(**decoder_config)
|
238 |
+
self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False)
|
239 |
+
if vq_config is not None:
|
240 |
+
self.vq_layer = GFSQ(**vq_config)
|
241 |
+
else:
|
242 |
+
self.vq_layer = None
|
243 |
+
|
244 |
+
def __repr__(self) -> str:
|
245 |
+
return b14.encode_to_string(
|
246 |
+
self.coef.cpu().numpy().astype(np.float32).tobytes()
|
247 |
+
)
|
248 |
+
|
249 |
+
def __call__(
|
250 |
+
self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode"
|
251 |
+
) -> torch.Tensor:
|
252 |
+
return super().__call__(inp, mode)
|
253 |
+
|
254 |
+
@torch.inference_mode()
|
255 |
+
def forward(
|
256 |
+
self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode"
|
257 |
+
) -> torch.Tensor:
|
258 |
+
if mode == "encode" and hasattr(self, "encoder") and self.vq_layer is not None:
|
259 |
+
mel = self.preprocessor_mel(inp)
|
260 |
+
x: torch.Tensor = self.downsample_conv(
|
261 |
+
torch.div(mel, self.coef.view(100, 1).expand(mel.shape), out=mel),
|
262 |
+
).unsqueeze_(0)
|
263 |
+
del mel
|
264 |
+
x = self.encoder(x)
|
265 |
+
ind = self.vq_layer(x)
|
266 |
+
del x
|
267 |
+
return ind
|
268 |
+
|
269 |
+
if self.vq_layer is not None:
|
270 |
+
vq_feats = self.vq_layer._embed(inp)
|
271 |
+
else:
|
272 |
+
vq_feats = inp
|
273 |
+
|
274 |
+
vq_feats = (
|
275 |
+
vq_feats.view(
|
276 |
+
(vq_feats.size(0), 2, vq_feats.size(1) // 2, vq_feats.size(2)),
|
277 |
+
)
|
278 |
+
.permute(0, 2, 3, 1)
|
279 |
+
.flatten(2)
|
280 |
+
)
|
281 |
+
|
282 |
+
dec_out = self.out_conv(
|
283 |
+
self.decoder(
|
284 |
+
x=vq_feats,
|
285 |
+
),
|
286 |
+
)
|
287 |
+
|
288 |
+
del vq_feats
|
289 |
+
|
290 |
+
return torch.mul(dec_out, self.coef, out=dec_out)
|
291 |
+
|
292 |
+
@torch.inference_mode()
|
293 |
+
def sample_audio(self, wav: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
|
294 |
+
if isinstance(wav, np.ndarray):
|
295 |
+
wav = torch.from_numpy(wav)
|
296 |
+
return self(wav, "encode").squeeze_(0)
|
ChatTTS/model/embed.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from safetensors.torch import safe_open
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn.utils.parametrizations import weight_norm
|
5 |
+
|
6 |
+
|
7 |
+
class Embed(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self, hidden_size: int, num_audio_tokens: int, num_text_tokens: int, num_vq=4
|
10 |
+
):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
self.num_vq = num_vq
|
14 |
+
self.num_audio_tokens = num_audio_tokens
|
15 |
+
|
16 |
+
self.model_dim = hidden_size
|
17 |
+
self.emb_code = nn.ModuleList(
|
18 |
+
[nn.Embedding(num_audio_tokens, self.model_dim) for _ in range(num_vq)],
|
19 |
+
)
|
20 |
+
self.emb_text = nn.Embedding(num_text_tokens, self.model_dim)
|
21 |
+
|
22 |
+
self.head_text = weight_norm(
|
23 |
+
nn.Linear(self.model_dim, num_text_tokens, bias=False),
|
24 |
+
name="weight",
|
25 |
+
)
|
26 |
+
self.head_code = nn.ModuleList(
|
27 |
+
[
|
28 |
+
weight_norm(
|
29 |
+
nn.Linear(self.model_dim, num_audio_tokens, bias=False),
|
30 |
+
name="weight",
|
31 |
+
)
|
32 |
+
for _ in range(self.num_vq)
|
33 |
+
],
|
34 |
+
)
|
35 |
+
|
36 |
+
@torch.inference_mode()
|
37 |
+
def from_pretrained(self, filename: str, device: torch.device):
|
38 |
+
state_dict_tensors = {}
|
39 |
+
with safe_open(filename, framework="pt") as f:
|
40 |
+
for k in f.keys():
|
41 |
+
state_dict_tensors[k] = f.get_tensor(k)
|
42 |
+
self.load_state_dict(state_dict_tensors)
|
43 |
+
self.to(device)
|
44 |
+
|
45 |
+
def __call__(
|
46 |
+
self, input_ids: torch.Tensor, text_mask: torch.Tensor
|
47 |
+
) -> torch.Tensor:
|
48 |
+
"""
|
49 |
+
get_emb
|
50 |
+
"""
|
51 |
+
return super().__call__(input_ids, text_mask)
|
52 |
+
|
53 |
+
@torch.inference_mode()
|
54 |
+
def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Tensor:
|
55 |
+
"""
|
56 |
+
get_emb
|
57 |
+
"""
|
58 |
+
device = next(self.parameters()).device
|
59 |
+
emb_text: torch.Tensor = self.emb_text(
|
60 |
+
input_ids[text_mask].narrow(1, 0, 1).squeeze_(1).to(device)
|
61 |
+
)
|
62 |
+
|
63 |
+
text_mask_inv = text_mask.logical_not().to(device)
|
64 |
+
masked_input_ids: torch.Tensor = input_ids[text_mask_inv].to(device)
|
65 |
+
|
66 |
+
emb_code = [
|
67 |
+
self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq)
|
68 |
+
]
|
69 |
+
emb_code = torch.stack(emb_code, 2).sum(2)
|
70 |
+
|
71 |
+
emb = torch.zeros(
|
72 |
+
(input_ids.shape[:-1]) + (emb_text.shape[-1],),
|
73 |
+
device=emb_text.device,
|
74 |
+
dtype=emb_text.dtype,
|
75 |
+
)
|
76 |
+
emb[text_mask] = emb_text
|
77 |
+
emb[text_mask_inv] = emb_code.to(emb.dtype)
|
78 |
+
|
79 |
+
del emb_text, emb_code, text_mask_inv
|
80 |
+
|
81 |
+
return emb
|
ChatTTS/model/gpt.py
ADDED
@@ -0,0 +1,613 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import platform
|
2 |
+
from dataclasses import dataclass
|
3 |
+
import logging
|
4 |
+
from typing import Union, List, Optional, Tuple, Callable
|
5 |
+
import gc
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.nn.utils.parametrize as P
|
11 |
+
from tqdm import tqdm
|
12 |
+
from transformers import LlamaModel, LlamaConfig
|
13 |
+
from transformers.cache_utils import Cache
|
14 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
15 |
+
from transformers.utils import is_flash_attn_2_available
|
16 |
+
|
17 |
+
from ..utils import del_all
|
18 |
+
from .embed import Embed
|
19 |
+
|
20 |
+
|
21 |
+
class GPT(nn.Module):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
gpt_config: dict,
|
25 |
+
embed: Embed,
|
26 |
+
use_flash_attn=False,
|
27 |
+
use_vllm=False,
|
28 |
+
device=torch.device("cpu"),
|
29 |
+
device_gpt=torch.device("cpu"),
|
30 |
+
logger=logging.getLogger(__name__),
|
31 |
+
):
|
32 |
+
super().__init__()
|
33 |
+
|
34 |
+
self.logger = logger
|
35 |
+
|
36 |
+
self.device = device
|
37 |
+
self.device_gpt = device_gpt
|
38 |
+
|
39 |
+
self.generator = torch.Generator(device=device)
|
40 |
+
|
41 |
+
self.num_vq = int(gpt_config["num_vq"])
|
42 |
+
self.num_audio_tokens = int(gpt_config["num_audio_tokens"])
|
43 |
+
self.num_text_tokens = int(gpt_config["num_text_tokens"])
|
44 |
+
|
45 |
+
self.use_flash_attn = use_flash_attn
|
46 |
+
self.is_te_llama = False
|
47 |
+
self.is_vllm = use_vllm
|
48 |
+
|
49 |
+
if self.is_vllm:
|
50 |
+
return
|
51 |
+
|
52 |
+
self.llama_config = self._build_llama_config(gpt_config)
|
53 |
+
|
54 |
+
self.emb_code = [ec.__call__ for ec in embed.emb_code]
|
55 |
+
self.emb_text = embed.emb_text.__call__
|
56 |
+
self.head_text = embed.head_text.__call__
|
57 |
+
self.head_code = [hc.__call__ for hc in embed.head_code]
|
58 |
+
|
59 |
+
def from_pretrained(
|
60 |
+
self, gpt_folder: str, embed_file_path: str, experimental=False
|
61 |
+
):
|
62 |
+
if self.is_vllm and platform.system().lower() == "linux":
|
63 |
+
|
64 |
+
from .velocity import LLM
|
65 |
+
|
66 |
+
self.llm = LLM(
|
67 |
+
model=gpt_folder,
|
68 |
+
num_audio_tokens=self.num_audio_tokens,
|
69 |
+
num_text_tokens=self.num_text_tokens,
|
70 |
+
post_model_path=embed_file_path,
|
71 |
+
)
|
72 |
+
self.logger.info("vLLM model loaded")
|
73 |
+
return
|
74 |
+
|
75 |
+
self.gpt: LlamaModel = LlamaModel.from_pretrained(gpt_folder).to(
|
76 |
+
self.device_gpt
|
77 |
+
)
|
78 |
+
del self.gpt.embed_tokens
|
79 |
+
|
80 |
+
if (
|
81 |
+
experimental
|
82 |
+
and "cuda" in str(self.device_gpt)
|
83 |
+
and platform.system().lower() == "linux"
|
84 |
+
): # is TELlamaModel
|
85 |
+
try:
|
86 |
+
from .cuda import TELlamaModel
|
87 |
+
|
88 |
+
self.logger.warning(
|
89 |
+
"Linux with CUDA, try NVIDIA accelerated TELlamaModel because experimental is enabled"
|
90 |
+
)
|
91 |
+
state_dict = self.gpt.state_dict()
|
92 |
+
vanilla = TELlamaModel.from_state_dict(state_dict, self.llama_config)
|
93 |
+
# Force mem release. Taken from huggingface code
|
94 |
+
del state_dict, self.gpt
|
95 |
+
gc.collect()
|
96 |
+
self.gpt = vanilla
|
97 |
+
self.is_te_llama = True
|
98 |
+
except Exception as e:
|
99 |
+
self.logger.warning(
|
100 |
+
f"use default LlamaModel for importing TELlamaModel error: {e}"
|
101 |
+
)
|
102 |
+
|
103 |
+
class Context:
|
104 |
+
def __init__(self):
|
105 |
+
self._interrupt = False
|
106 |
+
|
107 |
+
def set(self, v: bool):
|
108 |
+
self._interrupt = v
|
109 |
+
|
110 |
+
def get(self) -> bool:
|
111 |
+
return self._interrupt
|
112 |
+
|
113 |
+
def _build_llama_config(
|
114 |
+
self,
|
115 |
+
config: dict,
|
116 |
+
) -> Tuple[LlamaModel, LlamaConfig]:
|
117 |
+
|
118 |
+
if self.use_flash_attn and is_flash_attn_2_available():
|
119 |
+
llama_config = LlamaConfig(
|
120 |
+
**config,
|
121 |
+
attn_implementation="flash_attention_2",
|
122 |
+
)
|
123 |
+
self.logger.warning(
|
124 |
+
"enabling flash_attention_2 may make gpt be even slower"
|
125 |
+
)
|
126 |
+
else:
|
127 |
+
llama_config = LlamaConfig(**config)
|
128 |
+
|
129 |
+
return llama_config
|
130 |
+
|
131 |
+
def prepare(self, compile=False):
|
132 |
+
if self.use_flash_attn and is_flash_attn_2_available():
|
133 |
+
self.gpt = self.gpt.to(dtype=torch.float16)
|
134 |
+
if compile and not self.is_te_llama and not self.is_vllm:
|
135 |
+
try:
|
136 |
+
self.compile(backend="inductor", dynamic=True)
|
137 |
+
self.gpt.compile(backend="inductor", dynamic=True)
|
138 |
+
except RuntimeError as e:
|
139 |
+
self.logger.warning(f"compile failed: {e}. fallback to normal mode.")
|
140 |
+
|
141 |
+
@dataclass(repr=False, eq=False)
|
142 |
+
class _GenerationInputs:
|
143 |
+
position_ids: torch.Tensor
|
144 |
+
cache_position: torch.Tensor
|
145 |
+
use_cache: bool
|
146 |
+
input_ids: Optional[torch.Tensor] = None
|
147 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
148 |
+
attention_mask: Optional[torch.Tensor] = None
|
149 |
+
inputs_embeds: Optional[torch.Tensor] = None
|
150 |
+
|
151 |
+
def to(self, device: torch.device, dtype: torch.dtype):
|
152 |
+
if self.attention_mask is not None:
|
153 |
+
self.attention_mask = self.attention_mask.to(device, dtype=dtype)
|
154 |
+
if self.position_ids is not None:
|
155 |
+
self.position_ids = self.position_ids.to(device, dtype=dtype)
|
156 |
+
if self.inputs_embeds is not None:
|
157 |
+
self.inputs_embeds = self.inputs_embeds.to(device, dtype=dtype)
|
158 |
+
if self.cache_position is not None:
|
159 |
+
self.cache_position = self.cache_position.to(device, dtype=dtype)
|
160 |
+
|
161 |
+
@torch.no_grad()
|
162 |
+
def _prepare_generation_inputs(
|
163 |
+
self,
|
164 |
+
input_ids: torch.Tensor,
|
165 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
166 |
+
attention_mask: Optional[torch.Tensor] = None,
|
167 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
168 |
+
cache_position: Optional[torch.Tensor] = None,
|
169 |
+
position_ids: Optional[torch.Tensor] = None,
|
170 |
+
use_cache=True,
|
171 |
+
) -> _GenerationInputs:
|
172 |
+
# With static cache, the `past_key_values` is None
|
173 |
+
# TODO joao: standardize interface for the different Cache classes and remove of this if
|
174 |
+
has_static_cache = False
|
175 |
+
if past_key_values is None:
|
176 |
+
if hasattr(self.gpt.layers[0], "self_attn"):
|
177 |
+
past_key_values = getattr(
|
178 |
+
self.gpt.layers[0].self_attn, "past_key_value", None
|
179 |
+
)
|
180 |
+
has_static_cache = past_key_values is not None
|
181 |
+
|
182 |
+
past_length = 0
|
183 |
+
if past_key_values is not None:
|
184 |
+
if isinstance(past_key_values, Cache):
|
185 |
+
past_length = (
|
186 |
+
int(cache_position[0])
|
187 |
+
if cache_position is not None
|
188 |
+
else past_key_values.get_seq_length()
|
189 |
+
)
|
190 |
+
max_cache_length = past_key_values.get_max_length()
|
191 |
+
cache_length = (
|
192 |
+
past_length
|
193 |
+
if max_cache_length is None
|
194 |
+
else min(max_cache_length, past_length)
|
195 |
+
)
|
196 |
+
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
|
197 |
+
else:
|
198 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
199 |
+
max_cache_length = None
|
200 |
+
|
201 |
+
# Keep only the unprocessed tokens:
|
202 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
203 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
204 |
+
# input)
|
205 |
+
if (
|
206 |
+
attention_mask is not None
|
207 |
+
and attention_mask.shape[1] > input_ids.shape[1]
|
208 |
+
):
|
209 |
+
start = attention_mask.shape[1] - past_length
|
210 |
+
input_ids = input_ids.narrow(1, -start, start)
|
211 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
212 |
+
# input_ids based on the past_length.
|
213 |
+
elif past_length < input_ids.shape[1]:
|
214 |
+
input_ids = input_ids.narrow(
|
215 |
+
1, past_length, input_ids.size(1) - past_length
|
216 |
+
)
|
217 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
218 |
+
|
219 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
220 |
+
if (
|
221 |
+
max_cache_length is not None
|
222 |
+
and attention_mask is not None
|
223 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
224 |
+
):
|
225 |
+
attention_mask = attention_mask.narrow(
|
226 |
+
1, -max_cache_length, max_cache_length
|
227 |
+
)
|
228 |
+
|
229 |
+
if attention_mask is not None and position_ids is None:
|
230 |
+
# create position_ids on the fly for batch generation
|
231 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
232 |
+
position_ids.masked_fill_(attention_mask.eq(0), 1)
|
233 |
+
if past_key_values:
|
234 |
+
position_ids = position_ids.narrow(
|
235 |
+
1, -input_ids.shape[1], input_ids.shape[1]
|
236 |
+
)
|
237 |
+
|
238 |
+
input_length = (
|
239 |
+
position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
240 |
+
)
|
241 |
+
if cache_position is None:
|
242 |
+
cache_position = torch.arange(
|
243 |
+
past_length, past_length + input_length, device=input_ids.device
|
244 |
+
)
|
245 |
+
else:
|
246 |
+
cache_position = cache_position.narrow(0, -input_length, input_length)
|
247 |
+
|
248 |
+
if has_static_cache:
|
249 |
+
past_key_values = None
|
250 |
+
|
251 |
+
model_inputs = self._GenerationInputs(
|
252 |
+
position_ids=position_ids,
|
253 |
+
cache_position=cache_position,
|
254 |
+
use_cache=use_cache,
|
255 |
+
)
|
256 |
+
|
257 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
258 |
+
if inputs_embeds is not None and past_key_values is None:
|
259 |
+
model_inputs.inputs_embeds = inputs_embeds
|
260 |
+
else:
|
261 |
+
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
262 |
+
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
263 |
+
# TODO: use `next_tokens` directly instead.
|
264 |
+
model_inputs.input_ids = input_ids.contiguous()
|
265 |
+
|
266 |
+
model_inputs.past_key_values = past_key_values
|
267 |
+
model_inputs.attention_mask = attention_mask
|
268 |
+
|
269 |
+
return model_inputs
|
270 |
+
|
271 |
+
@dataclass(repr=False, eq=False)
|
272 |
+
class GenerationOutputs:
|
273 |
+
ids: List[torch.Tensor]
|
274 |
+
attentions: List[Optional[Tuple[torch.FloatTensor, ...]]]
|
275 |
+
hiddens: List[torch.Tensor]
|
276 |
+
|
277 |
+
def destroy(self):
|
278 |
+
del_all(self.ids)
|
279 |
+
del_all(self.attentions)
|
280 |
+
del_all(self.hiddens)
|
281 |
+
|
282 |
+
@torch.no_grad()
|
283 |
+
def _prepare_generation_outputs(
|
284 |
+
self,
|
285 |
+
inputs_ids: torch.Tensor,
|
286 |
+
start_idx: int,
|
287 |
+
end_idx: torch.Tensor,
|
288 |
+
attentions: List[Optional[Tuple[torch.FloatTensor, ...]]],
|
289 |
+
hiddens: List[torch.Tensor],
|
290 |
+
infer_text: bool,
|
291 |
+
) -> GenerationOutputs:
|
292 |
+
inputs_ids = [
|
293 |
+
inputs_ids[idx].narrow(0, start_idx, i) for idx, i in enumerate(end_idx)
|
294 |
+
]
|
295 |
+
if infer_text:
|
296 |
+
inputs_ids = [i.narrow(1, 0, 1).squeeze_(1) for i in inputs_ids]
|
297 |
+
|
298 |
+
if len(hiddens) > 0:
|
299 |
+
hiddens = torch.stack(hiddens, 1)
|
300 |
+
hiddens = [
|
301 |
+
hiddens[idx].narrow(0, 0, i) for idx, i in enumerate(end_idx.int())
|
302 |
+
]
|
303 |
+
|
304 |
+
return self.GenerationOutputs(
|
305 |
+
ids=inputs_ids,
|
306 |
+
attentions=attentions,
|
307 |
+
hiddens=hiddens,
|
308 |
+
)
|
309 |
+
|
310 |
+
@torch.no_grad()
|
311 |
+
def generate(
|
312 |
+
self,
|
313 |
+
emb: torch.Tensor,
|
314 |
+
inputs_ids: torch.Tensor,
|
315 |
+
temperature: torch.Tensor,
|
316 |
+
eos_token: Union[int, torch.Tensor],
|
317 |
+
attention_mask: Optional[torch.Tensor] = None,
|
318 |
+
max_new_token=2048,
|
319 |
+
min_new_token=0,
|
320 |
+
logits_processors: Tuple[
|
321 |
+
Callable[[torch.LongTensor, torch.FloatTensor], torch.FloatTensor]
|
322 |
+
] = (),
|
323 |
+
infer_text=False,
|
324 |
+
return_attn=False,
|
325 |
+
return_hidden=False,
|
326 |
+
stream=False,
|
327 |
+
show_tqdm=True,
|
328 |
+
ensure_non_empty=True,
|
329 |
+
stream_batch=24,
|
330 |
+
manual_seed: Optional[int] = None,
|
331 |
+
context=Context(),
|
332 |
+
):
|
333 |
+
|
334 |
+
attentions: List[Optional[Tuple[torch.FloatTensor, ...]]] = []
|
335 |
+
hiddens = []
|
336 |
+
stream_iter = 0
|
337 |
+
|
338 |
+
start_idx, end_idx = inputs_ids.shape[1], torch.zeros(
|
339 |
+
inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long
|
340 |
+
)
|
341 |
+
finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()
|
342 |
+
|
343 |
+
old_temperature = temperature
|
344 |
+
|
345 |
+
temperature = (
|
346 |
+
temperature.unsqueeze(0)
|
347 |
+
.expand(inputs_ids.shape[0], -1)
|
348 |
+
.contiguous()
|
349 |
+
.view(-1, 1)
|
350 |
+
)
|
351 |
+
|
352 |
+
attention_mask_cache = torch.ones(
|
353 |
+
(
|
354 |
+
inputs_ids.shape[0],
|
355 |
+
inputs_ids.shape[1] + max_new_token,
|
356 |
+
),
|
357 |
+
dtype=torch.bool,
|
358 |
+
device=inputs_ids.device,
|
359 |
+
)
|
360 |
+
if attention_mask is not None:
|
361 |
+
attention_mask_cache.narrow(1, 0, attention_mask.shape[1]).copy_(
|
362 |
+
attention_mask
|
363 |
+
)
|
364 |
+
|
365 |
+
progress = inputs_ids.size(1)
|
366 |
+
# pre-allocate inputs_ids
|
367 |
+
inputs_ids_buf = torch.zeros(
|
368 |
+
inputs_ids.size(0),
|
369 |
+
progress + max_new_token,
|
370 |
+
inputs_ids.size(2),
|
371 |
+
dtype=inputs_ids.dtype,
|
372 |
+
device=inputs_ids.device,
|
373 |
+
)
|
374 |
+
inputs_ids_buf.narrow(1, 0, progress).copy_(inputs_ids)
|
375 |
+
del inputs_ids
|
376 |
+
inputs_ids = inputs_ids_buf.narrow(1, 0, progress)
|
377 |
+
|
378 |
+
pbar: Optional[tqdm] = None
|
379 |
+
|
380 |
+
if show_tqdm:
|
381 |
+
pbar = tqdm(
|
382 |
+
total=max_new_token,
|
383 |
+
desc="text" if infer_text else "code",
|
384 |
+
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]",
|
385 |
+
)
|
386 |
+
|
387 |
+
past_key_values = None
|
388 |
+
|
389 |
+
for i in range(max_new_token):
|
390 |
+
|
391 |
+
model_input = self._prepare_generation_inputs(
|
392 |
+
inputs_ids,
|
393 |
+
past_key_values,
|
394 |
+
attention_mask_cache.narrow(1, 0, inputs_ids.shape[1]),
|
395 |
+
use_cache=not self.is_te_llama,
|
396 |
+
)
|
397 |
+
|
398 |
+
if i > 0:
|
399 |
+
del emb
|
400 |
+
inputs_ids_emb = model_input.input_ids.to(self.device_gpt)
|
401 |
+
if infer_text:
|
402 |
+
emb: torch.Tensor = self.emb_text(inputs_ids_emb[:, :, 0])
|
403 |
+
else:
|
404 |
+
code_emb = [
|
405 |
+
self.emb_code[i](inputs_ids_emb[:, :, i])
|
406 |
+
for i in range(self.num_vq)
|
407 |
+
]
|
408 |
+
emb = torch.stack(code_emb, 3).sum(3)
|
409 |
+
del inputs_ids_emb, model_input.input_ids
|
410 |
+
model_input.inputs_embeds = emb
|
411 |
+
|
412 |
+
model_input.to(self.device_gpt, self.gpt.dtype)
|
413 |
+
|
414 |
+
outputs: BaseModelOutputWithPast = self.gpt(
|
415 |
+
attention_mask=model_input.attention_mask,
|
416 |
+
position_ids=model_input.position_ids,
|
417 |
+
past_key_values=model_input.past_key_values,
|
418 |
+
inputs_embeds=model_input.inputs_embeds,
|
419 |
+
use_cache=model_input.use_cache,
|
420 |
+
output_attentions=return_attn,
|
421 |
+
cache_position=model_input.cache_position,
|
422 |
+
)
|
423 |
+
del_all(model_input)
|
424 |
+
attentions.append(outputs.attentions)
|
425 |
+
hidden_states = outputs.last_hidden_state.to(
|
426 |
+
self.device, dtype=torch.float
|
427 |
+
) # 🐻
|
428 |
+
past_key_values = outputs.past_key_values
|
429 |
+
del_all(outputs)
|
430 |
+
if return_hidden:
|
431 |
+
hiddens.append(hidden_states.narrow(1, -1, 1).squeeze_(1))
|
432 |
+
|
433 |
+
with P.cached():
|
434 |
+
if infer_text:
|
435 |
+
logits: torch.Tensor = self.head_text(hidden_states)
|
436 |
+
else:
|
437 |
+
# logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3)
|
438 |
+
logits = torch.empty(
|
439 |
+
hidden_states.size(0),
|
440 |
+
hidden_states.size(1),
|
441 |
+
self.num_audio_tokens,
|
442 |
+
self.num_vq,
|
443 |
+
dtype=torch.float,
|
444 |
+
device=self.device,
|
445 |
+
)
|
446 |
+
for num_vq_iter in range(self.num_vq):
|
447 |
+
x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
|
448 |
+
logits[..., num_vq_iter] = x
|
449 |
+
del x
|
450 |
+
|
451 |
+
del hidden_states
|
452 |
+
|
453 |
+
# logits = logits[:, -1].float()
|
454 |
+
logits = logits.narrow(1, -1, 1).squeeze_(1).float()
|
455 |
+
|
456 |
+
if not infer_text:
|
457 |
+
# logits = rearrange(logits, "b c n -> (b n) c")
|
458 |
+
logits = logits.permute(0, 2, 1)
|
459 |
+
logits = logits.reshape(-1, logits.size(2))
|
460 |
+
# logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
|
461 |
+
inputs_ids_sliced = inputs_ids.narrow(
|
462 |
+
1,
|
463 |
+
start_idx,
|
464 |
+
inputs_ids.size(1) - start_idx,
|
465 |
+
).permute(0, 2, 1)
|
466 |
+
logits_token = inputs_ids_sliced.reshape(
|
467 |
+
inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1),
|
468 |
+
-1,
|
469 |
+
).to(self.device)
|
470 |
+
del inputs_ids_sliced
|
471 |
+
else:
|
472 |
+
logits_token = (
|
473 |
+
inputs_ids.narrow(
|
474 |
+
1,
|
475 |
+
start_idx,
|
476 |
+
inputs_ids.size(1) - start_idx,
|
477 |
+
)
|
478 |
+
.narrow(2, 0, 1)
|
479 |
+
.to(self.device)
|
480 |
+
)
|
481 |
+
|
482 |
+
logits /= temperature
|
483 |
+
|
484 |
+
for logitsProcessors in logits_processors:
|
485 |
+
logits = logitsProcessors(logits_token, logits)
|
486 |
+
|
487 |
+
del logits_token
|
488 |
+
|
489 |
+
if i < min_new_token:
|
490 |
+
logits[:, eos_token] = -torch.inf
|
491 |
+
|
492 |
+
scores = F.softmax(logits, dim=-1)
|
493 |
+
|
494 |
+
del logits
|
495 |
+
|
496 |
+
if manual_seed is None:
|
497 |
+
idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)
|
498 |
+
else:
|
499 |
+
idx_next = torch.multinomial(
|
500 |
+
scores,
|
501 |
+
num_samples=1,
|
502 |
+
generator=self.generator.manual_seed(manual_seed),
|
503 |
+
).to(finish.device)
|
504 |
+
|
505 |
+
del scores
|
506 |
+
|
507 |
+
if not infer_text:
|
508 |
+
# idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
|
509 |
+
idx_next = idx_next.view(-1, self.num_vq)
|
510 |
+
finish_or = idx_next.eq(eos_token).any(1)
|
511 |
+
finish.logical_or_(finish_or)
|
512 |
+
del finish_or
|
513 |
+
inputs_ids_buf.narrow(1, progress, 1).copy_(idx_next.unsqueeze_(1))
|
514 |
+
else:
|
515 |
+
finish_or = idx_next.eq(eos_token).any(1)
|
516 |
+
finish.logical_or_(finish_or)
|
517 |
+
del finish_or
|
518 |
+
inputs_ids_buf.narrow(1, progress, 1).copy_(
|
519 |
+
idx_next.unsqueeze_(-1).expand(-1, -1, self.num_vq),
|
520 |
+
)
|
521 |
+
|
522 |
+
if i == 0 and finish.any():
|
523 |
+
self.logger.warning(
|
524 |
+
"unexpected end at index %s",
|
525 |
+
str([unexpected_idx.item() for unexpected_idx in finish.nonzero()]),
|
526 |
+
)
|
527 |
+
if ensure_non_empty and manual_seed is None:
|
528 |
+
if show_tqdm:
|
529 |
+
pbar.close()
|
530 |
+
self.logger.warning("regenerate in order to ensure non-empty")
|
531 |
+
del_all(attentions)
|
532 |
+
del_all(hiddens)
|
533 |
+
del (
|
534 |
+
start_idx,
|
535 |
+
end_idx,
|
536 |
+
finish,
|
537 |
+
temperature,
|
538 |
+
attention_mask_cache,
|
539 |
+
past_key_values,
|
540 |
+
idx_next,
|
541 |
+
inputs_ids_buf,
|
542 |
+
)
|
543 |
+
new_gen = self.generate(
|
544 |
+
emb,
|
545 |
+
inputs_ids,
|
546 |
+
old_temperature,
|
547 |
+
eos_token,
|
548 |
+
attention_mask,
|
549 |
+
max_new_token,
|
550 |
+
min_new_token,
|
551 |
+
logits_processors,
|
552 |
+
infer_text,
|
553 |
+
return_attn,
|
554 |
+
return_hidden,
|
555 |
+
stream,
|
556 |
+
show_tqdm,
|
557 |
+
ensure_non_empty,
|
558 |
+
stream_batch,
|
559 |
+
manual_seed,
|
560 |
+
context,
|
561 |
+
)
|
562 |
+
for result in new_gen:
|
563 |
+
yield result
|
564 |
+
del inputs_ids
|
565 |
+
return
|
566 |
+
|
567 |
+
del idx_next
|
568 |
+
progress += 1
|
569 |
+
inputs_ids = inputs_ids_buf.narrow(1, 0, progress)
|
570 |
+
|
571 |
+
not_finished = finish.logical_not().to(end_idx.device)
|
572 |
+
end_idx.add_(not_finished.int())
|
573 |
+
stream_iter += not_finished.any().int()
|
574 |
+
if stream:
|
575 |
+
if stream_iter > 0 and stream_iter % stream_batch == 0:
|
576 |
+
self.logger.debug("yield stream result, end: %d", end_idx)
|
577 |
+
yield self._prepare_generation_outputs(
|
578 |
+
inputs_ids,
|
579 |
+
start_idx,
|
580 |
+
end_idx,
|
581 |
+
attentions,
|
582 |
+
hiddens,
|
583 |
+
infer_text,
|
584 |
+
)
|
585 |
+
del not_finished
|
586 |
+
|
587 |
+
if finish.all() or context.get():
|
588 |
+
break
|
589 |
+
|
590 |
+
if pbar is not None:
|
591 |
+
pbar.update(1)
|
592 |
+
|
593 |
+
if pbar is not None:
|
594 |
+
pbar.close()
|
595 |
+
|
596 |
+
if not finish.all():
|
597 |
+
if context.get():
|
598 |
+
self.logger.warning("generation is interrupted")
|
599 |
+
else:
|
600 |
+
self.logger.warning(
|
601 |
+
f"incomplete result. hit max_new_token: {max_new_token}"
|
602 |
+
)
|
603 |
+
|
604 |
+
del finish, inputs_ids_buf
|
605 |
+
|
606 |
+
yield self._prepare_generation_outputs(
|
607 |
+
inputs_ids,
|
608 |
+
start_idx,
|
609 |
+
end_idx,
|
610 |
+
attentions,
|
611 |
+
hiddens,
|
612 |
+
infer_text,
|
613 |
+
)
|
ChatTTS/model/processors.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
|
4 |
+
|
5 |
+
|
6 |
+
class CustomRepetitionPenaltyLogitsProcessorRepeat:
|
7 |
+
|
8 |
+
def __init__(self, penalty: float, max_input_ids: int, past_window: int):
|
9 |
+
if not isinstance(penalty, float) or not (penalty > 0):
|
10 |
+
raise ValueError(
|
11 |
+
f"`penalty` has to be a strictly positive float, but is {penalty}"
|
12 |
+
)
|
13 |
+
|
14 |
+
self.penalty = penalty
|
15 |
+
self.max_input_ids = max_input_ids
|
16 |
+
self.past_window = past_window
|
17 |
+
|
18 |
+
def __call__(
|
19 |
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
20 |
+
) -> torch.FloatTensor:
|
21 |
+
if input_ids.size(1) > self.past_window:
|
22 |
+
input_ids = input_ids.narrow(1, -self.past_window, self.past_window)
|
23 |
+
freq = F.one_hot(input_ids, scores.size(1)).sum(1)
|
24 |
+
if freq.size(0) > self.max_input_ids:
|
25 |
+
freq.narrow(
|
26 |
+
0, self.max_input_ids, freq.size(0) - self.max_input_ids
|
27 |
+
).zero_()
|
28 |
+
alpha = torch.pow(self.penalty, freq)
|
29 |
+
scores = scores.contiguous()
|
30 |
+
inp = scores.multiply(alpha)
|
31 |
+
oth = scores.divide(alpha)
|
32 |
+
con = scores < 0
|
33 |
+
out = torch.where(con, inp, oth)
|
34 |
+
del inp, oth, scores, con, alpha
|
35 |
+
return out
|
36 |
+
|
37 |
+
|
38 |
+
def gen_logits(
|
39 |
+
num_code: int,
|
40 |
+
top_P=0.7,
|
41 |
+
top_K=20,
|
42 |
+
repetition_penalty=1.0,
|
43 |
+
):
|
44 |
+
logits_warpers = []
|
45 |
+
if top_P is not None:
|
46 |
+
logits_warpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
|
47 |
+
if top_K is not None:
|
48 |
+
logits_warpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
|
49 |
+
|
50 |
+
logits_processors = []
|
51 |
+
if repetition_penalty is not None and repetition_penalty != 1:
|
52 |
+
logits_processors.append(
|
53 |
+
CustomRepetitionPenaltyLogitsProcessorRepeat(
|
54 |
+
repetition_penalty, num_code, 16
|
55 |
+
)
|
56 |
+
)
|
57 |
+
|
58 |
+
return logits_warpers, logits_processors
|
ChatTTS/model/speaker.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import lzma
|
2 |
+
from typing import List, Optional, Union
|
3 |
+
|
4 |
+
import pybase16384 as b14
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
class Speaker:
|
11 |
+
def __init__(self, dim: int, spk_cfg: str, device=torch.device("cpu")) -> None:
|
12 |
+
spk_stat = torch.from_numpy(
|
13 |
+
np.frombuffer(b14.decode_from_string(spk_cfg), dtype=np.float16).copy()
|
14 |
+
).to(device=device)
|
15 |
+
self.std, self.mean = spk_stat.requires_grad_(False).chunk(2)
|
16 |
+
self.dim = dim
|
17 |
+
|
18 |
+
def sample_random(self) -> str:
|
19 |
+
return self._encode(self._sample_random())
|
20 |
+
|
21 |
+
@torch.inference_mode()
|
22 |
+
def apply(
|
23 |
+
self,
|
24 |
+
emb: torch.Tensor,
|
25 |
+
spk_emb: Union[str, torch.Tensor],
|
26 |
+
input_ids: torch.Tensor,
|
27 |
+
spk_emb_ids: int,
|
28 |
+
device: torch.device,
|
29 |
+
inplace: bool = True,
|
30 |
+
) -> torch.Tensor:
|
31 |
+
if isinstance(spk_emb, str):
|
32 |
+
spk_emb_tensor = torch.from_numpy(self._decode(spk_emb))
|
33 |
+
else:
|
34 |
+
spk_emb_tensor = spk_emb
|
35 |
+
n = (
|
36 |
+
F.normalize(
|
37 |
+
spk_emb_tensor,
|
38 |
+
p=2.0,
|
39 |
+
dim=0,
|
40 |
+
eps=1e-12,
|
41 |
+
)
|
42 |
+
.to(device)
|
43 |
+
.unsqueeze_(0)
|
44 |
+
.expand(emb.size(0), -1)
|
45 |
+
.unsqueeze_(1)
|
46 |
+
.expand(emb.shape)
|
47 |
+
)
|
48 |
+
cond = input_ids.narrow(-1, 0, 1).eq(spk_emb_ids).expand(emb.shape)
|
49 |
+
out = torch.where(cond, n, emb, out=emb if inplace else None)
|
50 |
+
if inplace:
|
51 |
+
del cond, n
|
52 |
+
return out
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
@torch.no_grad()
|
56 |
+
def decorate_code_prompts(
|
57 |
+
text: List[str],
|
58 |
+
prompt: str,
|
59 |
+
txt_smp: Optional[str],
|
60 |
+
spk_emb: Optional[str],
|
61 |
+
) -> List[str]:
|
62 |
+
for i, t in enumerate(text):
|
63 |
+
text[i] = (
|
64 |
+
t.replace("[Stts]", "")
|
65 |
+
.replace("[spk_emb]", "")
|
66 |
+
.replace("[empty_spk]", "")
|
67 |
+
.strip()
|
68 |
+
)
|
69 |
+
"""
|
70 |
+
see https://github.com/2noise/ChatTTS/issues/459
|
71 |
+
"""
|
72 |
+
|
73 |
+
if prompt:
|
74 |
+
text = [prompt + i for i in text]
|
75 |
+
|
76 |
+
txt_smp = "" if txt_smp is None else txt_smp
|
77 |
+
if spk_emb is not None:
|
78 |
+
text = [f"[Stts][spk_emb]{txt_smp}{i}[Ptts]" for i in text]
|
79 |
+
else:
|
80 |
+
text = [f"[Stts][empty_spk]{txt_smp}{i}[Ptts]" for i in text]
|
81 |
+
|
82 |
+
return text
|
83 |
+
|
84 |
+
@staticmethod
|
85 |
+
@torch.no_grad()
|
86 |
+
def decorate_text_prompts(text: List[str], prompt: str) -> List[str]:
|
87 |
+
return [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text]
|
88 |
+
|
89 |
+
@staticmethod
|
90 |
+
@torch.no_grad()
|
91 |
+
def encode_prompt(prompt: torch.Tensor) -> str:
|
92 |
+
arr: np.ndarray = prompt.cpu().numpy().astype(np.uint16)
|
93 |
+
shp = arr.shape
|
94 |
+
assert len(shp) == 2, "prompt must be a 2D tensor"
|
95 |
+
s = b14.encode_to_string(
|
96 |
+
np.array(shp, dtype="<u2").tobytes()
|
97 |
+
+ lzma.compress(
|
98 |
+
arr.astype("<u2").tobytes(),
|
99 |
+
format=lzma.FORMAT_RAW,
|
100 |
+
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
|
101 |
+
),
|
102 |
+
)
|
103 |
+
del arr
|
104 |
+
return s
|
105 |
+
|
106 |
+
@staticmethod
|
107 |
+
@torch.no_grad()
|
108 |
+
def decode_prompt(prompt: str) -> torch.Tensor:
|
109 |
+
dec = b14.decode_from_string(prompt)
|
110 |
+
shp = np.frombuffer(dec[:4], dtype="<u2")
|
111 |
+
p = np.frombuffer(
|
112 |
+
lzma.decompress(
|
113 |
+
dec[4:],
|
114 |
+
format=lzma.FORMAT_RAW,
|
115 |
+
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
|
116 |
+
),
|
117 |
+
dtype="<u2",
|
118 |
+
).copy()
|
119 |
+
del dec
|
120 |
+
return torch.from_numpy(p.astype(np.int32)).view(*shp)
|
121 |
+
|
122 |
+
@torch.no_grad()
|
123 |
+
def _sample_random(self) -> torch.Tensor:
|
124 |
+
spk = (
|
125 |
+
torch.randn(self.dim, device=self.std.device, dtype=self.std.dtype)
|
126 |
+
.mul_(self.std)
|
127 |
+
.add_(self.mean)
|
128 |
+
)
|
129 |
+
return spk
|
130 |
+
|
131 |
+
@staticmethod
|
132 |
+
@torch.no_grad()
|
133 |
+
def _encode(spk_emb: torch.Tensor) -> str:
|
134 |
+
arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy()
|
135 |
+
s = b14.encode_to_string(
|
136 |
+
lzma.compress(
|
137 |
+
arr.tobytes(),
|
138 |
+
format=lzma.FORMAT_RAW,
|
139 |
+
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
|
140 |
+
),
|
141 |
+
)
|
142 |
+
del arr
|
143 |
+
return s
|
144 |
+
|
145 |
+
@staticmethod
|
146 |
+
def _decode(spk_emb: str) -> np.ndarray:
|
147 |
+
return np.frombuffer(
|
148 |
+
lzma.decompress(
|
149 |
+
b14.decode_from_string(spk_emb),
|
150 |
+
format=lzma.FORMAT_RAW,
|
151 |
+
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
|
152 |
+
),
|
153 |
+
dtype=np.float16,
|
154 |
+
).copy()
|
ChatTTS/model/tokenizer.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
4 |
+
"""
|
5 |
+
https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning
|
6 |
+
"""
|
7 |
+
|
8 |
+
from typing import List, Tuple, Optional, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from transformers import BertTokenizerFast
|
12 |
+
|
13 |
+
from ..utils import del_all
|
14 |
+
|
15 |
+
|
16 |
+
class Tokenizer:
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
tokenizer_path: torch.serialization.FILE_LIKE,
|
20 |
+
):
|
21 |
+
"""
|
22 |
+
tokenizer: BertTokenizerFast = torch.load(
|
23 |
+
tokenizer_path, map_location=device, mmap=True
|
24 |
+
)
|
25 |
+
# tokenizer.save_pretrained("asset/tokenizer", legacy_format=False)
|
26 |
+
"""
|
27 |
+
tokenizer: BertTokenizerFast = BertTokenizerFast.from_pretrained(tokenizer_path)
|
28 |
+
self._tokenizer = tokenizer
|
29 |
+
|
30 |
+
self.len = len(tokenizer)
|
31 |
+
self.spk_emb_ids = tokenizer.convert_tokens_to_ids("[spk_emb]")
|
32 |
+
self.break_0_ids = tokenizer.convert_tokens_to_ids("[break_0]")
|
33 |
+
self.eos_token = tokenizer.convert_tokens_to_ids("[Ebreak]")
|
34 |
+
|
35 |
+
@torch.inference_mode()
|
36 |
+
def encode(
|
37 |
+
self,
|
38 |
+
text: List[str],
|
39 |
+
num_vq: int,
|
40 |
+
prompt: Optional[torch.Tensor] = None,
|
41 |
+
device="cpu",
|
42 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
43 |
+
|
44 |
+
input_ids_lst = []
|
45 |
+
attention_mask_lst = []
|
46 |
+
max_input_ids_len = -1
|
47 |
+
max_attention_mask_len = -1
|
48 |
+
prompt_size = 0
|
49 |
+
|
50 |
+
if prompt is not None:
|
51 |
+
assert prompt.size(0) == num_vq, "prompt dim 0 must equal to num_vq"
|
52 |
+
prompt_size = prompt.size(1)
|
53 |
+
|
54 |
+
# avoid random speaker embedding of tokenizer in the other dims
|
55 |
+
for t in text:
|
56 |
+
x = self._tokenizer.encode_plus(
|
57 |
+
t, return_tensors="pt", add_special_tokens=False, padding=True
|
58 |
+
)
|
59 |
+
input_ids_lst.append(x["input_ids"].squeeze_(0))
|
60 |
+
attention_mask_lst.append(x["attention_mask"].squeeze_(0))
|
61 |
+
del_all(x)
|
62 |
+
ids_sz = input_ids_lst[-1].size(0)
|
63 |
+
if ids_sz > max_input_ids_len:
|
64 |
+
max_input_ids_len = ids_sz
|
65 |
+
attn_sz = attention_mask_lst[-1].size(0)
|
66 |
+
if attn_sz > max_attention_mask_len:
|
67 |
+
max_attention_mask_len = attn_sz
|
68 |
+
|
69 |
+
if prompt is not None:
|
70 |
+
max_input_ids_len += prompt_size
|
71 |
+
max_attention_mask_len += prompt_size
|
72 |
+
|
73 |
+
input_ids = torch.zeros(
|
74 |
+
len(input_ids_lst),
|
75 |
+
max_input_ids_len,
|
76 |
+
device=device,
|
77 |
+
dtype=input_ids_lst[0].dtype,
|
78 |
+
)
|
79 |
+
for i in range(len(input_ids_lst)):
|
80 |
+
input_ids.narrow(0, i, 1).narrow(
|
81 |
+
1,
|
82 |
+
max_input_ids_len - prompt_size - input_ids_lst[i].size(0),
|
83 |
+
input_ids_lst[i].size(0),
|
84 |
+
).copy_(
|
85 |
+
input_ids_lst[i]
|
86 |
+
) # left padding
|
87 |
+
del_all(input_ids_lst)
|
88 |
+
|
89 |
+
attention_mask = torch.zeros(
|
90 |
+
len(attention_mask_lst),
|
91 |
+
max_attention_mask_len,
|
92 |
+
device=device,
|
93 |
+
dtype=attention_mask_lst[0].dtype,
|
94 |
+
)
|
95 |
+
for i in range(len(attention_mask_lst)):
|
96 |
+
attn = attention_mask.narrow(0, i, 1)
|
97 |
+
attn.narrow(
|
98 |
+
1,
|
99 |
+
max_attention_mask_len - prompt_size - attention_mask_lst[i].size(0),
|
100 |
+
attention_mask_lst[i].size(0),
|
101 |
+
).copy_(
|
102 |
+
attention_mask_lst[i]
|
103 |
+
) # left padding
|
104 |
+
if prompt_size > 0:
|
105 |
+
attn.narrow(
|
106 |
+
1,
|
107 |
+
max_attention_mask_len - prompt_size,
|
108 |
+
prompt_size,
|
109 |
+
).fill_(1)
|
110 |
+
del_all(attention_mask_lst)
|
111 |
+
|
112 |
+
text_mask = attention_mask.bool()
|
113 |
+
new_input_ids = input_ids.unsqueeze_(-1).expand(-1, -1, num_vq).clone()
|
114 |
+
del input_ids
|
115 |
+
|
116 |
+
if prompt_size > 0:
|
117 |
+
text_mask.narrow(1, max_input_ids_len - prompt_size, prompt_size).fill_(0)
|
118 |
+
prompt_t = prompt.t().unsqueeze_(0).expand(new_input_ids.size(0), -1, -1)
|
119 |
+
new_input_ids.narrow(
|
120 |
+
1,
|
121 |
+
max_input_ids_len - prompt_size,
|
122 |
+
prompt_size,
|
123 |
+
).copy_(prompt_t)
|
124 |
+
del prompt_t
|
125 |
+
|
126 |
+
return new_input_ids, attention_mask, text_mask
|
127 |
+
|
128 |
+
@torch.inference_mode
|
129 |
+
def decode(
|
130 |
+
self,
|
131 |
+
sequences: Union[List[int], List[List[int]]],
|
132 |
+
skip_special_tokens: bool = False,
|
133 |
+
clean_up_tokenization_spaces: bool = None,
|
134 |
+
**kwargs,
|
135 |
+
):
|
136 |
+
return self._tokenizer.batch_decode(
|
137 |
+
sequences, skip_special_tokens, clean_up_tokenization_spaces, **kwargs
|
138 |
+
)
|
ChatTTS/model/velocity/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .llm import LLM
|
2 |
+
from .sampling_params import SamplingParams
|
ChatTTS/model/velocity/block_manager.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A block manager that manages token blocks."""
|
2 |
+
|
3 |
+
import enum
|
4 |
+
from typing import Dict, List, Optional, Set, Tuple
|
5 |
+
|
6 |
+
from vllm.block import PhysicalTokenBlock
|
7 |
+
from .sequence import Sequence, SequenceGroup, SequenceStatus
|
8 |
+
from vllm.utils import Device
|
9 |
+
|
10 |
+
# Mapping: logical block number -> physical block.
|
11 |
+
BlockTable = List[PhysicalTokenBlock]
|
12 |
+
|
13 |
+
|
14 |
+
class BlockAllocator:
|
15 |
+
"""Manages free physical token blocks for a device.
|
16 |
+
|
17 |
+
The allocator maintains a list of free blocks and allocates a block when
|
18 |
+
requested. When a block is freed, its reference count is decremented. If
|
19 |
+
the reference count becomes zero, the block is added back to the free list.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
device: Device,
|
25 |
+
block_size: int,
|
26 |
+
num_blocks: int,
|
27 |
+
) -> None:
|
28 |
+
self.device = device
|
29 |
+
self.block_size = block_size
|
30 |
+
self.num_blocks = num_blocks
|
31 |
+
|
32 |
+
# Initialize the free blocks.
|
33 |
+
self.free_blocks: BlockTable = []
|
34 |
+
for i in range(num_blocks):
|
35 |
+
block = PhysicalTokenBlock(
|
36 |
+
device=device, block_number=i, block_size=block_size
|
37 |
+
)
|
38 |
+
self.free_blocks.append(block)
|
39 |
+
|
40 |
+
def allocate(self) -> PhysicalTokenBlock:
|
41 |
+
if not self.free_blocks:
|
42 |
+
raise ValueError("Out of memory! No free blocks are available.")
|
43 |
+
block = self.free_blocks.pop()
|
44 |
+
block.ref_count = 1
|
45 |
+
return block
|
46 |
+
|
47 |
+
def free(self, block: PhysicalTokenBlock) -> None:
|
48 |
+
if block.ref_count == 0:
|
49 |
+
raise ValueError(f"Double free! {block} is already freed.")
|
50 |
+
block.ref_count -= 1
|
51 |
+
if block.ref_count == 0:
|
52 |
+
self.free_blocks.append(block)
|
53 |
+
|
54 |
+
def get_num_free_blocks(self) -> int:
|
55 |
+
return len(self.free_blocks)
|
56 |
+
|
57 |
+
|
58 |
+
class AllocStatus(enum.Enum):
|
59 |
+
"""Result for BlockSpaceManager.can_allocate
|
60 |
+
|
61 |
+
1. Ok: seq_group can be allocated now.
|
62 |
+
2. Later: seq_group cannot be allocated.
|
63 |
+
The capacity of allocator is larger than seq_group required.
|
64 |
+
3. Never: seq_group can never be allocated.
|
65 |
+
The seq_group is too large to allocated in GPU.
|
66 |
+
"""
|
67 |
+
|
68 |
+
OK = enum.auto()
|
69 |
+
LATER = enum.auto()
|
70 |
+
NEVER = enum.auto()
|
71 |
+
|
72 |
+
|
73 |
+
class BlockSpaceManager:
|
74 |
+
"""Manages the mapping between logical and physical token blocks."""
|
75 |
+
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
block_size: int,
|
79 |
+
num_gpu_blocks: int,
|
80 |
+
num_cpu_blocks: int,
|
81 |
+
watermark: float = 0.01,
|
82 |
+
sliding_window: Optional[int] = None,
|
83 |
+
) -> None:
|
84 |
+
self.block_size = block_size
|
85 |
+
self.num_total_gpu_blocks = num_gpu_blocks
|
86 |
+
self.num_total_cpu_blocks = num_cpu_blocks
|
87 |
+
|
88 |
+
self.block_sliding_window = None
|
89 |
+
if sliding_window is not None:
|
90 |
+
assert sliding_window % block_size == 0, (sliding_window, block_size)
|
91 |
+
self.block_sliding_window = sliding_window // block_size
|
92 |
+
|
93 |
+
self.watermark = watermark
|
94 |
+
assert watermark >= 0.0
|
95 |
+
|
96 |
+
self.watermark_blocks = int(watermark * num_gpu_blocks)
|
97 |
+
self.gpu_allocator = BlockAllocator(Device.GPU, block_size, num_gpu_blocks)
|
98 |
+
self.cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks)
|
99 |
+
# Mapping: seq_id -> BlockTable.
|
100 |
+
self.block_tables: Dict[int, BlockTable] = {}
|
101 |
+
|
102 |
+
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
|
103 |
+
# FIXME(woosuk): Here we assume that all sequences in the group share
|
104 |
+
# the same prompt. This may not be true for preempted sequences.
|
105 |
+
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
106 |
+
num_required_blocks = len(seq.logical_token_blocks)
|
107 |
+
if self.block_sliding_window is not None:
|
108 |
+
num_required_blocks = min(num_required_blocks, self.block_sliding_window)
|
109 |
+
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
110 |
+
|
111 |
+
# Use watermark to avoid frequent cache eviction.
|
112 |
+
if self.num_total_gpu_blocks - num_required_blocks < self.watermark_blocks:
|
113 |
+
return AllocStatus.NEVER
|
114 |
+
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
|
115 |
+
return AllocStatus.OK
|
116 |
+
else:
|
117 |
+
return AllocStatus.LATER
|
118 |
+
|
119 |
+
def allocate(self, seq_group: SequenceGroup) -> None:
|
120 |
+
# NOTE: Here we assume that all sequences in the group have the same
|
121 |
+
# prompt.
|
122 |
+
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
123 |
+
|
124 |
+
# Allocate new physical token blocks that will store the prompt tokens.
|
125 |
+
block_table: BlockTable = []
|
126 |
+
for logical_idx in range(len(seq.logical_token_blocks)):
|
127 |
+
if (
|
128 |
+
self.block_sliding_window is not None
|
129 |
+
and logical_idx >= self.block_sliding_window
|
130 |
+
):
|
131 |
+
block = block_table[logical_idx % self.block_sliding_window]
|
132 |
+
else:
|
133 |
+
block = self.gpu_allocator.allocate()
|
134 |
+
# Set the reference counts of the token blocks.
|
135 |
+
block.ref_count = seq_group.num_seqs()
|
136 |
+
block_table.append(block)
|
137 |
+
|
138 |
+
# Assign the block table for each sequence.
|
139 |
+
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
140 |
+
self.block_tables[seq.seq_id] = block_table.copy()
|
141 |
+
|
142 |
+
def can_append_slot(self, seq_group: SequenceGroup) -> bool:
|
143 |
+
# Simple heuristic: If there is at least one free block
|
144 |
+
# for each sequence, we can append.
|
145 |
+
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
146 |
+
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
147 |
+
return num_seqs <= num_free_gpu_blocks
|
148 |
+
|
149 |
+
def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
|
150 |
+
"""Allocate a physical slot for a new token."""
|
151 |
+
logical_blocks = seq.logical_token_blocks
|
152 |
+
block_table = self.block_tables[seq.seq_id]
|
153 |
+
|
154 |
+
if len(block_table) < len(logical_blocks):
|
155 |
+
if (
|
156 |
+
self.block_sliding_window
|
157 |
+
and len(block_table) >= self.block_sliding_window
|
158 |
+
):
|
159 |
+
# re-use a block
|
160 |
+
block_table.append(
|
161 |
+
block_table[len(block_table) % self.block_sliding_window]
|
162 |
+
)
|
163 |
+
else:
|
164 |
+
# The sequence has a new logical block.
|
165 |
+
# Allocate a new physical block.
|
166 |
+
block = self.gpu_allocator.allocate()
|
167 |
+
block_table.append(block)
|
168 |
+
return None
|
169 |
+
|
170 |
+
# We want to append the token to the last physical block.
|
171 |
+
last_block = block_table[-1]
|
172 |
+
assert last_block.device == Device.GPU
|
173 |
+
if last_block.ref_count == 1:
|
174 |
+
# Not shared with other sequences. Appendable.
|
175 |
+
return None
|
176 |
+
else:
|
177 |
+
# The last block is shared with other sequences.
|
178 |
+
# Copy on Write: Allocate a new block and copy the tokens.
|
179 |
+
new_block = self.gpu_allocator.allocate()
|
180 |
+
block_table[-1] = new_block
|
181 |
+
self.gpu_allocator.free(last_block)
|
182 |
+
return last_block.block_number, new_block.block_number
|
183 |
+
|
184 |
+
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
185 |
+
# NOTE: fork does not allocate a new physical block.
|
186 |
+
# Thus, it is always safe from OOM.
|
187 |
+
src_block_table = self.block_tables[parent_seq.seq_id]
|
188 |
+
self.block_tables[child_seq.seq_id] = src_block_table.copy()
|
189 |
+
for block in src_block_table:
|
190 |
+
block.ref_count += 1
|
191 |
+
|
192 |
+
def _get_physical_blocks(
|
193 |
+
self, seq_group: SequenceGroup
|
194 |
+
) -> List[PhysicalTokenBlock]:
|
195 |
+
# NOTE: Here, we assume that the physical blocks are only shared by
|
196 |
+
# the sequences in the same group.
|
197 |
+
blocks: Set[PhysicalTokenBlock] = set()
|
198 |
+
for seq in seq_group.get_seqs():
|
199 |
+
if seq.is_finished():
|
200 |
+
continue
|
201 |
+
blocks.update(self.block_tables[seq.seq_id])
|
202 |
+
return list(blocks)
|
203 |
+
|
204 |
+
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
|
205 |
+
blocks = self._get_physical_blocks(seq_group)
|
206 |
+
num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
|
207 |
+
num_free_blocks = self.gpu_allocator.get_num_free_blocks()
|
208 |
+
# NOTE: Conservatively, we assume that every sequence will allocate
|
209 |
+
# at least one free block right after the swap-in.
|
210 |
+
# NOTE: This should match the logic in can_append_slot().
|
211 |
+
num_required_blocks = len(blocks) + num_swapped_seqs
|
212 |
+
return num_free_blocks - num_required_blocks >= self.watermark_blocks
|
213 |
+
|
214 |
+
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
215 |
+
# CPU block -> GPU block.
|
216 |
+
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
217 |
+
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
218 |
+
new_block_table: BlockTable = []
|
219 |
+
block_table = self.block_tables[seq.seq_id]
|
220 |
+
|
221 |
+
for cpu_block in block_table:
|
222 |
+
if cpu_block in mapping:
|
223 |
+
gpu_block = mapping[cpu_block]
|
224 |
+
gpu_block.ref_count += 1
|
225 |
+
else:
|
226 |
+
gpu_block = self.gpu_allocator.allocate()
|
227 |
+
mapping[cpu_block] = gpu_block
|
228 |
+
new_block_table.append(gpu_block)
|
229 |
+
# Free the CPU block swapped in to GPU.
|
230 |
+
self.cpu_allocator.free(cpu_block)
|
231 |
+
self.block_tables[seq.seq_id] = new_block_table
|
232 |
+
|
233 |
+
block_number_mapping = {
|
234 |
+
cpu_block.block_number: gpu_block.block_number
|
235 |
+
for cpu_block, gpu_block in mapping.items()
|
236 |
+
}
|
237 |
+
return block_number_mapping
|
238 |
+
|
239 |
+
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
240 |
+
blocks = self._get_physical_blocks(seq_group)
|
241 |
+
return len(blocks) <= self.cpu_allocator.get_num_free_blocks()
|
242 |
+
|
243 |
+
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
244 |
+
# GPU block -> CPU block.
|
245 |
+
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
246 |
+
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
247 |
+
new_block_table: BlockTable = []
|
248 |
+
block_table = self.block_tables[seq.seq_id]
|
249 |
+
|
250 |
+
for gpu_block in block_table:
|
251 |
+
if gpu_block in mapping:
|
252 |
+
cpu_block = mapping[gpu_block]
|
253 |
+
cpu_block.ref_count += 1
|
254 |
+
else:
|
255 |
+
cpu_block = self.cpu_allocator.allocate()
|
256 |
+
mapping[gpu_block] = cpu_block
|
257 |
+
new_block_table.append(cpu_block)
|
258 |
+
# Free the GPU block swapped out to CPU.
|
259 |
+
self.gpu_allocator.free(gpu_block)
|
260 |
+
self.block_tables[seq.seq_id] = new_block_table
|
261 |
+
|
262 |
+
block_number_mapping = {
|
263 |
+
gpu_block.block_number: cpu_block.block_number
|
264 |
+
for gpu_block, cpu_block in mapping.items()
|
265 |
+
}
|
266 |
+
return block_number_mapping
|
267 |
+
|
268 |
+
def _free_block_table(self, block_table: BlockTable) -> None:
|
269 |
+
for block in set(block_table):
|
270 |
+
if block.device == Device.GPU:
|
271 |
+
self.gpu_allocator.free(block)
|
272 |
+
else:
|
273 |
+
self.cpu_allocator.free(block)
|
274 |
+
|
275 |
+
def free(self, seq: Sequence) -> None:
|
276 |
+
if seq.seq_id not in self.block_tables:
|
277 |
+
# Already freed or haven't been scheduled yet.
|
278 |
+
return
|
279 |
+
block_table = self.block_tables[seq.seq_id]
|
280 |
+
self._free_block_table(block_table)
|
281 |
+
del self.block_tables[seq.seq_id]
|
282 |
+
|
283 |
+
def reset(self) -> None:
|
284 |
+
for block_table in self.block_tables.values():
|
285 |
+
self._free_block_table(block_table)
|
286 |
+
self.block_tables.clear()
|
287 |
+
|
288 |
+
def get_block_table(self, seq: Sequence) -> List[int]:
|
289 |
+
block_table = self.block_tables[seq.seq_id]
|
290 |
+
return [block.block_number for block in block_table]
|
291 |
+
|
292 |
+
def get_num_free_gpu_blocks(self) -> int:
|
293 |
+
return self.gpu_allocator.get_num_free_blocks()
|
294 |
+
|
295 |
+
def get_num_free_cpu_blocks(self) -> int:
|
296 |
+
return self.cpu_allocator.get_num_free_blocks()
|
ChatTTS/model/velocity/configs.py
ADDED
@@ -0,0 +1,865 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union, Tuple
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from transformers import PretrainedConfig
|
6 |
+
|
7 |
+
from vllm.logger import init_logger
|
8 |
+
from vllm.transformers_utils.config import get_config
|
9 |
+
from vllm.utils import get_cpu_memory, is_hip
|
10 |
+
|
11 |
+
import argparse
|
12 |
+
import dataclasses
|
13 |
+
from dataclasses import dataclass
|
14 |
+
|
15 |
+
|
16 |
+
logger = init_logger(__name__)
|
17 |
+
|
18 |
+
_GB = 1 << 30
|
19 |
+
|
20 |
+
|
21 |
+
class ModelConfig:
|
22 |
+
"""Configuration for the model.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
model: Name or path of the huggingface model to use.
|
26 |
+
tokenizer: Name or path of the huggingface tokenizer to use.
|
27 |
+
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
|
28 |
+
available, and "slow" will always use the slow tokenizer.
|
29 |
+
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
30 |
+
downloading the model and tokenizer.
|
31 |
+
download_dir: Directory to download and load the weights, default to the
|
32 |
+
default cache directory of huggingface.
|
33 |
+
load_format: The format of the model weights to load:
|
34 |
+
"auto" will try to load the weights in the safetensors format and
|
35 |
+
fall back to the pytorch bin format if safetensors format is
|
36 |
+
not available.
|
37 |
+
"pt" will load the weights in the pytorch bin format.
|
38 |
+
"safetensors" will load the weights in the safetensors format.
|
39 |
+
"npcache" will load the weights in pytorch format and store
|
40 |
+
a numpy cache to speed up the loading.
|
41 |
+
"dummy" will initialize the weights with random values, which is
|
42 |
+
mainly for profiling.
|
43 |
+
dtype: Data type for model weights and activations. The "auto" option
|
44 |
+
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
45 |
+
for BF16 models.
|
46 |
+
seed: Random seed for reproducibility.
|
47 |
+
revision: The specific model version to use. It can be a branch name,
|
48 |
+
a tag name, or a commit id. If unspecified, will use the default
|
49 |
+
version.
|
50 |
+
tokenizer_revision: The specific tokenizer version to use. It can be a
|
51 |
+
branch name, a tag name, or a commit id. If unspecified, will use
|
52 |
+
the default version.
|
53 |
+
max_model_len: Maximum length of a sequence (including prompt and
|
54 |
+
output). If None, will be derived from the model.
|
55 |
+
quantization: Quantization method that was used to quantize the model
|
56 |
+
weights. If None, we assume the model weights are not quantized.
|
57 |
+
enforce_eager: Whether to enforce eager execution. If True, we will
|
58 |
+
disable CUDA graph and always execute the model in eager mode.
|
59 |
+
If False, we will use CUDA graph and eager execution in hybrid.
|
60 |
+
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
61 |
+
When a sequence has context length larger than this, we fall back
|
62 |
+
to eager mode.
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
model: str,
|
68 |
+
tokenizer: str,
|
69 |
+
tokenizer_mode: str,
|
70 |
+
trust_remote_code: bool,
|
71 |
+
download_dir: Optional[str],
|
72 |
+
load_format: str,
|
73 |
+
dtype: Union[str, torch.dtype],
|
74 |
+
seed: int,
|
75 |
+
revision: Optional[str] = None,
|
76 |
+
tokenizer_revision: Optional[str] = None,
|
77 |
+
max_model_len: Optional[int] = None,
|
78 |
+
quantization: Optional[str] = None,
|
79 |
+
enforce_eager: bool = False,
|
80 |
+
max_context_len_to_capture: Optional[int] = None,
|
81 |
+
num_audio_tokens: int = 1024,
|
82 |
+
num_text_tokens: int = 80,
|
83 |
+
) -> None:
|
84 |
+
self.model = model
|
85 |
+
self.tokenizer = tokenizer
|
86 |
+
self.tokenizer_mode = tokenizer_mode
|
87 |
+
self.trust_remote_code = trust_remote_code
|
88 |
+
self.download_dir = download_dir
|
89 |
+
self.load_format = load_format
|
90 |
+
self.seed = seed
|
91 |
+
self.revision = revision
|
92 |
+
self.tokenizer_revision = tokenizer_revision
|
93 |
+
self.quantization = quantization
|
94 |
+
self.enforce_eager = enforce_eager
|
95 |
+
self.max_context_len_to_capture = max_context_len_to_capture
|
96 |
+
self.num_audio_tokens = num_audio_tokens
|
97 |
+
self.num_text_tokens = num_text_tokens
|
98 |
+
|
99 |
+
if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
|
100 |
+
# download model from ModelScope hub,
|
101 |
+
# lazy import so that modelscope is not required for normal use.
|
102 |
+
from modelscope.hub.snapshot_download import (
|
103 |
+
snapshot_download,
|
104 |
+
) # pylint: disable=C
|
105 |
+
|
106 |
+
model_path = snapshot_download(
|
107 |
+
model_id=model, cache_dir=download_dir, revision=revision
|
108 |
+
)
|
109 |
+
self.model = model_path
|
110 |
+
self.download_dir = model_path
|
111 |
+
self.tokenizer = model_path
|
112 |
+
|
113 |
+
self.hf_config = get_config(self.model, trust_remote_code, revision)
|
114 |
+
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
115 |
+
self.max_model_len = _get_and_verify_max_len(self.hf_config, max_model_len)
|
116 |
+
self._verify_load_format()
|
117 |
+
self._verify_tokenizer_mode()
|
118 |
+
self._verify_quantization()
|
119 |
+
self._verify_cuda_graph()
|
120 |
+
|
121 |
+
def _verify_load_format(self) -> None:
|
122 |
+
load_format = self.load_format.lower()
|
123 |
+
supported_load_format = ["auto", "pt", "safetensors", "npcache", "dummy"]
|
124 |
+
rocm_not_supported_load_format = []
|
125 |
+
if load_format not in supported_load_format:
|
126 |
+
raise ValueError(
|
127 |
+
f"Unknown load format: {self.load_format}. Must be one of "
|
128 |
+
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'."
|
129 |
+
)
|
130 |
+
if is_hip() and load_format in rocm_not_supported_load_format:
|
131 |
+
rocm_supported_load_format = [
|
132 |
+
f
|
133 |
+
for f in supported_load_format
|
134 |
+
if (f not in rocm_not_supported_load_format)
|
135 |
+
]
|
136 |
+
raise ValueError(
|
137 |
+
f"load format '{load_format}' is not supported in ROCm. "
|
138 |
+
f"Supported load format are "
|
139 |
+
f"{rocm_supported_load_format}"
|
140 |
+
)
|
141 |
+
|
142 |
+
# TODO: Remove this check once HF updates the pt weights of Mixtral.
|
143 |
+
architectures = getattr(self.hf_config, "architectures", [])
|
144 |
+
if "MixtralForCausalLM" in architectures and load_format == "pt":
|
145 |
+
raise ValueError(
|
146 |
+
"Currently, the 'pt' format is not supported for Mixtral. "
|
147 |
+
"Please use the 'safetensors' format instead. "
|
148 |
+
)
|
149 |
+
self.load_format = load_format
|
150 |
+
|
151 |
+
def _verify_tokenizer_mode(self) -> None:
|
152 |
+
tokenizer_mode = self.tokenizer_mode.lower()
|
153 |
+
if tokenizer_mode not in ["auto", "slow"]:
|
154 |
+
raise ValueError(
|
155 |
+
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
|
156 |
+
"either 'auto' or 'slow'."
|
157 |
+
)
|
158 |
+
self.tokenizer_mode = tokenizer_mode
|
159 |
+
|
160 |
+
def _verify_quantization(self) -> None:
|
161 |
+
supported_quantization = ["awq", "gptq", "squeezellm"]
|
162 |
+
rocm_not_supported_quantization = ["awq"]
|
163 |
+
if self.quantization is not None:
|
164 |
+
self.quantization = self.quantization.lower()
|
165 |
+
|
166 |
+
# Parse quantization method from the HF model config, if available.
|
167 |
+
hf_quant_config = getattr(self.hf_config, "quantization_config", None)
|
168 |
+
if hf_quant_config is not None:
|
169 |
+
hf_quant_method = str(hf_quant_config["quant_method"]).lower()
|
170 |
+
if self.quantization is None:
|
171 |
+
self.quantization = hf_quant_method
|
172 |
+
elif self.quantization != hf_quant_method:
|
173 |
+
raise ValueError(
|
174 |
+
"Quantization method specified in the model config "
|
175 |
+
f"({hf_quant_method}) does not match the quantization "
|
176 |
+
f"method specified in the `quantization` argument "
|
177 |
+
f"({self.quantization})."
|
178 |
+
)
|
179 |
+
|
180 |
+
if self.quantization is not None:
|
181 |
+
if self.quantization not in supported_quantization:
|
182 |
+
raise ValueError(
|
183 |
+
f"Unknown quantization method: {self.quantization}. Must "
|
184 |
+
f"be one of {supported_quantization}."
|
185 |
+
)
|
186 |
+
if is_hip() and self.quantization in rocm_not_supported_quantization:
|
187 |
+
raise ValueError(
|
188 |
+
f"{self.quantization} quantization is currently not supported "
|
189 |
+
f"in ROCm."
|
190 |
+
)
|
191 |
+
logger.warning(
|
192 |
+
f"{self.quantization} quantization is not fully "
|
193 |
+
"optimized yet. The speed can be slower than "
|
194 |
+
"non-quantized models."
|
195 |
+
)
|
196 |
+
|
197 |
+
def _verify_cuda_graph(self) -> None:
|
198 |
+
if self.max_context_len_to_capture is None:
|
199 |
+
self.max_context_len_to_capture = self.max_model_len
|
200 |
+
self.max_context_len_to_capture = min(
|
201 |
+
self.max_context_len_to_capture, self.max_model_len
|
202 |
+
)
|
203 |
+
|
204 |
+
def verify_with_parallel_config(
|
205 |
+
self,
|
206 |
+
parallel_config: "ParallelConfig",
|
207 |
+
) -> None:
|
208 |
+
total_num_attention_heads = self.hf_config.num_attention_heads
|
209 |
+
tensor_parallel_size = parallel_config.tensor_parallel_size
|
210 |
+
if total_num_attention_heads % tensor_parallel_size != 0:
|
211 |
+
raise ValueError(
|
212 |
+
f"Total number of attention heads ({total_num_attention_heads})"
|
213 |
+
" must be divisible by tensor parallel size "
|
214 |
+
f"({tensor_parallel_size})."
|
215 |
+
)
|
216 |
+
|
217 |
+
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
218 |
+
pipeline_parallel_size = parallel_config.pipeline_parallel_size
|
219 |
+
if total_num_hidden_layers % pipeline_parallel_size != 0:
|
220 |
+
raise ValueError(
|
221 |
+
f"Total number of hidden layers ({total_num_hidden_layers}) "
|
222 |
+
"must be divisible by pipeline parallel size "
|
223 |
+
f"({pipeline_parallel_size})."
|
224 |
+
)
|
225 |
+
|
226 |
+
def get_sliding_window(self) -> Optional[int]:
|
227 |
+
return getattr(self.hf_config, "sliding_window", None)
|
228 |
+
|
229 |
+
def get_vocab_size(self) -> int:
|
230 |
+
return self.hf_config.vocab_size
|
231 |
+
|
232 |
+
def get_hidden_size(self) -> int:
|
233 |
+
return self.hf_config.hidden_size
|
234 |
+
|
235 |
+
def get_head_size(self) -> int:
|
236 |
+
# FIXME(woosuk): This may not be true for all models.
|
237 |
+
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
238 |
+
|
239 |
+
def get_total_num_kv_heads(self) -> int:
|
240 |
+
"""Returns the total number of KV heads."""
|
241 |
+
# For GPTBigCode & Falcon:
|
242 |
+
# NOTE: for falcon, when new_decoder_architecture is True, the
|
243 |
+
# multi_query flag is ignored and we use n_head_kv for the number of
|
244 |
+
# KV heads.
|
245 |
+
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
|
246 |
+
new_decoder_arch_falcon = (
|
247 |
+
self.hf_config.model_type in falcon_model_types
|
248 |
+
and getattr(self.hf_config, "new_decoder_architecture", False)
|
249 |
+
)
|
250 |
+
if not new_decoder_arch_falcon and getattr(
|
251 |
+
self.hf_config, "multi_query", False
|
252 |
+
):
|
253 |
+
# Multi-query attention, only one KV head.
|
254 |
+
# Currently, tensor parallelism is not supported in this case.
|
255 |
+
return 1
|
256 |
+
|
257 |
+
attributes = [
|
258 |
+
# For Falcon:
|
259 |
+
"n_head_kv",
|
260 |
+
"num_kv_heads",
|
261 |
+
# For LLaMA-2:
|
262 |
+
"num_key_value_heads",
|
263 |
+
# For ChatGLM:
|
264 |
+
"multi_query_group_num",
|
265 |
+
]
|
266 |
+
for attr in attributes:
|
267 |
+
num_kv_heads = getattr(self.hf_config, attr, None)
|
268 |
+
if num_kv_heads is not None:
|
269 |
+
return num_kv_heads
|
270 |
+
|
271 |
+
# For non-grouped-query attention models, the number of KV heads is
|
272 |
+
# equal to the number of attention heads.
|
273 |
+
return self.hf_config.num_attention_heads
|
274 |
+
|
275 |
+
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
|
276 |
+
"""Returns the number of KV heads per GPU."""
|
277 |
+
total_num_kv_heads = self.get_total_num_kv_heads()
|
278 |
+
# If tensor parallelism is used, we divide the number of KV heads by
|
279 |
+
# the tensor parallel size. We will replicate the KV heads in the
|
280 |
+
# case where the number of KV heads is smaller than the tensor
|
281 |
+
# parallel size so each GPU has at least one KV head.
|
282 |
+
return max(1, total_num_kv_heads // parallel_config.tensor_parallel_size)
|
283 |
+
|
284 |
+
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
285 |
+
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
286 |
+
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
|
287 |
+
|
288 |
+
|
289 |
+
class CacheConfig:
|
290 |
+
"""Configuration for the KV cache.
|
291 |
+
|
292 |
+
Args:
|
293 |
+
block_size: Size of a cache block in number of tokens.
|
294 |
+
gpu_memory_utilization: Fraction of GPU memory to use for the
|
295 |
+
vLLM execution.
|
296 |
+
swap_space: Size of the CPU swap space per GPU (in GiB).
|
297 |
+
"""
|
298 |
+
|
299 |
+
def __init__(
|
300 |
+
self,
|
301 |
+
block_size: int,
|
302 |
+
gpu_memory_utilization: float,
|
303 |
+
swap_space: int,
|
304 |
+
sliding_window: Optional[int] = None,
|
305 |
+
) -> None:
|
306 |
+
self.block_size = block_size
|
307 |
+
self.gpu_memory_utilization = gpu_memory_utilization
|
308 |
+
self.swap_space_bytes = swap_space * _GB
|
309 |
+
self.sliding_window = sliding_window
|
310 |
+
self._verify_args()
|
311 |
+
|
312 |
+
# Will be set after profiling.
|
313 |
+
self.num_gpu_blocks = None
|
314 |
+
self.num_cpu_blocks = None
|
315 |
+
|
316 |
+
def _verify_args(self) -> None:
|
317 |
+
if self.gpu_memory_utilization > 1.0:
|
318 |
+
raise ValueError(
|
319 |
+
"GPU memory utilization must be less than 1.0. Got "
|
320 |
+
f"{self.gpu_memory_utilization}."
|
321 |
+
)
|
322 |
+
|
323 |
+
def verify_with_parallel_config(
|
324 |
+
self,
|
325 |
+
parallel_config: "ParallelConfig",
|
326 |
+
) -> None:
|
327 |
+
total_cpu_memory = get_cpu_memory()
|
328 |
+
# FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
|
329 |
+
# group are in the same node. However, the GPUs may span multiple nodes.
|
330 |
+
num_gpus_per_node = parallel_config.tensor_parallel_size
|
331 |
+
cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
|
332 |
+
|
333 |
+
msg = (
|
334 |
+
f"{cpu_memory_usage / _GB:.2f} GiB out of "
|
335 |
+
f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
|
336 |
+
"allocated for the swap space."
|
337 |
+
)
|
338 |
+
if cpu_memory_usage > 0.7 * total_cpu_memory:
|
339 |
+
raise ValueError("Too large swap space. " + msg)
|
340 |
+
elif cpu_memory_usage > 0.4 * total_cpu_memory:
|
341 |
+
logger.warning("Possibly too large swap space. " + msg)
|
342 |
+
|
343 |
+
|
344 |
+
class ParallelConfig:
|
345 |
+
"""Configuration for the distributed execution.
|
346 |
+
|
347 |
+
Args:
|
348 |
+
pipeline_parallel_size: Number of pipeline parallel groups.
|
349 |
+
tensor_parallel_size: Number of tensor parallel groups.
|
350 |
+
worker_use_ray: Whether to use Ray for model workers. Will be set to
|
351 |
+
True if either pipeline_parallel_size or tensor_parallel_size is
|
352 |
+
greater than 1.
|
353 |
+
"""
|
354 |
+
|
355 |
+
def __init__(
|
356 |
+
self,
|
357 |
+
pipeline_parallel_size: int,
|
358 |
+
tensor_parallel_size: int,
|
359 |
+
worker_use_ray: bool,
|
360 |
+
max_parallel_loading_workers: Optional[int] = None,
|
361 |
+
) -> None:
|
362 |
+
self.pipeline_parallel_size = pipeline_parallel_size
|
363 |
+
self.tensor_parallel_size = tensor_parallel_size
|
364 |
+
self.worker_use_ray = worker_use_ray
|
365 |
+
self.max_parallel_loading_workers = max_parallel_loading_workers
|
366 |
+
|
367 |
+
self.world_size = pipeline_parallel_size * tensor_parallel_size
|
368 |
+
if self.world_size > 1:
|
369 |
+
self.worker_use_ray = True
|
370 |
+
self._verify_args()
|
371 |
+
|
372 |
+
def _verify_args(self) -> None:
|
373 |
+
if self.pipeline_parallel_size > 1:
|
374 |
+
raise NotImplementedError("Pipeline parallelism is not supported yet.")
|
375 |
+
|
376 |
+
|
377 |
+
class SchedulerConfig:
|
378 |
+
"""Scheduler configuration.
|
379 |
+
|
380 |
+
Args:
|
381 |
+
max_num_batched_tokens: Maximum number of tokens to be processed in
|
382 |
+
a single iteration.
|
383 |
+
max_num_seqs: Maximum number of sequences to be processed in a single
|
384 |
+
iteration.
|
385 |
+
max_model_len: Maximum length of a sequence (including prompt
|
386 |
+
and generated text).
|
387 |
+
max_paddings: Maximum number of paddings to be added to a batch.
|
388 |
+
"""
|
389 |
+
|
390 |
+
def __init__(
|
391 |
+
self,
|
392 |
+
max_num_batched_tokens: Optional[int],
|
393 |
+
max_num_seqs: int,
|
394 |
+
max_model_len: int,
|
395 |
+
max_paddings: int,
|
396 |
+
) -> None:
|
397 |
+
if max_num_batched_tokens is not None:
|
398 |
+
self.max_num_batched_tokens = max_num_batched_tokens
|
399 |
+
else:
|
400 |
+
# If max_model_len is too short, use 2048 as the default value for
|
401 |
+
# higher throughput.
|
402 |
+
self.max_num_batched_tokens = max(max_model_len, 2048)
|
403 |
+
self.max_num_seqs = max_num_seqs
|
404 |
+
self.max_model_len = max_model_len
|
405 |
+
self.max_paddings = max_paddings
|
406 |
+
self._verify_args()
|
407 |
+
|
408 |
+
def _verify_args(self) -> None:
|
409 |
+
if self.max_num_batched_tokens < self.max_model_len:
|
410 |
+
raise ValueError(
|
411 |
+
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
|
412 |
+
f"smaller than max_model_len ({self.max_model_len}). "
|
413 |
+
"This effectively limits the maximum sequence length to "
|
414 |
+
"max_num_batched_tokens and makes vLLM reject longer "
|
415 |
+
"sequences. Please increase max_num_batched_tokens or "
|
416 |
+
"decrease max_model_len."
|
417 |
+
)
|
418 |
+
if self.max_num_batched_tokens < self.max_num_seqs:
|
419 |
+
raise ValueError(
|
420 |
+
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
|
421 |
+
"be greater than or equal to max_num_seqs "
|
422 |
+
f"({self.max_num_seqs})."
|
423 |
+
)
|
424 |
+
|
425 |
+
|
426 |
+
_STR_DTYPE_TO_TORCH_DTYPE = {
|
427 |
+
"half": torch.float16,
|
428 |
+
"float16": torch.float16,
|
429 |
+
"float": torch.float32,
|
430 |
+
"float32": torch.float32,
|
431 |
+
"bfloat16": torch.bfloat16,
|
432 |
+
}
|
433 |
+
|
434 |
+
_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
|
435 |
+
|
436 |
+
|
437 |
+
def _get_and_verify_dtype(
|
438 |
+
config: PretrainedConfig,
|
439 |
+
dtype: Union[str, torch.dtype],
|
440 |
+
) -> torch.dtype:
|
441 |
+
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
442 |
+
# because config.torch_dtype can be None.
|
443 |
+
config_dtype = getattr(config, "torch_dtype", None)
|
444 |
+
if config_dtype is None:
|
445 |
+
config_dtype = torch.float32
|
446 |
+
|
447 |
+
if isinstance(dtype, str):
|
448 |
+
dtype = dtype.lower()
|
449 |
+
if dtype == "auto":
|
450 |
+
if config_dtype == torch.float32:
|
451 |
+
# Following the common practice, we use float16 for float32
|
452 |
+
# models.
|
453 |
+
torch_dtype = torch.float16
|
454 |
+
else:
|
455 |
+
torch_dtype = config_dtype
|
456 |
+
else:
|
457 |
+
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
458 |
+
raise ValueError(f"Unknown dtype: {dtype}")
|
459 |
+
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
460 |
+
elif isinstance(dtype, torch.dtype):
|
461 |
+
torch_dtype = dtype
|
462 |
+
else:
|
463 |
+
raise ValueError(f"Unknown dtype: {dtype}")
|
464 |
+
|
465 |
+
if is_hip() and torch_dtype == torch.float32:
|
466 |
+
rocm_supported_dtypes = [
|
467 |
+
k
|
468 |
+
for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
|
469 |
+
if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
|
470 |
+
]
|
471 |
+
raise ValueError(
|
472 |
+
f"dtype '{dtype}' is not supported in ROCm. "
|
473 |
+
f"Supported dtypes are {rocm_supported_dtypes}"
|
474 |
+
)
|
475 |
+
|
476 |
+
# Verify the dtype.
|
477 |
+
if torch_dtype != config_dtype:
|
478 |
+
if torch_dtype == torch.float32:
|
479 |
+
# Upcasting to float32 is allowed.
|
480 |
+
pass
|
481 |
+
elif config_dtype == torch.float32:
|
482 |
+
# Downcasting from float32 to float16 or bfloat16 is allowed.
|
483 |
+
pass
|
484 |
+
else:
|
485 |
+
# Casting between float16 and bfloat16 is allowed with a warning.
|
486 |
+
logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
|
487 |
+
|
488 |
+
return torch_dtype
|
489 |
+
|
490 |
+
|
491 |
+
def _get_and_verify_max_len(
|
492 |
+
hf_config: PretrainedConfig,
|
493 |
+
max_model_len: Optional[int],
|
494 |
+
) -> int:
|
495 |
+
"""Get and verify the model's maximum length."""
|
496 |
+
derived_max_model_len = float("inf")
|
497 |
+
possible_keys = [
|
498 |
+
# OPT
|
499 |
+
"max_position_embeddings",
|
500 |
+
# GPT-2
|
501 |
+
"n_positions",
|
502 |
+
# MPT
|
503 |
+
"max_seq_len",
|
504 |
+
# ChatGLM2
|
505 |
+
"seq_length",
|
506 |
+
# Others
|
507 |
+
"max_sequence_length",
|
508 |
+
"max_seq_length",
|
509 |
+
"seq_len",
|
510 |
+
]
|
511 |
+
for key in possible_keys:
|
512 |
+
max_len_key = getattr(hf_config, key, None)
|
513 |
+
if max_len_key is not None:
|
514 |
+
derived_max_model_len = min(derived_max_model_len, max_len_key)
|
515 |
+
if derived_max_model_len == float("inf"):
|
516 |
+
if max_model_len is not None:
|
517 |
+
# If max_model_len is specified, we use it.
|
518 |
+
return max_model_len
|
519 |
+
|
520 |
+
default_max_len = 2048
|
521 |
+
logger.warning(
|
522 |
+
"The model's config.json does not contain any of the following "
|
523 |
+
"keys to determine the original maximum length of the model: "
|
524 |
+
f"{possible_keys}. Assuming the model's maximum length is "
|
525 |
+
f"{default_max_len}."
|
526 |
+
)
|
527 |
+
derived_max_model_len = default_max_len
|
528 |
+
|
529 |
+
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
530 |
+
if rope_scaling is not None:
|
531 |
+
assert "factor" in rope_scaling
|
532 |
+
scaling_factor = rope_scaling["factor"]
|
533 |
+
if rope_scaling["type"] == "yarn":
|
534 |
+
derived_max_model_len = rope_scaling["original_max_position_embeddings"]
|
535 |
+
derived_max_model_len *= scaling_factor
|
536 |
+
|
537 |
+
if max_model_len is None:
|
538 |
+
max_model_len = derived_max_model_len
|
539 |
+
elif max_model_len > derived_max_model_len:
|
540 |
+
raise ValueError(
|
541 |
+
f"User-specified max_model_len ({max_model_len}) is greater than "
|
542 |
+
f"the derived max_model_len ({max_len_key}={derived_max_model_len}"
|
543 |
+
" in model's config.json). This may lead to incorrect model "
|
544 |
+
"outputs or CUDA errors. Make sure the value is correct and "
|
545 |
+
"within the model context size."
|
546 |
+
)
|
547 |
+
return int(max_model_len)
|
548 |
+
|
549 |
+
|
550 |
+
@dataclass
|
551 |
+
class EngineArgs:
|
552 |
+
"""Arguments for vLLM engine."""
|
553 |
+
|
554 |
+
model: str
|
555 |
+
tokenizer: Optional[str] = None
|
556 |
+
tokenizer_mode: str = "auto"
|
557 |
+
trust_remote_code: bool = False
|
558 |
+
download_dir: Optional[str] = None
|
559 |
+
load_format: str = "auto"
|
560 |
+
dtype: str = "auto"
|
561 |
+
seed: int = 0
|
562 |
+
max_model_len: Optional[int] = None
|
563 |
+
worker_use_ray: bool = False
|
564 |
+
pipeline_parallel_size: int = 1
|
565 |
+
tensor_parallel_size: int = 1
|
566 |
+
max_parallel_loading_workers: Optional[int] = None
|
567 |
+
block_size: int = 16
|
568 |
+
swap_space: int = 4 # GiB
|
569 |
+
gpu_memory_utilization: float = 0.90
|
570 |
+
max_num_batched_tokens: Optional[int] = None
|
571 |
+
max_num_seqs: int = 256
|
572 |
+
max_paddings: int = 256
|
573 |
+
disable_log_stats: bool = False
|
574 |
+
revision: Optional[str] = None
|
575 |
+
tokenizer_revision: Optional[str] = None
|
576 |
+
quantization: Optional[str] = None
|
577 |
+
enforce_eager: bool = False
|
578 |
+
max_context_len_to_capture: int = 8192
|
579 |
+
num_audio_tokens: int = 1024
|
580 |
+
num_text_tokens: int = 80
|
581 |
+
|
582 |
+
def __post_init__(self):
|
583 |
+
if self.tokenizer is None:
|
584 |
+
self.tokenizer = self.model
|
585 |
+
|
586 |
+
@staticmethod
|
587 |
+
def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
588 |
+
"""Shared CLI arguments for vLLM engine."""
|
589 |
+
|
590 |
+
# NOTE: If you update any of the arguments below, please also
|
591 |
+
# make sure to update docs/source/models/engine_args.rst
|
592 |
+
|
593 |
+
# Model arguments
|
594 |
+
parser.add_argument(
|
595 |
+
"--model",
|
596 |
+
type=str,
|
597 |
+
default="facebook/opt-125m",
|
598 |
+
help="name or path of the huggingface model to use",
|
599 |
+
)
|
600 |
+
parser.add_argument(
|
601 |
+
"--tokenizer",
|
602 |
+
type=str,
|
603 |
+
default=EngineArgs.tokenizer,
|
604 |
+
help="name or path of the huggingface tokenizer to use",
|
605 |
+
)
|
606 |
+
parser.add_argument(
|
607 |
+
"--revision",
|
608 |
+
type=str,
|
609 |
+
default=None,
|
610 |
+
help="the specific model version to use. It can be a branch "
|
611 |
+
"name, a tag name, or a commit id. If unspecified, will use "
|
612 |
+
"the default version.",
|
613 |
+
)
|
614 |
+
parser.add_argument(
|
615 |
+
"--tokenizer-revision",
|
616 |
+
type=str,
|
617 |
+
default=None,
|
618 |
+
help="the specific tokenizer version to use. It can be a branch "
|
619 |
+
"name, a tag name, or a commit id. If unspecified, will use "
|
620 |
+
"the default version.",
|
621 |
+
)
|
622 |
+
parser.add_argument(
|
623 |
+
"--tokenizer-mode",
|
624 |
+
type=str,
|
625 |
+
default=EngineArgs.tokenizer_mode,
|
626 |
+
choices=["auto", "slow"],
|
627 |
+
help='tokenizer mode. "auto" will use the fast '
|
628 |
+
'tokenizer if available, and "slow" will '
|
629 |
+
"always use the slow tokenizer.",
|
630 |
+
)
|
631 |
+
parser.add_argument(
|
632 |
+
"--trust-remote-code",
|
633 |
+
action="store_true",
|
634 |
+
help="trust remote code from huggingface",
|
635 |
+
)
|
636 |
+
parser.add_argument(
|
637 |
+
"--download-dir",
|
638 |
+
type=str,
|
639 |
+
default=EngineArgs.download_dir,
|
640 |
+
help="directory to download and load the weights, "
|
641 |
+
"default to the default cache dir of "
|
642 |
+
"huggingface",
|
643 |
+
)
|
644 |
+
parser.add_argument(
|
645 |
+
"--load-format",
|
646 |
+
type=str,
|
647 |
+
default=EngineArgs.load_format,
|
648 |
+
choices=["auto", "pt", "safetensors", "npcache", "dummy"],
|
649 |
+
help="The format of the model weights to load. "
|
650 |
+
'"auto" will try to load the weights in the safetensors format '
|
651 |
+
"and fall back to the pytorch bin format if safetensors format "
|
652 |
+
"is not available. "
|
653 |
+
'"pt" will load the weights in the pytorch bin format. '
|
654 |
+
'"safetensors" will load the weights in the safetensors format. '
|
655 |
+
'"npcache" will load the weights in pytorch format and store '
|
656 |
+
"a numpy cache to speed up the loading. "
|
657 |
+
'"dummy" will initialize the weights with random values, '
|
658 |
+
"which is mainly for profiling.",
|
659 |
+
)
|
660 |
+
parser.add_argument(
|
661 |
+
"--dtype",
|
662 |
+
type=str,
|
663 |
+
default=EngineArgs.dtype,
|
664 |
+
choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
|
665 |
+
help="data type for model weights and activations. "
|
666 |
+
'The "auto" option will use FP16 precision '
|
667 |
+
"for FP32 and FP16 models, and BF16 precision "
|
668 |
+
"for BF16 models.",
|
669 |
+
)
|
670 |
+
parser.add_argument(
|
671 |
+
"--max-model-len",
|
672 |
+
type=int,
|
673 |
+
default=None,
|
674 |
+
help="model context length. If unspecified, "
|
675 |
+
"will be automatically derived from the model.",
|
676 |
+
)
|
677 |
+
# Parallel arguments
|
678 |
+
parser.add_argument(
|
679 |
+
"--worker-use-ray",
|
680 |
+
action="store_true",
|
681 |
+
help="use Ray for distributed serving, will be "
|
682 |
+
"automatically set when using more than 1 GPU",
|
683 |
+
)
|
684 |
+
parser.add_argument(
|
685 |
+
"--pipeline-parallel-size",
|
686 |
+
"-pp",
|
687 |
+
type=int,
|
688 |
+
default=EngineArgs.pipeline_parallel_size,
|
689 |
+
help="number of pipeline stages",
|
690 |
+
)
|
691 |
+
parser.add_argument(
|
692 |
+
"--tensor-parallel-size",
|
693 |
+
"-tp",
|
694 |
+
type=int,
|
695 |
+
default=EngineArgs.tensor_parallel_size,
|
696 |
+
help="number of tensor parallel replicas",
|
697 |
+
)
|
698 |
+
parser.add_argument(
|
699 |
+
"--max-parallel-loading-workers",
|
700 |
+
type=int,
|
701 |
+
help="load model sequentially in multiple batches, "
|
702 |
+
"to avoid RAM OOM when using tensor "
|
703 |
+
"parallel and large models",
|
704 |
+
)
|
705 |
+
# KV cache arguments
|
706 |
+
parser.add_argument(
|
707 |
+
"--block-size",
|
708 |
+
type=int,
|
709 |
+
default=EngineArgs.block_size,
|
710 |
+
choices=[8, 16, 32],
|
711 |
+
help="token block size",
|
712 |
+
)
|
713 |
+
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
714 |
+
parser.add_argument(
|
715 |
+
"--seed", type=int, default=EngineArgs.seed, help="random seed"
|
716 |
+
)
|
717 |
+
parser.add_argument(
|
718 |
+
"--swap-space",
|
719 |
+
type=int,
|
720 |
+
default=EngineArgs.swap_space,
|
721 |
+
help="CPU swap space size (GiB) per GPU",
|
722 |
+
)
|
723 |
+
parser.add_argument(
|
724 |
+
"--gpu-memory-utilization",
|
725 |
+
type=float,
|
726 |
+
default=EngineArgs.gpu_memory_utilization,
|
727 |
+
help="the fraction of GPU memory to be used for "
|
728 |
+
"the model executor, which can range from 0 to 1."
|
729 |
+
"If unspecified, will use the default value of 0.9.",
|
730 |
+
)
|
731 |
+
parser.add_argument(
|
732 |
+
"--max-num-batched-tokens",
|
733 |
+
type=int,
|
734 |
+
default=EngineArgs.max_num_batched_tokens,
|
735 |
+
help="maximum number of batched tokens per " "iteration",
|
736 |
+
)
|
737 |
+
parser.add_argument(
|
738 |
+
"--max-num-seqs",
|
739 |
+
type=int,
|
740 |
+
default=EngineArgs.max_num_seqs,
|
741 |
+
help="maximum number of sequences per iteration",
|
742 |
+
)
|
743 |
+
parser.add_argument(
|
744 |
+
"--max-paddings",
|
745 |
+
type=int,
|
746 |
+
default=EngineArgs.max_paddings,
|
747 |
+
help="maximum number of paddings in a batch",
|
748 |
+
)
|
749 |
+
parser.add_argument(
|
750 |
+
"--disable-log-stats",
|
751 |
+
action="store_true",
|
752 |
+
help="disable logging statistics",
|
753 |
+
)
|
754 |
+
# Quantization settings.
|
755 |
+
parser.add_argument(
|
756 |
+
"--quantization",
|
757 |
+
"-q",
|
758 |
+
type=str,
|
759 |
+
choices=["awq", "gptq", "squeezellm", None],
|
760 |
+
default=None,
|
761 |
+
help="Method used to quantize the weights. If "
|
762 |
+
"None, we first check the `quantization_config` "
|
763 |
+
"attribute in the model config file. If that is "
|
764 |
+
"None, we assume the model weights are not "
|
765 |
+
"quantized and use `dtype` to determine the data "
|
766 |
+
"type of the weights.",
|
767 |
+
)
|
768 |
+
parser.add_argument(
|
769 |
+
"--enforce-eager",
|
770 |
+
action="store_true",
|
771 |
+
help="Always use eager-mode PyTorch. If False, "
|
772 |
+
"will use eager mode and CUDA graph in hybrid "
|
773 |
+
"for maximal performance and flexibility.",
|
774 |
+
)
|
775 |
+
parser.add_argument(
|
776 |
+
"--max-context-len-to-capture",
|
777 |
+
type=int,
|
778 |
+
default=EngineArgs.max_context_len_to_capture,
|
779 |
+
help="maximum context length covered by CUDA "
|
780 |
+
"graphs. When a sequence has context length "
|
781 |
+
"larger than this, we fall back to eager mode.",
|
782 |
+
)
|
783 |
+
return parser
|
784 |
+
|
785 |
+
@classmethod
|
786 |
+
def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs":
|
787 |
+
# Get the list of attributes of this dataclass.
|
788 |
+
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
789 |
+
# Set the attributes from the parsed arguments.
|
790 |
+
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
|
791 |
+
return engine_args
|
792 |
+
|
793 |
+
def create_engine_configs(
|
794 |
+
self,
|
795 |
+
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
|
796 |
+
model_config = ModelConfig(
|
797 |
+
self.model,
|
798 |
+
self.tokenizer,
|
799 |
+
self.tokenizer_mode,
|
800 |
+
self.trust_remote_code,
|
801 |
+
self.download_dir,
|
802 |
+
self.load_format,
|
803 |
+
self.dtype,
|
804 |
+
self.seed,
|
805 |
+
self.revision,
|
806 |
+
self.tokenizer_revision,
|
807 |
+
self.max_model_len,
|
808 |
+
self.quantization,
|
809 |
+
self.enforce_eager,
|
810 |
+
self.max_context_len_to_capture,
|
811 |
+
self.num_audio_tokens,
|
812 |
+
self.num_text_tokens,
|
813 |
+
)
|
814 |
+
cache_config = CacheConfig(
|
815 |
+
self.block_size,
|
816 |
+
self.gpu_memory_utilization,
|
817 |
+
self.swap_space,
|
818 |
+
model_config.get_sliding_window(),
|
819 |
+
)
|
820 |
+
parallel_config = ParallelConfig(
|
821 |
+
self.pipeline_parallel_size,
|
822 |
+
self.tensor_parallel_size,
|
823 |
+
self.worker_use_ray,
|
824 |
+
self.max_parallel_loading_workers,
|
825 |
+
)
|
826 |
+
scheduler_config = SchedulerConfig(
|
827 |
+
self.max_num_batched_tokens,
|
828 |
+
self.max_num_seqs,
|
829 |
+
model_config.max_model_len,
|
830 |
+
self.max_paddings,
|
831 |
+
)
|
832 |
+
return model_config, cache_config, parallel_config, scheduler_config
|
833 |
+
|
834 |
+
|
835 |
+
@dataclass
|
836 |
+
class AsyncEngineArgs(EngineArgs):
|
837 |
+
"""Arguments for asynchronous vLLM engine."""
|
838 |
+
|
839 |
+
engine_use_ray: bool = False
|
840 |
+
disable_log_requests: bool = False
|
841 |
+
max_log_len: Optional[int] = None
|
842 |
+
|
843 |
+
@staticmethod
|
844 |
+
def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
845 |
+
parser = EngineArgs.add_cli_args(parser)
|
846 |
+
parser.add_argument(
|
847 |
+
"--engine-use-ray",
|
848 |
+
action="store_true",
|
849 |
+
help="use Ray to start the LLM engine in a "
|
850 |
+
"separate process as the server process.",
|
851 |
+
)
|
852 |
+
parser.add_argument(
|
853 |
+
"--disable-log-requests",
|
854 |
+
action="store_true",
|
855 |
+
help="disable logging requests",
|
856 |
+
)
|
857 |
+
parser.add_argument(
|
858 |
+
"--max-log-len",
|
859 |
+
type=int,
|
860 |
+
default=None,
|
861 |
+
help="max number of prompt characters or prompt "
|
862 |
+
"ID numbers being printed in log. "
|
863 |
+
"Default: unlimited.",
|
864 |
+
)
|
865 |
+
return parser
|
ChatTTS/model/velocity/llama.py
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Adapted from
|
3 |
+
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
4 |
+
# Copyright 2023 The vLLM team.
|
5 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
6 |
+
#
|
7 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
8 |
+
# and OPT implementations in this library. It has been modified from its
|
9 |
+
# original forms to accommodate minor architectural differences compared
|
10 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
11 |
+
#
|
12 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
13 |
+
# you may not use this file except in compliance with the License.
|
14 |
+
# You may obtain a copy of the License at
|
15 |
+
#
|
16 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
17 |
+
#
|
18 |
+
# Unless required by applicable law or agreed to in writing, software
|
19 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
20 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
21 |
+
# See the License for the specific language governing permissions and
|
22 |
+
# limitations under the License.
|
23 |
+
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
24 |
+
from typing import Any, Dict, List, Optional, Tuple
|
25 |
+
|
26 |
+
import torch
|
27 |
+
from torch import nn
|
28 |
+
from transformers import LlamaConfig
|
29 |
+
|
30 |
+
from vllm.model_executor.input_metadata import InputMetadata
|
31 |
+
from vllm.model_executor.layers.activation import SiluAndMul
|
32 |
+
from vllm.model_executor.layers.attention import PagedAttention
|
33 |
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
34 |
+
from vllm.model_executor.layers.linear import (
|
35 |
+
LinearMethodBase,
|
36 |
+
MergedColumnParallelLinear,
|
37 |
+
QKVParallelLinear,
|
38 |
+
RowParallelLinear,
|
39 |
+
)
|
40 |
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
41 |
+
from vllm.model_executor.layers.sampler import Sampler
|
42 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
43 |
+
VocabParallelEmbedding,
|
44 |
+
ParallelLMHead,
|
45 |
+
)
|
46 |
+
from vllm.model_executor.parallel_utils.parallel_state import (
|
47 |
+
get_tensor_model_parallel_world_size,
|
48 |
+
)
|
49 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
50 |
+
from vllm.model_executor.weight_utils import (
|
51 |
+
default_weight_loader,
|
52 |
+
hf_model_weights_iterator,
|
53 |
+
)
|
54 |
+
from vllm.sequence import SamplerOutput
|
55 |
+
|
56 |
+
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
57 |
+
|
58 |
+
|
59 |
+
class LlamaMLP(nn.Module):
|
60 |
+
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
hidden_size: int,
|
64 |
+
intermediate_size: int,
|
65 |
+
hidden_act: str,
|
66 |
+
linear_method: Optional[LinearMethodBase] = None,
|
67 |
+
) -> None:
|
68 |
+
super().__init__()
|
69 |
+
self.gate_up_proj = MergedColumnParallelLinear(
|
70 |
+
hidden_size,
|
71 |
+
[intermediate_size] * 2,
|
72 |
+
bias=False,
|
73 |
+
linear_method=linear_method,
|
74 |
+
)
|
75 |
+
self.down_proj = RowParallelLinear(
|
76 |
+
intermediate_size, hidden_size, bias=False, linear_method=linear_method
|
77 |
+
)
|
78 |
+
if hidden_act != "silu":
|
79 |
+
raise ValueError(
|
80 |
+
f"Unsupported activation: {hidden_act}. "
|
81 |
+
"Only silu is supported for now."
|
82 |
+
)
|
83 |
+
self.act_fn = SiluAndMul()
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
gate_up, _ = self.gate_up_proj(x)
|
87 |
+
x = self.act_fn(gate_up)
|
88 |
+
x, _ = self.down_proj(x)
|
89 |
+
return x
|
90 |
+
|
91 |
+
|
92 |
+
class LlamaAttention(nn.Module):
|
93 |
+
|
94 |
+
def __init__(
|
95 |
+
self,
|
96 |
+
hidden_size: int,
|
97 |
+
num_heads: int,
|
98 |
+
num_kv_heads: int,
|
99 |
+
rope_theta: float = 10000,
|
100 |
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
101 |
+
max_position_embeddings: int = 8192,
|
102 |
+
linear_method: Optional[LinearMethodBase] = None,
|
103 |
+
) -> None:
|
104 |
+
super().__init__()
|
105 |
+
self.hidden_size = hidden_size
|
106 |
+
tp_size = get_tensor_model_parallel_world_size()
|
107 |
+
self.total_num_heads = num_heads
|
108 |
+
assert self.total_num_heads % tp_size == 0
|
109 |
+
self.num_heads = self.total_num_heads // tp_size
|
110 |
+
self.total_num_kv_heads = num_kv_heads
|
111 |
+
if self.total_num_kv_heads >= tp_size:
|
112 |
+
# Number of KV heads is greater than TP size, so we partition
|
113 |
+
# the KV heads across multiple tensor parallel GPUs.
|
114 |
+
assert self.total_num_kv_heads % tp_size == 0
|
115 |
+
else:
|
116 |
+
# Number of KV heads is less than TP size, so we replicate
|
117 |
+
# the KV heads across multiple tensor parallel GPUs.
|
118 |
+
assert tp_size % self.total_num_kv_heads == 0
|
119 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
120 |
+
self.head_dim = hidden_size // self.total_num_heads
|
121 |
+
self.q_size = self.num_heads * self.head_dim
|
122 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
123 |
+
self.scaling = self.head_dim**-0.5
|
124 |
+
self.rope_theta = rope_theta
|
125 |
+
self.max_position_embeddings = max_position_embeddings
|
126 |
+
|
127 |
+
self.qkv_proj = QKVParallelLinear(
|
128 |
+
hidden_size,
|
129 |
+
self.head_dim,
|
130 |
+
self.total_num_heads,
|
131 |
+
self.total_num_kv_heads,
|
132 |
+
bias=False,
|
133 |
+
linear_method=linear_method,
|
134 |
+
)
|
135 |
+
self.o_proj = RowParallelLinear(
|
136 |
+
self.total_num_heads * self.head_dim,
|
137 |
+
hidden_size,
|
138 |
+
bias=False,
|
139 |
+
linear_method=linear_method,
|
140 |
+
)
|
141 |
+
|
142 |
+
self.rotary_emb = get_rope(
|
143 |
+
self.head_dim,
|
144 |
+
rotary_dim=self.head_dim,
|
145 |
+
max_position=max_position_embeddings,
|
146 |
+
base=rope_theta,
|
147 |
+
rope_scaling=rope_scaling,
|
148 |
+
)
|
149 |
+
self.attn = PagedAttention(
|
150 |
+
self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads
|
151 |
+
)
|
152 |
+
|
153 |
+
def forward(
|
154 |
+
self,
|
155 |
+
positions: torch.Tensor,
|
156 |
+
hidden_states: torch.Tensor,
|
157 |
+
kv_cache: KVCache,
|
158 |
+
input_metadata: InputMetadata,
|
159 |
+
) -> torch.Tensor:
|
160 |
+
qkv, _ = self.qkv_proj(hidden_states)
|
161 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
162 |
+
q, k = self.rotary_emb(positions, q, k)
|
163 |
+
k_cache, v_cache = kv_cache
|
164 |
+
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
165 |
+
output, _ = self.o_proj(attn_output)
|
166 |
+
return output
|
167 |
+
|
168 |
+
|
169 |
+
class LlamaDecoderLayer(nn.Module):
|
170 |
+
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
config: LlamaConfig,
|
174 |
+
linear_method: Optional[LinearMethodBase] = None,
|
175 |
+
) -> None:
|
176 |
+
super().__init__()
|
177 |
+
self.hidden_size = config.hidden_size
|
178 |
+
rope_theta = getattr(config, "rope_theta", 10000)
|
179 |
+
rope_scaling = getattr(config, "rope_scaling", None)
|
180 |
+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
181 |
+
self.self_attn = LlamaAttention(
|
182 |
+
hidden_size=self.hidden_size,
|
183 |
+
num_heads=config.num_attention_heads,
|
184 |
+
num_kv_heads=config.num_key_value_heads,
|
185 |
+
rope_theta=rope_theta,
|
186 |
+
rope_scaling=rope_scaling,
|
187 |
+
max_position_embeddings=max_position_embeddings,
|
188 |
+
linear_method=linear_method,
|
189 |
+
)
|
190 |
+
self.mlp = LlamaMLP(
|
191 |
+
hidden_size=self.hidden_size,
|
192 |
+
intermediate_size=config.intermediate_size,
|
193 |
+
hidden_act=config.hidden_act,
|
194 |
+
linear_method=linear_method,
|
195 |
+
)
|
196 |
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
197 |
+
self.post_attention_layernorm = RMSNorm(
|
198 |
+
config.hidden_size, eps=config.rms_norm_eps
|
199 |
+
)
|
200 |
+
|
201 |
+
def forward(
|
202 |
+
self,
|
203 |
+
positions: torch.Tensor,
|
204 |
+
hidden_states: torch.Tensor,
|
205 |
+
kv_cache: KVCache,
|
206 |
+
input_metadata: InputMetadata,
|
207 |
+
residual: Optional[torch.Tensor],
|
208 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
209 |
+
# Self Attention
|
210 |
+
if residual is None:
|
211 |
+
residual = hidden_states
|
212 |
+
hidden_states = self.input_layernorm(hidden_states)
|
213 |
+
else:
|
214 |
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
215 |
+
hidden_states = self.self_attn(
|
216 |
+
positions=positions,
|
217 |
+
hidden_states=hidden_states,
|
218 |
+
kv_cache=kv_cache,
|
219 |
+
input_metadata=input_metadata,
|
220 |
+
)
|
221 |
+
|
222 |
+
# Fully Connected
|
223 |
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
224 |
+
hidden_states = self.mlp(hidden_states)
|
225 |
+
return hidden_states, residual
|
226 |
+
|
227 |
+
|
228 |
+
class LlamaModel(nn.Module):
|
229 |
+
|
230 |
+
def __init__(
|
231 |
+
self,
|
232 |
+
config: LlamaConfig,
|
233 |
+
linear_method: Optional[LinearMethodBase] = None,
|
234 |
+
) -> None:
|
235 |
+
super().__init__()
|
236 |
+
self.config = config
|
237 |
+
self.padding_idx = config.pad_token_id
|
238 |
+
self.vocab_size = config.vocab_size
|
239 |
+
self.embed_tokens = VocabParallelEmbedding(
|
240 |
+
config.vocab_size,
|
241 |
+
config.hidden_size,
|
242 |
+
)
|
243 |
+
self.layers = nn.ModuleList(
|
244 |
+
[
|
245 |
+
LlamaDecoderLayer(config, linear_method)
|
246 |
+
for _ in range(config.num_hidden_layers)
|
247 |
+
]
|
248 |
+
)
|
249 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
250 |
+
|
251 |
+
def forward(
|
252 |
+
self,
|
253 |
+
input_emb: torch.Tensor,
|
254 |
+
positions: torch.Tensor,
|
255 |
+
kv_caches: List[KVCache],
|
256 |
+
input_metadata: InputMetadata,
|
257 |
+
) -> torch.Tensor:
|
258 |
+
hidden_states = input_emb
|
259 |
+
residual = None
|
260 |
+
for i in range(len(self.layers)):
|
261 |
+
layer = self.layers[i]
|
262 |
+
hidden_states, residual = layer(
|
263 |
+
positions,
|
264 |
+
hidden_states,
|
265 |
+
kv_caches[i],
|
266 |
+
input_metadata,
|
267 |
+
residual,
|
268 |
+
)
|
269 |
+
hidden_states, _ = self.norm(hidden_states, residual)
|
270 |
+
return hidden_states
|
271 |
+
|
272 |
+
def load_weights(
|
273 |
+
self,
|
274 |
+
model_name_or_path: str,
|
275 |
+
cache_dir: Optional[str] = None,
|
276 |
+
load_format: str = "auto",
|
277 |
+
revision: Optional[str] = None,
|
278 |
+
):
|
279 |
+
stacked_params_mapping = [
|
280 |
+
# (param_name, shard_name, shard_id)
|
281 |
+
("qkv_proj", "q_proj", "q"),
|
282 |
+
("qkv_proj", "k_proj", "k"),
|
283 |
+
("qkv_proj", "v_proj", "v"),
|
284 |
+
("gate_up_proj", "gate_proj", 0),
|
285 |
+
("gate_up_proj", "up_proj", 1),
|
286 |
+
]
|
287 |
+
params_dict = dict(self.named_parameters())
|
288 |
+
for name, loaded_weight in hf_model_weights_iterator(
|
289 |
+
model_name_or_path, cache_dir, load_format, revision
|
290 |
+
):
|
291 |
+
if "rotary_emb.inv_freq" in name:
|
292 |
+
continue
|
293 |
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
294 |
+
# Models trained using ColossalAI may include these tensors in
|
295 |
+
# the checkpoint. Skip them.
|
296 |
+
continue
|
297 |
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
298 |
+
if weight_name not in name:
|
299 |
+
continue
|
300 |
+
name = name.replace(weight_name, param_name)
|
301 |
+
# Skip loading extra bias for GPTQ models.
|
302 |
+
if name.endswith(".bias") and name not in params_dict:
|
303 |
+
continue
|
304 |
+
param = params_dict[name]
|
305 |
+
weight_loader = param.weight_loader
|
306 |
+
weight_loader(param, loaded_weight, shard_id)
|
307 |
+
break
|
308 |
+
else:
|
309 |
+
# Skip loading extra bias for GPTQ models.
|
310 |
+
if name.endswith(".bias") and name not in params_dict:
|
311 |
+
continue
|
312 |
+
param = params_dict[name]
|
313 |
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
314 |
+
weight_loader(param, loaded_weight)
|
315 |
+
|
316 |
+
|
317 |
+
class LlamaForCausalLM(nn.Module):
|
318 |
+
|
319 |
+
def __init__(
|
320 |
+
self,
|
321 |
+
config: LlamaConfig,
|
322 |
+
linear_method: Optional[LinearMethodBase] = None,
|
323 |
+
) -> None:
|
324 |
+
super().__init__()
|
325 |
+
self.config = config
|
326 |
+
self.linear_method = linear_method
|
327 |
+
self.model = LlamaModel(config, linear_method)
|
328 |
+
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
329 |
+
self.sampler = Sampler(config.vocab_size)
|
330 |
+
|
331 |
+
def forward(
|
332 |
+
self,
|
333 |
+
input_ids: torch.Tensor,
|
334 |
+
positions: torch.Tensor,
|
335 |
+
kv_caches: List[KVCache],
|
336 |
+
input_metadata: InputMetadata,
|
337 |
+
) -> torch.Tensor:
|
338 |
+
hidden_states = self.model(input_ids, positions, kv_caches, input_metadata)
|
339 |
+
return hidden_states
|
340 |
+
|
341 |
+
def sample(
|
342 |
+
self,
|
343 |
+
hidden_states: torch.Tensor,
|
344 |
+
sampling_metadata: SamplingMetadata,
|
345 |
+
) -> Optional[SamplerOutput]:
|
346 |
+
next_tokens = self.sampler(
|
347 |
+
self.lm_head.weight, hidden_states, sampling_metadata
|
348 |
+
)
|
349 |
+
return next_tokens
|
350 |
+
|
351 |
+
def load_weights(
|
352 |
+
self,
|
353 |
+
model_name_or_path: str,
|
354 |
+
cache_dir: Optional[str] = None,
|
355 |
+
load_format: str = "auto",
|
356 |
+
revision: Optional[str] = None,
|
357 |
+
):
|
358 |
+
stacked_params_mapping = [
|
359 |
+
# (param_name, shard_name, shard_id)
|
360 |
+
("qkv_proj", "q_proj", "q"),
|
361 |
+
("qkv_proj", "k_proj", "k"),
|
362 |
+
("qkv_proj", "v_proj", "v"),
|
363 |
+
("gate_up_proj", "gate_proj", 0),
|
364 |
+
("gate_up_proj", "up_proj", 1),
|
365 |
+
]
|
366 |
+
params_dict = dict(self.named_parameters())
|
367 |
+
for name, loaded_weight in hf_model_weights_iterator(
|
368 |
+
model_name_or_path, cache_dir, load_format, revision
|
369 |
+
):
|
370 |
+
if "rotary_emb.inv_freq" in name:
|
371 |
+
continue
|
372 |
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
373 |
+
# Models trained using ColossalAI may include these tensors in
|
374 |
+
# the checkpoint. Skip them.
|
375 |
+
continue
|
376 |
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
377 |
+
if weight_name not in name:
|
378 |
+
continue
|
379 |
+
name = name.replace(weight_name, param_name)
|
380 |
+
# Skip loading extra bias for GPTQ models.
|
381 |
+
if name.endswith(".bias") and name not in params_dict:
|
382 |
+
continue
|
383 |
+
param = params_dict[name]
|
384 |
+
weight_loader = param.weight_loader
|
385 |
+
weight_loader(param, loaded_weight, shard_id)
|
386 |
+
break
|
387 |
+
else:
|
388 |
+
# Skip loading extra bias for GPTQ models.
|
389 |
+
if name.endswith(".bias") and name not in params_dict:
|
390 |
+
continue
|
391 |
+
param = params_dict[name]
|
392 |
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
393 |
+
weight_loader(param, loaded_weight)
|
ChatTTS/model/velocity/llm.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Union
|
2 |
+
|
3 |
+
from tqdm import tqdm
|
4 |
+
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
5 |
+
from vllm.utils import Counter
|
6 |
+
|
7 |
+
from .configs import EngineArgs
|
8 |
+
from .llm_engine import LLMEngine
|
9 |
+
from .output import RequestOutput
|
10 |
+
from .sampling_params import SamplingParams
|
11 |
+
|
12 |
+
|
13 |
+
class LLM:
|
14 |
+
"""An LLM for generating texts from given prompts and sampling parameters.
|
15 |
+
|
16 |
+
This class includes a tokenizer, a language model (possibly distributed
|
17 |
+
across multiple GPUs), and GPU memory space allocated for intermediate
|
18 |
+
states (aka KV cache). Given a batch of prompts and sampling parameters,
|
19 |
+
this class generates texts from the model, using an intelligent batching
|
20 |
+
mechanism and efficient memory management.
|
21 |
+
|
22 |
+
NOTE: This class is intended to be used for offline inference. For online
|
23 |
+
serving, use the `AsyncLLMEngine` class instead.
|
24 |
+
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
model: The name or path of a HuggingFace Transformers model.
|
28 |
+
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
|
29 |
+
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
|
30 |
+
if available, and "slow" will always use the slow tokenizer.
|
31 |
+
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
32 |
+
downloading the model and tokenizer.
|
33 |
+
tensor_parallel_size: The number of GPUs to use for distributed
|
34 |
+
execution with tensor parallelism.
|
35 |
+
dtype: The data type for the model weights and activations. Currently,
|
36 |
+
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
|
37 |
+
the `torch_dtype` attribute specified in the model config file.
|
38 |
+
However, if the `torch_dtype` in the config is `float32`, we will
|
39 |
+
use `float16` instead.
|
40 |
+
quantization: The method used to quantize the model weights. Currently,
|
41 |
+
we support "awq", "gptq" and "squeezellm". If None, we first check
|
42 |
+
the `quantization_config` attribute in the model config file. If
|
43 |
+
that is None, we assume the model weights are not quantized and use
|
44 |
+
`dtype` to determine the data type of the weights.
|
45 |
+
revision: The specific model version to use. It can be a branch name,
|
46 |
+
a tag name, or a commit id.
|
47 |
+
tokenizer_revision: The specific tokenizer version to use. It can be a
|
48 |
+
branch name, a tag name, or a commit id.
|
49 |
+
seed: The seed to initialize the random number generator for sampling.
|
50 |
+
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
|
51 |
+
reserve for the model weights, activations, and KV cache. Higher
|
52 |
+
values will increase the KV cache size and thus improve the model's
|
53 |
+
throughput. However, if the value is too high, it may cause out-of-
|
54 |
+
memory (OOM) errors.
|
55 |
+
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
|
56 |
+
This can be used for temporarily storing the states of the requests
|
57 |
+
when their `best_of` sampling parameters are larger than 1. If all
|
58 |
+
requests will have `best_of=1`, you can safely set this to 0.
|
59 |
+
Otherwise, too small values may cause out-of-memory (OOM) errors.
|
60 |
+
enforce_eager: Whether to enforce eager execution. If True, we will
|
61 |
+
disable CUDA graph and always execute the model in eager mode.
|
62 |
+
If False, we will use CUDA graph and eager execution in hybrid.
|
63 |
+
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
64 |
+
When a sequence has context length larger than this, we fall back
|
65 |
+
to eager mode.
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
model: str,
|
71 |
+
tokenizer: Optional[str] = None,
|
72 |
+
tokenizer_mode: str = "auto",
|
73 |
+
trust_remote_code: bool = False,
|
74 |
+
tensor_parallel_size: int = 1,
|
75 |
+
dtype: str = "auto",
|
76 |
+
quantization: Optional[str] = None,
|
77 |
+
revision: Optional[str] = None,
|
78 |
+
tokenizer_revision: Optional[str] = None,
|
79 |
+
seed: int = 0,
|
80 |
+
gpu_memory_utilization: float = 0.9,
|
81 |
+
swap_space: int = 4,
|
82 |
+
enforce_eager: bool = False,
|
83 |
+
max_context_len_to_capture: int = 8192,
|
84 |
+
post_model_path: str = None,
|
85 |
+
num_audio_tokens: int = 0,
|
86 |
+
num_text_tokens: int = 0,
|
87 |
+
**kwargs,
|
88 |
+
) -> None:
|
89 |
+
if "disable_log_stats" not in kwargs:
|
90 |
+
kwargs["disable_log_stats"] = True
|
91 |
+
engine_args = EngineArgs(
|
92 |
+
model=model,
|
93 |
+
tokenizer=tokenizer,
|
94 |
+
tokenizer_mode=tokenizer_mode,
|
95 |
+
trust_remote_code=trust_remote_code,
|
96 |
+
tensor_parallel_size=tensor_parallel_size,
|
97 |
+
dtype=dtype,
|
98 |
+
quantization=quantization,
|
99 |
+
revision=revision,
|
100 |
+
tokenizer_revision=tokenizer_revision,
|
101 |
+
seed=seed,
|
102 |
+
gpu_memory_utilization=gpu_memory_utilization,
|
103 |
+
swap_space=swap_space,
|
104 |
+
enforce_eager=enforce_eager,
|
105 |
+
max_context_len_to_capture=max_context_len_to_capture,
|
106 |
+
num_audio_tokens=num_audio_tokens,
|
107 |
+
num_text_tokens=num_text_tokens,
|
108 |
+
**kwargs,
|
109 |
+
)
|
110 |
+
self.llm_engine = LLMEngine.from_engine_args(engine_args, post_model_path)
|
111 |
+
self.request_counter = Counter()
|
112 |
+
|
113 |
+
def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
114 |
+
return self.llm_engine.tokenizer
|
115 |
+
|
116 |
+
def set_tokenizer(
|
117 |
+
self,
|
118 |
+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
119 |
+
) -> None:
|
120 |
+
self.llm_engine.tokenizer = tokenizer
|
121 |
+
|
122 |
+
def generate(
|
123 |
+
self,
|
124 |
+
prompts: Optional[Union[str, List[str]]] = None,
|
125 |
+
sampling_params: Optional[SamplingParams] = None,
|
126 |
+
prompt_token_ids: Optional[List[List[int]]] = None,
|
127 |
+
use_tqdm: bool = True,
|
128 |
+
) -> List[RequestOutput]:
|
129 |
+
"""Generates the completions for the input prompts.
|
130 |
+
|
131 |
+
NOTE: This class automatically batches the given prompts, considering
|
132 |
+
the memory constraint. For the best performance, put all of your prompts
|
133 |
+
into a single list and pass it to this method.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
prompts: A list of prompts to generate completions for.
|
137 |
+
sampling_params: The sampling parameters for text generation. If
|
138 |
+
None, we use the default sampling parameters.
|
139 |
+
prompt_token_ids: A list of token IDs for the prompts. If None, we
|
140 |
+
use the tokenizer to convert the prompts to token IDs.
|
141 |
+
use_tqdm: Whether to use tqdm to display the progress bar.
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
A list of `RequestOutput` objects containing the generated
|
145 |
+
completions in the same order as the input prompts.
|
146 |
+
"""
|
147 |
+
if prompts is None and prompt_token_ids is None:
|
148 |
+
raise ValueError("Either prompts or prompt_token_ids must be " "provided.")
|
149 |
+
if isinstance(prompts, str):
|
150 |
+
# Convert a single prompt to a list.
|
151 |
+
prompts = [prompts]
|
152 |
+
if (
|
153 |
+
prompts is not None
|
154 |
+
and prompt_token_ids is not None
|
155 |
+
and len(prompts) != len(prompt_token_ids)
|
156 |
+
):
|
157 |
+
raise ValueError(
|
158 |
+
"The lengths of prompts and prompt_token_ids " "must be the same."
|
159 |
+
)
|
160 |
+
if sampling_params is None:
|
161 |
+
# Use default sampling params.
|
162 |
+
sampling_params = SamplingParams()
|
163 |
+
|
164 |
+
# Add requests to the engine.
|
165 |
+
num_requests = len(prompts) if prompts is not None else len(prompt_token_ids)
|
166 |
+
for i in range(num_requests):
|
167 |
+
prompt = prompts[i] if prompts is not None else None
|
168 |
+
token_ids = None if prompt_token_ids is None else prompt_token_ids[i]
|
169 |
+
self._add_request(prompt, sampling_params, token_ids)
|
170 |
+
|
171 |
+
rtns = self._run_engine(use_tqdm)
|
172 |
+
for i, rtn in enumerate(rtns):
|
173 |
+
token_ids = rtn.outputs[0].token_ids
|
174 |
+
for j, token_id in enumerate(token_ids):
|
175 |
+
if len(token_id) == 1:
|
176 |
+
token_ids[j] = token_id[0]
|
177 |
+
else:
|
178 |
+
token_ids[j] = list(token_id)
|
179 |
+
|
180 |
+
return rtns
|
181 |
+
|
182 |
+
def _add_request(
|
183 |
+
self,
|
184 |
+
prompt: Optional[str],
|
185 |
+
sampling_params: SamplingParams,
|
186 |
+
prompt_token_ids: Optional[List[int]],
|
187 |
+
) -> None:
|
188 |
+
request_id = str(next(self.request_counter))
|
189 |
+
self.llm_engine.add_request(
|
190 |
+
request_id, prompt, sampling_params, prompt_token_ids
|
191 |
+
)
|
192 |
+
|
193 |
+
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
|
194 |
+
# Initialize tqdm.
|
195 |
+
if use_tqdm:
|
196 |
+
num_requests = self.llm_engine.get_num_unfinished_requests()
|
197 |
+
pbar = tqdm(total=num_requests, desc="Processed prompts")
|
198 |
+
# Run the engine.
|
199 |
+
outputs: List[RequestOutput] = []
|
200 |
+
while self.llm_engine.has_unfinished_requests():
|
201 |
+
step_outputs = self.llm_engine.step()
|
202 |
+
for output in step_outputs:
|
203 |
+
if output.finished:
|
204 |
+
outputs.append(output)
|
205 |
+
if use_tqdm:
|
206 |
+
pbar.update(1)
|
207 |
+
if use_tqdm:
|
208 |
+
pbar.close()
|
209 |
+
# Sort the outputs by request ID.
|
210 |
+
# This is necessary because some requests may be finished earlier than
|
211 |
+
# its previous requests.
|
212 |
+
outputs = sorted(outputs, key=lambda x: int(x.request_id))
|
213 |
+
return outputs
|
ChatTTS/model/velocity/llm_engine.py
ADDED
@@ -0,0 +1,833 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from collections import defaultdict
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig
|
8 |
+
from .scheduler import Scheduler, SchedulerOutputs
|
9 |
+
from .configs import EngineArgs
|
10 |
+
from vllm.engine.metrics import record_metrics
|
11 |
+
from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray
|
12 |
+
from vllm.logger import init_logger
|
13 |
+
from .output import RequestOutput
|
14 |
+
from .sampling_params import SamplingParams
|
15 |
+
from .sequence import (
|
16 |
+
SamplerOutput,
|
17 |
+
Sequence,
|
18 |
+
SequenceGroup,
|
19 |
+
SequenceGroupOutput,
|
20 |
+
SequenceOutput,
|
21 |
+
SequenceStatus,
|
22 |
+
)
|
23 |
+
from vllm.transformers_utils.tokenizer import detokenize_incrementally, get_tokenizer
|
24 |
+
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port
|
25 |
+
import numpy as np
|
26 |
+
|
27 |
+
if ray:
|
28 |
+
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
29 |
+
|
30 |
+
if TYPE_CHECKING:
|
31 |
+
from ray.util.placement_group import PlacementGroup
|
32 |
+
|
33 |
+
logger = init_logger(__name__)
|
34 |
+
|
35 |
+
_LOGGING_INTERVAL_SEC = 5
|
36 |
+
|
37 |
+
|
38 |
+
class LLMEngine:
|
39 |
+
"""An LLM engine that receives requests and generates texts.
|
40 |
+
|
41 |
+
This is the main class for the vLLM engine. It receives requests
|
42 |
+
from clients and generates texts from the LLM. It includes a tokenizer, a
|
43 |
+
language model (possibly distributed across multiple GPUs), and GPU memory
|
44 |
+
space allocated for intermediate states (aka KV cache). This class utilizes
|
45 |
+
iteration-level scheduling and efficient memory management to maximize the
|
46 |
+
serving throughput.
|
47 |
+
|
48 |
+
The `LLM` class wraps this class for offline batched inference and the
|
49 |
+
`AsyncLLMEngine` class wraps this class for online serving.
|
50 |
+
|
51 |
+
NOTE: The config arguments are derived from the `EngineArgs` class. For the
|
52 |
+
comprehensive list of arguments, see `EngineArgs`.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
model_config: The configuration related to the LLM model.
|
56 |
+
cache_config: The configuration related to the KV cache memory
|
57 |
+
management.
|
58 |
+
parallel_config: The configuration related to distributed execution.
|
59 |
+
scheduler_config: The configuration related to the request scheduler.
|
60 |
+
placement_group: Ray placement group for distributed execution.
|
61 |
+
Required for distributed execution.
|
62 |
+
log_stats: Whether to log statistics.
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
model_config: ModelConfig,
|
68 |
+
cache_config: CacheConfig,
|
69 |
+
parallel_config: ParallelConfig,
|
70 |
+
scheduler_config: SchedulerConfig,
|
71 |
+
placement_group: Optional["PlacementGroup"],
|
72 |
+
post_model_path: str,
|
73 |
+
log_stats: bool,
|
74 |
+
) -> None:
|
75 |
+
logger.info(
|
76 |
+
"Initializing an LLM engine with config: "
|
77 |
+
f"model={model_config.model!r}, "
|
78 |
+
f"tokenizer={model_config.tokenizer!r}, "
|
79 |
+
f"tokenizer_mode={model_config.tokenizer_mode}, "
|
80 |
+
f"revision={model_config.revision}, "
|
81 |
+
f"tokenizer_revision={model_config.tokenizer_revision}, "
|
82 |
+
f"trust_remote_code={model_config.trust_remote_code}, "
|
83 |
+
f"dtype={model_config.dtype}, "
|
84 |
+
f"max_seq_len={model_config.max_model_len}, "
|
85 |
+
f"download_dir={model_config.download_dir!r}, "
|
86 |
+
f"load_format={model_config.load_format}, "
|
87 |
+
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
88 |
+
f"quantization={model_config.quantization}, "
|
89 |
+
f"enforce_eager={model_config.enforce_eager}, "
|
90 |
+
f"seed={model_config.seed}), "
|
91 |
+
f"post_model_path={post_model_path!r}"
|
92 |
+
)
|
93 |
+
# TODO(woosuk): Print more configs in debug mode.
|
94 |
+
|
95 |
+
self.model_config = model_config
|
96 |
+
self.cache_config = cache_config
|
97 |
+
self.parallel_config = parallel_config
|
98 |
+
self.scheduler_config = scheduler_config
|
99 |
+
self.log_stats = log_stats
|
100 |
+
self._verify_args()
|
101 |
+
self.post_model_path = post_model_path
|
102 |
+
self.seq_counter = Counter()
|
103 |
+
|
104 |
+
# Create the parallel GPU workers.
|
105 |
+
if self.parallel_config.worker_use_ray:
|
106 |
+
# Disable Ray usage stats collection.
|
107 |
+
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
|
108 |
+
if ray_usage != "1":
|
109 |
+
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
|
110 |
+
self._init_workers_ray(placement_group)
|
111 |
+
else:
|
112 |
+
self._init_workers()
|
113 |
+
|
114 |
+
# Profile the memory usage and initialize the cache.
|
115 |
+
self._init_cache()
|
116 |
+
|
117 |
+
# Create the scheduler.
|
118 |
+
self.scheduler = Scheduler(scheduler_config, cache_config)
|
119 |
+
|
120 |
+
# Logging.
|
121 |
+
self.last_logging_time = 0.0
|
122 |
+
# List of (timestamp, num_tokens)
|
123 |
+
self.num_prompt_tokens: List[Tuple[float, int]] = []
|
124 |
+
# List of (timestamp, num_tokens)
|
125 |
+
self.num_generation_tokens: List[Tuple[float, int]] = []
|
126 |
+
|
127 |
+
def _init_workers(self):
|
128 |
+
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
129 |
+
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
130 |
+
from .worker import Worker
|
131 |
+
|
132 |
+
assert (
|
133 |
+
self.parallel_config.world_size == 1
|
134 |
+
), "Ray is required if parallel_config.world_size > 1."
|
135 |
+
|
136 |
+
self.workers: List[Worker] = []
|
137 |
+
distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}"
|
138 |
+
self.driver_worker = Worker(
|
139 |
+
self.model_config,
|
140 |
+
self.parallel_config,
|
141 |
+
self.scheduler_config,
|
142 |
+
local_rank=0,
|
143 |
+
rank=0,
|
144 |
+
distributed_init_method=distributed_init_method,
|
145 |
+
is_driver_worker=True,
|
146 |
+
post_model_path=self.post_model_path,
|
147 |
+
)
|
148 |
+
self._run_workers("init_model")
|
149 |
+
self._run_workers("load_model")
|
150 |
+
|
151 |
+
def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs):
|
152 |
+
if self.parallel_config.tensor_parallel_size == 1:
|
153 |
+
num_gpus = self.cache_config.gpu_memory_utilization
|
154 |
+
else:
|
155 |
+
num_gpus = 1
|
156 |
+
|
157 |
+
self.driver_dummy_worker: RayWorkerVllm = None
|
158 |
+
self.workers: List[RayWorkerVllm] = []
|
159 |
+
|
160 |
+
driver_ip = get_ip()
|
161 |
+
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
162 |
+
if not bundle.get("GPU", 0):
|
163 |
+
continue
|
164 |
+
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
165 |
+
placement_group=placement_group,
|
166 |
+
placement_group_capture_child_tasks=True,
|
167 |
+
placement_group_bundle_index=bundle_id,
|
168 |
+
)
|
169 |
+
worker = ray.remote(
|
170 |
+
num_cpus=0,
|
171 |
+
num_gpus=num_gpus,
|
172 |
+
scheduling_strategy=scheduling_strategy,
|
173 |
+
**ray_remote_kwargs,
|
174 |
+
)(RayWorkerVllm).remote(self.model_config.trust_remote_code)
|
175 |
+
|
176 |
+
worker_ip = ray.get(worker.get_node_ip.remote())
|
177 |
+
if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
178 |
+
# If the worker is on the same node as the driver, we use it
|
179 |
+
# as the resource holder for the driver process.
|
180 |
+
self.driver_dummy_worker = worker
|
181 |
+
else:
|
182 |
+
self.workers.append(worker)
|
183 |
+
|
184 |
+
if self.driver_dummy_worker is None:
|
185 |
+
raise ValueError(
|
186 |
+
"Ray does not allocate any GPUs on the driver node. Consider "
|
187 |
+
"adjusting the Ray placement group or running the driver on a "
|
188 |
+
"GPU node."
|
189 |
+
)
|
190 |
+
|
191 |
+
driver_node_id, driver_gpu_ids = ray.get(
|
192 |
+
self.driver_dummy_worker.get_node_and_gpu_ids.remote()
|
193 |
+
)
|
194 |
+
worker_node_and_gpu_ids = ray.get(
|
195 |
+
[worker.get_node_and_gpu_ids.remote() for worker in self.workers]
|
196 |
+
)
|
197 |
+
|
198 |
+
node_workers = defaultdict(list)
|
199 |
+
node_gpus = defaultdict(list)
|
200 |
+
|
201 |
+
node_workers[driver_node_id].append(0)
|
202 |
+
node_gpus[driver_node_id].extend(driver_gpu_ids)
|
203 |
+
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids, start=1):
|
204 |
+
node_workers[node_id].append(i)
|
205 |
+
node_gpus[node_id].extend(gpu_ids)
|
206 |
+
for node_id, gpu_ids in node_gpus.items():
|
207 |
+
node_gpus[node_id] = sorted(gpu_ids)
|
208 |
+
|
209 |
+
# Set CUDA_VISIBLE_DEVICES for the driver.
|
210 |
+
set_cuda_visible_devices(node_gpus[driver_node_id])
|
211 |
+
for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
|
212 |
+
worker.set_cuda_visible_devices.remote(node_gpus[node_id])
|
213 |
+
|
214 |
+
distributed_init_method = f"tcp://{driver_ip}:{get_open_port()}"
|
215 |
+
|
216 |
+
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
217 |
+
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
218 |
+
from vllm.worker.worker import Worker
|
219 |
+
|
220 |
+
# Initialize torch distributed process group for the workers.
|
221 |
+
model_config = copy.deepcopy(self.model_config)
|
222 |
+
parallel_config = copy.deepcopy(self.parallel_config)
|
223 |
+
scheduler_config = copy.deepcopy(self.scheduler_config)
|
224 |
+
|
225 |
+
for rank, (worker, (node_id, _)) in enumerate(
|
226 |
+
zip(self.workers, worker_node_and_gpu_ids), start=1
|
227 |
+
):
|
228 |
+
local_rank = node_workers[node_id].index(rank)
|
229 |
+
worker.init_worker.remote(
|
230 |
+
lambda rank=rank, local_rank=local_rank: Worker(
|
231 |
+
model_config,
|
232 |
+
parallel_config,
|
233 |
+
scheduler_config,
|
234 |
+
local_rank,
|
235 |
+
rank,
|
236 |
+
distributed_init_method,
|
237 |
+
)
|
238 |
+
)
|
239 |
+
|
240 |
+
driver_rank = 0
|
241 |
+
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
|
242 |
+
self.driver_worker = Worker(
|
243 |
+
model_config,
|
244 |
+
parallel_config,
|
245 |
+
scheduler_config,
|
246 |
+
driver_local_rank,
|
247 |
+
driver_rank,
|
248 |
+
distributed_init_method,
|
249 |
+
is_driver_worker=True,
|
250 |
+
)
|
251 |
+
|
252 |
+
self._run_workers("init_model")
|
253 |
+
self._run_workers(
|
254 |
+
"load_model",
|
255 |
+
max_concurrent_workers=self.parallel_config.max_parallel_loading_workers,
|
256 |
+
)
|
257 |
+
|
258 |
+
def _verify_args(self) -> None:
|
259 |
+
self.model_config.verify_with_parallel_config(self.parallel_config)
|
260 |
+
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
261 |
+
|
262 |
+
def _init_cache(self) -> None:
|
263 |
+
"""Profiles the memory usage and initializes the KV cache."""
|
264 |
+
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
265 |
+
num_blocks = self._run_workers(
|
266 |
+
"profile_num_available_blocks",
|
267 |
+
block_size=self.cache_config.block_size,
|
268 |
+
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
|
269 |
+
cpu_swap_space=self.cache_config.swap_space_bytes,
|
270 |
+
)
|
271 |
+
|
272 |
+
# Since we use a shared centralized controller, we take the minimum
|
273 |
+
# number of blocks across all workers to make sure all the memory
|
274 |
+
# operators can be applied to all workers.
|
275 |
+
num_gpu_blocks = min(b[0] for b in num_blocks)
|
276 |
+
num_cpu_blocks = min(b[1] for b in num_blocks)
|
277 |
+
# FIXME(woosuk): Change to debug log.
|
278 |
+
logger.info(
|
279 |
+
f"# GPU blocks: {num_gpu_blocks}, " f"# CPU blocks: {num_cpu_blocks}"
|
280 |
+
)
|
281 |
+
|
282 |
+
if num_gpu_blocks <= 0:
|
283 |
+
raise ValueError(
|
284 |
+
"No available memory for the cache blocks. "
|
285 |
+
"Try increasing `gpu_memory_utilization` when "
|
286 |
+
"initializing the engine."
|
287 |
+
)
|
288 |
+
max_seq_len = self.cache_config.block_size * num_gpu_blocks
|
289 |
+
if self.model_config.max_model_len > max_seq_len:
|
290 |
+
raise ValueError(
|
291 |
+
f"The model's max seq len ({self.model_config.max_model_len}) "
|
292 |
+
"is larger than the maximum number of tokens that can be "
|
293 |
+
f"stored in KV cache ({max_seq_len}). Try increasing "
|
294 |
+
"`gpu_memory_utilization` or decreasing `max_model_len` when "
|
295 |
+
"initializing the engine."
|
296 |
+
)
|
297 |
+
|
298 |
+
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
299 |
+
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
300 |
+
|
301 |
+
# Initialize the cache.
|
302 |
+
self._run_workers("init_cache_engine", cache_config=self.cache_config)
|
303 |
+
# Warm up the model. This includes capturing the model into CUDA graph
|
304 |
+
# if enforce_eager is False.
|
305 |
+
self._run_workers("warm_up_model")
|
306 |
+
|
307 |
+
@classmethod
|
308 |
+
def from_engine_args(
|
309 |
+
cls, engine_args: EngineArgs, post_model_path=None
|
310 |
+
) -> "LLMEngine":
|
311 |
+
"""Creates an LLM engine from the engine arguments."""
|
312 |
+
# Create the engine configs.
|
313 |
+
engine_configs = engine_args.create_engine_configs()
|
314 |
+
parallel_config = engine_configs[2]
|
315 |
+
# Initialize the cluster.
|
316 |
+
placement_group = initialize_cluster(parallel_config)
|
317 |
+
# Create the LLM engine.
|
318 |
+
engine = cls(
|
319 |
+
*engine_configs,
|
320 |
+
placement_group,
|
321 |
+
log_stats=not engine_args.disable_log_stats,
|
322 |
+
post_model_path=post_model_path,
|
323 |
+
)
|
324 |
+
return engine
|
325 |
+
|
326 |
+
def add_request(
|
327 |
+
self,
|
328 |
+
request_id: str,
|
329 |
+
prompt: Optional[str],
|
330 |
+
sampling_params: SamplingParams,
|
331 |
+
prompt_token_ids: Optional[List[int]] = None,
|
332 |
+
arrival_time: Optional[float] = None,
|
333 |
+
) -> None:
|
334 |
+
"""Add a request to the engine's request pool.
|
335 |
+
|
336 |
+
The request is added to the request pool and will be processed by the
|
337 |
+
scheduler as `engine.step()` is called. The exact scheduling policy is
|
338 |
+
determined by the scheduler.
|
339 |
+
|
340 |
+
Args:
|
341 |
+
request_id: The unique ID of the request.
|
342 |
+
prompt: The prompt string. Can be None if prompt_token_ids is
|
343 |
+
provided.
|
344 |
+
sampling_params: The sampling parameters for text generation.
|
345 |
+
prompt_token_ids: The token IDs of the prompt. If None, we
|
346 |
+
use the tokenizer to convert the prompts to token IDs.
|
347 |
+
arrival_time: The arrival time of the request. If None, we use
|
348 |
+
the current monotonic time.
|
349 |
+
"""
|
350 |
+
if arrival_time is None:
|
351 |
+
arrival_time = time.monotonic()
|
352 |
+
|
353 |
+
assert prompt_token_ids is not None, "prompt_token_ids must be provided"
|
354 |
+
# Create the sequences.
|
355 |
+
block_size = self.cache_config.block_size
|
356 |
+
seq_id = next(self.seq_counter)
|
357 |
+
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
|
358 |
+
|
359 |
+
# Create the sequence group.
|
360 |
+
seq_group = SequenceGroup(request_id, [seq], sampling_params, arrival_time)
|
361 |
+
|
362 |
+
# Add the sequence group to the scheduler.
|
363 |
+
self.scheduler.add_seq_group(seq_group)
|
364 |
+
|
365 |
+
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
|
366 |
+
"""Aborts a request(s) with the given ID.
|
367 |
+
|
368 |
+
Args:
|
369 |
+
request_id: The ID(s) of the request to abort.
|
370 |
+
"""
|
371 |
+
self.scheduler.abort_seq_group(request_id)
|
372 |
+
|
373 |
+
def get_model_config(self) -> ModelConfig:
|
374 |
+
"""Gets the model configuration."""
|
375 |
+
return self.model_config
|
376 |
+
|
377 |
+
def get_num_unfinished_requests(self) -> int:
|
378 |
+
"""Gets the number of unfinished requests."""
|
379 |
+
return self.scheduler.get_num_unfinished_seq_groups()
|
380 |
+
|
381 |
+
def has_unfinished_requests(self) -> bool:
|
382 |
+
"""Returns True if there are unfinished requests."""
|
383 |
+
return self.scheduler.has_unfinished_seqs()
|
384 |
+
|
385 |
+
def _check_beam_search_early_stopping(
|
386 |
+
self,
|
387 |
+
early_stopping: Union[bool, str],
|
388 |
+
sampling_params: SamplingParams,
|
389 |
+
best_running_seq: Sequence,
|
390 |
+
current_worst_seq: Sequence,
|
391 |
+
) -> bool:
|
392 |
+
assert sampling_params.use_beam_search
|
393 |
+
length_penalty = sampling_params.length_penalty
|
394 |
+
if early_stopping is True:
|
395 |
+
return True
|
396 |
+
|
397 |
+
current_worst_score = current_worst_seq.get_beam_search_score(
|
398 |
+
length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id
|
399 |
+
)
|
400 |
+
if early_stopping is False:
|
401 |
+
highest_attainable_score = best_running_seq.get_beam_search_score(
|
402 |
+
length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id
|
403 |
+
)
|
404 |
+
else:
|
405 |
+
assert early_stopping == "never"
|
406 |
+
if length_penalty > 0.0:
|
407 |
+
# If length_penalty > 0.0, beam search will prefer longer
|
408 |
+
# sequences. The highest attainable score calculation is
|
409 |
+
# based on the longest possible sequence length in this case.
|
410 |
+
max_possible_length = max(
|
411 |
+
best_running_seq.get_prompt_len() + sampling_params.max_tokens,
|
412 |
+
self.scheduler_config.max_model_len,
|
413 |
+
)
|
414 |
+
highest_attainable_score = best_running_seq.get_beam_search_score(
|
415 |
+
length_penalty=length_penalty,
|
416 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
417 |
+
seq_len=max_possible_length,
|
418 |
+
)
|
419 |
+
else:
|
420 |
+
# Otherwise, beam search will prefer shorter sequences. The
|
421 |
+
# highest attainable score calculation is based on the current
|
422 |
+
# sequence length.
|
423 |
+
highest_attainable_score = best_running_seq.get_beam_search_score(
|
424 |
+
length_penalty=length_penalty,
|
425 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
426 |
+
)
|
427 |
+
return current_worst_score >= highest_attainable_score
|
428 |
+
|
429 |
+
def _process_sequence_group_outputs(
|
430 |
+
self, seq_group: SequenceGroup, outputs: SequenceGroupOutput
|
431 |
+
) -> None:
|
432 |
+
# Process prompt logprobs
|
433 |
+
prompt_logprobs = outputs.prompt_logprobs
|
434 |
+
if prompt_logprobs is not None:
|
435 |
+
seq_group.prompt_logprobs = prompt_logprobs
|
436 |
+
|
437 |
+
# Process samples
|
438 |
+
samples = outputs.samples
|
439 |
+
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
440 |
+
existing_finished_seqs = seq_group.get_finished_seqs()
|
441 |
+
parent_child_dict = {parent_seq.seq_id: [] for parent_seq in parent_seqs}
|
442 |
+
for sample in samples:
|
443 |
+
parent_child_dict[sample.parent_seq_id].append(sample)
|
444 |
+
# List of (child, parent)
|
445 |
+
child_seqs: List[Tuple[Sequence, Sequence]] = []
|
446 |
+
|
447 |
+
# Process the child samples for each parent sequence
|
448 |
+
for parent in parent_seqs:
|
449 |
+
child_samples: List[SequenceOutput] = parent_child_dict[parent.seq_id]
|
450 |
+
if len(child_samples) == 0:
|
451 |
+
# This parent sequence has no children samples. Remove
|
452 |
+
# the parent sequence from the sequence group since it will
|
453 |
+
# not be used in the future iterations.
|
454 |
+
parent.status = SequenceStatus.FINISHED_ABORTED
|
455 |
+
seq_group.remove(parent.seq_id)
|
456 |
+
self.scheduler.free_seq(parent)
|
457 |
+
continue
|
458 |
+
# Fork the parent sequence if there are multiple child samples.
|
459 |
+
for child_sample in child_samples[:-1]:
|
460 |
+
new_child_seq_id = next(self.seq_counter)
|
461 |
+
child = parent.fork(new_child_seq_id)
|
462 |
+
child.append_token_id(
|
463 |
+
child_sample.output_token,
|
464 |
+
child_sample.logprobs,
|
465 |
+
child_sample.hidden_states,
|
466 |
+
child_sample.finished,
|
467 |
+
)
|
468 |
+
child_seqs.append((child, parent))
|
469 |
+
# Continue the parent sequence for the last child sample.
|
470 |
+
# We reuse the parent sequence here to reduce redundant memory
|
471 |
+
# copies, especially when using non-beam search sampling methods.
|
472 |
+
last_child_sample = child_samples[-1]
|
473 |
+
parent.append_token_id(
|
474 |
+
last_child_sample.output_token,
|
475 |
+
last_child_sample.logprobs,
|
476 |
+
last_child_sample.hidden_states,
|
477 |
+
last_child_sample.finished,
|
478 |
+
)
|
479 |
+
child_seqs.append((parent, parent))
|
480 |
+
|
481 |
+
for seq, _ in child_seqs:
|
482 |
+
# self._decode_sequence(seq, seq_group.sampling_params)
|
483 |
+
self._check_stop(seq, seq_group.sampling_params)
|
484 |
+
|
485 |
+
# Non-beam search case
|
486 |
+
if not seq_group.sampling_params.use_beam_search:
|
487 |
+
# For newly created child sequences, add them to the sequence group
|
488 |
+
# and fork them in block manager if they are not finished.
|
489 |
+
for seq, parent in child_seqs:
|
490 |
+
if seq is not parent:
|
491 |
+
seq_group.add(seq)
|
492 |
+
if not seq.is_finished():
|
493 |
+
self.scheduler.fork_seq(parent, seq)
|
494 |
+
|
495 |
+
# Free the finished and selected parent sequences' memory in block
|
496 |
+
# manager. Keep them in the sequence group as candidate output.
|
497 |
+
# NOTE: we need to fork the new sequences before freeing the
|
498 |
+
# old sequences.
|
499 |
+
for seq, parent in child_seqs:
|
500 |
+
if seq is parent and seq.is_finished():
|
501 |
+
self.scheduler.free_seq(seq)
|
502 |
+
return
|
503 |
+
|
504 |
+
# Beam search case
|
505 |
+
# Select the child sequences to keep in the sequence group.
|
506 |
+
selected_child_seqs = []
|
507 |
+
unselected_child_seqs = []
|
508 |
+
beam_width = seq_group.sampling_params.best_of
|
509 |
+
length_penalty = seq_group.sampling_params.length_penalty
|
510 |
+
|
511 |
+
# Select the newly finished sequences with the highest scores
|
512 |
+
# to replace existing finished sequences.
|
513 |
+
# Tuple of (seq, parent, is_new)
|
514 |
+
existing_finished_seqs = [(seq, None, False) for seq in existing_finished_seqs]
|
515 |
+
new_finished_seqs = [
|
516 |
+
(seq, parent, True) for seq, parent in child_seqs if seq.is_finished()
|
517 |
+
]
|
518 |
+
all_finished_seqs = existing_finished_seqs + new_finished_seqs
|
519 |
+
# Sort the finished sequences by their scores.
|
520 |
+
all_finished_seqs.sort(
|
521 |
+
key=lambda x: x[0].get_beam_search_score(
|
522 |
+
length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id
|
523 |
+
),
|
524 |
+
reverse=True,
|
525 |
+
)
|
526 |
+
for seq, parent, is_new in all_finished_seqs[:beam_width]:
|
527 |
+
if is_new:
|
528 |
+
# A newly generated child sequence finishes and has a high
|
529 |
+
# score, so we will add it into the sequence group.
|
530 |
+
selected_child_seqs.append((seq, parent))
|
531 |
+
for seq, parent, is_new in all_finished_seqs[beam_width:]:
|
532 |
+
if is_new:
|
533 |
+
# A newly generated child sequence finishes but has a low
|
534 |
+
# score, so we will not add it into the sequence group.
|
535 |
+
# Additionally, if this sequence is a continuation of a
|
536 |
+
# parent sequence, we will need remove the parent sequence
|
537 |
+
# from the sequence group.
|
538 |
+
unselected_child_seqs.append((seq, parent))
|
539 |
+
else:
|
540 |
+
# An existing finished sequence has a low score, so we will
|
541 |
+
# remove it from the sequence group.
|
542 |
+
seq_group.remove(seq.seq_id)
|
543 |
+
|
544 |
+
# select the top beam_width sequences from the running
|
545 |
+
# sequences for the next iteration to continue the beam
|
546 |
+
# search.
|
547 |
+
running_child_seqs = [
|
548 |
+
(seq, parent) for seq, parent in child_seqs if not seq.is_finished()
|
549 |
+
]
|
550 |
+
# Sort the running sequences by their scores.
|
551 |
+
running_child_seqs.sort(
|
552 |
+
key=lambda x: x[0].get_beam_search_score(
|
553 |
+
length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id
|
554 |
+
),
|
555 |
+
reverse=True,
|
556 |
+
)
|
557 |
+
|
558 |
+
# Check if we can stop the beam search.
|
559 |
+
if len(running_child_seqs) == 0:
|
560 |
+
# No running sequences, stop the beam search.
|
561 |
+
stop_beam_search = True
|
562 |
+
elif len(all_finished_seqs) < beam_width:
|
563 |
+
# Not enough finished sequences, continue the beam search.
|
564 |
+
stop_beam_search = False
|
565 |
+
else:
|
566 |
+
# Check the early stopping criteria
|
567 |
+
best_running_seq = running_child_seqs[0][0]
|
568 |
+
current_worst_seq = all_finished_seqs[beam_width - 1][0]
|
569 |
+
stop_beam_search = self._check_beam_search_early_stopping(
|
570 |
+
seq_group.sampling_params.early_stopping,
|
571 |
+
seq_group.sampling_params,
|
572 |
+
best_running_seq,
|
573 |
+
current_worst_seq,
|
574 |
+
)
|
575 |
+
|
576 |
+
if stop_beam_search:
|
577 |
+
# Stop the beam search and remove all the running sequences from
|
578 |
+
# the sequence group.
|
579 |
+
unselected_child_seqs.extend(running_child_seqs)
|
580 |
+
else:
|
581 |
+
# Continue the beam search and select the top beam_width sequences
|
582 |
+
# to continue the beam search.
|
583 |
+
selected_child_seqs.extend(running_child_seqs[:beam_width])
|
584 |
+
# The remaining running sequences will not be used in the next
|
585 |
+
# iteration. Again, if these sequences are continuations of
|
586 |
+
# parent sequences, we will need to remove the parent sequences
|
587 |
+
# from the sequence group.
|
588 |
+
unselected_child_seqs.extend(running_child_seqs[beam_width:])
|
589 |
+
|
590 |
+
# For newly created child sequences, add them to the sequence group
|
591 |
+
# and fork them in block manager if they are not finished.
|
592 |
+
for seq, parent in selected_child_seqs:
|
593 |
+
if seq is not parent:
|
594 |
+
seq_group.add(seq)
|
595 |
+
if not seq.is_finished():
|
596 |
+
self.scheduler.fork_seq(parent, seq)
|
597 |
+
|
598 |
+
# Free the finished and selected parent sequences' memory in block
|
599 |
+
# manager. Keep them in the sequence group as candidate output.
|
600 |
+
for seq, parent in selected_child_seqs:
|
601 |
+
if seq is parent and seq.is_finished():
|
602 |
+
self.scheduler.free_seq(seq)
|
603 |
+
|
604 |
+
# Remove the unselected parent sequences from the sequence group and
|
605 |
+
# free their memory in block manager.
|
606 |
+
for seq, parent in unselected_child_seqs:
|
607 |
+
if seq is parent:
|
608 |
+
# Remove the parent sequence if it is not selected for next
|
609 |
+
# iteration
|
610 |
+
seq_group.remove(seq.seq_id)
|
611 |
+
self.scheduler.free_seq(seq)
|
612 |
+
|
613 |
+
def _process_model_outputs(
|
614 |
+
self, output: SamplerOutput, scheduler_outputs: SchedulerOutputs
|
615 |
+
) -> List[RequestOutput]:
|
616 |
+
# Update the scheduled sequence groups with the model outputs.
|
617 |
+
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
|
618 |
+
for seq_group, outputs in zip(scheduled_seq_groups, output):
|
619 |
+
self._process_sequence_group_outputs(seq_group, outputs)
|
620 |
+
|
621 |
+
# Free the finished sequence groups.
|
622 |
+
self.scheduler.free_finished_seq_groups()
|
623 |
+
|
624 |
+
# Create the outputs.
|
625 |
+
request_outputs: List[RequestOutput] = []
|
626 |
+
for seq_group in scheduled_seq_groups + scheduler_outputs.ignored_seq_groups:
|
627 |
+
request_output = RequestOutput.from_seq_group(seq_group)
|
628 |
+
request_outputs.append(request_output)
|
629 |
+
|
630 |
+
if self.log_stats:
|
631 |
+
# Log the system stats.
|
632 |
+
self._log_system_stats(
|
633 |
+
scheduler_outputs.prompt_run, scheduler_outputs.num_batched_tokens
|
634 |
+
)
|
635 |
+
return request_outputs
|
636 |
+
|
637 |
+
def step(self) -> List[RequestOutput]:
|
638 |
+
"""Performs one decoding iteration and returns newly generated results.
|
639 |
+
|
640 |
+
This function performs one decoding iteration of the engine. It first
|
641 |
+
schedules the sequences to be executed in the next iteration and the
|
642 |
+
token blocks to be swapped in/out/copy. Then, it executes the model
|
643 |
+
and updates the scheduler with the model outputs. Finally, it decodes
|
644 |
+
the sequences and returns the newly generated results.
|
645 |
+
"""
|
646 |
+
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
647 |
+
|
648 |
+
if not scheduler_outputs.is_empty():
|
649 |
+
# Execute the model.
|
650 |
+
all_outputs = self._run_workers(
|
651 |
+
"execute_model",
|
652 |
+
driver_kwargs={
|
653 |
+
"seq_group_metadata_list": seq_group_metadata_list,
|
654 |
+
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
|
655 |
+
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
|
656 |
+
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
|
657 |
+
},
|
658 |
+
)
|
659 |
+
|
660 |
+
# Only the driver worker returns the sampling results.
|
661 |
+
output = all_outputs[0]
|
662 |
+
else:
|
663 |
+
output = []
|
664 |
+
|
665 |
+
return self._process_model_outputs(output, scheduler_outputs)
|
666 |
+
|
667 |
+
def _log_system_stats(
|
668 |
+
self,
|
669 |
+
prompt_run: bool,
|
670 |
+
num_batched_tokens: int,
|
671 |
+
) -> None:
|
672 |
+
now = time.monotonic()
|
673 |
+
# Log the number of batched input tokens.
|
674 |
+
if prompt_run:
|
675 |
+
self.num_prompt_tokens.append((now, num_batched_tokens))
|
676 |
+
else:
|
677 |
+
self.num_generation_tokens.append((now, num_batched_tokens))
|
678 |
+
|
679 |
+
should_log = now - self.last_logging_time >= _LOGGING_INTERVAL_SEC
|
680 |
+
if not should_log:
|
681 |
+
return
|
682 |
+
|
683 |
+
# Discard the old stats.
|
684 |
+
self.num_prompt_tokens = [
|
685 |
+
(t, n) for t, n in self.num_prompt_tokens if now - t < _LOGGING_INTERVAL_SEC
|
686 |
+
]
|
687 |
+
self.num_generation_tokens = [
|
688 |
+
(t, n)
|
689 |
+
for t, n in self.num_generation_tokens
|
690 |
+
if now - t < _LOGGING_INTERVAL_SEC
|
691 |
+
]
|
692 |
+
|
693 |
+
if len(self.num_prompt_tokens) > 1:
|
694 |
+
total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1])
|
695 |
+
window = now - self.num_prompt_tokens[0][0]
|
696 |
+
avg_prompt_throughput = total_num_tokens / window
|
697 |
+
else:
|
698 |
+
avg_prompt_throughput = 0.0
|
699 |
+
if len(self.num_generation_tokens) > 1:
|
700 |
+
total_num_tokens = sum(n for _, n in self.num_generation_tokens[:-1])
|
701 |
+
window = now - self.num_generation_tokens[0][0]
|
702 |
+
avg_generation_throughput = total_num_tokens / window
|
703 |
+
else:
|
704 |
+
avg_generation_throughput = 0.0
|
705 |
+
|
706 |
+
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
|
707 |
+
num_free_gpu_blocks = self.scheduler.block_manager.get_num_free_gpu_blocks()
|
708 |
+
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
|
709 |
+
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
|
710 |
+
|
711 |
+
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
|
712 |
+
if total_num_cpu_blocks > 0:
|
713 |
+
num_free_cpu_blocks = self.scheduler.block_manager.get_num_free_cpu_blocks()
|
714 |
+
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
|
715 |
+
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
|
716 |
+
else:
|
717 |
+
cpu_cache_usage = 0.0
|
718 |
+
|
719 |
+
record_metrics(
|
720 |
+
avg_prompt_throughput=avg_prompt_throughput,
|
721 |
+
avg_generation_throughput=avg_generation_throughput,
|
722 |
+
scheduler_running=len(self.scheduler.running),
|
723 |
+
scheduler_swapped=len(self.scheduler.swapped),
|
724 |
+
scheduler_waiting=len(self.scheduler.waiting),
|
725 |
+
gpu_cache_usage=gpu_cache_usage,
|
726 |
+
cpu_cache_usage=cpu_cache_usage,
|
727 |
+
)
|
728 |
+
|
729 |
+
logger.info(
|
730 |
+
"Avg prompt throughput: "
|
731 |
+
f"{avg_prompt_throughput:.1f} tokens/s, "
|
732 |
+
"Avg generation throughput: "
|
733 |
+
f"{avg_generation_throughput:.1f} tokens/s, "
|
734 |
+
f"Running: {len(self.scheduler.running)} reqs, "
|
735 |
+
f"Swapped: {len(self.scheduler.swapped)} reqs, "
|
736 |
+
f"Pending: {len(self.scheduler.waiting)} reqs, "
|
737 |
+
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
|
738 |
+
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%"
|
739 |
+
)
|
740 |
+
self.last_logging_time = now
|
741 |
+
|
742 |
+
def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
|
743 |
+
"""Decodes the new token for a sequence."""
|
744 |
+
(new_tokens, new_output_text, prefix_offset, read_offset) = (
|
745 |
+
detokenize_incrementally(
|
746 |
+
self.tokenizer,
|
747 |
+
all_input_ids=seq.get_token_ids(),
|
748 |
+
prev_tokens=seq.tokens,
|
749 |
+
prefix_offset=seq.prefix_offset,
|
750 |
+
read_offset=seq.read_offset,
|
751 |
+
skip_special_tokens=prms.skip_special_tokens,
|
752 |
+
spaces_between_special_tokens=prms.spaces_between_special_tokens,
|
753 |
+
)
|
754 |
+
)
|
755 |
+
if seq.tokens is None:
|
756 |
+
seq.tokens = new_tokens
|
757 |
+
else:
|
758 |
+
seq.tokens.extend(new_tokens)
|
759 |
+
seq.prefix_offset = prefix_offset
|
760 |
+
seq.read_offset = read_offset
|
761 |
+
seq.output_text += new_output_text
|
762 |
+
|
763 |
+
def _check_stop(self, seq: Sequence, sampling_params: SamplingParams) -> None:
|
764 |
+
"""Stop the finished sequences."""
|
765 |
+
for stop_str in sampling_params.stop:
|
766 |
+
if seq.output_text.endswith(stop_str):
|
767 |
+
if not sampling_params.include_stop_str_in_output:
|
768 |
+
# Truncate the output text so that the stop string is
|
769 |
+
# not included in the output.
|
770 |
+
seq.output_text = seq.output_text[: -len(stop_str)]
|
771 |
+
seq.status = SequenceStatus.FINISHED_STOPPED
|
772 |
+
return
|
773 |
+
if seq.data.finished:
|
774 |
+
seq.status = SequenceStatus.FINISHED_STOPPED
|
775 |
+
return
|
776 |
+
|
777 |
+
for token_id in seq.get_last_token_id():
|
778 |
+
if token_id == sampling_params.eos_token:
|
779 |
+
seq.status = SequenceStatus.FINISHED_STOPPED
|
780 |
+
return
|
781 |
+
|
782 |
+
# Check if the sequence has reached max_model_len.
|
783 |
+
if seq.get_len() > self.scheduler_config.max_model_len:
|
784 |
+
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
785 |
+
return
|
786 |
+
|
787 |
+
# Check if the sequence has reached max_tokens.
|
788 |
+
if seq.get_output_len() == sampling_params.max_tokens:
|
789 |
+
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
790 |
+
return
|
791 |
+
|
792 |
+
# Check if the sequence has generated the EOS token.
|
793 |
+
if (not sampling_params.ignore_eos) and seq.get_last_token_id()[
|
794 |
+
0
|
795 |
+
] == sampling_params.eos_token:
|
796 |
+
seq.status = SequenceStatus.FINISHED_STOPPED
|
797 |
+
return
|
798 |
+
|
799 |
+
def _run_workers(
|
800 |
+
self,
|
801 |
+
method: str,
|
802 |
+
*args,
|
803 |
+
driver_args: Optional[List[Any]] = None,
|
804 |
+
driver_kwargs: Optional[Dict[str, Any]] = None,
|
805 |
+
max_concurrent_workers: Optional[int] = None,
|
806 |
+
**kwargs,
|
807 |
+
) -> Any:
|
808 |
+
"""Runs the given method on all workers."""
|
809 |
+
|
810 |
+
if max_concurrent_workers:
|
811 |
+
raise NotImplementedError("max_concurrent_workers is not supported yet.")
|
812 |
+
|
813 |
+
# Start the ray workers first.
|
814 |
+
ray_worker_outputs = [
|
815 |
+
worker.execute_method.remote(method, *args, **kwargs)
|
816 |
+
for worker in self.workers
|
817 |
+
]
|
818 |
+
|
819 |
+
if driver_args is None:
|
820 |
+
driver_args = args
|
821 |
+
if driver_kwargs is None:
|
822 |
+
driver_kwargs = kwargs
|
823 |
+
|
824 |
+
# Start the driver worker after all the ray workers.
|
825 |
+
driver_worker_output = getattr(self.driver_worker, method)(
|
826 |
+
*driver_args, **driver_kwargs
|
827 |
+
)
|
828 |
+
|
829 |
+
# Get the results of the ray workers.
|
830 |
+
if self.workers:
|
831 |
+
ray_worker_outputs = ray.get(ray_worker_outputs)
|
832 |
+
|
833 |
+
return [driver_worker_output] + ray_worker_outputs
|
ChatTTS/model/velocity/model_loader.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utilities for selecting and loading models."""
|
2 |
+
|
3 |
+
import contextlib
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from vllm.config import ModelConfig
|
9 |
+
from vllm.model_executor.models import ModelRegistry
|
10 |
+
from vllm.model_executor.weight_utils import get_quant_config, initialize_dummy_weights
|
11 |
+
|
12 |
+
from .llama import LlamaModel
|
13 |
+
|
14 |
+
|
15 |
+
@contextlib.contextmanager
|
16 |
+
def _set_default_torch_dtype(dtype: torch.dtype):
|
17 |
+
"""Sets the default torch dtype to the given dtype."""
|
18 |
+
old_dtype = torch.get_default_dtype()
|
19 |
+
torch.set_default_dtype(dtype)
|
20 |
+
yield
|
21 |
+
torch.set_default_dtype(old_dtype)
|
22 |
+
|
23 |
+
|
24 |
+
def get_model(model_config: ModelConfig) -> nn.Module:
|
25 |
+
# Get the (maybe quantized) linear method.
|
26 |
+
linear_method = None
|
27 |
+
if model_config.quantization is not None:
|
28 |
+
quant_config = get_quant_config(
|
29 |
+
model_config.quantization,
|
30 |
+
model_config.model,
|
31 |
+
model_config.hf_config,
|
32 |
+
model_config.download_dir,
|
33 |
+
)
|
34 |
+
capability = torch.cuda.get_device_capability()
|
35 |
+
capability = capability[0] * 10 + capability[1]
|
36 |
+
if capability < quant_config.get_min_capability():
|
37 |
+
raise ValueError(
|
38 |
+
f"The quantization method {model_config.quantization} is not "
|
39 |
+
"supported for the current GPU. "
|
40 |
+
f"Minimum capability: {quant_config.get_min_capability()}. "
|
41 |
+
f"Current capability: {capability}."
|
42 |
+
)
|
43 |
+
supported_dtypes = quant_config.get_supported_act_dtypes()
|
44 |
+
if model_config.dtype not in supported_dtypes:
|
45 |
+
raise ValueError(
|
46 |
+
f"{model_config.dtype} is not supported for quantization "
|
47 |
+
f"method {model_config.quantization}. Supported dtypes: "
|
48 |
+
f"{supported_dtypes}"
|
49 |
+
)
|
50 |
+
linear_method = quant_config.get_linear_method()
|
51 |
+
|
52 |
+
with _set_default_torch_dtype(model_config.dtype):
|
53 |
+
# Create a model instance.
|
54 |
+
# The weights will be initialized as empty tensors.
|
55 |
+
with torch.device("cuda"):
|
56 |
+
model = LlamaModel(model_config.hf_config, linear_method)
|
57 |
+
if model_config.load_format == "dummy":
|
58 |
+
# NOTE(woosuk): For accurate performance evaluation, we assign
|
59 |
+
# random values to the weights.
|
60 |
+
initialize_dummy_weights(model)
|
61 |
+
else:
|
62 |
+
# Load the weights from the cached or downloaded files.
|
63 |
+
model.load_weights(
|
64 |
+
model_config.model,
|
65 |
+
model_config.download_dir,
|
66 |
+
model_config.load_format,
|
67 |
+
model_config.revision,
|
68 |
+
)
|
69 |
+
return model.eval()
|
ChatTTS/model/velocity/model_runner.py
ADDED
@@ -0,0 +1,817 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from typing import Dict, List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .configs import ModelConfig, ParallelConfig, SchedulerConfig
|
9 |
+
from vllm.logger import init_logger
|
10 |
+
from .model_loader import get_model
|
11 |
+
from vllm.model_executor import InputMetadata, SamplingMetadata
|
12 |
+
from vllm.model_executor.parallel_utils.communication_op import (
|
13 |
+
broadcast,
|
14 |
+
broadcast_object_list,
|
15 |
+
)
|
16 |
+
from .sampling_params import SamplingParams, SamplingType
|
17 |
+
from .sequence import (
|
18 |
+
SamplerOutput,
|
19 |
+
SequenceData,
|
20 |
+
SequenceGroupMetadata,
|
21 |
+
SequenceGroupOutput,
|
22 |
+
SequenceOutput,
|
23 |
+
)
|
24 |
+
from vllm.utils import in_wsl
|
25 |
+
from ..embed import Embed
|
26 |
+
from .sampler import Sampler
|
27 |
+
from safetensors.torch import safe_open
|
28 |
+
|
29 |
+
logger = init_logger(__name__)
|
30 |
+
|
31 |
+
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
32 |
+
_PAD_SLOT_ID = -1
|
33 |
+
# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
|
34 |
+
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
|
35 |
+
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
|
36 |
+
|
37 |
+
|
38 |
+
class ModelRunner:
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
model_config: ModelConfig,
|
43 |
+
parallel_config: ParallelConfig,
|
44 |
+
scheduler_config: SchedulerConfig,
|
45 |
+
is_driver_worker: bool = False,
|
46 |
+
post_model_path: str = None,
|
47 |
+
):
|
48 |
+
self.model_config = model_config
|
49 |
+
self.parallel_config = parallel_config
|
50 |
+
self.scheduler_config = scheduler_config
|
51 |
+
self.is_driver_worker = is_driver_worker
|
52 |
+
self.post_model_path = post_model_path
|
53 |
+
|
54 |
+
# model_config can be None in tests/samplers/test_sampler.py.
|
55 |
+
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
|
56 |
+
self.sliding_window = (
|
57 |
+
model_config.get_sliding_window() if model_config is not None else None
|
58 |
+
)
|
59 |
+
self.model = None
|
60 |
+
self.block_size = None # Set after initial profiling.
|
61 |
+
|
62 |
+
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
63 |
+
self.graph_memory_pool = None # Set during graph capture.
|
64 |
+
|
65 |
+
self.max_context_len_to_capture = (
|
66 |
+
self.model_config.max_context_len_to_capture
|
67 |
+
if self.model_config is not None
|
68 |
+
else 0
|
69 |
+
)
|
70 |
+
# When using CUDA graph, the input block tables must be padded to
|
71 |
+
# max_context_len_to_capture. However, creating the block table in
|
72 |
+
# Python can be expensive. To optimize this, we cache the block table
|
73 |
+
# in numpy and only copy the actual input content at every iteration.
|
74 |
+
# The shape of the cached block table will be
|
75 |
+
# (max batch size to capture, max context len to capture / block size).
|
76 |
+
self.graph_block_tables = None # Set after initial profiling.
|
77 |
+
# cache in_wsl result
|
78 |
+
self.in_wsl = in_wsl()
|
79 |
+
|
80 |
+
def load_model(self) -> None:
|
81 |
+
self.model = get_model(self.model_config)
|
82 |
+
self.post_model = Embed(
|
83 |
+
self.model_config.get_hidden_size(),
|
84 |
+
self.model_config.num_audio_tokens,
|
85 |
+
self.model_config.num_text_tokens,
|
86 |
+
)
|
87 |
+
state_dict_tensors = {}
|
88 |
+
with safe_open(self.post_model_path, framework="pt", device=0) as f:
|
89 |
+
for k in f.keys():
|
90 |
+
state_dict_tensors[k] = f.get_tensor(k)
|
91 |
+
self.post_model.load_state_dict(state_dict_tensors)
|
92 |
+
self.post_model.to(next(self.model.parameters())).eval()
|
93 |
+
self.sampler = Sampler(self.post_model, self.model_config.num_audio_tokens, 4)
|
94 |
+
|
95 |
+
def set_block_size(self, block_size: int) -> None:
|
96 |
+
self.block_size = block_size
|
97 |
+
|
98 |
+
max_num_blocks = (
|
99 |
+
self.max_context_len_to_capture + block_size - 1
|
100 |
+
) // block_size
|
101 |
+
self.graph_block_tables = np.zeros(
|
102 |
+
(max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32
|
103 |
+
)
|
104 |
+
|
105 |
+
def _prepare_prompt(
|
106 |
+
self,
|
107 |
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
108 |
+
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int]]:
|
109 |
+
assert len(seq_group_metadata_list) > 0
|
110 |
+
input_tokens: List[List[int]] = []
|
111 |
+
input_positions: List[List[int]] = []
|
112 |
+
slot_mapping: List[List[int]] = []
|
113 |
+
|
114 |
+
prompt_lens: List[int] = []
|
115 |
+
for seq_group_metadata in seq_group_metadata_list:
|
116 |
+
assert seq_group_metadata.is_prompt
|
117 |
+
seq_ids = list(seq_group_metadata.seq_data.keys())
|
118 |
+
assert len(seq_ids) == 1
|
119 |
+
seq_id = seq_ids[0]
|
120 |
+
|
121 |
+
seq_data = seq_group_metadata.seq_data[seq_id]
|
122 |
+
prompt_tokens = seq_data.get_token_ids()
|
123 |
+
prompt_len = len(prompt_tokens)
|
124 |
+
prompt_lens.append(prompt_len)
|
125 |
+
|
126 |
+
input_tokens.append(prompt_tokens)
|
127 |
+
# NOTE(woosuk): Here we assume that the first token in the prompt
|
128 |
+
# is always the first token in the sequence.
|
129 |
+
input_positions.append(list(range(prompt_len)))
|
130 |
+
|
131 |
+
if seq_group_metadata.block_tables is None:
|
132 |
+
# During memory profiling, the block tables are not initialized
|
133 |
+
# yet. In this case, we just use a dummy slot mapping.
|
134 |
+
slot_mapping.append([_PAD_SLOT_ID] * prompt_len)
|
135 |
+
continue
|
136 |
+
|
137 |
+
# Compute the slot mapping.
|
138 |
+
slot_mapping.append([])
|
139 |
+
block_table = seq_group_metadata.block_tables[seq_id]
|
140 |
+
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
|
141 |
+
# where start_idx is max(0, prompt_len - sliding_window).
|
142 |
+
# For example, if the prompt len is 10, sliding window is 8, and
|
143 |
+
# block size is 4, the first two tokens are masked and the slot
|
144 |
+
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
145 |
+
start_idx = 0
|
146 |
+
if self.sliding_window is not None:
|
147 |
+
start_idx = max(0, prompt_len - self.sliding_window)
|
148 |
+
for i in range(prompt_len):
|
149 |
+
if i < start_idx:
|
150 |
+
slot_mapping[-1].append(_PAD_SLOT_ID)
|
151 |
+
continue
|
152 |
+
|
153 |
+
block_number = block_table[i // self.block_size]
|
154 |
+
block_offset = i % self.block_size
|
155 |
+
slot = block_number * self.block_size + block_offset
|
156 |
+
slot_mapping[-1].append(slot)
|
157 |
+
|
158 |
+
max_prompt_len = max(prompt_lens)
|
159 |
+
input_tokens = _make_tensor_with_pad(
|
160 |
+
input_tokens, max_prompt_len, pad=0, dtype=torch.long
|
161 |
+
)
|
162 |
+
input_positions = _make_tensor_with_pad(
|
163 |
+
input_positions, max_prompt_len, pad=0, dtype=torch.long
|
164 |
+
)
|
165 |
+
slot_mapping = _make_tensor_with_pad(
|
166 |
+
slot_mapping, max_prompt_len, pad=_PAD_SLOT_ID, dtype=torch.long
|
167 |
+
)
|
168 |
+
|
169 |
+
input_metadata = InputMetadata(
|
170 |
+
is_prompt=True,
|
171 |
+
slot_mapping=slot_mapping,
|
172 |
+
max_context_len=None,
|
173 |
+
context_lens=None,
|
174 |
+
block_tables=None,
|
175 |
+
use_cuda_graph=False,
|
176 |
+
)
|
177 |
+
return input_tokens, input_positions, input_metadata, prompt_lens
|
178 |
+
|
179 |
+
def _prepare_decode(
|
180 |
+
self,
|
181 |
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
182 |
+
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
|
183 |
+
assert len(seq_group_metadata_list) > 0
|
184 |
+
input_tokens: List[List[int]] = []
|
185 |
+
input_positions: List[List[int]] = []
|
186 |
+
slot_mapping: List[List[int]] = []
|
187 |
+
context_lens: List[int] = []
|
188 |
+
block_tables: List[List[int]] = []
|
189 |
+
|
190 |
+
for seq_group_metadata in seq_group_metadata_list:
|
191 |
+
assert not seq_group_metadata.is_prompt
|
192 |
+
|
193 |
+
seq_ids = list(seq_group_metadata.seq_data.keys())
|
194 |
+
for seq_id in seq_ids:
|
195 |
+
seq_data = seq_group_metadata.seq_data[seq_id]
|
196 |
+
generation_token = seq_data.get_last_token_id()
|
197 |
+
input_tokens.append([generation_token])
|
198 |
+
|
199 |
+
seq_len = seq_data.get_len()
|
200 |
+
position = seq_len - 1
|
201 |
+
input_positions.append([position])
|
202 |
+
|
203 |
+
context_len = (
|
204 |
+
seq_len
|
205 |
+
if self.sliding_window is None
|
206 |
+
else min(seq_len, self.sliding_window)
|
207 |
+
)
|
208 |
+
context_lens.append(context_len)
|
209 |
+
|
210 |
+
block_table = seq_group_metadata.block_tables[seq_id]
|
211 |
+
block_number = block_table[position // self.block_size]
|
212 |
+
block_offset = position % self.block_size
|
213 |
+
slot = block_number * self.block_size + block_offset
|
214 |
+
slot_mapping.append([slot])
|
215 |
+
|
216 |
+
if self.sliding_window is not None:
|
217 |
+
sliding_window_blocks = self.sliding_window // self.block_size
|
218 |
+
block_table = block_table[-sliding_window_blocks:]
|
219 |
+
block_tables.append(block_table)
|
220 |
+
|
221 |
+
batch_size = len(input_tokens)
|
222 |
+
max_context_len = max(context_lens)
|
223 |
+
use_captured_graph = (
|
224 |
+
not self.model_config.enforce_eager
|
225 |
+
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
|
226 |
+
and max_context_len <= self.max_context_len_to_capture
|
227 |
+
)
|
228 |
+
if use_captured_graph:
|
229 |
+
# Pad the input tokens, positions, and slot mapping to match the
|
230 |
+
# batch size of the captured graph.
|
231 |
+
graph_batch_size = _get_graph_batch_size(batch_size)
|
232 |
+
assert graph_batch_size >= batch_size
|
233 |
+
for _ in range(graph_batch_size - batch_size):
|
234 |
+
input_tokens.append([])
|
235 |
+
input_positions.append([])
|
236 |
+
slot_mapping.append([])
|
237 |
+
context_lens.append(1)
|
238 |
+
block_tables.append([])
|
239 |
+
batch_size = graph_batch_size
|
240 |
+
|
241 |
+
input_tokens = _make_tensor_with_pad(
|
242 |
+
input_tokens, max_len=1, pad=0, dtype=torch.long, device="cuda"
|
243 |
+
)
|
244 |
+
input_positions = _make_tensor_with_pad(
|
245 |
+
input_positions, max_len=1, pad=0, dtype=torch.long, device="cuda"
|
246 |
+
)
|
247 |
+
slot_mapping = _make_tensor_with_pad(
|
248 |
+
slot_mapping, max_len=1, pad=_PAD_SLOT_ID, dtype=torch.long, device="cuda"
|
249 |
+
)
|
250 |
+
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
|
251 |
+
|
252 |
+
if use_captured_graph:
|
253 |
+
# The shape of graph_block_tables is
|
254 |
+
# [max batch size, max context len // block size].
|
255 |
+
input_block_tables = self.graph_block_tables[:batch_size]
|
256 |
+
for i, block_table in enumerate(block_tables):
|
257 |
+
if block_table:
|
258 |
+
input_block_tables[i, : len(block_table)] = block_table
|
259 |
+
block_tables = torch.tensor(input_block_tables, device="cuda")
|
260 |
+
else:
|
261 |
+
block_tables = _make_tensor_with_pad(
|
262 |
+
block_tables,
|
263 |
+
max_len=max_context_len,
|
264 |
+
pad=0,
|
265 |
+
dtype=torch.int,
|
266 |
+
device="cuda",
|
267 |
+
)
|
268 |
+
|
269 |
+
input_metadata = InputMetadata(
|
270 |
+
is_prompt=False,
|
271 |
+
slot_mapping=slot_mapping,
|
272 |
+
max_context_len=max_context_len,
|
273 |
+
context_lens=context_lens,
|
274 |
+
block_tables=block_tables,
|
275 |
+
use_cuda_graph=use_captured_graph,
|
276 |
+
)
|
277 |
+
return input_tokens, input_positions, input_metadata
|
278 |
+
|
279 |
+
def _prepare_sample(
|
280 |
+
self,
|
281 |
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
282 |
+
prompt_lens: List[int],
|
283 |
+
) -> SamplingMetadata:
|
284 |
+
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
285 |
+
selected_token_indices: List[int] = []
|
286 |
+
selected_token_start_idx = 0
|
287 |
+
categorized_sample_indices = {t: [] for t in SamplingType}
|
288 |
+
categorized_sample_indices_start_idx = 0
|
289 |
+
|
290 |
+
max_prompt_len = max(prompt_lens) if prompt_lens else 1
|
291 |
+
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
292 |
+
seq_ids = list(seq_group_metadata.seq_data.keys())
|
293 |
+
sampling_params = seq_group_metadata.sampling_params
|
294 |
+
seq_groups.append((seq_ids, sampling_params))
|
295 |
+
|
296 |
+
if seq_group_metadata.is_prompt:
|
297 |
+
assert len(seq_ids) == 1
|
298 |
+
prompt_len = prompt_lens[i]
|
299 |
+
if sampling_params.prompt_logprobs is not None:
|
300 |
+
# NOTE: prompt token positions do not need sample, skip
|
301 |
+
categorized_sample_indices_start_idx += prompt_len - 1
|
302 |
+
|
303 |
+
categorized_sample_indices[sampling_params.sampling_type].append(
|
304 |
+
categorized_sample_indices_start_idx
|
305 |
+
)
|
306 |
+
categorized_sample_indices_start_idx += 1
|
307 |
+
|
308 |
+
if sampling_params.prompt_logprobs is not None:
|
309 |
+
selected_token_indices.extend(
|
310 |
+
range(
|
311 |
+
selected_token_start_idx,
|
312 |
+
selected_token_start_idx + prompt_len - 1,
|
313 |
+
)
|
314 |
+
)
|
315 |
+
selected_token_indices.append(selected_token_start_idx + prompt_len - 1)
|
316 |
+
selected_token_start_idx += max_prompt_len
|
317 |
+
else:
|
318 |
+
num_seqs = len(seq_ids)
|
319 |
+
selected_token_indices.extend(
|
320 |
+
range(selected_token_start_idx, selected_token_start_idx + num_seqs)
|
321 |
+
)
|
322 |
+
selected_token_start_idx += num_seqs
|
323 |
+
|
324 |
+
categorized_sample_indices[sampling_params.sampling_type].extend(
|
325 |
+
range(
|
326 |
+
categorized_sample_indices_start_idx,
|
327 |
+
categorized_sample_indices_start_idx + num_seqs,
|
328 |
+
)
|
329 |
+
)
|
330 |
+
categorized_sample_indices_start_idx += num_seqs
|
331 |
+
|
332 |
+
selected_token_indices = _async_h2d(
|
333 |
+
selected_token_indices, dtype=torch.long, pin_memory=not self.in_wsl
|
334 |
+
)
|
335 |
+
categorized_sample_indices = {
|
336 |
+
t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl)
|
337 |
+
for t, seq_ids in categorized_sample_indices.items()
|
338 |
+
}
|
339 |
+
|
340 |
+
seq_data: Dict[int, SequenceData] = {}
|
341 |
+
for seq_group_metadata in seq_group_metadata_list:
|
342 |
+
seq_data.update(seq_group_metadata.seq_data)
|
343 |
+
|
344 |
+
sampling_metadata = SamplingMetadata(
|
345 |
+
seq_groups=seq_groups,
|
346 |
+
seq_data=seq_data,
|
347 |
+
prompt_lens=prompt_lens,
|
348 |
+
selected_token_indices=selected_token_indices,
|
349 |
+
categorized_sample_indices=categorized_sample_indices,
|
350 |
+
)
|
351 |
+
return sampling_metadata
|
352 |
+
|
353 |
+
def prepare_input_tensors(
|
354 |
+
self,
|
355 |
+
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
356 |
+
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata]:
|
357 |
+
if self.is_driver_worker:
|
358 |
+
# NOTE: We assume that all sequences in the group are all prompts or
|
359 |
+
# all decodes.
|
360 |
+
is_prompt = seq_group_metadata_list[0].is_prompt
|
361 |
+
# Prepare input tensors.
|
362 |
+
if is_prompt:
|
363 |
+
(input_tokens, input_positions, input_metadata, prompt_lens) = (
|
364 |
+
self._prepare_prompt(seq_group_metadata_list)
|
365 |
+
)
|
366 |
+
else:
|
367 |
+
(input_tokens, input_positions, input_metadata) = self._prepare_decode(
|
368 |
+
seq_group_metadata_list
|
369 |
+
)
|
370 |
+
prompt_lens = []
|
371 |
+
sampling_metadata = self._prepare_sample(
|
372 |
+
seq_group_metadata_list, prompt_lens
|
373 |
+
)
|
374 |
+
|
375 |
+
def get_size_or_none(x: Optional[torch.Tensor]):
|
376 |
+
return x.size() if x is not None else None
|
377 |
+
|
378 |
+
# Broadcast the input data. For input tensors, we first broadcast
|
379 |
+
# its shape and then broadcast the tensor to avoid high
|
380 |
+
# serialization cost.
|
381 |
+
py_data = {
|
382 |
+
"input_tokens_size": input_tokens.size(),
|
383 |
+
"input_positions_size": input_positions.size(),
|
384 |
+
"is_prompt": input_metadata.is_prompt,
|
385 |
+
"slot_mapping_size": get_size_or_none(input_metadata.slot_mapping),
|
386 |
+
"max_context_len": input_metadata.max_context_len,
|
387 |
+
"context_lens_size": get_size_or_none(input_metadata.context_lens),
|
388 |
+
"block_tables_size": get_size_or_none(input_metadata.block_tables),
|
389 |
+
"use_cuda_graph": input_metadata.use_cuda_graph,
|
390 |
+
"selected_token_indices_size": sampling_metadata.selected_token_indices.size(),
|
391 |
+
}
|
392 |
+
broadcast_object_list([py_data], src=0)
|
393 |
+
# TODO(zhuohan): Combine the broadcasts or set async_op=True.
|
394 |
+
broadcast(input_tokens, src=0)
|
395 |
+
broadcast(input_positions, src=0)
|
396 |
+
if input_metadata.slot_mapping is not None:
|
397 |
+
broadcast(input_metadata.slot_mapping, src=0)
|
398 |
+
if input_metadata.context_lens is not None:
|
399 |
+
broadcast(input_metadata.context_lens, src=0)
|
400 |
+
if input_metadata.block_tables is not None:
|
401 |
+
broadcast(input_metadata.block_tables, src=0)
|
402 |
+
broadcast(sampling_metadata.selected_token_indices, src=0)
|
403 |
+
else:
|
404 |
+
receving_list = [None]
|
405 |
+
broadcast_object_list(receving_list, src=0)
|
406 |
+
py_data = receving_list[0]
|
407 |
+
input_tokens = torch.empty(
|
408 |
+
*py_data["input_tokens_size"], dtype=torch.long, device="cuda"
|
409 |
+
)
|
410 |
+
broadcast(input_tokens, src=0)
|
411 |
+
input_positions = torch.empty(
|
412 |
+
*py_data["input_positions_size"], dtype=torch.long, device="cuda"
|
413 |
+
)
|
414 |
+
broadcast(input_positions, src=0)
|
415 |
+
if py_data["slot_mapping_size"] is not None:
|
416 |
+
slot_mapping = torch.empty(
|
417 |
+
*py_data["slot_mapping_size"], dtype=torch.long, device="cuda"
|
418 |
+
)
|
419 |
+
broadcast(slot_mapping, src=0)
|
420 |
+
else:
|
421 |
+
slot_mapping = None
|
422 |
+
if py_data["context_lens_size"] is not None:
|
423 |
+
context_lens = torch.empty(
|
424 |
+
*py_data["context_lens_size"], dtype=torch.int, device="cuda"
|
425 |
+
)
|
426 |
+
broadcast(context_lens, src=0)
|
427 |
+
else:
|
428 |
+
context_lens = None
|
429 |
+
if py_data["block_tables_size"] is not None:
|
430 |
+
block_tables = torch.empty(
|
431 |
+
*py_data["block_tables_size"], dtype=torch.int, device="cuda"
|
432 |
+
)
|
433 |
+
broadcast(block_tables, src=0)
|
434 |
+
else:
|
435 |
+
block_tables = None
|
436 |
+
selected_token_indices = torch.empty(
|
437 |
+
*py_data["selected_token_indices_size"], dtype=torch.long, device="cuda"
|
438 |
+
)
|
439 |
+
broadcast(selected_token_indices, src=0)
|
440 |
+
input_metadata = InputMetadata(
|
441 |
+
is_prompt=py_data["is_prompt"],
|
442 |
+
slot_mapping=slot_mapping,
|
443 |
+
max_context_len=py_data["max_context_len"],
|
444 |
+
context_lens=context_lens,
|
445 |
+
block_tables=block_tables,
|
446 |
+
use_cuda_graph=py_data["use_cuda_graph"],
|
447 |
+
)
|
448 |
+
sampling_metadata = SamplingMetadata(
|
449 |
+
seq_groups=None,
|
450 |
+
seq_data=None,
|
451 |
+
prompt_lens=None,
|
452 |
+
selected_token_indices=selected_token_indices,
|
453 |
+
categorized_sample_indices=None,
|
454 |
+
perform_sampling=False,
|
455 |
+
)
|
456 |
+
|
457 |
+
return input_tokens, input_positions, input_metadata, sampling_metadata
|
458 |
+
|
459 |
+
@torch.inference_mode()
|
460 |
+
def execute_model(
|
461 |
+
self,
|
462 |
+
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
463 |
+
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
464 |
+
) -> Optional[SamplerOutput]:
|
465 |
+
input_tokens, input_positions, input_metadata, sampling_metadata = (
|
466 |
+
self.prepare_input_tensors(seq_group_metadata_list)
|
467 |
+
)
|
468 |
+
# print(sampling_metadata.seq_data)
|
469 |
+
seq_groups = []
|
470 |
+
input_tokens_history = []
|
471 |
+
for i, rtn in enumerate(sampling_metadata.seq_groups):
|
472 |
+
seq_groups.append(rtn[0][0])
|
473 |
+
tokens_history = sampling_metadata.seq_data[rtn[0][0]].output_token_ids
|
474 |
+
if len(tokens_history) >= 1:
|
475 |
+
if len(tokens_history[0]) == 1:
|
476 |
+
tokens_history = [token[0] for token in tokens_history]
|
477 |
+
else:
|
478 |
+
tokens_history = [list(token) for token in tokens_history]
|
479 |
+
input_tokens_history.append(tokens_history)
|
480 |
+
input_tokens_history = torch.tensor(input_tokens_history).to(
|
481 |
+
input_tokens.device
|
482 |
+
)
|
483 |
+
# token_ids = rtn.outputs[0].token_ids
|
484 |
+
# for j, token_id in enumerate(token_ids):
|
485 |
+
# if len(token_id) == 1:
|
486 |
+
# token_ids[j] = token_id[0]
|
487 |
+
# else:
|
488 |
+
# token_ids[j] = list(token_id)
|
489 |
+
|
490 |
+
# Execute the model.
|
491 |
+
# print("it1",input_tokens)
|
492 |
+
if len(input_tokens.shape) == 2:
|
493 |
+
input_tokens = input_tokens.unsqueeze(2).repeat(1, 1, 4)
|
494 |
+
if len(input_tokens_history.shape) == 2:
|
495 |
+
input_tokens_history = input_tokens_history.unsqueeze(2).repeat(1, 1, 4)
|
496 |
+
# print(input_tokens_history.shape)
|
497 |
+
# print("it2",input_tokens.shape)
|
498 |
+
text_mask = input_tokens != 0
|
499 |
+
text_mask = text_mask[:, :, 0]
|
500 |
+
|
501 |
+
if input_metadata.use_cuda_graph:
|
502 |
+
graph_batch_size = input_tokens.shape[0]
|
503 |
+
model_executable = self.graph_runners[graph_batch_size]
|
504 |
+
else:
|
505 |
+
model_executable = self.model
|
506 |
+
|
507 |
+
infer_text = sampling_metadata.seq_groups[0][1].infer_text
|
508 |
+
temperture = sampling_metadata.seq_groups[0][1].temperature
|
509 |
+
if not infer_text:
|
510 |
+
temperture = torch.tensor(temperture).to(input_tokens.device)
|
511 |
+
logits_processors, logits_warpers = sampling_metadata.seq_groups[0][
|
512 |
+
1
|
513 |
+
].logits_processors
|
514 |
+
# print(logits_processors, logits_warpers)
|
515 |
+
min_new_token = sampling_metadata.seq_groups[0][1].min_new_token
|
516 |
+
eos_token = sampling_metadata.seq_groups[0][1].eos_token
|
517 |
+
start_idx = sampling_metadata.seq_groups[0][1].start_idx
|
518 |
+
if input_tokens.shape[-2] == 1:
|
519 |
+
if infer_text:
|
520 |
+
input_emb: torch.Tensor = self.post_model.emb_text(
|
521 |
+
input_tokens[:, :, 0]
|
522 |
+
)
|
523 |
+
else:
|
524 |
+
code_emb = [
|
525 |
+
self.post_model.emb_code[i](input_tokens[:, :, i])
|
526 |
+
for i in range(self.post_model.num_vq)
|
527 |
+
]
|
528 |
+
input_emb = torch.stack(code_emb, 3).sum(3)
|
529 |
+
start_idx = (
|
530 |
+
input_tokens_history.shape[-2] - 1
|
531 |
+
if input_tokens_history.shape[-2] > 0
|
532 |
+
else 0
|
533 |
+
)
|
534 |
+
else:
|
535 |
+
input_emb = self.post_model(input_tokens, text_mask)
|
536 |
+
# print(input_emb.shape)
|
537 |
+
hidden_states = model_executable(
|
538 |
+
input_emb=input_emb,
|
539 |
+
positions=input_positions,
|
540 |
+
kv_caches=kv_caches,
|
541 |
+
input_metadata=input_metadata,
|
542 |
+
)
|
543 |
+
# print(hidden_states.shape)
|
544 |
+
# print(input_tokens)
|
545 |
+
B_NO_PAD = input_tokens_history.shape[0]
|
546 |
+
input_tokens = input_tokens[:B_NO_PAD, :, :]
|
547 |
+
hidden_states = hidden_states[:B_NO_PAD, :, :]
|
548 |
+
idx_next, logprob, finish = self.sampler.sample(
|
549 |
+
inputs_ids=(
|
550 |
+
input_tokens
|
551 |
+
if input_tokens_history.shape[-2] == 0
|
552 |
+
else input_tokens_history
|
553 |
+
),
|
554 |
+
hidden_states=hidden_states,
|
555 |
+
infer_text=infer_text,
|
556 |
+
temperature=temperture,
|
557 |
+
logits_processors=logits_processors,
|
558 |
+
logits_warpers=logits_warpers,
|
559 |
+
min_new_token=min_new_token,
|
560 |
+
now_length=1,
|
561 |
+
eos_token=eos_token,
|
562 |
+
start_idx=start_idx,
|
563 |
+
)
|
564 |
+
# print(logprob.shape, idx_next.shape)
|
565 |
+
if len(logprob.shape) == 2:
|
566 |
+
logprob = logprob[:, None, :]
|
567 |
+
logprob = torch.gather(logprob, -1, idx_next.transpose(-1, -2))[:, :, 0]
|
568 |
+
# print("测试",idx_next.shape, logprob.shape)
|
569 |
+
# Sample the next token.
|
570 |
+
# output = self.model.sample(
|
571 |
+
# hidden_states=hidden_states,
|
572 |
+
# sampling_metadata=sampling_metadata,
|
573 |
+
# )
|
574 |
+
results = []
|
575 |
+
for i in range(idx_next.shape[0]):
|
576 |
+
idx_next_i = idx_next[i, 0, :].tolist()
|
577 |
+
logprob_i = logprob[i].tolist()
|
578 |
+
tmp_hidden_states = hidden_states[i]
|
579 |
+
if input_tokens[i].shape[-2] != 1:
|
580 |
+
tmp_hidden_states = tmp_hidden_states[-1:, :]
|
581 |
+
result = SequenceGroupOutput(
|
582 |
+
samples=[
|
583 |
+
SequenceOutput(
|
584 |
+
parent_seq_id=seq_groups[i],
|
585 |
+
logprobs={tuple(idx_next_i): logprob_i},
|
586 |
+
output_token=tuple(idx_next_i),
|
587 |
+
hidden_states=tmp_hidden_states,
|
588 |
+
finished=finish[i].item(),
|
589 |
+
),
|
590 |
+
],
|
591 |
+
prompt_logprobs=None,
|
592 |
+
)
|
593 |
+
results.append(result)
|
594 |
+
# print(results)
|
595 |
+
# print(idx_next, idx_next.shape, logprob.shape)
|
596 |
+
return results
|
597 |
+
|
598 |
+
@torch.inference_mode()
|
599 |
+
def profile_run(self) -> None:
|
600 |
+
# Enable top-k sampling to reflect the accurate memory usage.
|
601 |
+
vocab_size = self.model_config.get_vocab_size()
|
602 |
+
sampling_params = SamplingParams(
|
603 |
+
top_p=0.99, top_k=vocab_size - 1, infer_text=True
|
604 |
+
)
|
605 |
+
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
606 |
+
max_num_seqs = self.scheduler_config.max_num_seqs
|
607 |
+
|
608 |
+
# Profile memory usage with max_num_sequences sequences and the total
|
609 |
+
# number of tokens equal to max_num_batched_tokens.
|
610 |
+
seqs: List[SequenceGroupMetadata] = []
|
611 |
+
for group_id in range(max_num_seqs):
|
612 |
+
seq_len = max_num_batched_tokens // max_num_seqs + (
|
613 |
+
group_id < max_num_batched_tokens % max_num_seqs
|
614 |
+
)
|
615 |
+
seq_data = SequenceData([0] * seq_len)
|
616 |
+
seq = SequenceGroupMetadata(
|
617 |
+
request_id=str(group_id),
|
618 |
+
is_prompt=True,
|
619 |
+
seq_data={group_id: seq_data},
|
620 |
+
sampling_params=sampling_params,
|
621 |
+
block_tables=None,
|
622 |
+
)
|
623 |
+
seqs.append(seq)
|
624 |
+
|
625 |
+
# Run the model with the dummy inputs.
|
626 |
+
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
627 |
+
kv_caches = [(None, None)] * num_layers
|
628 |
+
self.execute_model(seqs, kv_caches)
|
629 |
+
torch.cuda.synchronize()
|
630 |
+
return
|
631 |
+
|
632 |
+
@torch.inference_mode()
|
633 |
+
def capture_model(self, kv_caches: List[KVCache]) -> None:
|
634 |
+
assert not self.model_config.enforce_eager
|
635 |
+
logger.info(
|
636 |
+
"Capturing the model for CUDA graphs. This may lead to "
|
637 |
+
"unexpected consequences if the model is not static. To "
|
638 |
+
"run the model in eager mode, set 'enforce_eager=True' or "
|
639 |
+
"use '--enforce-eager' in the CLI."
|
640 |
+
)
|
641 |
+
logger.info(
|
642 |
+
"CUDA graphs can take additional 1~3 GiB memory per GPU. "
|
643 |
+
"If you are running out of memory, consider decreasing "
|
644 |
+
"`gpu_memory_utilization` or enforcing eager mode."
|
645 |
+
)
|
646 |
+
start_time = time.perf_counter()
|
647 |
+
|
648 |
+
# Prepare dummy inputs. These will be reused for all batch sizes.
|
649 |
+
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
650 |
+
input_emb = torch.zeros(
|
651 |
+
max_batch_size,
|
652 |
+
1,
|
653 |
+
self.model_config.get_hidden_size(),
|
654 |
+
dtype=next(self.model.parameters()).dtype,
|
655 |
+
).cuda()
|
656 |
+
input_positions = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
|
657 |
+
slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda()
|
658 |
+
slot_mapping.fill_(_PAD_SLOT_ID)
|
659 |
+
context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
|
660 |
+
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
661 |
+
|
662 |
+
# NOTE: Capturing the largest batch size first may help reduce the
|
663 |
+
# memory usage of CUDA graph.
|
664 |
+
for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE):
|
665 |
+
# Create dummy input_metadata.
|
666 |
+
input_metadata = InputMetadata(
|
667 |
+
is_prompt=False,
|
668 |
+
slot_mapping=slot_mapping[:batch_size],
|
669 |
+
max_context_len=self.max_context_len_to_capture,
|
670 |
+
context_lens=context_lens[:batch_size],
|
671 |
+
block_tables=block_tables[:batch_size],
|
672 |
+
use_cuda_graph=True,
|
673 |
+
)
|
674 |
+
|
675 |
+
graph_runner = CUDAGraphRunner(self.model)
|
676 |
+
graph_runner.capture(
|
677 |
+
input_emb[:batch_size],
|
678 |
+
input_positions[:batch_size],
|
679 |
+
kv_caches,
|
680 |
+
input_metadata,
|
681 |
+
memory_pool=self.graph_memory_pool,
|
682 |
+
)
|
683 |
+
self.graph_memory_pool = graph_runner.graph.pool()
|
684 |
+
self.graph_runners[batch_size] = graph_runner
|
685 |
+
|
686 |
+
end_time = time.perf_counter()
|
687 |
+
elapsed_time = end_time - start_time
|
688 |
+
# This usually takes < 10 seconds.
|
689 |
+
logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
|
690 |
+
|
691 |
+
|
692 |
+
class CUDAGraphRunner:
|
693 |
+
|
694 |
+
def __init__(self, model: nn.Module):
|
695 |
+
self.model = model
|
696 |
+
self.graph = None
|
697 |
+
self.input_buffers: Dict[str, torch.Tensor] = {}
|
698 |
+
self.output_buffers: Dict[str, torch.Tensor] = {}
|
699 |
+
|
700 |
+
def capture(
|
701 |
+
self,
|
702 |
+
input_emb: torch.Tensor,
|
703 |
+
positions: torch.Tensor,
|
704 |
+
kv_caches: List[KVCache],
|
705 |
+
input_metadata: InputMetadata,
|
706 |
+
memory_pool,
|
707 |
+
) -> None:
|
708 |
+
assert self.graph is None
|
709 |
+
# Run the model once without capturing the graph.
|
710 |
+
# This is to make sure that the captured graph does not include the
|
711 |
+
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
712 |
+
self.model(
|
713 |
+
input_emb,
|
714 |
+
positions,
|
715 |
+
kv_caches,
|
716 |
+
input_metadata,
|
717 |
+
)
|
718 |
+
torch.cuda.synchronize()
|
719 |
+
|
720 |
+
# Capture the graph.
|
721 |
+
self.graph = torch.cuda.CUDAGraph()
|
722 |
+
with torch.cuda.graph(self.graph, pool=memory_pool):
|
723 |
+
hidden_states = self.model(
|
724 |
+
input_emb,
|
725 |
+
positions,
|
726 |
+
kv_caches,
|
727 |
+
input_metadata,
|
728 |
+
)
|
729 |
+
torch.cuda.synchronize()
|
730 |
+
|
731 |
+
# Save the input and output buffers.
|
732 |
+
self.input_buffers = {
|
733 |
+
"input_emb": input_emb,
|
734 |
+
"positions": positions,
|
735 |
+
"kv_caches": kv_caches,
|
736 |
+
"slot_mapping": input_metadata.slot_mapping,
|
737 |
+
"context_lens": input_metadata.context_lens,
|
738 |
+
"block_tables": input_metadata.block_tables,
|
739 |
+
}
|
740 |
+
self.output_buffers = {"hidden_states": hidden_states}
|
741 |
+
return
|
742 |
+
|
743 |
+
def forward(
|
744 |
+
self,
|
745 |
+
input_emb: torch.Tensor,
|
746 |
+
positions: torch.Tensor,
|
747 |
+
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
748 |
+
input_metadata: InputMetadata,
|
749 |
+
) -> torch.Tensor:
|
750 |
+
# KV caches are fixed tensors, so we don't need to copy them.
|
751 |
+
del kv_caches
|
752 |
+
|
753 |
+
# Copy the input tensors to the input buffers.
|
754 |
+
self.input_buffers["input_emb"].copy_(input_emb, non_blocking=True)
|
755 |
+
self.input_buffers["positions"].copy_(positions, non_blocking=True)
|
756 |
+
self.input_buffers["slot_mapping"].copy_(
|
757 |
+
input_metadata.slot_mapping, non_blocking=True
|
758 |
+
)
|
759 |
+
self.input_buffers["context_lens"].copy_(
|
760 |
+
input_metadata.context_lens, non_blocking=True
|
761 |
+
)
|
762 |
+
self.input_buffers["block_tables"].copy_(
|
763 |
+
input_metadata.block_tables, non_blocking=True
|
764 |
+
)
|
765 |
+
|
766 |
+
# Run the graph.
|
767 |
+
self.graph.replay()
|
768 |
+
|
769 |
+
# Return the output tensor.
|
770 |
+
return self.output_buffers["hidden_states"]
|
771 |
+
|
772 |
+
def __call__(self, *args, **kwargs):
|
773 |
+
return self.forward(*args, **kwargs)
|
774 |
+
|
775 |
+
|
776 |
+
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
|
777 |
+
assert len(x) <= max_len
|
778 |
+
if len(x) == max_len:
|
779 |
+
return list(x)
|
780 |
+
return list(x) + [pad] * (max_len - len(x))
|
781 |
+
|
782 |
+
|
783 |
+
def _make_tensor_with_pad(
|
784 |
+
x: List[List[int]],
|
785 |
+
max_len: int,
|
786 |
+
pad: int,
|
787 |
+
dtype: torch.dtype,
|
788 |
+
device: Union[str, torch.device] = "cuda",
|
789 |
+
pin_memory: bool = False,
|
790 |
+
) -> torch.Tensor:
|
791 |
+
padded_x = []
|
792 |
+
for x_i in x:
|
793 |
+
pad_i = pad
|
794 |
+
if isinstance(x[0][0], tuple):
|
795 |
+
pad_i = (0,) * len(x[0][0])
|
796 |
+
padded_x.append(_pad_to_max(x_i, max_len, pad_i))
|
797 |
+
|
798 |
+
return torch.tensor(
|
799 |
+
padded_x,
|
800 |
+
dtype=dtype,
|
801 |
+
device=device,
|
802 |
+
pin_memory=pin_memory and str(device) == "cpu",
|
803 |
+
)
|
804 |
+
|
805 |
+
|
806 |
+
def _get_graph_batch_size(batch_size: int) -> int:
|
807 |
+
if batch_size <= 2:
|
808 |
+
return batch_size
|
809 |
+
elif batch_size <= 4:
|
810 |
+
return 4
|
811 |
+
else:
|
812 |
+
return (batch_size + 7) // 8 * 8
|
813 |
+
|
814 |
+
|
815 |
+
def _async_h2d(data: list, dtype, pin_memory):
|
816 |
+
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory)
|
817 |
+
return t.to(device="cuda", non_blocking=True)
|
ChatTTS/model/velocity/output.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from .sequence import (
|
5 |
+
PromptLogprobs,
|
6 |
+
SampleLogprobs,
|
7 |
+
SequenceGroup,
|
8 |
+
SequenceStatus,
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
class CompletionOutput:
|
13 |
+
"""The output data of one completion output of a request.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
index: The index of the output in the request.
|
17 |
+
text: The generated output text.
|
18 |
+
token_ids: The token IDs of the generated output text.
|
19 |
+
cumulative_logprob: The cumulative log probability of the generated
|
20 |
+
output text.
|
21 |
+
logprobs: The log probabilities of the top probability words at each
|
22 |
+
position if the logprobs are requested.
|
23 |
+
finish_reason: The reason why the sequence is finished.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
index: int,
|
29 |
+
text: str,
|
30 |
+
token_ids: List[int],
|
31 |
+
cumulative_logprob: float,
|
32 |
+
logprobs: Optional[SampleLogprobs],
|
33 |
+
finish_reason: Optional[str] = None,
|
34 |
+
hidden_states: Optional[torch.Tensor] = None,
|
35 |
+
) -> None:
|
36 |
+
self.index = index
|
37 |
+
self.text = text
|
38 |
+
self.token_ids = token_ids
|
39 |
+
self.cumulative_logprob = cumulative_logprob
|
40 |
+
self.logprobs = logprobs
|
41 |
+
self.finish_reason = finish_reason
|
42 |
+
self.hidden_states = hidden_states
|
43 |
+
|
44 |
+
def finished(self) -> bool:
|
45 |
+
return self.finish_reason is not None
|
46 |
+
|
47 |
+
def __repr__(self) -> str:
|
48 |
+
return (
|
49 |
+
f"CompletionOutput(index={self.index}, "
|
50 |
+
f"text={self.text!r}, "
|
51 |
+
f"token_ids={self.token_ids}, "
|
52 |
+
f"cumulative_logprob={self.cumulative_logprob}, "
|
53 |
+
f"logprobs={self.logprobs}, "
|
54 |
+
f"finish_reason={self.finish_reason}, "
|
55 |
+
f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None})"
|
56 |
+
)
|
57 |
+
|
58 |
+
|
59 |
+
class RequestOutput:
|
60 |
+
"""The output data of a request to the LLM.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
request_id: The unique ID of the request.
|
64 |
+
prompt: The prompt string of the request.
|
65 |
+
prompt_token_ids: The token IDs of the prompt.
|
66 |
+
prompt_logprobs: The log probabilities to return per prompt token.
|
67 |
+
outputs: The output sequences of the request.
|
68 |
+
finished: Whether the whole request is finished.
|
69 |
+
"""
|
70 |
+
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
request_id: str,
|
74 |
+
prompt: str,
|
75 |
+
prompt_token_ids: List[int],
|
76 |
+
prompt_logprobs: Optional[PromptLogprobs],
|
77 |
+
outputs: List[CompletionOutput],
|
78 |
+
finished: bool,
|
79 |
+
) -> None:
|
80 |
+
self.request_id = request_id
|
81 |
+
self.prompt = prompt
|
82 |
+
self.prompt_token_ids = prompt_token_ids
|
83 |
+
self.prompt_logprobs = prompt_logprobs
|
84 |
+
self.outputs = outputs
|
85 |
+
self.finished = finished
|
86 |
+
|
87 |
+
@classmethod
|
88 |
+
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
|
89 |
+
# Get the top-n sequences.
|
90 |
+
n = seq_group.sampling_params.n
|
91 |
+
seqs = seq_group.get_seqs()
|
92 |
+
if seq_group.sampling_params.use_beam_search:
|
93 |
+
sorting_key = lambda seq: seq.get_beam_search_score(
|
94 |
+
seq_group.sampling_params.length_penalty
|
95 |
+
)
|
96 |
+
else:
|
97 |
+
sorting_key = lambda seq: seq.get_cumulative_logprob()
|
98 |
+
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
|
99 |
+
top_n_seqs = sorted_seqs[:n]
|
100 |
+
|
101 |
+
# Create the outputs.
|
102 |
+
outputs: List[CompletionOutput] = []
|
103 |
+
for seq in top_n_seqs:
|
104 |
+
logprobs = seq.output_logprobs
|
105 |
+
if seq_group.sampling_params.logprobs is None:
|
106 |
+
# NOTE: We need to take care of this case because the sequence
|
107 |
+
# always has the logprobs of the sampled tokens even if the
|
108 |
+
# logprobs are not requested.
|
109 |
+
logprobs = None
|
110 |
+
finshed_reason = SequenceStatus.get_finished_reason(seq.status)
|
111 |
+
output = CompletionOutput(
|
112 |
+
seqs.index(seq),
|
113 |
+
seq.output_text,
|
114 |
+
seq.get_output_token_ids(),
|
115 |
+
seq.get_cumulative_logprob(),
|
116 |
+
logprobs,
|
117 |
+
finshed_reason,
|
118 |
+
seq.data.hidden_states,
|
119 |
+
)
|
120 |
+
outputs.append(output)
|
121 |
+
|
122 |
+
# Every sequence in the sequence group should have the same prompt.
|
123 |
+
prompt = seq_group.prompt
|
124 |
+
prompt_token_ids = seq_group.prompt_token_ids
|
125 |
+
prompt_logprobs = seq_group.prompt_logprobs
|
126 |
+
finished = seq_group.is_finished()
|
127 |
+
return cls(
|
128 |
+
seq_group.request_id,
|
129 |
+
prompt,
|
130 |
+
prompt_token_ids,
|
131 |
+
prompt_logprobs,
|
132 |
+
outputs,
|
133 |
+
finished,
|
134 |
+
)
|
135 |
+
|
136 |
+
def __repr__(self) -> str:
|
137 |
+
return (
|
138 |
+
f"RequestOutput(request_id={self.request_id}, "
|
139 |
+
f"prompt={self.prompt!r}, "
|
140 |
+
f"prompt_token_ids={self.prompt_token_ids}, "
|
141 |
+
f"prompt_logprobs={self.prompt_logprobs}, "
|
142 |
+
f"outputs={self.outputs}, "
|
143 |
+
f"finished={self.finished})"
|
144 |
+
)
|
ChatTTS/model/velocity/sampler.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.functional import F
|
3 |
+
from typing import List, Callable
|
4 |
+
|
5 |
+
from ..embed import Embed
|
6 |
+
|
7 |
+
|
8 |
+
class Sampler:
|
9 |
+
def __init__(self, post_model: Embed, num_audio_tokens: int, num_vq: int):
|
10 |
+
self.post_model = post_model
|
11 |
+
self.device = next(self.post_model.parameters()).device
|
12 |
+
self.num_audio_tokens = num_audio_tokens
|
13 |
+
self.num_vq = num_vq
|
14 |
+
|
15 |
+
def sample(
|
16 |
+
self,
|
17 |
+
inputs_ids: torch.Tensor,
|
18 |
+
hidden_states: torch.Tensor,
|
19 |
+
infer_text: bool = False,
|
20 |
+
temperature: torch.Tensor = 1.0,
|
21 |
+
logits_processors: List[Callable] = [
|
22 |
+
lambda logits_token, logits: logits,
|
23 |
+
],
|
24 |
+
logits_warpers: List[Callable] = [
|
25 |
+
lambda logits_token, logits: logits,
|
26 |
+
],
|
27 |
+
min_new_token: int = 0,
|
28 |
+
now_length: int = 0,
|
29 |
+
eos_token: int = 0,
|
30 |
+
start_idx: int = 0,
|
31 |
+
):
|
32 |
+
# print(inputs_ids.shape)
|
33 |
+
B = hidden_states.shape[0]
|
34 |
+
|
35 |
+
end_idx = torch.zeros(
|
36 |
+
inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long
|
37 |
+
)
|
38 |
+
finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()
|
39 |
+
if not infer_text:
|
40 |
+
temperature = (
|
41 |
+
temperature.unsqueeze(0)
|
42 |
+
.expand(inputs_ids.shape[0], -1)
|
43 |
+
.contiguous()
|
44 |
+
.view(-1, 1)
|
45 |
+
)
|
46 |
+
|
47 |
+
if infer_text:
|
48 |
+
logits: torch.Tensor = self.post_model.head_text(hidden_states)
|
49 |
+
else:
|
50 |
+
# logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3)
|
51 |
+
logits = torch.empty(
|
52 |
+
hidden_states.size(0),
|
53 |
+
hidden_states.size(1),
|
54 |
+
self.num_audio_tokens,
|
55 |
+
self.num_vq,
|
56 |
+
dtype=torch.float,
|
57 |
+
device=self.device,
|
58 |
+
)
|
59 |
+
for num_vq_iter in range(self.num_vq):
|
60 |
+
x: torch.Tensor = self.post_model.head_code[num_vq_iter](hidden_states)
|
61 |
+
logits[..., num_vq_iter] = x
|
62 |
+
del x
|
63 |
+
|
64 |
+
del hidden_states
|
65 |
+
|
66 |
+
# logits = logits[:, -1].float()
|
67 |
+
logits = logits.narrow(1, -1, 1).squeeze_(1).float()
|
68 |
+
|
69 |
+
if not infer_text:
|
70 |
+
# logits = rearrange(logits, "b c n -> (b n) c")
|
71 |
+
logits = logits.permute(0, 2, 1)
|
72 |
+
logits = logits.reshape(-1, logits.size(2))
|
73 |
+
# logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
|
74 |
+
inputs_ids_sliced = inputs_ids[:, start_idx:].permute(0, 2, 1)
|
75 |
+
logits_token = inputs_ids_sliced.reshape(
|
76 |
+
inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1),
|
77 |
+
-1,
|
78 |
+
).to(self.device)
|
79 |
+
else:
|
80 |
+
logits_token = inputs_ids[:, start_idx:, 0].to(self.device)
|
81 |
+
|
82 |
+
logits /= temperature
|
83 |
+
|
84 |
+
for logitsProcessors in logits_processors:
|
85 |
+
logits = logitsProcessors(logits_token, logits)
|
86 |
+
|
87 |
+
for logitsWarpers in logits_warpers:
|
88 |
+
logits = logitsWarpers(logits_token, logits)
|
89 |
+
|
90 |
+
del logits_token
|
91 |
+
|
92 |
+
if now_length < min_new_token:
|
93 |
+
logits[:, eos_token] = -torch.inf
|
94 |
+
|
95 |
+
scores = F.softmax(logits, dim=-1)
|
96 |
+
idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)
|
97 |
+
if not infer_text:
|
98 |
+
scores = scores.reshape(B, -1, scores.shape[-1])
|
99 |
+
if not infer_text:
|
100 |
+
# idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
|
101 |
+
idx_next = idx_next.view(-1, self.num_vq)
|
102 |
+
finish_or = idx_next.eq(eos_token).any(1)
|
103 |
+
finish.logical_or_(finish_or)
|
104 |
+
del finish_or
|
105 |
+
else:
|
106 |
+
finish_or = idx_next.eq(eos_token).any(1)
|
107 |
+
finish.logical_or_(finish_or)
|
108 |
+
del finish_or
|
109 |
+
|
110 |
+
del inputs_ids
|
111 |
+
|
112 |
+
not_finished = finish.logical_not().to(end_idx.device)
|
113 |
+
|
114 |
+
end_idx.add_(not_finished.int())
|
115 |
+
idx_next = idx_next[:, None, :]
|
116 |
+
return (
|
117 |
+
idx_next,
|
118 |
+
torch.log(scores),
|
119 |
+
finish,
|
120 |
+
)
|
ChatTTS/model/velocity/sampling_params.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Sampling parameters for text generation."""
|
2 |
+
|
3 |
+
from enum import IntEnum
|
4 |
+
from functools import cached_property
|
5 |
+
from typing import Callable, List, Optional, Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
_SAMPLING_EPS = 1e-5
|
10 |
+
|
11 |
+
|
12 |
+
class SamplingType(IntEnum):
|
13 |
+
GREEDY = 0
|
14 |
+
RANDOM = 1
|
15 |
+
BEAM = 2
|
16 |
+
|
17 |
+
|
18 |
+
LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
|
19 |
+
"""LogitsProcessor is a function that takes a list of previously generated
|
20 |
+
tokens and a tensor of the logits for the next token, and returns a modified
|
21 |
+
tensor of logits to sample from."""
|
22 |
+
|
23 |
+
|
24 |
+
class SamplingParams:
|
25 |
+
"""Sampling parameters for text generation.
|
26 |
+
|
27 |
+
Overall, we follow the sampling parameters from the OpenAI text completion
|
28 |
+
API (https://platform.openai.com/docs/api-reference/completions/create).
|
29 |
+
In addition, we support beam search, which is not supported by OpenAI.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
n: Number of output sequences to return for the given prompt.
|
33 |
+
best_of: Number of output sequences that are generated from the prompt.
|
34 |
+
From these `best_of` sequences, the top `n` sequences are returned.
|
35 |
+
`best_of` must be greater than or equal to `n`. This is treated as
|
36 |
+
the beam width when `use_beam_search` is True. By default, `best_of`
|
37 |
+
is set to `n`.
|
38 |
+
presence_penalty: Float that penalizes new tokens based on whether they
|
39 |
+
appear in the generated text so far. Values > 0 encourage the model
|
40 |
+
to use new tokens, while values < 0 encourage the model to repeat
|
41 |
+
tokens.
|
42 |
+
frequency_penalty: Float that penalizes new tokens based on their
|
43 |
+
frequency in the generated text so far. Values > 0 encourage the
|
44 |
+
model to use new tokens, while values < 0 encourage the model to
|
45 |
+
repeat tokens.
|
46 |
+
repetition_penalty: Float that penalizes new tokens based on whether
|
47 |
+
they appear in the prompt and the generated text so far. Values > 1
|
48 |
+
encourage the model to use new tokens, while values < 1 encourage
|
49 |
+
the model to repeat tokens.
|
50 |
+
temperature: Float that controls the randomness of the sampling. Lower
|
51 |
+
values make the model more deterministic, while higher values make
|
52 |
+
the model more random. Zero means greedy sampling.
|
53 |
+
top_p: Float that controls the cumulative probability of the top tokens
|
54 |
+
to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
|
55 |
+
top_k: Integer that controls the number of top tokens to consider. Set
|
56 |
+
to -1 to consider all tokens.
|
57 |
+
min_p: Float that represents the minimum probability for a token to be
|
58 |
+
considered, relative to the probability of the most likely token.
|
59 |
+
Must be in [0, 1]. Set to 0 to disable this.
|
60 |
+
use_beam_search: Whether to use beam search instead of sampling.
|
61 |
+
length_penalty: Float that penalizes sequences based on their length.
|
62 |
+
Used in beam search.
|
63 |
+
early_stopping: Controls the stopping condition for beam search. It
|
64 |
+
accepts the following values: `True`, where the generation stops as
|
65 |
+
soon as there are `best_of` complete candidates; `False`, where an
|
66 |
+
heuristic is applied and the generation stops when is it very
|
67 |
+
unlikely to find better candidates; `"never"`, where the beam search
|
68 |
+
procedure only stops when there cannot be better candidates
|
69 |
+
(canonical beam search algorithm).
|
70 |
+
stop: List of strings that stop the generation when they are generated.
|
71 |
+
The returned output will not contain the stop strings.
|
72 |
+
stop_token_ids: List of tokens that stop the generation when they are
|
73 |
+
generated. The returned output will contain the stop tokens unless
|
74 |
+
the stop tokens are special tokens.
|
75 |
+
include_stop_str_in_output: Whether to include the stop strings in output
|
76 |
+
text. Defaults to False.
|
77 |
+
ignore_eos: Whether to ignore the EOS token and continue generating
|
78 |
+
tokens after the EOS token is generated.
|
79 |
+
max_tokens: Maximum number of tokens to generate per output sequence.
|
80 |
+
logprobs: Number of log probabilities to return per output token.
|
81 |
+
Note that the implementation follows the OpenAI API: The return
|
82 |
+
result includes the log probabilities on the `logprobs` most likely
|
83 |
+
tokens, as well the chosen tokens. The API will always return the
|
84 |
+
log probability of the sampled token, so there may be up to
|
85 |
+
`logprobs+1` elements in the response.
|
86 |
+
prompt_logprobs: Number of log probabilities to return per prompt token.
|
87 |
+
skip_special_tokens: Whether to skip special tokens in the output.
|
88 |
+
spaces_between_special_tokens: Whether to add spaces between special
|
89 |
+
tokens in the output. Defaults to True.
|
90 |
+
logits_processors: List of functions that modify logits based on
|
91 |
+
previously generated tokens.
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(
|
95 |
+
self,
|
96 |
+
n: int = 1,
|
97 |
+
best_of: Optional[int] = None,
|
98 |
+
presence_penalty: float = 0.0,
|
99 |
+
frequency_penalty: float = 0.0,
|
100 |
+
repetition_penalty: float = 1.0,
|
101 |
+
temperature: float = 1.0,
|
102 |
+
top_p: float = 1.0,
|
103 |
+
top_k: int = -1,
|
104 |
+
min_p: float = 0.0,
|
105 |
+
use_beam_search: bool = False,
|
106 |
+
length_penalty: float = 1.0,
|
107 |
+
early_stopping: Union[bool, str] = False,
|
108 |
+
stop: Optional[Union[str, List[str]]] = None,
|
109 |
+
stop_token_ids: Optional[List[int]] = None,
|
110 |
+
include_stop_str_in_output: bool = False,
|
111 |
+
ignore_eos: bool = False,
|
112 |
+
max_tokens: int = 16,
|
113 |
+
logprobs: Optional[int] = None,
|
114 |
+
prompt_logprobs: Optional[int] = None,
|
115 |
+
skip_special_tokens: bool = True,
|
116 |
+
spaces_between_special_tokens: bool = True,
|
117 |
+
logits_processors: Optional[List[LogitsProcessor]] = (
|
118 |
+
[
|
119 |
+
lambda logits_token, logits: logits,
|
120 |
+
],
|
121 |
+
[
|
122 |
+
lambda logits_token, logits: logits,
|
123 |
+
],
|
124 |
+
),
|
125 |
+
min_new_token: int = 0,
|
126 |
+
max_new_token: int = 8192,
|
127 |
+
infer_text: bool = False,
|
128 |
+
eos_token: int = 0,
|
129 |
+
spk_emb: str = None,
|
130 |
+
start_idx: int = 0,
|
131 |
+
) -> None:
|
132 |
+
self.n = n
|
133 |
+
self.best_of = best_of if best_of is not None else n
|
134 |
+
self.presence_penalty = presence_penalty
|
135 |
+
self.frequency_penalty = frequency_penalty
|
136 |
+
self.repetition_penalty = repetition_penalty
|
137 |
+
self.temperature = temperature
|
138 |
+
self.top_p = top_p
|
139 |
+
self.top_k = top_k
|
140 |
+
self.min_p = min_p
|
141 |
+
self.use_beam_search = use_beam_search
|
142 |
+
self.length_penalty = length_penalty
|
143 |
+
self.early_stopping = early_stopping
|
144 |
+
self.min_new_token = min_new_token
|
145 |
+
self.max_new_token = max_new_token
|
146 |
+
self.infer_text = infer_text
|
147 |
+
self.eos_token = eos_token
|
148 |
+
self.spk_emb = spk_emb
|
149 |
+
self.start_idx = start_idx
|
150 |
+
if stop is None:
|
151 |
+
self.stop = []
|
152 |
+
elif isinstance(stop, str):
|
153 |
+
self.stop = [stop]
|
154 |
+
else:
|
155 |
+
self.stop = list(stop)
|
156 |
+
if stop_token_ids is None:
|
157 |
+
self.stop_token_ids = []
|
158 |
+
else:
|
159 |
+
self.stop_token_ids = list(stop_token_ids)
|
160 |
+
self.ignore_eos = ignore_eos
|
161 |
+
self.max_tokens = max_tokens
|
162 |
+
self.logprobs = logprobs
|
163 |
+
self.prompt_logprobs = prompt_logprobs
|
164 |
+
self.skip_special_tokens = skip_special_tokens
|
165 |
+
self.spaces_between_special_tokens = spaces_between_special_tokens
|
166 |
+
self.logits_processors = logits_processors
|
167 |
+
self.include_stop_str_in_output = include_stop_str_in_output
|
168 |
+
self._verify_args()
|
169 |
+
if self.use_beam_search:
|
170 |
+
self._verify_beam_search()
|
171 |
+
else:
|
172 |
+
self._verify_non_beam_search()
|
173 |
+
# if self.temperature < _SAMPLING_EPS:
|
174 |
+
# # Zero temperature means greedy sampling.
|
175 |
+
# self.top_p = 1.0
|
176 |
+
# self.top_k = -1
|
177 |
+
# self.min_p = 0.0
|
178 |
+
# self._verify_greedy_sampling()
|
179 |
+
|
180 |
+
def _verify_args(self) -> None:
|
181 |
+
if self.n < 1:
|
182 |
+
raise ValueError(f"n must be at least 1, got {self.n}.")
|
183 |
+
if self.best_of < self.n:
|
184 |
+
raise ValueError(
|
185 |
+
f"best_of must be greater than or equal to n, "
|
186 |
+
f"got n={self.n} and best_of={self.best_of}."
|
187 |
+
)
|
188 |
+
if not -2.0 <= self.presence_penalty <= 2.0:
|
189 |
+
raise ValueError(
|
190 |
+
"presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}."
|
191 |
+
)
|
192 |
+
if not -2.0 <= self.frequency_penalty <= 2.0:
|
193 |
+
raise ValueError(
|
194 |
+
"frequency_penalty must be in [-2, 2], got "
|
195 |
+
f"{self.frequency_penalty}."
|
196 |
+
)
|
197 |
+
if not 0.0 < self.repetition_penalty <= 2.0:
|
198 |
+
raise ValueError(
|
199 |
+
"repetition_penalty must be in (0, 2], got "
|
200 |
+
f"{self.repetition_penalty}."
|
201 |
+
)
|
202 |
+
# if self.temperature < 0.0:
|
203 |
+
# raise ValueError(
|
204 |
+
# f"temperature must be non-negative, got {self.temperature}.")
|
205 |
+
if not 0.0 < self.top_p <= 1.0:
|
206 |
+
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
|
207 |
+
if self.top_k < -1 or self.top_k == 0:
|
208 |
+
raise ValueError(
|
209 |
+
f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
|
210 |
+
)
|
211 |
+
if not 0.0 <= self.min_p <= 1.0:
|
212 |
+
raise ValueError("min_p must be in [0, 1], got " f"{self.min_p}.")
|
213 |
+
if self.max_tokens < 1:
|
214 |
+
raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.")
|
215 |
+
if self.logprobs is not None and self.logprobs < 0:
|
216 |
+
raise ValueError(f"logprobs must be non-negative, got {self.logprobs}.")
|
217 |
+
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
|
218 |
+
raise ValueError(
|
219 |
+
f"prompt_logprobs must be non-negative, got " f"{self.prompt_logprobs}."
|
220 |
+
)
|
221 |
+
|
222 |
+
def _verify_beam_search(self) -> None:
|
223 |
+
if self.best_of == 1:
|
224 |
+
raise ValueError(
|
225 |
+
"best_of must be greater than 1 when using beam "
|
226 |
+
f"search. Got {self.best_of}."
|
227 |
+
)
|
228 |
+
if self.temperature > _SAMPLING_EPS:
|
229 |
+
raise ValueError("temperature must be 0 when using beam search.")
|
230 |
+
if self.top_p < 1.0 - _SAMPLING_EPS:
|
231 |
+
raise ValueError("top_p must be 1 when using beam search.")
|
232 |
+
if self.top_k != -1:
|
233 |
+
raise ValueError("top_k must be -1 when using beam search.")
|
234 |
+
if self.early_stopping not in [True, False, "never"]:
|
235 |
+
raise ValueError(
|
236 |
+
f"early_stopping must be True, False, or 'never', "
|
237 |
+
f"got {self.early_stopping}."
|
238 |
+
)
|
239 |
+
|
240 |
+
def _verify_non_beam_search(self) -> None:
|
241 |
+
if self.early_stopping is not False:
|
242 |
+
raise ValueError(
|
243 |
+
"early_stopping is not effective and must be "
|
244 |
+
"False when not using beam search."
|
245 |
+
)
|
246 |
+
if (
|
247 |
+
self.length_penalty < 1.0 - _SAMPLING_EPS
|
248 |
+
or self.length_penalty > 1.0 + _SAMPLING_EPS
|
249 |
+
):
|
250 |
+
raise ValueError(
|
251 |
+
"length_penalty is not effective and must be the "
|
252 |
+
"default value of 1.0 when not using beam search."
|
253 |
+
)
|
254 |
+
|
255 |
+
def _verify_greedy_sampling(self) -> None:
|
256 |
+
if self.best_of > 1:
|
257 |
+
raise ValueError(
|
258 |
+
"best_of must be 1 when using greedy sampling." f"Got {self.best_of}."
|
259 |
+
)
|
260 |
+
|
261 |
+
@cached_property
|
262 |
+
def sampling_type(self) -> SamplingType:
|
263 |
+
if self.use_beam_search:
|
264 |
+
return SamplingType.BEAM
|
265 |
+
# if self.temperature < _SAMPLING_EPS:
|
266 |
+
# return SamplingType.GREEDY
|
267 |
+
return SamplingType.RANDOM
|
268 |
+
|
269 |
+
def __repr__(self) -> str:
|
270 |
+
return (
|
271 |
+
f"SamplingParams(n={self.n}, "
|
272 |
+
f"best_of={self.best_of}, "
|
273 |
+
f"presence_penalty={self.presence_penalty}, "
|
274 |
+
f"frequency_penalty={self.frequency_penalty}, "
|
275 |
+
f"repetition_penalty={self.repetition_penalty}, "
|
276 |
+
f"temperature={self.temperature}, "
|
277 |
+
f"top_p={self.top_p}, "
|
278 |
+
f"top_k={self.top_k}, "
|
279 |
+
f"min_p={self.min_p}, "
|
280 |
+
f"use_beam_search={self.use_beam_search}, "
|
281 |
+
f"length_penalty={self.length_penalty}, "
|
282 |
+
f"early_stopping={self.early_stopping}, "
|
283 |
+
f"stop={self.stop}, "
|
284 |
+
f"stop_token_ids={self.stop_token_ids}, "
|
285 |
+
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
|
286 |
+
f"ignore_eos={self.ignore_eos}, "
|
287 |
+
f"max_tokens={self.max_tokens}, "
|
288 |
+
f"logprobs={self.logprobs}, "
|
289 |
+
f"prompt_logprobs={self.prompt_logprobs}, "
|
290 |
+
f"skip_special_tokens={self.skip_special_tokens}, "
|
291 |
+
"spaces_between_special_tokens="
|
292 |
+
f"{self.spaces_between_special_tokens}), "
|
293 |
+
f"max_new_token={self.max_new_token}), "
|
294 |
+
f"min_new_token={self.min_new_token}), "
|
295 |
+
f"infer_text={self.infer_text})"
|
296 |
+
)
|
ChatTTS/model/velocity/scheduler.py
ADDED
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import enum
|
2 |
+
import time
|
3 |
+
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
from vllm.config import CacheConfig, SchedulerConfig
|
6 |
+
from .block_manager import AllocStatus, BlockSpaceManager
|
7 |
+
from vllm.core.policy import PolicyFactory
|
8 |
+
from vllm.logger import init_logger
|
9 |
+
from .sequence import (
|
10 |
+
Sequence,
|
11 |
+
SequenceData,
|
12 |
+
SequenceGroup,
|
13 |
+
SequenceGroupMetadata,
|
14 |
+
SequenceStatus,
|
15 |
+
)
|
16 |
+
|
17 |
+
logger = init_logger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
class PreemptionMode(enum.Enum):
|
21 |
+
"""Preemption modes.
|
22 |
+
|
23 |
+
1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
|
24 |
+
and swap them back in when the sequences are resumed.
|
25 |
+
2. Recomputation: Discard the blocks of the preempted sequences and
|
26 |
+
recompute them when the sequences are resumed, treating the sequences as
|
27 |
+
new prompts.
|
28 |
+
"""
|
29 |
+
|
30 |
+
SWAP = enum.auto()
|
31 |
+
RECOMPUTE = enum.auto()
|
32 |
+
|
33 |
+
|
34 |
+
class SchedulerOutputs:
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
scheduled_seq_groups: List[SequenceGroup],
|
39 |
+
prompt_run: bool,
|
40 |
+
num_batched_tokens: int,
|
41 |
+
blocks_to_swap_in: Dict[int, int],
|
42 |
+
blocks_to_swap_out: Dict[int, int],
|
43 |
+
blocks_to_copy: Dict[int, List[int]],
|
44 |
+
ignored_seq_groups: List[SequenceGroup],
|
45 |
+
) -> None:
|
46 |
+
self.scheduled_seq_groups = scheduled_seq_groups
|
47 |
+
self.prompt_run = prompt_run
|
48 |
+
self.num_batched_tokens = num_batched_tokens
|
49 |
+
self.blocks_to_swap_in = blocks_to_swap_in
|
50 |
+
self.blocks_to_swap_out = blocks_to_swap_out
|
51 |
+
self.blocks_to_copy = blocks_to_copy
|
52 |
+
# Swap in and swap out should never happen at the same time.
|
53 |
+
assert not (blocks_to_swap_in and blocks_to_swap_out)
|
54 |
+
self.ignored_seq_groups = ignored_seq_groups
|
55 |
+
|
56 |
+
def is_empty(self) -> bool:
|
57 |
+
# NOTE: We do not consider the ignored sequence groups.
|
58 |
+
return (
|
59 |
+
not self.scheduled_seq_groups
|
60 |
+
and not self.blocks_to_swap_in
|
61 |
+
and not self.blocks_to_swap_out
|
62 |
+
and not self.blocks_to_copy
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
class Scheduler:
|
67 |
+
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
scheduler_config: SchedulerConfig,
|
71 |
+
cache_config: CacheConfig,
|
72 |
+
) -> None:
|
73 |
+
self.scheduler_config = scheduler_config
|
74 |
+
self.cache_config = cache_config
|
75 |
+
|
76 |
+
self.prompt_limit = min(
|
77 |
+
self.scheduler_config.max_model_len,
|
78 |
+
self.scheduler_config.max_num_batched_tokens,
|
79 |
+
)
|
80 |
+
|
81 |
+
# Instantiate the scheduling policy.
|
82 |
+
self.policy = PolicyFactory.get_policy(policy_name="fcfs")
|
83 |
+
# Create the block space manager.
|
84 |
+
self.block_manager = BlockSpaceManager(
|
85 |
+
block_size=self.cache_config.block_size,
|
86 |
+
num_gpu_blocks=self.cache_config.num_gpu_blocks,
|
87 |
+
num_cpu_blocks=self.cache_config.num_cpu_blocks,
|
88 |
+
sliding_window=self.cache_config.sliding_window,
|
89 |
+
)
|
90 |
+
|
91 |
+
# TODO(zhuohan): Use deque instead of list for better performance.
|
92 |
+
# Sequence groups in the WAITING state.
|
93 |
+
self.waiting: List[SequenceGroup] = []
|
94 |
+
# Sequence groups in the RUNNING state.
|
95 |
+
self.running: List[SequenceGroup] = []
|
96 |
+
# Sequence groups in the SWAPPED state.
|
97 |
+
self.swapped: List[SequenceGroup] = []
|
98 |
+
|
99 |
+
def add_seq_group(self, seq_group: SequenceGroup) -> None:
|
100 |
+
# Add sequence groups to the waiting queue.
|
101 |
+
self.waiting.append(seq_group)
|
102 |
+
|
103 |
+
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
|
104 |
+
if isinstance(request_id, str):
|
105 |
+
request_id = (request_id,)
|
106 |
+
request_ids = set(request_id)
|
107 |
+
for state_queue in [self.waiting, self.running, self.swapped]:
|
108 |
+
# We need to reverse the list as we are removing elements
|
109 |
+
# from it as we iterate over it. If we don't do it,
|
110 |
+
# indices will get messed up and we will skip over elements.
|
111 |
+
for seq_group in reversed(state_queue):
|
112 |
+
if seq_group.request_id in request_ids:
|
113 |
+
# Remove the sequence group from the state queue.
|
114 |
+
state_queue.remove(seq_group)
|
115 |
+
for seq in seq_group.get_seqs():
|
116 |
+
if seq.is_finished():
|
117 |
+
continue
|
118 |
+
seq.status = SequenceStatus.FINISHED_ABORTED
|
119 |
+
self.free_seq(seq)
|
120 |
+
request_ids.remove(seq_group.request_id)
|
121 |
+
if not request_ids:
|
122 |
+
return
|
123 |
+
|
124 |
+
def has_unfinished_seqs(self) -> bool:
|
125 |
+
return self.waiting or self.running or self.swapped
|
126 |
+
|
127 |
+
def get_num_unfinished_seq_groups(self) -> int:
|
128 |
+
return len(self.waiting) + len(self.running) + len(self.swapped)
|
129 |
+
|
130 |
+
def _schedule(self) -> SchedulerOutputs:
|
131 |
+
# Blocks that need to be swaped or copied before model execution.
|
132 |
+
blocks_to_swap_in: Dict[int, int] = {}
|
133 |
+
blocks_to_swap_out: Dict[int, int] = {}
|
134 |
+
blocks_to_copy: Dict[int, List[int]] = {}
|
135 |
+
|
136 |
+
# Fix the current time.
|
137 |
+
now = time.monotonic()
|
138 |
+
|
139 |
+
# Join waiting sequences if possible.
|
140 |
+
if not self.swapped:
|
141 |
+
ignored_seq_groups: List[SequenceGroup] = []
|
142 |
+
scheduled: List[SequenceGroup] = []
|
143 |
+
# The total number of sequences on the fly, including the
|
144 |
+
# requests in the generation phase.
|
145 |
+
num_curr_seqs = sum(
|
146 |
+
seq_group.get_max_num_running_seqs() for seq_group in self.running
|
147 |
+
)
|
148 |
+
seq_lens: List[int] = []
|
149 |
+
|
150 |
+
# Optimization: We do not sort the waiting queue since the preempted
|
151 |
+
# sequence groups are added to the front and the new sequence groups
|
152 |
+
# are added to the back.
|
153 |
+
while self.waiting:
|
154 |
+
seq_group = self.waiting[0]
|
155 |
+
|
156 |
+
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
|
157 |
+
assert len(waiting_seqs) == 1, (
|
158 |
+
"Waiting sequence group should have only one prompt " "sequence."
|
159 |
+
)
|
160 |
+
num_prompt_tokens = waiting_seqs[0].get_len()
|
161 |
+
if num_prompt_tokens > self.prompt_limit:
|
162 |
+
logger.warning(
|
163 |
+
f"Input prompt ({num_prompt_tokens} tokens) is too long"
|
164 |
+
f" and exceeds limit of {self.prompt_limit}"
|
165 |
+
)
|
166 |
+
for seq in waiting_seqs:
|
167 |
+
seq.status = SequenceStatus.FINISHED_IGNORED
|
168 |
+
ignored_seq_groups.append(seq_group)
|
169 |
+
self.waiting.pop(0)
|
170 |
+
continue
|
171 |
+
|
172 |
+
# If the sequence group cannot be allocated, stop.
|
173 |
+
can_allocate = self.block_manager.can_allocate(seq_group)
|
174 |
+
if can_allocate == AllocStatus.LATER:
|
175 |
+
break
|
176 |
+
elif can_allocate == AllocStatus.NEVER:
|
177 |
+
logger.warning(
|
178 |
+
f"Input prompt ({num_prompt_tokens} tokens) is too long"
|
179 |
+
f" and exceeds the capacity of block_manager"
|
180 |
+
)
|
181 |
+
for seq in waiting_seqs:
|
182 |
+
seq.status = SequenceStatus.FINISHED_IGNORED
|
183 |
+
ignored_seq_groups.append(seq_group)
|
184 |
+
self.waiting.pop(0)
|
185 |
+
continue
|
186 |
+
|
187 |
+
# If the number of batched tokens exceeds the limit, stop.
|
188 |
+
new_seq_lens = seq_lens + [num_prompt_tokens]
|
189 |
+
num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
|
190 |
+
if num_batched_tokens > self.scheduler_config.max_num_batched_tokens:
|
191 |
+
break
|
192 |
+
|
193 |
+
# The total number of sequences in the RUNNING state should not
|
194 |
+
# exceed the maximum number of sequences.
|
195 |
+
num_new_seqs = seq_group.get_max_num_running_seqs()
|
196 |
+
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
|
197 |
+
break
|
198 |
+
|
199 |
+
num_paddings = num_batched_tokens - sum(new_seq_lens)
|
200 |
+
if num_paddings > self.scheduler_config.max_paddings:
|
201 |
+
break
|
202 |
+
seq_lens = new_seq_lens
|
203 |
+
|
204 |
+
seq_group = self.waiting.pop(0)
|
205 |
+
self._allocate(seq_group)
|
206 |
+
self.running.append(seq_group)
|
207 |
+
num_curr_seqs += num_new_seqs
|
208 |
+
scheduled.append(seq_group)
|
209 |
+
|
210 |
+
if scheduled or ignored_seq_groups:
|
211 |
+
scheduler_outputs = SchedulerOutputs(
|
212 |
+
scheduled_seq_groups=scheduled,
|
213 |
+
prompt_run=True,
|
214 |
+
num_batched_tokens=len(seq_lens) * max(seq_lens) if seq_lens else 0,
|
215 |
+
blocks_to_swap_in=blocks_to_swap_in,
|
216 |
+
blocks_to_swap_out=blocks_to_swap_out,
|
217 |
+
blocks_to_copy=blocks_to_copy,
|
218 |
+
ignored_seq_groups=ignored_seq_groups,
|
219 |
+
)
|
220 |
+
return scheduler_outputs
|
221 |
+
|
222 |
+
# NOTE(woosuk): Preemption happens only when there is no available slot
|
223 |
+
# to keep all the sequence groups in the RUNNING state.
|
224 |
+
# In this case, the policy is responsible for deciding which sequence
|
225 |
+
# groups to preempt.
|
226 |
+
self.running = self.policy.sort_by_priority(now, self.running)
|
227 |
+
|
228 |
+
# Reserve new token slots for the running sequence groups.
|
229 |
+
running: List[SequenceGroup] = []
|
230 |
+
preempted: List[SequenceGroup] = []
|
231 |
+
while self.running:
|
232 |
+
seq_group = self.running.pop(0)
|
233 |
+
while not self.block_manager.can_append_slot(seq_group):
|
234 |
+
if self.running:
|
235 |
+
# Preempt the lowest-priority sequence groups.
|
236 |
+
victim_seq_group = self.running.pop(-1)
|
237 |
+
self._preempt(victim_seq_group, blocks_to_swap_out)
|
238 |
+
preempted.append(victim_seq_group)
|
239 |
+
else:
|
240 |
+
# No other sequence groups can be preempted.
|
241 |
+
# Preempt the current sequence group.
|
242 |
+
self._preempt(seq_group, blocks_to_swap_out)
|
243 |
+
preempted.append(seq_group)
|
244 |
+
break
|
245 |
+
else:
|
246 |
+
# Append new slots to the sequence group.
|
247 |
+
self._append_slot(seq_group, blocks_to_copy)
|
248 |
+
running.append(seq_group)
|
249 |
+
self.running = running
|
250 |
+
|
251 |
+
# Swap in the sequence groups in the SWAPPED state if possible.
|
252 |
+
self.swapped = self.policy.sort_by_priority(now, self.swapped)
|
253 |
+
if not preempted:
|
254 |
+
num_curr_seqs = sum(
|
255 |
+
seq_group.get_max_num_running_seqs() for seq_group in self.running
|
256 |
+
)
|
257 |
+
|
258 |
+
while self.swapped:
|
259 |
+
seq_group = self.swapped[0]
|
260 |
+
# If the sequence group cannot be swapped in, stop.
|
261 |
+
if not self.block_manager.can_swap_in(seq_group):
|
262 |
+
break
|
263 |
+
|
264 |
+
# The total number of sequences in the RUNNING state should not
|
265 |
+
# exceed the maximum number of sequences.
|
266 |
+
num_new_seqs = seq_group.get_max_num_running_seqs()
|
267 |
+
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
|
268 |
+
break
|
269 |
+
|
270 |
+
seq_group = self.swapped.pop(0)
|
271 |
+
self._swap_in(seq_group, blocks_to_swap_in)
|
272 |
+
self._append_slot(seq_group, blocks_to_copy)
|
273 |
+
num_curr_seqs += num_new_seqs
|
274 |
+
self.running.append(seq_group)
|
275 |
+
|
276 |
+
# Each sequence in the generation phase only takes one token slot.
|
277 |
+
# Therefore, the number of batched tokens is equal to the number of
|
278 |
+
# sequences in the RUNNING state.
|
279 |
+
num_batched_tokens = sum(
|
280 |
+
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
281 |
+
for seq_group in self.running
|
282 |
+
)
|
283 |
+
|
284 |
+
scheduler_outputs = SchedulerOutputs(
|
285 |
+
scheduled_seq_groups=self.running,
|
286 |
+
prompt_run=False,
|
287 |
+
num_batched_tokens=num_batched_tokens,
|
288 |
+
blocks_to_swap_in=blocks_to_swap_in,
|
289 |
+
blocks_to_swap_out=blocks_to_swap_out,
|
290 |
+
blocks_to_copy=blocks_to_copy,
|
291 |
+
ignored_seq_groups=[],
|
292 |
+
)
|
293 |
+
return scheduler_outputs
|
294 |
+
|
295 |
+
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
|
296 |
+
# Schedule sequence groups.
|
297 |
+
# This function call changes the internal states of the scheduler
|
298 |
+
# such as self.running, self.swapped, and self.waiting.
|
299 |
+
scheduler_outputs = self._schedule()
|
300 |
+
|
301 |
+
# Create input data structures.
|
302 |
+
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
303 |
+
for seq_group in scheduler_outputs.scheduled_seq_groups:
|
304 |
+
seq_data: Dict[int, SequenceData] = {}
|
305 |
+
block_tables: Dict[int, List[int]] = {}
|
306 |
+
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
307 |
+
seq_id = seq.seq_id
|
308 |
+
seq_data[seq_id] = seq.data
|
309 |
+
block_tables[seq_id] = self.block_manager.get_block_table(seq)
|
310 |
+
|
311 |
+
seq_group_metadata = SequenceGroupMetadata(
|
312 |
+
request_id=seq_group.request_id,
|
313 |
+
is_prompt=scheduler_outputs.prompt_run,
|
314 |
+
seq_data=seq_data,
|
315 |
+
sampling_params=seq_group.sampling_params,
|
316 |
+
block_tables=block_tables,
|
317 |
+
)
|
318 |
+
seq_group_metadata_list.append(seq_group_metadata)
|
319 |
+
return seq_group_metadata_list, scheduler_outputs
|
320 |
+
|
321 |
+
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
322 |
+
self.block_manager.fork(parent_seq, child_seq)
|
323 |
+
|
324 |
+
def free_seq(self, seq: Sequence) -> None:
|
325 |
+
self.block_manager.free(seq)
|
326 |
+
|
327 |
+
def free_finished_seq_groups(self) -> None:
|
328 |
+
self.running = [
|
329 |
+
seq_group for seq_group in self.running if not seq_group.is_finished()
|
330 |
+
]
|
331 |
+
|
332 |
+
def _allocate(self, seq_group: SequenceGroup) -> None:
|
333 |
+
self.block_manager.allocate(seq_group)
|
334 |
+
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
335 |
+
seq.status = SequenceStatus.RUNNING
|
336 |
+
|
337 |
+
def _append_slot(
|
338 |
+
self,
|
339 |
+
seq_group: SequenceGroup,
|
340 |
+
blocks_to_copy: Dict[int, List[int]],
|
341 |
+
) -> None:
|
342 |
+
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
343 |
+
ret = self.block_manager.append_slot(seq)
|
344 |
+
if ret is not None:
|
345 |
+
src_block, dst_block = ret
|
346 |
+
if src_block in blocks_to_copy:
|
347 |
+
blocks_to_copy[src_block].append(dst_block)
|
348 |
+
else:
|
349 |
+
blocks_to_copy[src_block] = [dst_block]
|
350 |
+
|
351 |
+
def _preempt(
|
352 |
+
self,
|
353 |
+
seq_group: SequenceGroup,
|
354 |
+
blocks_to_swap_out: Dict[int, int],
|
355 |
+
preemption_mode: Optional[PreemptionMode] = None,
|
356 |
+
) -> None:
|
357 |
+
# If preemption mode is not specified, we determine the mode as follows:
|
358 |
+
# We use recomputation by default since it incurs lower overhead than
|
359 |
+
# swapping. However, when the sequence group has multiple sequences
|
360 |
+
# (e.g., beam search), recomputation is not currently supported. In
|
361 |
+
# such a case, we use swapping instead.
|
362 |
+
# FIXME(woosuk): This makes our scheduling policy a bit bizarre.
|
363 |
+
# As swapped sequences are prioritized over waiting sequences,
|
364 |
+
# sequence groups with multiple sequences are implicitly prioritized
|
365 |
+
# over sequence groups with a single sequence.
|
366 |
+
# TODO(woosuk): Support recomputation for sequence groups with multiple
|
367 |
+
# sequences. This may require a more sophisticated CUDA kernel.
|
368 |
+
if preemption_mode is None:
|
369 |
+
if seq_group.get_max_num_running_seqs() == 1:
|
370 |
+
preemption_mode = PreemptionMode.RECOMPUTE
|
371 |
+
else:
|
372 |
+
preemption_mode = PreemptionMode.SWAP
|
373 |
+
if preemption_mode == PreemptionMode.RECOMPUTE:
|
374 |
+
self._preempt_by_recompute(seq_group)
|
375 |
+
elif preemption_mode == PreemptionMode.SWAP:
|
376 |
+
self._preempt_by_swap(seq_group, blocks_to_swap_out)
|
377 |
+
else:
|
378 |
+
raise AssertionError("Invalid preemption mode.")
|
379 |
+
|
380 |
+
def _preempt_by_recompute(
|
381 |
+
self,
|
382 |
+
seq_group: SequenceGroup,
|
383 |
+
) -> None:
|
384 |
+
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
385 |
+
assert len(seqs) == 1
|
386 |
+
for seq in seqs:
|
387 |
+
seq.status = SequenceStatus.WAITING
|
388 |
+
self.block_manager.free(seq)
|
389 |
+
# NOTE: For FCFS, we insert the preempted sequence group to the front
|
390 |
+
# of the waiting queue.
|
391 |
+
self.waiting.insert(0, seq_group)
|
392 |
+
|
393 |
+
def _preempt_by_swap(
|
394 |
+
self,
|
395 |
+
seq_group: SequenceGroup,
|
396 |
+
blocks_to_swap_out: Dict[int, int],
|
397 |
+
) -> None:
|
398 |
+
self._swap_out(seq_group, blocks_to_swap_out)
|
399 |
+
self.swapped.append(seq_group)
|
400 |
+
|
401 |
+
def _swap_in(
|
402 |
+
self,
|
403 |
+
seq_group: SequenceGroup,
|
404 |
+
blocks_to_swap_in: Dict[int, int],
|
405 |
+
) -> None:
|
406 |
+
mapping = self.block_manager.swap_in(seq_group)
|
407 |
+
blocks_to_swap_in.update(mapping)
|
408 |
+
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
409 |
+
seq.status = SequenceStatus.RUNNING
|
410 |
+
|
411 |
+
def _swap_out(
|
412 |
+
self,
|
413 |
+
seq_group: SequenceGroup,
|
414 |
+
blocks_to_swap_out: Dict[int, int],
|
415 |
+
) -> None:
|
416 |
+
if not self.block_manager.can_swap_out(seq_group):
|
417 |
+
# FIXME(woosuk): Abort the sequence group instead of aborting the
|
418 |
+
# entire engine.
|
419 |
+
raise RuntimeError(
|
420 |
+
"Aborted due to the lack of CPU swap space. Please increase "
|
421 |
+
"the swap space to avoid this error."
|
422 |
+
)
|
423 |
+
mapping = self.block_manager.swap_out(seq_group)
|
424 |
+
blocks_to_swap_out.update(mapping)
|
425 |
+
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
426 |
+
seq.status = SequenceStatus.SWAPPED
|
ChatTTS/model/velocity/sequence.py
ADDED
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Sequence and its related classes."""
|
2 |
+
|
3 |
+
import copy
|
4 |
+
import enum
|
5 |
+
from typing import Dict, List, Optional, Union
|
6 |
+
import torch
|
7 |
+
from vllm.block import LogicalTokenBlock
|
8 |
+
from .sampling_params import SamplingParams
|
9 |
+
|
10 |
+
PromptLogprobs = List[Optional[Dict[int, float]]]
|
11 |
+
SampleLogprobs = List[Dict[int, float]]
|
12 |
+
|
13 |
+
|
14 |
+
class SequenceStatus(enum.Enum):
|
15 |
+
"""Status of a sequence."""
|
16 |
+
|
17 |
+
WAITING = enum.auto()
|
18 |
+
RUNNING = enum.auto()
|
19 |
+
SWAPPED = enum.auto()
|
20 |
+
FINISHED_STOPPED = enum.auto()
|
21 |
+
FINISHED_LENGTH_CAPPED = enum.auto()
|
22 |
+
FINISHED_ABORTED = enum.auto()
|
23 |
+
FINISHED_IGNORED = enum.auto()
|
24 |
+
|
25 |
+
@staticmethod
|
26 |
+
def is_finished(status: "SequenceStatus") -> bool:
|
27 |
+
return status in [
|
28 |
+
SequenceStatus.FINISHED_STOPPED,
|
29 |
+
SequenceStatus.FINISHED_LENGTH_CAPPED,
|
30 |
+
SequenceStatus.FINISHED_ABORTED,
|
31 |
+
SequenceStatus.FINISHED_IGNORED,
|
32 |
+
]
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
|
36 |
+
if status == SequenceStatus.FINISHED_STOPPED:
|
37 |
+
finish_reason = "stop"
|
38 |
+
elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
|
39 |
+
finish_reason = "length"
|
40 |
+
elif status == SequenceStatus.FINISHED_ABORTED:
|
41 |
+
finish_reason = "abort"
|
42 |
+
elif status == SequenceStatus.FINISHED_IGNORED:
|
43 |
+
# The ignored sequences are the sequences whose prompt lengths
|
44 |
+
# are longer than the model's length cap. Therefore, the stop
|
45 |
+
# reason should also be "length" as in OpenAI API.
|
46 |
+
finish_reason = "length"
|
47 |
+
else:
|
48 |
+
finish_reason = None
|
49 |
+
return finish_reason
|
50 |
+
|
51 |
+
|
52 |
+
class SequenceData:
|
53 |
+
"""Data associated with a sequence.
|
54 |
+
|
55 |
+
|
56 |
+
Args:
|
57 |
+
prompt_token_ids: The token IDs of the prompt.
|
58 |
+
|
59 |
+
Attributes:
|
60 |
+
prompt_token_ids: The token IDs of the prompt.
|
61 |
+
output_token_ids: The token IDs of the output.
|
62 |
+
cumulative_logprob: The cumulative log probability of the output.
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
prompt_token_ids: List[int],
|
68 |
+
) -> None:
|
69 |
+
self.prompt_token_ids = prompt_token_ids
|
70 |
+
self.output_token_ids: List[int] = []
|
71 |
+
self.cumulative_logprob = 0.0
|
72 |
+
self.hidden_states: Optional[torch.Tensor] = None
|
73 |
+
self.finished = False
|
74 |
+
|
75 |
+
def append_token_id(self, token_id: int, logprob: float) -> None:
|
76 |
+
if isinstance(self.cumulative_logprob, float):
|
77 |
+
self.cumulative_logprob = [
|
78 |
+
0.0,
|
79 |
+
] * len(logprob)
|
80 |
+
self.output_token_ids.append(token_id)
|
81 |
+
for i in range(len(self.cumulative_logprob)):
|
82 |
+
self.cumulative_logprob[i] += logprob[i]
|
83 |
+
|
84 |
+
def append_hidden_states(self, hidden_states: torch.Tensor) -> None:
|
85 |
+
if self.hidden_states is None:
|
86 |
+
self.hidden_states = hidden_states
|
87 |
+
else:
|
88 |
+
self.hidden_states = torch.cat([self.hidden_states, hidden_states], dim=0)
|
89 |
+
|
90 |
+
def get_len(self) -> int:
|
91 |
+
return len(self.output_token_ids) + len(self.prompt_token_ids)
|
92 |
+
|
93 |
+
def get_prompt_len(self) -> int:
|
94 |
+
return len(self.prompt_token_ids)
|
95 |
+
|
96 |
+
def get_output_len(self) -> int:
|
97 |
+
return len(self.output_token_ids)
|
98 |
+
|
99 |
+
def get_token_ids(self) -> List[int]:
|
100 |
+
return self.prompt_token_ids + self.output_token_ids
|
101 |
+
|
102 |
+
def get_last_token_id(self) -> int:
|
103 |
+
if not self.output_token_ids:
|
104 |
+
return self.prompt_token_ids[-1]
|
105 |
+
return self.output_token_ids[-1]
|
106 |
+
|
107 |
+
def __repr__(self) -> str:
|
108 |
+
return (
|
109 |
+
f"SequenceData("
|
110 |
+
f"prompt_token_ids={self.prompt_token_ids}, "
|
111 |
+
f"output_token_ids={self.output_token_ids}, "
|
112 |
+
f"cumulative_logprob={self.cumulative_logprob}), "
|
113 |
+
f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None}, "
|
114 |
+
f"finished={self.finished})"
|
115 |
+
)
|
116 |
+
|
117 |
+
|
118 |
+
class Sequence:
|
119 |
+
"""Stores the data, status, and block information of a sequence.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
seq_id: The ID of the sequence.
|
123 |
+
prompt: The prompt of the sequence.
|
124 |
+
prompt_token_ids: The token IDs of the prompt.
|
125 |
+
block_size: The block size of the sequence. Should be the same as the
|
126 |
+
block size used by the block manager and cache engine.
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
seq_id: int,
|
132 |
+
prompt: str,
|
133 |
+
prompt_token_ids: List[int],
|
134 |
+
block_size: int,
|
135 |
+
) -> None:
|
136 |
+
self.seq_id = seq_id
|
137 |
+
self.prompt = prompt
|
138 |
+
self.block_size = block_size
|
139 |
+
|
140 |
+
self.data = SequenceData(prompt_token_ids)
|
141 |
+
self.output_logprobs: SampleLogprobs = []
|
142 |
+
self.output_text = ""
|
143 |
+
|
144 |
+
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
145 |
+
# Initialize the logical token blocks with the prompt token ids.
|
146 |
+
self._append_tokens_to_blocks(prompt_token_ids)
|
147 |
+
self.status = SequenceStatus.WAITING
|
148 |
+
|
149 |
+
# Used for incremental detokenization
|
150 |
+
self.prefix_offset = 0
|
151 |
+
self.read_offset = 0
|
152 |
+
# Input + output tokens
|
153 |
+
self.tokens: Optional[List[str]] = None
|
154 |
+
|
155 |
+
def _append_logical_block(self) -> None:
|
156 |
+
block = LogicalTokenBlock(
|
157 |
+
block_number=len(self.logical_token_blocks),
|
158 |
+
block_size=self.block_size,
|
159 |
+
)
|
160 |
+
self.logical_token_blocks.append(block)
|
161 |
+
|
162 |
+
def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
|
163 |
+
cursor = 0
|
164 |
+
while cursor < len(token_ids):
|
165 |
+
if not self.logical_token_blocks:
|
166 |
+
self._append_logical_block()
|
167 |
+
|
168 |
+
last_block = self.logical_token_blocks[-1]
|
169 |
+
if last_block.is_full():
|
170 |
+
self._append_logical_block()
|
171 |
+
last_block = self.logical_token_blocks[-1]
|
172 |
+
|
173 |
+
num_empty_slots = last_block.get_num_empty_slots()
|
174 |
+
last_block.append_tokens(token_ids[cursor : cursor + num_empty_slots])
|
175 |
+
cursor += num_empty_slots
|
176 |
+
|
177 |
+
def append_token_id(
|
178 |
+
self,
|
179 |
+
token_id: int,
|
180 |
+
logprobs: Dict[int, float],
|
181 |
+
hidden_states: Optional[torch.Tensor] = None,
|
182 |
+
finished: bool = False,
|
183 |
+
) -> None:
|
184 |
+
assert token_id in logprobs
|
185 |
+
self._append_tokens_to_blocks([token_id])
|
186 |
+
self.output_logprobs.append(logprobs)
|
187 |
+
self.data.append_token_id(token_id, logprobs[token_id])
|
188 |
+
self.data.append_hidden_states(hidden_states)
|
189 |
+
self.data.finished = finished
|
190 |
+
|
191 |
+
def get_len(self) -> int:
|
192 |
+
return self.data.get_len()
|
193 |
+
|
194 |
+
def get_prompt_len(self) -> int:
|
195 |
+
return self.data.get_prompt_len()
|
196 |
+
|
197 |
+
def get_output_len(self) -> int:
|
198 |
+
return self.data.get_output_len()
|
199 |
+
|
200 |
+
def get_token_ids(self) -> List[int]:
|
201 |
+
return self.data.get_token_ids()
|
202 |
+
|
203 |
+
def get_last_token_id(self) -> int:
|
204 |
+
return self.data.get_last_token_id()
|
205 |
+
|
206 |
+
def get_output_token_ids(self) -> List[int]:
|
207 |
+
return self.data.output_token_ids
|
208 |
+
|
209 |
+
def get_cumulative_logprob(self) -> float:
|
210 |
+
return self.data.cumulative_logprob
|
211 |
+
|
212 |
+
def get_beam_search_score(
|
213 |
+
self,
|
214 |
+
length_penalty: float = 0.0,
|
215 |
+
seq_len: Optional[int] = None,
|
216 |
+
eos_token_id: Optional[int] = None,
|
217 |
+
) -> float:
|
218 |
+
"""Calculate the beam search score with length penalty.
|
219 |
+
|
220 |
+
Adapted from
|
221 |
+
|
222 |
+
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
|
223 |
+
"""
|
224 |
+
if seq_len is None:
|
225 |
+
seq_len = self.get_len()
|
226 |
+
# NOTE: HF implementation does not count the EOS token
|
227 |
+
# towards the length, we align with that here for testing.
|
228 |
+
if eos_token_id is not None and self.get_last_token_id() == eos_token_id:
|
229 |
+
seq_len -= 1
|
230 |
+
return self.get_cumulative_logprob() / (seq_len**length_penalty)
|
231 |
+
|
232 |
+
def is_finished(self) -> bool:
|
233 |
+
return SequenceStatus.is_finished(self.status)
|
234 |
+
|
235 |
+
def fork(self, new_seq_id: int) -> "Sequence":
|
236 |
+
new_seq = copy.deepcopy(self)
|
237 |
+
new_seq.seq_id = new_seq_id
|
238 |
+
return new_seq
|
239 |
+
|
240 |
+
def __repr__(self) -> str:
|
241 |
+
return (
|
242 |
+
f"Sequence(seq_id={self.seq_id}, "
|
243 |
+
f"status={self.status.name}, "
|
244 |
+
f"num_blocks={len(self.logical_token_blocks)})"
|
245 |
+
)
|
246 |
+
|
247 |
+
|
248 |
+
class SequenceGroup:
|
249 |
+
"""A group of sequences that are generated from the same prompt.
|
250 |
+
|
251 |
+
Args:
|
252 |
+
request_id: The ID of the request.
|
253 |
+
seqs: The list of sequences.
|
254 |
+
sampling_params: The sampling parameters used to generate the outputs.
|
255 |
+
arrival_time: The arrival time of the request.
|
256 |
+
"""
|
257 |
+
|
258 |
+
def __init__(
|
259 |
+
self,
|
260 |
+
request_id: str,
|
261 |
+
seqs: List[Sequence],
|
262 |
+
sampling_params: SamplingParams,
|
263 |
+
arrival_time: float,
|
264 |
+
) -> None:
|
265 |
+
self.request_id = request_id
|
266 |
+
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
267 |
+
self.sampling_params = sampling_params
|
268 |
+
self.arrival_time = arrival_time
|
269 |
+
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
270 |
+
|
271 |
+
@property
|
272 |
+
def prompt(self) -> str:
|
273 |
+
# All sequences in the group should have the same prompt.
|
274 |
+
# We use the prompt of an arbitrary sequence.
|
275 |
+
return next(iter(self.seqs_dict.values())).prompt
|
276 |
+
|
277 |
+
@property
|
278 |
+
def prompt_token_ids(self) -> List[int]:
|
279 |
+
# All sequences in the group should have the same prompt.
|
280 |
+
# We use the prompt of an arbitrary sequence.
|
281 |
+
return next(iter(self.seqs_dict.values())).data.prompt_token_ids
|
282 |
+
|
283 |
+
def get_max_num_running_seqs(self) -> int:
|
284 |
+
"""The maximum number of sequences running in parallel in the remaining
|
285 |
+
lifetime of the request."""
|
286 |
+
if self.sampling_params.use_beam_search:
|
287 |
+
# For beam search, maximally there will always be `best_of` beam
|
288 |
+
# candidates running in the future.
|
289 |
+
return self.sampling_params.best_of
|
290 |
+
else:
|
291 |
+
if self.sampling_params.best_of > self.num_seqs():
|
292 |
+
# At prompt stage, the sequence group is not yet filled up
|
293 |
+
# and only have one sequence running. However, in the
|
294 |
+
# generation stage, we will have `best_of` sequences running.
|
295 |
+
return self.sampling_params.best_of
|
296 |
+
# At sampling stages, return the number of actual sequences
|
297 |
+
# that are not finished yet.
|
298 |
+
return self.num_unfinished_seqs()
|
299 |
+
|
300 |
+
def get_seqs(
|
301 |
+
self,
|
302 |
+
status: Optional[SequenceStatus] = None,
|
303 |
+
) -> List[Sequence]:
|
304 |
+
if status is None:
|
305 |
+
return list(self.seqs_dict.values())
|
306 |
+
else:
|
307 |
+
return [seq for seq in self.seqs_dict.values() if seq.status == status]
|
308 |
+
|
309 |
+
def get_unfinished_seqs(self) -> List[Sequence]:
|
310 |
+
return [seq for seq in self.seqs_dict.values() if not seq.is_finished()]
|
311 |
+
|
312 |
+
def get_finished_seqs(self) -> List[Sequence]:
|
313 |
+
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
|
314 |
+
|
315 |
+
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
316 |
+
return len(self.get_seqs(status))
|
317 |
+
|
318 |
+
def num_unfinished_seqs(self) -> int:
|
319 |
+
return len(self.get_unfinished_seqs())
|
320 |
+
|
321 |
+
def num_finished_seqs(self) -> int:
|
322 |
+
return len(self.get_finished_seqs())
|
323 |
+
|
324 |
+
def find(self, seq_id: int) -> Sequence:
|
325 |
+
if seq_id not in self.seqs_dict:
|
326 |
+
raise ValueError(f"Sequence {seq_id} not found.")
|
327 |
+
return self.seqs_dict[seq_id]
|
328 |
+
|
329 |
+
def add(self, seq: Sequence) -> None:
|
330 |
+
if seq.seq_id in self.seqs_dict:
|
331 |
+
raise ValueError(f"Sequence {seq.seq_id} already exists.")
|
332 |
+
self.seqs_dict[seq.seq_id] = seq
|
333 |
+
|
334 |
+
def remove(self, seq_id: int) -> None:
|
335 |
+
if seq_id not in self.seqs_dict:
|
336 |
+
raise ValueError(f"Sequence {seq_id} not found.")
|
337 |
+
del self.seqs_dict[seq_id]
|
338 |
+
|
339 |
+
def is_finished(self) -> bool:
|
340 |
+
return all(seq.is_finished() for seq in self.get_seqs())
|
341 |
+
|
342 |
+
def __repr__(self) -> str:
|
343 |
+
return (
|
344 |
+
f"SequenceGroup(request_id={self.request_id}, "
|
345 |
+
f"sampling_params={self.sampling_params}, "
|
346 |
+
f"num_seqs={len(self.seqs_dict)})"
|
347 |
+
)
|
348 |
+
|
349 |
+
|
350 |
+
class SequenceGroupMetadata:
|
351 |
+
"""Metadata for a sequence group. Used to create `InputMetadata`.
|
352 |
+
|
353 |
+
|
354 |
+
Args:
|
355 |
+
request_id: The ID of the request.
|
356 |
+
is_prompt: Whether the request is at prompt stage.
|
357 |
+
seq_data: The sequence data. (Seq id -> sequence data)
|
358 |
+
sampling_params: The sampling parameters used to generate the outputs.
|
359 |
+
block_tables: The block tables. (Seq id -> list of physical block
|
360 |
+
numbers)
|
361 |
+
"""
|
362 |
+
|
363 |
+
def __init__(
|
364 |
+
self,
|
365 |
+
request_id: str,
|
366 |
+
is_prompt: bool,
|
367 |
+
seq_data: Dict[int, SequenceData],
|
368 |
+
sampling_params: SamplingParams,
|
369 |
+
block_tables: Dict[int, List[int]],
|
370 |
+
) -> None:
|
371 |
+
self.request_id = request_id
|
372 |
+
self.is_prompt = is_prompt
|
373 |
+
self.seq_data = seq_data
|
374 |
+
self.sampling_params = sampling_params
|
375 |
+
self.block_tables = block_tables
|
376 |
+
|
377 |
+
|
378 |
+
class SequenceOutput:
|
379 |
+
"""The model output associated with a sequence.
|
380 |
+
|
381 |
+
Args:
|
382 |
+
parent_seq_id: The ID of the parent sequence (for forking in beam
|
383 |
+
search).
|
384 |
+
output_token: The output token ID.
|
385 |
+
logprobs: The logprobs of the output token.
|
386 |
+
(Token id -> logP(x_i+1 | x_0, ..., x_i))
|
387 |
+
"""
|
388 |
+
|
389 |
+
def __init__(
|
390 |
+
self,
|
391 |
+
parent_seq_id: int,
|
392 |
+
output_token: int,
|
393 |
+
logprobs: Dict[int, float],
|
394 |
+
hidden_states: Optional[torch.Tensor] = None,
|
395 |
+
finished: bool = False,
|
396 |
+
) -> None:
|
397 |
+
self.parent_seq_id = parent_seq_id
|
398 |
+
self.output_token = output_token
|
399 |
+
self.logprobs = logprobs
|
400 |
+
self.finished = finished
|
401 |
+
self.hidden_states = hidden_states
|
402 |
+
|
403 |
+
def __repr__(self) -> str:
|
404 |
+
return (
|
405 |
+
f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
|
406 |
+
f"output_token={self.output_token}, "
|
407 |
+
f"logprobs={self.logprobs}),"
|
408 |
+
f"finished={self.finished}),"
|
409 |
+
f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None}"
|
410 |
+
)
|
411 |
+
|
412 |
+
def __eq__(self, other: object) -> bool:
|
413 |
+
if not isinstance(other, SequenceOutput):
|
414 |
+
raise NotImplementedError()
|
415 |
+
return (
|
416 |
+
self.parent_seq_id == other.parent_seq_id
|
417 |
+
and self.output_token == other.output_token
|
418 |
+
and self.logprobs == other.logprobs
|
419 |
+
)
|
420 |
+
|
421 |
+
|
422 |
+
class SequenceGroupOutput:
|
423 |
+
"""The model output associated with a sequence group."""
|
424 |
+
|
425 |
+
def __init__(
|
426 |
+
self,
|
427 |
+
samples: List[SequenceOutput],
|
428 |
+
prompt_logprobs: Optional[PromptLogprobs],
|
429 |
+
) -> None:
|
430 |
+
self.samples = samples
|
431 |
+
self.prompt_logprobs = prompt_logprobs
|
432 |
+
|
433 |
+
def __repr__(self) -> str:
|
434 |
+
return (
|
435 |
+
f"SequenceGroupOutput(samples={self.samples}, "
|
436 |
+
f"prompt_logprobs={self.prompt_logprobs})"
|
437 |
+
)
|
438 |
+
|
439 |
+
def __eq__(self, other: object) -> bool:
|
440 |
+
if not isinstance(other, SequenceGroupOutput):
|
441 |
+
raise NotImplementedError()
|
442 |
+
return (
|
443 |
+
self.samples == other.samples
|
444 |
+
and self.prompt_logprobs == other.prompt_logprobs
|
445 |
+
)
|
446 |
+
|
447 |
+
|
448 |
+
# For each sequence group, we generate a list of SequenceOutput object,
|
449 |
+
# each of which contains one possible candidate for the next token.
|
450 |
+
SamplerOutput = List[SequenceGroupOutput]
|
ChatTTS/model/velocity/worker.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A GPU worker class."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
from typing import Dict, List, Optional, Tuple
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.distributed
|
8 |
+
|
9 |
+
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig
|
10 |
+
from vllm.model_executor import set_random_seed
|
11 |
+
from vllm.model_executor.parallel_utils.communication_op import broadcast_object_list
|
12 |
+
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
13 |
+
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
14 |
+
from vllm.worker.cache_engine import CacheEngine
|
15 |
+
|
16 |
+
from .model_runner import ModelRunner
|
17 |
+
|
18 |
+
|
19 |
+
class Worker:
|
20 |
+
"""A worker class that executes (a partition of) the model on a GPU.
|
21 |
+
|
22 |
+
Each worker is associated with a single GPU. The worker is responsible for
|
23 |
+
maintaining the KV cache and executing the model on the GPU. In case of
|
24 |
+
distributed inference, each worker is assigned a partition of the model.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
model_config: ModelConfig,
|
30 |
+
parallel_config: ParallelConfig,
|
31 |
+
scheduler_config: SchedulerConfig,
|
32 |
+
local_rank: int,
|
33 |
+
rank: int,
|
34 |
+
distributed_init_method: str,
|
35 |
+
post_model_path: str,
|
36 |
+
is_driver_worker: bool = False,
|
37 |
+
) -> None:
|
38 |
+
self.model_config = model_config
|
39 |
+
self.parallel_config = parallel_config
|
40 |
+
self.scheduler_config = scheduler_config
|
41 |
+
self.local_rank = local_rank
|
42 |
+
self.rank = rank
|
43 |
+
self.distributed_init_method = distributed_init_method
|
44 |
+
self.is_driver_worker = is_driver_worker
|
45 |
+
self.post_model_path = post_model_path
|
46 |
+
|
47 |
+
if self.is_driver_worker:
|
48 |
+
assert self.rank == 0, "The driver worker must have rank 0."
|
49 |
+
|
50 |
+
self.model_runner = ModelRunner(
|
51 |
+
model_config,
|
52 |
+
parallel_config,
|
53 |
+
scheduler_config,
|
54 |
+
is_driver_worker,
|
55 |
+
post_model_path,
|
56 |
+
)
|
57 |
+
# Uninitialized cache engine. Will be initialized by
|
58 |
+
# self.init_cache_engine().
|
59 |
+
self.cache_config = None
|
60 |
+
self.cache_engine = None
|
61 |
+
self.cache_events = None
|
62 |
+
self.gpu_cache = None
|
63 |
+
|
64 |
+
def init_model(self) -> None:
|
65 |
+
# torch.distributed.all_reduce does not free the input tensor until
|
66 |
+
# the synchronization point. This causes the memory usage to grow
|
67 |
+
# as the number of all_reduce calls increases. This env var disables
|
68 |
+
# this behavior.
|
69 |
+
# Related issue:
|
70 |
+
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
71 |
+
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
72 |
+
|
73 |
+
# This env var set by Ray causes exceptions with graph building.
|
74 |
+
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
75 |
+
self.device = torch.device(f"cuda:{self.local_rank}")
|
76 |
+
torch.cuda.set_device(self.device)
|
77 |
+
|
78 |
+
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
79 |
+
|
80 |
+
# Initialize the distributed environment.
|
81 |
+
_init_distributed_environment(
|
82 |
+
self.parallel_config, self.rank, self.distributed_init_method
|
83 |
+
)
|
84 |
+
|
85 |
+
# Initialize the model.
|
86 |
+
set_random_seed(self.model_config.seed)
|
87 |
+
|
88 |
+
def load_model(self):
|
89 |
+
self.model_runner.load_model()
|
90 |
+
|
91 |
+
@torch.inference_mode()
|
92 |
+
def profile_num_available_blocks(
|
93 |
+
self,
|
94 |
+
block_size: int,
|
95 |
+
gpu_memory_utilization: float,
|
96 |
+
cpu_swap_space: int,
|
97 |
+
) -> Tuple[int, int]:
|
98 |
+
# Profile the memory usage of the model and get the maximum number of
|
99 |
+
# cache blocks that can be allocated with the remaining free memory.
|
100 |
+
torch.cuda.empty_cache()
|
101 |
+
|
102 |
+
# Execute a forward pass with dummy inputs to profile the memory usage
|
103 |
+
# of the model.
|
104 |
+
self.model_runner.profile_run()
|
105 |
+
|
106 |
+
# Calculate the number of blocks that can be allocated with the
|
107 |
+
# profiled peak memory.
|
108 |
+
torch.cuda.synchronize()
|
109 |
+
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
110 |
+
peak_memory = total_gpu_memory - free_gpu_memory
|
111 |
+
|
112 |
+
cache_block_size = CacheEngine.get_cache_block_size(
|
113 |
+
block_size, self.model_config, self.parallel_config
|
114 |
+
)
|
115 |
+
num_gpu_blocks = int(
|
116 |
+
(total_gpu_memory * gpu_memory_utilization - peak_memory)
|
117 |
+
// cache_block_size
|
118 |
+
)
|
119 |
+
num_cpu_blocks = int(cpu_swap_space // cache_block_size)
|
120 |
+
num_gpu_blocks = max(num_gpu_blocks, 0)
|
121 |
+
num_cpu_blocks = max(num_cpu_blocks, 0)
|
122 |
+
torch.cuda.empty_cache()
|
123 |
+
return num_gpu_blocks, num_cpu_blocks
|
124 |
+
|
125 |
+
def init_cache_engine(self, cache_config: CacheConfig) -> None:
|
126 |
+
self.cache_config = cache_config
|
127 |
+
self.cache_engine = CacheEngine(
|
128 |
+
self.cache_config, self.model_config, self.parallel_config
|
129 |
+
)
|
130 |
+
self.cache_events = self.cache_engine.events
|
131 |
+
self.gpu_cache = self.cache_engine.gpu_cache
|
132 |
+
self.model_runner.set_block_size(self.cache_engine.block_size)
|
133 |
+
|
134 |
+
def warm_up_model(self) -> None:
|
135 |
+
if not self.model_config.enforce_eager:
|
136 |
+
self.model_runner.capture_model(self.gpu_cache)
|
137 |
+
# Reset the seed to ensure that the random state is not affected by
|
138 |
+
# the model initialization and profiling.
|
139 |
+
set_random_seed(self.model_config.seed)
|
140 |
+
|
141 |
+
def cache_swap(
|
142 |
+
self,
|
143 |
+
blocks_to_swap_in: Dict[int, int],
|
144 |
+
blocks_to_swap_out: Dict[int, int],
|
145 |
+
blocks_to_copy: Dict[int, List[int]],
|
146 |
+
) -> None:
|
147 |
+
# Issue cache operations.
|
148 |
+
issued_cache_op = False
|
149 |
+
if blocks_to_swap_in:
|
150 |
+
self.cache_engine.swap_in(blocks_to_swap_in)
|
151 |
+
issued_cache_op = True
|
152 |
+
if blocks_to_swap_out:
|
153 |
+
self.cache_engine.swap_out(blocks_to_swap_out)
|
154 |
+
issued_cache_op = True
|
155 |
+
if blocks_to_copy:
|
156 |
+
self.cache_engine.copy(blocks_to_copy)
|
157 |
+
issued_cache_op = True
|
158 |
+
|
159 |
+
cache_events = self.cache_events if issued_cache_op else None
|
160 |
+
|
161 |
+
# Wait for cache operations to finish.
|
162 |
+
# TODO(woosuk): Profile swapping overhead and optimize if needed.
|
163 |
+
if cache_events is not None:
|
164 |
+
for event in cache_events:
|
165 |
+
event.wait()
|
166 |
+
|
167 |
+
@torch.inference_mode()
|
168 |
+
def execute_model(
|
169 |
+
self,
|
170 |
+
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
|
171 |
+
blocks_to_swap_in: Optional[Dict[int, int]] = None,
|
172 |
+
blocks_to_swap_out: Optional[Dict[int, int]] = None,
|
173 |
+
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
|
174 |
+
) -> Optional[SamplerOutput]:
|
175 |
+
if self.is_driver_worker:
|
176 |
+
assert seq_group_metadata_list is not None
|
177 |
+
num_seq_groups = len(seq_group_metadata_list)
|
178 |
+
assert blocks_to_swap_in is not None
|
179 |
+
assert blocks_to_swap_out is not None
|
180 |
+
assert blocks_to_copy is not None
|
181 |
+
block_swapping_info = [
|
182 |
+
blocks_to_swap_in,
|
183 |
+
blocks_to_swap_out,
|
184 |
+
blocks_to_copy,
|
185 |
+
]
|
186 |
+
broadcast_object_list([num_seq_groups] + block_swapping_info, src=0)
|
187 |
+
else:
|
188 |
+
# num_seq_groups, blocks_to_swap_in, blocks_to_swap_out,
|
189 |
+
# blocks_to_copy (4 elements)
|
190 |
+
recv_data = [None] * 4
|
191 |
+
broadcast_object_list(recv_data, src=0)
|
192 |
+
num_seq_groups = recv_data[0]
|
193 |
+
block_swapping_info = recv_data[1:]
|
194 |
+
|
195 |
+
self.cache_swap(*block_swapping_info)
|
196 |
+
|
197 |
+
# If there is no input, we don't need to execute the model.
|
198 |
+
if num_seq_groups == 0:
|
199 |
+
return {}
|
200 |
+
|
201 |
+
output = self.model_runner.execute_model(
|
202 |
+
seq_group_metadata_list, self.gpu_cache
|
203 |
+
)
|
204 |
+
return output
|
205 |
+
|
206 |
+
|
207 |
+
def _init_distributed_environment(
|
208 |
+
parallel_config: ParallelConfig,
|
209 |
+
rank: int,
|
210 |
+
distributed_init_method: Optional[str] = None,
|
211 |
+
) -> None:
|
212 |
+
"""Initialize the distributed environment."""
|
213 |
+
if torch.distributed.is_initialized():
|
214 |
+
torch_world_size = torch.distributed.get_world_size()
|
215 |
+
if torch_world_size != parallel_config.world_size:
|
216 |
+
raise RuntimeError(
|
217 |
+
"torch.distributed is already initialized but the torch world "
|
218 |
+
"size does not match parallel_config.world_size "
|
219 |
+
f"({torch_world_size} vs. {parallel_config.world_size})."
|
220 |
+
)
|
221 |
+
elif not distributed_init_method:
|
222 |
+
raise ValueError(
|
223 |
+
"distributed_init_method must be set if torch.distributed "
|
224 |
+
"is not already initialized"
|
225 |
+
)
|
226 |
+
else:
|
227 |
+
torch.distributed.init_process_group(
|
228 |
+
backend="nccl",
|
229 |
+
world_size=parallel_config.world_size,
|
230 |
+
rank=rank,
|
231 |
+
init_method=distributed_init_method,
|
232 |
+
)
|
233 |
+
|
234 |
+
# A small all_reduce for warmup.
|
235 |
+
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
236 |
+
initialize_model_parallel(
|
237 |
+
parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size
|
238 |
+
)
|
239 |
+
|
240 |
+
|
241 |
+
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
242 |
+
# Check if the GPU supports the dtype.
|
243 |
+
if torch_dtype == torch.bfloat16:
|
244 |
+
compute_capability = torch.cuda.get_device_capability()
|
245 |
+
if compute_capability[0] < 8:
|
246 |
+
gpu_name = torch.cuda.get_device_name()
|
247 |
+
raise ValueError(
|
248 |
+
"Bfloat16 is only supported on GPUs with compute capability "
|
249 |
+
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
250 |
+
f"{compute_capability[0]}.{compute_capability[1]}."
|
251 |
+
)
|
ChatTTS/norm.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import re
|
4 |
+
from typing import Dict, Tuple, List, Literal, Callable, Optional
|
5 |
+
import sys
|
6 |
+
|
7 |
+
from numba import jit
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from .utils import del_all
|
11 |
+
|
12 |
+
|
13 |
+
@jit
|
14 |
+
def _find_index(table: np.ndarray, val: np.uint16):
|
15 |
+
for i in range(table.size):
|
16 |
+
if table[i] == val:
|
17 |
+
return i
|
18 |
+
return -1
|
19 |
+
|
20 |
+
|
21 |
+
@jit
|
22 |
+
def _fast_replace(
|
23 |
+
table: np.ndarray, text: bytes
|
24 |
+
) -> Tuple[np.ndarray, List[Tuple[str, str]]]:
|
25 |
+
result = np.frombuffer(text, dtype=np.uint16).copy()
|
26 |
+
replaced_words = []
|
27 |
+
for i in range(result.size):
|
28 |
+
ch = result[i]
|
29 |
+
p = _find_index(table[0], ch)
|
30 |
+
if p >= 0:
|
31 |
+
repl_char = table[1][p]
|
32 |
+
result[i] = repl_char
|
33 |
+
replaced_words.append((chr(ch), chr(repl_char)))
|
34 |
+
return result, replaced_words
|
35 |
+
|
36 |
+
|
37 |
+
@jit
|
38 |
+
def _split_tags(text: str) -> Tuple[List[str], List[str]]:
|
39 |
+
texts: List[str] = []
|
40 |
+
tags: List[str] = []
|
41 |
+
current_text = ""
|
42 |
+
current_tag = ""
|
43 |
+
for c in text:
|
44 |
+
if c == "[":
|
45 |
+
texts.append(current_text)
|
46 |
+
current_text = ""
|
47 |
+
current_tag = c
|
48 |
+
elif current_tag != "":
|
49 |
+
current_tag += c
|
50 |
+
else:
|
51 |
+
current_text += c
|
52 |
+
if c == "]":
|
53 |
+
tags.append(current_tag)
|
54 |
+
current_tag = ""
|
55 |
+
if current_text != "":
|
56 |
+
texts.append(current_text)
|
57 |
+
return texts, tags
|
58 |
+
|
59 |
+
|
60 |
+
@jit
|
61 |
+
def _combine_tags(texts: List[str], tags: List[str]) -> str:
|
62 |
+
text = ""
|
63 |
+
for t in texts:
|
64 |
+
tg = ""
|
65 |
+
if len(tags) > 0:
|
66 |
+
tg = tags.pop(0)
|
67 |
+
text += t + tg
|
68 |
+
return text
|
69 |
+
|
70 |
+
|
71 |
+
class Normalizer:
|
72 |
+
def __init__(self, map_file_path: str, logger=logging.getLogger(__name__)):
|
73 |
+
self.logger = logger
|
74 |
+
self.normalizers: Dict[str, Callable[[str], str]] = {}
|
75 |
+
self.homophones_map = self._load_homophones_map(map_file_path)
|
76 |
+
"""
|
77 |
+
homophones_map
|
78 |
+
|
79 |
+
Replace the mispronounced characters with correctly pronounced ones.
|
80 |
+
|
81 |
+
Creation process of homophones_map.json:
|
82 |
+
|
83 |
+
1. Establish a word corpus using the [Tencent AI Lab Embedding Corpora v0.2.0 large] with 12 million entries. After cleaning, approximately 1.8 million entries remain. Use ChatTTS to infer the text.
|
84 |
+
2. Record discrepancies between the inferred and input text, identifying about 180,000 misread words.
|
85 |
+
3. Create a pinyin to common characters mapping using correctly read characters by ChatTTS.
|
86 |
+
4. For each discrepancy, extract the correct pinyin using [python-pinyin] and find homophones with the correct pronunciation from the mapping.
|
87 |
+
|
88 |
+
Thanks to:
|
89 |
+
[Tencent AI Lab Embedding Corpora for Chinese and English Words and Phrases](https://ai.tencent.com/ailab/nlp/en/embedding.html)
|
90 |
+
[python-pinyin](https://github.com/mozillazg/python-pinyin)
|
91 |
+
|
92 |
+
"""
|
93 |
+
self.coding = "utf-16-le" if sys.byteorder == "little" else "utf-16-be"
|
94 |
+
self.reject_pattern = re.compile(r"[^\u4e00-\u9fffA-Za-z,。、,\. ]")
|
95 |
+
self.sub_pattern = re.compile(r"\[[\w_]+\]")
|
96 |
+
self.chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]")
|
97 |
+
self.english_word_pattern = re.compile(r"\b[A-Za-z]+\b")
|
98 |
+
self.character_simplifier = str.maketrans(
|
99 |
+
{
|
100 |
+
":": ",",
|
101 |
+
";": ",",
|
102 |
+
"!": "。",
|
103 |
+
"(": ",",
|
104 |
+
")": ",",
|
105 |
+
"【": ",",
|
106 |
+
"】": ",",
|
107 |
+
"『": ",",
|
108 |
+
"』": ",",
|
109 |
+
"「": ",",
|
110 |
+
"」": ",",
|
111 |
+
"《": ",",
|
112 |
+
"》": ",",
|
113 |
+
"-": ",",
|
114 |
+
":": ",",
|
115 |
+
";": ",",
|
116 |
+
"!": ".",
|
117 |
+
"(": ",",
|
118 |
+
")": ",",
|
119 |
+
# "[": ",",
|
120 |
+
# "]": ",",
|
121 |
+
">": ",",
|
122 |
+
"<": ",",
|
123 |
+
"-": ",",
|
124 |
+
}
|
125 |
+
)
|
126 |
+
self.halfwidth_2_fullwidth = str.maketrans(
|
127 |
+
{
|
128 |
+
"!": "!",
|
129 |
+
'"': "“",
|
130 |
+
"'": "‘",
|
131 |
+
"#": "#",
|
132 |
+
"$": "$",
|
133 |
+
"%": "%",
|
134 |
+
"&": "&",
|
135 |
+
"(": "(",
|
136 |
+
")": ")",
|
137 |
+
",": ",",
|
138 |
+
"-": "-",
|
139 |
+
"*": "*",
|
140 |
+
"+": "+",
|
141 |
+
".": "。",
|
142 |
+
"/": "/",
|
143 |
+
":": ":",
|
144 |
+
";": ";",
|
145 |
+
"<": "<",
|
146 |
+
"=": "=",
|
147 |
+
">": ">",
|
148 |
+
"?": "?",
|
149 |
+
"@": "@",
|
150 |
+
# '[': '[',
|
151 |
+
"\\": "\",
|
152 |
+
# ']': ']',
|
153 |
+
"^": "^",
|
154 |
+
# '_': '_',
|
155 |
+
"`": "`",
|
156 |
+
"{": "{",
|
157 |
+
"|": "|",
|
158 |
+
"}": "}",
|
159 |
+
"~": "~",
|
160 |
+
}
|
161 |
+
)
|
162 |
+
|
163 |
+
def __call__(
|
164 |
+
self,
|
165 |
+
text: str,
|
166 |
+
do_text_normalization=True,
|
167 |
+
do_homophone_replacement=True,
|
168 |
+
lang: Optional[Literal["zh", "en"]] = None,
|
169 |
+
) -> str:
|
170 |
+
if do_text_normalization:
|
171 |
+
_lang = self._detect_language(text) if lang is None else lang
|
172 |
+
if _lang in self.normalizers:
|
173 |
+
texts, tags = _split_tags(text)
|
174 |
+
self.logger.debug("split texts %s, tags %s", str(texts), str(tags))
|
175 |
+
texts = [self.normalizers[_lang](t) for t in texts]
|
176 |
+
self.logger.debug("normed texts %s", str(texts))
|
177 |
+
text = _combine_tags(texts, tags) if len(tags) > 0 else texts[0]
|
178 |
+
self.logger.debug("combined text %s", text)
|
179 |
+
if _lang == "zh":
|
180 |
+
text = self._apply_half2full_map(text)
|
181 |
+
invalid_characters = self._count_invalid_characters(text)
|
182 |
+
if len(invalid_characters):
|
183 |
+
self.logger.warning(f"found invalid characters: {invalid_characters}")
|
184 |
+
text = self._apply_character_map(text)
|
185 |
+
if do_homophone_replacement:
|
186 |
+
arr, replaced_words = _fast_replace(
|
187 |
+
self.homophones_map,
|
188 |
+
text.encode(self.coding),
|
189 |
+
)
|
190 |
+
if replaced_words:
|
191 |
+
text = arr.tobytes().decode(self.coding)
|
192 |
+
repl_res = ", ".join([f"{_[0]}->{_[1]}" for _ in replaced_words])
|
193 |
+
self.logger.info(f"replace homophones: {repl_res}")
|
194 |
+
if len(invalid_characters):
|
195 |
+
texts, tags = _split_tags(text)
|
196 |
+
self.logger.debug("split texts %s, tags %s", str(texts), str(tags))
|
197 |
+
texts = [self.reject_pattern.sub("", t) for t in texts]
|
198 |
+
self.logger.debug("normed texts %s", str(texts))
|
199 |
+
text = _combine_tags(texts, tags) if len(tags) > 0 else texts[0]
|
200 |
+
self.logger.debug("combined text %s", text)
|
201 |
+
return text
|
202 |
+
|
203 |
+
def register(self, name: str, normalizer: Callable[[str], str]) -> bool:
|
204 |
+
if name in self.normalizers:
|
205 |
+
self.logger.warning(f"name {name} has been registered")
|
206 |
+
return False
|
207 |
+
try:
|
208 |
+
val = normalizer("test string 测试字符串")
|
209 |
+
if not isinstance(val, str):
|
210 |
+
self.logger.warning("normalizer must have caller type (str) -> str")
|
211 |
+
return False
|
212 |
+
except Exception as e:
|
213 |
+
self.logger.warning(e)
|
214 |
+
return False
|
215 |
+
self.normalizers[name] = normalizer
|
216 |
+
return True
|
217 |
+
|
218 |
+
def unregister(self, name: str):
|
219 |
+
if name in self.normalizers:
|
220 |
+
del self.normalizers[name]
|
221 |
+
|
222 |
+
def destroy(self):
|
223 |
+
del_all(self.normalizers)
|
224 |
+
del self.homophones_map
|
225 |
+
|
226 |
+
def _load_homophones_map(self, map_file_path: str) -> np.ndarray:
|
227 |
+
with open(map_file_path, "r", encoding="utf-8") as f:
|
228 |
+
homophones_map: Dict[str, str] = json.load(f)
|
229 |
+
map = np.empty((2, len(homophones_map)), dtype=np.uint32)
|
230 |
+
for i, k in enumerate(homophones_map.keys()):
|
231 |
+
map[:, i] = (ord(k), ord(homophones_map[k]))
|
232 |
+
del homophones_map
|
233 |
+
return map
|
234 |
+
|
235 |
+
def _count_invalid_characters(self, s: str):
|
236 |
+
s = self.sub_pattern.sub("", s)
|
237 |
+
non_alphabetic_chinese_chars = self.reject_pattern.findall(s)
|
238 |
+
return set(non_alphabetic_chinese_chars)
|
239 |
+
|
240 |
+
def _apply_half2full_map(self, text: str) -> str:
|
241 |
+
return text.translate(self.halfwidth_2_fullwidth)
|
242 |
+
|
243 |
+
def _apply_character_map(self, text: str) -> str:
|
244 |
+
return text.translate(self.character_simplifier)
|
245 |
+
|
246 |
+
def _detect_language(self, sentence: str) -> Literal["zh", "en"]:
|
247 |
+
chinese_chars = self.chinese_char_pattern.findall(sentence)
|
248 |
+
english_words = self.english_word_pattern.findall(sentence)
|
249 |
+
|
250 |
+
if len(chinese_chars) > len(english_words):
|
251 |
+
return "zh"
|
252 |
+
else:
|
253 |
+
return "en"
|
ChatTTS/res/__init__.py
ADDED
File without changes
|
ChatTTS/res/homophones_map.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
ChatTTS/res/sha256_map.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"sha256_asset_Decoder_pt" : "9964e36e840f0e3a748c5f716fe6de6490d2135a5f5155f4a642d51860e2ec38",
|
3 |
+
"sha256_asset_DVAE_full_pt" : "553eb75763511e23f3e5f86303e2163c5ca775489d637fb635d979c8ae58bbe5",
|
4 |
+
"sha256_asset_Embed_safetensors" : "2ff0be7134934155741b643b74e32fb6bf3eec41257984459b2ed60cdb4c48b0",
|
5 |
+
"sha256_asset_Vocos_pt" : "09a670eda1c08b740013679c7a90ebb7f1a97646ea7673069a6838e6b51d6c58",
|
6 |
+
|
7 |
+
"sha256_asset_gpt_config_json" : "0aaa1ecd96c49ad4f473459eb1982fa7ad79fa5de08cde2781bf6ad1f9a0c236",
|
8 |
+
"sha256_asset_gpt_model_safetensors" : "cd0806fd971f52f6a22c923ec64982b305e817bcc41ca83417fcf9141b984a0f",
|
9 |
+
|
10 |
+
"sha256_asset_tokenizer_special_tokens_map_json": "bd0ac9d9bb1657996b5c5fbcaa7d80f8de530d01a283da97f89deae5b1b8d011",
|
11 |
+
"sha256_asset_tokenizer_tokenizer_config_json" : "43e9d658b554fa5ee8d8e1d763349323bfef1ed7a89c0794220ab8861387d421",
|
12 |
+
"sha256_asset_tokenizer_tokenizer_json" : "843838a64e121e23e774cc75874c6fe862198d9f7dd43747914633a8fd89c20e"
|
13 |
+
}
|
ChatTTS/utils/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .dl import check_all_assets, download_all_assets
|
2 |
+
from .gpu import select_device
|
3 |
+
from .io import get_latest_modified_file, del_all
|
4 |
+
from .log import logger
|
ChatTTS/utils/dl.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
import hashlib
|
4 |
+
import requests
|
5 |
+
from io import BytesIO
|
6 |
+
from typing import Dict, Tuple, Optional
|
7 |
+
from mmap import mmap, ACCESS_READ
|
8 |
+
|
9 |
+
from .log import logger
|
10 |
+
|
11 |
+
|
12 |
+
def sha256(fileno: int) -> str:
|
13 |
+
data = mmap(fileno, 0, access=ACCESS_READ)
|
14 |
+
h = hashlib.sha256(data).hexdigest()
|
15 |
+
del data
|
16 |
+
return h
|
17 |
+
|
18 |
+
|
19 |
+
def check_model(
|
20 |
+
dir_name: Path, model_name: str, hash: str, remove_incorrect=False
|
21 |
+
) -> bool:
|
22 |
+
target = dir_name / model_name
|
23 |
+
relname = target.as_posix()
|
24 |
+
logger.get_logger().debug(f"checking {relname}...")
|
25 |
+
if not os.path.exists(target):
|
26 |
+
logger.get_logger().info(f"{target} not exist.")
|
27 |
+
return False
|
28 |
+
with open(target, "rb") as f:
|
29 |
+
digest = sha256(f.fileno())
|
30 |
+
bakfile = f"{target}.bak"
|
31 |
+
if digest != hash:
|
32 |
+
logger.get_logger().warning(f"{target} sha256 hash mismatch.")
|
33 |
+
logger.get_logger().info(f"expected: {hash}")
|
34 |
+
logger.get_logger().info(f"real val: {digest}")
|
35 |
+
if remove_incorrect:
|
36 |
+
if not os.path.exists(bakfile):
|
37 |
+
os.rename(str(target), bakfile)
|
38 |
+
else:
|
39 |
+
os.remove(str(target))
|
40 |
+
return False
|
41 |
+
if remove_incorrect and os.path.exists(bakfile):
|
42 |
+
os.remove(bakfile)
|
43 |
+
return True
|
44 |
+
|
45 |
+
|
46 |
+
def check_folder(
|
47 |
+
base_dir: Path,
|
48 |
+
*innder_dirs: str,
|
49 |
+
names: Tuple[str],
|
50 |
+
sha256_map: Dict[str, str],
|
51 |
+
update=False,
|
52 |
+
) -> bool:
|
53 |
+
key = "sha256_"
|
54 |
+
current_dir = base_dir
|
55 |
+
for d in innder_dirs:
|
56 |
+
current_dir /= d
|
57 |
+
key += f"{d}_"
|
58 |
+
|
59 |
+
for model in names:
|
60 |
+
menv = model.replace(".", "_")
|
61 |
+
if not check_model(current_dir, model, sha256_map[f"{key}{menv}"], update):
|
62 |
+
return False
|
63 |
+
return True
|
64 |
+
|
65 |
+
|
66 |
+
def check_all_assets(base_dir: Path, sha256_map: Dict[str, str], update=False) -> bool:
|
67 |
+
logger.get_logger().info("checking assets...")
|
68 |
+
|
69 |
+
if not check_folder(
|
70 |
+
base_dir,
|
71 |
+
"asset",
|
72 |
+
names=(
|
73 |
+
"Decoder.pt",
|
74 |
+
"DVAE_full.pt",
|
75 |
+
"Embed.safetensors",
|
76 |
+
"Vocos.pt",
|
77 |
+
),
|
78 |
+
sha256_map=sha256_map,
|
79 |
+
update=update,
|
80 |
+
):
|
81 |
+
return False
|
82 |
+
|
83 |
+
if not check_folder(
|
84 |
+
base_dir,
|
85 |
+
"asset",
|
86 |
+
"gpt",
|
87 |
+
names=(
|
88 |
+
"config.json",
|
89 |
+
"model.safetensors",
|
90 |
+
),
|
91 |
+
sha256_map=sha256_map,
|
92 |
+
update=update,
|
93 |
+
):
|
94 |
+
return False
|
95 |
+
|
96 |
+
if not check_folder(
|
97 |
+
base_dir,
|
98 |
+
"asset",
|
99 |
+
"tokenizer",
|
100 |
+
names=(
|
101 |
+
"special_tokens_map.json",
|
102 |
+
"tokenizer_config.json",
|
103 |
+
"tokenizer.json",
|
104 |
+
),
|
105 |
+
sha256_map=sha256_map,
|
106 |
+
update=update,
|
107 |
+
):
|
108 |
+
return False
|
109 |
+
|
110 |
+
logger.get_logger().info("all assets are already latest.")
|
111 |
+
return True
|
112 |
+
|
113 |
+
|
114 |
+
def download_and_extract_tar_gz(
|
115 |
+
url: str, folder: str, headers: Optional[Dict[str, str]] = None
|
116 |
+
):
|
117 |
+
import tarfile
|
118 |
+
|
119 |
+
logger.get_logger().info(f"downloading {url}")
|
120 |
+
response = requests.get(url, headers=headers, stream=True, timeout=(10, 3))
|
121 |
+
with BytesIO() as out_file:
|
122 |
+
out_file.write(response.content)
|
123 |
+
out_file.seek(0)
|
124 |
+
logger.get_logger().info(f"downloaded.")
|
125 |
+
with tarfile.open(fileobj=out_file, mode="r:gz") as tar:
|
126 |
+
tar.extractall(folder)
|
127 |
+
logger.get_logger().info(f"extracted into {folder}")
|
128 |
+
|
129 |
+
|
130 |
+
def download_and_extract_zip(
|
131 |
+
url: str, folder: str, headers: Optional[Dict[str, str]] = None
|
132 |
+
):
|
133 |
+
import zipfile
|
134 |
+
|
135 |
+
logger.get_logger().info(f"downloading {url}")
|
136 |
+
response = requests.get(url, headers=headers, stream=True, timeout=(10, 3))
|
137 |
+
with BytesIO() as out_file:
|
138 |
+
out_file.write(response.content)
|
139 |
+
out_file.seek(0)
|
140 |
+
logger.get_logger().info(f"downloaded.")
|
141 |
+
with zipfile.ZipFile(out_file) as zip_ref:
|
142 |
+
zip_ref.extractall(folder)
|
143 |
+
logger.get_logger().info(f"extracted into {folder}")
|
144 |
+
|
145 |
+
|
146 |
+
def download_dns_yaml(url: str, folder: str, headers: Dict[str, str]):
|
147 |
+
logger.get_logger().info(f"downloading {url}")
|
148 |
+
response = requests.get(url, headers=headers, stream=True, timeout=(100, 3))
|
149 |
+
with open(os.path.join(folder, "dns.yaml"), "wb") as out_file:
|
150 |
+
out_file.write(response.content)
|
151 |
+
logger.get_logger().info(f"downloaded into {folder}")
|
152 |
+
|
153 |
+
|
154 |
+
def download_all_assets(tmpdir: str, version="0.2.8"):
|
155 |
+
import subprocess
|
156 |
+
import platform
|
157 |
+
|
158 |
+
archs = {
|
159 |
+
"aarch64": "arm64",
|
160 |
+
"armv8l": "arm64",
|
161 |
+
"arm64": "arm64",
|
162 |
+
"x86": "386",
|
163 |
+
"i386": "386",
|
164 |
+
"i686": "386",
|
165 |
+
"386": "386",
|
166 |
+
"x86_64": "amd64",
|
167 |
+
"x64": "amd64",
|
168 |
+
"amd64": "amd64",
|
169 |
+
}
|
170 |
+
system_type = platform.system().lower()
|
171 |
+
architecture = platform.machine().lower()
|
172 |
+
is_win = system_type == "windows"
|
173 |
+
|
174 |
+
architecture = archs.get(architecture, None)
|
175 |
+
if not architecture:
|
176 |
+
logger.get_logger().error(f"architecture {architecture} is not supported")
|
177 |
+
exit(1)
|
178 |
+
try:
|
179 |
+
BASE_URL = "https://github.com/fumiama/RVC-Models-Downloader/releases/download/"
|
180 |
+
suffix = "zip" if is_win else "tar.gz"
|
181 |
+
RVCMD_URL = BASE_URL + f"v{version}/rvcmd_{system_type}_{architecture}.{suffix}"
|
182 |
+
cmdfile = os.path.join(tmpdir, "rvcmd")
|
183 |
+
if is_win:
|
184 |
+
download_and_extract_zip(RVCMD_URL, tmpdir)
|
185 |
+
cmdfile += ".exe"
|
186 |
+
else:
|
187 |
+
download_and_extract_tar_gz(RVCMD_URL, tmpdir)
|
188 |
+
os.chmod(cmdfile, 0o755)
|
189 |
+
subprocess.run([cmdfile, "-notui", "-w", "0", "assets/chtts"])
|
190 |
+
except Exception:
|
191 |
+
BASE_URL = (
|
192 |
+
"https://gitea.seku.su/fumiama/RVC-Models-Downloader/releases/download/"
|
193 |
+
)
|
194 |
+
suffix = "zip" if is_win else "tar.gz"
|
195 |
+
RVCMD_URL = BASE_URL + f"v{version}/rvcmd_{system_type}_{architecture}.{suffix}"
|
196 |
+
download_dns_yaml(
|
197 |
+
"https://gitea.seku.su/fumiama/RVC-Models-Downloader/raw/branch/main/dns.yaml",
|
198 |
+
tmpdir,
|
199 |
+
headers={
|
200 |
+
"user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36 Edg/128.0.0.0"
|
201 |
+
},
|
202 |
+
)
|
203 |
+
cmdfile = os.path.join(tmpdir, "rvcmd")
|
204 |
+
if is_win:
|
205 |
+
download_and_extract_zip(RVCMD_URL, tmpdir)
|
206 |
+
cmdfile += ".exe"
|
207 |
+
else:
|
208 |
+
download_and_extract_tar_gz(RVCMD_URL, tmpdir)
|
209 |
+
os.chmod(cmdfile, 0o755)
|
210 |
+
subprocess.run(
|
211 |
+
[
|
212 |
+
cmdfile,
|
213 |
+
"-notui",
|
214 |
+
"-w",
|
215 |
+
"0",
|
216 |
+
"-dns",
|
217 |
+
os.path.join(tmpdir, "dns.yaml"),
|
218 |
+
"assets/chtts",
|
219 |
+
]
|
220 |
+
)
|
ChatTTS/utils/gpu.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from .log import logger
|
4 |
+
|
5 |
+
|
6 |
+
def select_device(min_memory=2047, experimental=False):
|
7 |
+
if torch.cuda.is_available():
|
8 |
+
selected_gpu = 0
|
9 |
+
max_free_memory = -1
|
10 |
+
for i in range(torch.cuda.device_count()):
|
11 |
+
props = torch.cuda.get_device_properties(i)
|
12 |
+
free_memory = props.total_memory - torch.cuda.memory_reserved(i)
|
13 |
+
if max_free_memory < free_memory:
|
14 |
+
selected_gpu = i
|
15 |
+
max_free_memory = free_memory
|
16 |
+
free_memory_mb = max_free_memory / (1024 * 1024)
|
17 |
+
if free_memory_mb < min_memory:
|
18 |
+
logger.get_logger().warning(
|
19 |
+
f"GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left. Switching to CPU."
|
20 |
+
)
|
21 |
+
device = torch.device("cpu")
|
22 |
+
else:
|
23 |
+
device = torch.device(f"cuda:{selected_gpu}")
|
24 |
+
elif torch.backends.mps.is_available():
|
25 |
+
"""
|
26 |
+
Currently MPS is slower than CPU while needs more memory and core utility,
|
27 |
+
so only enable this for experimental use.
|
28 |
+
"""
|
29 |
+
if experimental:
|
30 |
+
# For Apple M1/M2 chips with Metal Performance Shaders
|
31 |
+
logger.get_logger().warning("experimantal: found apple GPU, using MPS.")
|
32 |
+
device = torch.device("mps")
|
33 |
+
else:
|
34 |
+
logger.get_logger().info("found Apple GPU, but use CPU.")
|
35 |
+
device = torch.device("cpu")
|
36 |
+
else:
|
37 |
+
logger.get_logger().warning("no GPU found, use CPU instead")
|
38 |
+
device = torch.device("cpu")
|
39 |
+
|
40 |
+
return device
|
ChatTTS/utils/io.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
from typing import Union
|
4 |
+
from dataclasses import is_dataclass
|
5 |
+
|
6 |
+
from .log import logger
|
7 |
+
|
8 |
+
|
9 |
+
def get_latest_modified_file(directory):
|
10 |
+
|
11 |
+
files = [os.path.join(directory, f) for f in os.listdir(directory)]
|
12 |
+
if not files:
|
13 |
+
logger.get_logger().log(
|
14 |
+
logging.WARNING, f"no files found in the directory: {directory}"
|
15 |
+
)
|
16 |
+
return None
|
17 |
+
latest_file = max(files, key=os.path.getmtime)
|
18 |
+
|
19 |
+
return latest_file
|
20 |
+
|
21 |
+
|
22 |
+
def del_all(d: Union[dict, list]):
|
23 |
+
if is_dataclass(d):
|
24 |
+
for k in list(vars(d).keys()):
|
25 |
+
x = getattr(d, k)
|
26 |
+
if isinstance(x, dict) or isinstance(x, list) or is_dataclass(x):
|
27 |
+
del_all(x)
|
28 |
+
del x
|
29 |
+
delattr(d, k)
|
30 |
+
elif isinstance(d, dict):
|
31 |
+
lst = list(d.keys())
|
32 |
+
for k in lst:
|
33 |
+
x = d.pop(k)
|
34 |
+
if isinstance(x, dict) or isinstance(x, list) or is_dataclass(x):
|
35 |
+
del_all(x)
|
36 |
+
del x
|
37 |
+
elif isinstance(d, list):
|
38 |
+
while len(d):
|
39 |
+
x = d.pop()
|
40 |
+
if isinstance(x, dict) or isinstance(x, list) or is_dataclass(x):
|
41 |
+
del_all(x)
|
42 |
+
del x
|
43 |
+
else:
|
44 |
+
del d
|
ChatTTS/utils/log.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
|
5 |
+
class Logger:
|
6 |
+
def __init__(self, logger=logging.getLogger(Path(__file__).parent.name)):
|
7 |
+
self.logger = logger
|
8 |
+
|
9 |
+
def set_logger(self, logger: logging.Logger):
|
10 |
+
self.logger = logger
|
11 |
+
|
12 |
+
def get_logger(self) -> logging.Logger:
|
13 |
+
return self.logger
|
14 |
+
|
15 |
+
|
16 |
+
logger = Logger()
|
Dockerfile
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9
|
2 |
+
|
3 |
+
RUN useradd -m -u 1000 user
|
4 |
+
USER user
|
5 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
6 |
+
|
7 |
+
WORKDIR /app
|
8 |
+
|
9 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
10 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
11 |
+
|
12 |
+
COPY --chown=user . /app
|
13 |
+
CMD ["fastapi", "dev", "examples/api/main.py", "--host", "0.0.0.0", "--port", "7860"]
|
LICENSE
ADDED
@@ -0,0 +1,661 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GNU AFFERO GENERAL PUBLIC LICENSE
|
2 |
+
Version 3, 19 November 2007
|
3 |
+
|
4 |
+
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
5 |
+
Everyone is permitted to copy and distribute verbatim copies
|
6 |
+
of this license document, but changing it is not allowed.
|
7 |
+
|
8 |
+
Preamble
|
9 |
+
|
10 |
+
The GNU Affero General Public License is a free, copyleft license for
|
11 |
+
software and other kinds of works, specifically designed to ensure
|
12 |
+
cooperation with the community in the case of network server software.
|
13 |
+
|
14 |
+
The licenses for most software and other practical works are designed
|
15 |
+
to take away your freedom to share and change the works. By contrast,
|
16 |
+
our General Public Licenses are intended to guarantee your freedom to
|
17 |
+
share and change all versions of a program--to make sure it remains free
|
18 |
+
software for all its users.
|
19 |
+
|
20 |
+
When we speak of free software, we are referring to freedom, not
|
21 |
+
price. Our General Public Licenses are designed to make sure that you
|
22 |
+
have the freedom to distribute copies of free software (and charge for
|
23 |
+
them if you wish), that you receive source code or can get it if you
|
24 |
+
want it, that you can change the software or use pieces of it in new
|
25 |
+
free programs, and that you know you can do these things.
|
26 |
+
|
27 |
+
Developers that use our General Public Licenses protect your rights
|
28 |
+
with two steps: (1) assert copyright on the software, and (2) offer
|
29 |
+
you this License which gives you legal permission to copy, distribute
|
30 |
+
and/or modify the software.
|
31 |
+
|
32 |
+
A secondary benefit of defending all users' freedom is that
|
33 |
+
improvements made in alternate versions of the program, if they
|
34 |
+
receive widespread use, become available for other developers to
|
35 |
+
incorporate. Many developers of free software are heartened and
|
36 |
+
encouraged by the resulting cooperation. However, in the case of
|
37 |
+
software used on network servers, this result may fail to come about.
|
38 |
+
The GNU General Public License permits making a modified version and
|
39 |
+
letting the public access it on a server without ever releasing its
|
40 |
+
source code to the public.
|
41 |
+
|
42 |
+
The GNU Affero General Public License is designed specifically to
|
43 |
+
ensure that, in such cases, the modified source code becomes available
|
44 |
+
to the community. It requires the operator of a network server to
|
45 |
+
provide the source code of the modified version running there to the
|
46 |
+
users of that server. Therefore, public use of a modified version, on
|
47 |
+
a publicly accessible server, gives the public access to the source
|
48 |
+
code of the modified version.
|
49 |
+
|
50 |
+
An older license, called the Affero General Public License and
|
51 |
+
published by Affero, was designed to accomplish similar goals. This is
|
52 |
+
a different license, not a version of the Affero GPL, but Affero has
|
53 |
+
released a new version of the Affero GPL which permits relicensing under
|
54 |
+
this license.
|
55 |
+
|
56 |
+
The precise terms and conditions for copying, distribution and
|
57 |
+
modification follow.
|
58 |
+
|
59 |
+
TERMS AND CONDITIONS
|
60 |
+
|
61 |
+
0. Definitions.
|
62 |
+
|
63 |
+
"This License" refers to version 3 of the GNU Affero General Public License.
|
64 |
+
|
65 |
+
"Copyright" also means copyright-like laws that apply to other kinds of
|
66 |
+
works, such as semiconductor masks.
|
67 |
+
|
68 |
+
"The Program" refers to any copyrightable work licensed under this
|
69 |
+
License. Each licensee is addressed as "you". "Licensees" and
|
70 |
+
"recipients" may be individuals or organizations.
|
71 |
+
|
72 |
+
To "modify" a work means to copy from or adapt all or part of the work
|
73 |
+
in a fashion requiring copyright permission, other than the making of an
|
74 |
+
exact copy. The resulting work is called a "modified version" of the
|
75 |
+
earlier work or a work "based on" the earlier work.
|
76 |
+
|
77 |
+
A "covered work" means either the unmodified Program or a work based
|
78 |
+
on the Program.
|
79 |
+
|
80 |
+
To "propagate" a work means to do anything with it that, without
|
81 |
+
permission, would make you directly or secondarily liable for
|
82 |
+
infringement under applicable copyright law, except executing it on a
|
83 |
+
computer or modifying a private copy. Propagation includes copying,
|
84 |
+
distribution (with or without modification), making available to the
|
85 |
+
public, and in some countries other activities as well.
|
86 |
+
|
87 |
+
To "convey" a work means any kind of propagation that enables other
|
88 |
+
parties to make or receive copies. Mere interaction with a user through
|
89 |
+
a computer network, with no transfer of a copy, is not conveying.
|
90 |
+
|
91 |
+
An interactive user interface displays "Appropriate Legal Notices"
|
92 |
+
to the extent that it includes a convenient and prominently visible
|
93 |
+
feature that (1) displays an appropriate copyright notice, and (2)
|
94 |
+
tells the user that there is no warranty for the work (except to the
|
95 |
+
extent that warranties are provided), that licensees may convey the
|
96 |
+
work under this License, and how to view a copy of this License. If
|
97 |
+
the interface presents a list of user commands or options, such as a
|
98 |
+
menu, a prominent item in the list meets this criterion.
|
99 |
+
|
100 |
+
1. Source Code.
|
101 |
+
|
102 |
+
The "source code" for a work means the preferred form of the work
|
103 |
+
for making modifications to it. "Object code" means any non-source
|
104 |
+
form of a work.
|
105 |
+
|
106 |
+
A "Standard Interface" means an interface that either is an official
|
107 |
+
standard defined by a recognized standards body, or, in the case of
|
108 |
+
interfaces specified for a particular programming language, one that
|
109 |
+
is widely used among developers working in that language.
|
110 |
+
|
111 |
+
The "System Libraries" of an executable work include anything, other
|
112 |
+
than the work as a whole, that (a) is included in the normal form of
|
113 |
+
packaging a Major Component, but which is not part of that Major
|
114 |
+
Component, and (b) serves only to enable use of the work with that
|
115 |
+
Major Component, or to implement a Standard Interface for which an
|
116 |
+
implementation is available to the public in source code form. A
|
117 |
+
"Major Component", in this context, means a major essential component
|
118 |
+
(kernel, window system, and so on) of the specific operating system
|
119 |
+
(if any) on which the executable work runs, or a compiler used to
|
120 |
+
produce the work, or an object code interpreter used to run it.
|
121 |
+
|
122 |
+
The "Corresponding Source" for a work in object code form means all
|
123 |
+
the source code needed to generate, install, and (for an executable
|
124 |
+
work) run the object code and to modify the work, including scripts to
|
125 |
+
control those activities. However, it does not include the work's
|
126 |
+
System Libraries, or general-purpose tools or generally available free
|
127 |
+
programs which are used unmodified in performing those activities but
|
128 |
+
which are not part of the work. For example, Corresponding Source
|
129 |
+
includes interface definition files associated with source files for
|
130 |
+
the work, and the source code for shared libraries and dynamically
|
131 |
+
linked subprograms that the work is specifically designed to require,
|
132 |
+
such as by intimate data communication or control flow between those
|
133 |
+
subprograms and other parts of the work.
|
134 |
+
|
135 |
+
The Corresponding Source need not include anything that users
|
136 |
+
can regenerate automatically from other parts of the Corresponding
|
137 |
+
Source.
|
138 |
+
|
139 |
+
The Corresponding Source for a work in source code form is that
|
140 |
+
same work.
|
141 |
+
|
142 |
+
2. Basic Permissions.
|
143 |
+
|
144 |
+
All rights granted under this License are granted for the term of
|
145 |
+
copyright on the Program, and are irrevocable provided the stated
|
146 |
+
conditions are met. This License explicitly affirms your unlimited
|
147 |
+
permission to run the unmodified Program. The output from running a
|
148 |
+
covered work is covered by this License only if the output, given its
|
149 |
+
content, constitutes a covered work. This License acknowledges your
|
150 |
+
rights of fair use or other equivalent, as provided by copyright law.
|
151 |
+
|
152 |
+
You may make, run and propagate covered works that you do not
|
153 |
+
convey, without conditions so long as your license otherwise remains
|
154 |
+
in force. You may convey covered works to others for the sole purpose
|
155 |
+
of having them make modifications exclusively for you, or provide you
|
156 |
+
with facilities for running those works, provided that you comply with
|
157 |
+
the terms of this License in conveying all material for which you do
|
158 |
+
not control copyright. Those thus making or running the covered works
|
159 |
+
for you must do so exclusively on your behalf, under your direction
|
160 |
+
and control, on terms that prohibit them from making any copies of
|
161 |
+
your copyrighted material outside their relationship with you.
|
162 |
+
|
163 |
+
Conveying under any other circumstances is permitted solely under
|
164 |
+
the conditions stated below. Sublicensing is not allowed; section 10
|
165 |
+
makes it unnecessary.
|
166 |
+
|
167 |
+
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
168 |
+
|
169 |
+
No covered work shall be deemed part of an effective technological
|
170 |
+
measure under any applicable law fulfilling obligations under article
|
171 |
+
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
172 |
+
similar laws prohibiting or restricting circumvention of such
|
173 |
+
measures.
|
174 |
+
|
175 |
+
When you convey a covered work, you waive any legal power to forbid
|
176 |
+
circumvention of technological measures to the extent such circumvention
|
177 |
+
is effected by exercising rights under this License with respect to
|
178 |
+
the covered work, and you disclaim any intention to limit operation or
|
179 |
+
modification of the work as a means of enforcing, against the work's
|
180 |
+
users, your or third parties' legal rights to forbid circumvention of
|
181 |
+
technological measures.
|
182 |
+
|
183 |
+
4. Conveying Verbatim Copies.
|
184 |
+
|
185 |
+
You may convey verbatim copies of the Program's source code as you
|
186 |
+
receive it, in any medium, provided that you conspicuously and
|
187 |
+
appropriately publish on each copy an appropriate copyright notice;
|
188 |
+
keep intact all notices stating that this License and any
|
189 |
+
non-permissive terms added in accord with section 7 apply to the code;
|
190 |
+
keep intact all notices of the absence of any warranty; and give all
|
191 |
+
recipients a copy of this License along with the Program.
|
192 |
+
|
193 |
+
You may charge any price or no price for each copy that you convey,
|
194 |
+
and you may offer support or warranty protection for a fee.
|
195 |
+
|
196 |
+
5. Conveying Modified Source Versions.
|
197 |
+
|
198 |
+
You may convey a work based on the Program, or the modifications to
|
199 |
+
produce it from the Program, in the form of source code under the
|
200 |
+
terms of section 4, provided that you also meet all of these conditions:
|
201 |
+
|
202 |
+
a) The work must carry prominent notices stating that you modified
|
203 |
+
it, and giving a relevant date.
|
204 |
+
|
205 |
+
b) The work must carry prominent notices stating that it is
|
206 |
+
released under this License and any conditions added under section
|
207 |
+
7. This requirement modifies the requirement in section 4 to
|
208 |
+
"keep intact all notices".
|
209 |
+
|
210 |
+
c) You must license the entire work, as a whole, under this
|
211 |
+
License to anyone who comes into possession of a copy. This
|
212 |
+
License will therefore apply, along with any applicable section 7
|
213 |
+
additional terms, to the whole of the work, and all its parts,
|
214 |
+
regardless of how they are packaged. This License gives no
|
215 |
+
permission to license the work in any other way, but it does not
|
216 |
+
invalidate such permission if you have separately received it.
|
217 |
+
|
218 |
+
d) If the work has interactive user interfaces, each must display
|
219 |
+
Appropriate Legal Notices; however, if the Program has interactive
|
220 |
+
interfaces that do not display Appropriate Legal Notices, your
|
221 |
+
work need not make them do so.
|
222 |
+
|
223 |
+
A compilation of a covered work with other separate and independent
|
224 |
+
works, which are not by their nature extensions of the covered work,
|
225 |
+
and which are not combined with it such as to form a larger program,
|
226 |
+
in or on a volume of a storage or distribution medium, is called an
|
227 |
+
"aggregate" if the compilation and its resulting copyright are not
|
228 |
+
used to limit the access or legal rights of the compilation's users
|
229 |
+
beyond what the individual works permit. Inclusion of a covered work
|
230 |
+
in an aggregate does not cause this License to apply to the other
|
231 |
+
parts of the aggregate.
|
232 |
+
|
233 |
+
6. Conveying Non-Source Forms.
|
234 |
+
|
235 |
+
You may convey a covered work in object code form under the terms
|
236 |
+
of sections 4 and 5, provided that you also convey the
|
237 |
+
machine-readable Corresponding Source under the terms of this License,
|
238 |
+
in one of these ways:
|
239 |
+
|
240 |
+
a) Convey the object code in, or embodied in, a physical product
|
241 |
+
(including a physical distribution medium), accompanied by the
|
242 |
+
Corresponding Source fixed on a durable physical medium
|
243 |
+
customarily used for software interchange.
|
244 |
+
|
245 |
+
b) Convey the object code in, or embodied in, a physical product
|
246 |
+
(including a physical distribution medium), accompanied by a
|
247 |
+
written offer, valid for at least three years and valid for as
|
248 |
+
long as you offer spare parts or customer support for that product
|
249 |
+
model, to give anyone who possesses the object code either (1) a
|
250 |
+
copy of the Corresponding Source for all the software in the
|
251 |
+
product that is covered by this License, on a durable physical
|
252 |
+
medium customarily used for software interchange, for a price no
|
253 |
+
more than your reasonable cost of physically performing this
|
254 |
+
conveying of source, or (2) access to copy the
|
255 |
+
Corresponding Source from a network server at no charge.
|
256 |
+
|
257 |
+
c) Convey individual copies of the object code with a copy of the
|
258 |
+
written offer to provide the Corresponding Source. This
|
259 |
+
alternative is allowed only occasionally and noncommercially, and
|
260 |
+
only if you received the object code with such an offer, in accord
|
261 |
+
with subsection 6b.
|
262 |
+
|
263 |
+
d) Convey the object code by offering access from a designated
|
264 |
+
place (gratis or for a charge), and offer equivalent access to the
|
265 |
+
Corresponding Source in the same way through the same place at no
|
266 |
+
further charge. You need not require recipients to copy the
|
267 |
+
Corresponding Source along with the object code. If the place to
|
268 |
+
copy the object code is a network server, the Corresponding Source
|
269 |
+
may be on a different server (operated by you or a third party)
|
270 |
+
that supports equivalent copying facilities, provided you maintain
|
271 |
+
clear directions next to the object code saying where to find the
|
272 |
+
Corresponding Source. Regardless of what server hosts the
|
273 |
+
Corresponding Source, you remain obligated to ensure that it is
|
274 |
+
available for as long as needed to satisfy these requirements.
|
275 |
+
|
276 |
+
e) Convey the object code using peer-to-peer transmission, provided
|
277 |
+
you inform other peers where the object code and Corresponding
|
278 |
+
Source of the work are being offered to the general public at no
|
279 |
+
charge under subsection 6d.
|
280 |
+
|
281 |
+
A separable portion of the object code, whose source code is excluded
|
282 |
+
from the Corresponding Source as a System Library, need not be
|
283 |
+
included in conveying the object code work.
|
284 |
+
|
285 |
+
A "User Product" is either (1) a "consumer product", which means any
|
286 |
+
tangible personal property which is normally used for personal, family,
|
287 |
+
or household purposes, or (2) anything designed or sold for incorporation
|
288 |
+
into a dwelling. In determining whether a product is a consumer product,
|
289 |
+
doubtful cases shall be resolved in favor of coverage. For a particular
|
290 |
+
product received by a particular user, "normally used" refers to a
|
291 |
+
typical or common use of that class of product, regardless of the status
|
292 |
+
of the particular user or of the way in which the particular user
|
293 |
+
actually uses, or expects or is expected to use, the product. A product
|
294 |
+
is a consumer product regardless of whether the product has substantial
|
295 |
+
commercial, industrial or non-consumer uses, unless such uses represent
|
296 |
+
the only significant mode of use of the product.
|
297 |
+
|
298 |
+
"Installation Information" for a User Product means any methods,
|
299 |
+
procedures, authorization keys, or other information required to install
|
300 |
+
and execute modified versions of a covered work in that User Product from
|
301 |
+
a modified version of its Corresponding Source. The information must
|
302 |
+
suffice to ensure that the continued functioning of the modified object
|
303 |
+
code is in no case prevented or interfered with solely because
|
304 |
+
modification has been made.
|
305 |
+
|
306 |
+
If you convey an object code work under this section in, or with, or
|
307 |
+
specifically for use in, a User Product, and the conveying occurs as
|
308 |
+
part of a transaction in which the right of possession and use of the
|
309 |
+
User Product is transferred to the recipient in perpetuity or for a
|
310 |
+
fixed term (regardless of how the transaction is characterized), the
|
311 |
+
Corresponding Source conveyed under this section must be accompanied
|
312 |
+
by the Installation Information. But this requirement does not apply
|
313 |
+
if neither you nor any third party retains the ability to install
|
314 |
+
modified object code on the User Product (for example, the work has
|
315 |
+
been installed in ROM).
|
316 |
+
|
317 |
+
The requirement to provide Installation Information does not include a
|
318 |
+
requirement to continue to provide support service, warranty, or updates
|
319 |
+
for a work that has been modified or installed by the recipient, or for
|
320 |
+
the User Product in which it has been modified or installed. Access to a
|
321 |
+
network may be denied when the modification itself materially and
|
322 |
+
adversely affects the operation of the network or violates the rules and
|
323 |
+
protocols for communication across the network.
|
324 |
+
|
325 |
+
Corresponding Source conveyed, and Installation Information provided,
|
326 |
+
in accord with this section must be in a format that is publicly
|
327 |
+
documented (and with an implementation available to the public in
|
328 |
+
source code form), and must require no special password or key for
|
329 |
+
unpacking, reading or copying.
|
330 |
+
|
331 |
+
7. Additional Terms.
|
332 |
+
|
333 |
+
"Additional permissions" are terms that supplement the terms of this
|
334 |
+
License by making exceptions from one or more of its conditions.
|
335 |
+
Additional permissions that are applicable to the entire Program shall
|
336 |
+
be treated as though they were included in this License, to the extent
|
337 |
+
that they are valid under applicable law. If additional permissions
|
338 |
+
apply only to part of the Program, that part may be used separately
|
339 |
+
under those permissions, but the entire Program remains governed by
|
340 |
+
this License without regard to the additional permissions.
|
341 |
+
|
342 |
+
When you convey a copy of a covered work, you may at your option
|
343 |
+
remove any additional permissions from that copy, or from any part of
|
344 |
+
it. (Additional permissions may be written to require their own
|
345 |
+
removal in certain cases when you modify the work.) You may place
|
346 |
+
additional permissions on material, added by you to a covered work,
|
347 |
+
for which you have or can give appropriate copyright permission.
|
348 |
+
|
349 |
+
Notwithstanding any other provision of this License, for material you
|
350 |
+
add to a covered work, you may (if authorized by the copyright holders of
|
351 |
+
that material) supplement the terms of this License with terms:
|
352 |
+
|
353 |
+
a) Disclaiming warranty or limiting liability differently from the
|
354 |
+
terms of sections 15 and 16 of this License; or
|
355 |
+
|
356 |
+
b) Requiring preservation of specified reasonable legal notices or
|
357 |
+
author attributions in that material or in the Appropriate Legal
|
358 |
+
Notices displayed by works containing it; or
|
359 |
+
|
360 |
+
c) Prohibiting misrepresentation of the origin of that material, or
|
361 |
+
requiring that modified versions of such material be marked in
|
362 |
+
reasonable ways as different from the original version; or
|
363 |
+
|
364 |
+
d) Limiting the use for publicity purposes of names of licensors or
|
365 |
+
authors of the material; or
|
366 |
+
|
367 |
+
e) Declining to grant rights under trademark law for use of some
|
368 |
+
trade names, trademarks, or service marks; or
|
369 |
+
|
370 |
+
f) Requiring indemnification of licensors and authors of that
|
371 |
+
material by anyone who conveys the material (or modified versions of
|
372 |
+
it) with contractual assumptions of liability to the recipient, for
|
373 |
+
any liability that these contractual assumptions directly impose on
|
374 |
+
those licensors and authors.
|
375 |
+
|
376 |
+
All other non-permissive additional terms are considered "further
|
377 |
+
restrictions" within the meaning of section 10. If the Program as you
|
378 |
+
received it, or any part of it, contains a notice stating that it is
|
379 |
+
governed by this License along with a term that is a further
|
380 |
+
restriction, you may remove that term. If a license document contains
|
381 |
+
a further restriction but permits relicensing or conveying under this
|
382 |
+
License, you may add to a covered work material governed by the terms
|
383 |
+
of that license document, provided that the further restriction does
|
384 |
+
not survive such relicensing or conveying.
|
385 |
+
|
386 |
+
If you add terms to a covered work in accord with this section, you
|
387 |
+
must place, in the relevant source files, a statement of the
|
388 |
+
additional terms that apply to those files, or a notice indicating
|
389 |
+
where to find the applicable terms.
|
390 |
+
|
391 |
+
Additional terms, permissive or non-permissive, may be stated in the
|
392 |
+
form of a separately written license, or stated as exceptions;
|
393 |
+
the above requirements apply either way.
|
394 |
+
|
395 |
+
8. Termination.
|
396 |
+
|
397 |
+
You may not propagate or modify a covered work except as expressly
|
398 |
+
provided under this License. Any attempt otherwise to propagate or
|
399 |
+
modify it is void, and will automatically terminate your rights under
|
400 |
+
this License (including any patent licenses granted under the third
|
401 |
+
paragraph of section 11).
|
402 |
+
|
403 |
+
However, if you cease all violation of this License, then your
|
404 |
+
license from a particular copyright holder is reinstated (a)
|
405 |
+
provisionally, unless and until the copyright holder explicitly and
|
406 |
+
finally terminates your license, and (b) permanently, if the copyright
|
407 |
+
holder fails to notify you of the violation by some reasonable means
|
408 |
+
prior to 60 days after the cessation.
|
409 |
+
|
410 |
+
Moreover, your license from a particular copyright holder is
|
411 |
+
reinstated permanently if the copyright holder notifies you of the
|
412 |
+
violation by some reasonable means, this is the first time you have
|
413 |
+
received notice of violation of this License (for any work) from that
|
414 |
+
copyright holder, and you cure the violation prior to 30 days after
|
415 |
+
your receipt of the notice.
|
416 |
+
|
417 |
+
Termination of your rights under this section does not terminate the
|
418 |
+
licenses of parties who have received copies or rights from you under
|
419 |
+
this License. If your rights have been terminated and not permanently
|
420 |
+
reinstated, you do not qualify to receive new licenses for the same
|
421 |
+
material under section 10.
|
422 |
+
|
423 |
+
9. Acceptance Not Required for Having Copies.
|
424 |
+
|
425 |
+
You are not required to accept this License in order to receive or
|
426 |
+
run a copy of the Program. Ancillary propagation of a covered work
|
427 |
+
occurring solely as a consequence of using peer-to-peer transmission
|
428 |
+
to receive a copy likewise does not require acceptance. However,
|
429 |
+
nothing other than this License grants you permission to propagate or
|
430 |
+
modify any covered work. These actions infringe copyright if you do
|
431 |
+
not accept this License. Therefore, by modifying or propagating a
|
432 |
+
covered work, you indicate your acceptance of this License to do so.
|
433 |
+
|
434 |
+
10. Automatic Licensing of Downstream Recipients.
|
435 |
+
|
436 |
+
Each time you convey a covered work, the recipient automatically
|
437 |
+
receives a license from the original licensors, to run, modify and
|
438 |
+
propagate that work, subject to this License. You are not responsible
|
439 |
+
for enforcing compliance by third parties with this License.
|
440 |
+
|
441 |
+
An "entity transaction" is a transaction transferring control of an
|
442 |
+
organization, or substantially all assets of one, or subdividing an
|
443 |
+
organization, or merging organizations. If propagation of a covered
|
444 |
+
work results from an entity transaction, each party to that
|
445 |
+
transaction who receives a copy of the work also receives whatever
|
446 |
+
licenses to the work the party's predecessor in interest had or could
|
447 |
+
give under the previous paragraph, plus a right to possession of the
|
448 |
+
Corresponding Source of the work from the predecessor in interest, if
|
449 |
+
the predecessor has it or can get it with reasonable efforts.
|
450 |
+
|
451 |
+
You may not impose any further restrictions on the exercise of the
|
452 |
+
rights granted or affirmed under this License. For example, you may
|
453 |
+
not impose a license fee, royalty, or other charge for exercise of
|
454 |
+
rights granted under this License, and you may not initiate litigation
|
455 |
+
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
456 |
+
any patent claim is infringed by making, using, selling, offering for
|
457 |
+
sale, or importing the Program or any portion of it.
|
458 |
+
|
459 |
+
11. Patents.
|
460 |
+
|
461 |
+
A "contributor" is a copyright holder who authorizes use under this
|
462 |
+
License of the Program or a work on which the Program is based. The
|
463 |
+
work thus licensed is called the contributor's "contributor version".
|
464 |
+
|
465 |
+
A contributor's "essential patent claims" are all patent claims
|
466 |
+
owned or controlled by the contributor, whether already acquired or
|
467 |
+
hereafter acquired, that would be infringed by some manner, permitted
|
468 |
+
by this License, of making, using, or selling its contributor version,
|
469 |
+
but do not include claims that would be infringed only as a
|
470 |
+
consequence of further modification of the contributor version. For
|
471 |
+
purposes of this definition, "control" includes the right to grant
|
472 |
+
patent sublicenses in a manner consistent with the requirements of
|
473 |
+
this License.
|
474 |
+
|
475 |
+
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
476 |
+
patent license under the contributor's essential patent claims, to
|
477 |
+
make, use, sell, offer for sale, import and otherwise run, modify and
|
478 |
+
propagate the contents of its contributor version.
|
479 |
+
|
480 |
+
In the following three paragraphs, a "patent license" is any express
|
481 |
+
agreement or commitment, however denominated, not to enforce a patent
|
482 |
+
(such as an express permission to practice a patent or covenant not to
|
483 |
+
sue for patent infringement). To "grant" such a patent license to a
|
484 |
+
party means to make such an agreement or commitment not to enforce a
|
485 |
+
patent against the party.
|
486 |
+
|
487 |
+
If you convey a covered work, knowingly relying on a patent license,
|
488 |
+
and the Corresponding Source of the work is not available for anyone
|
489 |
+
to copy, free of charge and under the terms of this License, through a
|
490 |
+
publicly available network server or other readily accessible means,
|
491 |
+
then you must either (1) cause the Corresponding Source to be so
|
492 |
+
available, or (2) arrange to deprive yourself of the benefit of the
|
493 |
+
patent license for this particular work, or (3) arrange, in a manner
|
494 |
+
consistent with the requirements of this License, to extend the patent
|
495 |
+
license to downstream recipients. "Knowingly relying" means you have
|
496 |
+
actual knowledge that, but for the patent license, your conveying the
|
497 |
+
covered work in a country, or your recipient's use of the covered work
|
498 |
+
in a country, would infringe one or more identifiable patents in that
|
499 |
+
country that you have reason to believe are valid.
|
500 |
+
|
501 |
+
If, pursuant to or in connection with a single transaction or
|
502 |
+
arrangement, you convey, or propagate by procuring conveyance of, a
|
503 |
+
covered work, and grant a patent license to some of the parties
|
504 |
+
receiving the covered work authorizing them to use, propagate, modify
|
505 |
+
or convey a specific copy of the covered work, then the patent license
|
506 |
+
you grant is automatically extended to all recipients of the covered
|
507 |
+
work and works based on it.
|
508 |
+
|
509 |
+
A patent license is "discriminatory" if it does not include within
|
510 |
+
the scope of its coverage, prohibits the exercise of, or is
|
511 |
+
conditioned on the non-exercise of one or more of the rights that are
|
512 |
+
specifically granted under this License. You may not convey a covered
|
513 |
+
work if you are a party to an arrangement with a third party that is
|
514 |
+
in the business of distributing software, under which you make payment
|
515 |
+
to the third party based on the extent of your activity of conveying
|
516 |
+
the work, and under which the third party grants, to any of the
|
517 |
+
parties who would receive the covered work from you, a discriminatory
|
518 |
+
patent license (a) in connection with copies of the covered work
|
519 |
+
conveyed by you (or copies made from those copies), or (b) primarily
|
520 |
+
for and in connection with specific products or compilations that
|
521 |
+
contain the covered work, unless you entered into that arrangement,
|
522 |
+
or that patent license was granted, prior to 28 March 2007.
|
523 |
+
|
524 |
+
Nothing in this License shall be construed as excluding or limiting
|
525 |
+
any implied license or other defenses to infringement that may
|
526 |
+
otherwise be available to you under applicable patent law.
|
527 |
+
|
528 |
+
12. No Surrender of Others' Freedom.
|
529 |
+
|
530 |
+
If conditions are imposed on you (whether by court order, agreement or
|
531 |
+
otherwise) that contradict the conditions of this License, they do not
|
532 |
+
excuse you from the conditions of this License. If you cannot convey a
|
533 |
+
covered work so as to satisfy simultaneously your obligations under this
|
534 |
+
License and any other pertinent obligations, then as a consequence you may
|
535 |
+
not convey it at all. For example, if you agree to terms that obligate you
|
536 |
+
to collect a royalty for further conveying from those to whom you convey
|
537 |
+
the Program, the only way you could satisfy both those terms and this
|
538 |
+
License would be to refrain entirely from conveying the Program.
|
539 |
+
|
540 |
+
13. Remote Network Interaction; Use with the GNU General Public License.
|
541 |
+
|
542 |
+
Notwithstanding any other provision of this License, if you modify the
|
543 |
+
Program, your modified version must prominently offer all users
|
544 |
+
interacting with it remotely through a computer network (if your version
|
545 |
+
supports such interaction) an opportunity to receive the Corresponding
|
546 |
+
Source of your version by providing access to the Corresponding Source
|
547 |
+
from a network server at no charge, through some standard or customary
|
548 |
+
means of facilitating copying of software. This Corresponding Source
|
549 |
+
shall include the Corresponding Source for any work covered by version 3
|
550 |
+
of the GNU General Public License that is incorporated pursuant to the
|
551 |
+
following paragraph.
|
552 |
+
|
553 |
+
Notwithstanding any other provision of this License, you have
|
554 |
+
permission to link or combine any covered work with a work licensed
|
555 |
+
under version 3 of the GNU General Public License into a single
|
556 |
+
combined work, and to convey the resulting work. The terms of this
|
557 |
+
License will continue to apply to the part which is the covered work,
|
558 |
+
but the work with which it is combined will remain governed by version
|
559 |
+
3 of the GNU General Public License.
|
560 |
+
|
561 |
+
14. Revised Versions of this License.
|
562 |
+
|
563 |
+
The Free Software Foundation may publish revised and/or new versions of
|
564 |
+
the GNU Affero General Public License from time to time. Such new versions
|
565 |
+
will be similar in spirit to the present version, but may differ in detail to
|
566 |
+
address new problems or concerns.
|
567 |
+
|
568 |
+
Each version is given a distinguishing version number. If the
|
569 |
+
Program specifies that a certain numbered version of the GNU Affero General
|
570 |
+
Public License "or any later version" applies to it, you have the
|
571 |
+
option of following the terms and conditions either of that numbered
|
572 |
+
version or of any later version published by the Free Software
|
573 |
+
Foundation. If the Program does not specify a version number of the
|
574 |
+
GNU Affero General Public License, you may choose any version ever published
|
575 |
+
by the Free Software Foundation.
|
576 |
+
|
577 |
+
If the Program specifies that a proxy can decide which future
|
578 |
+
versions of the GNU Affero General Public License can be used, that proxy's
|
579 |
+
public statement of acceptance of a version permanently authorizes you
|
580 |
+
to choose that version for the Program.
|
581 |
+
|
582 |
+
Later license versions may give you additional or different
|
583 |
+
permissions. However, no additional obligations are imposed on any
|
584 |
+
author or copyright holder as a result of your choosing to follow a
|
585 |
+
later version.
|
586 |
+
|
587 |
+
15. Disclaimer of Warranty.
|
588 |
+
|
589 |
+
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
590 |
+
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
591 |
+
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
592 |
+
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
593 |
+
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
594 |
+
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
595 |
+
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
596 |
+
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
597 |
+
|
598 |
+
16. Limitation of Liability.
|
599 |
+
|
600 |
+
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
601 |
+
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
602 |
+
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
603 |
+
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
604 |
+
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
605 |
+
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
606 |
+
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
607 |
+
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
608 |
+
SUCH DAMAGES.
|
609 |
+
|
610 |
+
17. Interpretation of Sections 15 and 16.
|
611 |
+
|
612 |
+
If the disclaimer of warranty and limitation of liability provided
|
613 |
+
above cannot be given local legal effect according to their terms,
|
614 |
+
reviewing courts shall apply local law that most closely approximates
|
615 |
+
an absolute waiver of all civil liability in connection with the
|
616 |
+
Program, unless a warranty or assumption of liability accompanies a
|
617 |
+
copy of the Program in return for a fee.
|
618 |
+
|
619 |
+
END OF TERMS AND CONDITIONS
|
620 |
+
|
621 |
+
How to Apply These Terms to Your New Programs
|
622 |
+
|
623 |
+
If you develop a new program, and you want it to be of the greatest
|
624 |
+
possible use to the public, the best way to achieve this is to make it
|
625 |
+
free software which everyone can redistribute and change under these terms.
|
626 |
+
|
627 |
+
To do so, attach the following notices to the program. It is safest
|
628 |
+
to attach them to the start of each source file to most effectively
|
629 |
+
state the exclusion of warranty; and each file should have at least
|
630 |
+
the "copyright" line and a pointer to where the full notice is found.
|
631 |
+
|
632 |
+
<one line to give the program's name and a brief idea of what it does.>
|
633 |
+
Copyright (C) <year> <name of author>
|
634 |
+
|
635 |
+
This program is free software: you can redistribute it and/or modify
|
636 |
+
it under the terms of the GNU Affero General Public License as published
|
637 |
+
by the Free Software Foundation, either version 3 of the License, or
|
638 |
+
(at your option) any later version.
|
639 |
+
|
640 |
+
This program is distributed in the hope that it will be useful,
|
641 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
642 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
643 |
+
GNU Affero General Public License for more details.
|
644 |
+
|
645 |
+
You should have received a copy of the GNU Affero General Public License
|
646 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
647 |
+
|
648 |
+
Also add information on how to contact you by electronic and paper mail.
|
649 |
+
|
650 |
+
If your software can interact with users remotely through a computer
|
651 |
+
network, you should also make sure that it provides a way for users to
|
652 |
+
get its source. For example, if your program is a web application, its
|
653 |
+
interface could display a "Source" link that leads users to an archive
|
654 |
+
of the code. There are many ways you could offer source, and different
|
655 |
+
solutions will be better for different programs; see section 13 for the
|
656 |
+
specific requirements.
|
657 |
+
|
658 |
+
You should also get your employer (if you work as a programmer) or school,
|
659 |
+
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
660 |
+
For more information on this, and how to apply and follow the GNU AGPL, see
|
661 |
+
<https://www.gnu.org/licenses/>.
|
docs/cn/README.md
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
|
3 |
+
<a href="https://trendshift.io/repositories/10489" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10489" alt="2noise%2FChatTTS | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
4 |
+
|
5 |
+
# ChatTTS
|
6 |
+
一款适用于日常对话的生成式语音模型。
|
7 |
+
|
8 |
+
[![Licence](https://img.shields.io/github/license/2noise/ChatTTS?style=for-the-badge)](https://github.com/2noise/ChatTTS/blob/main/LICENSE)
|
9 |
+
[![PyPI](https://img.shields.io/pypi/v/ChatTTS.svg?style=for-the-badge&color=green)](https://pypi.org/project/ChatTTS)
|
10 |
+
|
11 |
+
[![Huggingface](https://img.shields.io/badge/🤗%20-Models-yellow.svg?style=for-the-badge)](https://huggingface.co/2Noise/ChatTTS)
|
12 |
+
[![Open In Colab](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/2noise/ChatTTS/blob/main/examples/ipynb/colab.ipynb)
|
13 |
+
[![Discord](https://img.shields.io/badge/Discord-7289DA?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/Ud5Jxgx5yD)
|
14 |
+
|
15 |
+
[**English**](../../README.md) | **简体中文** | [**日本語**](../jp/README.md) | [**Русский**](../ru/README.md) | [**Español**](../es/README.md) | [**Français**](../fr/README.md)
|
16 |
+
|
17 |
+
</div>
|
18 |
+
|
19 |
+
> [!NOTE]
|
20 |
+
> 注意此版本可能不是最新版,所有内容请以英文版为准。
|
21 |
+
|
22 |
+
## 简介
|
23 |
+
|
24 |
+
> [!Note]
|
25 |
+
> 这个仓库包含算法架构和一些简单的示例。
|
26 |
+
|
27 |
+
> [!Tip]
|
28 |
+
> 由本仓库衍生出的用户端产品,请参见由社区维护的索引仓库 [Awesome-ChatTTS](https://github.com/libukai/Awesome-ChatTTS)。
|
29 |
+
|
30 |
+
ChatTTS 是一款专门为对话场景(例如 LLM 助手)设计的文本转语音模型。
|
31 |
+
|
32 |
+
### 支持的语种
|
33 |
+
|
34 |
+
- [x] 英语
|
35 |
+
- [x] 中文
|
36 |
+
- [ ] 敬请期待...
|
37 |
+
|
38 |
+
### 亮点
|
39 |
+
|
40 |
+
> 你可以参考 **[Bilibili](https://www.bilibili.com/video/BV1zn4y1o7iV)** 上的这个视频,了解本项目的详细情况。
|
41 |
+
|
42 |
+
1. **对话式 TTS**: ChatTTS 针对对话式任务进行了优化,能够实现自然且富有表现力的合成语音。它支持多个说话者,便于生成互动式对话。
|
43 |
+
2. **精细的控制**: 该模型可以预测和控制精细的韵律特征,包括笑声、停顿和插入语。
|
44 |
+
3. **更好的韵律**: ChatTTS 在韵律方面超越了大多数开源 TTS 模型。我们提供预训练模型以支持进一步的研究和开发。
|
45 |
+
|
46 |
+
### 数据集和模型
|
47 |
+
|
48 |
+
- 主模型使用了 100,000+ 小时的中文和英文音频数据进行训练。
|
49 |
+
- **[HuggingFace](https://huggingface.co/2Noise/ChatTTS)** 上的开源版本是一个在 40,000 小时数据上进行无监督微调的预训练模型。
|
50 |
+
|
51 |
+
### 路线图
|
52 |
+
|
53 |
+
- [x] 开源 4 万小时基础模型和 spk_stats 文件。
|
54 |
+
- [x] 支持流式语音输出。
|
55 |
+
- [ ] 开源具有多情感控制功能的 4 万小时版本。
|
56 |
+
- [ ] ChatTTS.cpp (欢迎在 2noise 组织中新建仓库)。
|
57 |
+
|
58 |
+
### 免责声明
|
59 |
+
|
60 |
+
> [!Important]
|
61 |
+
> 此仓库仅供学术用途。
|
62 |
+
|
63 |
+
本项目旨在用于教育和研究目的,不适用于任何商业或法律目的。作者不保证信息的准确性、完整性和可靠性。此仓库中使用的信息和数据仅供学术和研究目的。数据来自公开来源,作者不声称对数据拥有任何所有权或版权。
|
64 |
+
|
65 |
+
ChatTTS 是一款强大的文本转语音系统。但是,负责任和道德地使用这项技术非常重要。为了限制 ChatTTS 的使用,我们在 40,000 小时模型的训练过程中添加了少量高频噪声,并使用 MP3 格式尽可能压缩音频质量,以防止恶意行为者将其用于犯罪目的。同时,我们内部训练了一个检测模型,并计划在未来开源它。
|
66 |
+
|
67 |
+
### 联系方式
|
68 |
+
|
69 |
+
> 欢迎随时提交 GitHub issues/PRs。
|
70 |
+
|
71 |
+
#### 合作洽谈
|
72 |
+
|
73 |
+
如需就模型和路线图进行合作洽谈,请发送邮件至 **[email protected]**。
|
74 |
+
|
75 |
+
#### 线上讨论
|
76 |
+
|
77 |
+
##### 1. 官方 QQ 群
|
78 |
+
|
79 |
+
- **群 1**, 808364215 (已满)
|
80 |
+
- **群 2**, 230696694 (已满)
|
81 |
+
- **群 3**, 933639842 (已满)
|
82 |
+
- **群 4**, 608667975
|
83 |
+
|
84 |
+
##### 2. Discord
|
85 |
+
|
86 |
+
点击加入 [Discord](https://discord.gg/Ud5Jxgx5yD)。
|
87 |
+
|
88 |
+
## 体验教程
|
89 |
+
|
90 |
+
### 克隆仓库
|
91 |
+
|
92 |
+
```bash
|
93 |
+
git clone https://github.com/2noise/ChatTTS
|
94 |
+
cd ChatTTS
|
95 |
+
```
|
96 |
+
|
97 |
+
### 安装依赖
|
98 |
+
|
99 |
+
#### 1. 直接安装
|
100 |
+
|
101 |
+
```bash
|
102 |
+
pip install --upgrade -r requirements.txt
|
103 |
+
```
|
104 |
+
|
105 |
+
#### 2. 使用 conda 安装
|
106 |
+
|
107 |
+
```bash
|
108 |
+
conda create -n chattts
|
109 |
+
conda activate chattts
|
110 |
+
pip install -r requirements.txt
|
111 |
+
```
|
112 |
+
|
113 |
+
#### 可选 : 如果使用 NVIDIA GPU(仅限 Linux),可安装 TransformerEngine。
|
114 |
+
|
115 |
+
> [!Note]
|
116 |
+
> 安装过程可能耗时很长。
|
117 |
+
|
118 |
+
> [!Warning]
|
119 |
+
> TransformerEngine 的适配目前正在开发中,运行时可能会遇到较多问题。仅推荐出于开发目的安装。
|
120 |
+
|
121 |
+
```bash
|
122 |
+
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
|
123 |
+
```
|
124 |
+
|
125 |
+
#### 可选 : 安装 FlashAttention-2 (主要适用于 NVIDIA GPU)
|
126 |
+
|
127 |
+
> [!Note]
|
128 |
+
> 支持设备列表详见 [Hugging Face Doc](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2).
|
129 |
+
|
130 |
+
```bash
|
131 |
+
pip install flash-attn --no-build-isolation
|
132 |
+
```
|
133 |
+
|
134 |
+
### 快速启动
|
135 |
+
|
136 |
+
> 确保在执行以下命令时,处于项目根目录下。
|
137 |
+
|
138 |
+
#### 1. WebUI 可视化界面
|
139 |
+
|
140 |
+
```bash
|
141 |
+
python examples/web/webui.py
|
142 |
+
```
|
143 |
+
|
144 |
+
#### 2. 命令行交互
|
145 |
+
|
146 |
+
> 生成的音频将保存至 `./output_audio_n.mp3`
|
147 |
+
|
148 |
+
```bash
|
149 |
+
python examples/cmd/run.py "Your text 1." "Your text 2."
|
150 |
+
```
|
151 |
+
|
152 |
+
## 开发教程
|
153 |
+
|
154 |
+
### 安装 Python 包
|
155 |
+
|
156 |
+
1. 从 PyPI 安装稳定版
|
157 |
+
|
158 |
+
```bash
|
159 |
+
pip install ChatTTS
|
160 |
+
```
|
161 |
+
|
162 |
+
2. 从 GitHub 安装最新版
|
163 |
+
|
164 |
+
```bash
|
165 |
+
pip install git+https://github.com/2noise/ChatTTS
|
166 |
+
```
|
167 |
+
|
168 |
+
3. 从本地文件夹安装开发版
|
169 |
+
|
170 |
+
```bash
|
171 |
+
pip install -e .
|
172 |
+
```
|
173 |
+
|
174 |
+
### 基础用法
|
175 |
+
|
176 |
+
```python
|
177 |
+
import ChatTTS
|
178 |
+
import torch
|
179 |
+
import torchaudio
|
180 |
+
|
181 |
+
chat = ChatTTS.Chat()
|
182 |
+
chat.load(compile=False) # Set to True for better performance
|
183 |
+
|
184 |
+
texts = ["PUT YOUR 1st TEXT HERE", "PUT YOUR 2nd TEXT HERE"]
|
185 |
+
|
186 |
+
wavs = chat.infer(texts)
|
187 |
+
|
188 |
+
torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
|
189 |
+
```
|
190 |
+
|
191 |
+
### 进阶用法
|
192 |
+
|
193 |
+
```python
|
194 |
+
###################################
|
195 |
+
# Sample a speaker from Gaussian.
|
196 |
+
|
197 |
+
rand_spk = chat.sample_random_speaker()
|
198 |
+
print(rand_spk) # save it for later timbre recovery
|
199 |
+
|
200 |
+
params_infer_code = ChatTTS.Chat.InferCodeParams(
|
201 |
+
spk_emb = rand_spk, # add sampled speaker
|
202 |
+
temperature = .3, # using custom temperature
|
203 |
+
top_P = 0.7, # top P decode
|
204 |
+
top_K = 20, # top K decode
|
205 |
+
)
|
206 |
+
|
207 |
+
###################################
|
208 |
+
# For sentence level manual control.
|
209 |
+
|
210 |
+
# use oral_(0-9), laugh_(0-2), break_(0-7)
|
211 |
+
# to generate special token in text to synthesize.
|
212 |
+
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
213 |
+
prompt='[oral_2][laugh_0][break_6]',
|
214 |
+
)
|
215 |
+
|
216 |
+
wavs = chat.infer(
|
217 |
+
texts,
|
218 |
+
params_refine_text=params_refine_text,
|
219 |
+
params_infer_code=params_infer_code,
|
220 |
+
)
|
221 |
+
|
222 |
+
###################################
|
223 |
+
# For word level manual control.
|
224 |
+
|
225 |
+
text = 'What is [uv_break]your favorite english food?[laugh][lbreak]'
|
226 |
+
wavs = chat.infer(text, skip_refine_text=True, params_refine_text=params_refine_text, params_infer_code=params_infer_code)
|
227 |
+
torchaudio.save("output2.wav", torch.from_numpy(wavs[0]), 24000)
|
228 |
+
```
|
229 |
+
|
230 |
+
<details open>
|
231 |
+
<summary><h4>示例: 自我介绍</h4></summary>
|
232 |
+
|
233 |
+
```python
|
234 |
+
inputs_en = """
|
235 |
+
chatTTS is a text to speech model designed for dialogue applications.
|
236 |
+
[uv_break]it supports mixed language input [uv_break]and offers multi speaker
|
237 |
+
capabilities with precise control over prosodic elements like
|
238 |
+
[uv_break]laughter[uv_break][laugh], [uv_break]pauses, [uv_break]and intonation.
|
239 |
+
[uv_break]it delivers natural and expressive speech,[uv_break]so please
|
240 |
+
[uv_break] use the project responsibly at your own risk.[uv_break]
|
241 |
+
""".replace('\n', '') # English is still experimental.
|
242 |
+
|
243 |
+
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
244 |
+
prompt='[oral_2][laugh_0][break_4]',
|
245 |
+
)
|
246 |
+
|
247 |
+
audio_array_en = chat.infer(inputs_en, params_refine_text=params_refine_text)
|
248 |
+
torchaudio.save("output3.wav", torch.from_numpy(audio_array_en[0]), 24000)
|
249 |
+
```
|
250 |
+
|
251 |
+
<table>
|
252 |
+
<tr>
|
253 |
+
<td align="center">
|
254 |
+
|
255 |
+
**男性音色**
|
256 |
+
|
257 |
+
</td>
|
258 |
+
<td align="center">
|
259 |
+
|
260 |
+
**女性音色**
|
261 |
+
|
262 |
+
</td>
|
263 |
+
</tr>
|
264 |
+
<tr>
|
265 |
+
<td align="center">
|
266 |
+
|
267 |
+
[男性音色](https://github.com/2noise/ChatTTS/assets/130631963/e0f51251-db7f-4d39-a0e9-3e095bb65de1)
|
268 |
+
|
269 |
+
</td>
|
270 |
+
<td align="center">
|
271 |
+
|
272 |
+
[女性音色](https://github.com/2noise/ChatTTS/assets/130631963/f5dcdd01-1091-47c5-8241-c4f6aaaa8bbd)
|
273 |
+
|
274 |
+
</td>
|
275 |
+
</tr>
|
276 |
+
</table>
|
277 |
+
|
278 |
+
</details>
|
279 |
+
|
280 |
+
## 常见问题
|
281 |
+
|
282 |
+
#### 1. 我需要多少 VRAM? 推理速度如何?
|
283 |
+
|
284 |
+
对于 30 秒的音频片段,至少需要 4GB 的 GPU 内存。 对于 4090 GPU,它可以每秒生成大约 7 个语义 token 对应的音频。实时因子 (RTF) 约为 0.3。
|
285 |
+
|
286 |
+
#### 2. 模型稳定性不够好,存在多个说话者或音频质量差等问题。
|
287 |
+
|
288 |
+
这是一个通常发生在自回归模型(例如 bark 和 valle)中的问题,通常很难避免。可以尝试多个样本以找到合适的结果。
|
289 |
+
|
290 |
+
#### 3. 除了笑声,我们还能控制其他东西吗?我们能控制其他情绪吗?
|
291 |
+
|
292 |
+
在当前发布的模型中,可用的 token 级控制单元是 `[laugh]`, `[uv_break]` 和 `[lbreak]`。未来的版本中,我们可能会开源具有更多情绪控制功能的模型。
|
293 |
+
|
294 |
+
## 致谢
|
295 |
+
|
296 |
+
- [bark](https://github.com/suno-ai/bark), [XTTSv2](https://github.com/coqui-ai/TTS) 和 [valle](https://arxiv.org/abs/2301.02111) 通过自回归式系统展示了非凡的 TTS 效果。
|
297 |
+
- [fish-speech](https://github.com/fishaudio/fish-speech) 揭示了 GVQ 作为 LLM 建模的音频分词器的能力。
|
298 |
+
- [vocos](https://github.com/gemelo-ai/vocos) vocos 被用作预训练声码器。
|
299 |
+
|
300 |
+
## 特别鸣谢
|
301 |
+
|
302 |
+
- [wlu-audio lab](https://audio.westlake.edu.cn/) 对于早期算法实验的支持。
|
303 |
+
|
304 |
+
## 贡献者列表
|
305 |
+
|
306 |
+
[![contributors](https://contrib.rocks/image?repo=2noise/ChatTTS)](https://github.com/2noise/ChatTTS/graphs/contributors)
|
307 |
+
|
308 |
+
## 项目浏览量
|
309 |
+
|
310 |
+
<div align="center">
|
311 |
+
|
312 |
+
![counter](https://counter.seku.su/cmoe?name=chattts&theme=mbs)
|
313 |
+
|
314 |
+
</div>
|
docs/es/README.md
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
|
3 |
+
<a href="https://trendshift.io/repositories/10489" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10489" alt="2noise%2FChatTTS | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
4 |
+
|
5 |
+
# ChatTTS
|
6 |
+
Un modelo de generación de voz para la conversación diaria.
|
7 |
+
|
8 |
+
[![Licence](https://img.shields.io/github/license/2noise/ChatTTS?style=for-the-badge)](https://github.com/2noise/ChatTTS/blob/main/LICENSE)
|
9 |
+
|
10 |
+
[![Huggingface](https://img.shields.io/badge/🤗%20-Models-yellow.svg?style=for-the-badge)](https://huggingface.co/2Noise/ChatTTS)
|
11 |
+
[![Open In Colab](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/2noise/ChatTTS/blob/main/examples/ipynb/colab.ipynb)
|
12 |
+
|
13 |
+
[**English**](../../README.md) | [**简体中文**](../cn/README.md) | [**日本語**](../jp/README.md) | [**Русский**](../ru/README.md) | **Español**
|
14 |
+
| [**Français**](../fr/README.md)
|
15 |
+
</div>
|
16 |
+
|
17 |
+
> [!NOTE]
|
18 |
+
> Atención, es posible que esta versión no sea la última. Por favor, consulte la versión en inglés para conocer todo el contenido.
|
19 |
+
|
20 |
+
## Introducción
|
21 |
+
|
22 |
+
ChatTTS es un modelo de texto a voz diseñado específicamente para escenarios conversacionales como LLM assistant.
|
23 |
+
|
24 |
+
### Idiomas Soportados
|
25 |
+
|
26 |
+
- [x] Inglés
|
27 |
+
- [x] Chino
|
28 |
+
- [ ] Manténganse al tanto...
|
29 |
+
|
30 |
+
### Aspectos Destacados
|
31 |
+
|
32 |
+
> Puede consultar **[este video en Bilibili](https://www.bilibili.com/video/BV1zn4y1o7iV)** para obtener una descripción detallada.
|
33 |
+
|
34 |
+
1. **TTS Conversacional**: ChatTTS está optimizado para tareas conversacionales, logrando una síntesis de voz natural y expresiva. Soporta múltiples hablantes, lo que facilita la generación de diálogos interactivos.
|
35 |
+
2. **Control Finas**: Este modelo puede predecir y controlar características detalladas de la prosodia, incluyendo risas, pausas e interjecciones.
|
36 |
+
3. **Mejor Prosodia**: ChatTTS supera a la mayoría de los modelos TTS de código abierto en cuanto a prosodia. Ofrecemos modelos preentrenados para apoyar estudios y desarrollos adicionales.
|
37 |
+
|
38 |
+
### Conjunto de Datos & Modelo
|
39 |
+
|
40 |
+
- El modelo principal se entrena con más de 100.000 horas de datos de audio en chino e inglés.
|
41 |
+
- La versión de código abierto en **[HuggingFace](https://huggingface.co/2Noise/ChatTTS)** es un modelo preentrenado con 40.000 horas, sin SFT.
|
42 |
+
|
43 |
+
### Hoja de Ruta
|
44 |
+
|
45 |
+
- [x] Publicar el modelo base de 40k horas y el archivo spk_stats como código abierto
|
46 |
+
- [ ] Publicar los códigos de codificador VQ y entrenamiento de Lora como código abierto
|
47 |
+
- [ ] Generación de audio en streaming sin refinar el texto
|
48 |
+
- [ ] Publicar la versión de 40k horas con control de múltiples emociones como código abierto
|
49 |
+
- [ ] ¿ChatTTS.cpp? (Se aceptan PR o un nuevo repositorio)
|
50 |
+
|
51 |
+
### Descargo de Responsabilidad
|
52 |
+
|
53 |
+
> [!Important]
|
54 |
+
> Este repositorio es sólo para fines académicos.
|
55 |
+
|
56 |
+
Este proyecto está destinado a fines educativos y estudios, y no es adecuado para ningún propósito comercial o legal. El autor no garantiza la exactitud, integridad o fiabilidad de la información. La información y los datos utilizados en este repositorio son únicamente para fines académicos y de investigación. Los datos provienen de fuentes públicas, y el autor no reclama ningún derecho de propiedad o copyright sobre ellos.
|
57 |
+
|
58 |
+
ChatTTS es un potente sistema de conversión de texto a voz. Sin embargo, es crucial utilizar esta tecnología de manera responsable y ética. Para limitar el uso de ChatTTS, hemos añadido una pequeña cantidad de ruido de alta frecuencia durante el proceso de entrenamiento del modelo de 40.000 horas y hemos comprimido la calidad del audio en formato MP3 tanto como sea posible para evitar que actores malintencionados lo usen con fines delictivos. Además, hemos entrenado internamente un modelo de detección y planeamos hacerlo de código abierto en el futuro.
|
59 |
+
|
60 |
+
### Contacto
|
61 |
+
|
62 |
+
> No dudes en enviar issues/PRs de GitHub.
|
63 |
+
|
64 |
+
#### Consultas Formales
|
65 |
+
|
66 |
+
Si desea discutir la cooperación sobre modelos y hojas de ruta, envíe un correo electrónico a **[email protected]**.
|
67 |
+
|
68 |
+
#### Chat en Línea
|
69 |
+
|
70 |
+
##### 1. Grupo QQ (Aplicación Social China)
|
71 |
+
|
72 |
+
- **Grupo 1**, 808364215 (Lleno)
|
73 |
+
- **Grupo 2**, 230696694 (Lleno)
|
74 |
+
- **Grupo 3**, 933639842
|
75 |
+
|
76 |
+
## Instalación (En Proceso)
|
77 |
+
|
78 |
+
> Se cargará en pypi pronto según https://github.com/2noise/ChatTTS/issues/269.
|
79 |
+
|
80 |
+
```bash
|
81 |
+
pip install git+https://github.com/2noise/ChatTTS
|
82 |
+
```
|
83 |
+
|
84 |
+
## Inicio
|
85 |
+
### Clonar el repositorio
|
86 |
+
```bash
|
87 |
+
git clone https://github.com/2noise/ChatTTS
|
88 |
+
cd ChatTTS
|
89 |
+
```
|
90 |
+
|
91 |
+
### Requerimientos de instalación
|
92 |
+
#### 1. Instalar directamente
|
93 |
+
```bash
|
94 |
+
pip install --upgrade -r requirements.txt
|
95 |
+
```
|
96 |
+
|
97 |
+
#### 2. Instalar desde conda
|
98 |
+
```bash
|
99 |
+
conda create -n chattts
|
100 |
+
conda activate chattts
|
101 |
+
pip install -r requirements.txt
|
102 |
+
```
|
103 |
+
|
104 |
+
### Inicio Rápido
|
105 |
+
#### 1. Iniciar la interfaz de usuario web (WebUI)
|
106 |
+
```bash
|
107 |
+
python examples/web/webui.py
|
108 |
+
```
|
109 |
+
|
110 |
+
#### 2. Inferir por línea de comando
|
111 |
+
> Guardará el audio en `./output_audio_xxx.wav`
|
112 |
+
|
113 |
+
```bash
|
114 |
+
python examples/cmd/run.py "Please input your text."
|
115 |
+
```
|
116 |
+
|
117 |
+
### Básico
|
118 |
+
|
119 |
+
```python
|
120 |
+
import ChatTTS
|
121 |
+
from IPython.display import Audio
|
122 |
+
import torchaudio
|
123 |
+
import torch
|
124 |
+
|
125 |
+
chat = ChatTTS.Chat()
|
126 |
+
chat.load(compile=False) # Set to True for better performance
|
127 |
+
|
128 |
+
texts = ["PUT YOUR TEXT HERE",]
|
129 |
+
|
130 |
+
wavs = chat.infer(texts)
|
131 |
+
|
132 |
+
torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
|
133 |
+
```
|
134 |
+
|
135 |
+
### Avanzado
|
136 |
+
|
137 |
+
```python
|
138 |
+
###################################
|
139 |
+
# Sample a speaker from Gaussian.
|
140 |
+
|
141 |
+
rand_spk = chat.sample_random_speaker()
|
142 |
+
print(rand_spk) # save it for later timbre recovery
|
143 |
+
|
144 |
+
params_infer_code = ChatTTS.Chat.InferCodeParams(
|
145 |
+
spk_emb = rand_spk, # add sampled speaker
|
146 |
+
temperature = .3, # using custom temperature
|
147 |
+
top_P = 0.7, # top P decode
|
148 |
+
top_K = 20, # top K decode
|
149 |
+
)
|
150 |
+
|
151 |
+
###################################
|
152 |
+
# For sentence level manual control.
|
153 |
+
|
154 |
+
# use oral_(0-9), laugh_(0-2), break_(0-7)
|
155 |
+
# to generate special token in text to synthesize.
|
156 |
+
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
157 |
+
prompt='[oral_2][laugh_0][break_6]',
|
158 |
+
)
|
159 |
+
|
160 |
+
wavs = chat.infer(
|
161 |
+
texts,
|
162 |
+
params_refine_text=params_refine_text,
|
163 |
+
params_infer_code=params_infer_code,
|
164 |
+
)
|
165 |
+
|
166 |
+
###################################
|
167 |
+
# For word level manual control.
|
168 |
+
text = 'What is [uv_break]your favorite english food?[laugh][lbreak]'
|
169 |
+
wavs = chat.infer(text, skip_refine_text=True, params_refine_text=params_refine_text, params_infer_code=params_infer_code)
|
170 |
+
torchaudio.save("output2.wav", torch.from_numpy(wavs[0]), 24000)
|
171 |
+
```
|
172 |
+
|
173 |
+
<details open>
|
174 |
+
<summary><h4>Ejemplo: auto presentación</h4></summary>
|
175 |
+
|
176 |
+
```python
|
177 |
+
inputs_en = """
|
178 |
+
chat T T S is a text to speech model designed for dialogue applications.
|
179 |
+
[uv_break]it supports mixed language input [uv_break]and offers multi speaker
|
180 |
+
capabilities with precise control over prosodic elements [laugh]like like
|
181 |
+
[uv_break]laughter[laugh], [uv_break]pauses, [uv_break]and intonation.
|
182 |
+
[uv_break]it delivers natural and expressive speech,[uv_break]so please
|
183 |
+
[uv_break] use the project responsibly at your own risk.[uv_break]
|
184 |
+
""".replace('\n', '') # English is still experimental.
|
185 |
+
|
186 |
+
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
187 |
+
prompt='[oral_2][laugh_0][break_4]',
|
188 |
+
)
|
189 |
+
|
190 |
+
audio_array_en = chat.infer(inputs_en, params_refine_text=params_refine_text)
|
191 |
+
torchaudio.save("output3.wav", torch.from_numpy(audio_array_en[0]), 24000)
|
192 |
+
```
|
193 |
+
|
194 |
+
<table>
|
195 |
+
<tr>
|
196 |
+
<td align="center">
|
197 |
+
|
198 |
+
**altavoz masculino**
|
199 |
+
|
200 |
+
</td>
|
201 |
+
<td align="center">
|
202 |
+
|
203 |
+
**altavoz femenino**
|
204 |
+
|
205 |
+
</td>
|
206 |
+
</tr>
|
207 |
+
<tr>
|
208 |
+
<td align="center">
|
209 |
+
|
210 |
+
[male speaker](https://github.com/2noise/ChatTTS/assets/130631963/e0f51251-db7f-4d39-a0e9-3e095bb65de1)
|
211 |
+
|
212 |
+
</td>
|
213 |
+
<td align="center">
|
214 |
+
|
215 |
+
[female speaker](https://github.com/2noise/ChatTTS/assets/130631963/f5dcdd01-1091-47c5-8241-c4f6aaaa8bbd)
|
216 |
+
|
217 |
+
</td>
|
218 |
+
</tr>
|
219 |
+
</table>
|
220 |
+
|
221 |
+
|
222 |
+
</details>
|
223 |
+
|
224 |
+
## Preguntas y Respuestas
|
225 |
+
|
226 |
+
#### 1. ¿Cuánta memoria gráfica de acceso aleatorio necesito? ¿Qué tal inferir la velocidad?
|
227 |
+
Para un clip de audio de 30 segundos, se requieren al menos 4 GB de memoria de GPU. Para la GPU 4090, puede generar audio correspondiente a aproximadamente 7 tokens semánticos por segundo. El Factor en Tiempo Real (RTF) es aproximadamente 0,3.
|
228 |
+
|
229 |
+
#### 2. La estabilidad del modelo no es lo suficientemente buena y existen problemas como varios altavoces o mala calidad del sonido.
|
230 |
+
|
231 |
+
Este es un problema común en los modelos autorregresivos (para bark y valle). Generalmente es difícil de evitar. Puede probar varias muestras para encontrar resultados adecuados.
|
232 |
+
|
233 |
+
#### 3. ¿Podemos controlar algo más que la risa? ¿Podemos controlar otras emociones?
|
234 |
+
|
235 |
+
En el modelo lanzado actualmente, las únicas unidades de control a nivel de token son `[risa]`, `[uv_break]` y `[lbreak]`. En una versión futura, es posible que abramos el código fuente del modelo con capacidades adicionales de control de emociones.
|
236 |
+
|
237 |
+
## Agradecimientos
|
238 |
+
- [bark](https://github.com/suno-ai/bark), [XTTSv2](https://github.com/coqui-ai/TTS) y [valle](https://arxiv.org/abs/2301.02111) demuestran un resultado TTS notable mediante un sistema de estilo autorregresivo.
|
239 |
+
- [fish-speech](https://github.com/fishaudio/fish-speech) revela las capacidades de GVQ como tokenizador de audio para el modelado LLM.
|
240 |
+
- [vocos](https://github.com/gemelo-ai/vocos) se utiliza como codificador de voz previamente entrenado.
|
241 |
+
|
242 |
+
## Agradecimiento Especial
|
243 |
+
- [wlu-audio lab](https://audio.westlake.edu.cn/) para experimentos iniciales del algoritmo.
|
244 |
+
|
245 |
+
## Recursos Relacionados
|
246 |
+
- [Awesome-ChatTTS](https://github.com/libukai/Awesome-ChatTTS)
|
247 |
+
|
248 |
+
## Gracias a todos los contribuyentes por sus esfuerzos.
|
249 |
+
[![contributors](https://contrib.rocks/image?repo=2noise/ChatTTS)](https://github.com/2noise/ChatTTS/graphs/contributors)
|
250 |
+
|
251 |
+
<div align="center">
|
252 |
+
|
253 |
+
![counter](https://counter.seku.su/cmoe?name=chattts&theme=mbs)
|
254 |
+
|
255 |
+
</div>
|
docs/fr/README.md
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
|
3 |
+
<a href="https://trendshift.io/repositories/10489" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10489" alt="2noise%2FChatTTS | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
4 |
+
|
5 |
+
# ChatTTS
|
6 |
+
Un modèle de parole génératif pour le dialogue quotidien.
|
7 |
+
|
8 |
+
[![Licence](https://img.shields.io/github/license/2noise/ChatTTS?style=for-the-badge)](https://github.com/2noise/ChatTTS/blob/main/LICENSE)
|
9 |
+
[![PyPI](https://img.shields.io/pypi/v/ChatTTS.svg?style=for-the-badge&color=green)](https://pypi.org/project/ChatTTS)
|
10 |
+
|
11 |
+
[![Huggingface](https://img.shields.io/badge/🤗%20-Models-yellow.svg?style=for-the-badge)](https://huggingface.co/2Noise/ChatTTS)
|
12 |
+
[![Open In Colab](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/2noise/ChatTTS/blob/main/examples/ipynb/colab.ipynb)
|
13 |
+
[![Discord](https://img.shields.io/badge/Discord-7289DA?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/Ud5Jxgx5yD)
|
14 |
+
|
15 |
+
[**English**](../../README.md) | [**简体中文**](../cn/README.md) | [**日本語**](../jp/README.md) | [**Русский**](../ru/README.md) | [**Español**](../es/README.md)| **Français**
|
16 |
+
|
17 |
+
</div>
|
18 |
+
|
19 |
+
## Introduction
|
20 |
+
> [!Note]
|
21 |
+
> Ce dépôt contient l'infrastructure de l'algorithme et quelques exemples simples.
|
22 |
+
|
23 |
+
> [!Tip]
|
24 |
+
> Pour les produits finaux étendus pour les utilisateurs, veuillez consulter le dépôt index [Awesome-ChatTTS](https://github.com/libukai/Awesome-ChatTTS/tree/en) maintenu par la communauté.
|
25 |
+
|
26 |
+
ChatTTS est un modèle de synthèse vocale conçu spécifiquement pour les scénarios de dialogue tels que les assistants LLM.
|
27 |
+
|
28 |
+
### Langues prises en charge
|
29 |
+
- [x] Anglais
|
30 |
+
- [x] Chinois
|
31 |
+
- [ ] À venir...
|
32 |
+
|
33 |
+
### Points forts
|
34 |
+
> Vous pouvez vous référer à **[cette vidéo sur Bilibili](https://www.bilibili.com/video/BV1zn4y1o7iV)** pour une description détaillée.
|
35 |
+
|
36 |
+
1. **Synthèse vocale conversationnelle**: ChatTTS est optimisé pour les tâches basées sur le dialogue, permettant une synthèse vocale naturelle et expressive. Il prend en charge plusieurs locuteurs, facilitant les conversations interactives.
|
37 |
+
2. **Contrôle granulaire**: Le modèle peut prédire et contrôler des caractéristiques prosodiques fines, y compris le rire, les pauses et les interjections.
|
38 |
+
3. **Meilleure prosodie**: ChatTTS dépasse la plupart des modèles TTS open-source en termes de prosodie. Nous fournissons des modèles pré-entraînés pour soutenir la recherche et le développement.
|
39 |
+
|
40 |
+
### Dataset & Modèle
|
41 |
+
- Le modèle principal est entraîné avec des données audio en chinois et en anglais de plus de 100 000 heures.
|
42 |
+
- La version open-source sur **[HuggingFace](https://huggingface.co/2Noise/ChatTTS)** est un modèle pré-entraîné de 40 000 heures sans SFT.
|
43 |
+
|
44 |
+
### Roadmap
|
45 |
+
- [x] Open-source du modèle de base de 40k heures et du fichier spk_stats.
|
46 |
+
- [x] Génération audio en streaming.
|
47 |
+
- [ ] Open-source de la version 40k heures avec contrôle multi-émotions.
|
48 |
+
- [ ] ChatTTS.cpp (nouveau dépôt dans l'organisation `2noise` est bienvenu)
|
49 |
+
|
50 |
+
### Avertissement
|
51 |
+
> [!Important]
|
52 |
+
> Ce dépôt est à des fins académiques uniquement.
|
53 |
+
|
54 |
+
Il est destiné à un usage éducatif et de recherche, et ne doit pas être utilisé à des fins commerciales ou légales. Les auteurs ne garantissent pas l'exactitude, l'exhaustivité ou la fiabilité des informations. Les informations et les données utilisées dans ce dépôt sont à des fins académiques et de recherche uniquement. Les données obtenues à partir de sources accessibles au public, et les auteurs ne revendiquent aucun droit de propriété ou de copyright sur les données.
|
55 |
+
|
56 |
+
ChatTTS est un système de synthèse vocale puissant. Cependant, il est très important d'utiliser cette technologie de manière responsable et éthique. Pour limiter l'utilisation de ChatTTS, nous avons ajouté une petite quantité de bruit haute fréquence pendant l'entraînement du modèle de 40 000 heures et compressé la qualité audio autant que possible en utilisant le format MP3, pour empêcher les acteurs malveillants de l'utiliser potentiellement à des fins criminelles. En même temps, nous avons entraîné en interne un modèle de détection et prévoyons de l'open-source à l'avenir.
|
57 |
+
|
58 |
+
### Contact
|
59 |
+
> Les issues/PRs sur GitHub sont toujours les bienvenus.
|
60 |
+
|
61 |
+
#### Demandes formelles
|
62 |
+
Pour les demandes formelles concernant le modèle et la feuille de route, veuillez nous contacter à **[email protected]**.
|
63 |
+
|
64 |
+
#### Discussion en ligne
|
65 |
+
##### 1. Groupe QQ (application sociale chinoise)
|
66 |
+
- **Groupe 1**, 808364215 (Complet)
|
67 |
+
- **Groupe 2**, 230696694 (Complet)
|
68 |
+
- **Groupe 3**, 933639842 (Complet)
|
69 |
+
- **Groupe 4**, 608667975
|
70 |
+
|
71 |
+
##### 2. Serveur Discord
|
72 |
+
Rejoignez en cliquant [ici](https://discord.gg/Ud5Jxgx5yD).
|
73 |
+
|
74 |
+
## Pour commencer
|
75 |
+
### Cloner le dépôt
|
76 |
+
```bash
|
77 |
+
git clone https://github.com/2noise/ChatTTS
|
78 |
+
cd ChatTTS
|
79 |
+
```
|
80 |
+
|
81 |
+
### Installer les dépendances
|
82 |
+
#### 1. Installation directe
|
83 |
+
```bash
|
84 |
+
pip install --upgrade -r requirements.txt
|
85 |
+
```
|
86 |
+
|
87 |
+
#### 2. Installer depuis conda
|
88 |
+
```bash
|
89 |
+
conda create -n chattts
|
90 |
+
conda activate chattts
|
91 |
+
pip install -r requirements.txt
|
92 |
+
```
|
93 |
+
|
94 |
+
#### Optionnel : Installer TransformerEngine si vous utilisez un GPU NVIDIA (Linux uniquement)
|
95 |
+
> [!Note]
|
96 |
+
> Le processus d'installation est très lent.
|
97 |
+
|
98 |
+
> [!Warning]
|
99 |
+
> L'adaptation de TransformerEngine est actuellement en cours de développement et NE PEUT PAS fonctionner correctement pour le moment.
|
100 |
+
> Installez-le uniquement à des fins de développement.
|
101 |
+
|
102 |
+
```bash
|
103 |
+
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
|
104 |
+
```
|
105 |
+
|
106 |
+
#### Optionnel : Installer FlashAttention-2 (principalement GPU NVIDIA)
|
107 |
+
> [!Note]
|
108 |
+
> Voir les appareils pris en charge dans la [documentation Hugging Face](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2).
|
109 |
+
|
110 |
+
> [!Warning]
|
111 |
+
> Actuellement, FlashAttention-2 ralentira la vitesse de génération selon [ce problème](https://github.com/huggingface/transformers/issues/26990).
|
112 |
+
> Installez-le uniquement à des fins de développement.
|
113 |
+
|
114 |
+
```bash
|
115 |
+
pip install flash-attn --no-build-isolation
|
116 |
+
```
|
117 |
+
|
118 |
+
### Démarrage rapide
|
119 |
+
> Assurez-vous que vous êtes dans le répertoire racine du projet lorsque vous exécutez ces commandes ci-dessous.
|
120 |
+
|
121 |
+
#### 1. Lancer WebUI
|
122 |
+
```bash
|
123 |
+
python examples/web/webui.py
|
124 |
+
```
|
125 |
+
|
126 |
+
#### 2. Inférence par ligne de commande
|
127 |
+
> Cela enregistrera l'audio sous ‘./output_audio_n.mp3’
|
128 |
+
|
129 |
+
```bash
|
130 |
+
python examples/cmd/run.py "Votre premier texte." "Votre deuxième texte."
|
131 |
+
```
|
132 |
+
|
133 |
+
## Installation
|
134 |
+
|
135 |
+
1. Installer la version stable depuis PyPI
|
136 |
+
```bash
|
137 |
+
pip install ChatTTS
|
138 |
+
```
|
139 |
+
|
140 |
+
2. Installer la dernière version depuis GitHub
|
141 |
+
```bash
|
142 |
+
pip install git+https://github.com/2noise/ChatTTS
|
143 |
+
```
|
144 |
+
|
145 |
+
3. Installer depuis le répertoire local en mode développement
|
146 |
+
```bash
|
147 |
+
pip install -e .
|
148 |
+
```
|
149 |
+
|
150 |
+
### Utilisation de base
|
151 |
+
|
152 |
+
```python
|
153 |
+
import ChatTTS
|
154 |
+
import torch
|
155 |
+
import torchaudio
|
156 |
+
|
157 |
+
chat = ChatTTS.Chat()
|
158 |
+
chat.load(compile=False) # Définissez sur True pour de meilleures performances
|
159 |
+
|
160 |
+
texts = ["METTEZ VOTRE PREMIER TEXTE ICI", "METTEZ VOTRE DEUXIÈME TEXTE ICI"]
|
161 |
+
|
162 |
+
wavs = chat.infer(texts)
|
163 |
+
|
164 |
+
torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
|
165 |
+
```
|
166 |
+
|
167 |
+
### Utilisation avancée
|
168 |
+
|
169 |
+
```python
|
170 |
+
###################################
|
171 |
+
# Échantillonner un locuteur à partir d'une distribution gaussienne.
|
172 |
+
|
173 |
+
rand_spk = chat.sample_random_speaker()
|
174 |
+
print(rand_spk) # sauvegardez-le pour une récupération ultérieure du timbre
|
175 |
+
|
176 |
+
params_infer_code = ChatTTS.Chat.InferCodeParams(
|
177 |
+
spk_emb = rand_spk, # ajouter le locuteur échantillonné
|
178 |
+
temperature = .3, # en utilisant une température personnalisée
|
179 |
+
top_P = 0.7, # top P décode
|
180 |
+
top_K = 20, # top K décode
|
181 |
+
)
|
182 |
+
|
183 |
+
###################################
|
184 |
+
# Pour le contrôle manuel au niveau des phrases.
|
185 |
+
|
186 |
+
# utilisez oral_(0-9), laugh_(0-2), break_(0-7)
|
187 |
+
# pour générer un token spécial dans le texte à synthétiser.
|
188 |
+
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
189 |
+
prompt='[oral_2][laugh_0][break_6]',
|
190 |
+
)
|
191 |
+
|
192 |
+
wavs = chat.infer(
|
193 |
+
texts,
|
194 |
+
params_refine_text=params_refine_text,
|
195 |
+
params_infer_code=params_infer_code,
|
196 |
+
)
|
197 |
+
|
198 |
+
###################################
|
199 |
+
# Pour le contrôle manuel au niveau des mots.
|
200 |
+
|
201 |
+
text = 'Quel est [uv_break]votre plat anglais préféré?[laugh][lbreak]'
|
202 |
+
wavs = chat.infer(text, skip_refine_text=True, params_refine_text=params_refine_text, params_infer_code=params_infer_code)
|
203 |
+
torchaudio.save("output2.wav", torch.from_numpy(wavs[0]), 24000)
|
204 |
+
```
|
205 |
+
|
206 |
+
<details open>
|
207 |
+
<summary><h4>Exemple : auto-présentation</h4></summary>
|
208 |
+
|
209 |
+
```python
|
210 |
+
inputs_en = """
|
211 |
+
chat T T S est un modèle de synthèse vocale conçu pour les applications de dialogue.
|
212 |
+
[uv_break]il prend en charge les entrées en langues mixtes [uv_break]et offre des capacités multi-locuteurs
|
213 |
+
avec un contrôle précis des éléments prosodiques comme
|
214 |
+
[uv_break]le rire[uv_break][laugh], [uv_break]les pauses, [uv_break]et l'intonation.
|
215 |
+
[uv_break]il délivre une parole naturelle et expressive,[uv_break]donc veuillez
|
216 |
+
[uv_break]utiliser le projet de manière responsable à vos risques et périls.[uv_break]
|
217 |
+
""".replace('\n', '') # L'anglais est encore expérimental.
|
218 |
+
|
219 |
+
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
220 |
+
prompt='[oral_2][laugh_0][break_4]',
|
221 |
+
)
|
222 |
+
|
223 |
+
audio_array_en = chat.infer(inputs_en, params_refine_text=params_refine_text)
|
224 |
+
torchaudio.save("output3.wav", torch.from_numpy(audio_array_en[0]), 24000)
|
225 |
+
```
|
226 |
+
|
227 |
+
<table>
|
228 |
+
<tr>
|
229 |
+
<td align="center">
|
230 |
+
|
231 |
+
**locuteur masculin**
|
232 |
+
|
233 |
+
</td>
|
234 |
+
<td align="center">
|
235 |
+
|
236 |
+
**locutrice féminine**
|
237 |
+
|
238 |
+
</td>
|
239 |
+
</tr>
|
240 |
+
<tr>
|
241 |
+
<td align="center">
|
242 |
+
|
243 |
+
[locuteur masculin](https://github.com/2noise/ChatTTS/assets/130631963/e0f51251-db7f-4d39-a0e9-3e095bb65de1)
|
244 |
+
|
245 |
+
</td>
|
246 |
+
<td align="center">
|
247 |
+
|
248 |
+
[locutrice féminine](https://github.com/2noise/ChatTTS/assets/130631963/f5dcdd01-1091-47c5-8241-c4f6aaaa8bbd)
|
249 |
+
|
250 |
+
</td>
|
251 |
+
</tr>
|
252 |
+
</table>
|
253 |
+
|
254 |
+
|
255 |
+
</details>
|
256 |
+
|
257 |
+
## FAQ
|
258 |
+
|
259 |
+
#### 1. De combien de VRAM ai-je besoin ? Quelle est la vitesse d'inférence ?
|
260 |
+
Pour un clip audio de 30 secondes, au moins 4 Go de mémoire GPU sont nécessaires. Pour le GPU 4090, il peut générer de l'audio correspondant à environ 7 tokens sémantiques par seconde. Le Facteur Temps Réel (RTF) est d'environ 0.3.
|
261 |
+
|
262 |
+
#### 2. La stabilité du modèle n'est pas suffisante, avec des problèmes tels que des locuteurs multiples ou une mauvaise qualité audio.
|
263 |
+
C'est un problème qui se produit généralement avec les modèles autoregressifs (pour bark et valle). Il est généralement difficile à éviter. On peut essayer plusieurs échantillons pour trouver un résultat approprié.
|
264 |
+
|
265 |
+
#### 3. En plus du rire, pouvons-nous contrôler autre chose ? Pouvons-nous contrôler d'autres émotions ?
|
266 |
+
Dans le modèle actuellement publié, les seules unités de contrôle au niveau des tokens sont `[laugh]`, `[uv_break]`, et `[lbreak]`. Dans les futures versions, nous pourrions open-source des modèles avec des capacités de contrôle émotionnel supplémentaires.
|
267 |
+
|
268 |
+
## Remerciements
|
269 |
+
- [bark](https://github.com/suno-ai/bark), [XTTSv2](https://github.com/coqui-ai/TTS) et [valle](https://arxiv.org/abs/2301.02111) démontrent un résultat TTS remarquable par un système de style autoregressif.
|
270 |
+
- [fish-speech](https://github.com/fishaudio/fish-speech) révèle la capacité de GVQ en tant que tokenizer audio pour la modélisation LLM.
|
271 |
+
- [vocos](https://github.com/gemelo-ai/vocos) qui est utilisé comme vocodeur pré-entraîné.
|
272 |
+
|
273 |
+
## Appréciation spéciale
|
274 |
+
- [wlu-audio lab](https://audio.westlake.edu.cn/) pour les expériences d'algorithme précoce.
|
275 |
+
|
276 |
+
## Merci à tous les contributeurs pour leurs efforts
|
277 |
+
[![contributors](https://contrib.rocks/image?repo=2noise/ChatTTS)](https://github.com/2noise/ChatTTS/graphs/contributors)
|
278 |
+
|
279 |
+
<div align="center">
|
280 |
+
|
281 |
+
![counter](https://counter.seku.su/cmoe?name=chattts&theme=mbs)
|
282 |
+
|
283 |
+
</div>
|
docs/jp/README.md
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ChatTTS
|
2 |
+
> [!NOTE]
|
3 |
+
> 以下の内容は最新情報ではない可能性がありますのでご了承ください。全ての内容は英語版に基準することになります。
|
4 |
+
|
5 |
+
[![Huggingface](https://img.shields.io/badge/🤗%20-Models-yellow.svg?style=for-the-badge)](https://huggingface.co/2Noise/ChatTTS)
|
6 |
+
|
7 |
+
[**English**](../../README.md) | [**简体中文**](../cn/README.md) | **日本語** | [**Русский**](../ru/README.md) | [**Español**](../es/README.md) | [**Français**](../fr/README.md)
|
8 |
+
|
9 |
+
ChatTTSは、LLMアシスタントなどの対話シナリオ用に特別に設計されたテキストから音声へのモデルです。英語と中国語の両方をサポートしています。私たちのモデルは、中国語と英語で構成される100,000時間以上でトレーニングされています。**[HuggingFace](https://huggingface.co/2Noise/ChatTTS)**でオープンソース化されているバージョンは、40,000時間の事前トレーニングモデルで、SFTは行われていません。
|
10 |
+
|
11 |
+
モデルやロードマップについての正式なお問い合わせは、**[email protected]**までご連絡ください。QQグループ:808364215に参加してディスカッションすることもできます。GitHubでの問題提起も歓迎します。
|
12 |
+
|
13 |
+
---
|
14 |
+
## ハイライト
|
15 |
+
1. **会話型TTS**: ChatTTSは対話ベースのタスクに最適化されており、自然で表現豊かな音声合成を実現します。複数の話者をサポートし、対話型の会話を容易にします。
|
16 |
+
2. **細かい制御**: このモデルは、笑い、一時停止、間投詞などの細かい韻律特徴を予測および制御することができます。
|
17 |
+
3. **より良い韻律**: ChatTTSは、韻律の面でほとんどのオープンソースTTSモデルを超えています。さらなる研究と開発をサポートするために、事前トレーニングされたモデルを提供しています。
|
18 |
+
|
19 |
+
モデルの詳細な説明については、**[Bilibiliのビデオ](https://www.bilibili.com/video/BV1zn4y1o7iV)**を参照してください。
|
20 |
+
|
21 |
+
---
|
22 |
+
|
23 |
+
## 免責事項
|
24 |
+
|
25 |
+
このリポジトリは学術目的のみのためです。教育および研究用途にのみ使用され、商業的または法的な目的には使用されません。著者は情報の正確性、完全性、または信頼性を保証しません。このリポジトリで使用される情報およびデータは、学術および研究目的のみのためのものです。データは公開されているソースから取得され、著者はデータに対する所有権または著作権を主張しません。
|
26 |
+
|
27 |
+
ChatTTSは強力なテキストから音声へのシステムです。しかし、この技術を責任を持って、倫理的に利用することが非常に重要です。ChatTTSの使用を制限するために、40,000時間のモデルのトレーニング中に少量の高周波ノイズを追加し、MP3形式を使用して音質を可能な限り圧縮しました。これは、悪意のあるアクターが潜在的に犯罪目的で使用することを防ぐためです。同時に、私たちは内部的に検出モデルをトレーニングしており、将来的にオープンソース化する予定です。
|
28 |
+
|
29 |
+
---
|
30 |
+
## 使用方法
|
31 |
+
|
32 |
+
<h4>基本的な使用方法</h4>
|
33 |
+
|
34 |
+
```python
|
35 |
+
import ChatTTS
|
36 |
+
from IPython.display import Audio
|
37 |
+
import torch
|
38 |
+
|
39 |
+
chat = ChatTTS.Chat()
|
40 |
+
chat.load(compile=False) # より良いパフォーマンスのためにTrueに設定
|
41 |
+
|
42 |
+
texts = ["ここにテキストを入力してください",]
|
43 |
+
|
44 |
+
wavs = chat.infer(texts, )
|
45 |
+
|
46 |
+
torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
|
47 |
+
```
|
48 |
+
|
49 |
+
<h4>高度な使用方法</h4>
|
50 |
+
|
51 |
+
```python
|
52 |
+
###################################
|
53 |
+
# ガウス分布から話者をサンプリングします。
|
54 |
+
|
55 |
+
rand_spk = chat.sample_random_speaker()
|
56 |
+
print(rand_spk) # save it for later timbre recovery
|
57 |
+
|
58 |
+
params_infer_code = {
|
59 |
+
'spk_emb': rand_spk, # サンプリングされた話者を追加
|
60 |
+
'temperature': .3, # カスタム温度を使用
|
61 |
+
'top_P': 0.7, # トップPデコード
|
62 |
+
'top_K': 20, # トップKデコード
|
63 |
+
}
|
64 |
+
|
65 |
+
###################################
|
66 |
+
# 文レベルの手動制御のために。
|
67 |
+
|
68 |
+
# 特別なトークンを生成するためにテキストにoral_(0-9)、laugh_(0-2)、break_(0-7)を使用します。
|
69 |
+
params_refine_text = {
|
70 |
+
'prompt': '[oral_2][laugh_0][break_6]'
|
71 |
+
}
|
72 |
+
|
73 |
+
wav = chat.infer(texts, params_refine_text=params_refine_text, params_infer_code=params_infer_code)
|
74 |
+
|
75 |
+
###################################
|
76 |
+
# 単語レベルの手動制御のために。
|
77 |
+
text = 'あなたの好きな英語の食べ物は何ですか?[uv_break][laugh][lbreak]'
|
78 |
+
wav = chat.infer(text, skip_refine_text=True, params_refine_text=params_refine_text, params_infer_code=params_infer_code)
|
79 |
+
torchaudio.save("output2.wav", torch.from_numpy(wavs[0]), 24000)
|
80 |
+
```
|
81 |
+
|
82 |
+
<details open>
|
83 |
+
<summary><h4>例:自己紹介</h4></summary>
|
84 |
+
|
85 |
+
```python
|
86 |
+
inputs_jp = """
|
87 |
+
ChatTTSは、対話アプリケーション用に設計されたテキストから音声へのモデルです。
|
88 |
+
[uv_break]混合言語入力をサポートし[uv_break]、韻律要素[laugh]の正確な制御を提供します
|
89 |
+
[uv_break]笑い[laugh]、[uv_break]一時停止、[uv_break]およびイントネーション。[uv_break]自然で表現豊かな音声を提供します
|
90 |
+
[uv_break]したがって、自己責任でプロジェクトを責任を持って使用してください。[uv_break]
|
91 |
+
""".replace('\n', '') # 英語はまだ実験的です。
|
92 |
+
|
93 |
+
params_refine_text = {
|
94 |
+
'prompt': '[oral_2][laugh_0][break_4]'
|
95 |
+
}
|
96 |
+
audio_array_jp = chat.infer(inputs_jp, params_refine_text=params_refine_text)
|
97 |
+
torchaudio.save("output3.wav", torch.from_numpy(audio_array_jp[0]), 24000)
|
98 |
+
```
|
99 |
+
[男性話者](https://github.com/2noise/ChatTTS/assets/130631963/e0f51251-db7f-4d39-a0e9-3e095bb65de1)
|
100 |
+
|
101 |
+
[女性話者](https://github.com/2noise/ChatTTS/assets/130631963/f5dcdd01-1091-47c5-8241-c4f6aaaa8bbd)
|
102 |
+
</details>
|
103 |
+
|
104 |
+
---
|
105 |
+
## ロードマップ
|
106 |
+
- [x] 40k時間のベースモデルとspk_statsファイルをオープンソース化
|
107 |
+
- [ ] VQエンコーダーとLoraトレーニングコードをオープンソース化
|
108 |
+
- [ ] テキストをリファインせずにストリーミングオーディオ生成*
|
109 |
+
- [ ] 複数の感情制御を備えた40k時間バージョンをオープンソース化
|
110 |
+
- [ ] ChatTTS.cppもしかしたら?(PRや新しいリポジトリが歓迎されます。)
|
111 |
+
|
112 |
+
----
|
113 |
+
## FAQ
|
114 |
+
|
115 |
+
##### VRAMはどれくらい必要ですか?推論速度はどうですか?
|
116 |
+
30秒のオーディオクリップには、少なくとも4GBのGPUメモリが必要です。4090 GPUの場合、約7つの意味トークンに対応するオーディオを1秒あたり生成できます。リアルタイムファクター(RTF)は約0.3です。
|
117 |
+
|
118 |
+
##### モデルの安定性が十分でなく、複数の話者や音質が悪いという問題があります。
|
119 |
+
|
120 |
+
これは、自己回帰モデル(barkおよびvalleの場合)で一般的に発生する問題です。一般的に避けるのは難しいです。複数のサンプルを試して、適切な結果を見つけることができます。
|
121 |
+
|
122 |
+
##### 笑い以外に何か制御できますか?他の感情を制御できますか?
|
123 |
+
|
124 |
+
現在リリースされているモデルでは、トークンレベルの制御ユニットは[laugh]、[uv_break]、および[lbreak]のみです。将来のバージョンでは、追加の感情制御機能を備えたモデルをオープンソース化する可能性があります。
|
125 |
+
|
126 |
+
---
|
127 |
+
## 謝辞
|
128 |
+
- [bark](https://github.com/suno-ai/bark)、[XTTSv2](https://github.com/coqui-ai/TTS)、および[valle](https://arxiv.org/abs/2301.02111)は、自己回帰型システムによる顕著なTTS結果を示しました。
|
129 |
+
- [fish-speech](https://github.com/fishaudio/fish-speech)は、LLMモデリングのためのオーディオトークナイザーとしてのGVQの能力を明らかにしました。
|
130 |
+
- 事前トレーニングされたボコーダーとして使用される[vocos](https://github.com/gemelo-ai/vocos)。
|
131 |
+
|
132 |
+
---
|
133 |
+
## 特別感謝
|
134 |
+
- 初期のアルゴリズム実験をサポートしてくれた[wlu-audio lab](https://audio.westlake.edu.cn/)。
|
docs/ru/README.md
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ChatTTS
|
2 |
+
> [!NOTE]
|
3 |
+
> Следующая информация может быть не самой последней, пожалуйста, смотрите английскую версию для актуальных данных.
|
4 |
+
|
5 |
+
[![Huggingface](https://img.shields.io/badge/🤗%20-Models-yellow.svg?style=for-the-badge)](https://huggingface.co/2Noise/ChatTTS)
|
6 |
+
|
7 |
+
[**English**](../../README.md) | [**简体中文**](../cn/README.md) | [**日本語**](../jp/README.md) | **Русский** | [**Español**](../es/README.md) | [**Français**](../fr/README.md)
|
8 |
+
|
9 |
+
ChatTTS - это модель преобразования текста в речь, специально разработанная для диалоговых сценариев, таких как помощник LLM. Она поддерживает как английский, так и китайский языки. Наша модель обучена на более чем 100 000 часах английского и китайского языков. Открытая версия на **[HuggingFace](https://huggingface.co/2Noise/ChatTTS)** - это предварительно обученная модель с 40 000 часами без SFT.
|
10 |
+
|
11 |
+
Для официальных запросов о модели и плане развития, пожалуйста, свяжитесь с нами по адресу **[email protected]**. Вы можете присоединиться к нашей группе QQ: 808364215 для обсуждения. Добавление вопросов на GitHub также приветствуется.
|
12 |
+
|
13 |
+
---
|
14 |
+
## Особенности
|
15 |
+
1. **Диалоговый TTS**: ChatTTS оптимизирован для задач, основанных на диалогах, что позволяет создавать натуральную и выразительную речь. Он поддерживает несколько говорящих, облегчая интерактивные беседы.
|
16 |
+
2. **Тонкий контроль**: Модель может предсказывать и контролировать тонкие просодические особенности, включая смех, паузы и вставные слова.
|
17 |
+
3. **Лучшая просодия**: ChatTTS превосходит большинство открытых моделей TTS с точки зрения просодии. Мы предоставляем предварительно обученные модели для поддержки дальнейших исследований и разработок.
|
18 |
+
|
19 |
+
Для подробного описания модели вы можете обратиться к **[видео на Bilibili](https://www.bilibili.com/video/BV1zn4y1o7iV)**
|
20 |
+
|
21 |
+
---
|
22 |
+
|
23 |
+
## Отказ от ответственности
|
24 |
+
|
25 |
+
Этот репозиторий предназначен только для академических целей. Он предназначен для образовательного и исследовательского использования и не должен использоваться в коммерческих или юридических целях. Авторы не гарантируют точность, полноту или надежность информации. Информация и данные, использованные в этом репозитории, предназначены только для академических и исследовательских целей. Данные получены из общедоступных источников, и авторы не заявляют о каких-либо правах собственности или авторских правах на данные.
|
26 |
+
|
27 |
+
ChatTTS - мощная система преобразования текста в речь. Однако очень важно использовать эту технологию ответственно и этично. Чтобы ограничить использование ChatTTS, мы добавили небольшое количество высокочастотного шума во время обучения модели на 40 000 часов и сжали качество аудио как можно больше с помощью формата MP3, чтобы предотвратить возможное использование злоумышленниками в преступных целях. В то же время мы внутренне обучили модель обнаружения и планируем открыть ее в будущем.
|
28 |
+
|
29 |
+
---
|
30 |
+
## Использование
|
31 |
+
|
32 |
+
<h4>Базовое использование</h4>
|
33 |
+
|
34 |
+
```python
|
35 |
+
import ChatTTS
|
36 |
+
from IPython.display import Audio
|
37 |
+
import torch
|
38 |
+
|
39 |
+
chat = ChatTTS.Chat()
|
40 |
+
chat.load(compile=False) # Установите значение True для лучшей производительности
|
41 |
+
|
42 |
+
texts = ["ВВЕДИ��Е ВАШ ТЕКСТ ЗДЕСЬ",]
|
43 |
+
|
44 |
+
wavs = chat.infer(texts)
|
45 |
+
|
46 |
+
torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
|
47 |
+
```
|
48 |
+
|
49 |
+
<h4>Продвинутое использование</h4>
|
50 |
+
|
51 |
+
```python
|
52 |
+
###################################
|
53 |
+
# Выборка говорящего из Гауссиана.
|
54 |
+
|
55 |
+
rand_spk = chat.sample_random_speaker()
|
56 |
+
print(rand_spk) # save it for later timbre recovery
|
57 |
+
|
58 |
+
params_infer_code = {
|
59 |
+
'spk_emb': rand_spk, # добавить выбранного говорящего
|
60 |
+
'temperature': .3, # использовать пользовательскую температуру
|
61 |
+
'top_P': 0.7, # декодирование top P
|
62 |
+
'top_K': 20, # декодирование top K
|
63 |
+
}
|
64 |
+
|
65 |
+
###################################
|
66 |
+
# Для контроля на уровне предложений.
|
67 |
+
|
68 |
+
# используйте oral_(0-9), laugh_(0-2), break_(0-7)
|
69 |
+
# для генерации специального токена в тексте для синтеза.
|
70 |
+
params_refine_text = {
|
71 |
+
'prompt': '[oral_2][laugh_0][break_6]'
|
72 |
+
}
|
73 |
+
|
74 |
+
wav = chat.infer(texts, params_refine_text=params_refine_text, params_infer_code=params_infer_code)
|
75 |
+
|
76 |
+
###################################
|
77 |
+
# Для контроля на уровне слов.
|
78 |
+
text = 'Какая ваша любимая английская еда?[uv_break]your favorite english food?[laugh][lbreak]'
|
79 |
+
wav = chat.infer(text, skip_refine_text=True, params_refine_text=params_refine_text, params_infer_code=params_infer_code)
|
80 |
+
torchaudio.save("output2.wav", torch.from_numpy(wavs[0]), 24000)
|
81 |
+
```
|
82 |
+
|
83 |
+
<details open>
|
84 |
+
<summary><h4>Пример: самопрезентация</h4></summary>
|
85 |
+
|
86 |
+
```python
|
87 |
+
inputs_ru = """
|
88 |
+
ChatTTS - это модель преобразования текста в речь, разработанная для диалоговых приложений.
|
89 |
+
[uv_break]Она поддерживает смешанный языковой ввод [uv_break]и предлагает возможности множественных говорящих
|
90 |
+
с точным контролем над просодическими элементами [laugh]как [uv_break]смех[laugh], [uv_break]паузы, [uv_break]и интонацию.
|
91 |
+
[uv_break]Она обеспечивает натуральную и выразительную речь,[uv_break]поэтому, пожалуйста,
|
92 |
+
[uv_break] используйте проект ответственно и на свой страх и риск.[uv_break]
|
93 |
+
""".replace('\n', '') # Русский язык все еще находится в экспериментальной стадии.
|
94 |
+
|
95 |
+
params_refine_text = {
|
96 |
+
'prompt': '[oral_2][laugh_0][break_4]'
|
97 |
+
}
|
98 |
+
audio_array_ru = chat.infer(inputs_ru, params_refine_text=params_refine_text)
|
99 |
+
torchaudio.save("output3.wav", torch.from_numpy(audio_array_ru[0]), 24000)
|
100 |
+
```
|
101 |
+
[мужской говорящий](https://github.com/2noise/ChatTTS/assets/130631963/e0f51251-db7f-4d39-a0e9-3e095bb65de1)
|
102 |
+
|
103 |
+
[женский говорящий](https://github.com/2noise/ChatTTS/assets/130631963/f5dcdd01-1091-47c5-8241-c4f6aaaa8bbd)
|
104 |
+
</details>
|
105 |
+
|
106 |
+
---
|
107 |
+
## План развития
|
108 |
+
- [x] Открыть исходный код базовой модели на 40 тысяч часов и файла spk_stats
|
109 |
+
- [ ] Открыть исходный код кодировщика VQ и кода обучения Lora
|
110 |
+
- [ ] Потоковая генерация аудио без уточнения текста*
|
111 |
+
- [ ] Открыть исходный код версии на 40 тысяч часов с управлением множественными эмоциями
|
112 |
+
- [ ] ChatTTS.cpp возможно? (PR или новый репозиторий приветствуются.)
|
113 |
+
|
114 |
+
----
|
115 |
+
## Часто задаваемые вопросы
|
116 |
+
|
117 |
+
##### Сколько VRAM мне нужно? Как насчет скорости инференса?
|
118 |
+
Для 30-секундного аудиоклипа требуется как минимум 4 ГБ памяти GPU. Для GPU 4090, он может генерировать аудио, соответствующее примерно 7 семантическим токенам в секунду. Фактор реального времени (RTF) составляет около 0.3.
|
119 |
+
|
120 |
+
##### Стабильность модели кажется недостаточно хорошей, возникают проблемы с множественными говорящими или плохим качеством аудио.
|
121 |
+
|
122 |
+
Это проблема, которая обычно возникает с авторегрессивными моделями (для bark и valle). Это обычно трудно избежать. Можно попробовать несколько образцов, чтобы найти подходящий результат.
|
123 |
+
|
124 |
+
##### Помимо смеха, можем ли мы контролировать что-то еще? Можем ли мы контролировать другие эмоции?
|
125 |
+
|
126 |
+
В текущей выпущенной модели единственными элементами управления на уровне токенов являются [laugh], [uv_break] и [lbreak]. В будущих версиях мы можем открыть модели с дополнительными возможностями контроля эмоций.
|
127 |
+
|
128 |
+
---
|
129 |
+
## Благодарности
|
130 |
+
- [bark](https://github.com/suno-ai/bark), [XTTSv2](https://github.com/coqui-ai/TTS) и [valle](https://arxiv.org/abs/2301.02111) демонстрируют замечательный результат TTS с помощью системы авторегрессивного стиля.
|
131 |
+
- [fish-speech](https://github.com/fishaudio/fish-speech) показывает возможности GVQ как аудио токенизатора для моделирования LLM.
|
132 |
+
- [vocos](https://github.com/gemelo-ai/vocos), который используется в качестве предварительно обученного вокодера.
|
133 |
+
|
134 |
+
---
|
135 |
+
## Особая благодарность
|
136 |
+
- [wlu-audio lab](https://audio.westlake.edu.cn/) за ранние эксперименты с алгоритмами.
|
examples/__init__.py
ADDED
File without changes
|
examples/api/README.md
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Generating voice with ChatTTS via API
|
2 |
+
|
3 |
+
## Install requirements
|
4 |
+
|
5 |
+
Install `FastAPI` and `requests`:
|
6 |
+
|
7 |
+
```
|
8 |
+
pip install -r examples/api/requirements.txt
|
9 |
+
```
|
10 |
+
|
11 |
+
## Run API server
|
12 |
+
|
13 |
+
```
|
14 |
+
fastapi dev examples/api/main.py --host 0.0.0.0 --port 8000
|
15 |
+
```
|
16 |
+
|
17 |
+
## Generate audio using requests
|
18 |
+
|
19 |
+
```
|
20 |
+
python examples/api/client.py
|
21 |
+
```
|
22 |
+
|
23 |
+
mp3 audio files will be saved to the `output` directory.
|
examples/api/client.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import os
|
3 |
+
import zipfile
|
4 |
+
from io import BytesIO
|
5 |
+
|
6 |
+
import requests
|
7 |
+
|
8 |
+
chattts_service_host = os.environ.get("CHATTTS_SERVICE_HOST", "localhost")
|
9 |
+
chattts_service_port = os.environ.get("CHATTTS_SERVICE_PORT", "8000")
|
10 |
+
|
11 |
+
CHATTTS_URL = f"http://{chattts_service_host}:{chattts_service_port}/generate_voice"
|
12 |
+
|
13 |
+
|
14 |
+
# main infer params
|
15 |
+
body = {
|
16 |
+
"text": [
|
17 |
+
"四川美食确实以辣闻名,但也有不辣的选择。",
|
18 |
+
"比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。",
|
19 |
+
],
|
20 |
+
"stream": False,
|
21 |
+
"lang": None,
|
22 |
+
"skip_refine_text": True,
|
23 |
+
"refine_text_only": False,
|
24 |
+
"use_decoder": True,
|
25 |
+
"audio_seed": 12345678,
|
26 |
+
"text_seed": 87654321,
|
27 |
+
"do_text_normalization": True,
|
28 |
+
"do_homophone_replacement": False,
|
29 |
+
}
|
30 |
+
|
31 |
+
# refine text params
|
32 |
+
params_refine_text = {
|
33 |
+
"prompt": "",
|
34 |
+
"top_P": 0.7,
|
35 |
+
"top_K": 20,
|
36 |
+
"temperature": 0.7,
|
37 |
+
"repetition_penalty": 1,
|
38 |
+
"max_new_token": 384,
|
39 |
+
"min_new_token": 0,
|
40 |
+
"show_tqdm": True,
|
41 |
+
"ensure_non_empty": True,
|
42 |
+
"stream_batch": 24,
|
43 |
+
}
|
44 |
+
body["params_refine_text"] = params_refine_text
|
45 |
+
|
46 |
+
# infer code params
|
47 |
+
params_infer_code = {
|
48 |
+
"prompt": "[speed_5]",
|
49 |
+
"top_P": 0.1,
|
50 |
+
"top_K": 20,
|
51 |
+
"temperature": 0.3,
|
52 |
+
"repetition_penalty": 1.05,
|
53 |
+
"max_new_token": 2048,
|
54 |
+
"min_new_token": 0,
|
55 |
+
"show_tqdm": True,
|
56 |
+
"ensure_non_empty": True,
|
57 |
+
"stream_batch": True,
|
58 |
+
"spk_emb": None,
|
59 |
+
}
|
60 |
+
body["params_infer_code"] = params_infer_code
|
61 |
+
|
62 |
+
|
63 |
+
try:
|
64 |
+
response = requests.post(CHATTTS_URL, json=body)
|
65 |
+
response.raise_for_status()
|
66 |
+
with zipfile.ZipFile(BytesIO(response.content), "r") as zip_ref:
|
67 |
+
# save files for each request in a different folder
|
68 |
+
dt = datetime.datetime.now()
|
69 |
+
ts = int(dt.timestamp())
|
70 |
+
tgt = f"./output/{ts}/"
|
71 |
+
os.makedirs(tgt, 0o755)
|
72 |
+
zip_ref.extractall(tgt)
|
73 |
+
print("Extracted files into", tgt)
|
74 |
+
|
75 |
+
except requests.exceptions.RequestException as e:
|
76 |
+
print(f"Request Error: {e}")
|
examples/api/main.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import zipfile
|
5 |
+
|
6 |
+
from fastapi import FastAPI
|
7 |
+
from fastapi.responses import StreamingResponse
|
8 |
+
|
9 |
+
|
10 |
+
if sys.platform == "darwin":
|
11 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
12 |
+
|
13 |
+
now_dir = os.getcwd()
|
14 |
+
sys.path.append(now_dir)
|
15 |
+
|
16 |
+
from typing import Optional
|
17 |
+
|
18 |
+
import ChatTTS
|
19 |
+
|
20 |
+
from tools.audio import pcm_arr_to_mp3_view
|
21 |
+
from tools.logger import get_logger
|
22 |
+
import torch
|
23 |
+
|
24 |
+
|
25 |
+
from pydantic import BaseModel
|
26 |
+
|
27 |
+
|
28 |
+
logger = get_logger("Command")
|
29 |
+
|
30 |
+
app = FastAPI()
|
31 |
+
|
32 |
+
|
33 |
+
@app.on_event("startup")
|
34 |
+
async def startup_event():
|
35 |
+
global chat
|
36 |
+
|
37 |
+
chat = ChatTTS.Chat(get_logger("ChatTTS"))
|
38 |
+
logger.info("Initializing ChatTTS...")
|
39 |
+
if chat.load():
|
40 |
+
logger.info("Models loaded successfully.")
|
41 |
+
else:
|
42 |
+
logger.error("Models load failed.")
|
43 |
+
sys.exit(1)
|
44 |
+
|
45 |
+
|
46 |
+
class ChatTTSParams(BaseModel):
|
47 |
+
text: list[str]
|
48 |
+
stream: bool = False
|
49 |
+
lang: Optional[str] = None
|
50 |
+
skip_refine_text: bool = False
|
51 |
+
refine_text_only: bool = False
|
52 |
+
use_decoder: bool = True
|
53 |
+
do_text_normalization: bool = True
|
54 |
+
do_homophone_replacement: bool = False
|
55 |
+
params_refine_text: ChatTTS.Chat.RefineTextParams
|
56 |
+
params_infer_code: ChatTTS.Chat.InferCodeParams
|
57 |
+
|
58 |
+
|
59 |
+
@app.post("/generate_voice")
|
60 |
+
async def generate_voice(params: ChatTTSParams):
|
61 |
+
logger.info("Text input: %s", str(params.text))
|
62 |
+
|
63 |
+
# audio seed
|
64 |
+
if params.params_infer_code.manual_seed is not None:
|
65 |
+
torch.manual_seed(params.params_infer_code.manual_seed)
|
66 |
+
params.params_infer_code.spk_emb = chat.sample_random_speaker()
|
67 |
+
|
68 |
+
# text seed for text refining
|
69 |
+
if params.params_refine_text:
|
70 |
+
text = chat.infer(
|
71 |
+
text=params.text, skip_refine_text=False, refine_text_only=True
|
72 |
+
)
|
73 |
+
logger.info(f"Refined text: {text}")
|
74 |
+
else:
|
75 |
+
# no text refining
|
76 |
+
text = params.text
|
77 |
+
|
78 |
+
logger.info("Use speaker:")
|
79 |
+
logger.info(params.params_infer_code.spk_emb)
|
80 |
+
|
81 |
+
logger.info("Start voice inference.")
|
82 |
+
wavs = chat.infer(
|
83 |
+
text=text,
|
84 |
+
stream=params.stream,
|
85 |
+
lang=params.lang,
|
86 |
+
skip_refine_text=params.skip_refine_text,
|
87 |
+
use_decoder=params.use_decoder,
|
88 |
+
do_text_normalization=params.do_text_normalization,
|
89 |
+
do_homophone_replacement=params.do_homophone_replacement,
|
90 |
+
params_infer_code=params.params_infer_code,
|
91 |
+
params_refine_text=params.params_refine_text,
|
92 |
+
)
|
93 |
+
logger.info("Inference completed.")
|
94 |
+
|
95 |
+
# zip all of the audio files together
|
96 |
+
buf = io.BytesIO()
|
97 |
+
with zipfile.ZipFile(
|
98 |
+
buf, "a", compression=zipfile.ZIP_DEFLATED, allowZip64=False
|
99 |
+
) as f:
|
100 |
+
for idx, wav in enumerate(wavs):
|
101 |
+
f.writestr(f"{idx}.mp3", pcm_arr_to_mp3_view(wav))
|
102 |
+
logger.info("Audio generation successful.")
|
103 |
+
buf.seek(0)
|
104 |
+
|
105 |
+
response = StreamingResponse(buf, media_type="application/zip")
|
106 |
+
response.headers["Content-Disposition"] = "attachment; filename=audio_files.zip"
|
107 |
+
return response
|
examples/api/requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
fastapi
|
2 |
+
requests
|
examples/cmd/run.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
if sys.platform == "darwin":
|
4 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
5 |
+
|
6 |
+
now_dir = os.getcwd()
|
7 |
+
sys.path.append(now_dir)
|
8 |
+
|
9 |
+
from typing import Optional, List
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
import ChatTTS
|
15 |
+
|
16 |
+
from tools.logger import get_logger
|
17 |
+
from tools.audio import pcm_arr_to_mp3_view
|
18 |
+
from tools.normalizer.en import normalizer_en_nemo_text
|
19 |
+
from tools.normalizer.zh import normalizer_zh_tn
|
20 |
+
|
21 |
+
logger = get_logger("Command")
|
22 |
+
|
23 |
+
|
24 |
+
def save_mp3_file(wav, index):
|
25 |
+
data = pcm_arr_to_mp3_view(wav)
|
26 |
+
mp3_filename = f"output_audio_{index}.mp3"
|
27 |
+
with open(mp3_filename, "wb") as f:
|
28 |
+
f.write(data)
|
29 |
+
logger.info(f"Audio saved to {mp3_filename}")
|
30 |
+
|
31 |
+
|
32 |
+
def load_normalizer(chat: ChatTTS.Chat):
|
33 |
+
# try to load normalizer
|
34 |
+
try:
|
35 |
+
chat.normalizer.register("en", normalizer_en_nemo_text())
|
36 |
+
except ValueError as e:
|
37 |
+
logger.error(e)
|
38 |
+
except BaseException:
|
39 |
+
logger.warning("Package nemo_text_processing not found!")
|
40 |
+
logger.warning(
|
41 |
+
"Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing",
|
42 |
+
)
|
43 |
+
try:
|
44 |
+
chat.normalizer.register("zh", normalizer_zh_tn())
|
45 |
+
except ValueError as e:
|
46 |
+
logger.error(e)
|
47 |
+
except BaseException:
|
48 |
+
logger.warning("Package WeTextProcessing not found!")
|
49 |
+
logger.warning(
|
50 |
+
"Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing",
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
def main(
|
55 |
+
texts: List[str],
|
56 |
+
spk: Optional[str] = None,
|
57 |
+
stream: bool = False,
|
58 |
+
source: str = "local",
|
59 |
+
custom_path: str = "",
|
60 |
+
):
|
61 |
+
logger.info("Text input: %s", str(texts))
|
62 |
+
|
63 |
+
chat = ChatTTS.Chat(get_logger("ChatTTS"))
|
64 |
+
logger.info("Initializing ChatTTS...")
|
65 |
+
load_normalizer(chat)
|
66 |
+
|
67 |
+
is_load = False
|
68 |
+
if os.path.isdir(custom_path) and source == "custom":
|
69 |
+
is_load = chat.load(source="custom", custom_path=custom_path)
|
70 |
+
else:
|
71 |
+
is_load = chat.load(source=source)
|
72 |
+
|
73 |
+
if is_load:
|
74 |
+
logger.info("Models loaded successfully.")
|
75 |
+
else:
|
76 |
+
logger.error("Models load failed.")
|
77 |
+
sys.exit(1)
|
78 |
+
|
79 |
+
if spk is None:
|
80 |
+
spk = chat.sample_random_speaker()
|
81 |
+
logger.info("Use speaker:")
|
82 |
+
print(spk)
|
83 |
+
|
84 |
+
logger.info("Start inference.")
|
85 |
+
wavs = chat.infer(
|
86 |
+
texts,
|
87 |
+
stream,
|
88 |
+
params_infer_code=ChatTTS.Chat.InferCodeParams(
|
89 |
+
spk_emb=spk,
|
90 |
+
),
|
91 |
+
)
|
92 |
+
logger.info("Inference completed.")
|
93 |
+
# Save each generated wav file to a local file
|
94 |
+
if stream:
|
95 |
+
wavs_list = []
|
96 |
+
for index, wav in enumerate(wavs):
|
97 |
+
if stream:
|
98 |
+
for i, w in enumerate(wav):
|
99 |
+
save_mp3_file(w, (i + 1) * 1000 + index)
|
100 |
+
wavs_list.append(wav)
|
101 |
+
else:
|
102 |
+
save_mp3_file(wav, index)
|
103 |
+
if stream:
|
104 |
+
for index, wav in enumerate(np.concatenate(wavs_list, axis=1)):
|
105 |
+
save_mp3_file(wav, index)
|
106 |
+
logger.info("Audio generation successful.")
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == "__main__":
|
110 |
+
r"""
|
111 |
+
python -m examples.cmd.run \
|
112 |
+
--source custom --custom_path ../../models/2Noise/ChatTTS 你好喲 ":)"
|
113 |
+
"""
|
114 |
+
logger.info("Starting ChatTTS commandline demo...")
|
115 |
+
parser = argparse.ArgumentParser(
|
116 |
+
description="ChatTTS Command",
|
117 |
+
usage='[--spk xxx] [--stream] [--source ***] [--custom_path XXX] "Your text 1." " Your text 2."',
|
118 |
+
)
|
119 |
+
parser.add_argument(
|
120 |
+
"--spk",
|
121 |
+
help="Speaker (empty to sample a random one)",
|
122 |
+
type=Optional[str],
|
123 |
+
default=None,
|
124 |
+
)
|
125 |
+
parser.add_argument(
|
126 |
+
"--stream",
|
127 |
+
help="Use stream mode",
|
128 |
+
action="store_true",
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--source",
|
132 |
+
help="source form [ huggingface(hf download), local(ckpt save to asset dir), custom(define) ]",
|
133 |
+
type=str,
|
134 |
+
default="local",
|
135 |
+
)
|
136 |
+
parser.add_argument(
|
137 |
+
"--custom_path",
|
138 |
+
help="custom defined model path(include asset ckpt dir)",
|
139 |
+
type=str,
|
140 |
+
default="",
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
"texts",
|
144 |
+
help="Original text",
|
145 |
+
default=["YOUR TEXT HERE"],
|
146 |
+
nargs=argparse.REMAINDER,
|
147 |
+
)
|
148 |
+
args = parser.parse_args()
|
149 |
+
logger.info(args)
|
150 |
+
main(args.texts, args.spk, args.stream, args.source, args.custom_path)
|
151 |
+
logger.info("ChatTTS process finished.")
|