File size: 7,229 Bytes
625572e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import streamlit as st
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import google.generativeai as genai
import numpy as np

# Set up Gemini API (replace with your actual API key)
genai.configure(api_key='AIzaSyBDeJo3pioFL92ErFTtmRBmWt5diryp0E0')

def load_and_preprocess_data(file):
    data = pd.read_csv(file)
    numeric_columns = ['gaze', 'blink', 'eye_offset']
    for col in numeric_columns:
        data[col] = pd.to_numeric(data[col], errors='coerce')
    data = data.dropna()
    return data

def calculate_gaze_stats(data, feature):
    return {
        'mean': data[feature].mean(),
        'median': data[feature].median(),
        'std': data[feature].std()
    }

def visualize_gaze_distribution(data_dict, features, selected_movies):
    for feature in features:
        fig, ax = plt.subplots(figsize=(10, 6))
        for movie in selected_movies:
            sns.kdeplot(data_dict[movie][feature], ax=ax, label=movie, shade=True)
        ax.set_title(f"Distribution of '{feature}'")
        ax.set_xlabel("Value")
        ax.set_ylabel("Density")
        ax.legend()
        st.pyplot(fig)

    if len(selected_movies) == 2:
        fig, ax = plt.subplots(figsize=(12, 10))
        correlation = pd.concat([data_dict[movie][features] for movie in selected_movies], axis=1, keys=selected_movies).corr()
        sns.heatmap(correlation, annot=True, cmap='coolwarm', ax=ax)
        ax.set_title('Correlation Heatmap of Selected Features Between Movies')
        st.pyplot(fig)

def visualize_single_movie(data, features):
    fig, axes = plt.subplots(len(features), 1, figsize=(10, 5*len(features)), sharex=True)
    for i, feature in enumerate(features):
        sns.lineplot(x='image_seq', y=feature, data=data, ax=axes[i] if len(features) > 1 else axes)
        axes[i].set_title(f"{feature.capitalize()} over Image Sequence")
        axes[i].set_xlabel("Image Sequence")
        axes[i].set_ylabel("Value")
    plt.tight_layout()
    st.pyplot(fig)

def format_gaze_prompt(features, data_dict, selected_movies):
    prompt = """
    You are an AI assistant specializing in movie gaze analysis. You have access to gaze data for the following movies: {MOVIES}, focusing on these features: {FEATURES}.

    {STATS}

    Based on this data:

    1. Compare the overall patterns of the specified features across the selected movies. Are there any notable differences or similarities?
    2. Analyze the distribution of each feature in the selected movies. Are there any outliers or unexpected patterns?
    3. Discuss any significant differences in how these features change over the image sequence. What might these differences suggest about the movies' visual content or style?
    4. Consider the variability of each feature in each selected movie. Do some movies have more consistent patterns, or do they fluctuate more?
    5. Based on this gaze data, hypothesize about potential scenes or visual elements in each selected movie that might contribute to the observed patterns.
    6. How might the differences in gaze, blink, and eye offset patterns between these selected movies affect the viewer's visual experience?

    Provide a detailed analysis addressing these points, using specific data references where relevant. Your analysis should offer insights into how these gaze-related features are utilized in each selected movie and what this reveals about their visual content and potential audience engagement.
    """

    stats = ""
    for movie in selected_movies:
        stats += f"\n{movie}:\n"
        for feature in features:
            feature_stats = calculate_gaze_stats(data_dict[movie], feature)
            stats += f"{feature.capitalize()} - Mean: {feature_stats['mean']:.2f}, Median: {feature_stats['median']:.2f}, Standard Deviation: {feature_stats['std']:.2f}\n"

    return prompt.format(MOVIES=", ".join(selected_movies), FEATURES=", ".join(features), STATS=stats)

def generate_response(prompt, data_dict, features, selected_movies):
    model = genai.GenerativeModel('gemini-pro')

    analysis_prompt = format_gaze_prompt(features, data_dict, selected_movies)

    full_prompt = analysis_prompt + "\n\nUser query: " + prompt
    response = model.generate_content(full_prompt)

    if hasattr(response, 'candidates'):
        if response.candidates:
            content = response.candidates[0].content
            if hasattr(content, 'parts'):
                for part in content.parts:
                    if hasattr(part, 'text'):
                        return part.text

    return "Error: Unable to extract text from the response. Please check the API response structure."

def main():
    st.title("Multi-Movie Gaze Analysis Chat Interface")

    num_movies = st.number_input("How many movies would you like to compare?", min_value=1, max_value=10, value=1)

    data_dict = {}
    for i in range(num_movies):
        uploaded_file = st.file_uploader(f"Choose CSV file for Movie {i+1}", type="csv", key=f"movie_{i+1}")
        if uploaded_file is not None:
            data = load_and_preprocess_data(uploaded_file)
            data_dict[f"Movie {i+1}"] = data

    if len(data_dict) == num_movies:
        st.success("All files uploaded successfully. You can now start chatting!")

        st.subheader("Data Information")
        for movie, data in data_dict.items():
            st.write(f"{movie} Columns:", data.columns.tolist())

        features = ['gaze', 'blink', 'eye_offset']
        st.write("Available Features:", features)

        if "messages" not in st.session_state:
            st.session_state.messages = []

        for message in st.session_state.messages:
            with st.chat_message(message["role"]):
                st.markdown(message["content"])

        selected_movies = st.multiselect("Select movies to compare:", list(data_dict.keys()))
        if not selected_movies:
            selected_movies = list(data_dict.keys())

        if prompt := st.chat_input("What would you like to know about the movies' gaze data?"):
            st.chat_message("user").markdown(prompt)
            st.session_state.messages.append({"role": "user", "content": prompt})

            selected_features = [feature for feature in features if feature in prompt.lower()]
            if not selected_features:
                selected_features = features

            with st.chat_message("assistant"):
                if len(selected_movies) == 1:
                    visualize_single_movie(data_dict[selected_movies[0]], selected_features)
                elif any(keyword in prompt.lower() for keyword in ["graph", "compare", "visualize", "show"]):
                    visualize_gaze_distribution(data_dict, selected_features, selected_movies)

            response = generate_response(prompt, data_dict, selected_features, selected_movies)

            with st.chat_message("assistant"):
                st.markdown(response)

            st.session_state.messages.append({"role": "assistant", "content": response})

        if st.checkbox("Show raw data"):
            for movie in selected_movies:
                st.subheader(f"{movie} Data")
                st.write(data_dict[movie])

if __name__ == "__main__":
    main()