import os from PIL import Image import random import shutil import datetime import torchvision.transforms.functional as f import torch from typing import Optional, Tuple from threading import Lock from langchain import ConversationChain from chat_anything.tts_talker.tts_edge import TTSTalker from chat_anything.sad_talker.sad_talker import SadTalker from chat_anything.chatbot.chat import load_chain from chat_anything.chatbot.select import model_selection_chain from chat_anything.chatbot.voice_select import voice_selection_chain import gradio as gr TALKING_HEAD_WIDTH = "350" sadtalker_checkpoint_path = "MODELS/SadTalker" config_path = "chat_anything/sad_talker/config" class ChatWrapper: def __init__(self): self.lock = Lock() self.sad_talker = SadTalker( sadtalker_checkpoint_path, config_path, lazy_load=True) def __call__( self, api_key: str, inp: str, history: Optional[Tuple[str, str]], chain: Optional[ConversationChain], speak_text: bool, talking_head: bool, uid: str, talker : None, fullbody : str, ): """Execute the chat functionality.""" self.lock.acquire() if chain is None: history.append((inp, "Please register with your API key first!")) else: try: print("\n==== date/time: " + str(datetime.datetime.now()) + " ====") print("inp: " + inp) print("speak_text: ", speak_text) print("talking_head: ", talking_head) history = history or [] # If chain is None, that is because no API key was provided. output = "Please paste your OpenAI key from openai.com to use this app. " + \ str(datetime.datetime.now()) output = chain.predict(input=inp).strip() output = output.replace("\n", "\n\n") text_to_display = output # #预定义一个talker # talker = MaleEn() history.append((inp, text_to_display)) html_video, temp_file, html_audio, temp_aud_file = None, None, None, None if speak_text: if talking_head: html_video, temp_file = self.do_html_video_speak( talker, output, fullbody, uid) else: html_audio, temp_aud_file = self.do_html_audio_speak( talker, output,uid) else: if talking_head: temp_file = os.path.join('tmp', uid, 'videos') html_video = create_html_video( temp_file, TALKING_HEAD_WIDTH) else: pass except Exception as e: raise e finally: self.lock.release() return history, history, html_video, temp_file, html_audio, temp_aud_file, "" def do_html_audio_speak(self,talker, words_to_speak, uid): audio_path = os.path.join('tmp', uid, 'audios') print('uid:', uid, ":", words_to_speak) audo_file_path = talker.test(text=words_to_speak, audio_path=audio_path) html_audio = '
no audio
' try: temp_aud_file = gr.File(audo_file_path) # print("audio-----------------------------------------------------success") temp_aud_file_url = "/file=" + temp_aud_file.value['name'] html_audio = f'' except IOError as error: # Could not write to file, exit gracefully print(error) return None, None return html_audio, audo_file_path def do_html_video_speak(self,talker,words_to_speak,fullbody, uid): if fullbody: # preprocess='somthing' preprocess='full' else: preprocess='crop' print("success") video_path = os.path.join('tmp', uid, 'videos') if not os.path.exists(video_path): os.makedirs(video_path) video_file_path = os.path.join(video_path, 'tempfile.mp4') _, audio_path = self.do_html_audio_speak( talker,words_to_speak,uid) face_file_path = os.path.join('tmp', uid, 'images', 'test.jpg') video = self.sad_talker.test(face_file_path, audio_path,preprocess, uid=uid) #video_file_path # print("---------------------------------------------------------success") # print(f"moving {video} -> {video_file_path}") shutil.move(video, video_file_path) return video_file_path, video_file_path def generate_init_face_video(self,class_concept="clock", llm=None,uid=None,fullbody=None, ref_image=None, seed=None): """ """ print('generate concept of', class_concept) print("=================================================") print('fullbody:', fullbody) print('uid:', uid) print("==================================================") chain, memory, personality_text = load_chain(llm, class_concept) model_conf, selected_model = model_selection_chain(llm, class_concept, conf_file='resources/models.yaml') # use class concept to choose a generating model, otherwise crack down # model_conf, selected_model = model_selection_chain(llm, personality_text, conf_file='resources/models_personality.yaml') # use class concept to choose a generating model, otherwise crack down voice_conf, selected_voice = model_selection_chain(llm, personality_text, conf_file='resources/voices_edge.yaml') # added for safe face generation print('generate concept of', class_concept) augment_word_list = ["Female ", "female ", "beautiful ", "small ", "cute "] first_sentence = "Hello, how are you doing today?" voice_conf, selected_voice = model_selection_chain(llm, personality_text, conf_file='resources/voices_edge.yaml') talker = TTSTalker(selected_voice=selected_voice, gender=voice_conf['gender'], language=voice_conf['language']) model_conf, selected_model = model_selection_chain(llm, class_concept, conf_file='resources/models.yaml') # use class concept to choose a generating model, otherwise crack down retry_cnt = 4 if ref_image is None: face_files = os.listdir(FACE_DIR) face_img_path = os.path.join(FACE_DIR, random.choice(face_files)) ref_image = Image.open(face_img_path) print('loading face generating model') anything_facemaker = load_face_generator( model_dir=model_conf['model_dir'], lora_path=model_conf['lora_path'], prompt_template=model_conf['prompt_template'], negative_prompt=model_conf['negative_prompt'], ) retry_cnt = 0 has_face = anything_facemaker.has_face(ref_image) init_strength = 1.0 if has_face else 0.85 strength_retry_step = -0.04 if has_face else 0.04 while retry_cnt < 8: try: generate_face_image( anything_facemaker, class_concept, ref_image, uid=uid, strength=init_strength if (retry_cnt==0 and has_face) else init_strength + retry_cnt * strength_retry_step, controlnet_conditioning_scale=0.5 if retry_cnt == 8 else 0.3, seed=seed, ) self.do_html_video_speak(talker, first_sentence, fullbody, uid=uid) video_file_path = os.path.join('tmp', uid, 'videos/tempfile.mp4') htm_video = create_html_video( video_file_path, TALKING_HEAD_WIDTH) break except Exception as e: retry_cnt += 1 class_concept = random.choice(augment_word_list) + class_concept print(e) # end of repeat block return chain, memory, htm_video, talker def update_talking_head(self, widget, uid, state): # print("success----------------") if widget: state = widget temp_file = os.path.join('tmp', uid, 'videos') video_html_talking_head = create_html_video( temp_file, TALKING_HEAD_WIDTH) return state, video_html_talking_head else: return None, "
"


