Blakus commited on
Commit
4ff535a
1 Parent(s): 9b668a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -17
app.py CHANGED
@@ -43,9 +43,6 @@ model.load_checkpoint(config, checkpoint_path=checkpoint_path, vocab_path=vocab_
43
 
44
  print("Modelo cargado en CPU")
45
 
46
- def split_text(text):
47
- return re.split(r'(?<=[.!?])\s+', text)
48
-
49
  def predict(prompt, language, reference_audio):
50
  try:
51
  if len(prompt) < 2 or len(prompt) > 600:
@@ -53,22 +50,15 @@ def predict(prompt, language, reference_audio):
53
 
54
  sentences = split_text(prompt)
55
 
56
- # Usar los parámetros del config.json
57
- temperature = config.model_args.get("temperature", 0.85)
58
- repetition_penalty = config.model_args.get("repetition_penalty", 2.0)
59
- length_penalty = config.model_args.get("length_penalty", 1.0)
60
- top_k = config.model_args.get("top_k", 50)
61
- top_p = config.model_args.get("top_p", 0.85)
62
-
63
- gpt_cond_len = config.model_args.get("gpt_cond_len", 12)
64
- gpt_cond_chunk_len = config.model_args.get("gpt_cond_chunk_len", 4)
65
- max_ref_len = config.model_args.get("max_ref_len", 10)
66
 
67
  gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
68
- audio_path=reference_audio,
69
- gpt_cond_len=gpt_cond_len,
70
- gpt_cond_chunk_len=gpt_cond_chunk_len,
71
- max_ref_len=max_ref_len
72
  )
73
 
74
  start_time = time.time()
 
43
 
44
  print("Modelo cargado en CPU")
45
 
 
 
 
46
  def predict(prompt, language, reference_audio):
47
  try:
48
  if len(prompt) < 2 or len(prompt) > 600:
 
50
 
51
  sentences = split_text(prompt)
52
 
53
+ # Obtener los parámetros de la configuración JSON
54
+ temperature = config.get("temperature", 0.85)
55
+ length_penalty = config.get("length_penalty", 1.0)
56
+ repetition_penalty = config.get("repetition_penalty", 2.0)
57
+ top_k = config.get("top_k", 50)
58
+ top_p = config.get("top_p", 0.85)
 
 
 
 
59
 
60
  gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
61
+ audio_path=reference_audio
 
 
 
62
  )
63
 
64
  start_time = time.time()