Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
import google.generativeai as genai | |
import numpy as np | |
# Set up Gemini API (replace with your actual API key) | |
genai.configure(api_key='AIzaSyBDeJo3pioFL92ErFTtmRBmWt5diryp0E0') | |
def load_and_preprocess_data(file): | |
data = pd.read_csv(file) | |
numeric_columns = ['gaze', 'blink', 'eye_offset'] | |
for col in numeric_columns: | |
data[col] = pd.to_numeric(data[col], errors='coerce') | |
data = data.dropna() | |
return data | |
def calculate_gaze_stats(data, feature): | |
return { | |
'mean': data[feature].mean(), | |
'median': data[feature].median(), | |
'std': data[feature].std() | |
} | |
def visualize_gaze_distribution(data_dict, features, selected_movies): | |
for feature in features: | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
for movie in selected_movies: | |
sns.kdeplot(data_dict[movie][feature], ax=ax, label=movie, shade=True) | |
ax.set_title(f"Distribution of '{feature}'") | |
ax.set_xlabel("Value") | |
ax.set_ylabel("Density") | |
ax.legend() | |
st.pyplot(fig) | |
if len(selected_movies) == 2: | |
fig, ax = plt.subplots(figsize=(12, 10)) | |
correlation = pd.concat([data_dict[movie][features] for movie in selected_movies], axis=1, keys=selected_movies).corr() | |
sns.heatmap(correlation, annot=True, cmap='coolwarm', ax=ax) | |
ax.set_title('Correlation Heatmap of Selected Features Between Movies') | |
st.pyplot(fig) | |
def visualize_single_movie(data, features): | |
fig, axes = plt.subplots(len(features), 1, figsize=(10, 5*len(features)), sharex=True) | |
for i, feature in enumerate(features): | |
sns.lineplot(x='image_seq', y=feature, data=data, ax=axes[i] if len(features) > 1 else axes) | |
axes[i].set_title(f"{feature.capitalize()} over Image Sequence") | |
axes[i].set_xlabel("Image Sequence") | |
axes[i].set_ylabel("Value") | |
plt.tight_layout() | |
st.pyplot(fig) | |
def format_gaze_prompt(features, data_dict, selected_movies): | |
prompt = """ | |
You are an AI assistant specializing in movie gaze analysis. You have access to gaze data for the following movies: {MOVIES}, focusing on these features: {FEATURES}. | |
{STATS} | |
Based on this data: | |
1. Compare the overall patterns of the specified features across the selected movies. Are there any notable differences or similarities? | |
2. Analyze the distribution of each feature in the selected movies. Are there any outliers or unexpected patterns? | |
3. Discuss any significant differences in how these features change over the image sequence. What might these differences suggest about the movies' visual content or style? | |
4. Consider the variability of each feature in each selected movie. Do some movies have more consistent patterns, or do they fluctuate more? | |
5. Based on this gaze data, hypothesize about potential scenes or visual elements in each selected movie that might contribute to the observed patterns. | |
6. How might the differences in gaze, blink, and eye offset patterns between these selected movies affect the viewer's visual experience? | |
Provide a detailed analysis addressing these points, using specific data references where relevant. Your analysis should offer insights into how these gaze-related features are utilized in each selected movie and what this reveals about their visual content and potential audience engagement. | |
""" | |
stats = "" | |
for movie in selected_movies: | |
stats += f"\n{movie}:\n" | |
for feature in features: | |
feature_stats = calculate_gaze_stats(data_dict[movie], feature) | |
stats += f"{feature.capitalize()} - Mean: {feature_stats['mean']:.2f}, Median: {feature_stats['median']:.2f}, Standard Deviation: {feature_stats['std']:.2f}\n" | |
return prompt.format(MOVIES=", ".join(selected_movies), FEATURES=", ".join(features), STATS=stats) | |
def generate_response(prompt, data_dict, features, selected_movies): | |
model = genai.GenerativeModel('gemini-pro') | |
analysis_prompt = format_gaze_prompt(features, data_dict, selected_movies) | |
full_prompt = analysis_prompt + "\n\nUser query: " + prompt | |
response = model.generate_content(full_prompt) | |
if hasattr(response, 'candidates'): | |
if response.candidates: | |
content = response.candidates[0].content | |
if hasattr(content, 'parts'): | |
for part in content.parts: | |
if hasattr(part, 'text'): | |
return part.text | |
return "Error: Unable to extract text from the response. Please check the API response structure." | |
def main(): | |
st.title("Multi-Movie Gaze Analysis Chat Interface") | |
num_movies = st.number_input("How many movies would you like to compare?", min_value=1, max_value=10, value=1) | |
data_dict = {} | |
for i in range(num_movies): | |
uploaded_file = st.file_uploader(f"Choose CSV file for Movie {i+1}", type="csv", key=f"movie_{i+1}") | |
if uploaded_file is not None: | |
data = load_and_preprocess_data(uploaded_file) | |
data_dict[f"Movie {i+1}"] = data | |
if len(data_dict) == num_movies: | |
st.success("All files uploaded successfully. You can now start chatting!") | |
st.subheader("Data Information") | |
for movie, data in data_dict.items(): | |
st.write(f"{movie} Columns:", data.columns.tolist()) | |
features = ['gaze', 'blink', 'eye_offset'] | |
st.write("Available Features:", features) | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
selected_movies = st.multiselect("Select movies to compare:", list(data_dict.keys())) | |
if not selected_movies: | |
selected_movies = list(data_dict.keys()) | |
if prompt := st.chat_input("What would you like to know about the movies' gaze data?"): | |
st.chat_message("user").markdown(prompt) | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
selected_features = [feature for feature in features if feature in prompt.lower()] | |
if not selected_features: | |
selected_features = features | |
with st.chat_message("assistant"): | |
if len(selected_movies) == 1: | |
visualize_single_movie(data_dict[selected_movies[0]], selected_features) | |
elif any(keyword in prompt.lower() for keyword in ["graph", "compare", "visualize", "show"]): | |
visualize_gaze_distribution(data_dict, selected_features, selected_movies) | |
response = generate_response(prompt, data_dict, selected_features, selected_movies) | |
with st.chat_message("assistant"): | |
st.markdown(response) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
if st.checkbox("Show raw data"): | |
for movie in selected_movies: | |
st.subheader(f"{movie} Data") | |
st.write(data_dict[movie]) | |
if __name__ == "__main__": | |
main() |