Spaces:
Sleeping
Sleeping
Commit
•
bddd843
1
Parent(s):
fcf0aa2
Correct prompt padding side (#1)
Browse files- Correct prompt padding side (037776cca036c2b340673b03e3f25470c913938e)
- Update app.py (627dc63ff0fb3ceb1448818235c9f532da50a2b7)
Co-authored-by: Yoach Lacombe <[email protected]>
app.py
CHANGED
@@ -29,7 +29,8 @@ model = ParlerTTSForConditionalGeneration.from_pretrained(
|
|
29 |
|
30 |
client = InferenceClient()
|
31 |
|
32 |
-
|
|
|
33 |
feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
|
34 |
|
35 |
SAMPLE_RATE = feature_extractor.sampling_rate
|
@@ -87,8 +88,8 @@ def generate_base(subject, setting):
|
|
87 |
|
88 |
gr.Info("Generating Audio")
|
89 |
description = "Jenny speaks at an average pace with a calm delivery in a very confined sounding environment with clear audio quality."
|
90 |
-
story_tokens =
|
91 |
-
description_tokens =
|
92 |
speech_output = model.generate(input_ids=description_tokens, prompt_input_ids=story_tokens)
|
93 |
speech_output = [output.cpu().numpy() for output in speech_output]
|
94 |
gr.Info("Generated Audio")
|
|
|
29 |
|
30 |
client = InferenceClient()
|
31 |
|
32 |
+
description_tokenizer = AutoTokenizer.from_pretrained(repo_id)
|
33 |
+
prompt_tokenizer = AutoTokenizer.from_pretrained(repo_id, padding_side="left")
|
34 |
feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
|
35 |
|
36 |
SAMPLE_RATE = feature_extractor.sampling_rate
|
|
|
88 |
|
89 |
gr.Info("Generating Audio")
|
90 |
description = "Jenny speaks at an average pace with a calm delivery in a very confined sounding environment with clear audio quality."
|
91 |
+
story_tokens = prompt_tokenizer(model_input_tokens, return_tensors="pt", padding=True).input_ids.to(device)
|
92 |
+
description_tokens = description_tokenizer([description for _ in range(len(model_input_tokens))], return_tensors="pt").input_ids.to(device)
|
93 |
speech_output = model.generate(input_ids=description_tokens, prompt_input_ids=story_tokens)
|
94 |
speech_output = [output.cpu().numpy() for output in speech_output]
|
95 |
gr.Info("Generated Audio")
|