ChatTTS-Forge / modules /speaker.py
zhzluke96
update
32b2aaa
raw
history blame
5.4 kB
import os
from typing import Union
from box import Box
import torch
from modules import models
from modules.utils.SeedContext import SeedContext
import uuid
def create_speaker_from_seed(seed):
chat_tts = models.load_chat_tts()
with SeedContext(seed):
emb = chat_tts.sample_random_speaker()
return emb
class Speaker:
@staticmethod
def from_file(file_like):
speaker = torch.load(file_like, map_location=torch.device("cpu"))
speaker.fix()
return speaker
@staticmethod
def from_tensor(tensor):
speaker = Speaker(seed=-2)
speaker.emb = tensor
return speaker
def __init__(self, seed, name="", gender="", describe=""):
self.id = uuid.uuid4()
self.seed = seed
self.name = name
self.gender = gender
self.describe = describe
self.emb = None
# TODO replace emb => tokens
self.tokens = []
def to_json(self, with_emb=False):
return Box(
**{
"id": str(self.id),
"seed": self.seed,
"name": self.name,
"gender": self.gender,
"describe": self.describe,
"emb": self.emb.tolist() if with_emb else None,
}
)
def fix(self):
is_update = False
if "id" not in self.__dict__:
setattr(self, "id", uuid.uuid4())
is_update = True
if "seed" not in self.__dict__:
setattr(self, "seed", -2)
is_update = True
if "name" not in self.__dict__:
setattr(self, "name", "")
is_update = True
if "gender" not in self.__dict__:
setattr(self, "gender", "*")
is_update = True
if "describe" not in self.__dict__:
setattr(self, "describe", "")
is_update = True
return is_update
def __hash__(self):
return hash(str(self.id))
def __eq__(self, other):
if not isinstance(other, Speaker):
return False
return str(self.id) == str(other.id)
# 每个speaker就是一个 emb 文件 .pt
# 管理 speaker 就是管理 ./data/speaker/ 下的所有 speaker
# 可以 用 seed 创建一个 speaker
# 可以 刷新列表 读取所有 speaker
# 可以列出所有 speaker
class SpeakerManager:
def __init__(self):
self.speakers = {}
self.speaker_dir = "./data/speakers/"
self.refresh_speakers()
def refresh_speakers(self):
self.speakers = {}
for speaker_file in os.listdir(self.speaker_dir):
if speaker_file.endswith(".pt"):
self.speakers[speaker_file] = Speaker.from_file(
self.speaker_dir + speaker_file
)
# 检查是否有被删除的,同步到 speakers
for fname, spk in self.speakers.items():
if not os.path.exists(self.speaker_dir + fname):
del self.speakers[fname]
def list_speakers(self):
return list(self.speakers.values())
def create_speaker_from_seed(self, seed, name="", gender="", describe=""):
if name == "":
name = seed
filename = name + ".pt"
speaker = Speaker(seed, name=name, gender=gender, describe=describe)
speaker.emb = create_speaker_from_seed(seed)
torch.save(speaker, self.speaker_dir + filename)
self.refresh_speakers()
return speaker
def create_speaker_from_tensor(
self, tensor, filename="", name="", gender="", describe=""
):
if filename == "":
filename = name
speaker = Speaker(seed=-2, name=name, gender=gender, describe=describe)
if isinstance(tensor, torch.Tensor):
speaker.emb = tensor
if isinstance(tensor, list):
speaker.emb = torch.tensor(tensor)
torch.save(speaker, self.speaker_dir + filename + ".pt")
self.refresh_speakers()
return speaker
def get_speaker(self, name) -> Union[Speaker, None]:
for speaker in self.speakers.values():
if speaker.name == name:
return speaker
return None
def get_speaker_by_id(self, id) -> Union[Speaker, None]:
for speaker in self.speakers.values():
if str(speaker.id) == str(id):
return speaker
return None
def get_speaker_filename(self, id: str):
filename = None
for fname, spk in self.speakers.items():
if str(spk.id) == str(id):
filename = fname
break
return filename
def update_speaker(self, speaker: Speaker):
filename = None
for fname, spk in self.speakers.items():
if str(spk.id) == str(speaker.id):
filename = fname
break
if filename:
torch.save(speaker, self.speaker_dir + filename)
self.refresh_speakers()
return speaker
else:
raise ValueError("Speaker not found for update")
def save_all(self):
for speaker in self.speakers.values():
filename = self.get_speaker_filename(speaker.id)
torch.save(speaker, self.speaker_dir + filename)
# self.refresh_speakers()
def __len__(self):
return len(self.speakers)
speaker_mgr = SpeakerManager()