File size: 3,698 Bytes
d903275 0a2c880 2ff483e 727c299 2ff483e d903275 2ff483e 7b6b4a2 727c299 7b6b4a2 727c299 7b6b4a2 727c299 97b7fcf 2aece5c 727c299 2aece5c 5bd8a1a 727c299 2aece5c 7b6b4a2 0a2c880 f31f247 7ee935b f31f247 0a2c880 3b25749 727c299 dc5958a 7b6b4a2 2aece5c 7b6b4a2 0a2c880 2aece5c 0a2c880 b99a30a f31f247 0a2c880 8139f95 2aece5c 5bd8a1a 8139f95 c34dd87 7ee935b 0f3f1e9 7ee935b c34dd87 6eb4256 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import requests
model_name = "Writer/palmyra-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
def get_movie_info(movie_title):
api_key = "20e959f0f28e6b3e3de49c50f358538a"
search_url = f"https://api.themoviedb.org/3/search/movie"
# Make a search query to TMDb
params = {
"api_key": api_key,
"query": movie_title,
"language": "en-US",
"page": 1,
}
try:
search_response = requests.get(search_url, params=params)
search_data = search_response.json()
# Check if any results are found
if search_data.get("results"):
movie_id = search_data["results"][0]["id"]
# Fetch detailed information using the movie ID
details_url = f"https://api.themoviedb.org/3/movie/{movie_id}"
details_params = {
"api_key": api_key,
"language": "en-US",
}
details_response = requests.get(details_url, params=details_params)
details_data = details_response.json()
# Extract relevant information
title = details_data.get("title", "Unknown Title")
year = details_data.get("release_date", "Unknown Year")[:4]
genre = ", ".join(genre["name"] for genre in details_data.get("genres", []))
tmdb_link = f"https://www.themoviedb.org/movie/{movie_id}"
return f"Title: {title}, Year: {year}, Genre: {genre}\nFind more info here: {tmdb_link}"
else:
return "Movie not found", ""
except Exception as e:
return f"Error: {e}", ""
def generate_response(prompt):
input_text_template = (
"Hi! I am a gen AI bot powered by the Writer/palmyra-small model. "
"I am here to give helpful, detailed, and polite answers to your movie inquiries.\n \n"
f"USER: {prompt}\n \n"
"Writer AI:"
)
# Call the get_movie_info function to enrich the response
movie_info = get_movie_info(prompt)
# Concatenate the movie info with the input template
input_text_template += f" Movie Info: {movie_info}"
model_inputs = tokenizer(input_text_template, return_tensors="pt").to(device)
gen_conf = {
"top_k": 20,
"max_length": 20, # shortened to limit writer predictions; model loops and and prediction not coherent
"temperature": 0.6,
"do_sample": True,
"eos_token_id": tokenizer.eos_token_id,
}
output = model.generate(**model_inputs, **gen_conf)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return f"Movie Info:\n{movie_info}\n\n Writer AI Generated Response:\n{generated_text}\n"
# Define chat function for gr.ChatInterface
def chat_function(message, history):
response = generate_response(message)
history.append([message, response])
return response
# Create Gradio Chat Interface
chat_interface = gr.ChatInterface(
chat_function,
textbox=gr.Textbox(placeholder="Type in any movie title, e.g., Oppenheimer, Barbie, Poor Things ", container=False, scale=7),
title="Palmyra-Small - Movie Chatbot ",
description="This chatbot is powered by the Writer/Palmyra-small model and TMdb API. Type in any movie title and the chatbot will respond with the title, release date, genre, and link for more information. ",
theme="soft",
examples=["Oppenheimer", "Barbie", "Poor Things"],)
chat_interface.launch(share=True) |