Spaces:
Sleeping
Sleeping
import streamlit as st | |
import spotipy | |
from spotipy.oauth2 import SpotifyOAuth | |
from qdrant_client import QdrantClient | |
from qdrant_client.http import models | |
from src.laion_clap.inference import AudioEncoder | |
import os | |
import re | |
import unicodedata | |
import requests | |
import uuid | |
import os | |
# Spotify API credentials | |
SPOTIPY_CLIENT_ID = os.getenv("SPOTIPY_CLIENT_ID") | |
SPOTIPY_CLIENT_SECRET = os.getenv("SPOTIPY_CLIENT_SECRET") | |
SPOTIPY_REDIRECT_URI = os.getenv('SPOTIPY_REDIRECT_URI') | |
SCOPE = 'user-library-read' | |
CACHE_PATH = '.spotifycache' | |
# Qdrant setup | |
QDRANT_HOST = os.getenv('QDRANT_HOST', 'localhost') | |
COLLECTION_NAME = "spotify_songs" | |
st.set_page_config(page_title="Spotify Similarity Search", page_icon="π΅", layout="wide") | |
def reset_environment(): | |
# Clear all st.cache_resource and st.cache_data | |
st.cache_resource.clear() | |
st.cache_data.clear() | |
# Clear all items in session state | |
for key in list(st.session_state.keys()): | |
del st.session_state[key] | |
# Generate a new session ID | |
st.session_state.session_id = str(uuid.uuid4()) | |
def load_resources(): | |
return AudioEncoder() | |
def get_qdrant_client(): | |
client = QdrantClient(path="./qdrant_data") | |
try: | |
client.create_collection(COLLECTION_NAME, vectors_config=models.VectorParams(size=512, distance=models.Distance.COSINE),) | |
except Exception: | |
pass | |
try: | |
client.get_collection(COLLECTION_NAME) | |
except Exception: | |
st.error("Qdrant collection not found. Please ensure the collection is properly initialized.") | |
return client | |
def get_spotify_client(): | |
auth_manager = SpotifyOAuth( | |
client_id=SPOTIPY_CLIENT_ID, | |
client_secret=SPOTIPY_CLIENT_SECRET, | |
redirect_uri=SPOTIPY_REDIRECT_URI, | |
scope=SCOPE, | |
cache_path=CACHE_PATH | |
) | |
if 'code' in st.experimental_get_query_params(): | |
token_info = auth_manager.get_access_token(st.experimental_get_query_params()['code'][0]) | |
return spotipy.Spotify(auth=token_info['access_token']) | |
if not auth_manager.get_cached_token(): | |
auth_url = auth_manager.get_authorize_url() | |
st.markdown(f"[Click here to login with Spotify]({auth_url})") | |
return None | |
return spotipy.Spotify(auth_manager=auth_manager) | |
def find_similar_songs_by_text(_query_text, _qdrant_client, _text_encoder, top_k=10): | |
query_vector = generate_text_embedding(_query_text, _text_encoder) | |
search_result = _qdrant_client.query_points( | |
collection_name=COLLECTION_NAME, | |
query=query_vector.tolist()[0], | |
limit=top_k | |
).model_dump()["points"] | |
return [ | |
{ | |
"name": hit["payload"]["name"], | |
"artist": hit["payload"]["artists"][0]["name"], | |
"similarity": hit["score"], | |
"preview_url": hit["payload"]["preview_url"] | |
} for hit in search_result | |
] | |
def generate_text_embedding(text, text_encoder): | |
text_data = [text] | |
return text_encoder.get_text_embedding(text_data) | |
def logout(): | |
if os.path.exists(CACHE_PATH): | |
os.remove(CACHE_PATH) | |
for key in list(st.session_state.keys()): | |
del st.session_state[key] | |
st.experimental_rerun() | |
def truncate_qdrant_data(qdrant_client): | |
try: | |
qdrant_client.delete_collection(collection_name=COLLECTION_NAME) | |
qdrant_client.create_collection( | |
collection_name=COLLECTION_NAME, | |
vectors_config=models.VectorParams(size=512, distance=models.Distance.COSINE), | |
) | |
st.success("Qdrant data has been truncated successfully.") | |
except Exception as e: | |
st.error(f"An error occurred while truncating Qdrant data: {str(e)}") | |
def fetch_all_liked_songs(_sp): | |
all_songs = [] | |
offset = 0 | |
while True: | |
results = _sp.current_user_saved_tracks(limit=50, offset=offset) | |
if not results['items']: | |
break | |
all_songs.extend([{ | |
'id': item['track']['id'], | |
'name': item['track']['name'], | |
'artists': [{'name': artist['name'], 'id': artist['id']} for artist in item['track']['artists']], | |
'album': { | |
'name': item['track']['album']['name'], | |
'id': item['track']['album']['id'], | |
'release_date': item['track']['album']['release_date'], | |
'total_tracks': item['track']['album']['total_tracks'] | |
}, | |
'duration_ms': item['track']['duration_ms'], | |
'explicit': item['track']['explicit'], | |
'popularity': item['track']['popularity'], | |
'preview_url': item['track']['preview_url'], | |
'added_at': item['added_at'], | |
'is_local': item['track']['is_local'] | |
} for item in results['items']]) | |
offset += len(results['items']) | |
return all_songs | |
def sanitize_filename(filename): | |
filename = re.sub(r'[<>:"/\\|?*]', '', filename) | |
filename = re.sub(r'[\s.]+', '_', filename) | |
filename = unicodedata.normalize('NFKD', filename).encode('ASCII', 'ignore').decode() | |
return filename[:100] | |
def get_preview_filename(song): | |
safe_name = sanitize_filename(f"{song['name']}_{song['artists'][0]['name']}") | |
return f"{safe_name}.mp3" | |
def download_preview(preview_url, song): | |
if not preview_url: | |
return False, None | |
filename = get_preview_filename(song) | |
output_path = os.path.join("previews", filename) | |
if os.path.exists(output_path): | |
return True, output_path | |
response = requests.get(preview_url) | |
if response.status_code == 200: | |
os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
with open(output_path, 'wb') as f: | |
f.write(response.content) | |
return True, output_path | |
return False, None | |
def process_song(song, audio_encoder, qdrant_client): | |
filename = get_preview_filename(song) | |
output_path = os.path.join("previews", filename) | |
if os.path.exists(output_path): | |
return output_path, None | |
preview_url = song['preview_url'] | |
if not preview_url: | |
return None, f"No preview available for: {song['name']} by {song['artists'][0]['name']}" | |
success, file_path = download_preview(preview_url, song) | |
if success: | |
# Check if the song is already in Qdrant | |
existing_points = qdrant_client.scroll( | |
collection_name=COLLECTION_NAME, | |
scroll_filter=models.Filter( | |
must=[ | |
models.FieldCondition( | |
key="spotify_id", | |
match=models.MatchValue(value=song['id']) | |
) | |
] | |
), | |
limit=1 | |
)[0] | |
if not existing_points: | |
embedding = generate_audio_embedding(file_path, audio_encoder) | |
point_id = str(uuid.uuid4()) | |
qdrant_client.upsert( | |
collection_name=COLLECTION_NAME, | |
points=[ | |
models.PointStruct( | |
id=point_id, | |
vector=embedding, | |
payload={ | |
"name": song['name'], | |
"artists": song['artists'], | |
"spotify_id": song['id'], | |
"album": song['album'], | |
"duration_ms": song['duration_ms'], | |
"popularity": song['popularity'], | |
"preview_url": song['preview_url'], | |
"local_preview_path": file_path | |
} | |
) | |
] | |
) | |
return file_path, None | |
else: | |
return None, f"Failed to download preview for: {song['name']} by {song['artists'][0]['name']}" | |
def generate_audio_embedding(audio_path, audio_encoder): | |
# This is a placeholder. You'll need to implement the actual audio embedding generation | |
# based on how your audio_encoder works with local audio files | |
return audio_encoder.extract_audio_representaion(audio_path).tolist()[0] | |
def retrieve_all_previews(sp, qdrant_client, audio_encoder): | |
all_songs = fetch_all_liked_songs(sp) | |
total_songs = len(all_songs) | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
warnings = [] | |
for i, song in enumerate(all_songs): | |
_, warning = process_song(song, audio_encoder, qdrant_client) | |
if warning: | |
warnings.append(warning) | |
# Update progress | |
progress = (i + 1) / total_songs | |
progress_bar.progress(progress) | |
status_text.text(f"Processing: {i+1}/{total_songs} songs") | |
st.success(f"Processed {total_songs} songs.") | |
return warnings | |
def display_warnings(warnings): | |
if warnings: | |
with st.expander("Processing Warnings", expanded=False): | |
st.markdown(""" | |
<style> | |
.warning-box { | |
background-color: #fff3cd; | |
border-left: 6px solid #ffeeba; | |
margin-bottom: 10px; | |
padding: 10px; | |
color: #856404; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
for warning in warnings: | |
st.markdown(f'<div class="warning-box">{warning}</div>', unsafe_allow_html=True) | |
def main(): | |
st.title("Spotify Similarity Search") | |
if 'session_id' not in st.session_state: | |
reset_environment() | |
qdrant_client = get_qdrant_client() | |
# Sidebar for authentication and data management | |
with st.sidebar: | |
st.header("Authentication & Data Management") | |
if 'spotify_auth' not in st.session_state: | |
sp = get_spotify_client() | |
if sp: | |
st.session_state['spotify_auth'] = sp | |
if 'spotify_auth' in st.session_state: | |
st.success("Connected to Spotify and Qdrant") | |
audio_encoder = load_resources() | |
if st.button("Logout from Spotify"): | |
logout() | |
if st.button("Truncate Qdrant Data"): | |
truncate_qdrant_data(qdrant_client) | |
if st.button("Retrieve All Previews"): | |
with st.spinner("Retrieving previews..."): | |
warnings = retrieve_all_previews(st.session_state['spotify_auth'], qdrant_client, audio_encoder) | |
display_warnings(warnings) | |
elif 'code' in st.experimental_get_query_params(): | |
st.warning("Authentication in progress. Please refresh this page.") | |
else: | |
st.info("Please log in to access your Spotify data.") | |
# Main content area | |
if 'spotify_auth' in st.session_state: | |
# Quick Start Guide | |
st.info(""" | |
### π Quick Start Guide | |
1. π Click 'Retrieve All Previews' in the sidebar, to start getting 30 seconds raw audio previews. | |
2. π Enter descriptive keywords (e.g., "upbeat electronic with female vocals") | |
3. π΅ Explore similar songs and enjoy! | |
Note: Some songs may not have previews available mainly due to Spotify restrictions. | |
β Do: Use specific terms (genre, mood, instruments) | |
β Don't: Use artist names or song titles | |
π‘ Tip: Refine your search if results aren't perfect! | |
""") | |
st.header("Find Similar Songs") | |
query_text = st.text_input("Enter a description or keywords for the music you're looking for:") | |
if st.button("Search Similar Songs") or query_text: | |
if query_text: | |
with st.spinner("Searching for similar songs..."): | |
search_results = find_similar_songs_by_text(query_text, qdrant_client, audio_encoder) | |
if search_results: | |
st.subheader("Similar songs based on your description:") | |
for song in search_results: | |
st.write(f"{song['name']} by {song['artist']} (Similarity: {song['similarity']:.2f})") | |
if song['preview_url']: | |
st.audio(song['preview_url'], format='audio/mp3') | |
else: | |
st.write("No preview available") | |
st.write("---") # Add a separator between songs | |
else: | |
st.info("No similar songs found. Try a different description.") | |
else: | |
st.warning("Please enter a description or keywords for your search.") | |
if __name__ == "__main__": | |
main() |