ychenhq's picture
Initial test
0bab51a verified
raw
history blame
12.6 kB
import os
import sys
import gradio as gr
import math
import matplotlib.pyplot as plt
import requests
import fileinput
import firebase_admin
from firebase_admin import credentials
from firebase_admin import firestore
import gradio as gr
import json
import math
import requests
basedir = "/content/drive/MyDrive/FYP/Code/VideoCrafter"
# if you change the variables here, remember to change the "name" in .sh file
vidOut = "results/10videos"
uvqOut = "results/modified_prompts_eval"
evalOut = "evaluation_results"
num_of_vid = 3
vid_length = 2
uvq_threshold = 3.8
fps = 24
# Generate the scores in csv files
def genScore():
for i in range(1, num_of_vid+1):
fileindex = f"{i:04d}"
os.system(
f'python3 ./uvq/uvq_main.py --input_files="{fileindex},2,{basedir}/{vidOut}/{fileindex}.mp4" --output_dir {uvqOut} --model_dir ./uvq/models'
)
def getScore(filename):
# MOS_score defines the output of the uvq score
lines = str(filename).split('\n')
last_line = lines[-1]
MOS_score = last_line.split(',')[-1]
MOS_score = MOS_score[:-2]
return MOS_score
# MOS_score defines the Mean Opinion Score of prediction, if the video's MOS exceeds the threshold then we directly use this video
def chooseBestVideo():
MOS_score_high = 0
preferred_output = ""
chosen_idx = 0
for i in range(1, num_of_vid+1):
'''We loop thru this current processed video'''
filedir = f"{i:04d}"
filename = f"{i:04d}_uvq.csv"
with open(os.path.join(basedir, uvqOut, filedir, filename), 'r') as file:
MOS = file.read().strip()
MOS_score = getScore(MOS)
print("Video Index:", f"{i:04d}", "Score:", MOS_score)
# if the MOS_score is higher than the previous video, we choose this video as our preferred video output
if float(MOS_score) > float(MOS_score_high) or float(MOS_score) > uvq_threshold:
MOS_score_high = MOS_score
preferred_output = filename
chosen_idx = i
if float(MOS_score) > uvq_threshold:
break
return chosen_idx
# print(MOS_score_high)
# print(preferred_output)
def extract_scores_from_json(json_path):
with open(json_path) as file:
data = json.load(file)
for key, value in data.items():
if isinstance(value, list) and len(value) > 1 and isinstance(value[0], float):
motion_score = value[0]
return motion_score
def VBench_eval(vid_filename):
# vid_filename: video filename without .mp4
os.system(
f'python VBench/evaluate.py --dimension "motion_smoothness" --videos_path {os.path.join(basedir, vidOut, vid_filename)}.mp4 --custom_input --output_filename {vid_filename}'
)
eval_file_path = os.path.join(
basedir, evalOut, f"{vid_filename}_eval_results.json")
motion_score = extract_scores_from_json(eval_file_path)
return motion_score
def interpolation(chosen_idx, fps):
vid_filename = f"{chosen_idx:04d}.mp4"
os.chdir("/content/drive/MyDrive/FYP/Code/VideoCrafter/ECCV2022-RIFE")
os.system(
f'python3 inference_video.py --exp=2 --video={os.path.join(basedir, vidOut, vid_filename)} --fps {fps}'
)
os.chdir("/content/drive/MyDrive/FYP/Code/VideoCrafter")
out_name = f"{chosen_idx:04d}_4X_{fps}fps.mp4"
return out_name
# call the GPT API here
def call_gpt_api(prompt, isSentence=False):
api_key = "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"
response = requests.post(
'https://api.openai.com/v1/chat/completions',
headers={
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
},
json={
'messages': [{'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': prompt}],
'model': 'gpt-3.5-turbo',
# 'prompt': prompt,
'temperature': 0.4,
'max_tokens': 200
})
response_json = response.json()
choices = response_json['choices']
contents = [choice['message']['content'] for choice in choices]
contents = [
sentence for sublist in contents for sentence in sublist.split('\n')]
# Remove the leading number and dot from each sentence
sentences = [content.lstrip('1234567890.- ') for content in contents]
if len(sentences) > 2 and isSentence:
sentences = sentences[1:]
return sentences
# Initialize Firebase Admin SDK
cred = credentials.Certificate(
f"{basedir}/final-year-project-443dd-df6f48af0796.json")
firebase_admin.initialize_app(cred)
# Initialize Firestore client
db = firestore.client()
def retrieve_user_feedback():
# Retrieve user feedback from Firestore
feedback_collection = db.collection("user_feedbacks")
feedback_docs = feedback_collection.get()
feedback_text = []
experience = []
for doc in feedback_docs:
data = doc.to_dict()
feedback_text.append(data.get('feedback_text', None))
experience.append(data.get('experience', None))
return feedback_text, experience
feedback_text, experience = retrieve_user_feedback()
# print("Feedback Text:", feedback_text)
# print("Experience:", experience)
def store_user_feedback(feedback_text, experience):
# Get a reference to the Firestore collection
feedback_collection = db.collection("user_feedbacks")
# Create a new document with feedback_text and experience fields
feedback_collection.add({
'feedback_text': feedback_text,
'experience': experience
})
return
t2v_examples = [
['A tiger walks in the forest, photorealistic, 4k, high definition'],
['an elephant is walking under the sea, 4K, high definition'],
['an astronaut riding a horse in outer space'],
['a monkey is playing a piano'],
['A fire is burning on a candle'],
['a horse is drinking in the river'],
['Robot dancing in times square'],
]
def generate_output(input_text, output_video_1, fps, examples):
def generate_output_fn(input_text, output_video_1, fps, examples):
if input_text == "":
return input_text, output_video_1, examples
output = call_gpt_api(
prompt=f"Generate 2 similar prompts and add some reasonable words to the given prompt and not change the meaning, each within 30 words: {input_text}", isSentence=True)
output.append(input_text)
with open(f"{basedir}/prompts/test_prompts.txt", 'w') as file:
for i, sentence in enumerate(output):
if i < len(output) - 1:
file.write(sentence + '\n')
else:
file.write(sentence)
os.system(
f'sh {os.path.join(basedir, "scripts", "run_text2video.sh")}')
# Connect the video output and return the video corresponding link
genScore()
chosen_idx = chooseBestVideo()
chosen_vid_path = interpolation(chosen_idx, fps)
chosen_vid_path = f"{basedir}/{vidOut}/{chosen_vid_path}"
# chosen_vid_path = "/content/drive/MyDrive/FYP/Code/VideoCrafter/results/cat/0002_4X_16fps.mp4"
output_video_1 = gr.Video(
value=chosen_vid_path, show_download_button=True)
examples_list = call_gpt_api(
prompt=f"Generate 5 similar prompts that makes a storyline coming after the given input, each within 10 words: {input_text}")
examples = []
for prompt in examples_list:
examples.append([prompt])
input_text = ""
return input_text, output_video_1, examples
return generate_output_fn(input_text, output_video_1, fps, examples)
def t2v_demo(result_dir='./tmp/'):
with gr.Blocks() as videocrafter_iface:
gr.Markdown("<div align='center'> <h2> VideoCraftXtend: AI-Enhanced Text-to-Video Generation with Extended Length and Enhanced Motion Smoothness </span> </h2> </div>")
# Initialize values for video length and fps
video_len_value = 5.0
def update_fps(video_len, fps):
fps_value = 80 / video_len
return f"<div justify-content: 'center'; text-align='center'> <h6> FPS (frames per second) : {int(fps_value)} </span> </h6> </div>"
def load_example(example_id):
return example_id[0]
def update_feedback(value, text):
labels = ['Positive', 'Neutral', 'Negative']
colors = ['#66c2a5', '#fc8d62', '#8da0cb']
if value != '':
store_user_feedback(value, text)
user_satisfaction.append(value)
value = ''
if text != '':
user_feedback.append(text)
text = ''
user_feedback, user_satisfaction = retrieve_user_feedback()
sizes = [user_satisfaction.count('Positive'), user_satisfaction.count(
'Neutral'), user_satisfaction.count('Negative')]
plt.pie(sizes, labels=labels, autopct='%1.1f%%',
startangle=140, colors=colors)
plt.axis('equal')
return plt
with gr.Tab(label="Text2Video"):
with gr.Column():
with gr.Row():
with gr.Column():
input_text = gr.Text(
placeholder=t2v_examples[2], label='Please input your prompt here.')
with gr.Row():
examples = gr.Dataset(samples=t2v_examples, components=[
input_text], label='Sample prompts that can be used to form a storyline.')
with gr.Column():
gr.Markdown(
"<div align='center'> <h4> Modify video length and the corresponding fps will be shown on the right. </span> </h4> </div>")
with gr.Row():
video_len = gr.Slider(minimum=4.0, maximum=10.0, step=1, label='Video Length',
value=video_len_value, elem_id="video_len", interactive=True)
fps = gr.Markdown(
elem_id="fps", value=f"<div> <h6> FPS (frames per second) : 16</span> </h6> </div>")
send_btn = gr.Button("Send")
with gr.Column():
with gr.Tab(label='Result'):
with gr.Row():
output_video_1 = gr.Video(
value="/content/drive/MyDrive/FYP/Code/VideoCrafter/results/10videos/0009.mp4", show_download_button=True)
video_len.change(update_fps, inputs=[video_len, fps], outputs=fps)
# fps.change(update_video_len_slider, inputs = fps, outputs = video_len)
examples.click(load_example, inputs=[
examples], outputs=[input_text])
send_btn.click(
fn=generate_output,
inputs=[input_text, output_video_1, fps, examples],
outputs=[input_text, output_video_1, examples],
)
with gr.Tab(label="Feedback"):
with gr.Column():
with gr.Column():
with gr.Row():
feedback_value = gr.Radio(
['Positive', 'Neutral', 'Negative'], label="How is your experience?")
feedback_text = gr.Textbox(
placeholder="Enter feedback here", label="Feedback Text")
with gr.Row():
cancel_btn = gr.Button("Clear")
submit_btn = gr.Button("Submit")
with gr.Row():
pie_chart = gr.Plot(value=update_feedback(
'', ''), label="Feedback Pie Chart")
with gr.Column():
gr.Markdown(
"<div align='center'> <h4> Feedbacks from users: </span> </h4> </div>")
feedback_text_display = [gr.Markdown(
feedback, label="User Feedback") for feedback in retrieve_user_feedback()[0]]
submit_btn.click(fn=update_feedback, inputs=[
feedback_value, feedback_text], outputs=[pie_chart])
return videocrafter_iface
if __name__ == "__main__":
result_dir = os.path.join('./', 'results')
t2v_iface = t2v_demo(result_dir)
t2v_iface.queue(max_size=10)
t2v_iface.launch(debug=True)