File size: 4,341 Bytes
ba568fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43bee8d
ba568fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
031ba98
 
 
 
 
 
 
ba568fb
031ba98
ba568fb
 
 
 
 
 
 
 
 
031ba98
 
ba568fb
031ba98
ba568fb
031ba98
 
ba568fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#basic enviornments & openai
import romajitable
import re
import os
import numpy as np
import logging
logging.getLogger('numba').setLevel(logging.WARNING)
import IPython.display as ipd
import torch
import commons
import utils
from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence
import openai
import tkinter as tk
from tkinter import scrolledtext
import argparse
import time
from scipy.io.wavfile import write
def get_args():
    parser = argparse.ArgumentParser(description='inference')
    parser.add_argument('--model', default = 'lovelive/G_936000.pth')
    parser.add_argument('--audio',
                    type=str,
                    help='the sound file of live2d to be replace,assuming they are temp1.wav,temp2.wav,temp3.wav......',
                    default = 'path/to/temp.wav')
    parser.add_argument('--cfg', default="lovelive/config.json")
    parser.add_argument('--key',default = "openai key",
                        help='platform.openai.com')
    args = parser.parse_args()
    return args
args = get_args()
dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dev = torch.device("cuda:0")
hps_ms = utils.get_hparams_from_file(args.cfg)
#mult-speakers
net_g_ms = SynthesizerTrn(
    len(symbols),
    hps_ms.data.filter_length // 2 + 1,
    hps_ms.train.segment_size // hps_ms.data.hop_length,
    n_speakers=hps_ms.data.n_speakers,
    **hps_ms.model).to(dev)
_ = net_g_ms.eval()
_ = utils.load_checkpoint(args.model, net_g_ms, None)
# detecting japanese
def is_japanese(string):
        for ch in string:
            if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
                return True
        return False 
def get_text(text, hps):
    text_norm = text_to_sequence(text, hps.data.text_cleaners)
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    text_norm = torch.LongTensor(text_norm)
    return text_norm
def ttv(text):
    text = text.replace('\n','').replace(' ','')
    text = f"[JA]{text}[JA]" if is_japanese(text) else f"[ZH]{text}[ZH]"
    speaker_id = 7
    stn_tst = get_text(text,hps_ms)
    t1 = time.time()
    with torch.no_grad():
        x_tst = stn_tst.unsqueeze(0).to(dev)
        x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(dev)
        sid = torch.LongTensor([speaker_id]).to(dev)
        audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.467, noise_scale_w=0.5, length_scale=1)[0][0,0].data.cpu().float().numpy()
    write(args.audio + '.wav',22050,audio)
    i = 0
    while i < 19:
        i +=1
        cmd = 'ffmpeg -y -i ' +  args.audio + '.wav' + ' -ar 44100 '+ args.audio.replace('temp','temp'+str(i))
        os.system(cmd)
    t2 = time.time()
    print("推理耗时:",(t2 - t1),"s")
openai.api_key = args.key
result_list = []
messages = []
read_log = input('Loading log?(y/n)')
if read_log == 'y':
    messages = []
    with open('log.pickle', 'rb') as f:
        messages = pickle.load(f)
    print('Most recently log:\n'+str(messages[-1]))
def send_message():
    text = input_box.get("1.0", "end-1c") # 获取用户输入的文本
    messages.append({"role": "user", "content": text},)
    chat = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=messages)
    reply = chat.choices[0].message.content
    ttv(reply)
    messages.append({"role": "assistant", "content": reply})
    print(messages[-1])
    if len(messages) == 12:
        messages[6:10] = messages[8:]
        del messages[-2:]
    with open('log.pickle', 'wb') as f:
         pickle.dump(messages, f)
    chat_box.configure(state='normal') 
    chat_box.insert(tk.END, "You: " + text + "\n") 
    chat_box.insert(tk.END, "Tamao: " + reply + "\n") 
    chat_box.configure(state='disabled') 
    input_box.delete("1.0", tk.END) 

root = tk.Tk()
root.title("Tamao")

chat_box = scrolledtext.ScrolledText(root, width=50, height=10)
chat_box.configure(state='disabled')
chat_box.pack(side=tk.TOP, fill=tk.BOTH, padx=10, pady=10, expand=True)

input_frame = tk.Frame(root)
input_frame.pack(side=tk.BOTTOM, fill=tk.X, padx=10, pady=10)
input_box = tk.Text(input_frame, height=3, width=50)
input_box.pack(side=tk.LEFT, fill=tk.X, padx=10, expand=True)
send_button = tk.Button(input_frame, text="Send", command=send_message)
send_button.pack(side=tk.RIGHT, padx=10)

root.mainloop()