def reset_memory(history, memory):
    memory.clear()
    history = []
    return history, history, memory
            

def create_html_video(file_name, width):
    return file_name


def create_html_audio(file_name):
    if os.path.exists(file_name):
        tmp_audio_file = gr.File(file_name, visible=False)
        tmp_aud_file_url = "/file=" + tmp_audio_file.value['name']
        html_audio = f''
        del tmp_aud_file_url
    else:
       html_audio = f'' 
    
    return html_audio


def update_foo(widget, state):
    if widget:
        state = widget
        return state


# Pertains to question answering functionality
def update_use_embeddings(widget, state):
    if widget:
        state = widget
        return state

# This is the code for image generating.


def load_face_generator(model_dir, lora_path, prompt_template, negative_prompt):
    from chat_anything.face_generator.long_prompt_control_generator import LongPromptControlGenerator
    # # using local
    model_zoo = "MODELS"
    face_control_dir = os.path.join(
        model_zoo, "Face-Landmark-ControlNet", "models_for_diffusers")
    face_detect_path = os.path.join(
        model_zoo, "SadTalker", "shape_predictor_68_face_landmarks.dat")
    # use remote, hugginface auto-download.
    # use your model path, has to be a model derived from stable diffusion v1-5
    anything_facemaker = LongPromptControlGenerator(
        model_dir=model_dir,
        lora_path=lora_path,
        prompt_template=prompt_template,
        negative_prompt=negative_prompt,
        face_control_dir=face_control_dir,
        face_detect_path=face_detect_path,
    )
    anything_facemaker.load_model(safety_checker=None)
    return anything_facemaker



FACE_DIR="resources/images/faces"
def generate_face_image(
        anything_facemaker,
        class_concept, 
        face_img_pil,
        uid=None,
        controlnet_conditioning_scale=1.0,
        strength=0.95,
        seed=42,
    ):
    face_img_pil = f.center_crop(
        f.resize(face_img_pil, 512), 512).convert('RGB')
    prompt = anything_facemaker.prompt_template.format(class_concept)
    # # There are four ways to generate a image by now.
    # pure_generate = anything_facemaker.generate(prompt=prompt, image=face_img_pil, do_inversion=False)
    # inversion = anything_facemaker.generate(prompt=prompt, image=face_img_pil, strength=strength, do_inversion=True)

    print('USING SEED:', seed)
    generator = torch.Generator(device=anything_facemaker.face_control_pipe.device)
    generator.manual_seed(seed)
    if strength is None:
        pure_control = anything_facemaker.face_control_generate(prompt=prompt, face_img_pil=face_img_pil, do_inversion=False,
                                                                 controlnet_conditioning_scale=controlnet_conditioning_scale, generator=generator)
        init_face_pil = pure_control
    else:
        control_inversion = anything_facemaker.face_control_generate(prompt=prompt, face_img_pil=face_img_pil, do_inversion=True, 
                                                                 strength=strength,
                                                                 controlnet_conditioning_scale=controlnet_conditioning_scale, generator=generator)
        init_face_pil = control_inversion
    print('succeeded generating face image')
    face_path = os.path.join('tmp', uid, 'images')
    if not os.path.exists(face_path):
        os.makedirs(face_path)
    # TODO: reproduce the images for return, shouldn't use the filesystem
    face_file_path = os.path.join(face_path, 'test.jpg')
    init_face_pil.save(face_file_path)
    return init_face_pil