File size: 4,157 Bytes
d2b7e94
01e655b
 
d2b7e94
01e655b
 
d2b7e94
01e655b
 
 
 
 
 
da8d589
 
01e655b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c22399
 
 
 
 
 
 
 
01e655b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c22399
 
01e655b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da8d589
01e655b
 
 
 
 
 
 
da8d589
 
 
 
8c22399
01e655b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c22399
01e655b
 
 
 
 
 
8c22399
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
from fastapi import HTTPException
from pydantic import BaseModel

from modules.api import utils as api_utils
from modules.api.Api import APIManager
from modules.speaker import speaker_mgr


class CreateSpeaker(BaseModel):
    name: str
    gender: str
    describe: str
    tensor: list = None
    seed: int = None


class UpdateSpeaker(BaseModel):
    id: str
    name: str
    gender: str
    describe: str
    tensor: list


class SpeakerDetail(BaseModel):
    id: str
    with_emb: bool = False


class SpeakersUpdate(BaseModel):
    speakers: list


def setup(app: APIManager):

    @app.get("/v1/speakers/list", response_model=api_utils.BaseResponse)
    async def list_speakers():
        return api_utils.success_response(
            [spk.to_json() for spk in speaker_mgr.list_speakers()]
        )

    @app.post("/v1/speakers/refresh", response_model=api_utils.BaseResponse)
    async def refresh_speakers():
        speaker_mgr.refresh_speakers()
        return api_utils.success_response(None)

    @app.post("/v1/speakers/update", response_model=api_utils.BaseResponse)
    async def update_speakers(request: SpeakersUpdate):
        for spk in request.speakers:
            speaker = speaker_mgr.get_speaker_by_id(spk["id"])
            if speaker is None:
                raise HTTPException(
                    status_code=404, detail=f"Speaker not found: {spk['id']}"
                )
            speaker.name = spk.get("name", speaker.name)
            speaker.gender = spk.get("gender", speaker.gender)
            speaker.describe = spk.get("describe", speaker.describe)
            if (
                spk.get("tensor")
                and isinstance(spk["tensor"], list)
                and len(spk["tensor"]) > 0
            ):
                # number array => Tensor
                speaker.emb = torch.tensor(spk["tensor"])
        speaker_mgr.save_all()

        return api_utils.success_response(None)

    @app.post("/v1/speaker/create", response_model=api_utils.BaseResponse)
    async def create_speaker(request: CreateSpeaker):
        if (
            request.tensor
            and isinstance(request.tensor, list)
            and len(request.tensor) > 0
        ):
            # from tensor
            tensor = torch.tensor(request.tensor)
            speaker = speaker_mgr.create_speaker_from_tensor(
                tensor=tensor,
                name=request.name,
                gender=request.gender,
                describe=request.describe,
            )
        elif request.seed:
            # from seed
            speaker = speaker_mgr.create_speaker_from_seed(
                seed=request.seed,
                name=request.name,
                gender=request.gender,
                describe=request.describe,
            )
        else:
            raise HTTPException(
                status_code=400, detail="Missing tensor or seed in request"
            )
        return api_utils.success_response(speaker.to_json())

    @app.post("/v1/speaker/update", response_model=api_utils.BaseResponse)
    async def update_speaker(request: UpdateSpeaker):
        speaker = speaker_mgr.get_speaker_by_id(request.id)
        if speaker is None:
            raise HTTPException(
                status_code=404, detail=f"Speaker not found: {request.id}"
            )
        speaker.name = request.name
        speaker.gender = request.gender
        speaker.describe = request.describe
        if (
            request.tensor
            and isinstance(request.tensor, list)
            and len(request.tensor) > 0
        ):
            # number array => Tensor
            speaker.emb = torch.tensor(request.tensor)
        speaker_mgr.update_speaker(speaker)
        return api_utils.success_response(None)

    @app.post("/v1/speaker/detail", response_model=api_utils.BaseResponse)
    async def speaker_detail(request: SpeakerDetail):
        speaker = speaker_mgr.get_speaker_by_id(request.id)
        if speaker is None:
            raise HTTPException(status_code=404, detail="Speaker not found")
        return api_utils.success_response(speaker.to_json(with_emb=request.with_emb))