from typing import Dict, Any, List, Generator import torch import os import logging from s2s_pipeline import main, rename_args, parse_arguments, setup_logger, initialize_queues_and_events, build_pipeline import numpy as np from queue import Queue import threading class EndpointHandler: def __init__(self, path=""): ( self.module_kwargs, self.socket_receiver_kwargs, self.socket_sender_kwargs, self.vad_handler_kwargs, self.whisper_stt_handler_kwargs, self.paraformer_stt_handler_kwargs, self.language_model_handler_kwargs, self.mlx_language_model_handler_kwargs, self.parler_tts_handler_kwargs, self.melo_tts_handler_kwargs, self.chat_tts_handler_kwargs, ) = parse_arguments() setup_logger(self.module_kwargs.log_level) rename_args(self.whisper_stt_handler_kwargs, "stt") rename_args(self.paraformer_stt_handler_kwargs, "paraformer_stt") rename_args(self.language_model_handler_kwargs, "lm") rename_args(self.mlx_language_model_handler_kwargs, "mlx_lm") rename_args(self.parler_tts_handler_kwargs, "tts") rename_args(self.melo_tts_handler_kwargs, "melo") rename_args(self.chat_tts_handler_kwargs, "chat_tts") self.queues_and_events = initialize_queues_and_events() self.pipeline_manager = build_pipeline( self.module_kwargs, self.socket_receiver_kwargs, self.socket_sender_kwargs, self.vad_handler_kwargs, self.whisper_stt_handler_kwargs, self.paraformer_stt_handler_kwargs, self.language_model_handler_kwargs, self.mlx_language_model_handler_kwargs, self.parler_tts_handler_kwargs, self.melo_tts_handler_kwargs, self.chat_tts_handler_kwargs, self.queues_and_events, ) self.pipeline_manager.start() # Add a new queue for collecting the final output self.final_output_queue = Queue() # Start a thread to collect the final output self.output_collector_thread = threading.Thread(target=self._collect_output) self.output_collector_thread.start() def _collect_output(self): while True: output = self.queues_and_events['send_audio_chunks_queue'].get() if output == b"END": self.final_output_queue.put(b"END") break self.final_output_queue.put(output) def __call__(self, data: Dict[str, Any]) -> Generator[Dict[str, Any], None, None]: """ Args: data (Dict[str, Any]): The input data containing the necessary arguments. Returns: Generator[Dict[str, Any], None, None]: A generator yielding output chunks from the model or pipeline. """ input_type = data.get("input_type", "text") input_data = data.get("input", "") if input_type == "speech": # Convert input audio data to numpy array audio_array = np.frombuffer(input_data, dtype=np.int16) # Put audio data into the recv_audio_chunks_queue self.queues_and_events['recv_audio_chunks_queue'].put(audio_array.tobytes()) elif input_type == "text": # Put text data directly into the text_prompt_queue self.queues_and_events['text_prompt_queue'].put(input_data) else: raise ValueError(f"Unsupported input type: {input_type}") # Stream the output chunks while True: chunk = self.final_output_queue.get() if chunk == b"END": break yield {"output": chunk} def cleanup(self): # Stop the pipeline self.pipeline_manager.stop() # Stop the output collector thread self.queues_and_events['send_audio_chunks_queue'].put(b"END") self.output_collector_thread.join()