zhengr commited on
Commit
c02bdcd
1 Parent(s): 93e634f
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ChatTTS/__init__.py +1 -0
  2. ChatTTS/config/__init__.py +1 -0
  3. ChatTTS/config/config.py +134 -0
  4. ChatTTS/core.py +669 -0
  5. ChatTTS/model/__init__.py +6 -0
  6. ChatTTS/model/cuda/__init__.py +1 -0
  7. ChatTTS/model/cuda/patch.py +18 -0
  8. ChatTTS/model/cuda/te_llama.py +192 -0
  9. ChatTTS/model/dvae.py +296 -0
  10. ChatTTS/model/embed.py +81 -0
  11. ChatTTS/model/gpt.py +613 -0
  12. ChatTTS/model/processors.py +58 -0
  13. ChatTTS/model/speaker.py +154 -0
  14. ChatTTS/model/tokenizer.py +138 -0
  15. ChatTTS/model/velocity/__init__.py +2 -0
  16. ChatTTS/model/velocity/block_manager.py +296 -0
  17. ChatTTS/model/velocity/configs.py +865 -0
  18. ChatTTS/model/velocity/llama.py +393 -0
  19. ChatTTS/model/velocity/llm.py +213 -0
  20. ChatTTS/model/velocity/llm_engine.py +833 -0
  21. ChatTTS/model/velocity/model_loader.py +69 -0
  22. ChatTTS/model/velocity/model_runner.py +817 -0
  23. ChatTTS/model/velocity/output.py +144 -0
  24. ChatTTS/model/velocity/sampler.py +120 -0
  25. ChatTTS/model/velocity/sampling_params.py +296 -0
  26. ChatTTS/model/velocity/scheduler.py +426 -0
  27. ChatTTS/model/velocity/sequence.py +450 -0
  28. ChatTTS/model/velocity/worker.py +251 -0
  29. ChatTTS/norm.py +253 -0
  30. ChatTTS/res/__init__.py +0 -0
  31. ChatTTS/res/homophones_map.json +0 -0
  32. ChatTTS/res/sha256_map.json +13 -0
  33. ChatTTS/utils/__init__.py +4 -0
  34. ChatTTS/utils/dl.py +220 -0
  35. ChatTTS/utils/gpu.py +40 -0
  36. ChatTTS/utils/io.py +44 -0
  37. ChatTTS/utils/log.py +16 -0
  38. Dockerfile +13 -0
  39. LICENSE +661 -0
  40. docs/cn/README.md +314 -0
  41. docs/es/README.md +255 -0
  42. docs/fr/README.md +283 -0
  43. docs/jp/README.md +134 -0
  44. docs/ru/README.md +136 -0
  45. examples/__init__.py +0 -0
  46. examples/api/README.md +23 -0
  47. examples/api/client.py +76 -0
  48. examples/api/main.py +107 -0
  49. examples/api/requirements.txt +2 -0
  50. examples/cmd/run.py +151 -0
ChatTTS/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .core import Chat
ChatTTS/config/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .config import Config
ChatTTS/config/config.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass(repr=False, eq=False)
5
+ class Path:
6
+ vocos_ckpt_path: str = "asset/Vocos.pt"
7
+ dvae_ckpt_path: str = "asset/DVAE_full.pt"
8
+ gpt_ckpt_path: str = "asset/gpt"
9
+ decoder_ckpt_path: str = "asset/Decoder.pt"
10
+ tokenizer_path: str = "asset/tokenizer"
11
+ embed_path: str = "asset/Embed.safetensors"
12
+
13
+
14
+ @dataclass(repr=False, eq=False)
15
+ class Decoder:
16
+ idim: int = 384
17
+ odim: int = 384
18
+ hidden: int = 512
19
+ n_layer: int = 12
20
+ bn_dim: int = 128
21
+
22
+
23
+ @dataclass(repr=False, eq=False)
24
+ class VQ:
25
+ dim: int = 1024
26
+ levels: tuple = (5, 5, 5, 5)
27
+ G: int = 2
28
+ R: int = 2
29
+
30
+
31
+ @dataclass(repr=False, eq=False)
32
+ class DVAE:
33
+ encoder: Decoder = Decoder(
34
+ idim=512,
35
+ odim=1024,
36
+ hidden=256,
37
+ n_layer=12,
38
+ bn_dim=128,
39
+ )
40
+ decoder: Decoder = Decoder(
41
+ idim=512,
42
+ odim=512,
43
+ hidden=256,
44
+ n_layer=12,
45
+ bn_dim=128,
46
+ )
47
+ vq: VQ = VQ()
48
+
49
+
50
+ @dataclass(repr=False, eq=False)
51
+ class GPT:
52
+ hidden_size: int = 768
53
+ intermediate_size: int = 3072
54
+ num_attention_heads: int = 12
55
+ num_hidden_layers: int = 20
56
+ use_cache: bool = False
57
+ max_position_embeddings: int = 4096
58
+
59
+ spk_emb_dim: int = 192
60
+ spk_KL: bool = False
61
+ num_audio_tokens: int = 626
62
+ num_text_tokens: int = 21178
63
+ num_vq: int = 4
64
+
65
+
66
+ @dataclass(repr=False, eq=False)
67
+ class Embed:
68
+ hidden_size: int = 768
69
+ num_audio_tokens: int = 626
70
+ num_text_tokens: int = 21178
71
+ num_vq: int = 4
72
+
73
+
74
+ @dataclass(repr=False, eq=False)
75
+ class FeatureExtractorInitArgs:
76
+ sample_rate: int = 24000
77
+ n_fft: int = 1024
78
+ hop_length: int = 256
79
+ n_mels: int = 100
80
+ padding: str = "center"
81
+
82
+
83
+ @dataclass(repr=False, eq=False)
84
+ class FeatureExtractor:
85
+ class_path: str = "vocos.feature_extractors.MelSpectrogramFeatures"
86
+ init_args: FeatureExtractorInitArgs = FeatureExtractorInitArgs()
87
+
88
+
89
+ @dataclass(repr=False, eq=False)
90
+ class BackboneInitArgs:
91
+ input_channels: int = 100
92
+ dim: int = 512
93
+ intermediate_dim: int = 1536
94
+ num_layers: int = 8
95
+
96
+
97
+ @dataclass(repr=False, eq=False)
98
+ class Backbone:
99
+ class_path: str = "vocos.models.VocosBackbone"
100
+ init_args: BackboneInitArgs = BackboneInitArgs()
101
+
102
+
103
+ @dataclass(repr=False, eq=False)
104
+ class FourierHeadInitArgs:
105
+ dim: int = 512
106
+ n_fft: int = 1024
107
+ hop_length: int = 256
108
+ padding: str = "center"
109
+
110
+
111
+ @dataclass(repr=False, eq=False)
112
+ class FourierHead:
113
+ class_path: str = "vocos.heads.ISTFTHead"
114
+ init_args: FourierHeadInitArgs = FourierHeadInitArgs()
115
+
116
+
117
+ @dataclass(repr=False, eq=False)
118
+ class Vocos:
119
+ feature_extractor: FeatureExtractor = FeatureExtractor()
120
+ backbone: Backbone = Backbone()
121
+ head: FourierHead = FourierHead()
122
+
123
+
124
+ @dataclass(repr=False, eq=False)
125
+ class Config:
126
+ path: Path = Path()
127
+ decoder: Decoder = Decoder()
128
+ dvae: DVAE = DVAE()
129
+ gpt: GPT = GPT()
130
+ embed: Embed = Embed()
131
+ vocos: Vocos = Vocos()
132
+ spk_stat: str = (
133
+ "愐穤巩噅廷戇笉屈癐媄垹垧帶爲漈塀殐慄亅倴庲舴猂瑈圐狴夥圓帍戛挠腉耐劤坽喳幾战謇聀崒栄呥倸庭燡欈杁襐褄乭埗幺爃弔摁斐捔兕佖廐舏竾豃磐姓趡佄幒爚欄豄讐皳訵仩帆投謌荃蝐叄圝伆幦抂茁呄掑斃讹傮庞爣蜀橁偐祄亥兡常爂欍扉丐浔佱僈強払伅扂蛐徴憍傞巀戺欀艂琐嗴啥値彷刂權穈扒卤俔贲庛初笂卄贐枴仭亁庛剎猢扃缐趤刁偵幪舏伌煁婐潤晍位弾舙茥穁葏蠣訑企庤刊笍橁溑僔云偁庯戚伍潉膐脴僵噔廃艅匊祂唐憴壝嗙席爥欁虁谐牴帽势弿牳蜁兀蛐傄喩丿帔刔圆衁廐罤庁促帙劢伈汄樐檄勵伴弝舑欍罅虐昴劭勅帜刼朊蕁虐蓴樑伫幨扑謪剀堐稴丵伱弐舮諸赁習俔容厱幫牶謃孄糐答嗝僊帜燲笄終瀒判久僤帘爴茇千孑冄凕佳引扐蜁歁缏裄剽儺恘爋朏眿廐呄塍嘇幻爱茠詁訐剴唭俐幾戊欀硁菐贄楕偒巡爀弎屄莐睳賙凶彎刅漄區唐溴剑劋庽舽猄煃跐夔惥伾庮舎伈罁垑坄怅业怯刁朇獁嶏覔坩俳巶爜朐潁崐萄俹凛常爺笌穀聐此夡倛帡刀匉終窏舣販侽怿扉伥贿憐忓謩姆幌犊漂慆癒却甝兎帼戏欅詂浐朔仹壭帰臷弎恇菐獤帡偖帘爞伅腂皐纤囅充幓戠伥灂丐訤戱倱弋爮嬌癁恐孄侥劬忶刓國詀桒古偩嘄庬戚茝赂监燤嘑勌幦舽持呂諐棤姑再底舡笍艃瀐孴倉傔弋爔猠乁濑塄偽嘧恂舛缇襃厐窴仡刱忕別漇穁岏缴廽价庌爊謈硄讑惤倁儂庭爋伇蝂嶐莔摝傠库刞茄歃戏薤伍伯廮创笠塄熐兴勽俄帅剉最腀砐敤卝侍弆戺朒虃旐蚄梕亖幔牻朣扅贐玔堝噅帡剌圅摀崐彤流僳庙爖嬇啁渐悤堁丛幆刧挜彃悐幤刹嚟恕芁看聀摐焔向乁帖爭欁癃糒圄弙佱廜戤謍婀咐昴焍亩廦艏拼謿芐癤怹兽幸舳朇畁喐稔毝丼弈懲挀譂勑哴啁伎常舭笯晁堑俄叩剔廟爍欦絁夒伤休傑廳戌蜅潆癐彴摑勯床刽欅艁砐忄搉从廡舊猥潂唐委仱僜廼爤朄呃弐礔滵垓幩爄挂筁乐籤刕凟幵爠弉���乑吴勥伖帪舩茆婁碐幤叭乢巜艳猁桀桐啄唩俊幍舮猀艅焐螔琽亀帋爜缅噃咐斤喩予幩爛笆摀浐猴依侹幃刕園慄蛐栤澹仑座爼謉桃慐浔斕偻幛懰嬓衁愐氄悅仿应芔漄衃敐謤傁匩幹抃圉癄廐裄屵噉幍利謍聂搐蛔嚙坍怗舁圐畃膐栄刵东巆戤諾呃偑媤嗨跞忶爝眄祂朒嶔僭劉忾刐匋癄袐翴珅僷廲芄茈恈皐擄崑伄廉牍匃剃犏澤唑丄庺戃伃煀某杄偙亽帴切缌罄挐尴噙倰带舞漄橄塐糴俩僯帀般漀坂栐更両俇廱舌猁慂拐偤嶱卶应刪眉獁茐伔嘅偺帟舊漂恀栐暄喡乞庙舆匂敀潑恔劑侖延戦盽怶唯慳蝘蟃孫娎益袰玍屃痶翮笪儚裀倹椌玻翀詵筽舘惯堿某侰晈藏缮詗廦夸妎瑻瀒裔媀憞唃冶璭狻渠荑奬熹茅愺氰菣滠翦岓褌泣崲嚭欓湒聙宺爄蛅愸庍匃帆誔穮懌蓪玷澌氋抌訙屌臞廛玸听屺希疭孝凂紋新煎彃膲跱尪懁眆窴珏卓揨菸紭概囥显壌榄垫嘮嬭覤媸侵佮烒耸觌婀秋狃帹葯訤桜糨笾腢伀肶悍炂艤禖岅臺惘梷瞍友盁佨岧憳瓧嘴汬藊愌蘤嶠硴绤蜲襏括勾谂縨妥蓪澭竭萢藜纞糲煮愆瀯孯琓罂諺塿燗狟弙衯揻縷丱糅臄梱瀮杰巳猙亊符胠匃泀廏圃膂蒃籏礩岈簹缌劺燲褡孓膜拔蠿觮呋煣厌尷熜論弲牭紫寊誃紀橴賬傸箍弚窃侫簲慯烣渽祌壓媥噜夽夛諛玹疮禄冪謇媽衤盰缺繑薫兾萧嵱打滽箺嚯凣狢蠜崼覽烸簶盯籓摀苶峸懗泲涻凮愳緗剋笔懆廡瞿椏礤惐藥崍腈烄伹亯昣翬褍絋桫僨吨莌丛矄蜞娈憊苆塁蓏嚢嫼绻崱婋囱蠸篯晣芀繼索兓僖誹岯圪褰蠇唓妷胅巁渮砛傈蝷嵚冃購赁峍裋荂舾符熻岳墩寮粃凲袑彚太绲头摯繳狁俥籌冝諝註坎幫擤詒宒凕賐唶梎噔弼課屿覍囨焬櫱撪蝮蝬簸懰櫫涺嵍睻屪翔峞慘滟熲昱军烊舿尦舄糖奁溏凂彆蝲糴禍困皻灏牋睒诙嶱臀开蓈眎腼丢纻廏憤嫖暭袭崲肸螛妒榗紉谨窮袃瑠聍绊腆亿冲葐喋縔詖岑兾给堸赏旻桀蛨媆訂峦紷敯囬偐筨岸焸拭笵殒哜墒萍屓娓諙械臮望摰芑寭准僞谹氍旋憢菮屃划欣瘫谎蘻哐繁籥禦僿誵皯墓燀縿笞熦绗稹榎矻綞蓓帡戓沺区才畃洊詪糐裶盰窶耎偌劂誐庩惝滜沺哮呃煐譠崄槀猄肼蔐擋湌蠺篃恥諌瞦宍堫挪裕崑慩狲悠煋仛愞砈粵八棁害楐妋萔貨尵奂苰怫誎傫岆蕯屇脉夈仆茎刓繸芺壸碗曛汁戭炻獻凉媁兎狜爴怰賃纎袏娷禃蓥膹薪渻罸窿粫凾褄舺窮墫干苊繁冏僮訸夯绛蓪虛羽慲烏憷趎睊蠰莍塞成廎盁欏喓蜮譤崆楁囘矇薭伣艘虝帴奮苢渶虎暣翐蝃尾稈糶瀴罐嵚氮葯笫慐棌悶炯竻爅们媡姢嫺窷刮歫劈裩屬椕賑蜹薊刲義哯尗褦瓀稾礋揣窼舫尋姁椄侸嗫珺修纘媃腽蛛稹梭呛瀈蘟縀礉論夵售主梮蠉娅娭裀誼嶭観枳倊簈褃擞綿催瞃溶苊笛襹櫲盅六囫獩佃粨慯瓢眸旱荃婨蔞岋祗墼焻网牻琖詆峋秉胳媴袭澓賢経稟壩胫碯偏囫嶎纆窈槊賐撹璬莃缘誾宭愊眗喷监劋萘訯總槿棭戾墮犄恌縈簍樥蛔杁袭嫛憫倆篏墵賈羯茎觳蒜致娢慄勒覸蘍曲栂葭宆妋皽缽免盳猼蔂糥觧烳檸佯憓煶蔐筼种繷琲膌塄剰讎対腕棥渽忲俛浪譬秛惛壒嘸淫冻曄睻砃奫貯庴爅粓脮脡娎妖峵蘲討惋泊蠀㴆"
134
+ )
ChatTTS/core.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import tempfile
4
+ from dataclasses import dataclass, asdict
5
+ from typing import Literal, Optional, List, Tuple, Dict, Union
6
+ from json import load
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ import torch
11
+ from vocos import Vocos
12
+ from vocos.pretrained import instantiate_class
13
+ from huggingface_hub import snapshot_download
14
+
15
+ from .config import Config
16
+ from .model import DVAE, Embed, GPT, gen_logits, Tokenizer, Speaker
17
+ from .utils import (
18
+ check_all_assets,
19
+ download_all_assets,
20
+ select_device,
21
+ get_latest_modified_file,
22
+ del_all,
23
+ )
24
+ from .utils import logger as utils_logger
25
+
26
+ from .norm import Normalizer
27
+
28
+
29
+ class Chat:
30
+ def __init__(self, logger=logging.getLogger(__name__)):
31
+ self.logger = logger
32
+ utils_logger.set_logger(logger)
33
+
34
+ self.config = Config()
35
+
36
+ self.normalizer = Normalizer(
37
+ os.path.join(os.path.dirname(__file__), "res", "homophones_map.json"),
38
+ logger,
39
+ )
40
+ with open(
41
+ os.path.join(os.path.dirname(__file__), "res", "sha256_map.json")
42
+ ) as f:
43
+ self.sha256_map: Dict[str, str] = load(f)
44
+
45
+ self.context = GPT.Context()
46
+
47
+ def has_loaded(self, use_decoder=False):
48
+ not_finish = False
49
+ check_list = ["vocos", "gpt", "tokenizer", "embed"]
50
+
51
+ if use_decoder:
52
+ check_list.append("decoder")
53
+ else:
54
+ check_list.append("dvae")
55
+
56
+ for module in check_list:
57
+ if not hasattr(self, module):
58
+ self.logger.warning(f"{module} not initialized.")
59
+ not_finish = True
60
+
61
+ return not not_finish
62
+
63
+ def download_models(
64
+ self,
65
+ source: Literal["huggingface", "local", "custom"] = "local",
66
+ force_redownload=False,
67
+ custom_path: Optional[torch.serialization.FILE_LIKE] = None,
68
+ ) -> Optional[str]:
69
+ if source == "local":
70
+ download_path = os.getcwd()
71
+ if (
72
+ not check_all_assets(Path(download_path), self.sha256_map, update=True)
73
+ or force_redownload
74
+ ):
75
+ with tempfile.TemporaryDirectory() as tmp:
76
+ download_all_assets(tmpdir=tmp)
77
+ if not check_all_assets(
78
+ Path(download_path), self.sha256_map, update=False
79
+ ):
80
+ self.logger.error(
81
+ "download to local path %s failed.", download_path
82
+ )
83
+ return None
84
+ elif source == "huggingface":
85
+ hf_home = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
86
+ try:
87
+ download_path = get_latest_modified_file(
88
+ os.path.join(hf_home, "hub/models--2Noise--ChatTTS/snapshots")
89
+ )
90
+ except:
91
+ download_path = None
92
+ if download_path is None or force_redownload:
93
+ self.logger.log(
94
+ logging.INFO,
95
+ f"download from HF: https://huggingface.co/2Noise/ChatTTS",
96
+ )
97
+ try:
98
+ download_path = snapshot_download(
99
+ repo_id="2Noise/ChatTTS",
100
+ allow_patterns=["*.pt", "*.yaml", "*.json", "*.safetensors"],
101
+ )
102
+ except:
103
+ download_path = None
104
+ else:
105
+ self.logger.log(
106
+ logging.INFO, f"load latest snapshot from cache: {download_path}"
107
+ )
108
+ if download_path is None:
109
+ self.logger.error("download from huggingface failed.")
110
+ return None
111
+ elif source == "custom":
112
+ self.logger.log(logging.INFO, f"try to load from local: {custom_path}")
113
+ if not check_all_assets(Path(custom_path), self.sha256_map, update=False):
114
+ self.logger.error("check models in custom path %s failed.", custom_path)
115
+ return None
116
+ download_path = custom_path
117
+
118
+ return download_path
119
+
120
+ def load(
121
+ self,
122
+ source: Literal["huggingface", "local", "custom"] = "local",
123
+ force_redownload=False,
124
+ compile: bool = False,
125
+ custom_path: Optional[torch.serialization.FILE_LIKE] = None,
126
+ device: Optional[torch.device] = None,
127
+ coef: Optional[torch.Tensor] = None,
128
+ use_flash_attn=False,
129
+ use_vllm=False,
130
+ experimental: bool = False,
131
+ ) -> bool:
132
+ download_path = self.download_models(source, force_redownload, custom_path)
133
+ if download_path is None:
134
+ return False
135
+ return self._load(
136
+ device=device,
137
+ compile=compile,
138
+ coef=coef,
139
+ use_flash_attn=use_flash_attn,
140
+ use_vllm=use_vllm,
141
+ experimental=experimental,
142
+ **{
143
+ k: os.path.join(download_path, v)
144
+ for k, v in asdict(self.config.path).items()
145
+ },
146
+ )
147
+
148
+ def unload(self):
149
+ logger = self.logger
150
+ self.normalizer.destroy()
151
+ del self.normalizer
152
+ del self.sha256_map
153
+ del_list = ["vocos", "gpt", "decoder", "dvae", "tokenizer", "embed"]
154
+ for module in del_list:
155
+ if hasattr(self, module):
156
+ delattr(self, module)
157
+ self.__init__(logger)
158
+
159
+ def sample_random_speaker(self) -> str:
160
+ return self.speaker.sample_random()
161
+
162
+ def sample_audio_speaker(self, wav: Union[np.ndarray, torch.Tensor]) -> str:
163
+ return self.speaker.encode_prompt(self.dvae.sample_audio(wav))
164
+
165
+ @dataclass(repr=False, eq=False)
166
+ class RefineTextParams:
167
+ prompt: str = ""
168
+ top_P: float = 0.7
169
+ top_K: int = 20
170
+ temperature: float = 0.7
171
+ repetition_penalty: float = 1.0
172
+ max_new_token: int = 384
173
+ min_new_token: int = 0
174
+ show_tqdm: bool = True
175
+ ensure_non_empty: bool = True
176
+ manual_seed: Optional[int] = None
177
+
178
+ @dataclass(repr=False, eq=False)
179
+ class InferCodeParams(RefineTextParams):
180
+ prompt: str = "[speed_5]"
181
+ spk_emb: Optional[str] = None
182
+ spk_smp: Optional[str] = None
183
+ txt_smp: Optional[str] = None
184
+ temperature: float = 0.3
185
+ repetition_penalty: float = 1.05
186
+ max_new_token: int = 2048
187
+ stream_batch: int = 24
188
+ stream_speed: int = 12000
189
+ pass_first_n_batches: int = 2
190
+
191
+ def infer(
192
+ self,
193
+ text,
194
+ stream=False,
195
+ lang=None,
196
+ skip_refine_text=False,
197
+ refine_text_only=False,
198
+ use_decoder=True,
199
+ do_text_normalization=True,
200
+ do_homophone_replacement=True,
201
+ params_refine_text=RefineTextParams(),
202
+ params_infer_code=InferCodeParams(),
203
+ ):
204
+ self.context.set(False)
205
+ res_gen = self._infer(
206
+ text,
207
+ stream,
208
+ lang,
209
+ skip_refine_text,
210
+ refine_text_only,
211
+ use_decoder,
212
+ do_text_normalization,
213
+ do_homophone_replacement,
214
+ params_refine_text,
215
+ params_infer_code,
216
+ )
217
+ if stream:
218
+ return res_gen
219
+ else:
220
+ return next(res_gen)
221
+
222
+ def interrupt(self):
223
+ self.context.set(True)
224
+
225
+ @torch.no_grad()
226
+ def _load(
227
+ self,
228
+ vocos_ckpt_path: str = None,
229
+ dvae_ckpt_path: str = None,
230
+ gpt_ckpt_path: str = None,
231
+ embed_path: str = None,
232
+ decoder_ckpt_path: str = None,
233
+ tokenizer_path: str = None,
234
+ device: Optional[torch.device] = None,
235
+ compile: bool = False,
236
+ coef: Optional[str] = None,
237
+ use_flash_attn=False,
238
+ use_vllm=False,
239
+ experimental: bool = False,
240
+ ):
241
+ if device is None:
242
+ device = select_device(experimental=experimental)
243
+ self.logger.info("use device %s", str(device))
244
+ self.device = device
245
+ self.device_gpt = device if "mps" not in str(device) else torch.device("cpu")
246
+ self.compile = compile
247
+
248
+ feature_extractor = instantiate_class(
249
+ args=(), init=asdict(self.config.vocos.feature_extractor)
250
+ )
251
+ backbone = instantiate_class(args=(), init=asdict(self.config.vocos.backbone))
252
+ head = instantiate_class(args=(), init=asdict(self.config.vocos.head))
253
+ vocos = (
254
+ Vocos(feature_extractor=feature_extractor, backbone=backbone, head=head)
255
+ .to(
256
+ # vocos on mps will crash, use cpu fallback
257
+ "cpu"
258
+ if "mps" in str(device)
259
+ else device
260
+ )
261
+ .eval()
262
+ )
263
+ assert vocos_ckpt_path, "vocos_ckpt_path should not be None"
264
+ vocos.load_state_dict(torch.load(vocos_ckpt_path, weights_only=True, mmap=True))
265
+ self.vocos = vocos
266
+ self.logger.log(logging.INFO, "vocos loaded.")
267
+
268
+ dvae = (
269
+ DVAE(
270
+ decoder_config=asdict(self.config.dvae.decoder),
271
+ encoder_config=asdict(self.config.dvae.encoder),
272
+ vq_config=asdict(self.config.dvae.vq),
273
+ dim=self.config.dvae.decoder.idim,
274
+ coef=coef,
275
+ device=device,
276
+ )
277
+ .to(device)
278
+ .eval()
279
+ )
280
+ coef = str(dvae)
281
+ assert dvae_ckpt_path, "dvae_ckpt_path should not be None"
282
+ dvae.load_state_dict(torch.load(dvae_ckpt_path, weights_only=True, mmap=True))
283
+ self.dvae = dvae
284
+ self.logger.log(logging.INFO, "dvae loaded.")
285
+
286
+ embed = Embed(
287
+ self.config.embed.hidden_size,
288
+ self.config.embed.num_audio_tokens,
289
+ self.config.embed.num_text_tokens,
290
+ self.config.embed.num_vq,
291
+ )
292
+ embed.from_pretrained(embed_path, device=device)
293
+ self.embed = embed.to(device)
294
+ self.logger.log(logging.INFO, "embed loaded.")
295
+
296
+ gpt = GPT(
297
+ gpt_config=asdict(self.config.gpt),
298
+ embed=self.embed,
299
+ use_flash_attn=use_flash_attn,
300
+ use_vllm=use_vllm,
301
+ device=device,
302
+ device_gpt=self.device_gpt,
303
+ logger=self.logger,
304
+ ).eval()
305
+ assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
306
+ gpt.from_pretrained(gpt_ckpt_path, embed_path, experimental=experimental)
307
+ gpt.prepare(compile=compile and "cuda" in str(device))
308
+ self.gpt = gpt
309
+ self.logger.log(logging.INFO, "gpt loaded.")
310
+
311
+ self.speaker = Speaker(
312
+ self.config.gpt.hidden_size, self.config.spk_stat, device
313
+ )
314
+ self.logger.log(logging.INFO, "speaker loaded.")
315
+
316
+ decoder = (
317
+ DVAE(
318
+ decoder_config=asdict(self.config.decoder),
319
+ dim=self.config.decoder.idim,
320
+ coef=coef,
321
+ device=device,
322
+ )
323
+ .to(device)
324
+ .eval()
325
+ )
326
+ coef = str(decoder)
327
+ assert decoder_ckpt_path, "decoder_ckpt_path should not be None"
328
+ decoder.load_state_dict(
329
+ torch.load(decoder_ckpt_path, weights_only=True, mmap=True)
330
+ )
331
+ self.decoder = decoder
332
+ self.logger.log(logging.INFO, "decoder loaded.")
333
+
334
+ if tokenizer_path:
335
+ self.tokenizer = Tokenizer(tokenizer_path)
336
+ self.logger.log(logging.INFO, "tokenizer loaded.")
337
+
338
+ self.coef = coef
339
+
340
+ return self.has_loaded()
341
+
342
+ def _infer(
343
+ self,
344
+ text,
345
+ stream=False,
346
+ lang=None,
347
+ skip_refine_text=False,
348
+ refine_text_only=False,
349
+ use_decoder=True,
350
+ do_text_normalization=True,
351
+ do_homophone_replacement=True,
352
+ params_refine_text=RefineTextParams(),
353
+ params_infer_code=InferCodeParams(),
354
+ ):
355
+
356
+ assert self.has_loaded(use_decoder=use_decoder)
357
+
358
+ if not isinstance(text, list):
359
+ text = [text]
360
+
361
+ text = [
362
+ self.normalizer(
363
+ t,
364
+ do_text_normalization,
365
+ do_homophone_replacement,
366
+ lang,
367
+ )
368
+ for t in text
369
+ ]
370
+
371
+ self.logger.debug("normed texts %s", str(text))
372
+
373
+ if not skip_refine_text:
374
+ refined = self._refine_text(
375
+ text,
376
+ self.device,
377
+ params_refine_text,
378
+ )
379
+ text_tokens = refined.ids
380
+ text_tokens = [i[i.less(self.tokenizer.break_0_ids)] for i in text_tokens]
381
+ text = self.tokenizer.decode(text_tokens)
382
+ refined.destroy()
383
+ if refine_text_only:
384
+ yield text
385
+ return
386
+
387
+ if stream:
388
+ length = 0
389
+ pass_batch_count = 0
390
+ for result in self._infer_code(
391
+ text,
392
+ stream,
393
+ self.device,
394
+ use_decoder,
395
+ params_infer_code,
396
+ ):
397
+ wavs = self._decode_to_wavs(
398
+ result.hiddens if use_decoder else result.ids,
399
+ use_decoder,
400
+ )
401
+ result.destroy()
402
+ if stream:
403
+ pass_batch_count += 1
404
+ if pass_batch_count <= params_infer_code.pass_first_n_batches:
405
+ continue
406
+ a = length
407
+ b = a + params_infer_code.stream_speed
408
+ if b > wavs.shape[1]:
409
+ b = wavs.shape[1]
410
+ new_wavs = wavs[:, a:b]
411
+ length = b
412
+ yield new_wavs
413
+ else:
414
+ yield wavs
415
+ if stream:
416
+ new_wavs = wavs[:, length:]
417
+ # Identify rows with non-zero elements using np.any
418
+ # keep_rows = np.any(array != 0, axis=1)
419
+ keep_cols = np.sum(new_wavs != 0, axis=0) > 0
420
+ # Filter both rows and columns using slicing
421
+ yield new_wavs[:][:, keep_cols]
422
+
423
+ @torch.inference_mode()
424
+ def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray:
425
+ if "mps" in str(self.device):
426
+ return self.vocos.decode(spec.cpu()).cpu().numpy()
427
+ else:
428
+ return self.vocos.decode(spec).cpu().numpy()
429
+
430
+ @torch.inference_mode()
431
+ def _decode_to_wavs(
432
+ self,
433
+ result_list: List[torch.Tensor],
434
+ use_decoder: bool,
435
+ ):
436
+ decoder = self.decoder if use_decoder else self.dvae
437
+ max_x_len = -1
438
+ if len(result_list) == 0:
439
+ return np.array([], dtype=np.float32)
440
+ for result in result_list:
441
+ if result.size(0) > max_x_len:
442
+ max_x_len = result.size(0)
443
+ batch_result = torch.zeros(
444
+ (len(result_list), result_list[0].size(1), max_x_len),
445
+ dtype=result_list[0].dtype,
446
+ device=result_list[0].device,
447
+ )
448
+ for i in range(len(result_list)):
449
+ src = result_list[i]
450
+ batch_result[i].narrow(1, 0, src.size(0)).copy_(src.permute(1, 0))
451
+ del src
452
+ del_all(result_list)
453
+ mel_specs = decoder(batch_result)
454
+ del batch_result
455
+ wavs = self._vocos_decode(mel_specs)
456
+ del mel_specs
457
+ return wavs
458
+
459
+ @torch.no_grad()
460
+ def _infer_code(
461
+ self,
462
+ text: Tuple[List[str], str],
463
+ stream: bool,
464
+ device: torch.device,
465
+ return_hidden: bool,
466
+ params: InferCodeParams,
467
+ ):
468
+
469
+ gpt = self.gpt
470
+
471
+ if not isinstance(text, list):
472
+ text = [text]
473
+
474
+ assert len(text), "text should not be empty"
475
+
476
+ if not isinstance(params.temperature, list):
477
+ temperature = [params.temperature] * self.config.gpt.num_vq
478
+ else:
479
+ temperature = params.temperature
480
+
481
+ input_ids, attention_mask, text_mask = self.tokenizer.encode(
482
+ self.speaker.decorate_code_prompts(
483
+ text,
484
+ params.prompt,
485
+ params.txt_smp,
486
+ params.spk_emb,
487
+ ),
488
+ self.config.gpt.num_vq,
489
+ prompt=(
490
+ self.speaker.decode_prompt(params.spk_smp)
491
+ if params.spk_smp is not None
492
+ else None
493
+ ),
494
+ device=self.device_gpt,
495
+ )
496
+ start_idx = input_ids.shape[-2]
497
+
498
+ num_code = self.config.gpt.num_audio_tokens - 1
499
+
500
+ logits_warpers, logits_processors = gen_logits(
501
+ num_code=num_code,
502
+ top_P=params.top_P,
503
+ top_K=params.top_K,
504
+ repetition_penalty=params.repetition_penalty,
505
+ )
506
+
507
+ if gpt.is_vllm:
508
+ from .model.velocity import SamplingParams
509
+
510
+ sample_params = SamplingParams(
511
+ temperature=temperature,
512
+ max_new_token=params.max_new_token,
513
+ max_tokens=8192,
514
+ min_new_token=params.min_new_token,
515
+ logits_processors=(logits_processors, logits_warpers),
516
+ eos_token=num_code,
517
+ infer_text=False,
518
+ start_idx=start_idx,
519
+ )
520
+ input_ids = [i.tolist() for i in input_ids]
521
+
522
+ result = gpt.llm.generate(
523
+ None,
524
+ sample_params,
525
+ input_ids,
526
+ )
527
+
528
+ token_ids = []
529
+ hidden_states = []
530
+ for i in result:
531
+ token_ids.append(torch.tensor(i.outputs[0].token_ids))
532
+ hidden_states.append(
533
+ i.outputs[0].hidden_states.to(torch.float32).to(self.device)
534
+ )
535
+
536
+ del text_mask, input_ids
537
+
538
+ return [
539
+ GPT.GenerationOutputs(
540
+ ids=token_ids,
541
+ hiddens=hidden_states,
542
+ attentions=[],
543
+ ),
544
+ ]
545
+
546
+ emb = self.embed(input_ids, text_mask)
547
+
548
+ del text_mask
549
+
550
+ if params.spk_emb is not None:
551
+ self.speaker.apply(
552
+ emb,
553
+ params.spk_emb,
554
+ input_ids,
555
+ self.tokenizer.spk_emb_ids,
556
+ self.gpt.device_gpt,
557
+ )
558
+
559
+ result = gpt.generate(
560
+ emb,
561
+ input_ids,
562
+ temperature=torch.tensor(temperature, device=device),
563
+ eos_token=num_code,
564
+ attention_mask=attention_mask,
565
+ max_new_token=params.max_new_token,
566
+ min_new_token=params.min_new_token,
567
+ logits_processors=(*logits_processors, *logits_warpers),
568
+ infer_text=False,
569
+ return_hidden=return_hidden,
570
+ stream=stream,
571
+ show_tqdm=params.show_tqdm,
572
+ ensure_non_empty=params.ensure_non_empty,
573
+ stream_batch=params.stream_batch,
574
+ manual_seed=params.manual_seed,
575
+ context=self.context,
576
+ )
577
+
578
+ del emb, input_ids
579
+
580
+ return result
581
+
582
+ @torch.no_grad()
583
+ def _refine_text(
584
+ self,
585
+ text: str,
586
+ device: torch.device,
587
+ params: RefineTextParams,
588
+ ):
589
+
590
+ gpt = self.gpt
591
+
592
+ if not isinstance(text, list):
593
+ text = [text]
594
+
595
+ input_ids, attention_mask, text_mask = self.tokenizer.encode(
596
+ self.speaker.decorate_text_prompts(text, params.prompt),
597
+ self.config.gpt.num_vq,
598
+ device=self.device_gpt,
599
+ )
600
+
601
+ logits_warpers, logits_processors = gen_logits(
602
+ num_code=self.tokenizer.len,
603
+ top_P=params.top_P,
604
+ top_K=params.top_K,
605
+ repetition_penalty=params.repetition_penalty,
606
+ )
607
+
608
+ if gpt.is_vllm:
609
+ from .model.velocity import SamplingParams
610
+
611
+ sample_params = SamplingParams(
612
+ repetition_penalty=params.repetition_penalty,
613
+ temperature=params.temperature,
614
+ top_p=params.top_P,
615
+ top_k=params.top_K,
616
+ max_new_token=params.max_new_token,
617
+ max_tokens=8192,
618
+ min_new_token=params.min_new_token,
619
+ logits_processors=(logits_processors, logits_warpers),
620
+ eos_token=self.tokenizer.eos_token,
621
+ infer_text=True,
622
+ start_idx=input_ids.shape[-2],
623
+ )
624
+ input_ids_list = [i.tolist() for i in input_ids]
625
+ del input_ids
626
+
627
+ result = gpt.llm.generate(
628
+ None, sample_params, input_ids_list, params.show_tqdm
629
+ )
630
+ token_ids = []
631
+ hidden_states = []
632
+ for i in result:
633
+ token_ids.append(torch.tensor(i.outputs[0].token_ids))
634
+ hidden_states.append(i.outputs[0].hidden_states)
635
+
636
+ del text_mask, input_ids_list, result
637
+
638
+ return GPT.GenerationOutputs(
639
+ ids=token_ids,
640
+ hiddens=hidden_states,
641
+ attentions=[],
642
+ )
643
+
644
+ emb = self.embed(input_ids, text_mask)
645
+
646
+ del text_mask
647
+
648
+ result = next(
649
+ gpt.generate(
650
+ emb,
651
+ input_ids,
652
+ temperature=torch.tensor([params.temperature], device=device),
653
+ eos_token=self.tokenizer.eos_token,
654
+ attention_mask=attention_mask,
655
+ max_new_token=params.max_new_token,
656
+ min_new_token=params.min_new_token,
657
+ logits_processors=(*logits_processors, *logits_warpers),
658
+ infer_text=True,
659
+ stream=False,
660
+ show_tqdm=params.show_tqdm,
661
+ ensure_non_empty=params.ensure_non_empty,
662
+ manual_seed=params.manual_seed,
663
+ context=self.context,
664
+ )
665
+ )
666
+
667
+ del emb, input_ids
668
+
669
+ return result
ChatTTS/model/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .dvae import DVAE
2
+ from .embed import Embed
3
+ from .gpt import GPT
4
+ from .processors import gen_logits
5
+ from .speaker import Speaker
6
+ from .tokenizer import Tokenizer
ChatTTS/model/cuda/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .te_llama import TELlamaModel
ChatTTS/model/cuda/patch.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class LlamaRMSNorm(torch.nn.Module):
5
+ def __init__(self, hidden_size, eps=1e-6):
6
+ """
7
+ LlamaRMSNorm is equivalent to T5LayerNorm
8
+ """
9
+ super().__init__()
10
+ self.weight = torch.nn.Parameter(torch.ones(hidden_size))
11
+ self.variance_epsilon = eps
12
+
13
+ def forward(self, hidden_states: torch.Tensor):
14
+ input_dtype = hidden_states.dtype
15
+ hidden_states = hidden_states.to(torch.float32)
16
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
17
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
18
+ return self.weight.to(hidden_states.device) * hidden_states.to(input_dtype)
ChatTTS/model/cuda/te_llama.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # See LICENSE for license information.
4
+ #
5
+ # From https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/te_llama/te_llama.py
6
+ #
7
+ # Edited by fumiama.
8
+
9
+ import re
10
+ from contextlib import contextmanager
11
+ from typing import Dict
12
+
13
+ import transformer_engine as te
14
+ from transformer_engine.pytorch.attention import RotaryPositionEmbedding
15
+
16
+ import torch
17
+
18
+ import transformers
19
+ from transformers.models.llama.modeling_llama import (
20
+ LlamaModel,
21
+ LlamaConfig,
22
+ )
23
+ from transformers.modeling_utils import _load_state_dict_into_model
24
+
25
+ from .patch import LlamaRMSNorm
26
+
27
+
28
+ @contextmanager
29
+ def replace_decoder(te_decoder_cls, llama_rms_norm_cls):
30
+ """
31
+ Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.
32
+ """
33
+ original_llama_decoder_cls = (
34
+ transformers.models.llama.modeling_llama.LlamaDecoderLayer
35
+ )
36
+ transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls
37
+ original_llama_rms_norm_cls = transformers.models.llama.modeling_llama.LlamaRMSNorm
38
+ transformers.models.llama.modeling_llama.LlamaRMSNorm = llama_rms_norm_cls
39
+ try:
40
+ yield
41
+ finally:
42
+ transformers.models.llama.modeling_llama.LlamaDecoderLayer = (
43
+ original_llama_decoder_cls
44
+ )
45
+ transformers.models.llama.modeling_llama.LlamaRMSNorm = (
46
+ original_llama_rms_norm_cls
47
+ )
48
+
49
+
50
+ class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
51
+ """
52
+ Wrapper class over TE's `TransformerLayer`. This makes the wrapper very
53
+ similar to HF's `LlamaDecoderLayer` and easier to replace it in the code.
54
+
55
+ Args:
56
+ config: LlamaConfig
57
+ args: positional args (for compatibility with `LlamaDecoderLayer`)
58
+ kwargs: keyword args (for compatibility with `LlamaDecoderLayer`)
59
+ """
60
+
61
+ def __init__(self, config, *args, **kwargs):
62
+ super().__init__(
63
+ hidden_size=config.hidden_size,
64
+ ffn_hidden_size=config.intermediate_size,
65
+ num_attention_heads=config.num_attention_heads,
66
+ bias=False,
67
+ layernorm_epsilon=config.rms_norm_eps,
68
+ hidden_dropout=0,
69
+ attention_dropout=0,
70
+ fuse_qkv_params=False,
71
+ normalization="RMSNorm",
72
+ activation="swiglu",
73
+ attn_input_format="bshd",
74
+ num_gqa_groups=config.num_key_value_heads,
75
+ )
76
+ te_rope = RotaryPositionEmbedding(
77
+ config.hidden_size // config.num_attention_heads
78
+ )
79
+ self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()
80
+
81
+ def forward(self, hidden_states, *args, attention_mask, **kwargs):
82
+ """
83
+ Custom forward to make sure we only pass relevant arguments to the
84
+ forward pass of the `TransformerLayer`. Also, make sure the output
85
+ format matches the output of the HF's `LlamaDecoderLayer`.
86
+ """
87
+ return (
88
+ super().forward(
89
+ hidden_states,
90
+ attention_mask=attention_mask,
91
+ rotary_pos_emb=self.te_rope_emb,
92
+ ),
93
+ )
94
+
95
+
96
+ class TELlamaModel:
97
+ """
98
+ LM created with `LlamaModel`. The underlying `LlamaDecoderLayer`
99
+ class is monkey-patched with `TELlamaDecoderLayer` class before
100
+ initializing the causal LM with `LlamaModel`.
101
+
102
+ Args:
103
+ config: LlamaConfig
104
+ """
105
+
106
+ def __new__(cls, config: LlamaConfig):
107
+ with replace_decoder(
108
+ te_decoder_cls=TELlamaDecoderLayer, llama_rms_norm_cls=LlamaRMSNorm
109
+ ):
110
+ model = LlamaModel(config)
111
+ return model
112
+
113
+ @classmethod
114
+ def from_state_dict(
115
+ cls,
116
+ state_dict: Dict[str, torch.Tensor],
117
+ config: LlamaConfig,
118
+ ):
119
+ """
120
+ Custom method adapted from `from_pretrained` method in HuggingFace
121
+ Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
122
+ """
123
+
124
+ vanilla_model = cls(config)
125
+
126
+ # replace_params copies parameters relevant only to TransformerEngine
127
+ _replace_params(state_dict, vanilla_model.state_dict(), config)
128
+ # _load_state_dict_into_model copies parameters other than those in TransformerEngine
129
+ _load_state_dict_into_model(vanilla_model, state_dict, start_prefix="")
130
+
131
+ return vanilla_model
132
+
133
+
134
+ def _replace_params(hf_state_dict, te_state_dict, config):
135
+ # collect all layer prefixes to update
136
+ all_layer_prefixes = set()
137
+ for param_key in hf_state_dict.keys():
138
+ layer_prefix_pat = "model.layers.\d+."
139
+ m = re.match(layer_prefix_pat, param_key)
140
+ if m is not None:
141
+ all_layer_prefixes.add(m.group())
142
+
143
+ for layer_prefix in all_layer_prefixes:
144
+ # When loading weights into models with less number of layers, skip the
145
+ # copy if the corresponding layer doesn't exist in HF model
146
+ if layer_prefix + "input_layernorm.weight" in hf_state_dict:
147
+ te_state_dict[
148
+ layer_prefix + "self_attention.layernorm_qkv.layer_norm_weight"
149
+ ].data[:] = hf_state_dict[layer_prefix + "input_layernorm.weight"].data[:]
150
+
151
+ if layer_prefix + "self_attn.q_proj.weight" in hf_state_dict:
152
+ te_state_dict[
153
+ layer_prefix + "self_attention.layernorm_qkv.query_weight"
154
+ ].data[:] = hf_state_dict[layer_prefix + "self_attn.q_proj.weight"].data[:]
155
+
156
+ if layer_prefix + "self_attn.k_proj.weight" in hf_state_dict:
157
+ te_state_dict[
158
+ layer_prefix + "self_attention.layernorm_qkv.key_weight"
159
+ ].data[:] = hf_state_dict[layer_prefix + "self_attn.k_proj.weight"].data[:]
160
+
161
+ if layer_prefix + "self_attn.v_proj.weight" in hf_state_dict:
162
+ te_state_dict[
163
+ layer_prefix + "self_attention.layernorm_qkv.value_weight"
164
+ ].data[:] = hf_state_dict[layer_prefix + "self_attn.v_proj.weight"].data[:]
165
+
166
+ if layer_prefix + "self_attn.o_proj.weight" in hf_state_dict:
167
+ te_state_dict[layer_prefix + "self_attention.proj.weight"].data[:] = (
168
+ hf_state_dict[layer_prefix + "self_attn.o_proj.weight"].data[:]
169
+ )
170
+
171
+ if layer_prefix + "post_attention_layernorm.weight" in hf_state_dict:
172
+ te_state_dict[layer_prefix + "layernorm_mlp.layer_norm_weight"].data[:] = (
173
+ hf_state_dict[layer_prefix + "post_attention_layernorm.weight"].data[:]
174
+ )
175
+
176
+ # It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to
177
+ # load them separately.
178
+ if layer_prefix + "mlp.gate_proj.weight" in hf_state_dict:
179
+ te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
180
+ : config.intermediate_size
181
+ ] = hf_state_dict[layer_prefix + "mlp.gate_proj.weight"].data
182
+
183
+ if layer_prefix + "mlp.up_proj.weight" in hf_state_dict:
184
+ te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
185
+ config.intermediate_size :
186
+ ] = hf_state_dict[layer_prefix + "mlp.up_proj.weight"].data
187
+
188
+ if layer_prefix + "mlp.down_proj.weight" in hf_state_dict:
189
+ te_state_dict[layer_prefix + "layernorm_mlp.fc2_weight"].data[:] = (
190
+ hf_state_dict[layer_prefix + "mlp.down_proj.weight"].data[:]
191
+ )
192
+ return all_layer_prefixes
ChatTTS/model/dvae.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Literal, Union
3
+
4
+ import numpy as np
5
+ import pybase16384 as b14
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchaudio
10
+ from vector_quantize_pytorch import GroupedResidualFSQ
11
+
12
+
13
+ class ConvNeXtBlock(nn.Module):
14
+ def __init__(
15
+ self,
16
+ dim: int,
17
+ intermediate_dim: int,
18
+ kernel: int,
19
+ dilation: int,
20
+ layer_scale_init_value: float = 1e-6,
21
+ ):
22
+ # ConvNeXt Block copied from Vocos.
23
+ super().__init__()
24
+ self.dwconv = nn.Conv1d(
25
+ dim,
26
+ dim,
27
+ kernel_size=kernel,
28
+ padding=dilation * (kernel // 2),
29
+ dilation=dilation,
30
+ groups=dim,
31
+ ) # depthwise conv
32
+
33
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
34
+ self.pwconv1 = nn.Linear(
35
+ dim, intermediate_dim
36
+ ) # pointwise/1x1 convs, implemented with linear layers
37
+ self.act = nn.GELU()
38
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
39
+ self.gamma = (
40
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
41
+ if layer_scale_init_value > 0
42
+ else None
43
+ )
44
+
45
+ def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor:
46
+ residual = x
47
+
48
+ y = self.dwconv(x)
49
+ y.transpose_(1, 2) # (B, C, T) -> (B, T, C)
50
+ x = self.norm(y)
51
+ del y
52
+ y = self.pwconv1(x)
53
+ del x
54
+ x = self.act(y)
55
+ del y
56
+ y = self.pwconv2(x)
57
+ del x
58
+ if self.gamma is not None:
59
+ y *= self.gamma
60
+ y.transpose_(1, 2) # (B, T, C) -> (B, C, T)
61
+
62
+ x = y + residual
63
+ del y
64
+
65
+ return x
66
+
67
+
68
+ class GFSQ(nn.Module):
69
+
70
+ def __init__(
71
+ self, dim: int, levels: List[int], G: int, R: int, eps=1e-5, transpose=True
72
+ ):
73
+ super(GFSQ, self).__init__()
74
+ self.quantizer = GroupedResidualFSQ(
75
+ dim=dim,
76
+ levels=list(levels),
77
+ num_quantizers=R,
78
+ groups=G,
79
+ )
80
+ self.n_ind = math.prod(levels)
81
+ self.eps = eps
82
+ self.transpose = transpose
83
+ self.G = G
84
+ self.R = R
85
+
86
+ def _embed(self, x: torch.Tensor):
87
+ if self.transpose:
88
+ x = x.transpose(1, 2)
89
+ """
90
+ x = rearrange(
91
+ x, "b t (g r) -> g b t r", g = self.G, r = self.R,
92
+ )
93
+ """
94
+ x = x.view(x.size(0), x.size(1), self.G, self.R).permute(2, 0, 1, 3)
95
+ feat = self.quantizer.get_output_from_indices(x)
96
+ return feat.transpose_(1, 2) if self.transpose else feat
97
+
98
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
99
+ return super().__call__(x)
100
+
101
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
102
+ if self.transpose:
103
+ x.transpose_(1, 2)
104
+ # feat, ind = self.quantizer(x)
105
+ _, ind = self.quantizer(x)
106
+ """
107
+ ind = rearrange(
108
+ ind, "g b t r ->b t (g r)",
109
+ )
110
+ """
111
+ ind = ind.permute(1, 2, 0, 3).contiguous()
112
+ ind = ind.view(ind.size(0), ind.size(1), -1)
113
+ """
114
+ embed_onehot_tmp = F.one_hot(ind.long(), self.n_ind)
115
+ embed_onehot = embed_onehot_tmp.to(x.dtype)
116
+ del embed_onehot_tmp
117
+ e_mean = torch.mean(embed_onehot, dim=[0, 1])
118
+ # e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
119
+ torch.div(e_mean, (e_mean.sum(dim=1) + self.eps).unsqueeze(1), out=e_mean)
120
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
121
+
122
+ return
123
+ torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device),
124
+ feat.transpose_(1, 2) if self.transpose else feat,
125
+ perplexity,
126
+ """
127
+ return ind.transpose_(1, 2) if self.transpose else ind
128
+
129
+
130
+ class DVAEDecoder(nn.Module):
131
+ def __init__(
132
+ self,
133
+ idim: int,
134
+ odim: int,
135
+ n_layer=12,
136
+ bn_dim=64,
137
+ hidden=256,
138
+ kernel=7,
139
+ dilation=2,
140
+ up=False,
141
+ ):
142
+ super().__init__()
143
+ self.up = up
144
+ self.conv_in = nn.Sequential(
145
+ nn.Conv1d(idim, bn_dim, 3, 1, 1),
146
+ nn.GELU(),
147
+ nn.Conv1d(bn_dim, hidden, 3, 1, 1),
148
+ )
149
+ self.decoder_block = nn.ModuleList(
150
+ [
151
+ ConvNeXtBlock(
152
+ hidden,
153
+ hidden * 4,
154
+ kernel,
155
+ dilation,
156
+ )
157
+ for _ in range(n_layer)
158
+ ]
159
+ )
160
+ self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
161
+
162
+ def forward(self, x: torch.Tensor, conditioning=None) -> torch.Tensor:
163
+ # B, C, T
164
+ y = self.conv_in(x)
165
+ del x
166
+ for f in self.decoder_block:
167
+ y = f(y, conditioning)
168
+
169
+ x = self.conv_out(y)
170
+ del y
171
+ return x
172
+
173
+
174
+ class MelSpectrogramFeatures(torch.nn.Module):
175
+ def __init__(
176
+ self,
177
+ sample_rate=24000,
178
+ n_fft=1024,
179
+ hop_length=256,
180
+ n_mels=100,
181
+ padding: Literal["center", "same"] = "center",
182
+ device: torch.device = torch.device("cpu"),
183
+ ):
184
+ super().__init__()
185
+ self.device = device
186
+ if padding not in ["center", "same"]:
187
+ raise ValueError("Padding must be 'center' or 'same'.")
188
+ self.padding = padding
189
+ self.mel_spec = torchaudio.transforms.MelSpectrogram(
190
+ sample_rate=sample_rate,
191
+ n_fft=n_fft,
192
+ hop_length=hop_length,
193
+ n_mels=n_mels,
194
+ center=padding == "center",
195
+ power=1,
196
+ )
197
+
198
+ def __call__(self, audio: torch.Tensor) -> torch.Tensor:
199
+ return super().__call__(audio)
200
+
201
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
202
+ audio = audio.to(self.device)
203
+ mel: torch.Tensor = self.mel_spec(audio)
204
+ features = torch.log(torch.clip(mel, min=1e-5))
205
+ return features
206
+
207
+
208
+ class DVAE(nn.Module):
209
+ def __init__(
210
+ self,
211
+ decoder_config: dict,
212
+ encoder_config: Optional[dict] = None,
213
+ vq_config: Optional[dict] = None,
214
+ dim=512,
215
+ coef: Optional[str] = None,
216
+ device: torch.device = torch.device("cpu"),
217
+ ):
218
+ super().__init__()
219
+ if coef is None:
220
+ coef = torch.rand(100)
221
+ else:
222
+ coef = torch.from_numpy(
223
+ np.frombuffer(b14.decode_from_string(coef), dtype=np.float32).copy()
224
+ )
225
+ self.register_buffer("coef", coef.unsqueeze(0).unsqueeze_(2))
226
+
227
+ if encoder_config is not None:
228
+ self.downsample_conv = nn.Sequential(
229
+ nn.Conv1d(100, dim, 3, 1, 1),
230
+ nn.GELU(),
231
+ nn.Conv1d(dim, dim, 4, 2, 1),
232
+ nn.GELU(),
233
+ )
234
+ self.preprocessor_mel = MelSpectrogramFeatures(device=device)
235
+ self.encoder: Optional[DVAEDecoder] = DVAEDecoder(**encoder_config)
236
+
237
+ self.decoder = DVAEDecoder(**decoder_config)
238
+ self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False)
239
+ if vq_config is not None:
240
+ self.vq_layer = GFSQ(**vq_config)
241
+ else:
242
+ self.vq_layer = None
243
+
244
+ def __repr__(self) -> str:
245
+ return b14.encode_to_string(
246
+ self.coef.cpu().numpy().astype(np.float32).tobytes()
247
+ )
248
+
249
+ def __call__(
250
+ self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode"
251
+ ) -> torch.Tensor:
252
+ return super().__call__(inp, mode)
253
+
254
+ @torch.inference_mode()
255
+ def forward(
256
+ self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode"
257
+ ) -> torch.Tensor:
258
+ if mode == "encode" and hasattr(self, "encoder") and self.vq_layer is not None:
259
+ mel = self.preprocessor_mel(inp)
260
+ x: torch.Tensor = self.downsample_conv(
261
+ torch.div(mel, self.coef.view(100, 1).expand(mel.shape), out=mel),
262
+ ).unsqueeze_(0)
263
+ del mel
264
+ x = self.encoder(x)
265
+ ind = self.vq_layer(x)
266
+ del x
267
+ return ind
268
+
269
+ if self.vq_layer is not None:
270
+ vq_feats = self.vq_layer._embed(inp)
271
+ else:
272
+ vq_feats = inp
273
+
274
+ vq_feats = (
275
+ vq_feats.view(
276
+ (vq_feats.size(0), 2, vq_feats.size(1) // 2, vq_feats.size(2)),
277
+ )
278
+ .permute(0, 2, 3, 1)
279
+ .flatten(2)
280
+ )
281
+
282
+ dec_out = self.out_conv(
283
+ self.decoder(
284
+ x=vq_feats,
285
+ ),
286
+ )
287
+
288
+ del vq_feats
289
+
290
+ return torch.mul(dec_out, self.coef, out=dec_out)
291
+
292
+ @torch.inference_mode()
293
+ def sample_audio(self, wav: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
294
+ if isinstance(wav, np.ndarray):
295
+ wav = torch.from_numpy(wav)
296
+ return self(wav, "encode").squeeze_(0)
ChatTTS/model/embed.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from safetensors.torch import safe_open
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn.utils.parametrizations import weight_norm
5
+
6
+
7
+ class Embed(nn.Module):
8
+ def __init__(
9
+ self, hidden_size: int, num_audio_tokens: int, num_text_tokens: int, num_vq=4
10
+ ):
11
+ super().__init__()
12
+
13
+ self.num_vq = num_vq
14
+ self.num_audio_tokens = num_audio_tokens
15
+
16
+ self.model_dim = hidden_size
17
+ self.emb_code = nn.ModuleList(
18
+ [nn.Embedding(num_audio_tokens, self.model_dim) for _ in range(num_vq)],
19
+ )
20
+ self.emb_text = nn.Embedding(num_text_tokens, self.model_dim)
21
+
22
+ self.head_text = weight_norm(
23
+ nn.Linear(self.model_dim, num_text_tokens, bias=False),
24
+ name="weight",
25
+ )
26
+ self.head_code = nn.ModuleList(
27
+ [
28
+ weight_norm(
29
+ nn.Linear(self.model_dim, num_audio_tokens, bias=False),
30
+ name="weight",
31
+ )
32
+ for _ in range(self.num_vq)
33
+ ],
34
+ )
35
+
36
+ @torch.inference_mode()
37
+ def from_pretrained(self, filename: str, device: torch.device):
38
+ state_dict_tensors = {}
39
+ with safe_open(filename, framework="pt") as f:
40
+ for k in f.keys():
41
+ state_dict_tensors[k] = f.get_tensor(k)
42
+ self.load_state_dict(state_dict_tensors)
43
+ self.to(device)
44
+
45
+ def __call__(
46
+ self, input_ids: torch.Tensor, text_mask: torch.Tensor
47
+ ) -> torch.Tensor:
48
+ """
49
+ get_emb
50
+ """
51
+ return super().__call__(input_ids, text_mask)
52
+
53
+ @torch.inference_mode()
54
+ def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Tensor:
55
+ """
56
+ get_emb
57
+ """
58
+ device = next(self.parameters()).device
59
+ emb_text: torch.Tensor = self.emb_text(
60
+ input_ids[text_mask].narrow(1, 0, 1).squeeze_(1).to(device)
61
+ )
62
+
63
+ text_mask_inv = text_mask.logical_not().to(device)
64
+ masked_input_ids: torch.Tensor = input_ids[text_mask_inv].to(device)
65
+
66
+ emb_code = [
67
+ self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq)
68
+ ]
69
+ emb_code = torch.stack(emb_code, 2).sum(2)
70
+
71
+ emb = torch.zeros(
72
+ (input_ids.shape[:-1]) + (emb_text.shape[-1],),
73
+ device=emb_text.device,
74
+ dtype=emb_text.dtype,
75
+ )
76
+ emb[text_mask] = emb_text
77
+ emb[text_mask_inv] = emb_code.to(emb.dtype)
78
+
79
+ del emb_text, emb_code, text_mask_inv
80
+
81
+ return emb
ChatTTS/model/gpt.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ from dataclasses import dataclass
3
+ import logging
4
+ from typing import Union, List, Optional, Tuple, Callable
5
+ import gc
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.nn.utils.parametrize as P
11
+ from tqdm import tqdm
12
+ from transformers import LlamaModel, LlamaConfig
13
+ from transformers.cache_utils import Cache
14
+ from transformers.modeling_outputs import BaseModelOutputWithPast
15
+ from transformers.utils import is_flash_attn_2_available
16
+
17
+ from ..utils import del_all
18
+ from .embed import Embed
19
+
20
+
21
+ class GPT(nn.Module):
22
+ def __init__(
23
+ self,
24
+ gpt_config: dict,
25
+ embed: Embed,
26
+ use_flash_attn=False,
27
+ use_vllm=False,
28
+ device=torch.device("cpu"),
29
+ device_gpt=torch.device("cpu"),
30
+ logger=logging.getLogger(__name__),
31
+ ):
32
+ super().__init__()
33
+
34
+ self.logger = logger
35
+
36
+ self.device = device
37
+ self.device_gpt = device_gpt
38
+
39
+ self.generator = torch.Generator(device=device)
40
+
41
+ self.num_vq = int(gpt_config["num_vq"])
42
+ self.num_audio_tokens = int(gpt_config["num_audio_tokens"])
43
+ self.num_text_tokens = int(gpt_config["num_text_tokens"])
44
+
45
+ self.use_flash_attn = use_flash_attn
46
+ self.is_te_llama = False
47
+ self.is_vllm = use_vllm
48
+
49
+ if self.is_vllm:
50
+ return
51
+
52
+ self.llama_config = self._build_llama_config(gpt_config)
53
+
54
+ self.emb_code = [ec.__call__ for ec in embed.emb_code]
55
+ self.emb_text = embed.emb_text.__call__
56
+ self.head_text = embed.head_text.__call__
57
+ self.head_code = [hc.__call__ for hc in embed.head_code]
58
+
59
+ def from_pretrained(
60
+ self, gpt_folder: str, embed_file_path: str, experimental=False
61
+ ):
62
+ if self.is_vllm and platform.system().lower() == "linux":
63
+
64
+ from .velocity import LLM
65
+
66
+ self.llm = LLM(
67
+ model=gpt_folder,
68
+ num_audio_tokens=self.num_audio_tokens,
69
+ num_text_tokens=self.num_text_tokens,
70
+ post_model_path=embed_file_path,
71
+ )
72
+ self.logger.info("vLLM model loaded")
73
+ return
74
+
75
+ self.gpt: LlamaModel = LlamaModel.from_pretrained(gpt_folder).to(
76
+ self.device_gpt
77
+ )
78
+ del self.gpt.embed_tokens
79
+
80
+ if (
81
+ experimental
82
+ and "cuda" in str(self.device_gpt)
83
+ and platform.system().lower() == "linux"
84
+ ): # is TELlamaModel
85
+ try:
86
+ from .cuda import TELlamaModel
87
+
88
+ self.logger.warning(
89
+ "Linux with CUDA, try NVIDIA accelerated TELlamaModel because experimental is enabled"
90
+ )
91
+ state_dict = self.gpt.state_dict()
92
+ vanilla = TELlamaModel.from_state_dict(state_dict, self.llama_config)
93
+ # Force mem release. Taken from huggingface code
94
+ del state_dict, self.gpt
95
+ gc.collect()
96
+ self.gpt = vanilla
97
+ self.is_te_llama = True
98
+ except Exception as e:
99
+ self.logger.warning(
100
+ f"use default LlamaModel for importing TELlamaModel error: {e}"
101
+ )
102
+
103
+ class Context:
104
+ def __init__(self):
105
+ self._interrupt = False
106
+
107
+ def set(self, v: bool):
108
+ self._interrupt = v
109
+
110
+ def get(self) -> bool:
111
+ return self._interrupt
112
+
113
+ def _build_llama_config(
114
+ self,
115
+ config: dict,
116
+ ) -> Tuple[LlamaModel, LlamaConfig]:
117
+
118
+ if self.use_flash_attn and is_flash_attn_2_available():
119
+ llama_config = LlamaConfig(
120
+ **config,
121
+ attn_implementation="flash_attention_2",
122
+ )
123
+ self.logger.warning(
124
+ "enabling flash_attention_2 may make gpt be even slower"
125
+ )
126
+ else:
127
+ llama_config = LlamaConfig(**config)
128
+
129
+ return llama_config
130
+
131
+ def prepare(self, compile=False):
132
+ if self.use_flash_attn and is_flash_attn_2_available():
133
+ self.gpt = self.gpt.to(dtype=torch.float16)
134
+ if compile and not self.is_te_llama and not self.is_vllm:
135
+ try:
136
+ self.compile(backend="inductor", dynamic=True)
137
+ self.gpt.compile(backend="inductor", dynamic=True)
138
+ except RuntimeError as e:
139
+ self.logger.warning(f"compile failed: {e}. fallback to normal mode.")
140
+
141
+ @dataclass(repr=False, eq=False)
142
+ class _GenerationInputs:
143
+ position_ids: torch.Tensor
144
+ cache_position: torch.Tensor
145
+ use_cache: bool
146
+ input_ids: Optional[torch.Tensor] = None
147
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
148
+ attention_mask: Optional[torch.Tensor] = None
149
+ inputs_embeds: Optional[torch.Tensor] = None
150
+
151
+ def to(self, device: torch.device, dtype: torch.dtype):
152
+ if self.attention_mask is not None:
153
+ self.attention_mask = self.attention_mask.to(device, dtype=dtype)
154
+ if self.position_ids is not None:
155
+ self.position_ids = self.position_ids.to(device, dtype=dtype)
156
+ if self.inputs_embeds is not None:
157
+ self.inputs_embeds = self.inputs_embeds.to(device, dtype=dtype)
158
+ if self.cache_position is not None:
159
+ self.cache_position = self.cache_position.to(device, dtype=dtype)
160
+
161
+ @torch.no_grad()
162
+ def _prepare_generation_inputs(
163
+ self,
164
+ input_ids: torch.Tensor,
165
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
166
+ attention_mask: Optional[torch.Tensor] = None,
167
+ inputs_embeds: Optional[torch.Tensor] = None,
168
+ cache_position: Optional[torch.Tensor] = None,
169
+ position_ids: Optional[torch.Tensor] = None,
170
+ use_cache=True,
171
+ ) -> _GenerationInputs:
172
+ # With static cache, the `past_key_values` is None
173
+ # TODO joao: standardize interface for the different Cache classes and remove of this if
174
+ has_static_cache = False
175
+ if past_key_values is None:
176
+ if hasattr(self.gpt.layers[0], "self_attn"):
177
+ past_key_values = getattr(
178
+ self.gpt.layers[0].self_attn, "past_key_value", None
179
+ )
180
+ has_static_cache = past_key_values is not None
181
+
182
+ past_length = 0
183
+ if past_key_values is not None:
184
+ if isinstance(past_key_values, Cache):
185
+ past_length = (
186
+ int(cache_position[0])
187
+ if cache_position is not None
188
+ else past_key_values.get_seq_length()
189
+ )
190
+ max_cache_length = past_key_values.get_max_length()
191
+ cache_length = (
192
+ past_length
193
+ if max_cache_length is None
194
+ else min(max_cache_length, past_length)
195
+ )
196
+ # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
197
+ else:
198
+ cache_length = past_length = past_key_values[0][0].shape[2]
199
+ max_cache_length = None
200
+
201
+ # Keep only the unprocessed tokens:
202
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
203
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
204
+ # input)
205
+ if (
206
+ attention_mask is not None
207
+ and attention_mask.shape[1] > input_ids.shape[1]
208
+ ):
209
+ start = attention_mask.shape[1] - past_length
210
+ input_ids = input_ids.narrow(1, -start, start)
211
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
212
+ # input_ids based on the past_length.
213
+ elif past_length < input_ids.shape[1]:
214
+ input_ids = input_ids.narrow(
215
+ 1, past_length, input_ids.size(1) - past_length
216
+ )
217
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
218
+
219
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
220
+ if (
221
+ max_cache_length is not None
222
+ and attention_mask is not None
223
+ and cache_length + input_ids.shape[1] > max_cache_length
224
+ ):
225
+ attention_mask = attention_mask.narrow(
226
+ 1, -max_cache_length, max_cache_length
227
+ )
228
+
229
+ if attention_mask is not None and position_ids is None:
230
+ # create position_ids on the fly for batch generation
231
+ position_ids = attention_mask.long().cumsum(-1) - 1
232
+ position_ids.masked_fill_(attention_mask.eq(0), 1)
233
+ if past_key_values:
234
+ position_ids = position_ids.narrow(
235
+ 1, -input_ids.shape[1], input_ids.shape[1]
236
+ )
237
+
238
+ input_length = (
239
+ position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
240
+ )
241
+ if cache_position is None:
242
+ cache_position = torch.arange(
243
+ past_length, past_length + input_length, device=input_ids.device
244
+ )
245
+ else:
246
+ cache_position = cache_position.narrow(0, -input_length, input_length)
247
+
248
+ if has_static_cache:
249
+ past_key_values = None
250
+
251
+ model_inputs = self._GenerationInputs(
252
+ position_ids=position_ids,
253
+ cache_position=cache_position,
254
+ use_cache=use_cache,
255
+ )
256
+
257
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
258
+ if inputs_embeds is not None and past_key_values is None:
259
+ model_inputs.inputs_embeds = inputs_embeds
260
+ else:
261
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
262
+ # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
263
+ # TODO: use `next_tokens` directly instead.
264
+ model_inputs.input_ids = input_ids.contiguous()
265
+
266
+ model_inputs.past_key_values = past_key_values
267
+ model_inputs.attention_mask = attention_mask
268
+
269
+ return model_inputs
270
+
271
+ @dataclass(repr=False, eq=False)
272
+ class GenerationOutputs:
273
+ ids: List[torch.Tensor]
274
+ attentions: List[Optional[Tuple[torch.FloatTensor, ...]]]
275
+ hiddens: List[torch.Tensor]
276
+
277
+ def destroy(self):
278
+ del_all(self.ids)
279
+ del_all(self.attentions)
280
+ del_all(self.hiddens)
281
+
282
+ @torch.no_grad()
283
+ def _prepare_generation_outputs(
284
+ self,
285
+ inputs_ids: torch.Tensor,
286
+ start_idx: int,
287
+ end_idx: torch.Tensor,
288
+ attentions: List[Optional[Tuple[torch.FloatTensor, ...]]],
289
+ hiddens: List[torch.Tensor],
290
+ infer_text: bool,
291
+ ) -> GenerationOutputs:
292
+ inputs_ids = [
293
+ inputs_ids[idx].narrow(0, start_idx, i) for idx, i in enumerate(end_idx)
294
+ ]
295
+ if infer_text:
296
+ inputs_ids = [i.narrow(1, 0, 1).squeeze_(1) for i in inputs_ids]
297
+
298
+ if len(hiddens) > 0:
299
+ hiddens = torch.stack(hiddens, 1)
300
+ hiddens = [
301
+ hiddens[idx].narrow(0, 0, i) for idx, i in enumerate(end_idx.int())
302
+ ]
303
+
304
+ return self.GenerationOutputs(
305
+ ids=inputs_ids,
306
+ attentions=attentions,
307
+ hiddens=hiddens,
308
+ )
309
+
310
+ @torch.no_grad()
311
+ def generate(
312
+ self,
313
+ emb: torch.Tensor,
314
+ inputs_ids: torch.Tensor,
315
+ temperature: torch.Tensor,
316
+ eos_token: Union[int, torch.Tensor],
317
+ attention_mask: Optional[torch.Tensor] = None,
318
+ max_new_token=2048,
319
+ min_new_token=0,
320
+ logits_processors: Tuple[
321
+ Callable[[torch.LongTensor, torch.FloatTensor], torch.FloatTensor]
322
+ ] = (),
323
+ infer_text=False,
324
+ return_attn=False,
325
+ return_hidden=False,
326
+ stream=False,
327
+ show_tqdm=True,
328
+ ensure_non_empty=True,
329
+ stream_batch=24,
330
+ manual_seed: Optional[int] = None,
331
+ context=Context(),
332
+ ):
333
+
334
+ attentions: List[Optional[Tuple[torch.FloatTensor, ...]]] = []
335
+ hiddens = []
336
+ stream_iter = 0
337
+
338
+ start_idx, end_idx = inputs_ids.shape[1], torch.zeros(
339
+ inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long
340
+ )
341
+ finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()
342
+
343
+ old_temperature = temperature
344
+
345
+ temperature = (
346
+ temperature.unsqueeze(0)
347
+ .expand(inputs_ids.shape[0], -1)
348
+ .contiguous()
349
+ .view(-1, 1)
350
+ )
351
+
352
+ attention_mask_cache = torch.ones(
353
+ (
354
+ inputs_ids.shape[0],
355
+ inputs_ids.shape[1] + max_new_token,
356
+ ),
357
+ dtype=torch.bool,
358
+ device=inputs_ids.device,
359
+ )
360
+ if attention_mask is not None:
361
+ attention_mask_cache.narrow(1, 0, attention_mask.shape[1]).copy_(
362
+ attention_mask
363
+ )
364
+
365
+ progress = inputs_ids.size(1)
366
+ # pre-allocate inputs_ids
367
+ inputs_ids_buf = torch.zeros(
368
+ inputs_ids.size(0),
369
+ progress + max_new_token,
370
+ inputs_ids.size(2),
371
+ dtype=inputs_ids.dtype,
372
+ device=inputs_ids.device,
373
+ )
374
+ inputs_ids_buf.narrow(1, 0, progress).copy_(inputs_ids)
375
+ del inputs_ids
376
+ inputs_ids = inputs_ids_buf.narrow(1, 0, progress)
377
+
378
+ pbar: Optional[tqdm] = None
379
+
380
+ if show_tqdm:
381
+ pbar = tqdm(
382
+ total=max_new_token,
383
+ desc="text" if infer_text else "code",
384
+ bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]",
385
+ )
386
+
387
+ past_key_values = None
388
+
389
+ for i in range(max_new_token):
390
+
391
+ model_input = self._prepare_generation_inputs(
392
+ inputs_ids,
393
+ past_key_values,
394
+ attention_mask_cache.narrow(1, 0, inputs_ids.shape[1]),
395
+ use_cache=not self.is_te_llama,
396
+ )
397
+
398
+ if i > 0:
399
+ del emb
400
+ inputs_ids_emb = model_input.input_ids.to(self.device_gpt)
401
+ if infer_text:
402
+ emb: torch.Tensor = self.emb_text(inputs_ids_emb[:, :, 0])
403
+ else:
404
+ code_emb = [
405
+ self.emb_code[i](inputs_ids_emb[:, :, i])
406
+ for i in range(self.num_vq)
407
+ ]
408
+ emb = torch.stack(code_emb, 3).sum(3)
409
+ del inputs_ids_emb, model_input.input_ids
410
+ model_input.inputs_embeds = emb
411
+
412
+ model_input.to(self.device_gpt, self.gpt.dtype)
413
+
414
+ outputs: BaseModelOutputWithPast = self.gpt(
415
+ attention_mask=model_input.attention_mask,
416
+ position_ids=model_input.position_ids,
417
+ past_key_values=model_input.past_key_values,
418
+ inputs_embeds=model_input.inputs_embeds,
419
+ use_cache=model_input.use_cache,
420
+ output_attentions=return_attn,
421
+ cache_position=model_input.cache_position,
422
+ )
423
+ del_all(model_input)
424
+ attentions.append(outputs.attentions)
425
+ hidden_states = outputs.last_hidden_state.to(
426
+ self.device, dtype=torch.float
427
+ ) # 🐻
428
+ past_key_values = outputs.past_key_values
429
+ del_all(outputs)
430
+ if return_hidden:
431
+ hiddens.append(hidden_states.narrow(1, -1, 1).squeeze_(1))
432
+
433
+ with P.cached():
434
+ if infer_text:
435
+ logits: torch.Tensor = self.head_text(hidden_states)
436
+ else:
437
+ # logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3)
438
+ logits = torch.empty(
439
+ hidden_states.size(0),
440
+ hidden_states.size(1),
441
+ self.num_audio_tokens,
442
+ self.num_vq,
443
+ dtype=torch.float,
444
+ device=self.device,
445
+ )
446
+ for num_vq_iter in range(self.num_vq):
447
+ x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
448
+ logits[..., num_vq_iter] = x
449
+ del x
450
+
451
+ del hidden_states
452
+
453
+ # logits = logits[:, -1].float()
454
+ logits = logits.narrow(1, -1, 1).squeeze_(1).float()
455
+
456
+ if not infer_text:
457
+ # logits = rearrange(logits, "b c n -> (b n) c")
458
+ logits = logits.permute(0, 2, 1)
459
+ logits = logits.reshape(-1, logits.size(2))
460
+ # logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
461
+ inputs_ids_sliced = inputs_ids.narrow(
462
+ 1,
463
+ start_idx,
464
+ inputs_ids.size(1) - start_idx,
465
+ ).permute(0, 2, 1)
466
+ logits_token = inputs_ids_sliced.reshape(
467
+ inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1),
468
+ -1,
469
+ ).to(self.device)
470
+ del inputs_ids_sliced
471
+ else:
472
+ logits_token = (
473
+ inputs_ids.narrow(
474
+ 1,
475
+ start_idx,
476
+ inputs_ids.size(1) - start_idx,
477
+ )
478
+ .narrow(2, 0, 1)
479
+ .to(self.device)
480
+ )
481
+
482
+ logits /= temperature
483
+
484
+ for logitsProcessors in logits_processors:
485
+ logits = logitsProcessors(logits_token, logits)
486
+
487
+ del logits_token
488
+
489
+ if i < min_new_token:
490
+ logits[:, eos_token] = -torch.inf
491
+
492
+ scores = F.softmax(logits, dim=-1)
493
+
494
+ del logits
495
+
496
+ if manual_seed is None:
497
+ idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)
498
+ else:
499
+ idx_next = torch.multinomial(
500
+ scores,
501
+ num_samples=1,
502
+ generator=self.generator.manual_seed(manual_seed),
503
+ ).to(finish.device)
504
+
505
+ del scores
506
+
507
+ if not infer_text:
508
+ # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
509
+ idx_next = idx_next.view(-1, self.num_vq)
510
+ finish_or = idx_next.eq(eos_token).any(1)
511
+ finish.logical_or_(finish_or)
512
+ del finish_or
513
+ inputs_ids_buf.narrow(1, progress, 1).copy_(idx_next.unsqueeze_(1))
514
+ else:
515
+ finish_or = idx_next.eq(eos_token).any(1)
516
+ finish.logical_or_(finish_or)
517
+ del finish_or
518
+ inputs_ids_buf.narrow(1, progress, 1).copy_(
519
+ idx_next.unsqueeze_(-1).expand(-1, -1, self.num_vq),
520
+ )
521
+
522
+ if i == 0 and finish.any():
523
+ self.logger.warning(
524
+ "unexpected end at index %s",
525
+ str([unexpected_idx.item() for unexpected_idx in finish.nonzero()]),
526
+ )
527
+ if ensure_non_empty and manual_seed is None:
528
+ if show_tqdm:
529
+ pbar.close()
530
+ self.logger.warning("regenerate in order to ensure non-empty")
531
+ del_all(attentions)
532
+ del_all(hiddens)
533
+ del (
534
+ start_idx,
535
+ end_idx,
536
+ finish,
537
+ temperature,
538
+ attention_mask_cache,
539
+ past_key_values,
540
+ idx_next,
541
+ inputs_ids_buf,
542
+ )
543
+ new_gen = self.generate(
544
+ emb,
545
+ inputs_ids,
546
+ old_temperature,
547
+ eos_token,
548
+ attention_mask,
549
+ max_new_token,
550
+ min_new_token,
551
+ logits_processors,
552
+ infer_text,
553
+ return_attn,
554
+ return_hidden,
555
+ stream,
556
+ show_tqdm,
557
+ ensure_non_empty,
558
+ stream_batch,
559
+ manual_seed,
560
+ context,
561
+ )
562
+ for result in new_gen:
563
+ yield result
564
+ del inputs_ids
565
+ return
566
+
567
+ del idx_next
568
+ progress += 1
569
+ inputs_ids = inputs_ids_buf.narrow(1, 0, progress)
570
+
571
+ not_finished = finish.logical_not().to(end_idx.device)
572
+ end_idx.add_(not_finished.int())
573
+ stream_iter += not_finished.any().int()
574
+ if stream:
575
+ if stream_iter > 0 and stream_iter % stream_batch == 0:
576
+ self.logger.debug("yield stream result, end: %d", end_idx)
577
+ yield self._prepare_generation_outputs(
578
+ inputs_ids,
579
+ start_idx,
580
+ end_idx,
581
+ attentions,
582
+ hiddens,
583
+ infer_text,
584
+ )
585
+ del not_finished
586
+
587
+ if finish.all() or context.get():
588
+ break
589
+
590
+ if pbar is not None:
591
+ pbar.update(1)
592
+
593
+ if pbar is not None:
594
+ pbar.close()
595
+
596
+ if not finish.all():
597
+ if context.get():
598
+ self.logger.warning("generation is interrupted")
599
+ else:
600
+ self.logger.warning(
601
+ f"incomplete result. hit max_new_token: {max_new_token}"
602
+ )
603
+
604
+ del finish, inputs_ids_buf
605
+
606
+ yield self._prepare_generation_outputs(
607
+ inputs_ids,
608
+ start_idx,
609
+ end_idx,
610
+ attentions,
611
+ hiddens,
612
+ infer_text,
613
+ )
ChatTTS/model/processors.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
4
+
5
+
6
+ class CustomRepetitionPenaltyLogitsProcessorRepeat:
7
+
8
+ def __init__(self, penalty: float, max_input_ids: int, past_window: int):
9
+ if not isinstance(penalty, float) or not (penalty > 0):
10
+ raise ValueError(
11
+ f"`penalty` has to be a strictly positive float, but is {penalty}"
12
+ )
13
+
14
+ self.penalty = penalty
15
+ self.max_input_ids = max_input_ids
16
+ self.past_window = past_window
17
+
18
+ def __call__(
19
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
20
+ ) -> torch.FloatTensor:
21
+ if input_ids.size(1) > self.past_window:
22
+ input_ids = input_ids.narrow(1, -self.past_window, self.past_window)
23
+ freq = F.one_hot(input_ids, scores.size(1)).sum(1)
24
+ if freq.size(0) > self.max_input_ids:
25
+ freq.narrow(
26
+ 0, self.max_input_ids, freq.size(0) - self.max_input_ids
27
+ ).zero_()
28
+ alpha = torch.pow(self.penalty, freq)
29
+ scores = scores.contiguous()
30
+ inp = scores.multiply(alpha)
31
+ oth = scores.divide(alpha)
32
+ con = scores < 0
33
+ out = torch.where(con, inp, oth)
34
+ del inp, oth, scores, con, alpha
35
+ return out
36
+
37
+
38
+ def gen_logits(
39
+ num_code: int,
40
+ top_P=0.7,
41
+ top_K=20,
42
+ repetition_penalty=1.0,
43
+ ):
44
+ logits_warpers = []
45
+ if top_P is not None:
46
+ logits_warpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
47
+ if top_K is not None:
48
+ logits_warpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
49
+
50
+ logits_processors = []
51
+ if repetition_penalty is not None and repetition_penalty != 1:
52
+ logits_processors.append(
53
+ CustomRepetitionPenaltyLogitsProcessorRepeat(
54
+ repetition_penalty, num_code, 16
55
+ )
56
+ )
57
+
58
+ return logits_warpers, logits_processors
ChatTTS/model/speaker.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lzma
2
+ from typing import List, Optional, Union
3
+
4
+ import pybase16384 as b14
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class Speaker:
11
+ def __init__(self, dim: int, spk_cfg: str, device=torch.device("cpu")) -> None:
12
+ spk_stat = torch.from_numpy(
13
+ np.frombuffer(b14.decode_from_string(spk_cfg), dtype=np.float16).copy()
14
+ ).to(device=device)
15
+ self.std, self.mean = spk_stat.requires_grad_(False).chunk(2)
16
+ self.dim = dim
17
+
18
+ def sample_random(self) -> str:
19
+ return self._encode(self._sample_random())
20
+
21
+ @torch.inference_mode()
22
+ def apply(
23
+ self,
24
+ emb: torch.Tensor,
25
+ spk_emb: Union[str, torch.Tensor],
26
+ input_ids: torch.Tensor,
27
+ spk_emb_ids: int,
28
+ device: torch.device,
29
+ inplace: bool = True,
30
+ ) -> torch.Tensor:
31
+ if isinstance(spk_emb, str):
32
+ spk_emb_tensor = torch.from_numpy(self._decode(spk_emb))
33
+ else:
34
+ spk_emb_tensor = spk_emb
35
+ n = (
36
+ F.normalize(
37
+ spk_emb_tensor,
38
+ p=2.0,
39
+ dim=0,
40
+ eps=1e-12,
41
+ )
42
+ .to(device)
43
+ .unsqueeze_(0)
44
+ .expand(emb.size(0), -1)
45
+ .unsqueeze_(1)
46
+ .expand(emb.shape)
47
+ )
48
+ cond = input_ids.narrow(-1, 0, 1).eq(spk_emb_ids).expand(emb.shape)
49
+ out = torch.where(cond, n, emb, out=emb if inplace else None)
50
+ if inplace:
51
+ del cond, n
52
+ return out
53
+
54
+ @staticmethod
55
+ @torch.no_grad()
56
+ def decorate_code_prompts(
57
+ text: List[str],
58
+ prompt: str,
59
+ txt_smp: Optional[str],
60
+ spk_emb: Optional[str],
61
+ ) -> List[str]:
62
+ for i, t in enumerate(text):
63
+ text[i] = (
64
+ t.replace("[Stts]", "")
65
+ .replace("[spk_emb]", "")
66
+ .replace("[empty_spk]", "")
67
+ .strip()
68
+ )
69
+ """
70
+ see https://github.com/2noise/ChatTTS/issues/459
71
+ """
72
+
73
+ if prompt:
74
+ text = [prompt + i for i in text]
75
+
76
+ txt_smp = "" if txt_smp is None else txt_smp
77
+ if spk_emb is not None:
78
+ text = [f"[Stts][spk_emb]{txt_smp}{i}[Ptts]" for i in text]
79
+ else:
80
+ text = [f"[Stts][empty_spk]{txt_smp}{i}[Ptts]" for i in text]
81
+
82
+ return text
83
+
84
+ @staticmethod
85
+ @torch.no_grad()
86
+ def decorate_text_prompts(text: List[str], prompt: str) -> List[str]:
87
+ return [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text]
88
+
89
+ @staticmethod
90
+ @torch.no_grad()
91
+ def encode_prompt(prompt: torch.Tensor) -> str:
92
+ arr: np.ndarray = prompt.cpu().numpy().astype(np.uint16)
93
+ shp = arr.shape
94
+ assert len(shp) == 2, "prompt must be a 2D tensor"
95
+ s = b14.encode_to_string(
96
+ np.array(shp, dtype="<u2").tobytes()
97
+ + lzma.compress(
98
+ arr.astype("<u2").tobytes(),
99
+ format=lzma.FORMAT_RAW,
100
+ filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
101
+ ),
102
+ )
103
+ del arr
104
+ return s
105
+
106
+ @staticmethod
107
+ @torch.no_grad()
108
+ def decode_prompt(prompt: str) -> torch.Tensor:
109
+ dec = b14.decode_from_string(prompt)
110
+ shp = np.frombuffer(dec[:4], dtype="<u2")
111
+ p = np.frombuffer(
112
+ lzma.decompress(
113
+ dec[4:],
114
+ format=lzma.FORMAT_RAW,
115
+ filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
116
+ ),
117
+ dtype="<u2",
118
+ ).copy()
119
+ del dec
120
+ return torch.from_numpy(p.astype(np.int32)).view(*shp)
121
+
122
+ @torch.no_grad()
123
+ def _sample_random(self) -> torch.Tensor:
124
+ spk = (
125
+ torch.randn(self.dim, device=self.std.device, dtype=self.std.dtype)
126
+ .mul_(self.std)
127
+ .add_(self.mean)
128
+ )
129
+ return spk
130
+
131
+ @staticmethod
132
+ @torch.no_grad()
133
+ def _encode(spk_emb: torch.Tensor) -> str:
134
+ arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy()
135
+ s = b14.encode_to_string(
136
+ lzma.compress(
137
+ arr.tobytes(),
138
+ format=lzma.FORMAT_RAW,
139
+ filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
140
+ ),
141
+ )
142
+ del arr
143
+ return s
144
+
145
+ @staticmethod
146
+ def _decode(spk_emb: str) -> np.ndarray:
147
+ return np.frombuffer(
148
+ lzma.decompress(
149
+ b14.decode_from_string(spk_emb),
150
+ format=lzma.FORMAT_RAW,
151
+ filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
152
+ ),
153
+ dtype=np.float16,
154
+ ).copy()
ChatTTS/model/tokenizer.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
+ """
5
+ https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning
6
+ """
7
+
8
+ from typing import List, Tuple, Optional, Union
9
+
10
+ import torch
11
+ from transformers import BertTokenizerFast
12
+
13
+ from ..utils import del_all
14
+
15
+
16
+ class Tokenizer:
17
+ def __init__(
18
+ self,
19
+ tokenizer_path: torch.serialization.FILE_LIKE,
20
+ ):
21
+ """
22
+ tokenizer: BertTokenizerFast = torch.load(
23
+ tokenizer_path, map_location=device, mmap=True
24
+ )
25
+ # tokenizer.save_pretrained("asset/tokenizer", legacy_format=False)
26
+ """
27
+ tokenizer: BertTokenizerFast = BertTokenizerFast.from_pretrained(tokenizer_path)
28
+ self._tokenizer = tokenizer
29
+
30
+ self.len = len(tokenizer)
31
+ self.spk_emb_ids = tokenizer.convert_tokens_to_ids("[spk_emb]")
32
+ self.break_0_ids = tokenizer.convert_tokens_to_ids("[break_0]")
33
+ self.eos_token = tokenizer.convert_tokens_to_ids("[Ebreak]")
34
+
35
+ @torch.inference_mode()
36
+ def encode(
37
+ self,
38
+ text: List[str],
39
+ num_vq: int,
40
+ prompt: Optional[torch.Tensor] = None,
41
+ device="cpu",
42
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
43
+
44
+ input_ids_lst = []
45
+ attention_mask_lst = []
46
+ max_input_ids_len = -1
47
+ max_attention_mask_len = -1
48
+ prompt_size = 0
49
+
50
+ if prompt is not None:
51
+ assert prompt.size(0) == num_vq, "prompt dim 0 must equal to num_vq"
52
+ prompt_size = prompt.size(1)
53
+
54
+ # avoid random speaker embedding of tokenizer in the other dims
55
+ for t in text:
56
+ x = self._tokenizer.encode_plus(
57
+ t, return_tensors="pt", add_special_tokens=False, padding=True
58
+ )
59
+ input_ids_lst.append(x["input_ids"].squeeze_(0))
60
+ attention_mask_lst.append(x["attention_mask"].squeeze_(0))
61
+ del_all(x)
62
+ ids_sz = input_ids_lst[-1].size(0)
63
+ if ids_sz > max_input_ids_len:
64
+ max_input_ids_len = ids_sz
65
+ attn_sz = attention_mask_lst[-1].size(0)
66
+ if attn_sz > max_attention_mask_len:
67
+ max_attention_mask_len = attn_sz
68
+
69
+ if prompt is not None:
70
+ max_input_ids_len += prompt_size
71
+ max_attention_mask_len += prompt_size
72
+
73
+ input_ids = torch.zeros(
74
+ len(input_ids_lst),
75
+ max_input_ids_len,
76
+ device=device,
77
+ dtype=input_ids_lst[0].dtype,
78
+ )
79
+ for i in range(len(input_ids_lst)):
80
+ input_ids.narrow(0, i, 1).narrow(
81
+ 1,
82
+ max_input_ids_len - prompt_size - input_ids_lst[i].size(0),
83
+ input_ids_lst[i].size(0),
84
+ ).copy_(
85
+ input_ids_lst[i]
86
+ ) # left padding
87
+ del_all(input_ids_lst)
88
+
89
+ attention_mask = torch.zeros(
90
+ len(attention_mask_lst),
91
+ max_attention_mask_len,
92
+ device=device,
93
+ dtype=attention_mask_lst[0].dtype,
94
+ )
95
+ for i in range(len(attention_mask_lst)):
96
+ attn = attention_mask.narrow(0, i, 1)
97
+ attn.narrow(
98
+ 1,
99
+ max_attention_mask_len - prompt_size - attention_mask_lst[i].size(0),
100
+ attention_mask_lst[i].size(0),
101
+ ).copy_(
102
+ attention_mask_lst[i]
103
+ ) # left padding
104
+ if prompt_size > 0:
105
+ attn.narrow(
106
+ 1,
107
+ max_attention_mask_len - prompt_size,
108
+ prompt_size,
109
+ ).fill_(1)
110
+ del_all(attention_mask_lst)
111
+
112
+ text_mask = attention_mask.bool()
113
+ new_input_ids = input_ids.unsqueeze_(-1).expand(-1, -1, num_vq).clone()
114
+ del input_ids
115
+
116
+ if prompt_size > 0:
117
+ text_mask.narrow(1, max_input_ids_len - prompt_size, prompt_size).fill_(0)
118
+ prompt_t = prompt.t().unsqueeze_(0).expand(new_input_ids.size(0), -1, -1)
119
+ new_input_ids.narrow(
120
+ 1,
121
+ max_input_ids_len - prompt_size,
122
+ prompt_size,
123
+ ).copy_(prompt_t)
124
+ del prompt_t
125
+
126
+ return new_input_ids, attention_mask, text_mask
127
+
128
+ @torch.inference_mode
129
+ def decode(
130
+ self,
131
+ sequences: Union[List[int], List[List[int]]],
132
+ skip_special_tokens: bool = False,
133
+ clean_up_tokenization_spaces: bool = None,
134
+ **kwargs,
135
+ ):
136
+ return self._tokenizer.batch_decode(
137
+ sequences, skip_special_tokens, clean_up_tokenization_spaces, **kwargs
138
+ )
ChatTTS/model/velocity/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .llm import LLM
2
+ from .sampling_params import SamplingParams
ChatTTS/model/velocity/block_manager.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A block manager that manages token blocks."""
2
+
3
+ import enum
4
+ from typing import Dict, List, Optional, Set, Tuple
5
+
6
+ from vllm.block import PhysicalTokenBlock
7
+ from .sequence import Sequence, SequenceGroup, SequenceStatus
8
+ from vllm.utils import Device
9
+
10
+ # Mapping: logical block number -> physical block.
11
+ BlockTable = List[PhysicalTokenBlock]
12
+
13
+
14
+ class BlockAllocator:
15
+ """Manages free physical token blocks for a device.
16
+
17
+ The allocator maintains a list of free blocks and allocates a block when
18
+ requested. When a block is freed, its reference count is decremented. If
19
+ the reference count becomes zero, the block is added back to the free list.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ device: Device,
25
+ block_size: int,
26
+ num_blocks: int,
27
+ ) -> None:
28
+ self.device = device
29
+ self.block_size = block_size
30
+ self.num_blocks = num_blocks
31
+
32
+ # Initialize the free blocks.
33
+ self.free_blocks: BlockTable = []
34
+ for i in range(num_blocks):
35
+ block = PhysicalTokenBlock(
36
+ device=device, block_number=i, block_size=block_size
37
+ )
38
+ self.free_blocks.append(block)
39
+
40
+ def allocate(self) -> PhysicalTokenBlock:
41
+ if not self.free_blocks:
42
+ raise ValueError("Out of memory! No free blocks are available.")
43
+ block = self.free_blocks.pop()
44
+ block.ref_count = 1
45
+ return block
46
+
47
+ def free(self, block: PhysicalTokenBlock) -> None:
48
+ if block.ref_count == 0:
49
+ raise ValueError(f"Double free! {block} is already freed.")
50
+ block.ref_count -= 1
51
+ if block.ref_count == 0:
52
+ self.free_blocks.append(block)
53
+
54
+ def get_num_free_blocks(self) -> int:
55
+ return len(self.free_blocks)
56
+
57
+
58
+ class AllocStatus(enum.Enum):
59
+ """Result for BlockSpaceManager.can_allocate
60
+
61
+ 1. Ok: seq_group can be allocated now.
62
+ 2. Later: seq_group cannot be allocated.
63
+ The capacity of allocator is larger than seq_group required.
64
+ 3. Never: seq_group can never be allocated.
65
+ The seq_group is too large to allocated in GPU.
66
+ """
67
+
68
+ OK = enum.auto()
69
+ LATER = enum.auto()
70
+ NEVER = enum.auto()
71
+
72
+
73
+ class BlockSpaceManager:
74
+ """Manages the mapping between logical and physical token blocks."""
75
+
76
+ def __init__(
77
+ self,
78
+ block_size: int,
79
+ num_gpu_blocks: int,
80
+ num_cpu_blocks: int,
81
+ watermark: float = 0.01,
82
+ sliding_window: Optional[int] = None,
83
+ ) -> None:
84
+ self.block_size = block_size
85
+ self.num_total_gpu_blocks = num_gpu_blocks
86
+ self.num_total_cpu_blocks = num_cpu_blocks
87
+
88
+ self.block_sliding_window = None
89
+ if sliding_window is not None:
90
+ assert sliding_window % block_size == 0, (sliding_window, block_size)
91
+ self.block_sliding_window = sliding_window // block_size
92
+
93
+ self.watermark = watermark
94
+ assert watermark >= 0.0
95
+
96
+ self.watermark_blocks = int(watermark * num_gpu_blocks)
97
+ self.gpu_allocator = BlockAllocator(Device.GPU, block_size, num_gpu_blocks)
98
+ self.cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks)
99
+ # Mapping: seq_id -> BlockTable.
100
+ self.block_tables: Dict[int, BlockTable] = {}
101
+
102
+ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
103
+ # FIXME(woosuk): Here we assume that all sequences in the group share
104
+ # the same prompt. This may not be true for preempted sequences.
105
+ seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
106
+ num_required_blocks = len(seq.logical_token_blocks)
107
+ if self.block_sliding_window is not None:
108
+ num_required_blocks = min(num_required_blocks, self.block_sliding_window)
109
+ num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
110
+
111
+ # Use watermark to avoid frequent cache eviction.
112
+ if self.num_total_gpu_blocks - num_required_blocks < self.watermark_blocks:
113
+ return AllocStatus.NEVER
114
+ if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
115
+ return AllocStatus.OK
116
+ else:
117
+ return AllocStatus.LATER
118
+
119
+ def allocate(self, seq_group: SequenceGroup) -> None:
120
+ # NOTE: Here we assume that all sequences in the group have the same
121
+ # prompt.
122
+ seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
123
+
124
+ # Allocate new physical token blocks that will store the prompt tokens.
125
+ block_table: BlockTable = []
126
+ for logical_idx in range(len(seq.logical_token_blocks)):
127
+ if (
128
+ self.block_sliding_window is not None
129
+ and logical_idx >= self.block_sliding_window
130
+ ):
131
+ block = block_table[logical_idx % self.block_sliding_window]
132
+ else:
133
+ block = self.gpu_allocator.allocate()
134
+ # Set the reference counts of the token blocks.
135
+ block.ref_count = seq_group.num_seqs()
136
+ block_table.append(block)
137
+
138
+ # Assign the block table for each sequence.
139
+ for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
140
+ self.block_tables[seq.seq_id] = block_table.copy()
141
+
142
+ def can_append_slot(self, seq_group: SequenceGroup) -> bool:
143
+ # Simple heuristic: If there is at least one free block
144
+ # for each sequence, we can append.
145
+ num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
146
+ num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
147
+ return num_seqs <= num_free_gpu_blocks
148
+
149
+ def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
150
+ """Allocate a physical slot for a new token."""
151
+ logical_blocks = seq.logical_token_blocks
152
+ block_table = self.block_tables[seq.seq_id]
153
+
154
+ if len(block_table) < len(logical_blocks):
155
+ if (
156
+ self.block_sliding_window
157
+ and len(block_table) >= self.block_sliding_window
158
+ ):
159
+ # re-use a block
160
+ block_table.append(
161
+ block_table[len(block_table) % self.block_sliding_window]
162
+ )
163
+ else:
164
+ # The sequence has a new logical block.
165
+ # Allocate a new physical block.
166
+ block = self.gpu_allocator.allocate()
167
+ block_table.append(block)
168
+ return None
169
+
170
+ # We want to append the token to the last physical block.
171
+ last_block = block_table[-1]
172
+ assert last_block.device == Device.GPU
173
+ if last_block.ref_count == 1:
174
+ # Not shared with other sequences. Appendable.
175
+ return None
176
+ else:
177
+ # The last block is shared with other sequences.
178
+ # Copy on Write: Allocate a new block and copy the tokens.
179
+ new_block = self.gpu_allocator.allocate()
180
+ block_table[-1] = new_block
181
+ self.gpu_allocator.free(last_block)
182
+ return last_block.block_number, new_block.block_number
183
+
184
+ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
185
+ # NOTE: fork does not allocate a new physical block.
186
+ # Thus, it is always safe from OOM.
187
+ src_block_table = self.block_tables[parent_seq.seq_id]
188
+ self.block_tables[child_seq.seq_id] = src_block_table.copy()
189
+ for block in src_block_table:
190
+ block.ref_count += 1
191
+
192
+ def _get_physical_blocks(
193
+ self, seq_group: SequenceGroup
194
+ ) -> List[PhysicalTokenBlock]:
195
+ # NOTE: Here, we assume that the physical blocks are only shared by
196
+ # the sequences in the same group.
197
+ blocks: Set[PhysicalTokenBlock] = set()
198
+ for seq in seq_group.get_seqs():
199
+ if seq.is_finished():
200
+ continue
201
+ blocks.update(self.block_tables[seq.seq_id])
202
+ return list(blocks)
203
+
204
+ def can_swap_in(self, seq_group: SequenceGroup) -> bool:
205
+ blocks = self._get_physical_blocks(seq_group)
206
+ num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
207
+ num_free_blocks = self.gpu_allocator.get_num_free_blocks()
208
+ # NOTE: Conservatively, we assume that every sequence will allocate
209
+ # at least one free block right after the swap-in.
210
+ # NOTE: This should match the logic in can_append_slot().
211
+ num_required_blocks = len(blocks) + num_swapped_seqs
212
+ return num_free_blocks - num_required_blocks >= self.watermark_blocks
213
+
214
+ def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
215
+ # CPU block -> GPU block.
216
+ mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
217
+ for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
218
+ new_block_table: BlockTable = []
219
+ block_table = self.block_tables[seq.seq_id]
220
+
221
+ for cpu_block in block_table:
222
+ if cpu_block in mapping:
223
+ gpu_block = mapping[cpu_block]
224
+ gpu_block.ref_count += 1
225
+ else:
226
+ gpu_block = self.gpu_allocator.allocate()
227
+ mapping[cpu_block] = gpu_block
228
+ new_block_table.append(gpu_block)
229
+ # Free the CPU block swapped in to GPU.
230
+ self.cpu_allocator.free(cpu_block)
231
+ self.block_tables[seq.seq_id] = new_block_table
232
+
233
+ block_number_mapping = {
234
+ cpu_block.block_number: gpu_block.block_number
235
+ for cpu_block, gpu_block in mapping.items()
236
+ }
237
+ return block_number_mapping
238
+
239
+ def can_swap_out(self, seq_group: SequenceGroup) -> bool:
240
+ blocks = self._get_physical_blocks(seq_group)
241
+ return len(blocks) <= self.cpu_allocator.get_num_free_blocks()
242
+
243
+ def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
244
+ # GPU block -> CPU block.
245
+ mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
246
+ for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
247
+ new_block_table: BlockTable = []
248
+ block_table = self.block_tables[seq.seq_id]
249
+
250
+ for gpu_block in block_table:
251
+ if gpu_block in mapping:
252
+ cpu_block = mapping[gpu_block]
253
+ cpu_block.ref_count += 1
254
+ else:
255
+ cpu_block = self.cpu_allocator.allocate()
256
+ mapping[gpu_block] = cpu_block
257
+ new_block_table.append(cpu_block)
258
+ # Free the GPU block swapped out to CPU.
259
+ self.gpu_allocator.free(gpu_block)
260
+ self.block_tables[seq.seq_id] = new_block_table
261
+
262
+ block_number_mapping = {
263
+ gpu_block.block_number: cpu_block.block_number
264
+ for gpu_block, cpu_block in mapping.items()
265
+ }
266
+ return block_number_mapping
267
+
268
+ def _free_block_table(self, block_table: BlockTable) -> None:
269
+ for block in set(block_table):
270
+ if block.device == Device.GPU:
271
+ self.gpu_allocator.free(block)
272
+ else:
273
+ self.cpu_allocator.free(block)
274
+
275
+ def free(self, seq: Sequence) -> None:
276
+ if seq.seq_id not in self.block_tables:
277
+ # Already freed or haven't been scheduled yet.
278
+ return
279
+ block_table = self.block_tables[seq.seq_id]
280
+ self._free_block_table(block_table)
281
+ del self.block_tables[seq.seq_id]
282
+
283
+ def reset(self) -> None:
284
+ for block_table in self.block_tables.values():
285
+ self._free_block_table(block_table)
286
+ self.block_tables.clear()
287
+
288
+ def get_block_table(self, seq: Sequence) -> List[int]:
289
+ block_table = self.block_tables[seq.seq_id]
290
+ return [block.block_number for block in block_table]
291
+
292
+ def get_num_free_gpu_blocks(self) -> int:
293
+ return self.gpu_allocator.get_num_free_blocks()
294
+
295
+ def get_num_free_cpu_blocks(self) -> int:
296
+ return self.cpu_allocator.get_num_free_blocks()
ChatTTS/model/velocity/configs.py ADDED
@@ -0,0 +1,865 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, Tuple
2
+ import os
3
+
4
+ import torch
5
+ from transformers import PretrainedConfig
6
+
7
+ from vllm.logger import init_logger
8
+ from vllm.transformers_utils.config import get_config
9
+ from vllm.utils import get_cpu_memory, is_hip
10
+
11
+ import argparse
12
+ import dataclasses
13
+ from dataclasses import dataclass
14
+
15
+
16
+ logger = init_logger(__name__)
17
+
18
+ _GB = 1 << 30
19
+
20
+
21
+ class ModelConfig:
22
+ """Configuration for the model.
23
+
24
+ Args:
25
+ model: Name or path of the huggingface model to use.
26
+ tokenizer: Name or path of the huggingface tokenizer to use.
27
+ tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
28
+ available, and "slow" will always use the slow tokenizer.
29
+ trust_remote_code: Trust remote code (e.g., from HuggingFace) when
30
+ downloading the model and tokenizer.
31
+ download_dir: Directory to download and load the weights, default to the
32
+ default cache directory of huggingface.
33
+ load_format: The format of the model weights to load:
34
+ "auto" will try to load the weights in the safetensors format and
35
+ fall back to the pytorch bin format if safetensors format is
36
+ not available.
37
+ "pt" will load the weights in the pytorch bin format.
38
+ "safetensors" will load the weights in the safetensors format.
39
+ "npcache" will load the weights in pytorch format and store
40
+ a numpy cache to speed up the loading.
41
+ "dummy" will initialize the weights with random values, which is
42
+ mainly for profiling.
43
+ dtype: Data type for model weights and activations. The "auto" option
44
+ will use FP16 precision for FP32 and FP16 models, and BF16 precision
45
+ for BF16 models.
46
+ seed: Random seed for reproducibility.
47
+ revision: The specific model version to use. It can be a branch name,
48
+ a tag name, or a commit id. If unspecified, will use the default
49
+ version.
50
+ tokenizer_revision: The specific tokenizer version to use. It can be a
51
+ branch name, a tag name, or a commit id. If unspecified, will use
52
+ the default version.
53
+ max_model_len: Maximum length of a sequence (including prompt and
54
+ output). If None, will be derived from the model.
55
+ quantization: Quantization method that was used to quantize the model
56
+ weights. If None, we assume the model weights are not quantized.
57
+ enforce_eager: Whether to enforce eager execution. If True, we will
58
+ disable CUDA graph and always execute the model in eager mode.
59
+ If False, we will use CUDA graph and eager execution in hybrid.
60
+ max_context_len_to_capture: Maximum context len covered by CUDA graphs.
61
+ When a sequence has context length larger than this, we fall back
62
+ to eager mode.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ model: str,
68
+ tokenizer: str,
69
+ tokenizer_mode: str,
70
+ trust_remote_code: bool,
71
+ download_dir: Optional[str],
72
+ load_format: str,
73
+ dtype: Union[str, torch.dtype],
74
+ seed: int,
75
+ revision: Optional[str] = None,
76
+ tokenizer_revision: Optional[str] = None,
77
+ max_model_len: Optional[int] = None,
78
+ quantization: Optional[str] = None,
79
+ enforce_eager: bool = False,
80
+ max_context_len_to_capture: Optional[int] = None,
81
+ num_audio_tokens: int = 1024,
82
+ num_text_tokens: int = 80,
83
+ ) -> None:
84
+ self.model = model
85
+ self.tokenizer = tokenizer
86
+ self.tokenizer_mode = tokenizer_mode
87
+ self.trust_remote_code = trust_remote_code
88
+ self.download_dir = download_dir
89
+ self.load_format = load_format
90
+ self.seed = seed
91
+ self.revision = revision
92
+ self.tokenizer_revision = tokenizer_revision
93
+ self.quantization = quantization
94
+ self.enforce_eager = enforce_eager
95
+ self.max_context_len_to_capture = max_context_len_to_capture
96
+ self.num_audio_tokens = num_audio_tokens
97
+ self.num_text_tokens = num_text_tokens
98
+
99
+ if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
100
+ # download model from ModelScope hub,
101
+ # lazy import so that modelscope is not required for normal use.
102
+ from modelscope.hub.snapshot_download import (
103
+ snapshot_download,
104
+ ) # pylint: disable=C
105
+
106
+ model_path = snapshot_download(
107
+ model_id=model, cache_dir=download_dir, revision=revision
108
+ )
109
+ self.model = model_path
110
+ self.download_dir = model_path
111
+ self.tokenizer = model_path
112
+
113
+ self.hf_config = get_config(self.model, trust_remote_code, revision)
114
+ self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
115
+ self.max_model_len = _get_and_verify_max_len(self.hf_config, max_model_len)
116
+ self._verify_load_format()
117
+ self._verify_tokenizer_mode()
118
+ self._verify_quantization()
119
+ self._verify_cuda_graph()
120
+
121
+ def _verify_load_format(self) -> None:
122
+ load_format = self.load_format.lower()
123
+ supported_load_format = ["auto", "pt", "safetensors", "npcache", "dummy"]
124
+ rocm_not_supported_load_format = []
125
+ if load_format not in supported_load_format:
126
+ raise ValueError(
127
+ f"Unknown load format: {self.load_format}. Must be one of "
128
+ "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'."
129
+ )
130
+ if is_hip() and load_format in rocm_not_supported_load_format:
131
+ rocm_supported_load_format = [
132
+ f
133
+ for f in supported_load_format
134
+ if (f not in rocm_not_supported_load_format)
135
+ ]
136
+ raise ValueError(
137
+ f"load format '{load_format}' is not supported in ROCm. "
138
+ f"Supported load format are "
139
+ f"{rocm_supported_load_format}"
140
+ )
141
+
142
+ # TODO: Remove this check once HF updates the pt weights of Mixtral.
143
+ architectures = getattr(self.hf_config, "architectures", [])
144
+ if "MixtralForCausalLM" in architectures and load_format == "pt":
145
+ raise ValueError(
146
+ "Currently, the 'pt' format is not supported for Mixtral. "
147
+ "Please use the 'safetensors' format instead. "
148
+ )
149
+ self.load_format = load_format
150
+
151
+ def _verify_tokenizer_mode(self) -> None:
152
+ tokenizer_mode = self.tokenizer_mode.lower()
153
+ if tokenizer_mode not in ["auto", "slow"]:
154
+ raise ValueError(
155
+ f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
156
+ "either 'auto' or 'slow'."
157
+ )
158
+ self.tokenizer_mode = tokenizer_mode
159
+
160
+ def _verify_quantization(self) -> None:
161
+ supported_quantization = ["awq", "gptq", "squeezellm"]
162
+ rocm_not_supported_quantization = ["awq"]
163
+ if self.quantization is not None:
164
+ self.quantization = self.quantization.lower()
165
+
166
+ # Parse quantization method from the HF model config, if available.
167
+ hf_quant_config = getattr(self.hf_config, "quantization_config", None)
168
+ if hf_quant_config is not None:
169
+ hf_quant_method = str(hf_quant_config["quant_method"]).lower()
170
+ if self.quantization is None:
171
+ self.quantization = hf_quant_method
172
+ elif self.quantization != hf_quant_method:
173
+ raise ValueError(
174
+ "Quantization method specified in the model config "
175
+ f"({hf_quant_method}) does not match the quantization "
176
+ f"method specified in the `quantization` argument "
177
+ f"({self.quantization})."
178
+ )
179
+
180
+ if self.quantization is not None:
181
+ if self.quantization not in supported_quantization:
182
+ raise ValueError(
183
+ f"Unknown quantization method: {self.quantization}. Must "
184
+ f"be one of {supported_quantization}."
185
+ )
186
+ if is_hip() and self.quantization in rocm_not_supported_quantization:
187
+ raise ValueError(
188
+ f"{self.quantization} quantization is currently not supported "
189
+ f"in ROCm."
190
+ )
191
+ logger.warning(
192
+ f"{self.quantization} quantization is not fully "
193
+ "optimized yet. The speed can be slower than "
194
+ "non-quantized models."
195
+ )
196
+
197
+ def _verify_cuda_graph(self) -> None:
198
+ if self.max_context_len_to_capture is None:
199
+ self.max_context_len_to_capture = self.max_model_len
200
+ self.max_context_len_to_capture = min(
201
+ self.max_context_len_to_capture, self.max_model_len
202
+ )
203
+
204
+ def verify_with_parallel_config(
205
+ self,
206
+ parallel_config: "ParallelConfig",
207
+ ) -> None:
208
+ total_num_attention_heads = self.hf_config.num_attention_heads
209
+ tensor_parallel_size = parallel_config.tensor_parallel_size
210
+ if total_num_attention_heads % tensor_parallel_size != 0:
211
+ raise ValueError(
212
+ f"Total number of attention heads ({total_num_attention_heads})"
213
+ " must be divisible by tensor parallel size "
214
+ f"({tensor_parallel_size})."
215
+ )
216
+
217
+ total_num_hidden_layers = self.hf_config.num_hidden_layers
218
+ pipeline_parallel_size = parallel_config.pipeline_parallel_size
219
+ if total_num_hidden_layers % pipeline_parallel_size != 0:
220
+ raise ValueError(
221
+ f"Total number of hidden layers ({total_num_hidden_layers}) "
222
+ "must be divisible by pipeline parallel size "
223
+ f"({pipeline_parallel_size})."
224
+ )
225
+
226
+ def get_sliding_window(self) -> Optional[int]:
227
+ return getattr(self.hf_config, "sliding_window", None)
228
+
229
+ def get_vocab_size(self) -> int:
230
+ return self.hf_config.vocab_size
231
+
232
+ def get_hidden_size(self) -> int:
233
+ return self.hf_config.hidden_size
234
+
235
+ def get_head_size(self) -> int:
236
+ # FIXME(woosuk): This may not be true for all models.
237
+ return self.hf_config.hidden_size // self.hf_config.num_attention_heads
238
+
239
+ def get_total_num_kv_heads(self) -> int:
240
+ """Returns the total number of KV heads."""
241
+ # For GPTBigCode & Falcon:
242
+ # NOTE: for falcon, when new_decoder_architecture is True, the
243
+ # multi_query flag is ignored and we use n_head_kv for the number of
244
+ # KV heads.
245
+ falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
246
+ new_decoder_arch_falcon = (
247
+ self.hf_config.model_type in falcon_model_types
248
+ and getattr(self.hf_config, "new_decoder_architecture", False)
249
+ )
250
+ if not new_decoder_arch_falcon and getattr(
251
+ self.hf_config, "multi_query", False
252
+ ):
253
+ # Multi-query attention, only one KV head.
254
+ # Currently, tensor parallelism is not supported in this case.
255
+ return 1
256
+
257
+ attributes = [
258
+ # For Falcon:
259
+ "n_head_kv",
260
+ "num_kv_heads",
261
+ # For LLaMA-2:
262
+ "num_key_value_heads",
263
+ # For ChatGLM:
264
+ "multi_query_group_num",
265
+ ]
266
+ for attr in attributes:
267
+ num_kv_heads = getattr(self.hf_config, attr, None)
268
+ if num_kv_heads is not None:
269
+ return num_kv_heads
270
+
271
+ # For non-grouped-query attention models, the number of KV heads is
272
+ # equal to the number of attention heads.
273
+ return self.hf_config.num_attention_heads
274
+
275
+ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
276
+ """Returns the number of KV heads per GPU."""
277
+ total_num_kv_heads = self.get_total_num_kv_heads()
278
+ # If tensor parallelism is used, we divide the number of KV heads by
279
+ # the tensor parallel size. We will replicate the KV heads in the
280
+ # case where the number of KV heads is smaller than the tensor
281
+ # parallel size so each GPU has at least one KV head.
282
+ return max(1, total_num_kv_heads // parallel_config.tensor_parallel_size)
283
+
284
+ def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
285
+ total_num_hidden_layers = self.hf_config.num_hidden_layers
286
+ return total_num_hidden_layers // parallel_config.pipeline_parallel_size
287
+
288
+
289
+ class CacheConfig:
290
+ """Configuration for the KV cache.
291
+
292
+ Args:
293
+ block_size: Size of a cache block in number of tokens.
294
+ gpu_memory_utilization: Fraction of GPU memory to use for the
295
+ vLLM execution.
296
+ swap_space: Size of the CPU swap space per GPU (in GiB).
297
+ """
298
+
299
+ def __init__(
300
+ self,
301
+ block_size: int,
302
+ gpu_memory_utilization: float,
303
+ swap_space: int,
304
+ sliding_window: Optional[int] = None,
305
+ ) -> None:
306
+ self.block_size = block_size
307
+ self.gpu_memory_utilization = gpu_memory_utilization
308
+ self.swap_space_bytes = swap_space * _GB
309
+ self.sliding_window = sliding_window
310
+ self._verify_args()
311
+
312
+ # Will be set after profiling.
313
+ self.num_gpu_blocks = None
314
+ self.num_cpu_blocks = None
315
+
316
+ def _verify_args(self) -> None:
317
+ if self.gpu_memory_utilization > 1.0:
318
+ raise ValueError(
319
+ "GPU memory utilization must be less than 1.0. Got "
320
+ f"{self.gpu_memory_utilization}."
321
+ )
322
+
323
+ def verify_with_parallel_config(
324
+ self,
325
+ parallel_config: "ParallelConfig",
326
+ ) -> None:
327
+ total_cpu_memory = get_cpu_memory()
328
+ # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
329
+ # group are in the same node. However, the GPUs may span multiple nodes.
330
+ num_gpus_per_node = parallel_config.tensor_parallel_size
331
+ cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
332
+
333
+ msg = (
334
+ f"{cpu_memory_usage / _GB:.2f} GiB out of "
335
+ f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
336
+ "allocated for the swap space."
337
+ )
338
+ if cpu_memory_usage > 0.7 * total_cpu_memory:
339
+ raise ValueError("Too large swap space. " + msg)
340
+ elif cpu_memory_usage > 0.4 * total_cpu_memory:
341
+ logger.warning("Possibly too large swap space. " + msg)
342
+
343
+
344
+ class ParallelConfig:
345
+ """Configuration for the distributed execution.
346
+
347
+ Args:
348
+ pipeline_parallel_size: Number of pipeline parallel groups.
349
+ tensor_parallel_size: Number of tensor parallel groups.
350
+ worker_use_ray: Whether to use Ray for model workers. Will be set to
351
+ True if either pipeline_parallel_size or tensor_parallel_size is
352
+ greater than 1.
353
+ """
354
+
355
+ def __init__(
356
+ self,
357
+ pipeline_parallel_size: int,
358
+ tensor_parallel_size: int,
359
+ worker_use_ray: bool,
360
+ max_parallel_loading_workers: Optional[int] = None,
361
+ ) -> None:
362
+ self.pipeline_parallel_size = pipeline_parallel_size
363
+ self.tensor_parallel_size = tensor_parallel_size
364
+ self.worker_use_ray = worker_use_ray
365
+ self.max_parallel_loading_workers = max_parallel_loading_workers
366
+
367
+ self.world_size = pipeline_parallel_size * tensor_parallel_size
368
+ if self.world_size > 1:
369
+ self.worker_use_ray = True
370
+ self._verify_args()
371
+
372
+ def _verify_args(self) -> None:
373
+ if self.pipeline_parallel_size > 1:
374
+ raise NotImplementedError("Pipeline parallelism is not supported yet.")
375
+
376
+
377
+ class SchedulerConfig:
378
+ """Scheduler configuration.
379
+
380
+ Args:
381
+ max_num_batched_tokens: Maximum number of tokens to be processed in
382
+ a single iteration.
383
+ max_num_seqs: Maximum number of sequences to be processed in a single
384
+ iteration.
385
+ max_model_len: Maximum length of a sequence (including prompt
386
+ and generated text).
387
+ max_paddings: Maximum number of paddings to be added to a batch.
388
+ """
389
+
390
+ def __init__(
391
+ self,
392
+ max_num_batched_tokens: Optional[int],
393
+ max_num_seqs: int,
394
+ max_model_len: int,
395
+ max_paddings: int,
396
+ ) -> None:
397
+ if max_num_batched_tokens is not None:
398
+ self.max_num_batched_tokens = max_num_batched_tokens
399
+ else:
400
+ # If max_model_len is too short, use 2048 as the default value for
401
+ # higher throughput.
402
+ self.max_num_batched_tokens = max(max_model_len, 2048)
403
+ self.max_num_seqs = max_num_seqs
404
+ self.max_model_len = max_model_len
405
+ self.max_paddings = max_paddings
406
+ self._verify_args()
407
+
408
+ def _verify_args(self) -> None:
409
+ if self.max_num_batched_tokens < self.max_model_len:
410
+ raise ValueError(
411
+ f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
412
+ f"smaller than max_model_len ({self.max_model_len}). "
413
+ "This effectively limits the maximum sequence length to "
414
+ "max_num_batched_tokens and makes vLLM reject longer "
415
+ "sequences. Please increase max_num_batched_tokens or "
416
+ "decrease max_model_len."
417
+ )
418
+ if self.max_num_batched_tokens < self.max_num_seqs:
419
+ raise ValueError(
420
+ f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
421
+ "be greater than or equal to max_num_seqs "
422
+ f"({self.max_num_seqs})."
423
+ )
424
+
425
+
426
+ _STR_DTYPE_TO_TORCH_DTYPE = {
427
+ "half": torch.float16,
428
+ "float16": torch.float16,
429
+ "float": torch.float32,
430
+ "float32": torch.float32,
431
+ "bfloat16": torch.bfloat16,
432
+ }
433
+
434
+ _ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
435
+
436
+
437
+ def _get_and_verify_dtype(
438
+ config: PretrainedConfig,
439
+ dtype: Union[str, torch.dtype],
440
+ ) -> torch.dtype:
441
+ # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
442
+ # because config.torch_dtype can be None.
443
+ config_dtype = getattr(config, "torch_dtype", None)
444
+ if config_dtype is None:
445
+ config_dtype = torch.float32
446
+
447
+ if isinstance(dtype, str):
448
+ dtype = dtype.lower()
449
+ if dtype == "auto":
450
+ if config_dtype == torch.float32:
451
+ # Following the common practice, we use float16 for float32
452
+ # models.
453
+ torch_dtype = torch.float16
454
+ else:
455
+ torch_dtype = config_dtype
456
+ else:
457
+ if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
458
+ raise ValueError(f"Unknown dtype: {dtype}")
459
+ torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
460
+ elif isinstance(dtype, torch.dtype):
461
+ torch_dtype = dtype
462
+ else:
463
+ raise ValueError(f"Unknown dtype: {dtype}")
464
+
465
+ if is_hip() and torch_dtype == torch.float32:
466
+ rocm_supported_dtypes = [
467
+ k
468
+ for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
469
+ if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
470
+ ]
471
+ raise ValueError(
472
+ f"dtype '{dtype}' is not supported in ROCm. "
473
+ f"Supported dtypes are {rocm_supported_dtypes}"
474
+ )
475
+
476
+ # Verify the dtype.
477
+ if torch_dtype != config_dtype:
478
+ if torch_dtype == torch.float32:
479
+ # Upcasting to float32 is allowed.
480
+ pass
481
+ elif config_dtype == torch.float32:
482
+ # Downcasting from float32 to float16 or bfloat16 is allowed.
483
+ pass
484
+ else:
485
+ # Casting between float16 and bfloat16 is allowed with a warning.
486
+ logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
487
+
488
+ return torch_dtype
489
+
490
+
491
+ def _get_and_verify_max_len(
492
+ hf_config: PretrainedConfig,
493
+ max_model_len: Optional[int],
494
+ ) -> int:
495
+ """Get and verify the model's maximum length."""
496
+ derived_max_model_len = float("inf")
497
+ possible_keys = [
498
+ # OPT
499
+ "max_position_embeddings",
500
+ # GPT-2
501
+ "n_positions",
502
+ # MPT
503
+ "max_seq_len",
504
+ # ChatGLM2
505
+ "seq_length",
506
+ # Others
507
+ "max_sequence_length",
508
+ "max_seq_length",
509
+ "seq_len",
510
+ ]
511
+ for key in possible_keys:
512
+ max_len_key = getattr(hf_config, key, None)
513
+ if max_len_key is not None:
514
+ derived_max_model_len = min(derived_max_model_len, max_len_key)
515
+ if derived_max_model_len == float("inf"):
516
+ if max_model_len is not None:
517
+ # If max_model_len is specified, we use it.
518
+ return max_model_len
519
+
520
+ default_max_len = 2048
521
+ logger.warning(
522
+ "The model's config.json does not contain any of the following "
523
+ "keys to determine the original maximum length of the model: "
524
+ f"{possible_keys}. Assuming the model's maximum length is "
525
+ f"{default_max_len}."
526
+ )
527
+ derived_max_model_len = default_max_len
528
+
529
+ rope_scaling = getattr(hf_config, "rope_scaling", None)
530
+ if rope_scaling is not None:
531
+ assert "factor" in rope_scaling
532
+ scaling_factor = rope_scaling["factor"]
533
+ if rope_scaling["type"] == "yarn":
534
+ derived_max_model_len = rope_scaling["original_max_position_embeddings"]
535
+ derived_max_model_len *= scaling_factor
536
+
537
+ if max_model_len is None:
538
+ max_model_len = derived_max_model_len
539
+ elif max_model_len > derived_max_model_len:
540
+ raise ValueError(
541
+ f"User-specified max_model_len ({max_model_len}) is greater than "
542
+ f"the derived max_model_len ({max_len_key}={derived_max_model_len}"
543
+ " in model's config.json). This may lead to incorrect model "
544
+ "outputs or CUDA errors. Make sure the value is correct and "
545
+ "within the model context size."
546
+ )
547
+ return int(max_model_len)
548
+
549
+
550
+ @dataclass
551
+ class EngineArgs:
552
+ """Arguments for vLLM engine."""
553
+
554
+ model: str
555
+ tokenizer: Optional[str] = None
556
+ tokenizer_mode: str = "auto"
557
+ trust_remote_code: bool = False
558
+ download_dir: Optional[str] = None
559
+ load_format: str = "auto"
560
+ dtype: str = "auto"
561
+ seed: int = 0
562
+ max_model_len: Optional[int] = None
563
+ worker_use_ray: bool = False
564
+ pipeline_parallel_size: int = 1
565
+ tensor_parallel_size: int = 1
566
+ max_parallel_loading_workers: Optional[int] = None
567
+ block_size: int = 16
568
+ swap_space: int = 4 # GiB
569
+ gpu_memory_utilization: float = 0.90
570
+ max_num_batched_tokens: Optional[int] = None
571
+ max_num_seqs: int = 256
572
+ max_paddings: int = 256
573
+ disable_log_stats: bool = False
574
+ revision: Optional[str] = None
575
+ tokenizer_revision: Optional[str] = None
576
+ quantization: Optional[str] = None
577
+ enforce_eager: bool = False
578
+ max_context_len_to_capture: int = 8192
579
+ num_audio_tokens: int = 1024
580
+ num_text_tokens: int = 80
581
+
582
+ def __post_init__(self):
583
+ if self.tokenizer is None:
584
+ self.tokenizer = self.model
585
+
586
+ @staticmethod
587
+ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
588
+ """Shared CLI arguments for vLLM engine."""
589
+
590
+ # NOTE: If you update any of the arguments below, please also
591
+ # make sure to update docs/source/models/engine_args.rst
592
+
593
+ # Model arguments
594
+ parser.add_argument(
595
+ "--model",
596
+ type=str,
597
+ default="facebook/opt-125m",
598
+ help="name or path of the huggingface model to use",
599
+ )
600
+ parser.add_argument(
601
+ "--tokenizer",
602
+ type=str,
603
+ default=EngineArgs.tokenizer,
604
+ help="name or path of the huggingface tokenizer to use",
605
+ )
606
+ parser.add_argument(
607
+ "--revision",
608
+ type=str,
609
+ default=None,
610
+ help="the specific model version to use. It can be a branch "
611
+ "name, a tag name, or a commit id. If unspecified, will use "
612
+ "the default version.",
613
+ )
614
+ parser.add_argument(
615
+ "--tokenizer-revision",
616
+ type=str,
617
+ default=None,
618
+ help="the specific tokenizer version to use. It can be a branch "
619
+ "name, a tag name, or a commit id. If unspecified, will use "
620
+ "the default version.",
621
+ )
622
+ parser.add_argument(
623
+ "--tokenizer-mode",
624
+ type=str,
625
+ default=EngineArgs.tokenizer_mode,
626
+ choices=["auto", "slow"],
627
+ help='tokenizer mode. "auto" will use the fast '
628
+ 'tokenizer if available, and "slow" will '
629
+ "always use the slow tokenizer.",
630
+ )
631
+ parser.add_argument(
632
+ "--trust-remote-code",
633
+ action="store_true",
634
+ help="trust remote code from huggingface",
635
+ )
636
+ parser.add_argument(
637
+ "--download-dir",
638
+ type=str,
639
+ default=EngineArgs.download_dir,
640
+ help="directory to download and load the weights, "
641
+ "default to the default cache dir of "
642
+ "huggingface",
643
+ )
644
+ parser.add_argument(
645
+ "--load-format",
646
+ type=str,
647
+ default=EngineArgs.load_format,
648
+ choices=["auto", "pt", "safetensors", "npcache", "dummy"],
649
+ help="The format of the model weights to load. "
650
+ '"auto" will try to load the weights in the safetensors format '
651
+ "and fall back to the pytorch bin format if safetensors format "
652
+ "is not available. "
653
+ '"pt" will load the weights in the pytorch bin format. '
654
+ '"safetensors" will load the weights in the safetensors format. '
655
+ '"npcache" will load the weights in pytorch format and store '
656
+ "a numpy cache to speed up the loading. "
657
+ '"dummy" will initialize the weights with random values, '
658
+ "which is mainly for profiling.",
659
+ )
660
+ parser.add_argument(
661
+ "--dtype",
662
+ type=str,
663
+ default=EngineArgs.dtype,
664
+ choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
665
+ help="data type for model weights and activations. "
666
+ 'The "auto" option will use FP16 precision '
667
+ "for FP32 and FP16 models, and BF16 precision "
668
+ "for BF16 models.",
669
+ )
670
+ parser.add_argument(
671
+ "--max-model-len",
672
+ type=int,
673
+ default=None,
674
+ help="model context length. If unspecified, "
675
+ "will be automatically derived from the model.",
676
+ )
677
+ # Parallel arguments
678
+ parser.add_argument(
679
+ "--worker-use-ray",
680
+ action="store_true",
681
+ help="use Ray for distributed serving, will be "
682
+ "automatically set when using more than 1 GPU",
683
+ )
684
+ parser.add_argument(
685
+ "--pipeline-parallel-size",
686
+ "-pp",
687
+ type=int,
688
+ default=EngineArgs.pipeline_parallel_size,
689
+ help="number of pipeline stages",
690
+ )
691
+ parser.add_argument(
692
+ "--tensor-parallel-size",
693
+ "-tp",
694
+ type=int,
695
+ default=EngineArgs.tensor_parallel_size,
696
+ help="number of tensor parallel replicas",
697
+ )
698
+ parser.add_argument(
699
+ "--max-parallel-loading-workers",
700
+ type=int,
701
+ help="load model sequentially in multiple batches, "
702
+ "to avoid RAM OOM when using tensor "
703
+ "parallel and large models",
704
+ )
705
+ # KV cache arguments
706
+ parser.add_argument(
707
+ "--block-size",
708
+ type=int,
709
+ default=EngineArgs.block_size,
710
+ choices=[8, 16, 32],
711
+ help="token block size",
712
+ )
713
+ # TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
714
+ parser.add_argument(
715
+ "--seed", type=int, default=EngineArgs.seed, help="random seed"
716
+ )
717
+ parser.add_argument(
718
+ "--swap-space",
719
+ type=int,
720
+ default=EngineArgs.swap_space,
721
+ help="CPU swap space size (GiB) per GPU",
722
+ )
723
+ parser.add_argument(
724
+ "--gpu-memory-utilization",
725
+ type=float,
726
+ default=EngineArgs.gpu_memory_utilization,
727
+ help="the fraction of GPU memory to be used for "
728
+ "the model executor, which can range from 0 to 1."
729
+ "If unspecified, will use the default value of 0.9.",
730
+ )
731
+ parser.add_argument(
732
+ "--max-num-batched-tokens",
733
+ type=int,
734
+ default=EngineArgs.max_num_batched_tokens,
735
+ help="maximum number of batched tokens per " "iteration",
736
+ )
737
+ parser.add_argument(
738
+ "--max-num-seqs",
739
+ type=int,
740
+ default=EngineArgs.max_num_seqs,
741
+ help="maximum number of sequences per iteration",
742
+ )
743
+ parser.add_argument(
744
+ "--max-paddings",
745
+ type=int,
746
+ default=EngineArgs.max_paddings,
747
+ help="maximum number of paddings in a batch",
748
+ )
749
+ parser.add_argument(
750
+ "--disable-log-stats",
751
+ action="store_true",
752
+ help="disable logging statistics",
753
+ )
754
+ # Quantization settings.
755
+ parser.add_argument(
756
+ "--quantization",
757
+ "-q",
758
+ type=str,
759
+ choices=["awq", "gptq", "squeezellm", None],
760
+ default=None,
761
+ help="Method used to quantize the weights. If "
762
+ "None, we first check the `quantization_config` "
763
+ "attribute in the model config file. If that is "
764
+ "None, we assume the model weights are not "
765
+ "quantized and use `dtype` to determine the data "
766
+ "type of the weights.",
767
+ )
768
+ parser.add_argument(
769
+ "--enforce-eager",
770
+ action="store_true",
771
+ help="Always use eager-mode PyTorch. If False, "
772
+ "will use eager mode and CUDA graph in hybrid "
773
+ "for maximal performance and flexibility.",
774
+ )
775
+ parser.add_argument(
776
+ "--max-context-len-to-capture",
777
+ type=int,
778
+ default=EngineArgs.max_context_len_to_capture,
779
+ help="maximum context length covered by CUDA "
780
+ "graphs. When a sequence has context length "
781
+ "larger than this, we fall back to eager mode.",
782
+ )
783
+ return parser
784
+
785
+ @classmethod
786
+ def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs":
787
+ # Get the list of attributes of this dataclass.
788
+ attrs = [attr.name for attr in dataclasses.fields(cls)]
789
+ # Set the attributes from the parsed arguments.
790
+ engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
791
+ return engine_args
792
+
793
+ def create_engine_configs(
794
+ self,
795
+ ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
796
+ model_config = ModelConfig(
797
+ self.model,
798
+ self.tokenizer,
799
+ self.tokenizer_mode,
800
+ self.trust_remote_code,
801
+ self.download_dir,
802
+ self.load_format,
803
+ self.dtype,
804
+ self.seed,
805
+ self.revision,
806
+ self.tokenizer_revision,
807
+ self.max_model_len,
808
+ self.quantization,
809
+ self.enforce_eager,
810
+ self.max_context_len_to_capture,
811
+ self.num_audio_tokens,
812
+ self.num_text_tokens,
813
+ )
814
+ cache_config = CacheConfig(
815
+ self.block_size,
816
+ self.gpu_memory_utilization,
817
+ self.swap_space,
818
+ model_config.get_sliding_window(),
819
+ )
820
+ parallel_config = ParallelConfig(
821
+ self.pipeline_parallel_size,
822
+ self.tensor_parallel_size,
823
+ self.worker_use_ray,
824
+ self.max_parallel_loading_workers,
825
+ )
826
+ scheduler_config = SchedulerConfig(
827
+ self.max_num_batched_tokens,
828
+ self.max_num_seqs,
829
+ model_config.max_model_len,
830
+ self.max_paddings,
831
+ )
832
+ return model_config, cache_config, parallel_config, scheduler_config
833
+
834
+
835
+ @dataclass
836
+ class AsyncEngineArgs(EngineArgs):
837
+ """Arguments for asynchronous vLLM engine."""
838
+
839
+ engine_use_ray: bool = False
840
+ disable_log_requests: bool = False
841
+ max_log_len: Optional[int] = None
842
+
843
+ @staticmethod
844
+ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
845
+ parser = EngineArgs.add_cli_args(parser)
846
+ parser.add_argument(
847
+ "--engine-use-ray",
848
+ action="store_true",
849
+ help="use Ray to start the LLM engine in a "
850
+ "separate process as the server process.",
851
+ )
852
+ parser.add_argument(
853
+ "--disable-log-requests",
854
+ action="store_true",
855
+ help="disable logging requests",
856
+ )
857
+ parser.add_argument(
858
+ "--max-log-len",
859
+ type=int,
860
+ default=None,
861
+ help="max number of prompt characters or prompt "
862
+ "ID numbers being printed in log. "
863
+ "Default: unlimited.",
864
+ )
865
+ return parser
ChatTTS/model/velocity/llama.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Adapted from
3
+ # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
4
+ # Copyright 2023 The vLLM team.
5
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
6
+ #
7
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
8
+ # and OPT implementations in this library. It has been modified from its
9
+ # original forms to accommodate minor architectural differences compared
10
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
11
+ #
12
+ # Licensed under the Apache License, Version 2.0 (the "License");
13
+ # you may not use this file except in compliance with the License.
14
+ # You may obtain a copy of the License at
15
+ #
16
+ # http://www.apache.org/licenses/LICENSE-2.0
17
+ #
18
+ # Unless required by applicable law or agreed to in writing, software
19
+ # distributed under the License is distributed on an "AS IS" BASIS,
20
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21
+ # See the License for the specific language governing permissions and
22
+ # limitations under the License.
23
+ """Inference-only LLaMA model compatible with HuggingFace weights."""
24
+ from typing import Any, Dict, List, Optional, Tuple
25
+
26
+ import torch
27
+ from torch import nn
28
+ from transformers import LlamaConfig
29
+
30
+ from vllm.model_executor.input_metadata import InputMetadata
31
+ from vllm.model_executor.layers.activation import SiluAndMul
32
+ from vllm.model_executor.layers.attention import PagedAttention
33
+ from vllm.model_executor.layers.layernorm import RMSNorm
34
+ from vllm.model_executor.layers.linear import (
35
+ LinearMethodBase,
36
+ MergedColumnParallelLinear,
37
+ QKVParallelLinear,
38
+ RowParallelLinear,
39
+ )
40
+ from vllm.model_executor.layers.rotary_embedding import get_rope
41
+ from vllm.model_executor.layers.sampler import Sampler
42
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
43
+ VocabParallelEmbedding,
44
+ ParallelLMHead,
45
+ )
46
+ from vllm.model_executor.parallel_utils.parallel_state import (
47
+ get_tensor_model_parallel_world_size,
48
+ )
49
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
50
+ from vllm.model_executor.weight_utils import (
51
+ default_weight_loader,
52
+ hf_model_weights_iterator,
53
+ )
54
+ from vllm.sequence import SamplerOutput
55
+
56
+ KVCache = Tuple[torch.Tensor, torch.Tensor]
57
+
58
+
59
+ class LlamaMLP(nn.Module):
60
+
61
+ def __init__(
62
+ self,
63
+ hidden_size: int,
64
+ intermediate_size: int,
65
+ hidden_act: str,
66
+ linear_method: Optional[LinearMethodBase] = None,
67
+ ) -> None:
68
+ super().__init__()
69
+ self.gate_up_proj = MergedColumnParallelLinear(
70
+ hidden_size,
71
+ [intermediate_size] * 2,
72
+ bias=False,
73
+ linear_method=linear_method,
74
+ )
75
+ self.down_proj = RowParallelLinear(
76
+ intermediate_size, hidden_size, bias=False, linear_method=linear_method
77
+ )
78
+ if hidden_act != "silu":
79
+ raise ValueError(
80
+ f"Unsupported activation: {hidden_act}. "
81
+ "Only silu is supported for now."
82
+ )
83
+ self.act_fn = SiluAndMul()
84
+
85
+ def forward(self, x):
86
+ gate_up, _ = self.gate_up_proj(x)
87
+ x = self.act_fn(gate_up)
88
+ x, _ = self.down_proj(x)
89
+ return x
90
+
91
+
92
+ class LlamaAttention(nn.Module):
93
+
94
+ def __init__(
95
+ self,
96
+ hidden_size: int,
97
+ num_heads: int,
98
+ num_kv_heads: int,
99
+ rope_theta: float = 10000,
100
+ rope_scaling: Optional[Dict[str, Any]] = None,
101
+ max_position_embeddings: int = 8192,
102
+ linear_method: Optional[LinearMethodBase] = None,
103
+ ) -> None:
104
+ super().__init__()
105
+ self.hidden_size = hidden_size
106
+ tp_size = get_tensor_model_parallel_world_size()
107
+ self.total_num_heads = num_heads
108
+ assert self.total_num_heads % tp_size == 0
109
+ self.num_heads = self.total_num_heads // tp_size
110
+ self.total_num_kv_heads = num_kv_heads
111
+ if self.total_num_kv_heads >= tp_size:
112
+ # Number of KV heads is greater than TP size, so we partition
113
+ # the KV heads across multiple tensor parallel GPUs.
114
+ assert self.total_num_kv_heads % tp_size == 0
115
+ else:
116
+ # Number of KV heads is less than TP size, so we replicate
117
+ # the KV heads across multiple tensor parallel GPUs.
118
+ assert tp_size % self.total_num_kv_heads == 0
119
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
120
+ self.head_dim = hidden_size // self.total_num_heads
121
+ self.q_size = self.num_heads * self.head_dim
122
+ self.kv_size = self.num_kv_heads * self.head_dim
123
+ self.scaling = self.head_dim**-0.5
124
+ self.rope_theta = rope_theta
125
+ self.max_position_embeddings = max_position_embeddings
126
+
127
+ self.qkv_proj = QKVParallelLinear(
128
+ hidden_size,
129
+ self.head_dim,
130
+ self.total_num_heads,
131
+ self.total_num_kv_heads,
132
+ bias=False,
133
+ linear_method=linear_method,
134
+ )
135
+ self.o_proj = RowParallelLinear(
136
+ self.total_num_heads * self.head_dim,
137
+ hidden_size,
138
+ bias=False,
139
+ linear_method=linear_method,
140
+ )
141
+
142
+ self.rotary_emb = get_rope(
143
+ self.head_dim,
144
+ rotary_dim=self.head_dim,
145
+ max_position=max_position_embeddings,
146
+ base=rope_theta,
147
+ rope_scaling=rope_scaling,
148
+ )
149
+ self.attn = PagedAttention(
150
+ self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads
151
+ )
152
+
153
+ def forward(
154
+ self,
155
+ positions: torch.Tensor,
156
+ hidden_states: torch.Tensor,
157
+ kv_cache: KVCache,
158
+ input_metadata: InputMetadata,
159
+ ) -> torch.Tensor:
160
+ qkv, _ = self.qkv_proj(hidden_states)
161
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
162
+ q, k = self.rotary_emb(positions, q, k)
163
+ k_cache, v_cache = kv_cache
164
+ attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
165
+ output, _ = self.o_proj(attn_output)
166
+ return output
167
+
168
+
169
+ class LlamaDecoderLayer(nn.Module):
170
+
171
+ def __init__(
172
+ self,
173
+ config: LlamaConfig,
174
+ linear_method: Optional[LinearMethodBase] = None,
175
+ ) -> None:
176
+ super().__init__()
177
+ self.hidden_size = config.hidden_size
178
+ rope_theta = getattr(config, "rope_theta", 10000)
179
+ rope_scaling = getattr(config, "rope_scaling", None)
180
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
181
+ self.self_attn = LlamaAttention(
182
+ hidden_size=self.hidden_size,
183
+ num_heads=config.num_attention_heads,
184
+ num_kv_heads=config.num_key_value_heads,
185
+ rope_theta=rope_theta,
186
+ rope_scaling=rope_scaling,
187
+ max_position_embeddings=max_position_embeddings,
188
+ linear_method=linear_method,
189
+ )
190
+ self.mlp = LlamaMLP(
191
+ hidden_size=self.hidden_size,
192
+ intermediate_size=config.intermediate_size,
193
+ hidden_act=config.hidden_act,
194
+ linear_method=linear_method,
195
+ )
196
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
197
+ self.post_attention_layernorm = RMSNorm(
198
+ config.hidden_size, eps=config.rms_norm_eps
199
+ )
200
+
201
+ def forward(
202
+ self,
203
+ positions: torch.Tensor,
204
+ hidden_states: torch.Tensor,
205
+ kv_cache: KVCache,
206
+ input_metadata: InputMetadata,
207
+ residual: Optional[torch.Tensor],
208
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
209
+ # Self Attention
210
+ if residual is None:
211
+ residual = hidden_states
212
+ hidden_states = self.input_layernorm(hidden_states)
213
+ else:
214
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
215
+ hidden_states = self.self_attn(
216
+ positions=positions,
217
+ hidden_states=hidden_states,
218
+ kv_cache=kv_cache,
219
+ input_metadata=input_metadata,
220
+ )
221
+
222
+ # Fully Connected
223
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
224
+ hidden_states = self.mlp(hidden_states)
225
+ return hidden_states, residual
226
+
227
+
228
+ class LlamaModel(nn.Module):
229
+
230
+ def __init__(
231
+ self,
232
+ config: LlamaConfig,
233
+ linear_method: Optional[LinearMethodBase] = None,
234
+ ) -> None:
235
+ super().__init__()
236
+ self.config = config
237
+ self.padding_idx = config.pad_token_id
238
+ self.vocab_size = config.vocab_size
239
+ self.embed_tokens = VocabParallelEmbedding(
240
+ config.vocab_size,
241
+ config.hidden_size,
242
+ )
243
+ self.layers = nn.ModuleList(
244
+ [
245
+ LlamaDecoderLayer(config, linear_method)
246
+ for _ in range(config.num_hidden_layers)
247
+ ]
248
+ )
249
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
250
+
251
+ def forward(
252
+ self,
253
+ input_emb: torch.Tensor,
254
+ positions: torch.Tensor,
255
+ kv_caches: List[KVCache],
256
+ input_metadata: InputMetadata,
257
+ ) -> torch.Tensor:
258
+ hidden_states = input_emb
259
+ residual = None
260
+ for i in range(len(self.layers)):
261
+ layer = self.layers[i]
262
+ hidden_states, residual = layer(
263
+ positions,
264
+ hidden_states,
265
+ kv_caches[i],
266
+ input_metadata,
267
+ residual,
268
+ )
269
+ hidden_states, _ = self.norm(hidden_states, residual)
270
+ return hidden_states
271
+
272
+ def load_weights(
273
+ self,
274
+ model_name_or_path: str,
275
+ cache_dir: Optional[str] = None,
276
+ load_format: str = "auto",
277
+ revision: Optional[str] = None,
278
+ ):
279
+ stacked_params_mapping = [
280
+ # (param_name, shard_name, shard_id)
281
+ ("qkv_proj", "q_proj", "q"),
282
+ ("qkv_proj", "k_proj", "k"),
283
+ ("qkv_proj", "v_proj", "v"),
284
+ ("gate_up_proj", "gate_proj", 0),
285
+ ("gate_up_proj", "up_proj", 1),
286
+ ]
287
+ params_dict = dict(self.named_parameters())
288
+ for name, loaded_weight in hf_model_weights_iterator(
289
+ model_name_or_path, cache_dir, load_format, revision
290
+ ):
291
+ if "rotary_emb.inv_freq" in name:
292
+ continue
293
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
294
+ # Models trained using ColossalAI may include these tensors in
295
+ # the checkpoint. Skip them.
296
+ continue
297
+ for param_name, weight_name, shard_id in stacked_params_mapping:
298
+ if weight_name not in name:
299
+ continue
300
+ name = name.replace(weight_name, param_name)
301
+ # Skip loading extra bias for GPTQ models.
302
+ if name.endswith(".bias") and name not in params_dict:
303
+ continue
304
+ param = params_dict[name]
305
+ weight_loader = param.weight_loader
306
+ weight_loader(param, loaded_weight, shard_id)
307
+ break
308
+ else:
309
+ # Skip loading extra bias for GPTQ models.
310
+ if name.endswith(".bias") and name not in params_dict:
311
+ continue
312
+ param = params_dict[name]
313
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
314
+ weight_loader(param, loaded_weight)
315
+
316
+
317
+ class LlamaForCausalLM(nn.Module):
318
+
319
+ def __init__(
320
+ self,
321
+ config: LlamaConfig,
322
+ linear_method: Optional[LinearMethodBase] = None,
323
+ ) -> None:
324
+ super().__init__()
325
+ self.config = config
326
+ self.linear_method = linear_method
327
+ self.model = LlamaModel(config, linear_method)
328
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
329
+ self.sampler = Sampler(config.vocab_size)
330
+
331
+ def forward(
332
+ self,
333
+ input_ids: torch.Tensor,
334
+ positions: torch.Tensor,
335
+ kv_caches: List[KVCache],
336
+ input_metadata: InputMetadata,
337
+ ) -> torch.Tensor:
338
+ hidden_states = self.model(input_ids, positions, kv_caches, input_metadata)
339
+ return hidden_states
340
+
341
+ def sample(
342
+ self,
343
+ hidden_states: torch.Tensor,
344
+ sampling_metadata: SamplingMetadata,
345
+ ) -> Optional[SamplerOutput]:
346
+ next_tokens = self.sampler(
347
+ self.lm_head.weight, hidden_states, sampling_metadata
348
+ )
349
+ return next_tokens
350
+
351
+ def load_weights(
352
+ self,
353
+ model_name_or_path: str,
354
+ cache_dir: Optional[str] = None,
355
+ load_format: str = "auto",
356
+ revision: Optional[str] = None,
357
+ ):
358
+ stacked_params_mapping = [
359
+ # (param_name, shard_name, shard_id)
360
+ ("qkv_proj", "q_proj", "q"),
361
+ ("qkv_proj", "k_proj", "k"),
362
+ ("qkv_proj", "v_proj", "v"),
363
+ ("gate_up_proj", "gate_proj", 0),
364
+ ("gate_up_proj", "up_proj", 1),
365
+ ]
366
+ params_dict = dict(self.named_parameters())
367
+ for name, loaded_weight in hf_model_weights_iterator(
368
+ model_name_or_path, cache_dir, load_format, revision
369
+ ):
370
+ if "rotary_emb.inv_freq" in name:
371
+ continue
372
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
373
+ # Models trained using ColossalAI may include these tensors in
374
+ # the checkpoint. Skip them.
375
+ continue
376
+ for param_name, weight_name, shard_id in stacked_params_mapping:
377
+ if weight_name not in name:
378
+ continue
379
+ name = name.replace(weight_name, param_name)
380
+ # Skip loading extra bias for GPTQ models.
381
+ if name.endswith(".bias") and name not in params_dict:
382
+ continue
383
+ param = params_dict[name]
384
+ weight_loader = param.weight_loader
385
+ weight_loader(param, loaded_weight, shard_id)
386
+ break
387
+ else:
388
+ # Skip loading extra bias for GPTQ models.
389
+ if name.endswith(".bias") and name not in params_dict:
390
+ continue
391
+ param = params_dict[name]
392
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
393
+ weight_loader(param, loaded_weight)
ChatTTS/model/velocity/llm.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+
3
+ from tqdm import tqdm
4
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
5
+ from vllm.utils import Counter
6
+
7
+ from .configs import EngineArgs
8
+ from .llm_engine import LLMEngine
9
+ from .output import RequestOutput
10
+ from .sampling_params import SamplingParams
11
+
12
+
13
+ class LLM:
14
+ """An LLM for generating texts from given prompts and sampling parameters.
15
+
16
+ This class includes a tokenizer, a language model (possibly distributed
17
+ across multiple GPUs), and GPU memory space allocated for intermediate
18
+ states (aka KV cache). Given a batch of prompts and sampling parameters,
19
+ this class generates texts from the model, using an intelligent batching
20
+ mechanism and efficient memory management.
21
+
22
+ NOTE: This class is intended to be used for offline inference. For online
23
+ serving, use the `AsyncLLMEngine` class instead.
24
+ NOTE: For the comprehensive list of arguments, see `EngineArgs`.
25
+
26
+ Args:
27
+ model: The name or path of a HuggingFace Transformers model.
28
+ tokenizer: The name or path of a HuggingFace Transformers tokenizer.
29
+ tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
30
+ if available, and "slow" will always use the slow tokenizer.
31
+ trust_remote_code: Trust remote code (e.g., from HuggingFace) when
32
+ downloading the model and tokenizer.
33
+ tensor_parallel_size: The number of GPUs to use for distributed
34
+ execution with tensor parallelism.
35
+ dtype: The data type for the model weights and activations. Currently,
36
+ we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
37
+ the `torch_dtype` attribute specified in the model config file.
38
+ However, if the `torch_dtype` in the config is `float32`, we will
39
+ use `float16` instead.
40
+ quantization: The method used to quantize the model weights. Currently,
41
+ we support "awq", "gptq" and "squeezellm". If None, we first check
42
+ the `quantization_config` attribute in the model config file. If
43
+ that is None, we assume the model weights are not quantized and use
44
+ `dtype` to determine the data type of the weights.
45
+ revision: The specific model version to use. It can be a branch name,
46
+ a tag name, or a commit id.
47
+ tokenizer_revision: The specific tokenizer version to use. It can be a
48
+ branch name, a tag name, or a commit id.
49
+ seed: The seed to initialize the random number generator for sampling.
50
+ gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
51
+ reserve for the model weights, activations, and KV cache. Higher
52
+ values will increase the KV cache size and thus improve the model's
53
+ throughput. However, if the value is too high, it may cause out-of-
54
+ memory (OOM) errors.
55
+ swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
56
+ This can be used for temporarily storing the states of the requests
57
+ when their `best_of` sampling parameters are larger than 1. If all
58
+ requests will have `best_of=1`, you can safely set this to 0.
59
+ Otherwise, too small values may cause out-of-memory (OOM) errors.
60
+ enforce_eager: Whether to enforce eager execution. If True, we will
61
+ disable CUDA graph and always execute the model in eager mode.
62
+ If False, we will use CUDA graph and eager execution in hybrid.
63
+ max_context_len_to_capture: Maximum context len covered by CUDA graphs.
64
+ When a sequence has context length larger than this, we fall back
65
+ to eager mode.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ model: str,
71
+ tokenizer: Optional[str] = None,
72
+ tokenizer_mode: str = "auto",
73
+ trust_remote_code: bool = False,
74
+ tensor_parallel_size: int = 1,
75
+ dtype: str = "auto",
76
+ quantization: Optional[str] = None,
77
+ revision: Optional[str] = None,
78
+ tokenizer_revision: Optional[str] = None,
79
+ seed: int = 0,
80
+ gpu_memory_utilization: float = 0.9,
81
+ swap_space: int = 4,
82
+ enforce_eager: bool = False,
83
+ max_context_len_to_capture: int = 8192,
84
+ post_model_path: str = None,
85
+ num_audio_tokens: int = 0,
86
+ num_text_tokens: int = 0,
87
+ **kwargs,
88
+ ) -> None:
89
+ if "disable_log_stats" not in kwargs:
90
+ kwargs["disable_log_stats"] = True
91
+ engine_args = EngineArgs(
92
+ model=model,
93
+ tokenizer=tokenizer,
94
+ tokenizer_mode=tokenizer_mode,
95
+ trust_remote_code=trust_remote_code,
96
+ tensor_parallel_size=tensor_parallel_size,
97
+ dtype=dtype,
98
+ quantization=quantization,
99
+ revision=revision,
100
+ tokenizer_revision=tokenizer_revision,
101
+ seed=seed,
102
+ gpu_memory_utilization=gpu_memory_utilization,
103
+ swap_space=swap_space,
104
+ enforce_eager=enforce_eager,
105
+ max_context_len_to_capture=max_context_len_to_capture,
106
+ num_audio_tokens=num_audio_tokens,
107
+ num_text_tokens=num_text_tokens,
108
+ **kwargs,
109
+ )
110
+ self.llm_engine = LLMEngine.from_engine_args(engine_args, post_model_path)
111
+ self.request_counter = Counter()
112
+
113
+ def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
114
+ return self.llm_engine.tokenizer
115
+
116
+ def set_tokenizer(
117
+ self,
118
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
119
+ ) -> None:
120
+ self.llm_engine.tokenizer = tokenizer
121
+
122
+ def generate(
123
+ self,
124
+ prompts: Optional[Union[str, List[str]]] = None,
125
+ sampling_params: Optional[SamplingParams] = None,
126
+ prompt_token_ids: Optional[List[List[int]]] = None,
127
+ use_tqdm: bool = True,
128
+ ) -> List[RequestOutput]:
129
+ """Generates the completions for the input prompts.
130
+
131
+ NOTE: This class automatically batches the given prompts, considering
132
+ the memory constraint. For the best performance, put all of your prompts
133
+ into a single list and pass it to this method.
134
+
135
+ Args:
136
+ prompts: A list of prompts to generate completions for.
137
+ sampling_params: The sampling parameters for text generation. If
138
+ None, we use the default sampling parameters.
139
+ prompt_token_ids: A list of token IDs for the prompts. If None, we
140
+ use the tokenizer to convert the prompts to token IDs.
141
+ use_tqdm: Whether to use tqdm to display the progress bar.
142
+
143
+ Returns:
144
+ A list of `RequestOutput` objects containing the generated
145
+ completions in the same order as the input prompts.
146
+ """
147
+ if prompts is None and prompt_token_ids is None:
148
+ raise ValueError("Either prompts or prompt_token_ids must be " "provided.")
149
+ if isinstance(prompts, str):
150
+ # Convert a single prompt to a list.
151
+ prompts = [prompts]
152
+ if (
153
+ prompts is not None
154
+ and prompt_token_ids is not None
155
+ and len(prompts) != len(prompt_token_ids)
156
+ ):
157
+ raise ValueError(
158
+ "The lengths of prompts and prompt_token_ids " "must be the same."
159
+ )
160
+ if sampling_params is None:
161
+ # Use default sampling params.
162
+ sampling_params = SamplingParams()
163
+
164
+ # Add requests to the engine.
165
+ num_requests = len(prompts) if prompts is not None else len(prompt_token_ids)
166
+ for i in range(num_requests):
167
+ prompt = prompts[i] if prompts is not None else None
168
+ token_ids = None if prompt_token_ids is None else prompt_token_ids[i]
169
+ self._add_request(prompt, sampling_params, token_ids)
170
+
171
+ rtns = self._run_engine(use_tqdm)
172
+ for i, rtn in enumerate(rtns):
173
+ token_ids = rtn.outputs[0].token_ids
174
+ for j, token_id in enumerate(token_ids):
175
+ if len(token_id) == 1:
176
+ token_ids[j] = token_id[0]
177
+ else:
178
+ token_ids[j] = list(token_id)
179
+
180
+ return rtns
181
+
182
+ def _add_request(
183
+ self,
184
+ prompt: Optional[str],
185
+ sampling_params: SamplingParams,
186
+ prompt_token_ids: Optional[List[int]],
187
+ ) -> None:
188
+ request_id = str(next(self.request_counter))
189
+ self.llm_engine.add_request(
190
+ request_id, prompt, sampling_params, prompt_token_ids
191
+ )
192
+
193
+ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
194
+ # Initialize tqdm.
195
+ if use_tqdm:
196
+ num_requests = self.llm_engine.get_num_unfinished_requests()
197
+ pbar = tqdm(total=num_requests, desc="Processed prompts")
198
+ # Run the engine.
199
+ outputs: List[RequestOutput] = []
200
+ while self.llm_engine.has_unfinished_requests():
201
+ step_outputs = self.llm_engine.step()
202
+ for output in step_outputs:
203
+ if output.finished:
204
+ outputs.append(output)
205
+ if use_tqdm:
206
+ pbar.update(1)
207
+ if use_tqdm:
208
+ pbar.close()
209
+ # Sort the outputs by request ID.
210
+ # This is necessary because some requests may be finished earlier than
211
+ # its previous requests.
212
+ outputs = sorted(outputs, key=lambda x: int(x.request_id))
213
+ return outputs
ChatTTS/model/velocity/llm_engine.py ADDED
@@ -0,0 +1,833 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from collections import defaultdict
3
+ import os
4
+ import time
5
+ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
6
+
7
+ from vllm.config import CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig
8
+ from .scheduler import Scheduler, SchedulerOutputs
9
+ from .configs import EngineArgs
10
+ from vllm.engine.metrics import record_metrics
11
+ from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray
12
+ from vllm.logger import init_logger
13
+ from .output import RequestOutput
14
+ from .sampling_params import SamplingParams
15
+ from .sequence import (
16
+ SamplerOutput,
17
+ Sequence,
18
+ SequenceGroup,
19
+ SequenceGroupOutput,
20
+ SequenceOutput,
21
+ SequenceStatus,
22
+ )
23
+ from vllm.transformers_utils.tokenizer import detokenize_incrementally, get_tokenizer
24
+ from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port
25
+ import numpy as np
26
+
27
+ if ray:
28
+ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
29
+
30
+ if TYPE_CHECKING:
31
+ from ray.util.placement_group import PlacementGroup
32
+
33
+ logger = init_logger(__name__)
34
+
35
+ _LOGGING_INTERVAL_SEC = 5
36
+
37
+
38
+ class LLMEngine:
39
+ """An LLM engine that receives requests and generates texts.
40
+
41
+ This is the main class for the vLLM engine. It receives requests
42
+ from clients and generates texts from the LLM. It includes a tokenizer, a
43
+ language model (possibly distributed across multiple GPUs), and GPU memory
44
+ space allocated for intermediate states (aka KV cache). This class utilizes
45
+ iteration-level scheduling and efficient memory management to maximize the
46
+ serving throughput.
47
+
48
+ The `LLM` class wraps this class for offline batched inference and the
49
+ `AsyncLLMEngine` class wraps this class for online serving.
50
+
51
+ NOTE: The config arguments are derived from the `EngineArgs` class. For the
52
+ comprehensive list of arguments, see `EngineArgs`.
53
+
54
+ Args:
55
+ model_config: The configuration related to the LLM model.
56
+ cache_config: The configuration related to the KV cache memory
57
+ management.
58
+ parallel_config: The configuration related to distributed execution.
59
+ scheduler_config: The configuration related to the request scheduler.
60
+ placement_group: Ray placement group for distributed execution.
61
+ Required for distributed execution.
62
+ log_stats: Whether to log statistics.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ model_config: ModelConfig,
68
+ cache_config: CacheConfig,
69
+ parallel_config: ParallelConfig,
70
+ scheduler_config: SchedulerConfig,
71
+ placement_group: Optional["PlacementGroup"],
72
+ post_model_path: str,
73
+ log_stats: bool,
74
+ ) -> None:
75
+ logger.info(
76
+ "Initializing an LLM engine with config: "
77
+ f"model={model_config.model!r}, "
78
+ f"tokenizer={model_config.tokenizer!r}, "
79
+ f"tokenizer_mode={model_config.tokenizer_mode}, "
80
+ f"revision={model_config.revision}, "
81
+ f"tokenizer_revision={model_config.tokenizer_revision}, "
82
+ f"trust_remote_code={model_config.trust_remote_code}, "
83
+ f"dtype={model_config.dtype}, "
84
+ f"max_seq_len={model_config.max_model_len}, "
85
+ f"download_dir={model_config.download_dir!r}, "
86
+ f"load_format={model_config.load_format}, "
87
+ f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
88
+ f"quantization={model_config.quantization}, "
89
+ f"enforce_eager={model_config.enforce_eager}, "
90
+ f"seed={model_config.seed}), "
91
+ f"post_model_path={post_model_path!r}"
92
+ )
93
+ # TODO(woosuk): Print more configs in debug mode.
94
+
95
+ self.model_config = model_config
96
+ self.cache_config = cache_config
97
+ self.parallel_config = parallel_config
98
+ self.scheduler_config = scheduler_config
99
+ self.log_stats = log_stats
100
+ self._verify_args()
101
+ self.post_model_path = post_model_path
102
+ self.seq_counter = Counter()
103
+
104
+ # Create the parallel GPU workers.
105
+ if self.parallel_config.worker_use_ray:
106
+ # Disable Ray usage stats collection.
107
+ ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
108
+ if ray_usage != "1":
109
+ os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
110
+ self._init_workers_ray(placement_group)
111
+ else:
112
+ self._init_workers()
113
+
114
+ # Profile the memory usage and initialize the cache.
115
+ self._init_cache()
116
+
117
+ # Create the scheduler.
118
+ self.scheduler = Scheduler(scheduler_config, cache_config)
119
+
120
+ # Logging.
121
+ self.last_logging_time = 0.0
122
+ # List of (timestamp, num_tokens)
123
+ self.num_prompt_tokens: List[Tuple[float, int]] = []
124
+ # List of (timestamp, num_tokens)
125
+ self.num_generation_tokens: List[Tuple[float, int]] = []
126
+
127
+ def _init_workers(self):
128
+ # Lazy import the Worker to avoid importing torch.cuda/xformers
129
+ # before CUDA_VISIBLE_DEVICES is set in the Worker
130
+ from .worker import Worker
131
+
132
+ assert (
133
+ self.parallel_config.world_size == 1
134
+ ), "Ray is required if parallel_config.world_size > 1."
135
+
136
+ self.workers: List[Worker] = []
137
+ distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}"
138
+ self.driver_worker = Worker(
139
+ self.model_config,
140
+ self.parallel_config,
141
+ self.scheduler_config,
142
+ local_rank=0,
143
+ rank=0,
144
+ distributed_init_method=distributed_init_method,
145
+ is_driver_worker=True,
146
+ post_model_path=self.post_model_path,
147
+ )
148
+ self._run_workers("init_model")
149
+ self._run_workers("load_model")
150
+
151
+ def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs):
152
+ if self.parallel_config.tensor_parallel_size == 1:
153
+ num_gpus = self.cache_config.gpu_memory_utilization
154
+ else:
155
+ num_gpus = 1
156
+
157
+ self.driver_dummy_worker: RayWorkerVllm = None
158
+ self.workers: List[RayWorkerVllm] = []
159
+
160
+ driver_ip = get_ip()
161
+ for bundle_id, bundle in enumerate(placement_group.bundle_specs):
162
+ if not bundle.get("GPU", 0):
163
+ continue
164
+ scheduling_strategy = PlacementGroupSchedulingStrategy(
165
+ placement_group=placement_group,
166
+ placement_group_capture_child_tasks=True,
167
+ placement_group_bundle_index=bundle_id,
168
+ )
169
+ worker = ray.remote(
170
+ num_cpus=0,
171
+ num_gpus=num_gpus,
172
+ scheduling_strategy=scheduling_strategy,
173
+ **ray_remote_kwargs,
174
+ )(RayWorkerVllm).remote(self.model_config.trust_remote_code)
175
+
176
+ worker_ip = ray.get(worker.get_node_ip.remote())
177
+ if worker_ip == driver_ip and self.driver_dummy_worker is None:
178
+ # If the worker is on the same node as the driver, we use it
179
+ # as the resource holder for the driver process.
180
+ self.driver_dummy_worker = worker
181
+ else:
182
+ self.workers.append(worker)
183
+
184
+ if self.driver_dummy_worker is None:
185
+ raise ValueError(
186
+ "Ray does not allocate any GPUs on the driver node. Consider "
187
+ "adjusting the Ray placement group or running the driver on a "
188
+ "GPU node."
189
+ )
190
+
191
+ driver_node_id, driver_gpu_ids = ray.get(
192
+ self.driver_dummy_worker.get_node_and_gpu_ids.remote()
193
+ )
194
+ worker_node_and_gpu_ids = ray.get(
195
+ [worker.get_node_and_gpu_ids.remote() for worker in self.workers]
196
+ )
197
+
198
+ node_workers = defaultdict(list)
199
+ node_gpus = defaultdict(list)
200
+
201
+ node_workers[driver_node_id].append(0)
202
+ node_gpus[driver_node_id].extend(driver_gpu_ids)
203
+ for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids, start=1):
204
+ node_workers[node_id].append(i)
205
+ node_gpus[node_id].extend(gpu_ids)
206
+ for node_id, gpu_ids in node_gpus.items():
207
+ node_gpus[node_id] = sorted(gpu_ids)
208
+
209
+ # Set CUDA_VISIBLE_DEVICES for the driver.
210
+ set_cuda_visible_devices(node_gpus[driver_node_id])
211
+ for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
212
+ worker.set_cuda_visible_devices.remote(node_gpus[node_id])
213
+
214
+ distributed_init_method = f"tcp://{driver_ip}:{get_open_port()}"
215
+
216
+ # Lazy import the Worker to avoid importing torch.cuda/xformers
217
+ # before CUDA_VISIBLE_DEVICES is set in the Worker
218
+ from vllm.worker.worker import Worker
219
+
220
+ # Initialize torch distributed process group for the workers.
221
+ model_config = copy.deepcopy(self.model_config)
222
+ parallel_config = copy.deepcopy(self.parallel_config)
223
+ scheduler_config = copy.deepcopy(self.scheduler_config)
224
+
225
+ for rank, (worker, (node_id, _)) in enumerate(
226
+ zip(self.workers, worker_node_and_gpu_ids), start=1
227
+ ):
228
+ local_rank = node_workers[node_id].index(rank)
229
+ worker.init_worker.remote(
230
+ lambda rank=rank, local_rank=local_rank: Worker(
231
+ model_config,
232
+ parallel_config,
233
+ scheduler_config,
234
+ local_rank,
235
+ rank,
236
+ distributed_init_method,
237
+ )
238
+ )
239
+
240
+ driver_rank = 0
241
+ driver_local_rank = node_workers[driver_node_id].index(driver_rank)
242
+ self.driver_worker = Worker(
243
+ model_config,
244
+ parallel_config,
245
+ scheduler_config,
246
+ driver_local_rank,
247
+ driver_rank,
248
+ distributed_init_method,
249
+ is_driver_worker=True,
250
+ )
251
+
252
+ self._run_workers("init_model")
253
+ self._run_workers(
254
+ "load_model",
255
+ max_concurrent_workers=self.parallel_config.max_parallel_loading_workers,
256
+ )
257
+
258
+ def _verify_args(self) -> None:
259
+ self.model_config.verify_with_parallel_config(self.parallel_config)
260
+ self.cache_config.verify_with_parallel_config(self.parallel_config)
261
+
262
+ def _init_cache(self) -> None:
263
+ """Profiles the memory usage and initializes the KV cache."""
264
+ # Get the maximum number of blocks that can be allocated on GPU and CPU.
265
+ num_blocks = self._run_workers(
266
+ "profile_num_available_blocks",
267
+ block_size=self.cache_config.block_size,
268
+ gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
269
+ cpu_swap_space=self.cache_config.swap_space_bytes,
270
+ )
271
+
272
+ # Since we use a shared centralized controller, we take the minimum
273
+ # number of blocks across all workers to make sure all the memory
274
+ # operators can be applied to all workers.
275
+ num_gpu_blocks = min(b[0] for b in num_blocks)
276
+ num_cpu_blocks = min(b[1] for b in num_blocks)
277
+ # FIXME(woosuk): Change to debug log.
278
+ logger.info(
279
+ f"# GPU blocks: {num_gpu_blocks}, " f"# CPU blocks: {num_cpu_blocks}"
280
+ )
281
+
282
+ if num_gpu_blocks <= 0:
283
+ raise ValueError(
284
+ "No available memory for the cache blocks. "
285
+ "Try increasing `gpu_memory_utilization` when "
286
+ "initializing the engine."
287
+ )
288
+ max_seq_len = self.cache_config.block_size * num_gpu_blocks
289
+ if self.model_config.max_model_len > max_seq_len:
290
+ raise ValueError(
291
+ f"The model's max seq len ({self.model_config.max_model_len}) "
292
+ "is larger than the maximum number of tokens that can be "
293
+ f"stored in KV cache ({max_seq_len}). Try increasing "
294
+ "`gpu_memory_utilization` or decreasing `max_model_len` when "
295
+ "initializing the engine."
296
+ )
297
+
298
+ self.cache_config.num_gpu_blocks = num_gpu_blocks
299
+ self.cache_config.num_cpu_blocks = num_cpu_blocks
300
+
301
+ # Initialize the cache.
302
+ self._run_workers("init_cache_engine", cache_config=self.cache_config)
303
+ # Warm up the model. This includes capturing the model into CUDA graph
304
+ # if enforce_eager is False.
305
+ self._run_workers("warm_up_model")
306
+
307
+ @classmethod
308
+ def from_engine_args(
309
+ cls, engine_args: EngineArgs, post_model_path=None
310
+ ) -> "LLMEngine":
311
+ """Creates an LLM engine from the engine arguments."""
312
+ # Create the engine configs.
313
+ engine_configs = engine_args.create_engine_configs()
314
+ parallel_config = engine_configs[2]
315
+ # Initialize the cluster.
316
+ placement_group = initialize_cluster(parallel_config)
317
+ # Create the LLM engine.
318
+ engine = cls(
319
+ *engine_configs,
320
+ placement_group,
321
+ log_stats=not engine_args.disable_log_stats,
322
+ post_model_path=post_model_path,
323
+ )
324
+ return engine
325
+
326
+ def add_request(
327
+ self,
328
+ request_id: str,
329
+ prompt: Optional[str],
330
+ sampling_params: SamplingParams,
331
+ prompt_token_ids: Optional[List[int]] = None,
332
+ arrival_time: Optional[float] = None,
333
+ ) -> None:
334
+ """Add a request to the engine's request pool.
335
+
336
+ The request is added to the request pool and will be processed by the
337
+ scheduler as `engine.step()` is called. The exact scheduling policy is
338
+ determined by the scheduler.
339
+
340
+ Args:
341
+ request_id: The unique ID of the request.
342
+ prompt: The prompt string. Can be None if prompt_token_ids is
343
+ provided.
344
+ sampling_params: The sampling parameters for text generation.
345
+ prompt_token_ids: The token IDs of the prompt. If None, we
346
+ use the tokenizer to convert the prompts to token IDs.
347
+ arrival_time: The arrival time of the request. If None, we use
348
+ the current monotonic time.
349
+ """
350
+ if arrival_time is None:
351
+ arrival_time = time.monotonic()
352
+
353
+ assert prompt_token_ids is not None, "prompt_token_ids must be provided"
354
+ # Create the sequences.
355
+ block_size = self.cache_config.block_size
356
+ seq_id = next(self.seq_counter)
357
+ seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
358
+
359
+ # Create the sequence group.
360
+ seq_group = SequenceGroup(request_id, [seq], sampling_params, arrival_time)
361
+
362
+ # Add the sequence group to the scheduler.
363
+ self.scheduler.add_seq_group(seq_group)
364
+
365
+ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
366
+ """Aborts a request(s) with the given ID.
367
+
368
+ Args:
369
+ request_id: The ID(s) of the request to abort.
370
+ """
371
+ self.scheduler.abort_seq_group(request_id)
372
+
373
+ def get_model_config(self) -> ModelConfig:
374
+ """Gets the model configuration."""
375
+ return self.model_config
376
+
377
+ def get_num_unfinished_requests(self) -> int:
378
+ """Gets the number of unfinished requests."""
379
+ return self.scheduler.get_num_unfinished_seq_groups()
380
+
381
+ def has_unfinished_requests(self) -> bool:
382
+ """Returns True if there are unfinished requests."""
383
+ return self.scheduler.has_unfinished_seqs()
384
+
385
+ def _check_beam_search_early_stopping(
386
+ self,
387
+ early_stopping: Union[bool, str],
388
+ sampling_params: SamplingParams,
389
+ best_running_seq: Sequence,
390
+ current_worst_seq: Sequence,
391
+ ) -> bool:
392
+ assert sampling_params.use_beam_search
393
+ length_penalty = sampling_params.length_penalty
394
+ if early_stopping is True:
395
+ return True
396
+
397
+ current_worst_score = current_worst_seq.get_beam_search_score(
398
+ length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id
399
+ )
400
+ if early_stopping is False:
401
+ highest_attainable_score = best_running_seq.get_beam_search_score(
402
+ length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id
403
+ )
404
+ else:
405
+ assert early_stopping == "never"
406
+ if length_penalty > 0.0:
407
+ # If length_penalty > 0.0, beam search will prefer longer
408
+ # sequences. The highest attainable score calculation is
409
+ # based on the longest possible sequence length in this case.
410
+ max_possible_length = max(
411
+ best_running_seq.get_prompt_len() + sampling_params.max_tokens,
412
+ self.scheduler_config.max_model_len,
413
+ )
414
+ highest_attainable_score = best_running_seq.get_beam_search_score(
415
+ length_penalty=length_penalty,
416
+ eos_token_id=self.tokenizer.eos_token_id,
417
+ seq_len=max_possible_length,
418
+ )
419
+ else:
420
+ # Otherwise, beam search will prefer shorter sequences. The
421
+ # highest attainable score calculation is based on the current
422
+ # sequence length.
423
+ highest_attainable_score = best_running_seq.get_beam_search_score(
424
+ length_penalty=length_penalty,
425
+ eos_token_id=self.tokenizer.eos_token_id,
426
+ )
427
+ return current_worst_score >= highest_attainable_score
428
+
429
+ def _process_sequence_group_outputs(
430
+ self, seq_group: SequenceGroup, outputs: SequenceGroupOutput
431
+ ) -> None:
432
+ # Process prompt logprobs
433
+ prompt_logprobs = outputs.prompt_logprobs
434
+ if prompt_logprobs is not None:
435
+ seq_group.prompt_logprobs = prompt_logprobs
436
+
437
+ # Process samples
438
+ samples = outputs.samples
439
+ parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
440
+ existing_finished_seqs = seq_group.get_finished_seqs()
441
+ parent_child_dict = {parent_seq.seq_id: [] for parent_seq in parent_seqs}
442
+ for sample in samples:
443
+ parent_child_dict[sample.parent_seq_id].append(sample)
444
+ # List of (child, parent)
445
+ child_seqs: List[Tuple[Sequence, Sequence]] = []
446
+
447
+ # Process the child samples for each parent sequence
448
+ for parent in parent_seqs:
449
+ child_samples: List[SequenceOutput] = parent_child_dict[parent.seq_id]
450
+ if len(child_samples) == 0:
451
+ # This parent sequence has no children samples. Remove
452
+ # the parent sequence from the sequence group since it will
453
+ # not be used in the future iterations.
454
+ parent.status = SequenceStatus.FINISHED_ABORTED
455
+ seq_group.remove(parent.seq_id)
456
+ self.scheduler.free_seq(parent)
457
+ continue
458
+ # Fork the parent sequence if there are multiple child samples.
459
+ for child_sample in child_samples[:-1]:
460
+ new_child_seq_id = next(self.seq_counter)
461
+ child = parent.fork(new_child_seq_id)
462
+ child.append_token_id(
463
+ child_sample.output_token,
464
+ child_sample.logprobs,
465
+ child_sample.hidden_states,
466
+ child_sample.finished,
467
+ )
468
+ child_seqs.append((child, parent))
469
+ # Continue the parent sequence for the last child sample.
470
+ # We reuse the parent sequence here to reduce redundant memory
471
+ # copies, especially when using non-beam search sampling methods.
472
+ last_child_sample = child_samples[-1]
473
+ parent.append_token_id(
474
+ last_child_sample.output_token,
475
+ last_child_sample.logprobs,
476
+ last_child_sample.hidden_states,
477
+ last_child_sample.finished,
478
+ )
479
+ child_seqs.append((parent, parent))
480
+
481
+ for seq, _ in child_seqs:
482
+ # self._decode_sequence(seq, seq_group.sampling_params)
483
+ self._check_stop(seq, seq_group.sampling_params)
484
+
485
+ # Non-beam search case
486
+ if not seq_group.sampling_params.use_beam_search:
487
+ # For newly created child sequences, add them to the sequence group
488
+ # and fork them in block manager if they are not finished.
489
+ for seq, parent in child_seqs:
490
+ if seq is not parent:
491
+ seq_group.add(seq)
492
+ if not seq.is_finished():
493
+ self.scheduler.fork_seq(parent, seq)
494
+
495
+ # Free the finished and selected parent sequences' memory in block
496
+ # manager. Keep them in the sequence group as candidate output.
497
+ # NOTE: we need to fork the new sequences before freeing the
498
+ # old sequences.
499
+ for seq, parent in child_seqs:
500
+ if seq is parent and seq.is_finished():
501
+ self.scheduler.free_seq(seq)
502
+ return
503
+
504
+ # Beam search case
505
+ # Select the child sequences to keep in the sequence group.
506
+ selected_child_seqs = []
507
+ unselected_child_seqs = []
508
+ beam_width = seq_group.sampling_params.best_of
509
+ length_penalty = seq_group.sampling_params.length_penalty
510
+
511
+ # Select the newly finished sequences with the highest scores
512
+ # to replace existing finished sequences.
513
+ # Tuple of (seq, parent, is_new)
514
+ existing_finished_seqs = [(seq, None, False) for seq in existing_finished_seqs]
515
+ new_finished_seqs = [
516
+ (seq, parent, True) for seq, parent in child_seqs if seq.is_finished()
517
+ ]
518
+ all_finished_seqs = existing_finished_seqs + new_finished_seqs
519
+ # Sort the finished sequences by their scores.
520
+ all_finished_seqs.sort(
521
+ key=lambda x: x[0].get_beam_search_score(
522
+ length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id
523
+ ),
524
+ reverse=True,
525
+ )
526
+ for seq, parent, is_new in all_finished_seqs[:beam_width]:
527
+ if is_new:
528
+ # A newly generated child sequence finishes and has a high
529
+ # score, so we will add it into the sequence group.
530
+ selected_child_seqs.append((seq, parent))
531
+ for seq, parent, is_new in all_finished_seqs[beam_width:]:
532
+ if is_new:
533
+ # A newly generated child sequence finishes but has a low
534
+ # score, so we will not add it into the sequence group.
535
+ # Additionally, if this sequence is a continuation of a
536
+ # parent sequence, we will need remove the parent sequence
537
+ # from the sequence group.
538
+ unselected_child_seqs.append((seq, parent))
539
+ else:
540
+ # An existing finished sequence has a low score, so we will
541
+ # remove it from the sequence group.
542
+ seq_group.remove(seq.seq_id)
543
+
544
+ # select the top beam_width sequences from the running
545
+ # sequences for the next iteration to continue the beam
546
+ # search.
547
+ running_child_seqs = [
548
+ (seq, parent) for seq, parent in child_seqs if not seq.is_finished()
549
+ ]
550
+ # Sort the running sequences by their scores.
551
+ running_child_seqs.sort(
552
+ key=lambda x: x[0].get_beam_search_score(
553
+ length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id
554
+ ),
555
+ reverse=True,
556
+ )
557
+
558
+ # Check if we can stop the beam search.
559
+ if len(running_child_seqs) == 0:
560
+ # No running sequences, stop the beam search.
561
+ stop_beam_search = True
562
+ elif len(all_finished_seqs) < beam_width:
563
+ # Not enough finished sequences, continue the beam search.
564
+ stop_beam_search = False
565
+ else:
566
+ # Check the early stopping criteria
567
+ best_running_seq = running_child_seqs[0][0]
568
+ current_worst_seq = all_finished_seqs[beam_width - 1][0]
569
+ stop_beam_search = self._check_beam_search_early_stopping(
570
+ seq_group.sampling_params.early_stopping,
571
+ seq_group.sampling_params,
572
+ best_running_seq,
573
+ current_worst_seq,
574
+ )
575
+
576
+ if stop_beam_search:
577
+ # Stop the beam search and remove all the running sequences from
578
+ # the sequence group.
579
+ unselected_child_seqs.extend(running_child_seqs)
580
+ else:
581
+ # Continue the beam search and select the top beam_width sequences
582
+ # to continue the beam search.
583
+ selected_child_seqs.extend(running_child_seqs[:beam_width])
584
+ # The remaining running sequences will not be used in the next
585
+ # iteration. Again, if these sequences are continuations of
586
+ # parent sequences, we will need to remove the parent sequences
587
+ # from the sequence group.
588
+ unselected_child_seqs.extend(running_child_seqs[beam_width:])
589
+
590
+ # For newly created child sequences, add them to the sequence group
591
+ # and fork them in block manager if they are not finished.
592
+ for seq, parent in selected_child_seqs:
593
+ if seq is not parent:
594
+ seq_group.add(seq)
595
+ if not seq.is_finished():
596
+ self.scheduler.fork_seq(parent, seq)
597
+
598
+ # Free the finished and selected parent sequences' memory in block
599
+ # manager. Keep them in the sequence group as candidate output.
600
+ for seq, parent in selected_child_seqs:
601
+ if seq is parent and seq.is_finished():
602
+ self.scheduler.free_seq(seq)
603
+
604
+ # Remove the unselected parent sequences from the sequence group and
605
+ # free their memory in block manager.
606
+ for seq, parent in unselected_child_seqs:
607
+ if seq is parent:
608
+ # Remove the parent sequence if it is not selected for next
609
+ # iteration
610
+ seq_group.remove(seq.seq_id)
611
+ self.scheduler.free_seq(seq)
612
+
613
+ def _process_model_outputs(
614
+ self, output: SamplerOutput, scheduler_outputs: SchedulerOutputs
615
+ ) -> List[RequestOutput]:
616
+ # Update the scheduled sequence groups with the model outputs.
617
+ scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
618
+ for seq_group, outputs in zip(scheduled_seq_groups, output):
619
+ self._process_sequence_group_outputs(seq_group, outputs)
620
+
621
+ # Free the finished sequence groups.
622
+ self.scheduler.free_finished_seq_groups()
623
+
624
+ # Create the outputs.
625
+ request_outputs: List[RequestOutput] = []
626
+ for seq_group in scheduled_seq_groups + scheduler_outputs.ignored_seq_groups:
627
+ request_output = RequestOutput.from_seq_group(seq_group)
628
+ request_outputs.append(request_output)
629
+
630
+ if self.log_stats:
631
+ # Log the system stats.
632
+ self._log_system_stats(
633
+ scheduler_outputs.prompt_run, scheduler_outputs.num_batched_tokens
634
+ )
635
+ return request_outputs
636
+
637
+ def step(self) -> List[RequestOutput]:
638
+ """Performs one decoding iteration and returns newly generated results.
639
+
640
+ This function performs one decoding iteration of the engine. It first
641
+ schedules the sequences to be executed in the next iteration and the
642
+ token blocks to be swapped in/out/copy. Then, it executes the model
643
+ and updates the scheduler with the model outputs. Finally, it decodes
644
+ the sequences and returns the newly generated results.
645
+ """
646
+ seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
647
+
648
+ if not scheduler_outputs.is_empty():
649
+ # Execute the model.
650
+ all_outputs = self._run_workers(
651
+ "execute_model",
652
+ driver_kwargs={
653
+ "seq_group_metadata_list": seq_group_metadata_list,
654
+ "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
655
+ "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
656
+ "blocks_to_copy": scheduler_outputs.blocks_to_copy,
657
+ },
658
+ )
659
+
660
+ # Only the driver worker returns the sampling results.
661
+ output = all_outputs[0]
662
+ else:
663
+ output = []
664
+
665
+ return self._process_model_outputs(output, scheduler_outputs)
666
+
667
+ def _log_system_stats(
668
+ self,
669
+ prompt_run: bool,
670
+ num_batched_tokens: int,
671
+ ) -> None:
672
+ now = time.monotonic()
673
+ # Log the number of batched input tokens.
674
+ if prompt_run:
675
+ self.num_prompt_tokens.append((now, num_batched_tokens))
676
+ else:
677
+ self.num_generation_tokens.append((now, num_batched_tokens))
678
+
679
+ should_log = now - self.last_logging_time >= _LOGGING_INTERVAL_SEC
680
+ if not should_log:
681
+ return
682
+
683
+ # Discard the old stats.
684
+ self.num_prompt_tokens = [
685
+ (t, n) for t, n in self.num_prompt_tokens if now - t < _LOGGING_INTERVAL_SEC
686
+ ]
687
+ self.num_generation_tokens = [
688
+ (t, n)
689
+ for t, n in self.num_generation_tokens
690
+ if now - t < _LOGGING_INTERVAL_SEC
691
+ ]
692
+
693
+ if len(self.num_prompt_tokens) > 1:
694
+ total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1])
695
+ window = now - self.num_prompt_tokens[0][0]
696
+ avg_prompt_throughput = total_num_tokens / window
697
+ else:
698
+ avg_prompt_throughput = 0.0
699
+ if len(self.num_generation_tokens) > 1:
700
+ total_num_tokens = sum(n for _, n in self.num_generation_tokens[:-1])
701
+ window = now - self.num_generation_tokens[0][0]
702
+ avg_generation_throughput = total_num_tokens / window
703
+ else:
704
+ avg_generation_throughput = 0.0
705
+
706
+ total_num_gpu_blocks = self.cache_config.num_gpu_blocks
707
+ num_free_gpu_blocks = self.scheduler.block_manager.get_num_free_gpu_blocks()
708
+ num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
709
+ gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
710
+
711
+ total_num_cpu_blocks = self.cache_config.num_cpu_blocks
712
+ if total_num_cpu_blocks > 0:
713
+ num_free_cpu_blocks = self.scheduler.block_manager.get_num_free_cpu_blocks()
714
+ num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
715
+ cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
716
+ else:
717
+ cpu_cache_usage = 0.0
718
+
719
+ record_metrics(
720
+ avg_prompt_throughput=avg_prompt_throughput,
721
+ avg_generation_throughput=avg_generation_throughput,
722
+ scheduler_running=len(self.scheduler.running),
723
+ scheduler_swapped=len(self.scheduler.swapped),
724
+ scheduler_waiting=len(self.scheduler.waiting),
725
+ gpu_cache_usage=gpu_cache_usage,
726
+ cpu_cache_usage=cpu_cache_usage,
727
+ )
728
+
729
+ logger.info(
730
+ "Avg prompt throughput: "
731
+ f"{avg_prompt_throughput:.1f} tokens/s, "
732
+ "Avg generation throughput: "
733
+ f"{avg_generation_throughput:.1f} tokens/s, "
734
+ f"Running: {len(self.scheduler.running)} reqs, "
735
+ f"Swapped: {len(self.scheduler.swapped)} reqs, "
736
+ f"Pending: {len(self.scheduler.waiting)} reqs, "
737
+ f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
738
+ f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%"
739
+ )
740
+ self.last_logging_time = now
741
+
742
+ def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
743
+ """Decodes the new token for a sequence."""
744
+ (new_tokens, new_output_text, prefix_offset, read_offset) = (
745
+ detokenize_incrementally(
746
+ self.tokenizer,
747
+ all_input_ids=seq.get_token_ids(),
748
+ prev_tokens=seq.tokens,
749
+ prefix_offset=seq.prefix_offset,
750
+ read_offset=seq.read_offset,
751
+ skip_special_tokens=prms.skip_special_tokens,
752
+ spaces_between_special_tokens=prms.spaces_between_special_tokens,
753
+ )
754
+ )
755
+ if seq.tokens is None:
756
+ seq.tokens = new_tokens
757
+ else:
758
+ seq.tokens.extend(new_tokens)
759
+ seq.prefix_offset = prefix_offset
760
+ seq.read_offset = read_offset
761
+ seq.output_text += new_output_text
762
+
763
+ def _check_stop(self, seq: Sequence, sampling_params: SamplingParams) -> None:
764
+ """Stop the finished sequences."""
765
+ for stop_str in sampling_params.stop:
766
+ if seq.output_text.endswith(stop_str):
767
+ if not sampling_params.include_stop_str_in_output:
768
+ # Truncate the output text so that the stop string is
769
+ # not included in the output.
770
+ seq.output_text = seq.output_text[: -len(stop_str)]
771
+ seq.status = SequenceStatus.FINISHED_STOPPED
772
+ return
773
+ if seq.data.finished:
774
+ seq.status = SequenceStatus.FINISHED_STOPPED
775
+ return
776
+
777
+ for token_id in seq.get_last_token_id():
778
+ if token_id == sampling_params.eos_token:
779
+ seq.status = SequenceStatus.FINISHED_STOPPED
780
+ return
781
+
782
+ # Check if the sequence has reached max_model_len.
783
+ if seq.get_len() > self.scheduler_config.max_model_len:
784
+ seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
785
+ return
786
+
787
+ # Check if the sequence has reached max_tokens.
788
+ if seq.get_output_len() == sampling_params.max_tokens:
789
+ seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
790
+ return
791
+
792
+ # Check if the sequence has generated the EOS token.
793
+ if (not sampling_params.ignore_eos) and seq.get_last_token_id()[
794
+ 0
795
+ ] == sampling_params.eos_token:
796
+ seq.status = SequenceStatus.FINISHED_STOPPED
797
+ return
798
+
799
+ def _run_workers(
800
+ self,
801
+ method: str,
802
+ *args,
803
+ driver_args: Optional[List[Any]] = None,
804
+ driver_kwargs: Optional[Dict[str, Any]] = None,
805
+ max_concurrent_workers: Optional[int] = None,
806
+ **kwargs,
807
+ ) -> Any:
808
+ """Runs the given method on all workers."""
809
+
810
+ if max_concurrent_workers:
811
+ raise NotImplementedError("max_concurrent_workers is not supported yet.")
812
+
813
+ # Start the ray workers first.
814
+ ray_worker_outputs = [
815
+ worker.execute_method.remote(method, *args, **kwargs)
816
+ for worker in self.workers
817
+ ]
818
+
819
+ if driver_args is None:
820
+ driver_args = args
821
+ if driver_kwargs is None:
822
+ driver_kwargs = kwargs
823
+
824
+ # Start the driver worker after all the ray workers.
825
+ driver_worker_output = getattr(self.driver_worker, method)(
826
+ *driver_args, **driver_kwargs
827
+ )
828
+
829
+ # Get the results of the ray workers.
830
+ if self.workers:
831
+ ray_worker_outputs = ray.get(ray_worker_outputs)
832
+
833
+ return [driver_worker_output] + ray_worker_outputs
ChatTTS/model/velocity/model_loader.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities for selecting and loading models."""
2
+
3
+ import contextlib
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from vllm.config import ModelConfig
9
+ from vllm.model_executor.models import ModelRegistry
10
+ from vllm.model_executor.weight_utils import get_quant_config, initialize_dummy_weights
11
+
12
+ from .llama import LlamaModel
13
+
14
+
15
+ @contextlib.contextmanager
16
+ def _set_default_torch_dtype(dtype: torch.dtype):
17
+ """Sets the default torch dtype to the given dtype."""
18
+ old_dtype = torch.get_default_dtype()
19
+ torch.set_default_dtype(dtype)
20
+ yield
21
+ torch.set_default_dtype(old_dtype)
22
+
23
+
24
+ def get_model(model_config: ModelConfig) -> nn.Module:
25
+ # Get the (maybe quantized) linear method.
26
+ linear_method = None
27
+ if model_config.quantization is not None:
28
+ quant_config = get_quant_config(
29
+ model_config.quantization,
30
+ model_config.model,
31
+ model_config.hf_config,
32
+ model_config.download_dir,
33
+ )
34
+ capability = torch.cuda.get_device_capability()
35
+ capability = capability[0] * 10 + capability[1]
36
+ if capability < quant_config.get_min_capability():
37
+ raise ValueError(
38
+ f"The quantization method {model_config.quantization} is not "
39
+ "supported for the current GPU. "
40
+ f"Minimum capability: {quant_config.get_min_capability()}. "
41
+ f"Current capability: {capability}."
42
+ )
43
+ supported_dtypes = quant_config.get_supported_act_dtypes()
44
+ if model_config.dtype not in supported_dtypes:
45
+ raise ValueError(
46
+ f"{model_config.dtype} is not supported for quantization "
47
+ f"method {model_config.quantization}. Supported dtypes: "
48
+ f"{supported_dtypes}"
49
+ )
50
+ linear_method = quant_config.get_linear_method()
51
+
52
+ with _set_default_torch_dtype(model_config.dtype):
53
+ # Create a model instance.
54
+ # The weights will be initialized as empty tensors.
55
+ with torch.device("cuda"):
56
+ model = LlamaModel(model_config.hf_config, linear_method)
57
+ if model_config.load_format == "dummy":
58
+ # NOTE(woosuk): For accurate performance evaluation, we assign
59
+ # random values to the weights.
60
+ initialize_dummy_weights(model)
61
+ else:
62
+ # Load the weights from the cached or downloaded files.
63
+ model.load_weights(
64
+ model_config.model,
65
+ model_config.download_dir,
66
+ model_config.load_format,
67
+ model_config.revision,
68
+ )
69
+ return model.eval()
ChatTTS/model/velocity/model_runner.py ADDED
@@ -0,0 +1,817 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import Dict, List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .configs import ModelConfig, ParallelConfig, SchedulerConfig
9
+ from vllm.logger import init_logger
10
+ from .model_loader import get_model
11
+ from vllm.model_executor import InputMetadata, SamplingMetadata
12
+ from vllm.model_executor.parallel_utils.communication_op import (
13
+ broadcast,
14
+ broadcast_object_list,
15
+ )
16
+ from .sampling_params import SamplingParams, SamplingType
17
+ from .sequence import (
18
+ SamplerOutput,
19
+ SequenceData,
20
+ SequenceGroupMetadata,
21
+ SequenceGroupOutput,
22
+ SequenceOutput,
23
+ )
24
+ from vllm.utils import in_wsl
25
+ from ..embed import Embed
26
+ from .sampler import Sampler
27
+ from safetensors.torch import safe_open
28
+
29
+ logger = init_logger(__name__)
30
+
31
+ KVCache = Tuple[torch.Tensor, torch.Tensor]
32
+ _PAD_SLOT_ID = -1
33
+ # Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
34
+ # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
35
+ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
36
+
37
+
38
+ class ModelRunner:
39
+
40
+ def __init__(
41
+ self,
42
+ model_config: ModelConfig,
43
+ parallel_config: ParallelConfig,
44
+ scheduler_config: SchedulerConfig,
45
+ is_driver_worker: bool = False,
46
+ post_model_path: str = None,
47
+ ):
48
+ self.model_config = model_config
49
+ self.parallel_config = parallel_config
50
+ self.scheduler_config = scheduler_config
51
+ self.is_driver_worker = is_driver_worker
52
+ self.post_model_path = post_model_path
53
+
54
+ # model_config can be None in tests/samplers/test_sampler.py.
55
+ # FIXME(woosuk): This is a hack to make the tests work. Refactor this.
56
+ self.sliding_window = (
57
+ model_config.get_sliding_window() if model_config is not None else None
58
+ )
59
+ self.model = None
60
+ self.block_size = None # Set after initial profiling.
61
+
62
+ self.graph_runners: Dict[int, CUDAGraphRunner] = {}
63
+ self.graph_memory_pool = None # Set during graph capture.
64
+
65
+ self.max_context_len_to_capture = (
66
+ self.model_config.max_context_len_to_capture
67
+ if self.model_config is not None
68
+ else 0
69
+ )
70
+ # When using CUDA graph, the input block tables must be padded to
71
+ # max_context_len_to_capture. However, creating the block table in
72
+ # Python can be expensive. To optimize this, we cache the block table
73
+ # in numpy and only copy the actual input content at every iteration.
74
+ # The shape of the cached block table will be
75
+ # (max batch size to capture, max context len to capture / block size).
76
+ self.graph_block_tables = None # Set after initial profiling.
77
+ # cache in_wsl result
78
+ self.in_wsl = in_wsl()
79
+
80
+ def load_model(self) -> None:
81
+ self.model = get_model(self.model_config)
82
+ self.post_model = Embed(
83
+ self.model_config.get_hidden_size(),
84
+ self.model_config.num_audio_tokens,
85
+ self.model_config.num_text_tokens,
86
+ )
87
+ state_dict_tensors = {}
88
+ with safe_open(self.post_model_path, framework="pt", device=0) as f:
89
+ for k in f.keys():
90
+ state_dict_tensors[k] = f.get_tensor(k)
91
+ self.post_model.load_state_dict(state_dict_tensors)
92
+ self.post_model.to(next(self.model.parameters())).eval()
93
+ self.sampler = Sampler(self.post_model, self.model_config.num_audio_tokens, 4)
94
+
95
+ def set_block_size(self, block_size: int) -> None:
96
+ self.block_size = block_size
97
+
98
+ max_num_blocks = (
99
+ self.max_context_len_to_capture + block_size - 1
100
+ ) // block_size
101
+ self.graph_block_tables = np.zeros(
102
+ (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32
103
+ )
104
+
105
+ def _prepare_prompt(
106
+ self,
107
+ seq_group_metadata_list: List[SequenceGroupMetadata],
108
+ ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int]]:
109
+ assert len(seq_group_metadata_list) > 0
110
+ input_tokens: List[List[int]] = []
111
+ input_positions: List[List[int]] = []
112
+ slot_mapping: List[List[int]] = []
113
+
114
+ prompt_lens: List[int] = []
115
+ for seq_group_metadata in seq_group_metadata_list:
116
+ assert seq_group_metadata.is_prompt
117
+ seq_ids = list(seq_group_metadata.seq_data.keys())
118
+ assert len(seq_ids) == 1
119
+ seq_id = seq_ids[0]
120
+
121
+ seq_data = seq_group_metadata.seq_data[seq_id]
122
+ prompt_tokens = seq_data.get_token_ids()
123
+ prompt_len = len(prompt_tokens)
124
+ prompt_lens.append(prompt_len)
125
+
126
+ input_tokens.append(prompt_tokens)
127
+ # NOTE(woosuk): Here we assume that the first token in the prompt
128
+ # is always the first token in the sequence.
129
+ input_positions.append(list(range(prompt_len)))
130
+
131
+ if seq_group_metadata.block_tables is None:
132
+ # During memory profiling, the block tables are not initialized
133
+ # yet. In this case, we just use a dummy slot mapping.
134
+ slot_mapping.append([_PAD_SLOT_ID] * prompt_len)
135
+ continue
136
+
137
+ # Compute the slot mapping.
138
+ slot_mapping.append([])
139
+ block_table = seq_group_metadata.block_tables[seq_id]
140
+ # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
141
+ # where start_idx is max(0, prompt_len - sliding_window).
142
+ # For example, if the prompt len is 10, sliding window is 8, and
143
+ # block size is 4, the first two tokens are masked and the slot
144
+ # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
145
+ start_idx = 0
146
+ if self.sliding_window is not None:
147
+ start_idx = max(0, prompt_len - self.sliding_window)
148
+ for i in range(prompt_len):
149
+ if i < start_idx:
150
+ slot_mapping[-1].append(_PAD_SLOT_ID)
151
+ continue
152
+
153
+ block_number = block_table[i // self.block_size]
154
+ block_offset = i % self.block_size
155
+ slot = block_number * self.block_size + block_offset
156
+ slot_mapping[-1].append(slot)
157
+
158
+ max_prompt_len = max(prompt_lens)
159
+ input_tokens = _make_tensor_with_pad(
160
+ input_tokens, max_prompt_len, pad=0, dtype=torch.long
161
+ )
162
+ input_positions = _make_tensor_with_pad(
163
+ input_positions, max_prompt_len, pad=0, dtype=torch.long
164
+ )
165
+ slot_mapping = _make_tensor_with_pad(
166
+ slot_mapping, max_prompt_len, pad=_PAD_SLOT_ID, dtype=torch.long
167
+ )
168
+
169
+ input_metadata = InputMetadata(
170
+ is_prompt=True,
171
+ slot_mapping=slot_mapping,
172
+ max_context_len=None,
173
+ context_lens=None,
174
+ block_tables=None,
175
+ use_cuda_graph=False,
176
+ )
177
+ return input_tokens, input_positions, input_metadata, prompt_lens
178
+
179
+ def _prepare_decode(
180
+ self,
181
+ seq_group_metadata_list: List[SequenceGroupMetadata],
182
+ ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
183
+ assert len(seq_group_metadata_list) > 0
184
+ input_tokens: List[List[int]] = []
185
+ input_positions: List[List[int]] = []
186
+ slot_mapping: List[List[int]] = []
187
+ context_lens: List[int] = []
188
+ block_tables: List[List[int]] = []
189
+
190
+ for seq_group_metadata in seq_group_metadata_list:
191
+ assert not seq_group_metadata.is_prompt
192
+
193
+ seq_ids = list(seq_group_metadata.seq_data.keys())
194
+ for seq_id in seq_ids:
195
+ seq_data = seq_group_metadata.seq_data[seq_id]
196
+ generation_token = seq_data.get_last_token_id()
197
+ input_tokens.append([generation_token])
198
+
199
+ seq_len = seq_data.get_len()
200
+ position = seq_len - 1
201
+ input_positions.append([position])
202
+
203
+ context_len = (
204
+ seq_len
205
+ if self.sliding_window is None
206
+ else min(seq_len, self.sliding_window)
207
+ )
208
+ context_lens.append(context_len)
209
+
210
+ block_table = seq_group_metadata.block_tables[seq_id]
211
+ block_number = block_table[position // self.block_size]
212
+ block_offset = position % self.block_size
213
+ slot = block_number * self.block_size + block_offset
214
+ slot_mapping.append([slot])
215
+
216
+ if self.sliding_window is not None:
217
+ sliding_window_blocks = self.sliding_window // self.block_size
218
+ block_table = block_table[-sliding_window_blocks:]
219
+ block_tables.append(block_table)
220
+
221
+ batch_size = len(input_tokens)
222
+ max_context_len = max(context_lens)
223
+ use_captured_graph = (
224
+ not self.model_config.enforce_eager
225
+ and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
226
+ and max_context_len <= self.max_context_len_to_capture
227
+ )
228
+ if use_captured_graph:
229
+ # Pad the input tokens, positions, and slot mapping to match the
230
+ # batch size of the captured graph.
231
+ graph_batch_size = _get_graph_batch_size(batch_size)
232
+ assert graph_batch_size >= batch_size
233
+ for _ in range(graph_batch_size - batch_size):
234
+ input_tokens.append([])
235
+ input_positions.append([])
236
+ slot_mapping.append([])
237
+ context_lens.append(1)
238
+ block_tables.append([])
239
+ batch_size = graph_batch_size
240
+
241
+ input_tokens = _make_tensor_with_pad(
242
+ input_tokens, max_len=1, pad=0, dtype=torch.long, device="cuda"
243
+ )
244
+ input_positions = _make_tensor_with_pad(
245
+ input_positions, max_len=1, pad=0, dtype=torch.long, device="cuda"
246
+ )
247
+ slot_mapping = _make_tensor_with_pad(
248
+ slot_mapping, max_len=1, pad=_PAD_SLOT_ID, dtype=torch.long, device="cuda"
249
+ )
250
+ context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
251
+
252
+ if use_captured_graph:
253
+ # The shape of graph_block_tables is
254
+ # [max batch size, max context len // block size].
255
+ input_block_tables = self.graph_block_tables[:batch_size]
256
+ for i, block_table in enumerate(block_tables):
257
+ if block_table:
258
+ input_block_tables[i, : len(block_table)] = block_table
259
+ block_tables = torch.tensor(input_block_tables, device="cuda")
260
+ else:
261
+ block_tables = _make_tensor_with_pad(
262
+ block_tables,
263
+ max_len=max_context_len,
264
+ pad=0,
265
+ dtype=torch.int,
266
+ device="cuda",
267
+ )
268
+
269
+ input_metadata = InputMetadata(
270
+ is_prompt=False,
271
+ slot_mapping=slot_mapping,
272
+ max_context_len=max_context_len,
273
+ context_lens=context_lens,
274
+ block_tables=block_tables,
275
+ use_cuda_graph=use_captured_graph,
276
+ )
277
+ return input_tokens, input_positions, input_metadata
278
+
279
+ def _prepare_sample(
280
+ self,
281
+ seq_group_metadata_list: List[SequenceGroupMetadata],
282
+ prompt_lens: List[int],
283
+ ) -> SamplingMetadata:
284
+ seq_groups: List[Tuple[List[int], SamplingParams]] = []
285
+ selected_token_indices: List[int] = []
286
+ selected_token_start_idx = 0
287
+ categorized_sample_indices = {t: [] for t in SamplingType}
288
+ categorized_sample_indices_start_idx = 0
289
+
290
+ max_prompt_len = max(prompt_lens) if prompt_lens else 1
291
+ for i, seq_group_metadata in enumerate(seq_group_metadata_list):
292
+ seq_ids = list(seq_group_metadata.seq_data.keys())
293
+ sampling_params = seq_group_metadata.sampling_params
294
+ seq_groups.append((seq_ids, sampling_params))
295
+
296
+ if seq_group_metadata.is_prompt:
297
+ assert len(seq_ids) == 1
298
+ prompt_len = prompt_lens[i]
299
+ if sampling_params.prompt_logprobs is not None:
300
+ # NOTE: prompt token positions do not need sample, skip
301
+ categorized_sample_indices_start_idx += prompt_len - 1
302
+
303
+ categorized_sample_indices[sampling_params.sampling_type].append(
304
+ categorized_sample_indices_start_idx
305
+ )
306
+ categorized_sample_indices_start_idx += 1
307
+
308
+ if sampling_params.prompt_logprobs is not None:
309
+ selected_token_indices.extend(
310
+ range(
311
+ selected_token_start_idx,
312
+ selected_token_start_idx + prompt_len - 1,
313
+ )
314
+ )
315
+ selected_token_indices.append(selected_token_start_idx + prompt_len - 1)
316
+ selected_token_start_idx += max_prompt_len
317
+ else:
318
+ num_seqs = len(seq_ids)
319
+ selected_token_indices.extend(
320
+ range(selected_token_start_idx, selected_token_start_idx + num_seqs)
321
+ )
322
+ selected_token_start_idx += num_seqs
323
+
324
+ categorized_sample_indices[sampling_params.sampling_type].extend(
325
+ range(
326
+ categorized_sample_indices_start_idx,
327
+ categorized_sample_indices_start_idx + num_seqs,
328
+ )
329
+ )
330
+ categorized_sample_indices_start_idx += num_seqs
331
+
332
+ selected_token_indices = _async_h2d(
333
+ selected_token_indices, dtype=torch.long, pin_memory=not self.in_wsl
334
+ )
335
+ categorized_sample_indices = {
336
+ t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl)
337
+ for t, seq_ids in categorized_sample_indices.items()
338
+ }
339
+
340
+ seq_data: Dict[int, SequenceData] = {}
341
+ for seq_group_metadata in seq_group_metadata_list:
342
+ seq_data.update(seq_group_metadata.seq_data)
343
+
344
+ sampling_metadata = SamplingMetadata(
345
+ seq_groups=seq_groups,
346
+ seq_data=seq_data,
347
+ prompt_lens=prompt_lens,
348
+ selected_token_indices=selected_token_indices,
349
+ categorized_sample_indices=categorized_sample_indices,
350
+ )
351
+ return sampling_metadata
352
+
353
+ def prepare_input_tensors(
354
+ self,
355
+ seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
356
+ ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata]:
357
+ if self.is_driver_worker:
358
+ # NOTE: We assume that all sequences in the group are all prompts or
359
+ # all decodes.
360
+ is_prompt = seq_group_metadata_list[0].is_prompt
361
+ # Prepare input tensors.
362
+ if is_prompt:
363
+ (input_tokens, input_positions, input_metadata, prompt_lens) = (
364
+ self._prepare_prompt(seq_group_metadata_list)
365
+ )
366
+ else:
367
+ (input_tokens, input_positions, input_metadata) = self._prepare_decode(
368
+ seq_group_metadata_list
369
+ )
370
+ prompt_lens = []
371
+ sampling_metadata = self._prepare_sample(
372
+ seq_group_metadata_list, prompt_lens
373
+ )
374
+
375
+ def get_size_or_none(x: Optional[torch.Tensor]):
376
+ return x.size() if x is not None else None
377
+
378
+ # Broadcast the input data. For input tensors, we first broadcast
379
+ # its shape and then broadcast the tensor to avoid high
380
+ # serialization cost.
381
+ py_data = {
382
+ "input_tokens_size": input_tokens.size(),
383
+ "input_positions_size": input_positions.size(),
384
+ "is_prompt": input_metadata.is_prompt,
385
+ "slot_mapping_size": get_size_or_none(input_metadata.slot_mapping),
386
+ "max_context_len": input_metadata.max_context_len,
387
+ "context_lens_size": get_size_or_none(input_metadata.context_lens),
388
+ "block_tables_size": get_size_or_none(input_metadata.block_tables),
389
+ "use_cuda_graph": input_metadata.use_cuda_graph,
390
+ "selected_token_indices_size": sampling_metadata.selected_token_indices.size(),
391
+ }
392
+ broadcast_object_list([py_data], src=0)
393
+ # TODO(zhuohan): Combine the broadcasts or set async_op=True.
394
+ broadcast(input_tokens, src=0)
395
+ broadcast(input_positions, src=0)
396
+ if input_metadata.slot_mapping is not None:
397
+ broadcast(input_metadata.slot_mapping, src=0)
398
+ if input_metadata.context_lens is not None:
399
+ broadcast(input_metadata.context_lens, src=0)
400
+ if input_metadata.block_tables is not None:
401
+ broadcast(input_metadata.block_tables, src=0)
402
+ broadcast(sampling_metadata.selected_token_indices, src=0)
403
+ else:
404
+ receving_list = [None]
405
+ broadcast_object_list(receving_list, src=0)
406
+ py_data = receving_list[0]
407
+ input_tokens = torch.empty(
408
+ *py_data["input_tokens_size"], dtype=torch.long, device="cuda"
409
+ )
410
+ broadcast(input_tokens, src=0)
411
+ input_positions = torch.empty(
412
+ *py_data["input_positions_size"], dtype=torch.long, device="cuda"
413
+ )
414
+ broadcast(input_positions, src=0)
415
+ if py_data["slot_mapping_size"] is not None:
416
+ slot_mapping = torch.empty(
417
+ *py_data["slot_mapping_size"], dtype=torch.long, device="cuda"
418
+ )
419
+ broadcast(slot_mapping, src=0)
420
+ else:
421
+ slot_mapping = None
422
+ if py_data["context_lens_size"] is not None:
423
+ context_lens = torch.empty(
424
+ *py_data["context_lens_size"], dtype=torch.int, device="cuda"
425
+ )
426
+ broadcast(context_lens, src=0)
427
+ else:
428
+ context_lens = None
429
+ if py_data["block_tables_size"] is not None:
430
+ block_tables = torch.empty(
431
+ *py_data["block_tables_size"], dtype=torch.int, device="cuda"
432
+ )
433
+ broadcast(block_tables, src=0)
434
+ else:
435
+ block_tables = None
436
+ selected_token_indices = torch.empty(
437
+ *py_data["selected_token_indices_size"], dtype=torch.long, device="cuda"
438
+ )
439
+ broadcast(selected_token_indices, src=0)
440
+ input_metadata = InputMetadata(
441
+ is_prompt=py_data["is_prompt"],
442
+ slot_mapping=slot_mapping,
443
+ max_context_len=py_data["max_context_len"],
444
+ context_lens=context_lens,
445
+ block_tables=block_tables,
446
+ use_cuda_graph=py_data["use_cuda_graph"],
447
+ )
448
+ sampling_metadata = SamplingMetadata(
449
+ seq_groups=None,
450
+ seq_data=None,
451
+ prompt_lens=None,
452
+ selected_token_indices=selected_token_indices,
453
+ categorized_sample_indices=None,
454
+ perform_sampling=False,
455
+ )
456
+
457
+ return input_tokens, input_positions, input_metadata, sampling_metadata
458
+
459
+ @torch.inference_mode()
460
+ def execute_model(
461
+ self,
462
+ seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
463
+ kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
464
+ ) -> Optional[SamplerOutput]:
465
+ input_tokens, input_positions, input_metadata, sampling_metadata = (
466
+ self.prepare_input_tensors(seq_group_metadata_list)
467
+ )
468
+ # print(sampling_metadata.seq_data)
469
+ seq_groups = []
470
+ input_tokens_history = []
471
+ for i, rtn in enumerate(sampling_metadata.seq_groups):
472
+ seq_groups.append(rtn[0][0])
473
+ tokens_history = sampling_metadata.seq_data[rtn[0][0]].output_token_ids
474
+ if len(tokens_history) >= 1:
475
+ if len(tokens_history[0]) == 1:
476
+ tokens_history = [token[0] for token in tokens_history]
477
+ else:
478
+ tokens_history = [list(token) for token in tokens_history]
479
+ input_tokens_history.append(tokens_history)
480
+ input_tokens_history = torch.tensor(input_tokens_history).to(
481
+ input_tokens.device
482
+ )
483
+ # token_ids = rtn.outputs[0].token_ids
484
+ # for j, token_id in enumerate(token_ids):
485
+ # if len(token_id) == 1:
486
+ # token_ids[j] = token_id[0]
487
+ # else:
488
+ # token_ids[j] = list(token_id)
489
+
490
+ # Execute the model.
491
+ # print("it1",input_tokens)
492
+ if len(input_tokens.shape) == 2:
493
+ input_tokens = input_tokens.unsqueeze(2).repeat(1, 1, 4)
494
+ if len(input_tokens_history.shape) == 2:
495
+ input_tokens_history = input_tokens_history.unsqueeze(2).repeat(1, 1, 4)
496
+ # print(input_tokens_history.shape)
497
+ # print("it2",input_tokens.shape)
498
+ text_mask = input_tokens != 0
499
+ text_mask = text_mask[:, :, 0]
500
+
501
+ if input_metadata.use_cuda_graph:
502
+ graph_batch_size = input_tokens.shape[0]
503
+ model_executable = self.graph_runners[graph_batch_size]
504
+ else:
505
+ model_executable = self.model
506
+
507
+ infer_text = sampling_metadata.seq_groups[0][1].infer_text
508
+ temperture = sampling_metadata.seq_groups[0][1].temperature
509
+ if not infer_text:
510
+ temperture = torch.tensor(temperture).to(input_tokens.device)
511
+ logits_processors, logits_warpers = sampling_metadata.seq_groups[0][
512
+ 1
513
+ ].logits_processors
514
+ # print(logits_processors, logits_warpers)
515
+ min_new_token = sampling_metadata.seq_groups[0][1].min_new_token
516
+ eos_token = sampling_metadata.seq_groups[0][1].eos_token
517
+ start_idx = sampling_metadata.seq_groups[0][1].start_idx
518
+ if input_tokens.shape[-2] == 1:
519
+ if infer_text:
520
+ input_emb: torch.Tensor = self.post_model.emb_text(
521
+ input_tokens[:, :, 0]
522
+ )
523
+ else:
524
+ code_emb = [
525
+ self.post_model.emb_code[i](input_tokens[:, :, i])
526
+ for i in range(self.post_model.num_vq)
527
+ ]
528
+ input_emb = torch.stack(code_emb, 3).sum(3)
529
+ start_idx = (
530
+ input_tokens_history.shape[-2] - 1
531
+ if input_tokens_history.shape[-2] > 0
532
+ else 0
533
+ )
534
+ else:
535
+ input_emb = self.post_model(input_tokens, text_mask)
536
+ # print(input_emb.shape)
537
+ hidden_states = model_executable(
538
+ input_emb=input_emb,
539
+ positions=input_positions,
540
+ kv_caches=kv_caches,
541
+ input_metadata=input_metadata,
542
+ )
543
+ # print(hidden_states.shape)
544
+ # print(input_tokens)
545
+ B_NO_PAD = input_tokens_history.shape[0]
546
+ input_tokens = input_tokens[:B_NO_PAD, :, :]
547
+ hidden_states = hidden_states[:B_NO_PAD, :, :]
548
+ idx_next, logprob, finish = self.sampler.sample(
549
+ inputs_ids=(
550
+ input_tokens
551
+ if input_tokens_history.shape[-2] == 0
552
+ else input_tokens_history
553
+ ),
554
+ hidden_states=hidden_states,
555
+ infer_text=infer_text,
556
+ temperature=temperture,
557
+ logits_processors=logits_processors,
558
+ logits_warpers=logits_warpers,
559
+ min_new_token=min_new_token,
560
+ now_length=1,
561
+ eos_token=eos_token,
562
+ start_idx=start_idx,
563
+ )
564
+ # print(logprob.shape, idx_next.shape)
565
+ if len(logprob.shape) == 2:
566
+ logprob = logprob[:, None, :]
567
+ logprob = torch.gather(logprob, -1, idx_next.transpose(-1, -2))[:, :, 0]
568
+ # print("测试",idx_next.shape, logprob.shape)
569
+ # Sample the next token.
570
+ # output = self.model.sample(
571
+ # hidden_states=hidden_states,
572
+ # sampling_metadata=sampling_metadata,
573
+ # )
574
+ results = []
575
+ for i in range(idx_next.shape[0]):
576
+ idx_next_i = idx_next[i, 0, :].tolist()
577
+ logprob_i = logprob[i].tolist()
578
+ tmp_hidden_states = hidden_states[i]
579
+ if input_tokens[i].shape[-2] != 1:
580
+ tmp_hidden_states = tmp_hidden_states[-1:, :]
581
+ result = SequenceGroupOutput(
582
+ samples=[
583
+ SequenceOutput(
584
+ parent_seq_id=seq_groups[i],
585
+ logprobs={tuple(idx_next_i): logprob_i},
586
+ output_token=tuple(idx_next_i),
587
+ hidden_states=tmp_hidden_states,
588
+ finished=finish[i].item(),
589
+ ),
590
+ ],
591
+ prompt_logprobs=None,
592
+ )
593
+ results.append(result)
594
+ # print(results)
595
+ # print(idx_next, idx_next.shape, logprob.shape)
596
+ return results
597
+
598
+ @torch.inference_mode()
599
+ def profile_run(self) -> None:
600
+ # Enable top-k sampling to reflect the accurate memory usage.
601
+ vocab_size = self.model_config.get_vocab_size()
602
+ sampling_params = SamplingParams(
603
+ top_p=0.99, top_k=vocab_size - 1, infer_text=True
604
+ )
605
+ max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
606
+ max_num_seqs = self.scheduler_config.max_num_seqs
607
+
608
+ # Profile memory usage with max_num_sequences sequences and the total
609
+ # number of tokens equal to max_num_batched_tokens.
610
+ seqs: List[SequenceGroupMetadata] = []
611
+ for group_id in range(max_num_seqs):
612
+ seq_len = max_num_batched_tokens // max_num_seqs + (
613
+ group_id < max_num_batched_tokens % max_num_seqs
614
+ )
615
+ seq_data = SequenceData([0] * seq_len)
616
+ seq = SequenceGroupMetadata(
617
+ request_id=str(group_id),
618
+ is_prompt=True,
619
+ seq_data={group_id: seq_data},
620
+ sampling_params=sampling_params,
621
+ block_tables=None,
622
+ )
623
+ seqs.append(seq)
624
+
625
+ # Run the model with the dummy inputs.
626
+ num_layers = self.model_config.get_num_layers(self.parallel_config)
627
+ kv_caches = [(None, None)] * num_layers
628
+ self.execute_model(seqs, kv_caches)
629
+ torch.cuda.synchronize()
630
+ return
631
+
632
+ @torch.inference_mode()
633
+ def capture_model(self, kv_caches: List[KVCache]) -> None:
634
+ assert not self.model_config.enforce_eager
635
+ logger.info(
636
+ "Capturing the model for CUDA graphs. This may lead to "
637
+ "unexpected consequences if the model is not static. To "
638
+ "run the model in eager mode, set 'enforce_eager=True' or "
639
+ "use '--enforce-eager' in the CLI."
640
+ )
641
+ logger.info(
642
+ "CUDA graphs can take additional 1~3 GiB memory per GPU. "
643
+ "If you are running out of memory, consider decreasing "
644
+ "`gpu_memory_utilization` or enforcing eager mode."
645
+ )
646
+ start_time = time.perf_counter()
647
+
648
+ # Prepare dummy inputs. These will be reused for all batch sizes.
649
+ max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
650
+ input_emb = torch.zeros(
651
+ max_batch_size,
652
+ 1,
653
+ self.model_config.get_hidden_size(),
654
+ dtype=next(self.model.parameters()).dtype,
655
+ ).cuda()
656
+ input_positions = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
657
+ slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda()
658
+ slot_mapping.fill_(_PAD_SLOT_ID)
659
+ context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
660
+ block_tables = torch.from_numpy(self.graph_block_tables).cuda()
661
+
662
+ # NOTE: Capturing the largest batch size first may help reduce the
663
+ # memory usage of CUDA graph.
664
+ for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE):
665
+ # Create dummy input_metadata.
666
+ input_metadata = InputMetadata(
667
+ is_prompt=False,
668
+ slot_mapping=slot_mapping[:batch_size],
669
+ max_context_len=self.max_context_len_to_capture,
670
+ context_lens=context_lens[:batch_size],
671
+ block_tables=block_tables[:batch_size],
672
+ use_cuda_graph=True,
673
+ )
674
+
675
+ graph_runner = CUDAGraphRunner(self.model)
676
+ graph_runner.capture(
677
+ input_emb[:batch_size],
678
+ input_positions[:batch_size],
679
+ kv_caches,
680
+ input_metadata,
681
+ memory_pool=self.graph_memory_pool,
682
+ )
683
+ self.graph_memory_pool = graph_runner.graph.pool()
684
+ self.graph_runners[batch_size] = graph_runner
685
+
686
+ end_time = time.perf_counter()
687
+ elapsed_time = end_time - start_time
688
+ # This usually takes < 10 seconds.
689
+ logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
690
+
691
+
692
+ class CUDAGraphRunner:
693
+
694
+ def __init__(self, model: nn.Module):
695
+ self.model = model
696
+ self.graph = None
697
+ self.input_buffers: Dict[str, torch.Tensor] = {}
698
+ self.output_buffers: Dict[str, torch.Tensor] = {}
699
+
700
+ def capture(
701
+ self,
702
+ input_emb: torch.Tensor,
703
+ positions: torch.Tensor,
704
+ kv_caches: List[KVCache],
705
+ input_metadata: InputMetadata,
706
+ memory_pool,
707
+ ) -> None:
708
+ assert self.graph is None
709
+ # Run the model once without capturing the graph.
710
+ # This is to make sure that the captured graph does not include the
711
+ # kernel launches for initial benchmarking (e.g., Triton autotune).
712
+ self.model(
713
+ input_emb,
714
+ positions,
715
+ kv_caches,
716
+ input_metadata,
717
+ )
718
+ torch.cuda.synchronize()
719
+
720
+ # Capture the graph.
721
+ self.graph = torch.cuda.CUDAGraph()
722
+ with torch.cuda.graph(self.graph, pool=memory_pool):
723
+ hidden_states = self.model(
724
+ input_emb,
725
+ positions,
726
+ kv_caches,
727
+ input_metadata,
728
+ )
729
+ torch.cuda.synchronize()
730
+
731
+ # Save the input and output buffers.
732
+ self.input_buffers = {
733
+ "input_emb": input_emb,
734
+ "positions": positions,
735
+ "kv_caches": kv_caches,
736
+ "slot_mapping": input_metadata.slot_mapping,
737
+ "context_lens": input_metadata.context_lens,
738
+ "block_tables": input_metadata.block_tables,
739
+ }
740
+ self.output_buffers = {"hidden_states": hidden_states}
741
+ return
742
+
743
+ def forward(
744
+ self,
745
+ input_emb: torch.Tensor,
746
+ positions: torch.Tensor,
747
+ kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
748
+ input_metadata: InputMetadata,
749
+ ) -> torch.Tensor:
750
+ # KV caches are fixed tensors, so we don't need to copy them.
751
+ del kv_caches
752
+
753
+ # Copy the input tensors to the input buffers.
754
+ self.input_buffers["input_emb"].copy_(input_emb, non_blocking=True)
755
+ self.input_buffers["positions"].copy_(positions, non_blocking=True)
756
+ self.input_buffers["slot_mapping"].copy_(
757
+ input_metadata.slot_mapping, non_blocking=True
758
+ )
759
+ self.input_buffers["context_lens"].copy_(
760
+ input_metadata.context_lens, non_blocking=True
761
+ )
762
+ self.input_buffers["block_tables"].copy_(
763
+ input_metadata.block_tables, non_blocking=True
764
+ )
765
+
766
+ # Run the graph.
767
+ self.graph.replay()
768
+
769
+ # Return the output tensor.
770
+ return self.output_buffers["hidden_states"]
771
+
772
+ def __call__(self, *args, **kwargs):
773
+ return self.forward(*args, **kwargs)
774
+
775
+
776
+ def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
777
+ assert len(x) <= max_len
778
+ if len(x) == max_len:
779
+ return list(x)
780
+ return list(x) + [pad] * (max_len - len(x))
781
+
782
+
783
+ def _make_tensor_with_pad(
784
+ x: List[List[int]],
785
+ max_len: int,
786
+ pad: int,
787
+ dtype: torch.dtype,
788
+ device: Union[str, torch.device] = "cuda",
789
+ pin_memory: bool = False,
790
+ ) -> torch.Tensor:
791
+ padded_x = []
792
+ for x_i in x:
793
+ pad_i = pad
794
+ if isinstance(x[0][0], tuple):
795
+ pad_i = (0,) * len(x[0][0])
796
+ padded_x.append(_pad_to_max(x_i, max_len, pad_i))
797
+
798
+ return torch.tensor(
799
+ padded_x,
800
+ dtype=dtype,
801
+ device=device,
802
+ pin_memory=pin_memory and str(device) == "cpu",
803
+ )
804
+
805
+
806
+ def _get_graph_batch_size(batch_size: int) -> int:
807
+ if batch_size <= 2:
808
+ return batch_size
809
+ elif batch_size <= 4:
810
+ return 4
811
+ else:
812
+ return (batch_size + 7) // 8 * 8
813
+
814
+
815
+ def _async_h2d(data: list, dtype, pin_memory):
816
+ t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory)
817
+ return t.to(device="cuda", non_blocking=True)
ChatTTS/model/velocity/output.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ import torch
3
+
4
+ from .sequence import (
5
+ PromptLogprobs,
6
+ SampleLogprobs,
7
+ SequenceGroup,
8
+ SequenceStatus,
9
+ )
10
+
11
+
12
+ class CompletionOutput:
13
+ """The output data of one completion output of a request.
14
+
15
+ Args:
16
+ index: The index of the output in the request.
17
+ text: The generated output text.
18
+ token_ids: The token IDs of the generated output text.
19
+ cumulative_logprob: The cumulative log probability of the generated
20
+ output text.
21
+ logprobs: The log probabilities of the top probability words at each
22
+ position if the logprobs are requested.
23
+ finish_reason: The reason why the sequence is finished.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ index: int,
29
+ text: str,
30
+ token_ids: List[int],
31
+ cumulative_logprob: float,
32
+ logprobs: Optional[SampleLogprobs],
33
+ finish_reason: Optional[str] = None,
34
+ hidden_states: Optional[torch.Tensor] = None,
35
+ ) -> None:
36
+ self.index = index
37
+ self.text = text
38
+ self.token_ids = token_ids
39
+ self.cumulative_logprob = cumulative_logprob
40
+ self.logprobs = logprobs
41
+ self.finish_reason = finish_reason
42
+ self.hidden_states = hidden_states
43
+
44
+ def finished(self) -> bool:
45
+ return self.finish_reason is not None
46
+
47
+ def __repr__(self) -> str:
48
+ return (
49
+ f"CompletionOutput(index={self.index}, "
50
+ f"text={self.text!r}, "
51
+ f"token_ids={self.token_ids}, "
52
+ f"cumulative_logprob={self.cumulative_logprob}, "
53
+ f"logprobs={self.logprobs}, "
54
+ f"finish_reason={self.finish_reason}, "
55
+ f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None})"
56
+ )
57
+
58
+
59
+ class RequestOutput:
60
+ """The output data of a request to the LLM.
61
+
62
+ Args:
63
+ request_id: The unique ID of the request.
64
+ prompt: The prompt string of the request.
65
+ prompt_token_ids: The token IDs of the prompt.
66
+ prompt_logprobs: The log probabilities to return per prompt token.
67
+ outputs: The output sequences of the request.
68
+ finished: Whether the whole request is finished.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ request_id: str,
74
+ prompt: str,
75
+ prompt_token_ids: List[int],
76
+ prompt_logprobs: Optional[PromptLogprobs],
77
+ outputs: List[CompletionOutput],
78
+ finished: bool,
79
+ ) -> None:
80
+ self.request_id = request_id
81
+ self.prompt = prompt
82
+ self.prompt_token_ids = prompt_token_ids
83
+ self.prompt_logprobs = prompt_logprobs
84
+ self.outputs = outputs
85
+ self.finished = finished
86
+
87
+ @classmethod
88
+ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
89
+ # Get the top-n sequences.
90
+ n = seq_group.sampling_params.n
91
+ seqs = seq_group.get_seqs()
92
+ if seq_group.sampling_params.use_beam_search:
93
+ sorting_key = lambda seq: seq.get_beam_search_score(
94
+ seq_group.sampling_params.length_penalty
95
+ )
96
+ else:
97
+ sorting_key = lambda seq: seq.get_cumulative_logprob()
98
+ sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
99
+ top_n_seqs = sorted_seqs[:n]
100
+
101
+ # Create the outputs.
102
+ outputs: List[CompletionOutput] = []
103
+ for seq in top_n_seqs:
104
+ logprobs = seq.output_logprobs
105
+ if seq_group.sampling_params.logprobs is None:
106
+ # NOTE: We need to take care of this case because the sequence
107
+ # always has the logprobs of the sampled tokens even if the
108
+ # logprobs are not requested.
109
+ logprobs = None
110
+ finshed_reason = SequenceStatus.get_finished_reason(seq.status)
111
+ output = CompletionOutput(
112
+ seqs.index(seq),
113
+ seq.output_text,
114
+ seq.get_output_token_ids(),
115
+ seq.get_cumulative_logprob(),
116
+ logprobs,
117
+ finshed_reason,
118
+ seq.data.hidden_states,
119
+ )
120
+ outputs.append(output)
121
+
122
+ # Every sequence in the sequence group should have the same prompt.
123
+ prompt = seq_group.prompt
124
+ prompt_token_ids = seq_group.prompt_token_ids
125
+ prompt_logprobs = seq_group.prompt_logprobs
126
+ finished = seq_group.is_finished()
127
+ return cls(
128
+ seq_group.request_id,
129
+ prompt,
130
+ prompt_token_ids,
131
+ prompt_logprobs,
132
+ outputs,
133
+ finished,
134
+ )
135
+
136
+ def __repr__(self) -> str:
137
+ return (
138
+ f"RequestOutput(request_id={self.request_id}, "
139
+ f"prompt={self.prompt!r}, "
140
+ f"prompt_token_ids={self.prompt_token_ids}, "
141
+ f"prompt_logprobs={self.prompt_logprobs}, "
142
+ f"outputs={self.outputs}, "
143
+ f"finished={self.finished})"
144
+ )
ChatTTS/model/velocity/sampler.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.functional import F
3
+ from typing import List, Callable
4
+
5
+ from ..embed import Embed
6
+
7
+
8
+ class Sampler:
9
+ def __init__(self, post_model: Embed, num_audio_tokens: int, num_vq: int):
10
+ self.post_model = post_model
11
+ self.device = next(self.post_model.parameters()).device
12
+ self.num_audio_tokens = num_audio_tokens
13
+ self.num_vq = num_vq
14
+
15
+ def sample(
16
+ self,
17
+ inputs_ids: torch.Tensor,
18
+ hidden_states: torch.Tensor,
19
+ infer_text: bool = False,
20
+ temperature: torch.Tensor = 1.0,
21
+ logits_processors: List[Callable] = [
22
+ lambda logits_token, logits: logits,
23
+ ],
24
+ logits_warpers: List[Callable] = [
25
+ lambda logits_token, logits: logits,
26
+ ],
27
+ min_new_token: int = 0,
28
+ now_length: int = 0,
29
+ eos_token: int = 0,
30
+ start_idx: int = 0,
31
+ ):
32
+ # print(inputs_ids.shape)
33
+ B = hidden_states.shape[0]
34
+
35
+ end_idx = torch.zeros(
36
+ inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long
37
+ )
38
+ finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()
39
+ if not infer_text:
40
+ temperature = (
41
+ temperature.unsqueeze(0)
42
+ .expand(inputs_ids.shape[0], -1)
43
+ .contiguous()
44
+ .view(-1, 1)
45
+ )
46
+
47
+ if infer_text:
48
+ logits: torch.Tensor = self.post_model.head_text(hidden_states)
49
+ else:
50
+ # logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3)
51
+ logits = torch.empty(
52
+ hidden_states.size(0),
53
+ hidden_states.size(1),
54
+ self.num_audio_tokens,
55
+ self.num_vq,
56
+ dtype=torch.float,
57
+ device=self.device,
58
+ )
59
+ for num_vq_iter in range(self.num_vq):
60
+ x: torch.Tensor = self.post_model.head_code[num_vq_iter](hidden_states)
61
+ logits[..., num_vq_iter] = x
62
+ del x
63
+
64
+ del hidden_states
65
+
66
+ # logits = logits[:, -1].float()
67
+ logits = logits.narrow(1, -1, 1).squeeze_(1).float()
68
+
69
+ if not infer_text:
70
+ # logits = rearrange(logits, "b c n -> (b n) c")
71
+ logits = logits.permute(0, 2, 1)
72
+ logits = logits.reshape(-1, logits.size(2))
73
+ # logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
74
+ inputs_ids_sliced = inputs_ids[:, start_idx:].permute(0, 2, 1)
75
+ logits_token = inputs_ids_sliced.reshape(
76
+ inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1),
77
+ -1,
78
+ ).to(self.device)
79
+ else:
80
+ logits_token = inputs_ids[:, start_idx:, 0].to(self.device)
81
+
82
+ logits /= temperature
83
+
84
+ for logitsProcessors in logits_processors:
85
+ logits = logitsProcessors(logits_token, logits)
86
+
87
+ for logitsWarpers in logits_warpers:
88
+ logits = logitsWarpers(logits_token, logits)
89
+
90
+ del logits_token
91
+
92
+ if now_length < min_new_token:
93
+ logits[:, eos_token] = -torch.inf
94
+
95
+ scores = F.softmax(logits, dim=-1)
96
+ idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)
97
+ if not infer_text:
98
+ scores = scores.reshape(B, -1, scores.shape[-1])
99
+ if not infer_text:
100
+ # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
101
+ idx_next = idx_next.view(-1, self.num_vq)
102
+ finish_or = idx_next.eq(eos_token).any(1)
103
+ finish.logical_or_(finish_or)
104
+ del finish_or
105
+ else:
106
+ finish_or = idx_next.eq(eos_token).any(1)
107
+ finish.logical_or_(finish_or)
108
+ del finish_or
109
+
110
+ del inputs_ids
111
+
112
+ not_finished = finish.logical_not().to(end_idx.device)
113
+
114
+ end_idx.add_(not_finished.int())
115
+ idx_next = idx_next[:, None, :]
116
+ return (
117
+ idx_next,
118
+ torch.log(scores),
119
+ finish,
120
+ )
ChatTTS/model/velocity/sampling_params.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sampling parameters for text generation."""
2
+
3
+ from enum import IntEnum
4
+ from functools import cached_property
5
+ from typing import Callable, List, Optional, Union
6
+
7
+ import torch
8
+
9
+ _SAMPLING_EPS = 1e-5
10
+
11
+
12
+ class SamplingType(IntEnum):
13
+ GREEDY = 0
14
+ RANDOM = 1
15
+ BEAM = 2
16
+
17
+
18
+ LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
19
+ """LogitsProcessor is a function that takes a list of previously generated
20
+ tokens and a tensor of the logits for the next token, and returns a modified
21
+ tensor of logits to sample from."""
22
+
23
+
24
+ class SamplingParams:
25
+ """Sampling parameters for text generation.
26
+
27
+ Overall, we follow the sampling parameters from the OpenAI text completion
28
+ API (https://platform.openai.com/docs/api-reference/completions/create).
29
+ In addition, we support beam search, which is not supported by OpenAI.
30
+
31
+ Args:
32
+ n: Number of output sequences to return for the given prompt.
33
+ best_of: Number of output sequences that are generated from the prompt.
34
+ From these `best_of` sequences, the top `n` sequences are returned.
35
+ `best_of` must be greater than or equal to `n`. This is treated as
36
+ the beam width when `use_beam_search` is True. By default, `best_of`
37
+ is set to `n`.
38
+ presence_penalty: Float that penalizes new tokens based on whether they
39
+ appear in the generated text so far. Values > 0 encourage the model
40
+ to use new tokens, while values < 0 encourage the model to repeat
41
+ tokens.
42
+ frequency_penalty: Float that penalizes new tokens based on their
43
+ frequency in the generated text so far. Values > 0 encourage the
44
+ model to use new tokens, while values < 0 encourage the model to
45
+ repeat tokens.
46
+ repetition_penalty: Float that penalizes new tokens based on whether
47
+ they appear in the prompt and the generated text so far. Values > 1
48
+ encourage the model to use new tokens, while values < 1 encourage
49
+ the model to repeat tokens.
50
+ temperature: Float that controls the randomness of the sampling. Lower
51
+ values make the model more deterministic, while higher values make
52
+ the model more random. Zero means greedy sampling.
53
+ top_p: Float that controls the cumulative probability of the top tokens
54
+ to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
55
+ top_k: Integer that controls the number of top tokens to consider. Set
56
+ to -1 to consider all tokens.
57
+ min_p: Float that represents the minimum probability for a token to be
58
+ considered, relative to the probability of the most likely token.
59
+ Must be in [0, 1]. Set to 0 to disable this.
60
+ use_beam_search: Whether to use beam search instead of sampling.
61
+ length_penalty: Float that penalizes sequences based on their length.
62
+ Used in beam search.
63
+ early_stopping: Controls the stopping condition for beam search. It
64
+ accepts the following values: `True`, where the generation stops as
65
+ soon as there are `best_of` complete candidates; `False`, where an
66
+ heuristic is applied and the generation stops when is it very
67
+ unlikely to find better candidates; `"never"`, where the beam search
68
+ procedure only stops when there cannot be better candidates
69
+ (canonical beam search algorithm).
70
+ stop: List of strings that stop the generation when they are generated.
71
+ The returned output will not contain the stop strings.
72
+ stop_token_ids: List of tokens that stop the generation when they are
73
+ generated. The returned output will contain the stop tokens unless
74
+ the stop tokens are special tokens.
75
+ include_stop_str_in_output: Whether to include the stop strings in output
76
+ text. Defaults to False.
77
+ ignore_eos: Whether to ignore the EOS token and continue generating
78
+ tokens after the EOS token is generated.
79
+ max_tokens: Maximum number of tokens to generate per output sequence.
80
+ logprobs: Number of log probabilities to return per output token.
81
+ Note that the implementation follows the OpenAI API: The return
82
+ result includes the log probabilities on the `logprobs` most likely
83
+ tokens, as well the chosen tokens. The API will always return the
84
+ log probability of the sampled token, so there may be up to
85
+ `logprobs+1` elements in the response.
86
+ prompt_logprobs: Number of log probabilities to return per prompt token.
87
+ skip_special_tokens: Whether to skip special tokens in the output.
88
+ spaces_between_special_tokens: Whether to add spaces between special
89
+ tokens in the output. Defaults to True.
90
+ logits_processors: List of functions that modify logits based on
91
+ previously generated tokens.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ n: int = 1,
97
+ best_of: Optional[int] = None,
98
+ presence_penalty: float = 0.0,
99
+ frequency_penalty: float = 0.0,
100
+ repetition_penalty: float = 1.0,
101
+ temperature: float = 1.0,
102
+ top_p: float = 1.0,
103
+ top_k: int = -1,
104
+ min_p: float = 0.0,
105
+ use_beam_search: bool = False,
106
+ length_penalty: float = 1.0,
107
+ early_stopping: Union[bool, str] = False,
108
+ stop: Optional[Union[str, List[str]]] = None,
109
+ stop_token_ids: Optional[List[int]] = None,
110
+ include_stop_str_in_output: bool = False,
111
+ ignore_eos: bool = False,
112
+ max_tokens: int = 16,
113
+ logprobs: Optional[int] = None,
114
+ prompt_logprobs: Optional[int] = None,
115
+ skip_special_tokens: bool = True,
116
+ spaces_between_special_tokens: bool = True,
117
+ logits_processors: Optional[List[LogitsProcessor]] = (
118
+ [
119
+ lambda logits_token, logits: logits,
120
+ ],
121
+ [
122
+ lambda logits_token, logits: logits,
123
+ ],
124
+ ),
125
+ min_new_token: int = 0,
126
+ max_new_token: int = 8192,
127
+ infer_text: bool = False,
128
+ eos_token: int = 0,
129
+ spk_emb: str = None,
130
+ start_idx: int = 0,
131
+ ) -> None:
132
+ self.n = n
133
+ self.best_of = best_of if best_of is not None else n
134
+ self.presence_penalty = presence_penalty
135
+ self.frequency_penalty = frequency_penalty
136
+ self.repetition_penalty = repetition_penalty
137
+ self.temperature = temperature
138
+ self.top_p = top_p
139
+ self.top_k = top_k
140
+ self.min_p = min_p
141
+ self.use_beam_search = use_beam_search
142
+ self.length_penalty = length_penalty
143
+ self.early_stopping = early_stopping
144
+ self.min_new_token = min_new_token
145
+ self.max_new_token = max_new_token
146
+ self.infer_text = infer_text
147
+ self.eos_token = eos_token
148
+ self.spk_emb = spk_emb
149
+ self.start_idx = start_idx
150
+ if stop is None:
151
+ self.stop = []
152
+ elif isinstance(stop, str):
153
+ self.stop = [stop]
154
+ else:
155
+ self.stop = list(stop)
156
+ if stop_token_ids is None:
157
+ self.stop_token_ids = []
158
+ else:
159
+ self.stop_token_ids = list(stop_token_ids)
160
+ self.ignore_eos = ignore_eos
161
+ self.max_tokens = max_tokens
162
+ self.logprobs = logprobs
163
+ self.prompt_logprobs = prompt_logprobs
164
+ self.skip_special_tokens = skip_special_tokens
165
+ self.spaces_between_special_tokens = spaces_between_special_tokens
166
+ self.logits_processors = logits_processors
167
+ self.include_stop_str_in_output = include_stop_str_in_output
168
+ self._verify_args()
169
+ if self.use_beam_search:
170
+ self._verify_beam_search()
171
+ else:
172
+ self._verify_non_beam_search()
173
+ # if self.temperature < _SAMPLING_EPS:
174
+ # # Zero temperature means greedy sampling.
175
+ # self.top_p = 1.0
176
+ # self.top_k = -1
177
+ # self.min_p = 0.0
178
+ # self._verify_greedy_sampling()
179
+
180
+ def _verify_args(self) -> None:
181
+ if self.n < 1:
182
+ raise ValueError(f"n must be at least 1, got {self.n}.")
183
+ if self.best_of < self.n:
184
+ raise ValueError(
185
+ f"best_of must be greater than or equal to n, "
186
+ f"got n={self.n} and best_of={self.best_of}."
187
+ )
188
+ if not -2.0 <= self.presence_penalty <= 2.0:
189
+ raise ValueError(
190
+ "presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}."
191
+ )
192
+ if not -2.0 <= self.frequency_penalty <= 2.0:
193
+ raise ValueError(
194
+ "frequency_penalty must be in [-2, 2], got "
195
+ f"{self.frequency_penalty}."
196
+ )
197
+ if not 0.0 < self.repetition_penalty <= 2.0:
198
+ raise ValueError(
199
+ "repetition_penalty must be in (0, 2], got "
200
+ f"{self.repetition_penalty}."
201
+ )
202
+ # if self.temperature < 0.0:
203
+ # raise ValueError(
204
+ # f"temperature must be non-negative, got {self.temperature}.")
205
+ if not 0.0 < self.top_p <= 1.0:
206
+ raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
207
+ if self.top_k < -1 or self.top_k == 0:
208
+ raise ValueError(
209
+ f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
210
+ )
211
+ if not 0.0 <= self.min_p <= 1.0:
212
+ raise ValueError("min_p must be in [0, 1], got " f"{self.min_p}.")
213
+ if self.max_tokens < 1:
214
+ raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.")
215
+ if self.logprobs is not None and self.logprobs < 0:
216
+ raise ValueError(f"logprobs must be non-negative, got {self.logprobs}.")
217
+ if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
218
+ raise ValueError(
219
+ f"prompt_logprobs must be non-negative, got " f"{self.prompt_logprobs}."
220
+ )
221
+
222
+ def _verify_beam_search(self) -> None:
223
+ if self.best_of == 1:
224
+ raise ValueError(
225
+ "best_of must be greater than 1 when using beam "
226
+ f"search. Got {self.best_of}."
227
+ )
228
+ if self.temperature > _SAMPLING_EPS:
229
+ raise ValueError("temperature must be 0 when using beam search.")
230
+ if self.top_p < 1.0 - _SAMPLING_EPS:
231
+ raise ValueError("top_p must be 1 when using beam search.")
232
+ if self.top_k != -1:
233
+ raise ValueError("top_k must be -1 when using beam search.")
234
+ if self.early_stopping not in [True, False, "never"]:
235
+ raise ValueError(
236
+ f"early_stopping must be True, False, or 'never', "
237
+ f"got {self.early_stopping}."
238
+ )
239
+
240
+ def _verify_non_beam_search(self) -> None:
241
+ if self.early_stopping is not False:
242
+ raise ValueError(
243
+ "early_stopping is not effective and must be "
244
+ "False when not using beam search."
245
+ )
246
+ if (
247
+ self.length_penalty < 1.0 - _SAMPLING_EPS
248
+ or self.length_penalty > 1.0 + _SAMPLING_EPS
249
+ ):
250
+ raise ValueError(
251
+ "length_penalty is not effective and must be the "
252
+ "default value of 1.0 when not using beam search."
253
+ )
254
+
255
+ def _verify_greedy_sampling(self) -> None:
256
+ if self.best_of > 1:
257
+ raise ValueError(
258
+ "best_of must be 1 when using greedy sampling." f"Got {self.best_of}."
259
+ )
260
+
261
+ @cached_property
262
+ def sampling_type(self) -> SamplingType:
263
+ if self.use_beam_search:
264
+ return SamplingType.BEAM
265
+ # if self.temperature < _SAMPLING_EPS:
266
+ # return SamplingType.GREEDY
267
+ return SamplingType.RANDOM
268
+
269
+ def __repr__(self) -> str:
270
+ return (
271
+ f"SamplingParams(n={self.n}, "
272
+ f"best_of={self.best_of}, "
273
+ f"presence_penalty={self.presence_penalty}, "
274
+ f"frequency_penalty={self.frequency_penalty}, "
275
+ f"repetition_penalty={self.repetition_penalty}, "
276
+ f"temperature={self.temperature}, "
277
+ f"top_p={self.top_p}, "
278
+ f"top_k={self.top_k}, "
279
+ f"min_p={self.min_p}, "
280
+ f"use_beam_search={self.use_beam_search}, "
281
+ f"length_penalty={self.length_penalty}, "
282
+ f"early_stopping={self.early_stopping}, "
283
+ f"stop={self.stop}, "
284
+ f"stop_token_ids={self.stop_token_ids}, "
285
+ f"include_stop_str_in_output={self.include_stop_str_in_output}, "
286
+ f"ignore_eos={self.ignore_eos}, "
287
+ f"max_tokens={self.max_tokens}, "
288
+ f"logprobs={self.logprobs}, "
289
+ f"prompt_logprobs={self.prompt_logprobs}, "
290
+ f"skip_special_tokens={self.skip_special_tokens}, "
291
+ "spaces_between_special_tokens="
292
+ f"{self.spaces_between_special_tokens}), "
293
+ f"max_new_token={self.max_new_token}), "
294
+ f"min_new_token={self.min_new_token}), "
295
+ f"infer_text={self.infer_text})"
296
+ )
ChatTTS/model/velocity/scheduler.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ import time
3
+ from typing import Dict, Iterable, List, Optional, Tuple, Union
4
+
5
+ from vllm.config import CacheConfig, SchedulerConfig
6
+ from .block_manager import AllocStatus, BlockSpaceManager
7
+ from vllm.core.policy import PolicyFactory
8
+ from vllm.logger import init_logger
9
+ from .sequence import (
10
+ Sequence,
11
+ SequenceData,
12
+ SequenceGroup,
13
+ SequenceGroupMetadata,
14
+ SequenceStatus,
15
+ )
16
+
17
+ logger = init_logger(__name__)
18
+
19
+
20
+ class PreemptionMode(enum.Enum):
21
+ """Preemption modes.
22
+
23
+ 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
24
+ and swap them back in when the sequences are resumed.
25
+ 2. Recomputation: Discard the blocks of the preempted sequences and
26
+ recompute them when the sequences are resumed, treating the sequences as
27
+ new prompts.
28
+ """
29
+
30
+ SWAP = enum.auto()
31
+ RECOMPUTE = enum.auto()
32
+
33
+
34
+ class SchedulerOutputs:
35
+
36
+ def __init__(
37
+ self,
38
+ scheduled_seq_groups: List[SequenceGroup],
39
+ prompt_run: bool,
40
+ num_batched_tokens: int,
41
+ blocks_to_swap_in: Dict[int, int],
42
+ blocks_to_swap_out: Dict[int, int],
43
+ blocks_to_copy: Dict[int, List[int]],
44
+ ignored_seq_groups: List[SequenceGroup],
45
+ ) -> None:
46
+ self.scheduled_seq_groups = scheduled_seq_groups
47
+ self.prompt_run = prompt_run
48
+ self.num_batched_tokens = num_batched_tokens
49
+ self.blocks_to_swap_in = blocks_to_swap_in
50
+ self.blocks_to_swap_out = blocks_to_swap_out
51
+ self.blocks_to_copy = blocks_to_copy
52
+ # Swap in and swap out should never happen at the same time.
53
+ assert not (blocks_to_swap_in and blocks_to_swap_out)
54
+ self.ignored_seq_groups = ignored_seq_groups
55
+
56
+ def is_empty(self) -> bool:
57
+ # NOTE: We do not consider the ignored sequence groups.
58
+ return (
59
+ not self.scheduled_seq_groups
60
+ and not self.blocks_to_swap_in
61
+ and not self.blocks_to_swap_out
62
+ and not self.blocks_to_copy
63
+ )
64
+
65
+
66
+ class Scheduler:
67
+
68
+ def __init__(
69
+ self,
70
+ scheduler_config: SchedulerConfig,
71
+ cache_config: CacheConfig,
72
+ ) -> None:
73
+ self.scheduler_config = scheduler_config
74
+ self.cache_config = cache_config
75
+
76
+ self.prompt_limit = min(
77
+ self.scheduler_config.max_model_len,
78
+ self.scheduler_config.max_num_batched_tokens,
79
+ )
80
+
81
+ # Instantiate the scheduling policy.
82
+ self.policy = PolicyFactory.get_policy(policy_name="fcfs")
83
+ # Create the block space manager.
84
+ self.block_manager = BlockSpaceManager(
85
+ block_size=self.cache_config.block_size,
86
+ num_gpu_blocks=self.cache_config.num_gpu_blocks,
87
+ num_cpu_blocks=self.cache_config.num_cpu_blocks,
88
+ sliding_window=self.cache_config.sliding_window,
89
+ )
90
+
91
+ # TODO(zhuohan): Use deque instead of list for better performance.
92
+ # Sequence groups in the WAITING state.
93
+ self.waiting: List[SequenceGroup] = []
94
+ # Sequence groups in the RUNNING state.
95
+ self.running: List[SequenceGroup] = []
96
+ # Sequence groups in the SWAPPED state.
97
+ self.swapped: List[SequenceGroup] = []
98
+
99
+ def add_seq_group(self, seq_group: SequenceGroup) -> None:
100
+ # Add sequence groups to the waiting queue.
101
+ self.waiting.append(seq_group)
102
+
103
+ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
104
+ if isinstance(request_id, str):
105
+ request_id = (request_id,)
106
+ request_ids = set(request_id)
107
+ for state_queue in [self.waiting, self.running, self.swapped]:
108
+ # We need to reverse the list as we are removing elements
109
+ # from it as we iterate over it. If we don't do it,
110
+ # indices will get messed up and we will skip over elements.
111
+ for seq_group in reversed(state_queue):
112
+ if seq_group.request_id in request_ids:
113
+ # Remove the sequence group from the state queue.
114
+ state_queue.remove(seq_group)
115
+ for seq in seq_group.get_seqs():
116
+ if seq.is_finished():
117
+ continue
118
+ seq.status = SequenceStatus.FINISHED_ABORTED
119
+ self.free_seq(seq)
120
+ request_ids.remove(seq_group.request_id)
121
+ if not request_ids:
122
+ return
123
+
124
+ def has_unfinished_seqs(self) -> bool:
125
+ return self.waiting or self.running or self.swapped
126
+
127
+ def get_num_unfinished_seq_groups(self) -> int:
128
+ return len(self.waiting) + len(self.running) + len(self.swapped)
129
+
130
+ def _schedule(self) -> SchedulerOutputs:
131
+ # Blocks that need to be swaped or copied before model execution.
132
+ blocks_to_swap_in: Dict[int, int] = {}
133
+ blocks_to_swap_out: Dict[int, int] = {}
134
+ blocks_to_copy: Dict[int, List[int]] = {}
135
+
136
+ # Fix the current time.
137
+ now = time.monotonic()
138
+
139
+ # Join waiting sequences if possible.
140
+ if not self.swapped:
141
+ ignored_seq_groups: List[SequenceGroup] = []
142
+ scheduled: List[SequenceGroup] = []
143
+ # The total number of sequences on the fly, including the
144
+ # requests in the generation phase.
145
+ num_curr_seqs = sum(
146
+ seq_group.get_max_num_running_seqs() for seq_group in self.running
147
+ )
148
+ seq_lens: List[int] = []
149
+
150
+ # Optimization: We do not sort the waiting queue since the preempted
151
+ # sequence groups are added to the front and the new sequence groups
152
+ # are added to the back.
153
+ while self.waiting:
154
+ seq_group = self.waiting[0]
155
+
156
+ waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
157
+ assert len(waiting_seqs) == 1, (
158
+ "Waiting sequence group should have only one prompt " "sequence."
159
+ )
160
+ num_prompt_tokens = waiting_seqs[0].get_len()
161
+ if num_prompt_tokens > self.prompt_limit:
162
+ logger.warning(
163
+ f"Input prompt ({num_prompt_tokens} tokens) is too long"
164
+ f" and exceeds limit of {self.prompt_limit}"
165
+ )
166
+ for seq in waiting_seqs:
167
+ seq.status = SequenceStatus.FINISHED_IGNORED
168
+ ignored_seq_groups.append(seq_group)
169
+ self.waiting.pop(0)
170
+ continue
171
+
172
+ # If the sequence group cannot be allocated, stop.
173
+ can_allocate = self.block_manager.can_allocate(seq_group)
174
+ if can_allocate == AllocStatus.LATER:
175
+ break
176
+ elif can_allocate == AllocStatus.NEVER:
177
+ logger.warning(
178
+ f"Input prompt ({num_prompt_tokens} tokens) is too long"
179
+ f" and exceeds the capacity of block_manager"
180
+ )
181
+ for seq in waiting_seqs:
182
+ seq.status = SequenceStatus.FINISHED_IGNORED
183
+ ignored_seq_groups.append(seq_group)
184
+ self.waiting.pop(0)
185
+ continue
186
+
187
+ # If the number of batched tokens exceeds the limit, stop.
188
+ new_seq_lens = seq_lens + [num_prompt_tokens]
189
+ num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
190
+ if num_batched_tokens > self.scheduler_config.max_num_batched_tokens:
191
+ break
192
+
193
+ # The total number of sequences in the RUNNING state should not
194
+ # exceed the maximum number of sequences.
195
+ num_new_seqs = seq_group.get_max_num_running_seqs()
196
+ if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
197
+ break
198
+
199
+ num_paddings = num_batched_tokens - sum(new_seq_lens)
200
+ if num_paddings > self.scheduler_config.max_paddings:
201
+ break
202
+ seq_lens = new_seq_lens
203
+
204
+ seq_group = self.waiting.pop(0)
205
+ self._allocate(seq_group)
206
+ self.running.append(seq_group)
207
+ num_curr_seqs += num_new_seqs
208
+ scheduled.append(seq_group)
209
+
210
+ if scheduled or ignored_seq_groups:
211
+ scheduler_outputs = SchedulerOutputs(
212
+ scheduled_seq_groups=scheduled,
213
+ prompt_run=True,
214
+ num_batched_tokens=len(seq_lens) * max(seq_lens) if seq_lens else 0,
215
+ blocks_to_swap_in=blocks_to_swap_in,
216
+ blocks_to_swap_out=blocks_to_swap_out,
217
+ blocks_to_copy=blocks_to_copy,
218
+ ignored_seq_groups=ignored_seq_groups,
219
+ )
220
+ return scheduler_outputs
221
+
222
+ # NOTE(woosuk): Preemption happens only when there is no available slot
223
+ # to keep all the sequence groups in the RUNNING state.
224
+ # In this case, the policy is responsible for deciding which sequence
225
+ # groups to preempt.
226
+ self.running = self.policy.sort_by_priority(now, self.running)
227
+
228
+ # Reserve new token slots for the running sequence groups.
229
+ running: List[SequenceGroup] = []
230
+ preempted: List[SequenceGroup] = []
231
+ while self.running:
232
+ seq_group = self.running.pop(0)
233
+ while not self.block_manager.can_append_slot(seq_group):
234
+ if self.running:
235
+ # Preempt the lowest-priority sequence groups.
236
+ victim_seq_group = self.running.pop(-1)
237
+ self._preempt(victim_seq_group, blocks_to_swap_out)
238
+ preempted.append(victim_seq_group)
239
+ else:
240
+ # No other sequence groups can be preempted.
241
+ # Preempt the current sequence group.
242
+ self._preempt(seq_group, blocks_to_swap_out)
243
+ preempted.append(seq_group)
244
+ break
245
+ else:
246
+ # Append new slots to the sequence group.
247
+ self._append_slot(seq_group, blocks_to_copy)
248
+ running.append(seq_group)
249
+ self.running = running
250
+
251
+ # Swap in the sequence groups in the SWAPPED state if possible.
252
+ self.swapped = self.policy.sort_by_priority(now, self.swapped)
253
+ if not preempted:
254
+ num_curr_seqs = sum(
255
+ seq_group.get_max_num_running_seqs() for seq_group in self.running
256
+ )
257
+
258
+ while self.swapped:
259
+ seq_group = self.swapped[0]
260
+ # If the sequence group cannot be swapped in, stop.
261
+ if not self.block_manager.can_swap_in(seq_group):
262
+ break
263
+
264
+ # The total number of sequences in the RUNNING state should not
265
+ # exceed the maximum number of sequences.
266
+ num_new_seqs = seq_group.get_max_num_running_seqs()
267
+ if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
268
+ break
269
+
270
+ seq_group = self.swapped.pop(0)
271
+ self._swap_in(seq_group, blocks_to_swap_in)
272
+ self._append_slot(seq_group, blocks_to_copy)
273
+ num_curr_seqs += num_new_seqs
274
+ self.running.append(seq_group)
275
+
276
+ # Each sequence in the generation phase only takes one token slot.
277
+ # Therefore, the number of batched tokens is equal to the number of
278
+ # sequences in the RUNNING state.
279
+ num_batched_tokens = sum(
280
+ seq_group.num_seqs(status=SequenceStatus.RUNNING)
281
+ for seq_group in self.running
282
+ )
283
+
284
+ scheduler_outputs = SchedulerOutputs(
285
+ scheduled_seq_groups=self.running,
286
+ prompt_run=False,
287
+ num_batched_tokens=num_batched_tokens,
288
+ blocks_to_swap_in=blocks_to_swap_in,
289
+ blocks_to_swap_out=blocks_to_swap_out,
290
+ blocks_to_copy=blocks_to_copy,
291
+ ignored_seq_groups=[],
292
+ )
293
+ return scheduler_outputs
294
+
295
+ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
296
+ # Schedule sequence groups.
297
+ # This function call changes the internal states of the scheduler
298
+ # such as self.running, self.swapped, and self.waiting.
299
+ scheduler_outputs = self._schedule()
300
+
301
+ # Create input data structures.
302
+ seq_group_metadata_list: List[SequenceGroupMetadata] = []
303
+ for seq_group in scheduler_outputs.scheduled_seq_groups:
304
+ seq_data: Dict[int, SequenceData] = {}
305
+ block_tables: Dict[int, List[int]] = {}
306
+ for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
307
+ seq_id = seq.seq_id
308
+ seq_data[seq_id] = seq.data
309
+ block_tables[seq_id] = self.block_manager.get_block_table(seq)
310
+
311
+ seq_group_metadata = SequenceGroupMetadata(
312
+ request_id=seq_group.request_id,
313
+ is_prompt=scheduler_outputs.prompt_run,
314
+ seq_data=seq_data,
315
+ sampling_params=seq_group.sampling_params,
316
+ block_tables=block_tables,
317
+ )
318
+ seq_group_metadata_list.append(seq_group_metadata)
319
+ return seq_group_metadata_list, scheduler_outputs
320
+
321
+ def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
322
+ self.block_manager.fork(parent_seq, child_seq)
323
+
324
+ def free_seq(self, seq: Sequence) -> None:
325
+ self.block_manager.free(seq)
326
+
327
+ def free_finished_seq_groups(self) -> None:
328
+ self.running = [
329
+ seq_group for seq_group in self.running if not seq_group.is_finished()
330
+ ]
331
+
332
+ def _allocate(self, seq_group: SequenceGroup) -> None:
333
+ self.block_manager.allocate(seq_group)
334
+ for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
335
+ seq.status = SequenceStatus.RUNNING
336
+
337
+ def _append_slot(
338
+ self,
339
+ seq_group: SequenceGroup,
340
+ blocks_to_copy: Dict[int, List[int]],
341
+ ) -> None:
342
+ for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
343
+ ret = self.block_manager.append_slot(seq)
344
+ if ret is not None:
345
+ src_block, dst_block = ret
346
+ if src_block in blocks_to_copy:
347
+ blocks_to_copy[src_block].append(dst_block)
348
+ else:
349
+ blocks_to_copy[src_block] = [dst_block]
350
+
351
+ def _preempt(
352
+ self,
353
+ seq_group: SequenceGroup,
354
+ blocks_to_swap_out: Dict[int, int],
355
+ preemption_mode: Optional[PreemptionMode] = None,
356
+ ) -> None:
357
+ # If preemption mode is not specified, we determine the mode as follows:
358
+ # We use recomputation by default since it incurs lower overhead than
359
+ # swapping. However, when the sequence group has multiple sequences
360
+ # (e.g., beam search), recomputation is not currently supported. In
361
+ # such a case, we use swapping instead.
362
+ # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
363
+ # As swapped sequences are prioritized over waiting sequences,
364
+ # sequence groups with multiple sequences are implicitly prioritized
365
+ # over sequence groups with a single sequence.
366
+ # TODO(woosuk): Support recomputation for sequence groups with multiple
367
+ # sequences. This may require a more sophisticated CUDA kernel.
368
+ if preemption_mode is None:
369
+ if seq_group.get_max_num_running_seqs() == 1:
370
+ preemption_mode = PreemptionMode.RECOMPUTE
371
+ else:
372
+ preemption_mode = PreemptionMode.SWAP
373
+ if preemption_mode == PreemptionMode.RECOMPUTE:
374
+ self._preempt_by_recompute(seq_group)
375
+ elif preemption_mode == PreemptionMode.SWAP:
376
+ self._preempt_by_swap(seq_group, blocks_to_swap_out)
377
+ else:
378
+ raise AssertionError("Invalid preemption mode.")
379
+
380
+ def _preempt_by_recompute(
381
+ self,
382
+ seq_group: SequenceGroup,
383
+ ) -> None:
384
+ seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
385
+ assert len(seqs) == 1
386
+ for seq in seqs:
387
+ seq.status = SequenceStatus.WAITING
388
+ self.block_manager.free(seq)
389
+ # NOTE: For FCFS, we insert the preempted sequence group to the front
390
+ # of the waiting queue.
391
+ self.waiting.insert(0, seq_group)
392
+
393
+ def _preempt_by_swap(
394
+ self,
395
+ seq_group: SequenceGroup,
396
+ blocks_to_swap_out: Dict[int, int],
397
+ ) -> None:
398
+ self._swap_out(seq_group, blocks_to_swap_out)
399
+ self.swapped.append(seq_group)
400
+
401
+ def _swap_in(
402
+ self,
403
+ seq_group: SequenceGroup,
404
+ blocks_to_swap_in: Dict[int, int],
405
+ ) -> None:
406
+ mapping = self.block_manager.swap_in(seq_group)
407
+ blocks_to_swap_in.update(mapping)
408
+ for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
409
+ seq.status = SequenceStatus.RUNNING
410
+
411
+ def _swap_out(
412
+ self,
413
+ seq_group: SequenceGroup,
414
+ blocks_to_swap_out: Dict[int, int],
415
+ ) -> None:
416
+ if not self.block_manager.can_swap_out(seq_group):
417
+ # FIXME(woosuk): Abort the sequence group instead of aborting the
418
+ # entire engine.
419
+ raise RuntimeError(
420
+ "Aborted due to the lack of CPU swap space. Please increase "
421
+ "the swap space to avoid this error."
422
+ )
423
+ mapping = self.block_manager.swap_out(seq_group)
424
+ blocks_to_swap_out.update(mapping)
425
+ for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
426
+ seq.status = SequenceStatus.SWAPPED
ChatTTS/model/velocity/sequence.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sequence and its related classes."""
2
+
3
+ import copy
4
+ import enum
5
+ from typing import Dict, List, Optional, Union
6
+ import torch
7
+ from vllm.block import LogicalTokenBlock
8
+ from .sampling_params import SamplingParams
9
+
10
+ PromptLogprobs = List[Optional[Dict[int, float]]]
11
+ SampleLogprobs = List[Dict[int, float]]
12
+
13
+
14
+ class SequenceStatus(enum.Enum):
15
+ """Status of a sequence."""
16
+
17
+ WAITING = enum.auto()
18
+ RUNNING = enum.auto()
19
+ SWAPPED = enum.auto()
20
+ FINISHED_STOPPED = enum.auto()
21
+ FINISHED_LENGTH_CAPPED = enum.auto()
22
+ FINISHED_ABORTED = enum.auto()
23
+ FINISHED_IGNORED = enum.auto()
24
+
25
+ @staticmethod
26
+ def is_finished(status: "SequenceStatus") -> bool:
27
+ return status in [
28
+ SequenceStatus.FINISHED_STOPPED,
29
+ SequenceStatus.FINISHED_LENGTH_CAPPED,
30
+ SequenceStatus.FINISHED_ABORTED,
31
+ SequenceStatus.FINISHED_IGNORED,
32
+ ]
33
+
34
+ @staticmethod
35
+ def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
36
+ if status == SequenceStatus.FINISHED_STOPPED:
37
+ finish_reason = "stop"
38
+ elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
39
+ finish_reason = "length"
40
+ elif status == SequenceStatus.FINISHED_ABORTED:
41
+ finish_reason = "abort"
42
+ elif status == SequenceStatus.FINISHED_IGNORED:
43
+ # The ignored sequences are the sequences whose prompt lengths
44
+ # are longer than the model's length cap. Therefore, the stop
45
+ # reason should also be "length" as in OpenAI API.
46
+ finish_reason = "length"
47
+ else:
48
+ finish_reason = None
49
+ return finish_reason
50
+
51
+
52
+ class SequenceData:
53
+ """Data associated with a sequence.
54
+
55
+
56
+ Args:
57
+ prompt_token_ids: The token IDs of the prompt.
58
+
59
+ Attributes:
60
+ prompt_token_ids: The token IDs of the prompt.
61
+ output_token_ids: The token IDs of the output.
62
+ cumulative_logprob: The cumulative log probability of the output.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ prompt_token_ids: List[int],
68
+ ) -> None:
69
+ self.prompt_token_ids = prompt_token_ids
70
+ self.output_token_ids: List[int] = []
71
+ self.cumulative_logprob = 0.0
72
+ self.hidden_states: Optional[torch.Tensor] = None
73
+ self.finished = False
74
+
75
+ def append_token_id(self, token_id: int, logprob: float) -> None:
76
+ if isinstance(self.cumulative_logprob, float):
77
+ self.cumulative_logprob = [
78
+ 0.0,
79
+ ] * len(logprob)
80
+ self.output_token_ids.append(token_id)
81
+ for i in range(len(self.cumulative_logprob)):
82
+ self.cumulative_logprob[i] += logprob[i]
83
+
84
+ def append_hidden_states(self, hidden_states: torch.Tensor) -> None:
85
+ if self.hidden_states is None:
86
+ self.hidden_states = hidden_states
87
+ else:
88
+ self.hidden_states = torch.cat([self.hidden_states, hidden_states], dim=0)
89
+
90
+ def get_len(self) -> int:
91
+ return len(self.output_token_ids) + len(self.prompt_token_ids)
92
+
93
+ def get_prompt_len(self) -> int:
94
+ return len(self.prompt_token_ids)
95
+
96
+ def get_output_len(self) -> int:
97
+ return len(self.output_token_ids)
98
+
99
+ def get_token_ids(self) -> List[int]:
100
+ return self.prompt_token_ids + self.output_token_ids
101
+
102
+ def get_last_token_id(self) -> int:
103
+ if not self.output_token_ids:
104
+ return self.prompt_token_ids[-1]
105
+ return self.output_token_ids[-1]
106
+
107
+ def __repr__(self) -> str:
108
+ return (
109
+ f"SequenceData("
110
+ f"prompt_token_ids={self.prompt_token_ids}, "
111
+ f"output_token_ids={self.output_token_ids}, "
112
+ f"cumulative_logprob={self.cumulative_logprob}), "
113
+ f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None}, "
114
+ f"finished={self.finished})"
115
+ )
116
+
117
+
118
+ class Sequence:
119
+ """Stores the data, status, and block information of a sequence.
120
+
121
+ Args:
122
+ seq_id: The ID of the sequence.
123
+ prompt: The prompt of the sequence.
124
+ prompt_token_ids: The token IDs of the prompt.
125
+ block_size: The block size of the sequence. Should be the same as the
126
+ block size used by the block manager and cache engine.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ seq_id: int,
132
+ prompt: str,
133
+ prompt_token_ids: List[int],
134
+ block_size: int,
135
+ ) -> None:
136
+ self.seq_id = seq_id
137
+ self.prompt = prompt
138
+ self.block_size = block_size
139
+
140
+ self.data = SequenceData(prompt_token_ids)
141
+ self.output_logprobs: SampleLogprobs = []
142
+ self.output_text = ""
143
+
144
+ self.logical_token_blocks: List[LogicalTokenBlock] = []
145
+ # Initialize the logical token blocks with the prompt token ids.
146
+ self._append_tokens_to_blocks(prompt_token_ids)
147
+ self.status = SequenceStatus.WAITING
148
+
149
+ # Used for incremental detokenization
150
+ self.prefix_offset = 0
151
+ self.read_offset = 0
152
+ # Input + output tokens
153
+ self.tokens: Optional[List[str]] = None
154
+
155
+ def _append_logical_block(self) -> None:
156
+ block = LogicalTokenBlock(
157
+ block_number=len(self.logical_token_blocks),
158
+ block_size=self.block_size,
159
+ )
160
+ self.logical_token_blocks.append(block)
161
+
162
+ def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
163
+ cursor = 0
164
+ while cursor < len(token_ids):
165
+ if not self.logical_token_blocks:
166
+ self._append_logical_block()
167
+
168
+ last_block = self.logical_token_blocks[-1]
169
+ if last_block.is_full():
170
+ self._append_logical_block()
171
+ last_block = self.logical_token_blocks[-1]
172
+
173
+ num_empty_slots = last_block.get_num_empty_slots()
174
+ last_block.append_tokens(token_ids[cursor : cursor + num_empty_slots])
175
+ cursor += num_empty_slots
176
+
177
+ def append_token_id(
178
+ self,
179
+ token_id: int,
180
+ logprobs: Dict[int, float],
181
+ hidden_states: Optional[torch.Tensor] = None,
182
+ finished: bool = False,
183
+ ) -> None:
184
+ assert token_id in logprobs
185
+ self._append_tokens_to_blocks([token_id])
186
+ self.output_logprobs.append(logprobs)
187
+ self.data.append_token_id(token_id, logprobs[token_id])
188
+ self.data.append_hidden_states(hidden_states)
189
+ self.data.finished = finished
190
+
191
+ def get_len(self) -> int:
192
+ return self.data.get_len()
193
+
194
+ def get_prompt_len(self) -> int:
195
+ return self.data.get_prompt_len()
196
+
197
+ def get_output_len(self) -> int:
198
+ return self.data.get_output_len()
199
+
200
+ def get_token_ids(self) -> List[int]:
201
+ return self.data.get_token_ids()
202
+
203
+ def get_last_token_id(self) -> int:
204
+ return self.data.get_last_token_id()
205
+
206
+ def get_output_token_ids(self) -> List[int]:
207
+ return self.data.output_token_ids
208
+
209
+ def get_cumulative_logprob(self) -> float:
210
+ return self.data.cumulative_logprob
211
+
212
+ def get_beam_search_score(
213
+ self,
214
+ length_penalty: float = 0.0,
215
+ seq_len: Optional[int] = None,
216
+ eos_token_id: Optional[int] = None,
217
+ ) -> float:
218
+ """Calculate the beam search score with length penalty.
219
+
220
+ Adapted from
221
+
222
+ https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
223
+ """
224
+ if seq_len is None:
225
+ seq_len = self.get_len()
226
+ # NOTE: HF implementation does not count the EOS token
227
+ # towards the length, we align with that here for testing.
228
+ if eos_token_id is not None and self.get_last_token_id() == eos_token_id:
229
+ seq_len -= 1
230
+ return self.get_cumulative_logprob() / (seq_len**length_penalty)
231
+
232
+ def is_finished(self) -> bool:
233
+ return SequenceStatus.is_finished(self.status)
234
+
235
+ def fork(self, new_seq_id: int) -> "Sequence":
236
+ new_seq = copy.deepcopy(self)
237
+ new_seq.seq_id = new_seq_id
238
+ return new_seq
239
+
240
+ def __repr__(self) -> str:
241
+ return (
242
+ f"Sequence(seq_id={self.seq_id}, "
243
+ f"status={self.status.name}, "
244
+ f"num_blocks={len(self.logical_token_blocks)})"
245
+ )
246
+
247
+
248
+ class SequenceGroup:
249
+ """A group of sequences that are generated from the same prompt.
250
+
251
+ Args:
252
+ request_id: The ID of the request.
253
+ seqs: The list of sequences.
254
+ sampling_params: The sampling parameters used to generate the outputs.
255
+ arrival_time: The arrival time of the request.
256
+ """
257
+
258
+ def __init__(
259
+ self,
260
+ request_id: str,
261
+ seqs: List[Sequence],
262
+ sampling_params: SamplingParams,
263
+ arrival_time: float,
264
+ ) -> None:
265
+ self.request_id = request_id
266
+ self.seqs_dict = {seq.seq_id: seq for seq in seqs}
267
+ self.sampling_params = sampling_params
268
+ self.arrival_time = arrival_time
269
+ self.prompt_logprobs: Optional[PromptLogprobs] = None
270
+
271
+ @property
272
+ def prompt(self) -> str:
273
+ # All sequences in the group should have the same prompt.
274
+ # We use the prompt of an arbitrary sequence.
275
+ return next(iter(self.seqs_dict.values())).prompt
276
+
277
+ @property
278
+ def prompt_token_ids(self) -> List[int]:
279
+ # All sequences in the group should have the same prompt.
280
+ # We use the prompt of an arbitrary sequence.
281
+ return next(iter(self.seqs_dict.values())).data.prompt_token_ids
282
+
283
+ def get_max_num_running_seqs(self) -> int:
284
+ """The maximum number of sequences running in parallel in the remaining
285
+ lifetime of the request."""
286
+ if self.sampling_params.use_beam_search:
287
+ # For beam search, maximally there will always be `best_of` beam
288
+ # candidates running in the future.
289
+ return self.sampling_params.best_of
290
+ else:
291
+ if self.sampling_params.best_of > self.num_seqs():
292
+ # At prompt stage, the sequence group is not yet filled up
293
+ # and only have one sequence running. However, in the
294
+ # generation stage, we will have `best_of` sequences running.
295
+ return self.sampling_params.best_of
296
+ # At sampling stages, return the number of actual sequences
297
+ # that are not finished yet.
298
+ return self.num_unfinished_seqs()
299
+
300
+ def get_seqs(
301
+ self,
302
+ status: Optional[SequenceStatus] = None,
303
+ ) -> List[Sequence]:
304
+ if status is None:
305
+ return list(self.seqs_dict.values())
306
+ else:
307
+ return [seq for seq in self.seqs_dict.values() if seq.status == status]
308
+
309
+ def get_unfinished_seqs(self) -> List[Sequence]:
310
+ return [seq for seq in self.seqs_dict.values() if not seq.is_finished()]
311
+
312
+ def get_finished_seqs(self) -> List[Sequence]:
313
+ return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
314
+
315
+ def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
316
+ return len(self.get_seqs(status))
317
+
318
+ def num_unfinished_seqs(self) -> int:
319
+ return len(self.get_unfinished_seqs())
320
+
321
+ def num_finished_seqs(self) -> int:
322
+ return len(self.get_finished_seqs())
323
+
324
+ def find(self, seq_id: int) -> Sequence:
325
+ if seq_id not in self.seqs_dict:
326
+ raise ValueError(f"Sequence {seq_id} not found.")
327
+ return self.seqs_dict[seq_id]
328
+
329
+ def add(self, seq: Sequence) -> None:
330
+ if seq.seq_id in self.seqs_dict:
331
+ raise ValueError(f"Sequence {seq.seq_id} already exists.")
332
+ self.seqs_dict[seq.seq_id] = seq
333
+
334
+ def remove(self, seq_id: int) -> None:
335
+ if seq_id not in self.seqs_dict:
336
+ raise ValueError(f"Sequence {seq_id} not found.")
337
+ del self.seqs_dict[seq_id]
338
+
339
+ def is_finished(self) -> bool:
340
+ return all(seq.is_finished() for seq in self.get_seqs())
341
+
342
+ def __repr__(self) -> str:
343
+ return (
344
+ f"SequenceGroup(request_id={self.request_id}, "
345
+ f"sampling_params={self.sampling_params}, "
346
+ f"num_seqs={len(self.seqs_dict)})"
347
+ )
348
+
349
+
350
+ class SequenceGroupMetadata:
351
+ """Metadata for a sequence group. Used to create `InputMetadata`.
352
+
353
+
354
+ Args:
355
+ request_id: The ID of the request.
356
+ is_prompt: Whether the request is at prompt stage.
357
+ seq_data: The sequence data. (Seq id -> sequence data)
358
+ sampling_params: The sampling parameters used to generate the outputs.
359
+ block_tables: The block tables. (Seq id -> list of physical block
360
+ numbers)
361
+ """
362
+
363
+ def __init__(
364
+ self,
365
+ request_id: str,
366
+ is_prompt: bool,
367
+ seq_data: Dict[int, SequenceData],
368
+ sampling_params: SamplingParams,
369
+ block_tables: Dict[int, List[int]],
370
+ ) -> None:
371
+ self.request_id = request_id
372
+ self.is_prompt = is_prompt
373
+ self.seq_data = seq_data
374
+ self.sampling_params = sampling_params
375
+ self.block_tables = block_tables
376
+
377
+
378
+ class SequenceOutput:
379
+ """The model output associated with a sequence.
380
+
381
+ Args:
382
+ parent_seq_id: The ID of the parent sequence (for forking in beam
383
+ search).
384
+ output_token: The output token ID.
385
+ logprobs: The logprobs of the output token.
386
+ (Token id -> logP(x_i+1 | x_0, ..., x_i))
387
+ """
388
+
389
+ def __init__(
390
+ self,
391
+ parent_seq_id: int,
392
+ output_token: int,
393
+ logprobs: Dict[int, float],
394
+ hidden_states: Optional[torch.Tensor] = None,
395
+ finished: bool = False,
396
+ ) -> None:
397
+ self.parent_seq_id = parent_seq_id
398
+ self.output_token = output_token
399
+ self.logprobs = logprobs
400
+ self.finished = finished
401
+ self.hidden_states = hidden_states
402
+
403
+ def __repr__(self) -> str:
404
+ return (
405
+ f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
406
+ f"output_token={self.output_token}, "
407
+ f"logprobs={self.logprobs}),"
408
+ f"finished={self.finished}),"
409
+ f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None}"
410
+ )
411
+
412
+ def __eq__(self, other: object) -> bool:
413
+ if not isinstance(other, SequenceOutput):
414
+ raise NotImplementedError()
415
+ return (
416
+ self.parent_seq_id == other.parent_seq_id
417
+ and self.output_token == other.output_token
418
+ and self.logprobs == other.logprobs
419
+ )
420
+
421
+
422
+ class SequenceGroupOutput:
423
+ """The model output associated with a sequence group."""
424
+
425
+ def __init__(
426
+ self,
427
+ samples: List[SequenceOutput],
428
+ prompt_logprobs: Optional[PromptLogprobs],
429
+ ) -> None:
430
+ self.samples = samples
431
+ self.prompt_logprobs = prompt_logprobs
432
+
433
+ def __repr__(self) -> str:
434
+ return (
435
+ f"SequenceGroupOutput(samples={self.samples}, "
436
+ f"prompt_logprobs={self.prompt_logprobs})"
437
+ )
438
+
439
+ def __eq__(self, other: object) -> bool:
440
+ if not isinstance(other, SequenceGroupOutput):
441
+ raise NotImplementedError()
442
+ return (
443
+ self.samples == other.samples
444
+ and self.prompt_logprobs == other.prompt_logprobs
445
+ )
446
+
447
+
448
+ # For each sequence group, we generate a list of SequenceOutput object,
449
+ # each of which contains one possible candidate for the next token.
450
+ SamplerOutput = List[SequenceGroupOutput]
ChatTTS/model/velocity/worker.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A GPU worker class."""
2
+
3
+ import os
4
+ from typing import Dict, List, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.distributed
8
+
9
+ from vllm.config import CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig
10
+ from vllm.model_executor import set_random_seed
11
+ from vllm.model_executor.parallel_utils.communication_op import broadcast_object_list
12
+ from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
13
+ from vllm.sequence import SamplerOutput, SequenceGroupMetadata
14
+ from vllm.worker.cache_engine import CacheEngine
15
+
16
+ from .model_runner import ModelRunner
17
+
18
+
19
+ class Worker:
20
+ """A worker class that executes (a partition of) the model on a GPU.
21
+
22
+ Each worker is associated with a single GPU. The worker is responsible for
23
+ maintaining the KV cache and executing the model on the GPU. In case of
24
+ distributed inference, each worker is assigned a partition of the model.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ model_config: ModelConfig,
30
+ parallel_config: ParallelConfig,
31
+ scheduler_config: SchedulerConfig,
32
+ local_rank: int,
33
+ rank: int,
34
+ distributed_init_method: str,
35
+ post_model_path: str,
36
+ is_driver_worker: bool = False,
37
+ ) -> None:
38
+ self.model_config = model_config
39
+ self.parallel_config = parallel_config
40
+ self.scheduler_config = scheduler_config
41
+ self.local_rank = local_rank
42
+ self.rank = rank
43
+ self.distributed_init_method = distributed_init_method
44
+ self.is_driver_worker = is_driver_worker
45
+ self.post_model_path = post_model_path
46
+
47
+ if self.is_driver_worker:
48
+ assert self.rank == 0, "The driver worker must have rank 0."
49
+
50
+ self.model_runner = ModelRunner(
51
+ model_config,
52
+ parallel_config,
53
+ scheduler_config,
54
+ is_driver_worker,
55
+ post_model_path,
56
+ )
57
+ # Uninitialized cache engine. Will be initialized by
58
+ # self.init_cache_engine().
59
+ self.cache_config = None
60
+ self.cache_engine = None
61
+ self.cache_events = None
62
+ self.gpu_cache = None
63
+
64
+ def init_model(self) -> None:
65
+ # torch.distributed.all_reduce does not free the input tensor until
66
+ # the synchronization point. This causes the memory usage to grow
67
+ # as the number of all_reduce calls increases. This env var disables
68
+ # this behavior.
69
+ # Related issue:
70
+ # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
71
+ os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
72
+
73
+ # This env var set by Ray causes exceptions with graph building.
74
+ os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
75
+ self.device = torch.device(f"cuda:{self.local_rank}")
76
+ torch.cuda.set_device(self.device)
77
+
78
+ _check_if_gpu_supports_dtype(self.model_config.dtype)
79
+
80
+ # Initialize the distributed environment.
81
+ _init_distributed_environment(
82
+ self.parallel_config, self.rank, self.distributed_init_method
83
+ )
84
+
85
+ # Initialize the model.
86
+ set_random_seed(self.model_config.seed)
87
+
88
+ def load_model(self):
89
+ self.model_runner.load_model()
90
+
91
+ @torch.inference_mode()
92
+ def profile_num_available_blocks(
93
+ self,
94
+ block_size: int,
95
+ gpu_memory_utilization: float,
96
+ cpu_swap_space: int,
97
+ ) -> Tuple[int, int]:
98
+ # Profile the memory usage of the model and get the maximum number of
99
+ # cache blocks that can be allocated with the remaining free memory.
100
+ torch.cuda.empty_cache()
101
+
102
+ # Execute a forward pass with dummy inputs to profile the memory usage
103
+ # of the model.
104
+ self.model_runner.profile_run()
105
+
106
+ # Calculate the number of blocks that can be allocated with the
107
+ # profiled peak memory.
108
+ torch.cuda.synchronize()
109
+ free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
110
+ peak_memory = total_gpu_memory - free_gpu_memory
111
+
112
+ cache_block_size = CacheEngine.get_cache_block_size(
113
+ block_size, self.model_config, self.parallel_config
114
+ )
115
+ num_gpu_blocks = int(
116
+ (total_gpu_memory * gpu_memory_utilization - peak_memory)
117
+ // cache_block_size
118
+ )
119
+ num_cpu_blocks = int(cpu_swap_space // cache_block_size)
120
+ num_gpu_blocks = max(num_gpu_blocks, 0)
121
+ num_cpu_blocks = max(num_cpu_blocks, 0)
122
+ torch.cuda.empty_cache()
123
+ return num_gpu_blocks, num_cpu_blocks
124
+
125
+ def init_cache_engine(self, cache_config: CacheConfig) -> None:
126
+ self.cache_config = cache_config
127
+ self.cache_engine = CacheEngine(
128
+ self.cache_config, self.model_config, self.parallel_config
129
+ )
130
+ self.cache_events = self.cache_engine.events
131
+ self.gpu_cache = self.cache_engine.gpu_cache
132
+ self.model_runner.set_block_size(self.cache_engine.block_size)
133
+
134
+ def warm_up_model(self) -> None:
135
+ if not self.model_config.enforce_eager:
136
+ self.model_runner.capture_model(self.gpu_cache)
137
+ # Reset the seed to ensure that the random state is not affected by
138
+ # the model initialization and profiling.
139
+ set_random_seed(self.model_config.seed)
140
+
141
+ def cache_swap(
142
+ self,
143
+ blocks_to_swap_in: Dict[int, int],
144
+ blocks_to_swap_out: Dict[int, int],
145
+ blocks_to_copy: Dict[int, List[int]],
146
+ ) -> None:
147
+ # Issue cache operations.
148
+ issued_cache_op = False
149
+ if blocks_to_swap_in:
150
+ self.cache_engine.swap_in(blocks_to_swap_in)
151
+ issued_cache_op = True
152
+ if blocks_to_swap_out:
153
+ self.cache_engine.swap_out(blocks_to_swap_out)
154
+ issued_cache_op = True
155
+ if blocks_to_copy:
156
+ self.cache_engine.copy(blocks_to_copy)
157
+ issued_cache_op = True
158
+
159
+ cache_events = self.cache_events if issued_cache_op else None
160
+
161
+ # Wait for cache operations to finish.
162
+ # TODO(woosuk): Profile swapping overhead and optimize if needed.
163
+ if cache_events is not None:
164
+ for event in cache_events:
165
+ event.wait()
166
+
167
+ @torch.inference_mode()
168
+ def execute_model(
169
+ self,
170
+ seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
171
+ blocks_to_swap_in: Optional[Dict[int, int]] = None,
172
+ blocks_to_swap_out: Optional[Dict[int, int]] = None,
173
+ blocks_to_copy: Optional[Dict[int, List[int]]] = None,
174
+ ) -> Optional[SamplerOutput]:
175
+ if self.is_driver_worker:
176
+ assert seq_group_metadata_list is not None
177
+ num_seq_groups = len(seq_group_metadata_list)
178
+ assert blocks_to_swap_in is not None
179
+ assert blocks_to_swap_out is not None
180
+ assert blocks_to_copy is not None
181
+ block_swapping_info = [
182
+ blocks_to_swap_in,
183
+ blocks_to_swap_out,
184
+ blocks_to_copy,
185
+ ]
186
+ broadcast_object_list([num_seq_groups] + block_swapping_info, src=0)
187
+ else:
188
+ # num_seq_groups, blocks_to_swap_in, blocks_to_swap_out,
189
+ # blocks_to_copy (4 elements)
190
+ recv_data = [None] * 4
191
+ broadcast_object_list(recv_data, src=0)
192
+ num_seq_groups = recv_data[0]
193
+ block_swapping_info = recv_data[1:]
194
+
195
+ self.cache_swap(*block_swapping_info)
196
+
197
+ # If there is no input, we don't need to execute the model.
198
+ if num_seq_groups == 0:
199
+ return {}
200
+
201
+ output = self.model_runner.execute_model(
202
+ seq_group_metadata_list, self.gpu_cache
203
+ )
204
+ return output
205
+
206
+
207
+ def _init_distributed_environment(
208
+ parallel_config: ParallelConfig,
209
+ rank: int,
210
+ distributed_init_method: Optional[str] = None,
211
+ ) -> None:
212
+ """Initialize the distributed environment."""
213
+ if torch.distributed.is_initialized():
214
+ torch_world_size = torch.distributed.get_world_size()
215
+ if torch_world_size != parallel_config.world_size:
216
+ raise RuntimeError(
217
+ "torch.distributed is already initialized but the torch world "
218
+ "size does not match parallel_config.world_size "
219
+ f"({torch_world_size} vs. {parallel_config.world_size})."
220
+ )
221
+ elif not distributed_init_method:
222
+ raise ValueError(
223
+ "distributed_init_method must be set if torch.distributed "
224
+ "is not already initialized"
225
+ )
226
+ else:
227
+ torch.distributed.init_process_group(
228
+ backend="nccl",
229
+ world_size=parallel_config.world_size,
230
+ rank=rank,
231
+ init_method=distributed_init_method,
232
+ )
233
+
234
+ # A small all_reduce for warmup.
235
+ torch.distributed.all_reduce(torch.zeros(1).cuda())
236
+ initialize_model_parallel(
237
+ parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size
238
+ )
239
+
240
+
241
+ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
242
+ # Check if the GPU supports the dtype.
243
+ if torch_dtype == torch.bfloat16:
244
+ compute_capability = torch.cuda.get_device_capability()
245
+ if compute_capability[0] < 8:
246
+ gpu_name = torch.cuda.get_device_name()
247
+ raise ValueError(
248
+ "Bfloat16 is only supported on GPUs with compute capability "
249
+ f"of at least 8.0. Your {gpu_name} GPU has compute capability "
250
+ f"{compute_capability[0]}.{compute_capability[1]}."
251
+ )
ChatTTS/norm.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import re
4
+ from typing import Dict, Tuple, List, Literal, Callable, Optional
5
+ import sys
6
+
7
+ from numba import jit
8
+ import numpy as np
9
+
10
+ from .utils import del_all
11
+
12
+
13
+ @jit
14
+ def _find_index(table: np.ndarray, val: np.uint16):
15
+ for i in range(table.size):
16
+ if table[i] == val:
17
+ return i
18
+ return -1
19
+
20
+
21
+ @jit
22
+ def _fast_replace(
23
+ table: np.ndarray, text: bytes
24
+ ) -> Tuple[np.ndarray, List[Tuple[str, str]]]:
25
+ result = np.frombuffer(text, dtype=np.uint16).copy()
26
+ replaced_words = []
27
+ for i in range(result.size):
28
+ ch = result[i]
29
+ p = _find_index(table[0], ch)
30
+ if p >= 0:
31
+ repl_char = table[1][p]
32
+ result[i] = repl_char
33
+ replaced_words.append((chr(ch), chr(repl_char)))
34
+ return result, replaced_words
35
+
36
+
37
+ @jit
38
+ def _split_tags(text: str) -> Tuple[List[str], List[str]]:
39
+ texts: List[str] = []
40
+ tags: List[str] = []
41
+ current_text = ""
42
+ current_tag = ""
43
+ for c in text:
44
+ if c == "[":
45
+ texts.append(current_text)
46
+ current_text = ""
47
+ current_tag = c
48
+ elif current_tag != "":
49
+ current_tag += c
50
+ else:
51
+ current_text += c
52
+ if c == "]":
53
+ tags.append(current_tag)
54
+ current_tag = ""
55
+ if current_text != "":
56
+ texts.append(current_text)
57
+ return texts, tags
58
+
59
+
60
+ @jit
61
+ def _combine_tags(texts: List[str], tags: List[str]) -> str:
62
+ text = ""
63
+ for t in texts:
64
+ tg = ""
65
+ if len(tags) > 0:
66
+ tg = tags.pop(0)
67
+ text += t + tg
68
+ return text
69
+
70
+
71
+ class Normalizer:
72
+ def __init__(self, map_file_path: str, logger=logging.getLogger(__name__)):
73
+ self.logger = logger
74
+ self.normalizers: Dict[str, Callable[[str], str]] = {}
75
+ self.homophones_map = self._load_homophones_map(map_file_path)
76
+ """
77
+ homophones_map
78
+
79
+ Replace the mispronounced characters with correctly pronounced ones.
80
+
81
+ Creation process of homophones_map.json:
82
+
83
+ 1. Establish a word corpus using the [Tencent AI Lab Embedding Corpora v0.2.0 large] with 12 million entries. After cleaning, approximately 1.8 million entries remain. Use ChatTTS to infer the text.
84
+ 2. Record discrepancies between the inferred and input text, identifying about 180,000 misread words.
85
+ 3. Create a pinyin to common characters mapping using correctly read characters by ChatTTS.
86
+ 4. For each discrepancy, extract the correct pinyin using [python-pinyin] and find homophones with the correct pronunciation from the mapping.
87
+
88
+ Thanks to:
89
+ [Tencent AI Lab Embedding Corpora for Chinese and English Words and Phrases](https://ai.tencent.com/ailab/nlp/en/embedding.html)
90
+ [python-pinyin](https://github.com/mozillazg/python-pinyin)
91
+
92
+ """
93
+ self.coding = "utf-16-le" if sys.byteorder == "little" else "utf-16-be"
94
+ self.reject_pattern = re.compile(r"[^\u4e00-\u9fffA-Za-z,。、,\. ]")
95
+ self.sub_pattern = re.compile(r"\[[\w_]+\]")
96
+ self.chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]")
97
+ self.english_word_pattern = re.compile(r"\b[A-Za-z]+\b")
98
+ self.character_simplifier = str.maketrans(
99
+ {
100
+ ":": ",",
101
+ ";": ",",
102
+ "!": "。",
103
+ "(": ",",
104
+ ")": ",",
105
+ "【": ",",
106
+ "】": ",",
107
+ "『": ",",
108
+ "』": ",",
109
+ "「": ",",
110
+ "」": ",",
111
+ "《": ",",
112
+ "》": ",",
113
+ "-": ",",
114
+ ":": ",",
115
+ ";": ",",
116
+ "!": ".",
117
+ "(": ",",
118
+ ")": ",",
119
+ # "[": ",",
120
+ # "]": ",",
121
+ ">": ",",
122
+ "<": ",",
123
+ "-": ",",
124
+ }
125
+ )
126
+ self.halfwidth_2_fullwidth = str.maketrans(
127
+ {
128
+ "!": "!",
129
+ '"': "“",
130
+ "'": "‘",
131
+ "#": "#",
132
+ "$": "$",
133
+ "%": "%",
134
+ "&": "&",
135
+ "(": "(",
136
+ ")": ")",
137
+ ",": ",",
138
+ "-": "-",
139
+ "*": "*",
140
+ "+": "+",
141
+ ".": "。",
142
+ "/": "/",
143
+ ":": ":",
144
+ ";": ";",
145
+ "<": "<",
146
+ "=": "=",
147
+ ">": ">",
148
+ "?": "?",
149
+ "@": "@",
150
+ # '[': '[',
151
+ "\\": "\",
152
+ # ']': ']',
153
+ "^": "^",
154
+ # '_': '_',
155
+ "`": "`",
156
+ "{": "{",
157
+ "|": "|",
158
+ "}": "}",
159
+ "~": "~",
160
+ }
161
+ )
162
+
163
+ def __call__(
164
+ self,
165
+ text: str,
166
+ do_text_normalization=True,
167
+ do_homophone_replacement=True,
168
+ lang: Optional[Literal["zh", "en"]] = None,
169
+ ) -> str:
170
+ if do_text_normalization:
171
+ _lang = self._detect_language(text) if lang is None else lang
172
+ if _lang in self.normalizers:
173
+ texts, tags = _split_tags(text)
174
+ self.logger.debug("split texts %s, tags %s", str(texts), str(tags))
175
+ texts = [self.normalizers[_lang](t) for t in texts]
176
+ self.logger.debug("normed texts %s", str(texts))
177
+ text = _combine_tags(texts, tags) if len(tags) > 0 else texts[0]
178
+ self.logger.debug("combined text %s", text)
179
+ if _lang == "zh":
180
+ text = self._apply_half2full_map(text)
181
+ invalid_characters = self._count_invalid_characters(text)
182
+ if len(invalid_characters):
183
+ self.logger.warning(f"found invalid characters: {invalid_characters}")
184
+ text = self._apply_character_map(text)
185
+ if do_homophone_replacement:
186
+ arr, replaced_words = _fast_replace(
187
+ self.homophones_map,
188
+ text.encode(self.coding),
189
+ )
190
+ if replaced_words:
191
+ text = arr.tobytes().decode(self.coding)
192
+ repl_res = ", ".join([f"{_[0]}->{_[1]}" for _ in replaced_words])
193
+ self.logger.info(f"replace homophones: {repl_res}")
194
+ if len(invalid_characters):
195
+ texts, tags = _split_tags(text)
196
+ self.logger.debug("split texts %s, tags %s", str(texts), str(tags))
197
+ texts = [self.reject_pattern.sub("", t) for t in texts]
198
+ self.logger.debug("normed texts %s", str(texts))
199
+ text = _combine_tags(texts, tags) if len(tags) > 0 else texts[0]
200
+ self.logger.debug("combined text %s", text)
201
+ return text
202
+
203
+ def register(self, name: str, normalizer: Callable[[str], str]) -> bool:
204
+ if name in self.normalizers:
205
+ self.logger.warning(f"name {name} has been registered")
206
+ return False
207
+ try:
208
+ val = normalizer("test string 测试字符串")
209
+ if not isinstance(val, str):
210
+ self.logger.warning("normalizer must have caller type (str) -> str")
211
+ return False
212
+ except Exception as e:
213
+ self.logger.warning(e)
214
+ return False
215
+ self.normalizers[name] = normalizer
216
+ return True
217
+
218
+ def unregister(self, name: str):
219
+ if name in self.normalizers:
220
+ del self.normalizers[name]
221
+
222
+ def destroy(self):
223
+ del_all(self.normalizers)
224
+ del self.homophones_map
225
+
226
+ def _load_homophones_map(self, map_file_path: str) -> np.ndarray:
227
+ with open(map_file_path, "r", encoding="utf-8") as f:
228
+ homophones_map: Dict[str, str] = json.load(f)
229
+ map = np.empty((2, len(homophones_map)), dtype=np.uint32)
230
+ for i, k in enumerate(homophones_map.keys()):
231
+ map[:, i] = (ord(k), ord(homophones_map[k]))
232
+ del homophones_map
233
+ return map
234
+
235
+ def _count_invalid_characters(self, s: str):
236
+ s = self.sub_pattern.sub("", s)
237
+ non_alphabetic_chinese_chars = self.reject_pattern.findall(s)
238
+ return set(non_alphabetic_chinese_chars)
239
+
240
+ def _apply_half2full_map(self, text: str) -> str:
241
+ return text.translate(self.halfwidth_2_fullwidth)
242
+
243
+ def _apply_character_map(self, text: str) -> str:
244
+ return text.translate(self.character_simplifier)
245
+
246
+ def _detect_language(self, sentence: str) -> Literal["zh", "en"]:
247
+ chinese_chars = self.chinese_char_pattern.findall(sentence)
248
+ english_words = self.english_word_pattern.findall(sentence)
249
+
250
+ if len(chinese_chars) > len(english_words):
251
+ return "zh"
252
+ else:
253
+ return "en"
ChatTTS/res/__init__.py ADDED
File without changes
ChatTTS/res/homophones_map.json ADDED
The diff for this file is too large to render. See raw diff
 
ChatTTS/res/sha256_map.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "sha256_asset_Decoder_pt" : "9964e36e840f0e3a748c5f716fe6de6490d2135a5f5155f4a642d51860e2ec38",
3
+ "sha256_asset_DVAE_full_pt" : "553eb75763511e23f3e5f86303e2163c5ca775489d637fb635d979c8ae58bbe5",
4
+ "sha256_asset_Embed_safetensors" : "2ff0be7134934155741b643b74e32fb6bf3eec41257984459b2ed60cdb4c48b0",
5
+ "sha256_asset_Vocos_pt" : "09a670eda1c08b740013679c7a90ebb7f1a97646ea7673069a6838e6b51d6c58",
6
+
7
+ "sha256_asset_gpt_config_json" : "0aaa1ecd96c49ad4f473459eb1982fa7ad79fa5de08cde2781bf6ad1f9a0c236",
8
+ "sha256_asset_gpt_model_safetensors" : "cd0806fd971f52f6a22c923ec64982b305e817bcc41ca83417fcf9141b984a0f",
9
+
10
+ "sha256_asset_tokenizer_special_tokens_map_json": "bd0ac9d9bb1657996b5c5fbcaa7d80f8de530d01a283da97f89deae5b1b8d011",
11
+ "sha256_asset_tokenizer_tokenizer_config_json" : "43e9d658b554fa5ee8d8e1d763349323bfef1ed7a89c0794220ab8861387d421",
12
+ "sha256_asset_tokenizer_tokenizer_json" : "843838a64e121e23e774cc75874c6fe862198d9f7dd43747914633a8fd89c20e"
13
+ }
ChatTTS/utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .dl import check_all_assets, download_all_assets
2
+ from .gpu import select_device
3
+ from .io import get_latest_modified_file, del_all
4
+ from .log import logger
ChatTTS/utils/dl.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import hashlib
4
+ import requests
5
+ from io import BytesIO
6
+ from typing import Dict, Tuple, Optional
7
+ from mmap import mmap, ACCESS_READ
8
+
9
+ from .log import logger
10
+
11
+
12
+ def sha256(fileno: int) -> str:
13
+ data = mmap(fileno, 0, access=ACCESS_READ)
14
+ h = hashlib.sha256(data).hexdigest()
15
+ del data
16
+ return h
17
+
18
+
19
+ def check_model(
20
+ dir_name: Path, model_name: str, hash: str, remove_incorrect=False
21
+ ) -> bool:
22
+ target = dir_name / model_name
23
+ relname = target.as_posix()
24
+ logger.get_logger().debug(f"checking {relname}...")
25
+ if not os.path.exists(target):
26
+ logger.get_logger().info(f"{target} not exist.")
27
+ return False
28
+ with open(target, "rb") as f:
29
+ digest = sha256(f.fileno())
30
+ bakfile = f"{target}.bak"
31
+ if digest != hash:
32
+ logger.get_logger().warning(f"{target} sha256 hash mismatch.")
33
+ logger.get_logger().info(f"expected: {hash}")
34
+ logger.get_logger().info(f"real val: {digest}")
35
+ if remove_incorrect:
36
+ if not os.path.exists(bakfile):
37
+ os.rename(str(target), bakfile)
38
+ else:
39
+ os.remove(str(target))
40
+ return False
41
+ if remove_incorrect and os.path.exists(bakfile):
42
+ os.remove(bakfile)
43
+ return True
44
+
45
+
46
+ def check_folder(
47
+ base_dir: Path,
48
+ *innder_dirs: str,
49
+ names: Tuple[str],
50
+ sha256_map: Dict[str, str],
51
+ update=False,
52
+ ) -> bool:
53
+ key = "sha256_"
54
+ current_dir = base_dir
55
+ for d in innder_dirs:
56
+ current_dir /= d
57
+ key += f"{d}_"
58
+
59
+ for model in names:
60
+ menv = model.replace(".", "_")
61
+ if not check_model(current_dir, model, sha256_map[f"{key}{menv}"], update):
62
+ return False
63
+ return True
64
+
65
+
66
+ def check_all_assets(base_dir: Path, sha256_map: Dict[str, str], update=False) -> bool:
67
+ logger.get_logger().info("checking assets...")
68
+
69
+ if not check_folder(
70
+ base_dir,
71
+ "asset",
72
+ names=(
73
+ "Decoder.pt",
74
+ "DVAE_full.pt",
75
+ "Embed.safetensors",
76
+ "Vocos.pt",
77
+ ),
78
+ sha256_map=sha256_map,
79
+ update=update,
80
+ ):
81
+ return False
82
+
83
+ if not check_folder(
84
+ base_dir,
85
+ "asset",
86
+ "gpt",
87
+ names=(
88
+ "config.json",
89
+ "model.safetensors",
90
+ ),
91
+ sha256_map=sha256_map,
92
+ update=update,
93
+ ):
94
+ return False
95
+
96
+ if not check_folder(
97
+ base_dir,
98
+ "asset",
99
+ "tokenizer",
100
+ names=(
101
+ "special_tokens_map.json",
102
+ "tokenizer_config.json",
103
+ "tokenizer.json",
104
+ ),
105
+ sha256_map=sha256_map,
106
+ update=update,
107
+ ):
108
+ return False
109
+
110
+ logger.get_logger().info("all assets are already latest.")
111
+ return True
112
+
113
+
114
+ def download_and_extract_tar_gz(
115
+ url: str, folder: str, headers: Optional[Dict[str, str]] = None
116
+ ):
117
+ import tarfile
118
+
119
+ logger.get_logger().info(f"downloading {url}")
120
+ response = requests.get(url, headers=headers, stream=True, timeout=(10, 3))
121
+ with BytesIO() as out_file:
122
+ out_file.write(response.content)
123
+ out_file.seek(0)
124
+ logger.get_logger().info(f"downloaded.")
125
+ with tarfile.open(fileobj=out_file, mode="r:gz") as tar:
126
+ tar.extractall(folder)
127
+ logger.get_logger().info(f"extracted into {folder}")
128
+
129
+
130
+ def download_and_extract_zip(
131
+ url: str, folder: str, headers: Optional[Dict[str, str]] = None
132
+ ):
133
+ import zipfile
134
+
135
+ logger.get_logger().info(f"downloading {url}")
136
+ response = requests.get(url, headers=headers, stream=True, timeout=(10, 3))
137
+ with BytesIO() as out_file:
138
+ out_file.write(response.content)
139
+ out_file.seek(0)
140
+ logger.get_logger().info(f"downloaded.")
141
+ with zipfile.ZipFile(out_file) as zip_ref:
142
+ zip_ref.extractall(folder)
143
+ logger.get_logger().info(f"extracted into {folder}")
144
+
145
+
146
+ def download_dns_yaml(url: str, folder: str, headers: Dict[str, str]):
147
+ logger.get_logger().info(f"downloading {url}")
148
+ response = requests.get(url, headers=headers, stream=True, timeout=(100, 3))
149
+ with open(os.path.join(folder, "dns.yaml"), "wb") as out_file:
150
+ out_file.write(response.content)
151
+ logger.get_logger().info(f"downloaded into {folder}")
152
+
153
+
154
+ def download_all_assets(tmpdir: str, version="0.2.8"):
155
+ import subprocess
156
+ import platform
157
+
158
+ archs = {
159
+ "aarch64": "arm64",
160
+ "armv8l": "arm64",
161
+ "arm64": "arm64",
162
+ "x86": "386",
163
+ "i386": "386",
164
+ "i686": "386",
165
+ "386": "386",
166
+ "x86_64": "amd64",
167
+ "x64": "amd64",
168
+ "amd64": "amd64",
169
+ }
170
+ system_type = platform.system().lower()
171
+ architecture = platform.machine().lower()
172
+ is_win = system_type == "windows"
173
+
174
+ architecture = archs.get(architecture, None)
175
+ if not architecture:
176
+ logger.get_logger().error(f"architecture {architecture} is not supported")
177
+ exit(1)
178
+ try:
179
+ BASE_URL = "https://github.com/fumiama/RVC-Models-Downloader/releases/download/"
180
+ suffix = "zip" if is_win else "tar.gz"
181
+ RVCMD_URL = BASE_URL + f"v{version}/rvcmd_{system_type}_{architecture}.{suffix}"
182
+ cmdfile = os.path.join(tmpdir, "rvcmd")
183
+ if is_win:
184
+ download_and_extract_zip(RVCMD_URL, tmpdir)
185
+ cmdfile += ".exe"
186
+ else:
187
+ download_and_extract_tar_gz(RVCMD_URL, tmpdir)
188
+ os.chmod(cmdfile, 0o755)
189
+ subprocess.run([cmdfile, "-notui", "-w", "0", "assets/chtts"])
190
+ except Exception:
191
+ BASE_URL = (
192
+ "https://gitea.seku.su/fumiama/RVC-Models-Downloader/releases/download/"
193
+ )
194
+ suffix = "zip" if is_win else "tar.gz"
195
+ RVCMD_URL = BASE_URL + f"v{version}/rvcmd_{system_type}_{architecture}.{suffix}"
196
+ download_dns_yaml(
197
+ "https://gitea.seku.su/fumiama/RVC-Models-Downloader/raw/branch/main/dns.yaml",
198
+ tmpdir,
199
+ headers={
200
+ "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36 Edg/128.0.0.0"
201
+ },
202
+ )
203
+ cmdfile = os.path.join(tmpdir, "rvcmd")
204
+ if is_win:
205
+ download_and_extract_zip(RVCMD_URL, tmpdir)
206
+ cmdfile += ".exe"
207
+ else:
208
+ download_and_extract_tar_gz(RVCMD_URL, tmpdir)
209
+ os.chmod(cmdfile, 0o755)
210
+ subprocess.run(
211
+ [
212
+ cmdfile,
213
+ "-notui",
214
+ "-w",
215
+ "0",
216
+ "-dns",
217
+ os.path.join(tmpdir, "dns.yaml"),
218
+ "assets/chtts",
219
+ ]
220
+ )
ChatTTS/utils/gpu.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .log import logger
4
+
5
+
6
+ def select_device(min_memory=2047, experimental=False):
7
+ if torch.cuda.is_available():
8
+ selected_gpu = 0
9
+ max_free_memory = -1
10
+ for i in range(torch.cuda.device_count()):
11
+ props = torch.cuda.get_device_properties(i)
12
+ free_memory = props.total_memory - torch.cuda.memory_reserved(i)
13
+ if max_free_memory < free_memory:
14
+ selected_gpu = i
15
+ max_free_memory = free_memory
16
+ free_memory_mb = max_free_memory / (1024 * 1024)
17
+ if free_memory_mb < min_memory:
18
+ logger.get_logger().warning(
19
+ f"GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left. Switching to CPU."
20
+ )
21
+ device = torch.device("cpu")
22
+ else:
23
+ device = torch.device(f"cuda:{selected_gpu}")
24
+ elif torch.backends.mps.is_available():
25
+ """
26
+ Currently MPS is slower than CPU while needs more memory and core utility,
27
+ so only enable this for experimental use.
28
+ """
29
+ if experimental:
30
+ # For Apple M1/M2 chips with Metal Performance Shaders
31
+ logger.get_logger().warning("experimantal: found apple GPU, using MPS.")
32
+ device = torch.device("mps")
33
+ else:
34
+ logger.get_logger().info("found Apple GPU, but use CPU.")
35
+ device = torch.device("cpu")
36
+ else:
37
+ logger.get_logger().warning("no GPU found, use CPU instead")
38
+ device = torch.device("cpu")
39
+
40
+ return device
ChatTTS/utils/io.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from typing import Union
4
+ from dataclasses import is_dataclass
5
+
6
+ from .log import logger
7
+
8
+
9
+ def get_latest_modified_file(directory):
10
+
11
+ files = [os.path.join(directory, f) for f in os.listdir(directory)]
12
+ if not files:
13
+ logger.get_logger().log(
14
+ logging.WARNING, f"no files found in the directory: {directory}"
15
+ )
16
+ return None
17
+ latest_file = max(files, key=os.path.getmtime)
18
+
19
+ return latest_file
20
+
21
+
22
+ def del_all(d: Union[dict, list]):
23
+ if is_dataclass(d):
24
+ for k in list(vars(d).keys()):
25
+ x = getattr(d, k)
26
+ if isinstance(x, dict) or isinstance(x, list) or is_dataclass(x):
27
+ del_all(x)
28
+ del x
29
+ delattr(d, k)
30
+ elif isinstance(d, dict):
31
+ lst = list(d.keys())
32
+ for k in lst:
33
+ x = d.pop(k)
34
+ if isinstance(x, dict) or isinstance(x, list) or is_dataclass(x):
35
+ del_all(x)
36
+ del x
37
+ elif isinstance(d, list):
38
+ while len(d):
39
+ x = d.pop()
40
+ if isinstance(x, dict) or isinstance(x, list) or is_dataclass(x):
41
+ del_all(x)
42
+ del x
43
+ else:
44
+ del d
ChatTTS/utils/log.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+
4
+
5
+ class Logger:
6
+ def __init__(self, logger=logging.getLogger(Path(__file__).parent.name)):
7
+ self.logger = logger
8
+
9
+ def set_logger(self, logger: logging.Logger):
10
+ self.logger = logger
11
+
12
+ def get_logger(self) -> logging.Logger:
13
+ return self.logger
14
+
15
+
16
+ logger = Logger()
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ RUN useradd -m -u 1000 user
4
+ USER user
5
+ ENV PATH="/home/user/.local/bin:$PATH"
6
+
7
+ WORKDIR /app
8
+
9
+ COPY --chown=user ./requirements.txt requirements.txt
10
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
+
12
+ COPY --chown=user . /app
13
+ CMD ["fastapi", "dev", "examples/api/main.py", "--host", "0.0.0.0", "--port", "7860"]
LICENSE ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU AFFERO GENERAL PUBLIC LICENSE
2
+ Version 3, 19 November 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU Affero General Public License is a free, copyleft license for
11
+ software and other kinds of works, specifically designed to ensure
12
+ cooperation with the community in the case of network server software.
13
+
14
+ The licenses for most software and other practical works are designed
15
+ to take away your freedom to share and change the works. By contrast,
16
+ our General Public Licenses are intended to guarantee your freedom to
17
+ share and change all versions of a program--to make sure it remains free
18
+ software for all its users.
19
+
20
+ When we speak of free software, we are referring to freedom, not
21
+ price. Our General Public Licenses are designed to make sure that you
22
+ have the freedom to distribute copies of free software (and charge for
23
+ them if you wish), that you receive source code or can get it if you
24
+ want it, that you can change the software or use pieces of it in new
25
+ free programs, and that you know you can do these things.
26
+
27
+ Developers that use our General Public Licenses protect your rights
28
+ with two steps: (1) assert copyright on the software, and (2) offer
29
+ you this License which gives you legal permission to copy, distribute
30
+ and/or modify the software.
31
+
32
+ A secondary benefit of defending all users' freedom is that
33
+ improvements made in alternate versions of the program, if they
34
+ receive widespread use, become available for other developers to
35
+ incorporate. Many developers of free software are heartened and
36
+ encouraged by the resulting cooperation. However, in the case of
37
+ software used on network servers, this result may fail to come about.
38
+ The GNU General Public License permits making a modified version and
39
+ letting the public access it on a server without ever releasing its
40
+ source code to the public.
41
+
42
+ The GNU Affero General Public License is designed specifically to
43
+ ensure that, in such cases, the modified source code becomes available
44
+ to the community. It requires the operator of a network server to
45
+ provide the source code of the modified version running there to the
46
+ users of that server. Therefore, public use of a modified version, on
47
+ a publicly accessible server, gives the public access to the source
48
+ code of the modified version.
49
+
50
+ An older license, called the Affero General Public License and
51
+ published by Affero, was designed to accomplish similar goals. This is
52
+ a different license, not a version of the Affero GPL, but Affero has
53
+ released a new version of the Affero GPL which permits relicensing under
54
+ this license.
55
+
56
+ The precise terms and conditions for copying, distribution and
57
+ modification follow.
58
+
59
+ TERMS AND CONDITIONS
60
+
61
+ 0. Definitions.
62
+
63
+ "This License" refers to version 3 of the GNU Affero General Public License.
64
+
65
+ "Copyright" also means copyright-like laws that apply to other kinds of
66
+ works, such as semiconductor masks.
67
+
68
+ "The Program" refers to any copyrightable work licensed under this
69
+ License. Each licensee is addressed as "you". "Licensees" and
70
+ "recipients" may be individuals or organizations.
71
+
72
+ To "modify" a work means to copy from or adapt all or part of the work
73
+ in a fashion requiring copyright permission, other than the making of an
74
+ exact copy. The resulting work is called a "modified version" of the
75
+ earlier work or a work "based on" the earlier work.
76
+
77
+ A "covered work" means either the unmodified Program or a work based
78
+ on the Program.
79
+
80
+ To "propagate" a work means to do anything with it that, without
81
+ permission, would make you directly or secondarily liable for
82
+ infringement under applicable copyright law, except executing it on a
83
+ computer or modifying a private copy. Propagation includes copying,
84
+ distribution (with or without modification), making available to the
85
+ public, and in some countries other activities as well.
86
+
87
+ To "convey" a work means any kind of propagation that enables other
88
+ parties to make or receive copies. Mere interaction with a user through
89
+ a computer network, with no transfer of a copy, is not conveying.
90
+
91
+ An interactive user interface displays "Appropriate Legal Notices"
92
+ to the extent that it includes a convenient and prominently visible
93
+ feature that (1) displays an appropriate copyright notice, and (2)
94
+ tells the user that there is no warranty for the work (except to the
95
+ extent that warranties are provided), that licensees may convey the
96
+ work under this License, and how to view a copy of this License. If
97
+ the interface presents a list of user commands or options, such as a
98
+ menu, a prominent item in the list meets this criterion.
99
+
100
+ 1. Source Code.
101
+
102
+ The "source code" for a work means the preferred form of the work
103
+ for making modifications to it. "Object code" means any non-source
104
+ form of a work.
105
+
106
+ A "Standard Interface" means an interface that either is an official
107
+ standard defined by a recognized standards body, or, in the case of
108
+ interfaces specified for a particular programming language, one that
109
+ is widely used among developers working in that language.
110
+
111
+ The "System Libraries" of an executable work include anything, other
112
+ than the work as a whole, that (a) is included in the normal form of
113
+ packaging a Major Component, but which is not part of that Major
114
+ Component, and (b) serves only to enable use of the work with that
115
+ Major Component, or to implement a Standard Interface for which an
116
+ implementation is available to the public in source code form. A
117
+ "Major Component", in this context, means a major essential component
118
+ (kernel, window system, and so on) of the specific operating system
119
+ (if any) on which the executable work runs, or a compiler used to
120
+ produce the work, or an object code interpreter used to run it.
121
+
122
+ The "Corresponding Source" for a work in object code form means all
123
+ the source code needed to generate, install, and (for an executable
124
+ work) run the object code and to modify the work, including scripts to
125
+ control those activities. However, it does not include the work's
126
+ System Libraries, or general-purpose tools or generally available free
127
+ programs which are used unmodified in performing those activities but
128
+ which are not part of the work. For example, Corresponding Source
129
+ includes interface definition files associated with source files for
130
+ the work, and the source code for shared libraries and dynamically
131
+ linked subprograms that the work is specifically designed to require,
132
+ such as by intimate data communication or control flow between those
133
+ subprograms and other parts of the work.
134
+
135
+ The Corresponding Source need not include anything that users
136
+ can regenerate automatically from other parts of the Corresponding
137
+ Source.
138
+
139
+ The Corresponding Source for a work in source code form is that
140
+ same work.
141
+
142
+ 2. Basic Permissions.
143
+
144
+ All rights granted under this License are granted for the term of
145
+ copyright on the Program, and are irrevocable provided the stated
146
+ conditions are met. This License explicitly affirms your unlimited
147
+ permission to run the unmodified Program. The output from running a
148
+ covered work is covered by this License only if the output, given its
149
+ content, constitutes a covered work. This License acknowledges your
150
+ rights of fair use or other equivalent, as provided by copyright law.
151
+
152
+ You may make, run and propagate covered works that you do not
153
+ convey, without conditions so long as your license otherwise remains
154
+ in force. You may convey covered works to others for the sole purpose
155
+ of having them make modifications exclusively for you, or provide you
156
+ with facilities for running those works, provided that you comply with
157
+ the terms of this License in conveying all material for which you do
158
+ not control copyright. Those thus making or running the covered works
159
+ for you must do so exclusively on your behalf, under your direction
160
+ and control, on terms that prohibit them from making any copies of
161
+ your copyrighted material outside their relationship with you.
162
+
163
+ Conveying under any other circumstances is permitted solely under
164
+ the conditions stated below. Sublicensing is not allowed; section 10
165
+ makes it unnecessary.
166
+
167
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
168
+
169
+ No covered work shall be deemed part of an effective technological
170
+ measure under any applicable law fulfilling obligations under article
171
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
172
+ similar laws prohibiting or restricting circumvention of such
173
+ measures.
174
+
175
+ When you convey a covered work, you waive any legal power to forbid
176
+ circumvention of technological measures to the extent such circumvention
177
+ is effected by exercising rights under this License with respect to
178
+ the covered work, and you disclaim any intention to limit operation or
179
+ modification of the work as a means of enforcing, against the work's
180
+ users, your or third parties' legal rights to forbid circumvention of
181
+ technological measures.
182
+
183
+ 4. Conveying Verbatim Copies.
184
+
185
+ You may convey verbatim copies of the Program's source code as you
186
+ receive it, in any medium, provided that you conspicuously and
187
+ appropriately publish on each copy an appropriate copyright notice;
188
+ keep intact all notices stating that this License and any
189
+ non-permissive terms added in accord with section 7 apply to the code;
190
+ keep intact all notices of the absence of any warranty; and give all
191
+ recipients a copy of this License along with the Program.
192
+
193
+ You may charge any price or no price for each copy that you convey,
194
+ and you may offer support or warranty protection for a fee.
195
+
196
+ 5. Conveying Modified Source Versions.
197
+
198
+ You may convey a work based on the Program, or the modifications to
199
+ produce it from the Program, in the form of source code under the
200
+ terms of section 4, provided that you also meet all of these conditions:
201
+
202
+ a) The work must carry prominent notices stating that you modified
203
+ it, and giving a relevant date.
204
+
205
+ b) The work must carry prominent notices stating that it is
206
+ released under this License and any conditions added under section
207
+ 7. This requirement modifies the requirement in section 4 to
208
+ "keep intact all notices".
209
+
210
+ c) You must license the entire work, as a whole, under this
211
+ License to anyone who comes into possession of a copy. This
212
+ License will therefore apply, along with any applicable section 7
213
+ additional terms, to the whole of the work, and all its parts,
214
+ regardless of how they are packaged. This License gives no
215
+ permission to license the work in any other way, but it does not
216
+ invalidate such permission if you have separately received it.
217
+
218
+ d) If the work has interactive user interfaces, each must display
219
+ Appropriate Legal Notices; however, if the Program has interactive
220
+ interfaces that do not display Appropriate Legal Notices, your
221
+ work need not make them do so.
222
+
223
+ A compilation of a covered work with other separate and independent
224
+ works, which are not by their nature extensions of the covered work,
225
+ and which are not combined with it such as to form a larger program,
226
+ in or on a volume of a storage or distribution medium, is called an
227
+ "aggregate" if the compilation and its resulting copyright are not
228
+ used to limit the access or legal rights of the compilation's users
229
+ beyond what the individual works permit. Inclusion of a covered work
230
+ in an aggregate does not cause this License to apply to the other
231
+ parts of the aggregate.
232
+
233
+ 6. Conveying Non-Source Forms.
234
+
235
+ You may convey a covered work in object code form under the terms
236
+ of sections 4 and 5, provided that you also convey the
237
+ machine-readable Corresponding Source under the terms of this License,
238
+ in one of these ways:
239
+
240
+ a) Convey the object code in, or embodied in, a physical product
241
+ (including a physical distribution medium), accompanied by the
242
+ Corresponding Source fixed on a durable physical medium
243
+ customarily used for software interchange.
244
+
245
+ b) Convey the object code in, or embodied in, a physical product
246
+ (including a physical distribution medium), accompanied by a
247
+ written offer, valid for at least three years and valid for as
248
+ long as you offer spare parts or customer support for that product
249
+ model, to give anyone who possesses the object code either (1) a
250
+ copy of the Corresponding Source for all the software in the
251
+ product that is covered by this License, on a durable physical
252
+ medium customarily used for software interchange, for a price no
253
+ more than your reasonable cost of physically performing this
254
+ conveying of source, or (2) access to copy the
255
+ Corresponding Source from a network server at no charge.
256
+
257
+ c) Convey individual copies of the object code with a copy of the
258
+ written offer to provide the Corresponding Source. This
259
+ alternative is allowed only occasionally and noncommercially, and
260
+ only if you received the object code with such an offer, in accord
261
+ with subsection 6b.
262
+
263
+ d) Convey the object code by offering access from a designated
264
+ place (gratis or for a charge), and offer equivalent access to the
265
+ Corresponding Source in the same way through the same place at no
266
+ further charge. You need not require recipients to copy the
267
+ Corresponding Source along with the object code. If the place to
268
+ copy the object code is a network server, the Corresponding Source
269
+ may be on a different server (operated by you or a third party)
270
+ that supports equivalent copying facilities, provided you maintain
271
+ clear directions next to the object code saying where to find the
272
+ Corresponding Source. Regardless of what server hosts the
273
+ Corresponding Source, you remain obligated to ensure that it is
274
+ available for as long as needed to satisfy these requirements.
275
+
276
+ e) Convey the object code using peer-to-peer transmission, provided
277
+ you inform other peers where the object code and Corresponding
278
+ Source of the work are being offered to the general public at no
279
+ charge under subsection 6d.
280
+
281
+ A separable portion of the object code, whose source code is excluded
282
+ from the Corresponding Source as a System Library, need not be
283
+ included in conveying the object code work.
284
+
285
+ A "User Product" is either (1) a "consumer product", which means any
286
+ tangible personal property which is normally used for personal, family,
287
+ or household purposes, or (2) anything designed or sold for incorporation
288
+ into a dwelling. In determining whether a product is a consumer product,
289
+ doubtful cases shall be resolved in favor of coverage. For a particular
290
+ product received by a particular user, "normally used" refers to a
291
+ typical or common use of that class of product, regardless of the status
292
+ of the particular user or of the way in which the particular user
293
+ actually uses, or expects or is expected to use, the product. A product
294
+ is a consumer product regardless of whether the product has substantial
295
+ commercial, industrial or non-consumer uses, unless such uses represent
296
+ the only significant mode of use of the product.
297
+
298
+ "Installation Information" for a User Product means any methods,
299
+ procedures, authorization keys, or other information required to install
300
+ and execute modified versions of a covered work in that User Product from
301
+ a modified version of its Corresponding Source. The information must
302
+ suffice to ensure that the continued functioning of the modified object
303
+ code is in no case prevented or interfered with solely because
304
+ modification has been made.
305
+
306
+ If you convey an object code work under this section in, or with, or
307
+ specifically for use in, a User Product, and the conveying occurs as
308
+ part of a transaction in which the right of possession and use of the
309
+ User Product is transferred to the recipient in perpetuity or for a
310
+ fixed term (regardless of how the transaction is characterized), the
311
+ Corresponding Source conveyed under this section must be accompanied
312
+ by the Installation Information. But this requirement does not apply
313
+ if neither you nor any third party retains the ability to install
314
+ modified object code on the User Product (for example, the work has
315
+ been installed in ROM).
316
+
317
+ The requirement to provide Installation Information does not include a
318
+ requirement to continue to provide support service, warranty, or updates
319
+ for a work that has been modified or installed by the recipient, or for
320
+ the User Product in which it has been modified or installed. Access to a
321
+ network may be denied when the modification itself materially and
322
+ adversely affects the operation of the network or violates the rules and
323
+ protocols for communication across the network.
324
+
325
+ Corresponding Source conveyed, and Installation Information provided,
326
+ in accord with this section must be in a format that is publicly
327
+ documented (and with an implementation available to the public in
328
+ source code form), and must require no special password or key for
329
+ unpacking, reading or copying.
330
+
331
+ 7. Additional Terms.
332
+
333
+ "Additional permissions" are terms that supplement the terms of this
334
+ License by making exceptions from one or more of its conditions.
335
+ Additional permissions that are applicable to the entire Program shall
336
+ be treated as though they were included in this License, to the extent
337
+ that they are valid under applicable law. If additional permissions
338
+ apply only to part of the Program, that part may be used separately
339
+ under those permissions, but the entire Program remains governed by
340
+ this License without regard to the additional permissions.
341
+
342
+ When you convey a copy of a covered work, you may at your option
343
+ remove any additional permissions from that copy, or from any part of
344
+ it. (Additional permissions may be written to require their own
345
+ removal in certain cases when you modify the work.) You may place
346
+ additional permissions on material, added by you to a covered work,
347
+ for which you have or can give appropriate copyright permission.
348
+
349
+ Notwithstanding any other provision of this License, for material you
350
+ add to a covered work, you may (if authorized by the copyright holders of
351
+ that material) supplement the terms of this License with terms:
352
+
353
+ a) Disclaiming warranty or limiting liability differently from the
354
+ terms of sections 15 and 16 of this License; or
355
+
356
+ b) Requiring preservation of specified reasonable legal notices or
357
+ author attributions in that material or in the Appropriate Legal
358
+ Notices displayed by works containing it; or
359
+
360
+ c) Prohibiting misrepresentation of the origin of that material, or
361
+ requiring that modified versions of such material be marked in
362
+ reasonable ways as different from the original version; or
363
+
364
+ d) Limiting the use for publicity purposes of names of licensors or
365
+ authors of the material; or
366
+
367
+ e) Declining to grant rights under trademark law for use of some
368
+ trade names, trademarks, or service marks; or
369
+
370
+ f) Requiring indemnification of licensors and authors of that
371
+ material by anyone who conveys the material (or modified versions of
372
+ it) with contractual assumptions of liability to the recipient, for
373
+ any liability that these contractual assumptions directly impose on
374
+ those licensors and authors.
375
+
376
+ All other non-permissive additional terms are considered "further
377
+ restrictions" within the meaning of section 10. If the Program as you
378
+ received it, or any part of it, contains a notice stating that it is
379
+ governed by this License along with a term that is a further
380
+ restriction, you may remove that term. If a license document contains
381
+ a further restriction but permits relicensing or conveying under this
382
+ License, you may add to a covered work material governed by the terms
383
+ of that license document, provided that the further restriction does
384
+ not survive such relicensing or conveying.
385
+
386
+ If you add terms to a covered work in accord with this section, you
387
+ must place, in the relevant source files, a statement of the
388
+ additional terms that apply to those files, or a notice indicating
389
+ where to find the applicable terms.
390
+
391
+ Additional terms, permissive or non-permissive, may be stated in the
392
+ form of a separately written license, or stated as exceptions;
393
+ the above requirements apply either way.
394
+
395
+ 8. Termination.
396
+
397
+ You may not propagate or modify a covered work except as expressly
398
+ provided under this License. Any attempt otherwise to propagate or
399
+ modify it is void, and will automatically terminate your rights under
400
+ this License (including any patent licenses granted under the third
401
+ paragraph of section 11).
402
+
403
+ However, if you cease all violation of this License, then your
404
+ license from a particular copyright holder is reinstated (a)
405
+ provisionally, unless and until the copyright holder explicitly and
406
+ finally terminates your license, and (b) permanently, if the copyright
407
+ holder fails to notify you of the violation by some reasonable means
408
+ prior to 60 days after the cessation.
409
+
410
+ Moreover, your license from a particular copyright holder is
411
+ reinstated permanently if the copyright holder notifies you of the
412
+ violation by some reasonable means, this is the first time you have
413
+ received notice of violation of this License (for any work) from that
414
+ copyright holder, and you cure the violation prior to 30 days after
415
+ your receipt of the notice.
416
+
417
+ Termination of your rights under this section does not terminate the
418
+ licenses of parties who have received copies or rights from you under
419
+ this License. If your rights have been terminated and not permanently
420
+ reinstated, you do not qualify to receive new licenses for the same
421
+ material under section 10.
422
+
423
+ 9. Acceptance Not Required for Having Copies.
424
+
425
+ You are not required to accept this License in order to receive or
426
+ run a copy of the Program. Ancillary propagation of a covered work
427
+ occurring solely as a consequence of using peer-to-peer transmission
428
+ to receive a copy likewise does not require acceptance. However,
429
+ nothing other than this License grants you permission to propagate or
430
+ modify any covered work. These actions infringe copyright if you do
431
+ not accept this License. Therefore, by modifying or propagating a
432
+ covered work, you indicate your acceptance of this License to do so.
433
+
434
+ 10. Automatic Licensing of Downstream Recipients.
435
+
436
+ Each time you convey a covered work, the recipient automatically
437
+ receives a license from the original licensors, to run, modify and
438
+ propagate that work, subject to this License. You are not responsible
439
+ for enforcing compliance by third parties with this License.
440
+
441
+ An "entity transaction" is a transaction transferring control of an
442
+ organization, or substantially all assets of one, or subdividing an
443
+ organization, or merging organizations. If propagation of a covered
444
+ work results from an entity transaction, each party to that
445
+ transaction who receives a copy of the work also receives whatever
446
+ licenses to the work the party's predecessor in interest had or could
447
+ give under the previous paragraph, plus a right to possession of the
448
+ Corresponding Source of the work from the predecessor in interest, if
449
+ the predecessor has it or can get it with reasonable efforts.
450
+
451
+ You may not impose any further restrictions on the exercise of the
452
+ rights granted or affirmed under this License. For example, you may
453
+ not impose a license fee, royalty, or other charge for exercise of
454
+ rights granted under this License, and you may not initiate litigation
455
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
456
+ any patent claim is infringed by making, using, selling, offering for
457
+ sale, or importing the Program or any portion of it.
458
+
459
+ 11. Patents.
460
+
461
+ A "contributor" is a copyright holder who authorizes use under this
462
+ License of the Program or a work on which the Program is based. The
463
+ work thus licensed is called the contributor's "contributor version".
464
+
465
+ A contributor's "essential patent claims" are all patent claims
466
+ owned or controlled by the contributor, whether already acquired or
467
+ hereafter acquired, that would be infringed by some manner, permitted
468
+ by this License, of making, using, or selling its contributor version,
469
+ but do not include claims that would be infringed only as a
470
+ consequence of further modification of the contributor version. For
471
+ purposes of this definition, "control" includes the right to grant
472
+ patent sublicenses in a manner consistent with the requirements of
473
+ this License.
474
+
475
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
476
+ patent license under the contributor's essential patent claims, to
477
+ make, use, sell, offer for sale, import and otherwise run, modify and
478
+ propagate the contents of its contributor version.
479
+
480
+ In the following three paragraphs, a "patent license" is any express
481
+ agreement or commitment, however denominated, not to enforce a patent
482
+ (such as an express permission to practice a patent or covenant not to
483
+ sue for patent infringement). To "grant" such a patent license to a
484
+ party means to make such an agreement or commitment not to enforce a
485
+ patent against the party.
486
+
487
+ If you convey a covered work, knowingly relying on a patent license,
488
+ and the Corresponding Source of the work is not available for anyone
489
+ to copy, free of charge and under the terms of this License, through a
490
+ publicly available network server or other readily accessible means,
491
+ then you must either (1) cause the Corresponding Source to be so
492
+ available, or (2) arrange to deprive yourself of the benefit of the
493
+ patent license for this particular work, or (3) arrange, in a manner
494
+ consistent with the requirements of this License, to extend the patent
495
+ license to downstream recipients. "Knowingly relying" means you have
496
+ actual knowledge that, but for the patent license, your conveying the
497
+ covered work in a country, or your recipient's use of the covered work
498
+ in a country, would infringe one or more identifiable patents in that
499
+ country that you have reason to believe are valid.
500
+
501
+ If, pursuant to or in connection with a single transaction or
502
+ arrangement, you convey, or propagate by procuring conveyance of, a
503
+ covered work, and grant a patent license to some of the parties
504
+ receiving the covered work authorizing them to use, propagate, modify
505
+ or convey a specific copy of the covered work, then the patent license
506
+ you grant is automatically extended to all recipients of the covered
507
+ work and works based on it.
508
+
509
+ A patent license is "discriminatory" if it does not include within
510
+ the scope of its coverage, prohibits the exercise of, or is
511
+ conditioned on the non-exercise of one or more of the rights that are
512
+ specifically granted under this License. You may not convey a covered
513
+ work if you are a party to an arrangement with a third party that is
514
+ in the business of distributing software, under which you make payment
515
+ to the third party based on the extent of your activity of conveying
516
+ the work, and under which the third party grants, to any of the
517
+ parties who would receive the covered work from you, a discriminatory
518
+ patent license (a) in connection with copies of the covered work
519
+ conveyed by you (or copies made from those copies), or (b) primarily
520
+ for and in connection with specific products or compilations that
521
+ contain the covered work, unless you entered into that arrangement,
522
+ or that patent license was granted, prior to 28 March 2007.
523
+
524
+ Nothing in this License shall be construed as excluding or limiting
525
+ any implied license or other defenses to infringement that may
526
+ otherwise be available to you under applicable patent law.
527
+
528
+ 12. No Surrender of Others' Freedom.
529
+
530
+ If conditions are imposed on you (whether by court order, agreement or
531
+ otherwise) that contradict the conditions of this License, they do not
532
+ excuse you from the conditions of this License. If you cannot convey a
533
+ covered work so as to satisfy simultaneously your obligations under this
534
+ License and any other pertinent obligations, then as a consequence you may
535
+ not convey it at all. For example, if you agree to terms that obligate you
536
+ to collect a royalty for further conveying from those to whom you convey
537
+ the Program, the only way you could satisfy both those terms and this
538
+ License would be to refrain entirely from conveying the Program.
539
+
540
+ 13. Remote Network Interaction; Use with the GNU General Public License.
541
+
542
+ Notwithstanding any other provision of this License, if you modify the
543
+ Program, your modified version must prominently offer all users
544
+ interacting with it remotely through a computer network (if your version
545
+ supports such interaction) an opportunity to receive the Corresponding
546
+ Source of your version by providing access to the Corresponding Source
547
+ from a network server at no charge, through some standard or customary
548
+ means of facilitating copying of software. This Corresponding Source
549
+ shall include the Corresponding Source for any work covered by version 3
550
+ of the GNU General Public License that is incorporated pursuant to the
551
+ following paragraph.
552
+
553
+ Notwithstanding any other provision of this License, you have
554
+ permission to link or combine any covered work with a work licensed
555
+ under version 3 of the GNU General Public License into a single
556
+ combined work, and to convey the resulting work. The terms of this
557
+ License will continue to apply to the part which is the covered work,
558
+ but the work with which it is combined will remain governed by version
559
+ 3 of the GNU General Public License.
560
+
561
+ 14. Revised Versions of this License.
562
+
563
+ The Free Software Foundation may publish revised and/or new versions of
564
+ the GNU Affero General Public License from time to time. Such new versions
565
+ will be similar in spirit to the present version, but may differ in detail to
566
+ address new problems or concerns.
567
+
568
+ Each version is given a distinguishing version number. If the
569
+ Program specifies that a certain numbered version of the GNU Affero General
570
+ Public License "or any later version" applies to it, you have the
571
+ option of following the terms and conditions either of that numbered
572
+ version or of any later version published by the Free Software
573
+ Foundation. If the Program does not specify a version number of the
574
+ GNU Affero General Public License, you may choose any version ever published
575
+ by the Free Software Foundation.
576
+
577
+ If the Program specifies that a proxy can decide which future
578
+ versions of the GNU Affero General Public License can be used, that proxy's
579
+ public statement of acceptance of a version permanently authorizes you
580
+ to choose that version for the Program.
581
+
582
+ Later license versions may give you additional or different
583
+ permissions. However, no additional obligations are imposed on any
584
+ author or copyright holder as a result of your choosing to follow a
585
+ later version.
586
+
587
+ 15. Disclaimer of Warranty.
588
+
589
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
590
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
591
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
592
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
593
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
594
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
595
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
596
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
597
+
598
+ 16. Limitation of Liability.
599
+
600
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
601
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
602
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
603
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
604
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
605
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
606
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
607
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
608
+ SUCH DAMAGES.
609
+
610
+ 17. Interpretation of Sections 15 and 16.
611
+
612
+ If the disclaimer of warranty and limitation of liability provided
613
+ above cannot be given local legal effect according to their terms,
614
+ reviewing courts shall apply local law that most closely approximates
615
+ an absolute waiver of all civil liability in connection with the
616
+ Program, unless a warranty or assumption of liability accompanies a
617
+ copy of the Program in return for a fee.
618
+
619
+ END OF TERMS AND CONDITIONS
620
+
621
+ How to Apply These Terms to Your New Programs
622
+
623
+ If you develop a new program, and you want it to be of the greatest
624
+ possible use to the public, the best way to achieve this is to make it
625
+ free software which everyone can redistribute and change under these terms.
626
+
627
+ To do so, attach the following notices to the program. It is safest
628
+ to attach them to the start of each source file to most effectively
629
+ state the exclusion of warranty; and each file should have at least
630
+ the "copyright" line and a pointer to where the full notice is found.
631
+
632
+ <one line to give the program's name and a brief idea of what it does.>
633
+ Copyright (C) <year> <name of author>
634
+
635
+ This program is free software: you can redistribute it and/or modify
636
+ it under the terms of the GNU Affero General Public License as published
637
+ by the Free Software Foundation, either version 3 of the License, or
638
+ (at your option) any later version.
639
+
640
+ This program is distributed in the hope that it will be useful,
641
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
642
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
643
+ GNU Affero General Public License for more details.
644
+
645
+ You should have received a copy of the GNU Affero General Public License
646
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
647
+
648
+ Also add information on how to contact you by electronic and paper mail.
649
+
650
+ If your software can interact with users remotely through a computer
651
+ network, you should also make sure that it provides a way for users to
652
+ get its source. For example, if your program is a web application, its
653
+ interface could display a "Source" link that leads users to an archive
654
+ of the code. There are many ways you could offer source, and different
655
+ solutions will be better for different programs; see section 13 for the
656
+ specific requirements.
657
+
658
+ You should also get your employer (if you work as a programmer) or school,
659
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
660
+ For more information on this, and how to apply and follow the GNU AGPL, see
661
+ <https://www.gnu.org/licenses/>.
docs/cn/README.md ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ <a href="https://trendshift.io/repositories/10489" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10489" alt="2noise%2FChatTTS | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
4
+
5
+ # ChatTTS
6
+ 一款适用于日常对话的生成式语音模型。
7
+
8
+ [![Licence](https://img.shields.io/github/license/2noise/ChatTTS?style=for-the-badge)](https://github.com/2noise/ChatTTS/blob/main/LICENSE)
9
+ [![PyPI](https://img.shields.io/pypi/v/ChatTTS.svg?style=for-the-badge&color=green)](https://pypi.org/project/ChatTTS)
10
+
11
+ [![Huggingface](https://img.shields.io/badge/🤗%20-Models-yellow.svg?style=for-the-badge)](https://huggingface.co/2Noise/ChatTTS)
12
+ [![Open In Colab](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/2noise/ChatTTS/blob/main/examples/ipynb/colab.ipynb)
13
+ [![Discord](https://img.shields.io/badge/Discord-7289DA?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/Ud5Jxgx5yD)
14
+
15
+ [**English**](../../README.md) | **简体中文** | [**日本語**](../jp/README.md) | [**Русский**](../ru/README.md) | [**Español**](../es/README.md) | [**Français**](../fr/README.md)
16
+
17
+ </div>
18
+
19
+ > [!NOTE]
20
+ > 注意此版本可能不是最新版,所有内容请以英文版为准。
21
+
22
+ ## 简介
23
+
24
+ > [!Note]
25
+ > 这个仓库包含算法架构和一些简单的示例。
26
+
27
+ > [!Tip]
28
+ > 由本仓库衍生出的用户端产品,请参见由社区维护的索引仓库 [Awesome-ChatTTS](https://github.com/libukai/Awesome-ChatTTS)。
29
+
30
+ ChatTTS 是一款专门为对话场景(例如 LLM 助手)设计的文本转语音模型。
31
+
32
+ ### 支持的语种
33
+
34
+ - [x] 英语
35
+ - [x] 中文
36
+ - [ ] 敬请期待...
37
+
38
+ ### 亮点
39
+
40
+ > 你可以参考 **[Bilibili](https://www.bilibili.com/video/BV1zn4y1o7iV)** 上的这个视频,了解本项目的详细情况。
41
+
42
+ 1. **对话式 TTS**: ChatTTS 针对对话式任务进行了优化,能够实现自然且富有表现力的合成语音。它支持多个说话者,便于生成互动式对话。
43
+ 2. **精细的控制**: 该模型可以预测和控制精细的韵律特征,包括笑声、停顿和插入语。
44
+ 3. **更好的韵律**: ChatTTS 在韵律方面超越了大多数开源 TTS 模型。我们提供预训练模型以支持进一步的研究和开发。
45
+
46
+ ### 数据集和模型
47
+
48
+ - 主模型使用了 100,000+ 小时的中文和英文音频数据进行训练。
49
+ - **[HuggingFace](https://huggingface.co/2Noise/ChatTTS)** 上的开源版本是一个在 40,000 小时数据上进行无监督微调的预训练模型。
50
+
51
+ ### 路线图
52
+
53
+ - [x] 开源 4 万小时基础模型和 spk_stats 文件。
54
+ - [x] 支持流式语音输出。
55
+ - [ ] 开源具有多情感控制功能的 4 万小时版本。
56
+ - [ ] ChatTTS.cpp (欢迎在 2noise 组织中新建仓库)。
57
+
58
+ ### 免责声明
59
+
60
+ > [!Important]
61
+ > 此仓库仅供学术用途。
62
+
63
+ 本项目旨在用于教育和研究目的,不适用于任何商业或法律目的。作者不保证信息的准确性、完整性和可靠性。此仓库中使用的信息和数据仅供学术和研究目的。数据来自公开来源,作者不声称对数据拥有任何所有权或版权。
64
+
65
+ ChatTTS 是一款强大的文本转语音系统。但是,负责任和道德地使用这项技术非常重要。为了限制 ChatTTS 的使用,我们在 40,000 小时模型的训练过程中添加了少量高频噪声,并使用 MP3 格式尽可能压缩音频质量,以防止恶意行为者将其用于犯罪目的。同时,我们内部训练了一个检测模型,并计划在未来开源它。
66
+
67
+ ### 联系方式
68
+
69
+ > 欢迎随时提交 GitHub issues/PRs。
70
+
71
+ #### 合作洽谈
72
+
73
+ 如需就模型和路线图进行合作洽谈,请发送邮件至 **[email protected]**。
74
+
75
+ #### 线上讨论
76
+
77
+ ##### 1. 官方 QQ 群
78
+
79
+ - **群 1**, 808364215 (已满)
80
+ - **群 2**, 230696694 (已满)
81
+ - **群 3**, 933639842 (已满)
82
+ - **群 4**, 608667975
83
+
84
+ ##### 2. Discord
85
+
86
+ 点击加入 [Discord](https://discord.gg/Ud5Jxgx5yD)。
87
+
88
+ ## 体验教程
89
+
90
+ ### 克隆仓库
91
+
92
+ ```bash
93
+ git clone https://github.com/2noise/ChatTTS
94
+ cd ChatTTS
95
+ ```
96
+
97
+ ### 安装依赖
98
+
99
+ #### 1. 直接安装
100
+
101
+ ```bash
102
+ pip install --upgrade -r requirements.txt
103
+ ```
104
+
105
+ #### 2. 使用 conda 安装
106
+
107
+ ```bash
108
+ conda create -n chattts
109
+ conda activate chattts
110
+ pip install -r requirements.txt
111
+ ```
112
+
113
+ #### 可选 : 如果使用 NVIDIA GPU(仅限 Linux),可安装 TransformerEngine。
114
+
115
+ > [!Note]
116
+ > 安装过程可能耗时很长。
117
+
118
+ > [!Warning]
119
+ > TransformerEngine 的适配目前正在开发中,运行时可能会遇到较多问题。仅推荐出于开发目的安装。
120
+
121
+ ```bash
122
+ pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
123
+ ```
124
+
125
+ #### 可选 : 安装 FlashAttention-2 (主要适用于 NVIDIA GPU)
126
+
127
+ > [!Note]
128
+ > 支持设备列表详见 [Hugging Face Doc](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2).
129
+
130
+ ```bash
131
+ pip install flash-attn --no-build-isolation
132
+ ```
133
+
134
+ ### 快速启动
135
+
136
+ > 确保在执行以下命令时,处于项目根目录下。
137
+
138
+ #### 1. WebUI 可视化界面
139
+
140
+ ```bash
141
+ python examples/web/webui.py
142
+ ```
143
+
144
+ #### 2. 命令行交互
145
+
146
+ > 生成的音频将保存至 `./output_audio_n.mp3`
147
+
148
+ ```bash
149
+ python examples/cmd/run.py "Your text 1." "Your text 2."
150
+ ```
151
+
152
+ ## 开发教程
153
+
154
+ ### 安装 Python 包
155
+
156
+ 1. 从 PyPI 安装稳定版
157
+
158
+ ```bash
159
+ pip install ChatTTS
160
+ ```
161
+
162
+ 2. 从 GitHub 安装最新版
163
+
164
+ ```bash
165
+ pip install git+https://github.com/2noise/ChatTTS
166
+ ```
167
+
168
+ 3. 从本地文件夹安装开发版
169
+
170
+ ```bash
171
+ pip install -e .
172
+ ```
173
+
174
+ ### 基础用法
175
+
176
+ ```python
177
+ import ChatTTS
178
+ import torch
179
+ import torchaudio
180
+
181
+ chat = ChatTTS.Chat()
182
+ chat.load(compile=False) # Set to True for better performance
183
+
184
+ texts = ["PUT YOUR 1st TEXT HERE", "PUT YOUR 2nd TEXT HERE"]
185
+
186
+ wavs = chat.infer(texts)
187
+
188
+ torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
189
+ ```
190
+
191
+ ### 进阶用法
192
+
193
+ ```python
194
+ ###################################
195
+ # Sample a speaker from Gaussian.
196
+
197
+ rand_spk = chat.sample_random_speaker()
198
+ print(rand_spk) # save it for later timbre recovery
199
+
200
+ params_infer_code = ChatTTS.Chat.InferCodeParams(
201
+ spk_emb = rand_spk, # add sampled speaker
202
+ temperature = .3, # using custom temperature
203
+ top_P = 0.7, # top P decode
204
+ top_K = 20, # top K decode
205
+ )
206
+
207
+ ###################################
208
+ # For sentence level manual control.
209
+
210
+ # use oral_(0-9), laugh_(0-2), break_(0-7)
211
+ # to generate special token in text to synthesize.
212
+ params_refine_text = ChatTTS.Chat.RefineTextParams(
213
+ prompt='[oral_2][laugh_0][break_6]',
214
+ )
215
+
216
+ wavs = chat.infer(
217
+ texts,
218
+ params_refine_text=params_refine_text,
219
+ params_infer_code=params_infer_code,
220
+ )
221
+
222
+ ###################################
223
+ # For word level manual control.
224
+
225
+ text = 'What is [uv_break]your favorite english food?[laugh][lbreak]'
226
+ wavs = chat.infer(text, skip_refine_text=True, params_refine_text=params_refine_text, params_infer_code=params_infer_code)
227
+ torchaudio.save("output2.wav", torch.from_numpy(wavs[0]), 24000)
228
+ ```
229
+
230
+ <details open>
231
+ <summary><h4>示例: 自我介绍</h4></summary>
232
+
233
+ ```python
234
+ inputs_en = """
235
+ chatTTS is a text to speech model designed for dialogue applications.
236
+ [uv_break]it supports mixed language input [uv_break]and offers multi speaker
237
+ capabilities with precise control over prosodic elements like
238
+ [uv_break]laughter[uv_break][laugh], [uv_break]pauses, [uv_break]and intonation.
239
+ [uv_break]it delivers natural and expressive speech,[uv_break]so please
240
+ [uv_break] use the project responsibly at your own risk.[uv_break]
241
+ """.replace('\n', '') # English is still experimental.
242
+
243
+ params_refine_text = ChatTTS.Chat.RefineTextParams(
244
+ prompt='[oral_2][laugh_0][break_4]',
245
+ )
246
+
247
+ audio_array_en = chat.infer(inputs_en, params_refine_text=params_refine_text)
248
+ torchaudio.save("output3.wav", torch.from_numpy(audio_array_en[0]), 24000)
249
+ ```
250
+
251
+ <table>
252
+ <tr>
253
+ <td align="center">
254
+
255
+ **男性音色**
256
+
257
+ </td>
258
+ <td align="center">
259
+
260
+ **女性音色**
261
+
262
+ </td>
263
+ </tr>
264
+ <tr>
265
+ <td align="center">
266
+
267
+ [男性音色](https://github.com/2noise/ChatTTS/assets/130631963/e0f51251-db7f-4d39-a0e9-3e095bb65de1)
268
+
269
+ </td>
270
+ <td align="center">
271
+
272
+ [女性音色](https://github.com/2noise/ChatTTS/assets/130631963/f5dcdd01-1091-47c5-8241-c4f6aaaa8bbd)
273
+
274
+ </td>
275
+ </tr>
276
+ </table>
277
+
278
+ </details>
279
+
280
+ ## 常见问题
281
+
282
+ #### 1. 我需要多少 VRAM? 推理速度如何?
283
+
284
+ 对于 30 秒的音频片段,至少需要 4GB 的 GPU 内存。 对于 4090 GPU,它可以每秒生成大约 7 个语义 token 对应的音频。实时因子 (RTF) 约为 0.3。
285
+
286
+ #### 2. 模型稳定性不够好,存在多个说话者或音频质量差等问题。
287
+
288
+ 这是一个通常发生在自回归模型(例如 bark 和 valle)中的问题,通常很难避免。可以尝试多个样本以找到合适的结果。
289
+
290
+ #### 3. 除了笑声,我们还能控制其他东西吗?我们能控制其他情绪吗?
291
+
292
+ 在当前发布的模型中,可用的 token 级控制单元是 `[laugh]`, `[uv_break]` 和 `[lbreak]`。未来的版本中,我们可能会开源具有更多情绪控制功能的模型。
293
+
294
+ ## 致谢
295
+
296
+ - [bark](https://github.com/suno-ai/bark), [XTTSv2](https://github.com/coqui-ai/TTS) 和 [valle](https://arxiv.org/abs/2301.02111) 通过自回归式系统展示了非凡的 TTS 效果。
297
+ - [fish-speech](https://github.com/fishaudio/fish-speech) 揭示了 GVQ 作为 LLM 建模的音频分词器的能力。
298
+ - [vocos](https://github.com/gemelo-ai/vocos) vocos 被用作预训练声码器。
299
+
300
+ ## 特别鸣谢
301
+
302
+ - [wlu-audio lab](https://audio.westlake.edu.cn/) 对于早期算法实验的支持。
303
+
304
+ ## 贡献者列表
305
+
306
+ [![contributors](https://contrib.rocks/image?repo=2noise/ChatTTS)](https://github.com/2noise/ChatTTS/graphs/contributors)
307
+
308
+ ## 项目浏览量
309
+
310
+ <div align="center">
311
+
312
+ ![counter](https://counter.seku.su/cmoe?name=chattts&theme=mbs)
313
+
314
+ </div>
docs/es/README.md ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ <a href="https://trendshift.io/repositories/10489" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10489" alt="2noise%2FChatTTS | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
4
+
5
+ # ChatTTS
6
+ Un modelo de generación de voz para la conversación diaria.
7
+
8
+ [![Licence](https://img.shields.io/github/license/2noise/ChatTTS?style=for-the-badge)](https://github.com/2noise/ChatTTS/blob/main/LICENSE)
9
+
10
+ [![Huggingface](https://img.shields.io/badge/🤗%20-Models-yellow.svg?style=for-the-badge)](https://huggingface.co/2Noise/ChatTTS)
11
+ [![Open In Colab](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/2noise/ChatTTS/blob/main/examples/ipynb/colab.ipynb)
12
+
13
+ [**English**](../../README.md) | [**简体中文**](../cn/README.md) | [**日本語**](../jp/README.md) | [**Русский**](../ru/README.md) | **Español**
14
+ | [**Français**](../fr/README.md)
15
+ </div>
16
+
17
+ > [!NOTE]
18
+ > Atención, es posible que esta versión no sea la última. Por favor, consulte la versión en inglés para conocer todo el contenido.
19
+
20
+ ## Introducción
21
+
22
+ ChatTTS es un modelo de texto a voz diseñado específicamente para escenarios conversacionales como LLM assistant.
23
+
24
+ ### Idiomas Soportados
25
+
26
+ - [x] Inglés
27
+ - [x] Chino
28
+ - [ ] Manténganse al tanto...
29
+
30
+ ### Aspectos Destacados
31
+
32
+ > Puede consultar **[este video en Bilibili](https://www.bilibili.com/video/BV1zn4y1o7iV)** para obtener una descripción detallada.
33
+
34
+ 1. **TTS Conversacional**: ChatTTS está optimizado para tareas conversacionales, logrando una síntesis de voz natural y expresiva. Soporta múltiples hablantes, lo que facilita la generación de diálogos interactivos.
35
+ 2. **Control Finas**: Este modelo puede predecir y controlar características detalladas de la prosodia, incluyendo risas, pausas e interjecciones.
36
+ 3. **Mejor Prosodia**: ChatTTS supera a la mayoría de los modelos TTS de código abierto en cuanto a prosodia. Ofrecemos modelos preentrenados para apoyar estudios y desarrollos adicionales.
37
+
38
+ ### Conjunto de Datos & Modelo
39
+
40
+ - El modelo principal se entrena con más de 100.000 horas de datos de audio en chino e inglés.
41
+ - La versión de código abierto en **[HuggingFace](https://huggingface.co/2Noise/ChatTTS)** es un modelo preentrenado con 40.000 horas, sin SFT.
42
+
43
+ ### Hoja de Ruta
44
+
45
+ - [x] Publicar el modelo base de 40k horas y el archivo spk_stats como código abierto
46
+ - [ ] Publicar los códigos de codificador VQ y entrenamiento de Lora como código abierto
47
+ - [ ] Generación de audio en streaming sin refinar el texto
48
+ - [ ] Publicar la versión de 40k horas con control de múltiples emociones como código abierto
49
+ - [ ] ¿ChatTTS.cpp? (Se aceptan PR o un nuevo repositorio)
50
+
51
+ ### Descargo de Responsabilidad
52
+
53
+ > [!Important]
54
+ > Este repositorio es sólo para fines académicos.
55
+
56
+ Este proyecto está destinado a fines educativos y estudios, y no es adecuado para ningún propósito comercial o legal. El autor no garantiza la exactitud, integridad o fiabilidad de la información. La información y los datos utilizados en este repositorio son únicamente para fines académicos y de investigación. Los datos provienen de fuentes públicas, y el autor no reclama ningún derecho de propiedad o copyright sobre ellos.
57
+
58
+ ChatTTS es un potente sistema de conversión de texto a voz. Sin embargo, es crucial utilizar esta tecnología de manera responsable y ética. Para limitar el uso de ChatTTS, hemos añadido una pequeña cantidad de ruido de alta frecuencia durante el proceso de entrenamiento del modelo de 40.000 horas y hemos comprimido la calidad del audio en formato MP3 tanto como sea posible para evitar que actores malintencionados lo usen con fines delictivos. Además, hemos entrenado internamente un modelo de detección y planeamos hacerlo de código abierto en el futuro.
59
+
60
+ ### Contacto
61
+
62
+ > No dudes en enviar issues/PRs de GitHub.
63
+
64
+ #### Consultas Formales
65
+
66
+ Si desea discutir la cooperación sobre modelos y hojas de ruta, envíe un correo electrónico a **[email protected]**.
67
+
68
+ #### Chat en Línea
69
+
70
+ ##### 1. Grupo QQ (Aplicación Social China)
71
+
72
+ - **Grupo 1**, 808364215 (Lleno)
73
+ - **Grupo 2**, 230696694 (Lleno)
74
+ - **Grupo 3**, 933639842
75
+
76
+ ## Instalación (En Proceso)
77
+
78
+ > Se cargará en pypi pronto según https://github.com/2noise/ChatTTS/issues/269.
79
+
80
+ ```bash
81
+ pip install git+https://github.com/2noise/ChatTTS
82
+ ```
83
+
84
+ ## Inicio
85
+ ### Clonar el repositorio
86
+ ```bash
87
+ git clone https://github.com/2noise/ChatTTS
88
+ cd ChatTTS
89
+ ```
90
+
91
+ ### Requerimientos de instalación
92
+ #### 1. Instalar directamente
93
+ ```bash
94
+ pip install --upgrade -r requirements.txt
95
+ ```
96
+
97
+ #### 2. Instalar desde conda
98
+ ```bash
99
+ conda create -n chattts
100
+ conda activate chattts
101
+ pip install -r requirements.txt
102
+ ```
103
+
104
+ ### Inicio Rápido
105
+ #### 1. Iniciar la interfaz de usuario web (WebUI)
106
+ ```bash
107
+ python examples/web/webui.py
108
+ ```
109
+
110
+ #### 2. Inferir por línea de comando
111
+ > Guardará el audio en `./output_audio_xxx.wav`
112
+
113
+ ```bash
114
+ python examples/cmd/run.py "Please input your text."
115
+ ```
116
+
117
+ ### Básico
118
+
119
+ ```python
120
+ import ChatTTS
121
+ from IPython.display import Audio
122
+ import torchaudio
123
+ import torch
124
+
125
+ chat = ChatTTS.Chat()
126
+ chat.load(compile=False) # Set to True for better performance
127
+
128
+ texts = ["PUT YOUR TEXT HERE",]
129
+
130
+ wavs = chat.infer(texts)
131
+
132
+ torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
133
+ ```
134
+
135
+ ### Avanzado
136
+
137
+ ```python
138
+ ###################################
139
+ # Sample a speaker from Gaussian.
140
+
141
+ rand_spk = chat.sample_random_speaker()
142
+ print(rand_spk) # save it for later timbre recovery
143
+
144
+ params_infer_code = ChatTTS.Chat.InferCodeParams(
145
+ spk_emb = rand_spk, # add sampled speaker
146
+ temperature = .3, # using custom temperature
147
+ top_P = 0.7, # top P decode
148
+ top_K = 20, # top K decode
149
+ )
150
+
151
+ ###################################
152
+ # For sentence level manual control.
153
+
154
+ # use oral_(0-9), laugh_(0-2), break_(0-7)
155
+ # to generate special token in text to synthesize.
156
+ params_refine_text = ChatTTS.Chat.RefineTextParams(
157
+ prompt='[oral_2][laugh_0][break_6]',
158
+ )
159
+
160
+ wavs = chat.infer(
161
+ texts,
162
+ params_refine_text=params_refine_text,
163
+ params_infer_code=params_infer_code,
164
+ )
165
+
166
+ ###################################
167
+ # For word level manual control.
168
+ text = 'What is [uv_break]your favorite english food?[laugh][lbreak]'
169
+ wavs = chat.infer(text, skip_refine_text=True, params_refine_text=params_refine_text, params_infer_code=params_infer_code)
170
+ torchaudio.save("output2.wav", torch.from_numpy(wavs[0]), 24000)
171
+ ```
172
+
173
+ <details open>
174
+ <summary><h4>Ejemplo: auto presentación</h4></summary>
175
+
176
+ ```python
177
+ inputs_en = """
178
+ chat T T S is a text to speech model designed for dialogue applications.
179
+ [uv_break]it supports mixed language input [uv_break]and offers multi speaker
180
+ capabilities with precise control over prosodic elements [laugh]like like
181
+ [uv_break]laughter[laugh], [uv_break]pauses, [uv_break]and intonation.
182
+ [uv_break]it delivers natural and expressive speech,[uv_break]so please
183
+ [uv_break] use the project responsibly at your own risk.[uv_break]
184
+ """.replace('\n', '') # English is still experimental.
185
+
186
+ params_refine_text = ChatTTS.Chat.RefineTextParams(
187
+ prompt='[oral_2][laugh_0][break_4]',
188
+ )
189
+
190
+ audio_array_en = chat.infer(inputs_en, params_refine_text=params_refine_text)
191
+ torchaudio.save("output3.wav", torch.from_numpy(audio_array_en[0]), 24000)
192
+ ```
193
+
194
+ <table>
195
+ <tr>
196
+ <td align="center">
197
+
198
+ **altavoz masculino**
199
+
200
+ </td>
201
+ <td align="center">
202
+
203
+ **altavoz femenino**
204
+
205
+ </td>
206
+ </tr>
207
+ <tr>
208
+ <td align="center">
209
+
210
+ [male speaker](https://github.com/2noise/ChatTTS/assets/130631963/e0f51251-db7f-4d39-a0e9-3e095bb65de1)
211
+
212
+ </td>
213
+ <td align="center">
214
+
215
+ [female speaker](https://github.com/2noise/ChatTTS/assets/130631963/f5dcdd01-1091-47c5-8241-c4f6aaaa8bbd)
216
+
217
+ </td>
218
+ </tr>
219
+ </table>
220
+
221
+
222
+ </details>
223
+
224
+ ## Preguntas y Respuestas
225
+
226
+ #### 1. ¿Cuánta memoria gráfica de acceso aleatorio necesito? ¿Qué tal inferir la velocidad?
227
+ Para un clip de audio de 30 segundos, se requieren al menos 4 GB de memoria de GPU. Para la GPU 4090, puede generar audio correspondiente a aproximadamente 7 tokens semánticos por segundo. El Factor en Tiempo Real (RTF) es aproximadamente 0,3.
228
+
229
+ #### 2. La estabilidad del modelo no es lo suficientemente buena y existen problemas como varios altavoces o mala calidad del sonido.
230
+
231
+ Este es un problema común en los modelos autorregresivos (para bark y valle). Generalmente es difícil de evitar. Puede probar varias muestras para encontrar resultados adecuados.
232
+
233
+ #### 3. ¿Podemos controlar algo más que la risa? ¿Podemos controlar otras emociones?
234
+
235
+ En el modelo lanzado actualmente, las únicas unidades de control a nivel de token son `[risa]`, `[uv_break]` y `[lbreak]`. En una versión futura, es posible que abramos el código fuente del modelo con capacidades adicionales de control de emociones.
236
+
237
+ ## Agradecimientos
238
+ - [bark](https://github.com/suno-ai/bark), [XTTSv2](https://github.com/coqui-ai/TTS) y [valle](https://arxiv.org/abs/2301.02111) demuestran un resultado TTS notable mediante un sistema de estilo autorregresivo.
239
+ - [fish-speech](https://github.com/fishaudio/fish-speech) revela las capacidades de GVQ como tokenizador de audio para el modelado LLM.
240
+ - [vocos](https://github.com/gemelo-ai/vocos) se utiliza como codificador de voz previamente entrenado.
241
+
242
+ ## Agradecimiento Especial
243
+ - [wlu-audio lab](https://audio.westlake.edu.cn/) para experimentos iniciales del algoritmo.
244
+
245
+ ## Recursos Relacionados
246
+ - [Awesome-ChatTTS](https://github.com/libukai/Awesome-ChatTTS)
247
+
248
+ ## Gracias a todos los contribuyentes por sus esfuerzos.
249
+ [![contributors](https://contrib.rocks/image?repo=2noise/ChatTTS)](https://github.com/2noise/ChatTTS/graphs/contributors)
250
+
251
+ <div align="center">
252
+
253
+ ![counter](https://counter.seku.su/cmoe?name=chattts&theme=mbs)
254
+
255
+ </div>
docs/fr/README.md ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ <a href="https://trendshift.io/repositories/10489" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10489" alt="2noise%2FChatTTS | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
4
+
5
+ # ChatTTS
6
+ Un modèle de parole génératif pour le dialogue quotidien.
7
+
8
+ [![Licence](https://img.shields.io/github/license/2noise/ChatTTS?style=for-the-badge)](https://github.com/2noise/ChatTTS/blob/main/LICENSE)
9
+ [![PyPI](https://img.shields.io/pypi/v/ChatTTS.svg?style=for-the-badge&color=green)](https://pypi.org/project/ChatTTS)
10
+
11
+ [![Huggingface](https://img.shields.io/badge/🤗%20-Models-yellow.svg?style=for-the-badge)](https://huggingface.co/2Noise/ChatTTS)
12
+ [![Open In Colab](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/2noise/ChatTTS/blob/main/examples/ipynb/colab.ipynb)
13
+ [![Discord](https://img.shields.io/badge/Discord-7289DA?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/Ud5Jxgx5yD)
14
+
15
+ [**English**](../../README.md) | [**简体中文**](../cn/README.md) | [**日本語**](../jp/README.md) | [**Русский**](../ru/README.md) | [**Español**](../es/README.md)| **Français**
16
+
17
+ </div>
18
+
19
+ ## Introduction
20
+ > [!Note]
21
+ > Ce dépôt contient l'infrastructure de l'algorithme et quelques exemples simples.
22
+
23
+ > [!Tip]
24
+ > Pour les produits finaux étendus pour les utilisateurs, veuillez consulter le dépôt index [Awesome-ChatTTS](https://github.com/libukai/Awesome-ChatTTS/tree/en) maintenu par la communauté.
25
+
26
+ ChatTTS est un modèle de synthèse vocale conçu spécifiquement pour les scénarios de dialogue tels que les assistants LLM.
27
+
28
+ ### Langues prises en charge
29
+ - [x] Anglais
30
+ - [x] Chinois
31
+ - [ ] À venir...
32
+
33
+ ### Points forts
34
+ > Vous pouvez vous référer à **[cette vidéo sur Bilibili](https://www.bilibili.com/video/BV1zn4y1o7iV)** pour une description détaillée.
35
+
36
+ 1. **Synthèse vocale conversationnelle**: ChatTTS est optimisé pour les tâches basées sur le dialogue, permettant une synthèse vocale naturelle et expressive. Il prend en charge plusieurs locuteurs, facilitant les conversations interactives.
37
+ 2. **Contrôle granulaire**: Le modèle peut prédire et contrôler des caractéristiques prosodiques fines, y compris le rire, les pauses et les interjections.
38
+ 3. **Meilleure prosodie**: ChatTTS dépasse la plupart des modèles TTS open-source en termes de prosodie. Nous fournissons des modèles pré-entraînés pour soutenir la recherche et le développement.
39
+
40
+ ### Dataset & Modèle
41
+ - Le modèle principal est entraîné avec des données audio en chinois et en anglais de plus de 100 000 heures.
42
+ - La version open-source sur **[HuggingFace](https://huggingface.co/2Noise/ChatTTS)** est un modèle pré-entraîné de 40 000 heures sans SFT.
43
+
44
+ ### Roadmap
45
+ - [x] Open-source du modèle de base de 40k heures et du fichier spk_stats.
46
+ - [x] Génération audio en streaming.
47
+ - [ ] Open-source de la version 40k heures avec contrôle multi-émotions.
48
+ - [ ] ChatTTS.cpp (nouveau dépôt dans l'organisation `2noise` est bienvenu)
49
+
50
+ ### Avertissement
51
+ > [!Important]
52
+ > Ce dépôt est à des fins académiques uniquement.
53
+
54
+ Il est destiné à un usage éducatif et de recherche, et ne doit pas être utilisé à des fins commerciales ou légales. Les auteurs ne garantissent pas l'exactitude, l'exhaustivité ou la fiabilité des informations. Les informations et les données utilisées dans ce dépôt sont à des fins académiques et de recherche uniquement. Les données obtenues à partir de sources accessibles au public, et les auteurs ne revendiquent aucun droit de propriété ou de copyright sur les données.
55
+
56
+ ChatTTS est un système de synthèse vocale puissant. Cependant, il est très important d'utiliser cette technologie de manière responsable et éthique. Pour limiter l'utilisation de ChatTTS, nous avons ajouté une petite quantité de bruit haute fréquence pendant l'entraînement du modèle de 40 000 heures et compressé la qualité audio autant que possible en utilisant le format MP3, pour empêcher les acteurs malveillants de l'utiliser potentiellement à des fins criminelles. En même temps, nous avons entraîné en interne un modèle de détection et prévoyons de l'open-source à l'avenir.
57
+
58
+ ### Contact
59
+ > Les issues/PRs sur GitHub sont toujours les bienvenus.
60
+
61
+ #### Demandes formelles
62
+ Pour les demandes formelles concernant le modèle et la feuille de route, veuillez nous contacter à **[email protected]**.
63
+
64
+ #### Discussion en ligne
65
+ ##### 1. Groupe QQ (application sociale chinoise)
66
+ - **Groupe 1**, 808364215 (Complet)
67
+ - **Groupe 2**, 230696694 (Complet)
68
+ - **Groupe 3**, 933639842 (Complet)
69
+ - **Groupe 4**, 608667975
70
+
71
+ ##### 2. Serveur Discord
72
+ Rejoignez en cliquant [ici](https://discord.gg/Ud5Jxgx5yD).
73
+
74
+ ## Pour commencer
75
+ ### Cloner le dépôt
76
+ ```bash
77
+ git clone https://github.com/2noise/ChatTTS
78
+ cd ChatTTS
79
+ ```
80
+
81
+ ### Installer les dépendances
82
+ #### 1. Installation directe
83
+ ```bash
84
+ pip install --upgrade -r requirements.txt
85
+ ```
86
+
87
+ #### 2. Installer depuis conda
88
+ ```bash
89
+ conda create -n chattts
90
+ conda activate chattts
91
+ pip install -r requirements.txt
92
+ ```
93
+
94
+ #### Optionnel : Installer TransformerEngine si vous utilisez un GPU NVIDIA (Linux uniquement)
95
+ > [!Note]
96
+ > Le processus d'installation est très lent.
97
+
98
+ > [!Warning]
99
+ > L'adaptation de TransformerEngine est actuellement en cours de développement et NE PEUT PAS fonctionner correctement pour le moment.
100
+ > Installez-le uniquement à des fins de développement.
101
+
102
+ ```bash
103
+ pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
104
+ ```
105
+
106
+ #### Optionnel : Installer FlashAttention-2 (principalement GPU NVIDIA)
107
+ > [!Note]
108
+ > Voir les appareils pris en charge dans la [documentation Hugging Face](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2).
109
+
110
+ > [!Warning]
111
+ > Actuellement, FlashAttention-2 ralentira la vitesse de génération selon [ce problème](https://github.com/huggingface/transformers/issues/26990).
112
+ > Installez-le uniquement à des fins de développement.
113
+
114
+ ```bash
115
+ pip install flash-attn --no-build-isolation
116
+ ```
117
+
118
+ ### Démarrage rapide
119
+ > Assurez-vous que vous êtes dans le répertoire racine du projet lorsque vous exécutez ces commandes ci-dessous.
120
+
121
+ #### 1. Lancer WebUI
122
+ ```bash
123
+ python examples/web/webui.py
124
+ ```
125
+
126
+ #### 2. Inférence par ligne de commande
127
+ > Cela enregistrera l'audio sous ‘./output_audio_n.mp3’
128
+
129
+ ```bash
130
+ python examples/cmd/run.py "Votre premier texte." "Votre deuxième texte."
131
+ ```
132
+
133
+ ## Installation
134
+
135
+ 1. Installer la version stable depuis PyPI
136
+ ```bash
137
+ pip install ChatTTS
138
+ ```
139
+
140
+ 2. Installer la dernière version depuis GitHub
141
+ ```bash
142
+ pip install git+https://github.com/2noise/ChatTTS
143
+ ```
144
+
145
+ 3. Installer depuis le répertoire local en mode développement
146
+ ```bash
147
+ pip install -e .
148
+ ```
149
+
150
+ ### Utilisation de base
151
+
152
+ ```python
153
+ import ChatTTS
154
+ import torch
155
+ import torchaudio
156
+
157
+ chat = ChatTTS.Chat()
158
+ chat.load(compile=False) # Définissez sur True pour de meilleures performances
159
+
160
+ texts = ["METTEZ VOTRE PREMIER TEXTE ICI", "METTEZ VOTRE DEUXIÈME TEXTE ICI"]
161
+
162
+ wavs = chat.infer(texts)
163
+
164
+ torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
165
+ ```
166
+
167
+ ### Utilisation avancée
168
+
169
+ ```python
170
+ ###################################
171
+ # Échantillonner un locuteur à partir d'une distribution gaussienne.
172
+
173
+ rand_spk = chat.sample_random_speaker()
174
+ print(rand_spk) # sauvegardez-le pour une récupération ultérieure du timbre
175
+
176
+ params_infer_code = ChatTTS.Chat.InferCodeParams(
177
+ spk_emb = rand_spk, # ajouter le locuteur échantillonné
178
+ temperature = .3, # en utilisant une température personnalisée
179
+ top_P = 0.7, # top P décode
180
+ top_K = 20, # top K décode
181
+ )
182
+
183
+ ###################################
184
+ # Pour le contrôle manuel au niveau des phrases.
185
+
186
+ # utilisez oral_(0-9), laugh_(0-2), break_(0-7)
187
+ # pour générer un token spécial dans le texte à synthétiser.
188
+ params_refine_text = ChatTTS.Chat.RefineTextParams(
189
+ prompt='[oral_2][laugh_0][break_6]',
190
+ )
191
+
192
+ wavs = chat.infer(
193
+ texts,
194
+ params_refine_text=params_refine_text,
195
+ params_infer_code=params_infer_code,
196
+ )
197
+
198
+ ###################################
199
+ # Pour le contrôle manuel au niveau des mots.
200
+
201
+ text = 'Quel est [uv_break]votre plat anglais préféré?[laugh][lbreak]'
202
+ wavs = chat.infer(text, skip_refine_text=True, params_refine_text=params_refine_text, params_infer_code=params_infer_code)
203
+ torchaudio.save("output2.wav", torch.from_numpy(wavs[0]), 24000)
204
+ ```
205
+
206
+ <details open>
207
+ <summary><h4>Exemple : auto-présentation</h4></summary>
208
+
209
+ ```python
210
+ inputs_en = """
211
+ chat T T S est un modèle de synthèse vocale conçu pour les applications de dialogue.
212
+ [uv_break]il prend en charge les entrées en langues mixtes [uv_break]et offre des capacités multi-locuteurs
213
+ avec un contrôle précis des éléments prosodiques comme
214
+ [uv_break]le rire[uv_break][laugh], [uv_break]les pauses, [uv_break]et l'intonation.
215
+ [uv_break]il délivre une parole naturelle et expressive,[uv_break]donc veuillez
216
+ [uv_break]utiliser le projet de manière responsable à vos risques et périls.[uv_break]
217
+ """.replace('\n', '') # L'anglais est encore expérimental.
218
+
219
+ params_refine_text = ChatTTS.Chat.RefineTextParams(
220
+ prompt='[oral_2][laugh_0][break_4]',
221
+ )
222
+
223
+ audio_array_en = chat.infer(inputs_en, params_refine_text=params_refine_text)
224
+ torchaudio.save("output3.wav", torch.from_numpy(audio_array_en[0]), 24000)
225
+ ```
226
+
227
+ <table>
228
+ <tr>
229
+ <td align="center">
230
+
231
+ **locuteur masculin**
232
+
233
+ </td>
234
+ <td align="center">
235
+
236
+ **locutrice féminine**
237
+
238
+ </td>
239
+ </tr>
240
+ <tr>
241
+ <td align="center">
242
+
243
+ [locuteur masculin](https://github.com/2noise/ChatTTS/assets/130631963/e0f51251-db7f-4d39-a0e9-3e095bb65de1)
244
+
245
+ </td>
246
+ <td align="center">
247
+
248
+ [locutrice féminine](https://github.com/2noise/ChatTTS/assets/130631963/f5dcdd01-1091-47c5-8241-c4f6aaaa8bbd)
249
+
250
+ </td>
251
+ </tr>
252
+ </table>
253
+
254
+
255
+ </details>
256
+
257
+ ## FAQ
258
+
259
+ #### 1. De combien de VRAM ai-je besoin ? Quelle est la vitesse d'inférence ?
260
+ Pour un clip audio de 30 secondes, au moins 4 Go de mémoire GPU sont nécessaires. Pour le GPU 4090, il peut générer de l'audio correspondant à environ 7 tokens sémantiques par seconde. Le Facteur Temps Réel (RTF) est d'environ 0.3.
261
+
262
+ #### 2. La stabilité du modèle n'est pas suffisante, avec des problèmes tels que des locuteurs multiples ou une mauvaise qualité audio.
263
+ C'est un problème qui se produit généralement avec les modèles autoregressifs (pour bark et valle). Il est généralement difficile à éviter. On peut essayer plusieurs échantillons pour trouver un résultat approprié.
264
+
265
+ #### 3. En plus du rire, pouvons-nous contrôler autre chose ? Pouvons-nous contrôler d'autres émotions ?
266
+ Dans le modèle actuellement publié, les seules unités de contrôle au niveau des tokens sont `[laugh]`, `[uv_break]`, et `[lbreak]`. Dans les futures versions, nous pourrions open-source des modèles avec des capacités de contrôle émotionnel supplémentaires.
267
+
268
+ ## Remerciements
269
+ - [bark](https://github.com/suno-ai/bark), [XTTSv2](https://github.com/coqui-ai/TTS) et [valle](https://arxiv.org/abs/2301.02111) démontrent un résultat TTS remarquable par un système de style autoregressif.
270
+ - [fish-speech](https://github.com/fishaudio/fish-speech) révèle la capacité de GVQ en tant que tokenizer audio pour la modélisation LLM.
271
+ - [vocos](https://github.com/gemelo-ai/vocos) qui est utilisé comme vocodeur pré-entraîné.
272
+
273
+ ## Appréciation spéciale
274
+ - [wlu-audio lab](https://audio.westlake.edu.cn/) pour les expériences d'algorithme précoce.
275
+
276
+ ## Merci à tous les contributeurs pour leurs efforts
277
+ [![contributors](https://contrib.rocks/image?repo=2noise/ChatTTS)](https://github.com/2noise/ChatTTS/graphs/contributors)
278
+
279
+ <div align="center">
280
+
281
+ ![counter](https://counter.seku.su/cmoe?name=chattts&theme=mbs)
282
+
283
+ </div>
docs/jp/README.md ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ChatTTS
2
+ > [!NOTE]
3
+ > 以下の内容は最新情報ではない可能性がありますのでご了承ください。全ての内容は英語版に基準することになります。
4
+
5
+ [![Huggingface](https://img.shields.io/badge/🤗%20-Models-yellow.svg?style=for-the-badge)](https://huggingface.co/2Noise/ChatTTS)
6
+
7
+ [**English**](../../README.md) | [**简体中文**](../cn/README.md) | **日本語** | [**Русский**](../ru/README.md) | [**Español**](../es/README.md) | [**Français**](../fr/README.md)
8
+
9
+ ChatTTSは、LLMアシスタントなどの対話シナリオ用に特別に設計されたテキストから音声へのモデルです。英語と中国語の両方をサポートしています。私たちのモデルは、中国語と英語で構成される100,000時間以上でトレーニングされています。**[HuggingFace](https://huggingface.co/2Noise/ChatTTS)**でオープンソース化されているバージョンは、40,000時間の事前トレーニングモデルで、SFTは行われていません。
10
+
11
+ モデルやロードマップについての正式なお問い合わせは、**[email protected]**までご連絡ください。QQグループ:808364215に参加してディスカッションすることもできます。GitHubでの問題提起も歓迎します。
12
+
13
+ ---
14
+ ## ハイライト
15
+ 1. **会話型TTS**: ChatTTSは対話ベースのタスクに最適化されており、自然で表現豊かな音声合成を実現します。複数の話者をサポートし、対話型の会話を容易にします。
16
+ 2. **細かい制御**: このモデルは、笑い、一時停止、間投詞などの細かい韻律特徴を予測および制御することができます。
17
+ 3. **より良い韻律**: ChatTTSは、韻律の面でほとんどのオープンソースTTSモデルを超えています。さらなる研究と開発をサポートするために、事前トレーニングされたモデルを提供しています。
18
+
19
+ モデルの詳細な説明については、**[Bilibiliのビデオ](https://www.bilibili.com/video/BV1zn4y1o7iV)**を参照してください。
20
+
21
+ ---
22
+
23
+ ## 免責事項
24
+
25
+ このリポジトリは学術目的のみのためです。教育および研究用途にのみ使用され、商業的または法的な目的には使用されません。著者は情報の正確性、完全性、または信頼性を保証しません。このリポジトリで使用される情報およびデータは、学術および研究目的のみのためのものです。データは公開されているソースから取得され、著者はデータに対する所有権または著作権を主張しません。
26
+
27
+ ChatTTSは強力なテキストから音声へのシステムです。しかし、この技術を責任を持って、倫理的に利用することが非常に重要です。ChatTTSの使用を制限するために、40,000時間のモデルのトレーニング中に少量の高周波ノイズを追加し、MP3形式を使用して音質を可能な限り圧縮しました。これは、悪意のあるアクターが潜在的に犯罪目的で使用することを防ぐためです。同時に、私たちは内部的に検出モデルをトレーニングしており、将来的にオープンソース化する予定です。
28
+
29
+ ---
30
+ ## 使用方法
31
+
32
+ <h4>基本的な使用方法</h4>
33
+
34
+ ```python
35
+ import ChatTTS
36
+ from IPython.display import Audio
37
+ import torch
38
+
39
+ chat = ChatTTS.Chat()
40
+ chat.load(compile=False) # より良いパフォーマンスのためにTrueに設定
41
+
42
+ texts = ["ここにテキストを入力してください",]
43
+
44
+ wavs = chat.infer(texts, )
45
+
46
+ torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
47
+ ```
48
+
49
+ <h4>高度な使用方法</h4>
50
+
51
+ ```python
52
+ ###################################
53
+ # ガウス分布から話者をサンプリングします。
54
+
55
+ rand_spk = chat.sample_random_speaker()
56
+ print(rand_spk) # save it for later timbre recovery
57
+
58
+ params_infer_code = {
59
+ 'spk_emb': rand_spk, # サンプリングされた話者を追加
60
+ 'temperature': .3, # カスタム温度を使用
61
+ 'top_P': 0.7, # トップPデコード
62
+ 'top_K': 20, # トップKデコード
63
+ }
64
+
65
+ ###################################
66
+ # 文レベルの手動制御のために。
67
+
68
+ # 特別なトークンを生成するためにテキストにoral_(0-9)、laugh_(0-2)、break_(0-7)を使用します。
69
+ params_refine_text = {
70
+ 'prompt': '[oral_2][laugh_0][break_6]'
71
+ }
72
+
73
+ wav = chat.infer(texts, params_refine_text=params_refine_text, params_infer_code=params_infer_code)
74
+
75
+ ###################################
76
+ # 単語レベルの手動制御のために。
77
+ text = 'あなたの好きな英語の食べ物は何ですか?[uv_break][laugh][lbreak]'
78
+ wav = chat.infer(text, skip_refine_text=True, params_refine_text=params_refine_text, params_infer_code=params_infer_code)
79
+ torchaudio.save("output2.wav", torch.from_numpy(wavs[0]), 24000)
80
+ ```
81
+
82
+ <details open>
83
+ <summary><h4>例:自己紹介</h4></summary>
84
+
85
+ ```python
86
+ inputs_jp = """
87
+ ChatTTSは、対話アプリケーション用に設計されたテキストから音声へのモデルです。
88
+ [uv_break]混合言語入力をサポートし[uv_break]、韻律要素[laugh]の正確な制御を提供します
89
+ [uv_break]笑い[laugh]、[uv_break]一時停止、[uv_break]およびイントネーション。[uv_break]自然で表現豊かな音声を提供します
90
+ [uv_break]したがって、自己責任でプロジェクトを責任を持って使用してください。[uv_break]
91
+ """.replace('\n', '') # 英語はまだ実験的です。
92
+
93
+ params_refine_text = {
94
+ 'prompt': '[oral_2][laugh_0][break_4]'
95
+ }
96
+ audio_array_jp = chat.infer(inputs_jp, params_refine_text=params_refine_text)
97
+ torchaudio.save("output3.wav", torch.from_numpy(audio_array_jp[0]), 24000)
98
+ ```
99
+ [男性話者](https://github.com/2noise/ChatTTS/assets/130631963/e0f51251-db7f-4d39-a0e9-3e095bb65de1)
100
+
101
+ [女性話者](https://github.com/2noise/ChatTTS/assets/130631963/f5dcdd01-1091-47c5-8241-c4f6aaaa8bbd)
102
+ </details>
103
+
104
+ ---
105
+ ## ロードマップ
106
+ - [x] 40k時間のベースモデルとspk_statsファイルをオープンソース化
107
+ - [ ] VQエンコーダーとLoraトレーニングコードをオープンソース化
108
+ - [ ] テキストをリファインせずにストリーミングオーディオ生成*
109
+ - [ ] 複数の感情制御を備えた40k時間バージョンをオープンソース化
110
+ - [ ] ChatTTS.cppもしかしたら?(PRや新しいリポジトリが歓迎されます。)
111
+
112
+ ----
113
+ ## FAQ
114
+
115
+ ##### VRAMはどれくらい必要ですか?推論速度はどうですか?
116
+ 30秒のオーディオクリップには、少なくとも4GBのGPUメモリが必要です。4090 GPUの場合、約7つの意味トークンに対応するオーディオを1秒あたり生成できます。リアルタイムファクター(RTF)は約0.3です。
117
+
118
+ ##### モデルの安定性が十分でなく、複数の話者や音質が悪いという問題があります。
119
+
120
+ これは、自己回帰モデル(barkおよびvalleの場合)で一般的に発生する問題です。一般的に避けるのは難しいです。複数のサンプルを試して、適切な結果を見つけることができます。
121
+
122
+ ##### 笑い以外に何か制御できますか?他の感情を制御できますか?
123
+
124
+ 現在リリースされているモデルでは、トークンレベルの制御ユニットは[laugh]、[uv_break]、および[lbreak]のみです。将来のバージョンでは、追加の感情制御機能を備えたモデルをオープンソース化する可能性があります。
125
+
126
+ ---
127
+ ## 謝辞
128
+ - [bark](https://github.com/suno-ai/bark)、[XTTSv2](https://github.com/coqui-ai/TTS)、および[valle](https://arxiv.org/abs/2301.02111)は、自己回帰型システムによる顕著なTTS結果を示しました。
129
+ - [fish-speech](https://github.com/fishaudio/fish-speech)は、LLMモデリングのためのオーディオトークナイザーとしてのGVQの能力を明らかにしました。
130
+ - 事前トレーニングされたボコーダーとして使用される[vocos](https://github.com/gemelo-ai/vocos)。
131
+
132
+ ---
133
+ ## 特別感謝
134
+ - 初期のアルゴリズム実験をサポートしてくれた[wlu-audio lab](https://audio.westlake.edu.cn/)。
docs/ru/README.md ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ChatTTS
2
+ > [!NOTE]
3
+ > Следующая информация может быть не самой последней, пожалуйста, смотрите английскую версию для актуальных данных.
4
+
5
+ [![Huggingface](https://img.shields.io/badge/🤗%20-Models-yellow.svg?style=for-the-badge)](https://huggingface.co/2Noise/ChatTTS)
6
+
7
+ [**English**](../../README.md) | [**简体中文**](../cn/README.md) | [**日本語**](../jp/README.md) | **Русский** | [**Español**](../es/README.md) | [**Français**](../fr/README.md)
8
+
9
+ ChatTTS - это модель преобразования текста в речь, специально разработанная для диалоговых сценариев, таких как помощник LLM. Она поддерживает как английский, так и китайский языки. Наша модель обучена на более чем 100 000 часах английского и китайского языков. Открытая версия на **[HuggingFace](https://huggingface.co/2Noise/ChatTTS)** - это предварительно обученная модель с 40 000 часами без SFT.
10
+
11
+ Для официальных запросов о модели и плане развития, пожалуйста, свяжитесь с нами по адресу **[email protected]**. Вы можете присоединиться к нашей группе QQ: 808364215 для обсуждения. Добавление вопросов на GitHub также приветствуется.
12
+
13
+ ---
14
+ ## Особенности
15
+ 1. **Диалоговый TTS**: ChatTTS оптимизирован для задач, основанных на диалогах, что позволяет создавать натуральную и выразительную речь. Он поддерживает несколько говорящих, облегчая интерактивные беседы.
16
+ 2. **Тонкий контроль**: Модель может предсказывать и контролировать тонкие просодические особенности, включая смех, паузы и вставные слова.
17
+ 3. **Лучшая просодия**: ChatTTS превосходит большинство открытых моделей TTS с точки зрения просодии. Мы предоставляем предварительно обученные модели для поддержки дальнейших исследований и разработок.
18
+
19
+ Для подробного описания модели вы можете обратиться к **[видео на Bilibili](https://www.bilibili.com/video/BV1zn4y1o7iV)**
20
+
21
+ ---
22
+
23
+ ## Отказ от ответственности
24
+
25
+ Этот репозиторий предназначен только для академических целей. Он предназначен для образовательного и исследовательского использования и не должен использоваться в коммерческих или юридических целях. Авторы не гарантируют точность, полноту или надежность информации. Информация и данные, использованные в этом репозитории, предназначены только для академических и исследовательских целей. Данные получены из общедоступных источников, и авторы не заявляют о каких-либо правах собственности или авторских правах на данные.
26
+
27
+ ChatTTS - мощная система преобразования текста в речь. Однако очень важно использовать эту технологию ответственно и этично. Чтобы ограничить использование ChatTTS, мы добавили небольшое количество высокочастотного шума во время обучения модели на 40 000 часов и сжали качество аудио как можно больше с помощью формата MP3, чтобы предотвратить возможное использование злоумышленниками в преступных целях. В то же время мы внутренне обучили модель обнаружения и планируем открыть ее в будущем.
28
+
29
+ ---
30
+ ## Использование
31
+
32
+ <h4>Базовое использование</h4>
33
+
34
+ ```python
35
+ import ChatTTS
36
+ from IPython.display import Audio
37
+ import torch
38
+
39
+ chat = ChatTTS.Chat()
40
+ chat.load(compile=False) # Установите значение True для лучшей производительности
41
+
42
+ texts = ["ВВЕДИ��Е ВАШ ТЕКСТ ЗДЕСЬ",]
43
+
44
+ wavs = chat.infer(texts)
45
+
46
+ torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
47
+ ```
48
+
49
+ <h4>Продвинутое использование</h4>
50
+
51
+ ```python
52
+ ###################################
53
+ # Выборка говорящего из Гауссиана.
54
+
55
+ rand_spk = chat.sample_random_speaker()
56
+ print(rand_spk) # save it for later timbre recovery
57
+
58
+ params_infer_code = {
59
+ 'spk_emb': rand_spk, # добавить выбранного говорящего
60
+ 'temperature': .3, # использовать пользовательскую температуру
61
+ 'top_P': 0.7, # декодирование top P
62
+ 'top_K': 20, # декодирование top K
63
+ }
64
+
65
+ ###################################
66
+ # Для контроля на уровне предложений.
67
+
68
+ # используйте oral_(0-9), laugh_(0-2), break_(0-7)
69
+ # для генерации специального токена в тексте для синтеза.
70
+ params_refine_text = {
71
+ 'prompt': '[oral_2][laugh_0][break_6]'
72
+ }
73
+
74
+ wav = chat.infer(texts, params_refine_text=params_refine_text, params_infer_code=params_infer_code)
75
+
76
+ ###################################
77
+ # Для контроля на уровне слов.
78
+ text = 'Какая ваша любимая английская еда?[uv_break]your favorite english food?[laugh][lbreak]'
79
+ wav = chat.infer(text, skip_refine_text=True, params_refine_text=params_refine_text, params_infer_code=params_infer_code)
80
+ torchaudio.save("output2.wav", torch.from_numpy(wavs[0]), 24000)
81
+ ```
82
+
83
+ <details open>
84
+ <summary><h4>Пример: самопрезентация</h4></summary>
85
+
86
+ ```python
87
+ inputs_ru = """
88
+ ChatTTS - это модель преобразования текста в речь, разработанная для диалоговых приложений.
89
+ [uv_break]Она поддерживает смешанный языковой ввод [uv_break]и предлагает возможности множественных говорящих
90
+ с точным контролем над просодическими элементами [laugh]как [uv_break]смех[laugh], [uv_break]паузы, [uv_break]и интонацию.
91
+ [uv_break]Она обеспечивает натуральную и выразительную речь,[uv_break]поэтому, пожалуйста,
92
+ [uv_break] используйте проект ответственно и на свой страх и риск.[uv_break]
93
+ """.replace('\n', '') # Русский язык все еще находится в экспериментальной стадии.
94
+
95
+ params_refine_text = {
96
+ 'prompt': '[oral_2][laugh_0][break_4]'
97
+ }
98
+ audio_array_ru = chat.infer(inputs_ru, params_refine_text=params_refine_text)
99
+ torchaudio.save("output3.wav", torch.from_numpy(audio_array_ru[0]), 24000)
100
+ ```
101
+ [мужской говорящий](https://github.com/2noise/ChatTTS/assets/130631963/e0f51251-db7f-4d39-a0e9-3e095bb65de1)
102
+
103
+ [женский говорящий](https://github.com/2noise/ChatTTS/assets/130631963/f5dcdd01-1091-47c5-8241-c4f6aaaa8bbd)
104
+ </details>
105
+
106
+ ---
107
+ ## План развития
108
+ - [x] Открыть исходный код базовой модели на 40 тысяч часов и файла spk_stats
109
+ - [ ] Открыть исходный код кодировщика VQ и кода обучения Lora
110
+ - [ ] Потоковая генерация аудио без уточнения текста*
111
+ - [ ] Открыть исходный код версии на 40 тысяч часов с управлением множественными эмоциями
112
+ - [ ] ChatTTS.cpp возможно? (PR или новый репозиторий приветствуются.)
113
+
114
+ ----
115
+ ## Часто задаваемые вопросы
116
+
117
+ ##### Сколько VRAM мне нужно? Как насчет скорости инференса?
118
+ Для 30-секундного аудиоклипа требуется как минимум 4 ГБ памяти GPU. Для GPU 4090, он может генерировать аудио, соответствующее примерно 7 семантическим токенам в секунду. Фактор реального времени (RTF) составляет около 0.3.
119
+
120
+ ##### Стабильность модели кажется недостаточно хорошей, возникают проблемы с множественными говорящими или плохим качеством аудио.
121
+
122
+ Это проблема, которая обычно возникает с авторегрессивными моделями (для bark и valle). Это обычно трудно избежать. Можно попробовать несколько образцов, чтобы найти подходящий результат.
123
+
124
+ ##### Помимо смеха, можем ли мы контролировать что-то еще? Можем ли мы контролировать другие эмоции?
125
+
126
+ В текущей выпущенной модели единственными элементами управления на уровне токенов являются [laugh], [uv_break] и [lbreak]. В будущих версиях мы можем открыть модели с дополнительными возможностями контроля эмоций.
127
+
128
+ ---
129
+ ## Благодарности
130
+ - [bark](https://github.com/suno-ai/bark), [XTTSv2](https://github.com/coqui-ai/TTS) и [valle](https://arxiv.org/abs/2301.02111) демонстрируют замечательный результат TTS с помощью системы авторегрессивного стиля.
131
+ - [fish-speech](https://github.com/fishaudio/fish-speech) показывает возможности GVQ как аудио токенизатора для моделирования LLM.
132
+ - [vocos](https://github.com/gemelo-ai/vocos), который используется в качестве предварительно обученного вокодера.
133
+
134
+ ---
135
+ ## Особая благодарность
136
+ - [wlu-audio lab](https://audio.westlake.edu.cn/) за ранние эксперименты с алгоритмами.
examples/__init__.py ADDED
File without changes
examples/api/README.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generating voice with ChatTTS via API
2
+
3
+ ## Install requirements
4
+
5
+ Install `FastAPI` and `requests`:
6
+
7
+ ```
8
+ pip install -r examples/api/requirements.txt
9
+ ```
10
+
11
+ ## Run API server
12
+
13
+ ```
14
+ fastapi dev examples/api/main.py --host 0.0.0.0 --port 8000
15
+ ```
16
+
17
+ ## Generate audio using requests
18
+
19
+ ```
20
+ python examples/api/client.py
21
+ ```
22
+
23
+ mp3 audio files will be saved to the `output` directory.
examples/api/client.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ import zipfile
4
+ from io import BytesIO
5
+
6
+ import requests
7
+
8
+ chattts_service_host = os.environ.get("CHATTTS_SERVICE_HOST", "localhost")
9
+ chattts_service_port = os.environ.get("CHATTTS_SERVICE_PORT", "8000")
10
+
11
+ CHATTTS_URL = f"http://{chattts_service_host}:{chattts_service_port}/generate_voice"
12
+
13
+
14
+ # main infer params
15
+ body = {
16
+ "text": [
17
+ "四川美食确实以辣闻名,但也有不辣的选择。",
18
+ "比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。",
19
+ ],
20
+ "stream": False,
21
+ "lang": None,
22
+ "skip_refine_text": True,
23
+ "refine_text_only": False,
24
+ "use_decoder": True,
25
+ "audio_seed": 12345678,
26
+ "text_seed": 87654321,
27
+ "do_text_normalization": True,
28
+ "do_homophone_replacement": False,
29
+ }
30
+
31
+ # refine text params
32
+ params_refine_text = {
33
+ "prompt": "",
34
+ "top_P": 0.7,
35
+ "top_K": 20,
36
+ "temperature": 0.7,
37
+ "repetition_penalty": 1,
38
+ "max_new_token": 384,
39
+ "min_new_token": 0,
40
+ "show_tqdm": True,
41
+ "ensure_non_empty": True,
42
+ "stream_batch": 24,
43
+ }
44
+ body["params_refine_text"] = params_refine_text
45
+
46
+ # infer code params
47
+ params_infer_code = {
48
+ "prompt": "[speed_5]",
49
+ "top_P": 0.1,
50
+ "top_K": 20,
51
+ "temperature": 0.3,
52
+ "repetition_penalty": 1.05,
53
+ "max_new_token": 2048,
54
+ "min_new_token": 0,
55
+ "show_tqdm": True,
56
+ "ensure_non_empty": True,
57
+ "stream_batch": True,
58
+ "spk_emb": None,
59
+ }
60
+ body["params_infer_code"] = params_infer_code
61
+
62
+
63
+ try:
64
+ response = requests.post(CHATTTS_URL, json=body)
65
+ response.raise_for_status()
66
+ with zipfile.ZipFile(BytesIO(response.content), "r") as zip_ref:
67
+ # save files for each request in a different folder
68
+ dt = datetime.datetime.now()
69
+ ts = int(dt.timestamp())
70
+ tgt = f"./output/{ts}/"
71
+ os.makedirs(tgt, 0o755)
72
+ zip_ref.extractall(tgt)
73
+ print("Extracted files into", tgt)
74
+
75
+ except requests.exceptions.RequestException as e:
76
+ print(f"Request Error: {e}")
examples/api/main.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import sys
4
+ import zipfile
5
+
6
+ from fastapi import FastAPI
7
+ from fastapi.responses import StreamingResponse
8
+
9
+
10
+ if sys.platform == "darwin":
11
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
12
+
13
+ now_dir = os.getcwd()
14
+ sys.path.append(now_dir)
15
+
16
+ from typing import Optional
17
+
18
+ import ChatTTS
19
+
20
+ from tools.audio import pcm_arr_to_mp3_view
21
+ from tools.logger import get_logger
22
+ import torch
23
+
24
+
25
+ from pydantic import BaseModel
26
+
27
+
28
+ logger = get_logger("Command")
29
+
30
+ app = FastAPI()
31
+
32
+
33
+ @app.on_event("startup")
34
+ async def startup_event():
35
+ global chat
36
+
37
+ chat = ChatTTS.Chat(get_logger("ChatTTS"))
38
+ logger.info("Initializing ChatTTS...")
39
+ if chat.load():
40
+ logger.info("Models loaded successfully.")
41
+ else:
42
+ logger.error("Models load failed.")
43
+ sys.exit(1)
44
+
45
+
46
+ class ChatTTSParams(BaseModel):
47
+ text: list[str]
48
+ stream: bool = False
49
+ lang: Optional[str] = None
50
+ skip_refine_text: bool = False
51
+ refine_text_only: bool = False
52
+ use_decoder: bool = True
53
+ do_text_normalization: bool = True
54
+ do_homophone_replacement: bool = False
55
+ params_refine_text: ChatTTS.Chat.RefineTextParams
56
+ params_infer_code: ChatTTS.Chat.InferCodeParams
57
+
58
+
59
+ @app.post("/generate_voice")
60
+ async def generate_voice(params: ChatTTSParams):
61
+ logger.info("Text input: %s", str(params.text))
62
+
63
+ # audio seed
64
+ if params.params_infer_code.manual_seed is not None:
65
+ torch.manual_seed(params.params_infer_code.manual_seed)
66
+ params.params_infer_code.spk_emb = chat.sample_random_speaker()
67
+
68
+ # text seed for text refining
69
+ if params.params_refine_text:
70
+ text = chat.infer(
71
+ text=params.text, skip_refine_text=False, refine_text_only=True
72
+ )
73
+ logger.info(f"Refined text: {text}")
74
+ else:
75
+ # no text refining
76
+ text = params.text
77
+
78
+ logger.info("Use speaker:")
79
+ logger.info(params.params_infer_code.spk_emb)
80
+
81
+ logger.info("Start voice inference.")
82
+ wavs = chat.infer(
83
+ text=text,
84
+ stream=params.stream,
85
+ lang=params.lang,
86
+ skip_refine_text=params.skip_refine_text,
87
+ use_decoder=params.use_decoder,
88
+ do_text_normalization=params.do_text_normalization,
89
+ do_homophone_replacement=params.do_homophone_replacement,
90
+ params_infer_code=params.params_infer_code,
91
+ params_refine_text=params.params_refine_text,
92
+ )
93
+ logger.info("Inference completed.")
94
+
95
+ # zip all of the audio files together
96
+ buf = io.BytesIO()
97
+ with zipfile.ZipFile(
98
+ buf, "a", compression=zipfile.ZIP_DEFLATED, allowZip64=False
99
+ ) as f:
100
+ for idx, wav in enumerate(wavs):
101
+ f.writestr(f"{idx}.mp3", pcm_arr_to_mp3_view(wav))
102
+ logger.info("Audio generation successful.")
103
+ buf.seek(0)
104
+
105
+ response = StreamingResponse(buf, media_type="application/zip")
106
+ response.headers["Content-Disposition"] = "attachment; filename=audio_files.zip"
107
+ return response
examples/api/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ fastapi
2
+ requests
examples/cmd/run.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ if sys.platform == "darwin":
4
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
5
+
6
+ now_dir = os.getcwd()
7
+ sys.path.append(now_dir)
8
+
9
+ from typing import Optional, List
10
+ import argparse
11
+
12
+ import numpy as np
13
+
14
+ import ChatTTS
15
+
16
+ from tools.logger import get_logger
17
+ from tools.audio import pcm_arr_to_mp3_view
18
+ from tools.normalizer.en import normalizer_en_nemo_text
19
+ from tools.normalizer.zh import normalizer_zh_tn
20
+
21
+ logger = get_logger("Command")
22
+
23
+
24
+ def save_mp3_file(wav, index):
25
+ data = pcm_arr_to_mp3_view(wav)
26
+ mp3_filename = f"output_audio_{index}.mp3"
27
+ with open(mp3_filename, "wb") as f:
28
+ f.write(data)
29
+ logger.info(f"Audio saved to {mp3_filename}")
30
+
31
+
32
+ def load_normalizer(chat: ChatTTS.Chat):
33
+ # try to load normalizer
34
+ try:
35
+ chat.normalizer.register("en", normalizer_en_nemo_text())
36
+ except ValueError as e:
37
+ logger.error(e)
38
+ except BaseException:
39
+ logger.warning("Package nemo_text_processing not found!")
40
+ logger.warning(
41
+ "Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing",
42
+ )
43
+ try:
44
+ chat.normalizer.register("zh", normalizer_zh_tn())
45
+ except ValueError as e:
46
+ logger.error(e)
47
+ except BaseException:
48
+ logger.warning("Package WeTextProcessing not found!")
49
+ logger.warning(
50
+ "Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing",
51
+ )
52
+
53
+
54
+ def main(
55
+ texts: List[str],
56
+ spk: Optional[str] = None,
57
+ stream: bool = False,
58
+ source: str = "local",
59
+ custom_path: str = "",
60
+ ):
61
+ logger.info("Text input: %s", str(texts))
62
+
63
+ chat = ChatTTS.Chat(get_logger("ChatTTS"))
64
+ logger.info("Initializing ChatTTS...")
65
+ load_normalizer(chat)
66
+
67
+ is_load = False
68
+ if os.path.isdir(custom_path) and source == "custom":
69
+ is_load = chat.load(source="custom", custom_path=custom_path)
70
+ else:
71
+ is_load = chat.load(source=source)
72
+
73
+ if is_load:
74
+ logger.info("Models loaded successfully.")
75
+ else:
76
+ logger.error("Models load failed.")
77
+ sys.exit(1)
78
+
79
+ if spk is None:
80
+ spk = chat.sample_random_speaker()
81
+ logger.info("Use speaker:")
82
+ print(spk)
83
+
84
+ logger.info("Start inference.")
85
+ wavs = chat.infer(
86
+ texts,
87
+ stream,
88
+ params_infer_code=ChatTTS.Chat.InferCodeParams(
89
+ spk_emb=spk,
90
+ ),
91
+ )
92
+ logger.info("Inference completed.")
93
+ # Save each generated wav file to a local file
94
+ if stream:
95
+ wavs_list = []
96
+ for index, wav in enumerate(wavs):
97
+ if stream:
98
+ for i, w in enumerate(wav):
99
+ save_mp3_file(w, (i + 1) * 1000 + index)
100
+ wavs_list.append(wav)
101
+ else:
102
+ save_mp3_file(wav, index)
103
+ if stream:
104
+ for index, wav in enumerate(np.concatenate(wavs_list, axis=1)):
105
+ save_mp3_file(wav, index)
106
+ logger.info("Audio generation successful.")
107
+
108
+
109
+ if __name__ == "__main__":
110
+ r"""
111
+ python -m examples.cmd.run \
112
+ --source custom --custom_path ../../models/2Noise/ChatTTS 你好喲 ":)"
113
+ """
114
+ logger.info("Starting ChatTTS commandline demo...")
115
+ parser = argparse.ArgumentParser(
116
+ description="ChatTTS Command",
117
+ usage='[--spk xxx] [--stream] [--source ***] [--custom_path XXX] "Your text 1." " Your text 2."',
118
+ )
119
+ parser.add_argument(
120
+ "--spk",
121
+ help="Speaker (empty to sample a random one)",
122
+ type=Optional[str],
123
+ default=None,
124
+ )
125
+ parser.add_argument(
126
+ "--stream",
127
+ help="Use stream mode",
128
+ action="store_true",
129
+ )
130
+ parser.add_argument(
131
+ "--source",
132
+ help="source form [ huggingface(hf download), local(ckpt save to asset dir), custom(define) ]",
133
+ type=str,
134
+ default="local",
135
+ )
136
+ parser.add_argument(
137
+ "--custom_path",
138
+ help="custom defined model path(include asset ckpt dir)",
139
+ type=str,
140
+ default="",
141
+ )
142
+ parser.add_argument(
143
+ "texts",
144
+ help="Original text",
145
+ default=["YOUR TEXT HERE"],
146
+ nargs=argparse.REMAINDER,
147
+ )
148
+ args = parser.parse_args()
149
+ logger.info(args)
150
+ main(args.texts, args.spk, args.stream, args.source, args.custom_path)
151
+ logger.info("ChatTTS process finished.")