berkaygkv's picture
Update app.py
a7fde88 verified
raw
history blame
12.6 kB
import streamlit as st
import spotipy
from spotipy.oauth2 import SpotifyOAuth
from qdrant_client import QdrantClient
from qdrant_client.http import models
from src.laion_clap.inference import AudioEncoder
import os
import re
import unicodedata
import requests
import uuid
import os
# Spotify API credentials
SPOTIPY_CLIENT_ID = os.getenv("SPOTIPY_CLIENT_ID")
SPOTIPY_CLIENT_SECRET = os.getenv("SPOTIPY_CLIENT_SECRET")
SPOTIPY_REDIRECT_URI = os.getenv('SPOTIPY_REDIRECT_URI')
SCOPE = 'user-library-read'
CACHE_PATH = '.spotifycache'
# Qdrant setup
QDRANT_HOST = os.getenv('QDRANT_HOST', 'localhost')
COLLECTION_NAME = "spotify_songs"
st.set_page_config(page_title="Spotify Similarity Search", page_icon="🎡", layout="wide")
def reset_environment():
# Clear all st.cache_resource and st.cache_data
st.cache_resource.clear()
st.cache_data.clear()
# Clear all items in session state
for key in list(st.session_state.keys()):
del st.session_state[key]
# Generate a new session ID
st.session_state.session_id = str(uuid.uuid4())
@st.cache_resource
def load_resources():
return AudioEncoder()
@st.cache_resource
def get_qdrant_client():
client = QdrantClient(path="./qdrant_data")
try:
client.create_collection(COLLECTION_NAME, vectors_config=models.VectorParams(size=512, distance=models.Distance.COSINE),)
except Exception:
pass
try:
client.get_collection(COLLECTION_NAME)
except Exception:
st.error("Qdrant collection not found. Please ensure the collection is properly initialized.")
return client
def get_spotify_client():
auth_manager = SpotifyOAuth(
client_id=SPOTIPY_CLIENT_ID,
client_secret=SPOTIPY_CLIENT_SECRET,
redirect_uri=SPOTIPY_REDIRECT_URI,
scope=SCOPE,
cache_path=CACHE_PATH
)
if 'code' in st.experimental_get_query_params():
token_info = auth_manager.get_access_token(st.experimental_get_query_params()['code'][0])
return spotipy.Spotify(auth=token_info['access_token'])
if not auth_manager.get_cached_token():
auth_url = auth_manager.get_authorize_url()
st.markdown(f"[Click here to login with Spotify]({auth_url})")
return None
return spotipy.Spotify(auth_manager=auth_manager)
def find_similar_songs_by_text(_query_text, _qdrant_client, _text_encoder, top_k=10):
query_vector = generate_text_embedding(_query_text, _text_encoder)
search_result = _qdrant_client.query_points(
collection_name=COLLECTION_NAME,
query=query_vector.tolist()[0],
limit=top_k
).model_dump()["points"]
return [
{
"name": hit["payload"]["name"],
"artist": hit["payload"]["artists"][0]["name"],
"similarity": hit["score"],
"preview_url": hit["payload"]["preview_url"]
} for hit in search_result
]
def generate_text_embedding(text, text_encoder):
text_data = [text]
return text_encoder.get_text_embedding(text_data)
def logout():
if os.path.exists(CACHE_PATH):
os.remove(CACHE_PATH)
for key in list(st.session_state.keys()):
del st.session_state[key]
st.experimental_rerun()
def truncate_qdrant_data(qdrant_client):
try:
qdrant_client.delete_collection(collection_name=COLLECTION_NAME)
qdrant_client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=models.VectorParams(size=512, distance=models.Distance.COSINE),
)
st.success("Qdrant data has been truncated successfully.")
except Exception as e:
st.error(f"An error occurred while truncating Qdrant data: {str(e)}")
@st.cache_data
def fetch_all_liked_songs(_sp):
all_songs = []
offset = 0
while True:
results = _sp.current_user_saved_tracks(limit=50, offset=offset)
if not results['items']:
break
all_songs.extend([{
'id': item['track']['id'],
'name': item['track']['name'],
'artists': [{'name': artist['name'], 'id': artist['id']} for artist in item['track']['artists']],
'album': {
'name': item['track']['album']['name'],
'id': item['track']['album']['id'],
'release_date': item['track']['album']['release_date'],
'total_tracks': item['track']['album']['total_tracks']
},
'duration_ms': item['track']['duration_ms'],
'explicit': item['track']['explicit'],
'popularity': item['track']['popularity'],
'preview_url': item['track']['preview_url'],
'added_at': item['added_at'],
'is_local': item['track']['is_local']
} for item in results['items']])
offset += len(results['items'])
return all_songs
def sanitize_filename(filename):
filename = re.sub(r'[<>:"/\\|?*]', '', filename)
filename = re.sub(r'[\s.]+', '_', filename)
filename = unicodedata.normalize('NFKD', filename).encode('ASCII', 'ignore').decode()
return filename[:100]
def get_preview_filename(song):
safe_name = sanitize_filename(f"{song['name']}_{song['artists'][0]['name']}")
return f"{safe_name}.mp3"
def download_preview(preview_url, song):
if not preview_url:
return False, None
filename = get_preview_filename(song)
output_path = os.path.join("previews", filename)
if os.path.exists(output_path):
return True, output_path
response = requests.get(preview_url)
if response.status_code == 200:
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'wb') as f:
f.write(response.content)
return True, output_path
return False, None
def process_song(song, audio_encoder, qdrant_client):
filename = get_preview_filename(song)
output_path = os.path.join("previews", filename)
if os.path.exists(output_path):
return output_path, None
preview_url = song['preview_url']
if not preview_url:
return None, f"No preview available for: {song['name']} by {song['artists'][0]['name']}"
success, file_path = download_preview(preview_url, song)
if success:
# Check if the song is already in Qdrant
existing_points = qdrant_client.scroll(
collection_name=COLLECTION_NAME,
scroll_filter=models.Filter(
must=[
models.FieldCondition(
key="spotify_id",
match=models.MatchValue(value=song['id'])
)
]
),
limit=1
)[0]
if not existing_points:
embedding = generate_audio_embedding(file_path, audio_encoder)
point_id = str(uuid.uuid4())
qdrant_client.upsert(
collection_name=COLLECTION_NAME,
points=[
models.PointStruct(
id=point_id,
vector=embedding,
payload={
"name": song['name'],
"artists": song['artists'],
"spotify_id": song['id'],
"album": song['album'],
"duration_ms": song['duration_ms'],
"popularity": song['popularity'],
"preview_url": song['preview_url'],
"local_preview_path": file_path
}
)
]
)
return file_path, None
else:
return None, f"Failed to download preview for: {song['name']} by {song['artists'][0]['name']}"
def generate_audio_embedding(audio_path, audio_encoder):
# This is a placeholder. You'll need to implement the actual audio embedding generation
# based on how your audio_encoder works with local audio files
return audio_encoder.extract_audio_representaion(audio_path).tolist()[0]
def retrieve_all_previews(sp, qdrant_client, audio_encoder):
all_songs = fetch_all_liked_songs(sp)
total_songs = len(all_songs)
progress_bar = st.progress(0)
status_text = st.empty()
warnings = []
for i, song in enumerate(all_songs):
_, warning = process_song(song, audio_encoder, qdrant_client)
if warning:
warnings.append(warning)
# Update progress
progress = (i + 1) / total_songs
progress_bar.progress(progress)
status_text.text(f"Processing: {i+1}/{total_songs} songs")
st.success(f"Processed {total_songs} songs.")
return warnings
def display_warnings(warnings):
if warnings:
with st.expander("Processing Warnings", expanded=False):
st.markdown("""
<style>
.warning-box {
background-color: #fff3cd;
border-left: 6px solid #ffeeba;
margin-bottom: 10px;
padding: 10px;
color: #856404;
}
</style>
""", unsafe_allow_html=True)
for warning in warnings:
st.markdown(f'<div class="warning-box">{warning}</div>', unsafe_allow_html=True)
def main():
st.title("Spotify Similarity Search")
if 'session_id' not in st.session_state:
reset_environment()
qdrant_client = get_qdrant_client()
# Sidebar for authentication and data management
with st.sidebar:
st.header("Authentication & Data Management")
if 'spotify_auth' not in st.session_state:
sp = get_spotify_client()
if sp:
st.session_state['spotify_auth'] = sp
if 'spotify_auth' in st.session_state:
st.success("Connected to Spotify and Qdrant")
audio_encoder = load_resources()
if st.button("Logout from Spotify"):
logout()
if st.button("Truncate Qdrant Data"):
truncate_qdrant_data(qdrant_client)
if st.button("Retrieve All Previews"):
with st.spinner("Retrieving previews..."):
warnings = retrieve_all_previews(st.session_state['spotify_auth'], qdrant_client, audio_encoder)
display_warnings(warnings)
elif 'code' in st.experimental_get_query_params():
st.warning("Authentication in progress. Please refresh this page.")
else:
st.info("Please log in to access your Spotify data.")
# Main content area
if 'spotify_auth' in st.session_state:
# Quick Start Guide
st.info("""
### πŸš€ Quick Start Guide
1. πŸ”„ Click 'Retrieve All Previews' in the sidebar, to start getting 30 seconds raw audio previews.
2. πŸ” Enter descriptive keywords (e.g., "upbeat electronic with female vocals")
3. 🎡 Explore similar songs and enjoy!
Note: Some songs may not have previews available mainly due to Spotify restrictions.
βœ… Do: Use specific terms (genre, mood, instruments)
❌ Don't: Use artist names or song titles
πŸ’‘ Tip: Refine your search if results aren't perfect!
""")
st.header("Find Similar Songs")
query_text = st.text_input("Enter a description or keywords for the music you're looking for:")
if st.button("Search Similar Songs") or query_text:
if query_text:
with st.spinner("Searching for similar songs..."):
search_results = find_similar_songs_by_text(query_text, qdrant_client, audio_encoder)
if search_results:
st.subheader("Similar songs based on your description:")
for song in search_results:
st.write(f"{song['name']} by {song['artist']} (Similarity: {song['similarity']:.2f})")
if song['preview_url']:
st.audio(song['preview_url'], format='audio/mp3')
else:
st.write("No preview available")
st.write("---") # Add a separator between songs
else:
st.info("No similar songs found. Try a different description.")
else:
st.warning("Please enter a description or keywords for your search.")
if __name__ == "__main__":
main()