Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -14,26 +14,31 @@ gpt2_model = AutoModelForCausalLM.from_pretrained("gpt2-medium")
|
|
14 |
|
15 |
emotion_classifier = pipeline("text-classification", model="bhadresh-savani/distilbert-base-uncased-emotion", return_all_scores=True)
|
16 |
|
17 |
-
# Function to generate a comforting story using GPT-2
|
18 |
def generate_story(theme):
|
19 |
# A detailed prompt for generating a comforting story about the selected theme
|
20 |
story_prompt = f"Write a comforting, detailed, and heartwarming story about {theme}. The story should include a character who faces a tough challenge, finds hope, and ultimately overcomes the situation with a positive resolution."
|
21 |
|
22 |
-
# Generate story using GPT-2
|
23 |
input_ids = gpt2_tokenizer.encode(story_prompt, return_tensors='pt')
|
24 |
|
25 |
story_ids = gpt2_model.generate(
|
26 |
input_ids,
|
27 |
-
max_length=
|
28 |
-
temperature=0.
|
29 |
-
top_p=0.9,
|
30 |
-
|
|
|
31 |
num_return_sequences=1
|
32 |
)
|
33 |
|
34 |
# Decode the generated text
|
35 |
story = gpt2_tokenizer.decode(story_ids[0], skip_special_tokens=True)
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
def generate_response(user_input):
|
39 |
# Limit user input length to prevent overflow issues
|
|
|
14 |
|
15 |
emotion_classifier = pipeline("text-classification", model="bhadresh-savani/distilbert-base-uncased-emotion", return_all_scores=True)
|
16 |
|
|
|
17 |
def generate_story(theme):
|
18 |
# A detailed prompt for generating a comforting story about the selected theme
|
19 |
story_prompt = f"Write a comforting, detailed, and heartwarming story about {theme}. The story should include a character who faces a tough challenge, finds hope, and ultimately overcomes the situation with a positive resolution."
|
20 |
|
21 |
+
# Generate story using GPT-2 with adjusted parameters
|
22 |
input_ids = gpt2_tokenizer.encode(story_prompt, return_tensors='pt')
|
23 |
|
24 |
story_ids = gpt2_model.generate(
|
25 |
input_ids,
|
26 |
+
max_length=450, # Generate slightly shorter but focused stories
|
27 |
+
temperature=0.7, # Balanced creativity without too much randomness
|
28 |
+
top_p=0.9, # Encourage diversity in output
|
29 |
+
top_k=50, # Limit to more probable words
|
30 |
+
repetition_penalty=1.2, # Prevent repetitive patterns
|
31 |
num_return_sequences=1
|
32 |
)
|
33 |
|
34 |
# Decode the generated text
|
35 |
story = gpt2_tokenizer.decode(story_ids[0], skip_special_tokens=True)
|
36 |
+
|
37 |
+
# Clean up the generated story by removing the initial prompt
|
38 |
+
cleaned_response = story.replace(story_prompt, "").strip()
|
39 |
+
|
40 |
+
return cleaned_response
|
41 |
+
|
42 |
|
43 |
def generate_response(user_input):
|
44 |
# Limit user input length to prevent overflow issues
|