jeremierostan commited on
Commit
aba08de
1 Parent(s): bc9b3f1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +285 -0
app.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import numpy as np
4
+ import soundfile as sf
5
+ import io
6
+ import tempfile
7
+ import os
8
+ from groq import Groq
9
+ from crdt import warn
10
+
11
+ # Language configuration
12
+ LANGUAGE_CONFIG = {
13
+ "English": {
14
+ "voice_id": "vits-eng-1",
15
+ "whisper_lang": "en",
16
+ "system_prompt": "You are a language tutor. Your goal is to help students practice language skills in English by engaging in conversations with them. As part of your responses, give them feedback on their English proficiency if needed. If they make a clear mistake (vocabulary or grammar, not spelling), indicate it gently and explain how to correct them. Please respond in English. Keep your answers very concise. Note that you are used in a school setting and should refuse to answer or produce any content that is violent, sexual, discriminatory, or inappropriate for children under 13."
17
+ },
18
+ "French": {
19
+ "voice_id": "vits-fra-1",
20
+ "whisper_lang": "fr",
21
+ "system_prompt": "You are a language tutor. Your goal is to help students practice language skills in French by engaging in conversations with them. As part of your responses, give them feedback on their French proficiency if needed. If they make a clear mistake (vocabulary or grammar, not spelling), indicate it gently and explain how to correct them. Please respond in French. Keep your answers very concise. Note that you are used in a school setting and should refuse to answer or produce any content that is violent, sexual, discriminatory, or inappropriate for children under 13."
22
+ },
23
+ "Spanish": {
24
+ "voice_id": "vits-spa-1",
25
+ "whisper_lang": "es",
26
+ "system_prompt": "You are a language tutor. Your goal is to help students practice language skills in Spanish by engaging in conversations with them. As part of your responses, give them feedback on their Spanish proficiency if needed. If they make a clear mistake (vocabulary or grammar, not spelling), indicate it gently and explain how to correct them. Please respond in Spanish. Keep your answers very concise. Note that you are used in a school setting and should refuse to answer or produce any content that is violent, sexual, discriminatory, or inappropriate for children under 13."
27
+ }
28
+ }
29
+
30
+ def generate_audio(text: str, neets_api_key: str, language: str) -> tuple[int, np.ndarray]:
31
+ """Generate audio from text using Neets API"""
32
+ print(f"Generating audio for text in {language}:", text)
33
+
34
+ try:
35
+ # Make request with simplified params
36
+ response = requests.post(
37
+ url="https://api.neets.ai/v1/tts",
38
+ headers={
39
+ "Content-Type": "application/json",
40
+ "X-API-Key": neets_api_key
41
+ },
42
+ json={
43
+ "text": text,
44
+ "voice_id": LANGUAGE_CONFIG[language]["voice_id"],
45
+ "params": {
46
+ "model": "vits"
47
+ }
48
+ }
49
+ )
50
+
51
+ print(f"TTS Response status: {response.status_code}")
52
+
53
+ if response.status_code != 200:
54
+ print(f"TTS Response content: {response.content.decode()}")
55
+ raise ValueError(f"TTS API returned status code {response.status_code}")
56
+
57
+ # Save the audio to a temporary file
58
+ with tempfile.NamedTemporaryFile(suffix='.mp3', delete=False) as temp_file:
59
+ temp_file.write(response.content)
60
+ temp_path = temp_file.name
61
+
62
+ # Read the audio file
63
+ audio_data, sample_rate = sf.read(temp_path)
64
+
65
+ # Convert to mono if stereo
66
+ if len(audio_data.shape) > 1:
67
+ audio_data = np.mean(audio_data, axis=1)
68
+
69
+ # Convert to float32 if needed
70
+ if audio_data.dtype != np.float32:
71
+ audio_data = audio_data.astype(np.float32)
72
+
73
+ # Clean up the temporary file
74
+ os.unlink(temp_path)
75
+
76
+ return sample_rate, audio_data
77
+
78
+ except Exception as e:
79
+ print(f"Audio generation error: {str(e)}")
80
+ if hasattr(response, 'content'):
81
+ print(f"Response content: {response.content[:100]}") # Print first 100 bytes
82
+ raise
83
+
84
+ def transcribe_audio(audio_path: str, groq_api_key: str, language: str) -> str:
85
+ """Transcribe audio using Groq's Whisper API"""
86
+ client = Groq(api_key=groq_api_key)
87
+
88
+ try:
89
+ with open(audio_path, "rb") as audio_file:
90
+ transcription = client.audio.transcriptions.create(
91
+ file=(audio_path, audio_file.read()),
92
+ model="whisper-large-v3-turbo",
93
+ response_format="json",
94
+ language=LANGUAGE_CONFIG[language]["whisper_lang"],
95
+ temperature=0.0
96
+ )
97
+ return transcription.text
98
+ except Exception as e:
99
+ print(f"Transcription error: {str(e)}")
100
+ raise
101
+
102
+ def chat_with_groq(messages: list, groq_api_key: str, language: str) -> str:
103
+ """Send chat request to Groq API"""
104
+ client = Groq(api_key=groq_api_key)
105
+
106
+ try:
107
+ response = client.chat.completions.create(
108
+ model="llama-3.2-90b-vision-preview",
109
+ messages=[
110
+ {"role": "system", "content": LANGUAGE_CONFIG[language]["system_prompt"]},
111
+ *messages
112
+ ],
113
+ temperature=0.7
114
+ )
115
+ return response.choices[0].message.content
116
+ except Exception as e:
117
+ print(f"Groq chat error: {str(e)}")
118
+ raise
119
+
120
+ def process_voice_message(audio, history, neets_api_key: str, groq_api_key: str,
121
+ english: bool, french: bool, spanish: bool):
122
+ """Process recorded voice message"""
123
+ if not all([neets_api_key, groq_api_key]):
124
+ return "", [{"role": "error", "content": "Please provide all API keys."}], None
125
+
126
+ # Check language selection
127
+ selected_languages = []
128
+ if english: selected_languages.append("English")
129
+ if french: selected_languages.append("French")
130
+ if spanish: selected_languages.append("Spanish")
131
+
132
+ if not selected_languages:
133
+ return "", [{"role": "error", "content": "Please select at least one language."}], None
134
+ if len(selected_languages) > 1:
135
+ return "", [{"role": "error", "content": "Please select only one language."}], None
136
+
137
+ selected_language = selected_languages[0]
138
+ print(f"Selected language: {selected_language}")
139
+
140
+ try:
141
+ # Save the recorded audio to a temporary file
142
+ if isinstance(audio, tuple):
143
+ sr, data = audio
144
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
145
+ sf.write(temp_file.name, data, sr)
146
+ temp_path = temp_file.name
147
+ else:
148
+ temp_path = audio
149
+
150
+ # Transcribe the audio using Groq
151
+ transcribed_text = transcribe_audio(temp_path, groq_api_key, selected_language)
152
+ print(f"Transcribed text: {transcribed_text}")
153
+
154
+ # Clean up temporary file if we created one
155
+ if isinstance(audio, tuple):
156
+ os.unlink(temp_path)
157
+
158
+ # Prepare messages
159
+ messages = []
160
+ for msg in (history or []):
161
+ if isinstance(msg, dict):
162
+ messages.append({
163
+ "role": msg["role"],
164
+ "content": msg["content"]
165
+ })
166
+
167
+ # Add transcribed message
168
+ messages.append({
169
+ "role": "user",
170
+ "content": transcribed_text
171
+ })
172
+
173
+ # Get chat completion from Groq
174
+ bot_response = chat_with_groq(messages, groq_api_key, selected_language)
175
+ print(f"Bot response: {bot_response}")
176
+
177
+ # Generate audio for the response
178
+ try:
179
+ sample_rate, audio_data = generate_audio(bot_response, neets_api_key, selected_language)
180
+ audio_output = (sample_rate, audio_data)
181
+ except Exception as audio_error:
182
+ print(f"Audio generation error: {audio_error}")
183
+ audio_output = None
184
+
185
+ # Update history with new messages
186
+ new_history = messages + [
187
+ {"role": "assistant", "content": bot_response}
188
+ ]
189
+
190
+ return transcribed_text, new_history, audio_output
191
+
192
+ except Exception as e:
193
+ print(f"Processing error: {str(e)}")
194
+ return "", [{"role": "error", "content": f"Error: {str(e)}"}], None
195
+
196
+ # Create Gradio interface
197
+ with gr.Blocks() as demo:
198
+ gr.Image(value="https://i.postimg.cc/L830G7XS/ED-COACH.jpg", type="url", elem_id="top-image").style(full_width=True)
199
+ gr.Markdown("# Multilingual Tutor")
200
+
201
+ with gr.Row():
202
+ neets_api_key = gr.Textbox(
203
+ label="Neets API Key",
204
+ type="password",
205
+ placeholder="Enter your Neets API key here"
206
+ )
207
+ groq_api_key = gr.Textbox(
208
+ label="Groq API Key",
209
+ type="password",
210
+ placeholder="Enter your Groq API key here"
211
+ )
212
+
213
+ with gr.Row():
214
+ gr.Markdown("### Select Language")
215
+ english_checkbox = gr.Checkbox(label="English", value=True)
216
+ french_checkbox = gr.Checkbox(label="French")
217
+ spanish_checkbox = gr.Checkbox(label="Spanish")
218
+
219
+ chatbot = gr.Chatbot(
220
+ type="messages",
221
+ show_copy_button=True,
222
+ layout="bubble"
223
+ )
224
+
225
+ with gr.Row():
226
+ # Audio recorder for voice input
227
+ audio_input = gr.Audio(
228
+ label="Record Message",
229
+ sources=["microphone"],
230
+ type="numpy"
231
+ )
232
+
233
+ # Display transcribed text
234
+ transcribed_msg = gr.Textbox(
235
+ label="Transcribed Message",
236
+ placeholder="Your message will appear here after recording...",
237
+ interactive=False
238
+ )
239
+
240
+ # Audio output for bot response
241
+ audio_output = gr.Audio(
242
+ label="Bot Voice",
243
+ autoplay=True,
244
+ type="numpy"
245
+ )
246
+
247
+ # Add clear button
248
+ clear = gr.Button("Clear Conversation")
249
+
250
+ # Handle audio submission
251
+ audio_input.stop_recording(
252
+ process_voice_message,
253
+ inputs=[
254
+ audio_input, chatbot, neets_api_key, groq_api_key,
255
+ english_checkbox, french_checkbox, spanish_checkbox
256
+ ],
257
+ outputs=[transcribed_msg, chatbot, audio_output]
258
+ )
259
+
260
+ # Make checkboxes mutually exclusive
261
+ english_checkbox.change(
262
+ lambda x: (False, False) if x else (False, False),
263
+ inputs=english_checkbox,
264
+ outputs=[french_checkbox, spanish_checkbox]
265
+ )
266
+ french_checkbox.change(
267
+ lambda x: (False, False) if x else (False, False),
268
+ inputs=french_checkbox,
269
+ outputs=[english_checkbox, spanish_checkbox]
270
+ )
271
+ spanish_checkbox.change(
272
+ lambda x: (False, False) if x else (False, False),
273
+ inputs=spanish_checkbox,
274
+ outputs=[english_checkbox, french_checkbox]
275
+ )
276
+
277
+ # Handle clear button
278
+ clear.click(
279
+ lambda: ("", [], None),
280
+ outputs=[transcribed_msg, chatbot, audio_output]
281
+ )
282
+ gr.Markdown(warn)
283
+
284
+ if __name__ == "__main__":
285
+ demo.launch()