File size: 3,698 Bytes
d903275
0a2c880
2ff483e
727c299
2ff483e
d903275
 
 
 
2ff483e
7b6b4a2
727c299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b6b4a2
727c299
 
7b6b4a2
727c299
 
 
 
97b7fcf
2aece5c
 
727c299
 
2aece5c
5bd8a1a
727c299
2aece5c
7b6b4a2
0a2c880
 
f31f247
7ee935b
 
f31f247
0a2c880
3b25749
727c299
dc5958a
7b6b4a2
 
2aece5c
7b6b4a2
0a2c880
 
 
 
2aece5c
0a2c880
 
 
 
 
 
 
 
b99a30a
f31f247
0a2c880
8139f95
 
2aece5c
 
 
5bd8a1a
8139f95
c34dd87
 
7ee935b
0f3f1e9
7ee935b
c34dd87
 
6eb4256
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import requests

model_name = "Writer/palmyra-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

def get_movie_info(movie_title):
    api_key = "20e959f0f28e6b3e3de49c50f358538a"
    search_url = f"https://api.themoviedb.org/3/search/movie"
    
    # Make a search query to TMDb
    params = {
        "api_key": api_key,
        "query": movie_title,
        "language": "en-US",
        "page": 1,
    }

    try:
        search_response = requests.get(search_url, params=params)
        search_data = search_response.json()

        # Check if any results are found
        if search_data.get("results"):
            movie_id = search_data["results"][0]["id"]

            # Fetch detailed information using the movie ID
            details_url = f"https://api.themoviedb.org/3/movie/{movie_id}"
            details_params = {
                "api_key": api_key,
                "language": "en-US",
            }

            details_response = requests.get(details_url, params=details_params)
            details_data = details_response.json()

            # Extract relevant information
            title = details_data.get("title", "Unknown Title")
            year = details_data.get("release_date", "Unknown Year")[:4]
            genre = ", ".join(genre["name"] for genre in details_data.get("genres", []))
            tmdb_link = f"https://www.themoviedb.org/movie/{movie_id}"

            return f"Title: {title}, Year: {year}, Genre: {genre}\nFind more info here: {tmdb_link}"

        else:
            return "Movie not found", ""

    except Exception as e:
        return f"Error: {e}", ""

def generate_response(prompt):
    input_text_template = (
        "Hi! I am a gen AI bot powered by the Writer/palmyra-small model. "
        "I am here to give helpful, detailed, and polite answers to your movie inquiries.\n \n"
        f"USER: {prompt}\n \n"
        "Writer AI:"
    )

    # Call the get_movie_info function to enrich the response
    movie_info = get_movie_info(prompt)

    # Concatenate the movie info with the input template
    input_text_template += f" Movie Info: {movie_info}"

    model_inputs = tokenizer(input_text_template, return_tensors="pt").to(device)

    gen_conf = {
        "top_k": 20,
        "max_length": 20, # shortened to limit writer predictions; model loops and and prediction not coherent
        "temperature": 0.6,
        "do_sample": True,
        "eos_token_id": tokenizer.eos_token_id,
    }

    output = model.generate(**model_inputs, **gen_conf)

    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    
    return f"Movie Info:\n{movie_info}\n\n Writer AI Generated Response:\n{generated_text}\n"

# Define chat function for gr.ChatInterface
def chat_function(message, history):
    response = generate_response(message)
    history.append([message, response])
    return response

# Create Gradio Chat Interface
chat_interface = gr.ChatInterface(
    chat_function,
    textbox=gr.Textbox(placeholder="Type in any movie title, e.g., Oppenheimer, Barbie, Poor Things ", container=False, scale=7),
    title="Palmyra-Small - Movie Chatbot ",
    description="This chatbot is powered by the Writer/Palmyra-small model and TMdb API. Type in any movie title and the chatbot will respond with the title, release date, genre, and link for more information. ",
    theme="soft",
    examples=["Oppenheimer", "Barbie", "Poor Things"],)
chat_interface.launch(share=True)