Spaces:
Sleeping
Sleeping
berkaygkv54
commited on
Commit
•
24510fe
1
Parent(s):
a20c02a
llm integration
Browse files- app.py +49 -76
- src/laion_clap/inference.py +79 -18
- src/llm/chain.py +51 -0
- src/llm/output_parser.py +8 -0
- src/utils/__init__.py +0 -0
app.py
CHANGED
@@ -1,79 +1,67 @@
|
|
1 |
import streamlit as st
|
2 |
from streamlit import session_state as session
|
3 |
-
from src.config.configs import ProjectPaths
|
4 |
-
import numpy as np
|
5 |
from src.laion_clap.inference import AudioEncoder
|
6 |
-
import
|
7 |
-
import torch
|
8 |
import pandas as pd
|
9 |
-
import json
|
10 |
-
import os
|
11 |
-
import smtplib, ssl
|
12 |
from dotenv import load_dotenv
|
13 |
-
|
|
|
|
|
|
|
14 |
|
|
|
15 |
load_dotenv()
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
31 |
|
32 |
@st.cache_resource
|
33 |
-
def
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
|
38 |
-
|
39 |
-
port = int(os.getenv("PORT"))
|
40 |
-
print(port)
|
41 |
-
smtp_server = "smtp.gmail.com"
|
42 |
-
sender_email = os.getenv("EMAIL_ADDRESS")
|
43 |
-
receiver_email = os.getenv("EMAIL_RECEIVER")
|
44 |
-
password = os.getenv("EMAIL_PASSWORD")
|
45 |
-
from email.mime.multipart import MIMEMultipart
|
46 |
-
from email.mime.text import MIMEText
|
47 |
-
|
48 |
-
msg = MIMEMultipart("alternative")
|
49 |
-
msg["Subject"] = "Curate me a playlist submission"
|
50 |
-
part1 = MIMEText(body, "plain")
|
51 |
-
msg.attach(part1)
|
52 |
-
context = ssl.create_default_context()
|
53 |
-
with smtplib.SMTP_SSL(smtp_server, port, context=context) as server:
|
54 |
-
server.login(sender_email, password)
|
55 |
-
server.sendmail(sender_email, receiver_email, msg)
|
56 |
-
|
57 |
-
print("Email sent.")
|
58 |
-
|
59 |
-
|
60 |
-
recommender = load_model()
|
61 |
-
audio_vectors, song_names, df_youtube = load_data()
|
62 |
|
63 |
st.title("""Curate me a Playlist.""")
|
64 |
session.text_input = st.text_input(label="Describe a playlist")
|
65 |
-
session.slider_count = st.slider(label="Track counts", min_value=5, max_value=
|
66 |
-
buffer1, col1, buffer2 = st.columns([1.45, 1, 1])
|
67 |
|
68 |
is_clicked = col1.button(label="Curate")
|
69 |
if is_clicked:
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
77 |
st.data_editor(
|
78 |
dataframe,
|
79 |
column_config={
|
@@ -88,22 +76,7 @@ if is_clicked:
|
|
88 |
use_container_width=True
|
89 |
)
|
90 |
|
91 |
-
form = st.form("form")
|
92 |
-
form.write("You can submit the playlist you've curated")
|
93 |
-
sender = form.text_input("Name of the curator")
|
94 |
-
query = session.text_input
|
95 |
-
playlist = [f"{k}\n" for k in dataframe.index]
|
96 |
-
playlist_string = "\n".join(dataframe.index.tolist())
|
97 |
-
body = f"""\
|
98 |
-
Subject: Curate me a playlist submission
|
99 |
-
|
100 |
-
Curator --> {sender}
|
101 |
-
Query --> {session.text_input}
|
102 |
-
|
103 |
-
Playlist
|
104 |
-
{playlist_string}
|
105 |
-
"""
|
106 |
-
|
107 |
|
108 |
-
|
109 |
-
|
|
|
|
1 |
import streamlit as st
|
2 |
from streamlit import session_state as session
|
|
|
|
|
3 |
from src.laion_clap.inference import AudioEncoder
|
4 |
+
from src.utils.spotify import SpotifyHandler, SpotifyAuthentication
|
|
|
5 |
import pandas as pd
|
|
|
|
|
|
|
6 |
from dotenv import load_dotenv
|
7 |
+
from langchain.llms import CTransformers, Ollama
|
8 |
+
from src.llm.chain import LLMChain
|
9 |
+
from pymongo.mongo_client import MongoClient
|
10 |
+
import os
|
11 |
|
12 |
+
st.set_page_config(page_title="Curate me a playlist", layout="wide")
|
13 |
load_dotenv()
|
14 |
|
15 |
+
def load_llm_pipeline():
|
16 |
+
ctransformers_config = {
|
17 |
+
"max_new_tokens": 3000,
|
18 |
+
"temperature": 0,
|
19 |
+
"top_k": 1,
|
20 |
+
"top_p": 1,
|
21 |
+
"context_length": 2800
|
22 |
+
}
|
23 |
+
|
24 |
+
llm = CTransformers(
|
25 |
+
model="TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
|
26 |
+
model_file="mistral-7b-instruct-v0.1.Q5_K_M.gguf",
|
27 |
+
config=ctransformers_config
|
28 |
+
)
|
29 |
+
# llm = Ollama(temperature=0, model="mistral:7b-instruct-q8_0", top_k=1, top_p=1, num_ctx=2800)
|
30 |
+
chain = LLMChain(llm)
|
31 |
+
return chain
|
32 |
|
33 |
@st.cache_resource
|
34 |
+
def load_resources():
|
35 |
+
password = os.getenv("MONGODB_PASSWORD")
|
36 |
+
url = os.getenv("MONGODB_URL")
|
37 |
+
uri = f"mongodb+srv://berkaygkv:{password}@{url}/?retryWrites=true&w=majority"
|
38 |
+
client = MongoClient(uri)
|
39 |
+
db = client.spoti
|
40 |
+
mongo_db_collection = db.saved_tracks
|
41 |
+
recommender = AudioEncoder(mongo_db_collection)
|
42 |
+
recommender.load_existing_audio_vectors()
|
43 |
+
llm_pipeline = load_llm_pipeline()
|
44 |
+
return recommender, llm_pipeline
|
45 |
|
46 |
|
47 |
+
recommender, llm_pipeline = load_resources()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
st.title("""Curate me a Playlist.""")
|
50 |
session.text_input = st.text_input(label="Describe a playlist")
|
51 |
+
session.slider_count = st.slider(label="Track counts", min_value=5, max_value=35, step=5)
|
52 |
+
buffer1, col1, col2, buffer2 = st.columns([1.45, 1, 1, 1])
|
53 |
|
54 |
is_clicked = col1.button(label="Curate")
|
55 |
if is_clicked:
|
56 |
+
output = llm_pipeline.process_user_description(session.text_input)
|
57 |
+
song_list = []
|
58 |
+
for _, song_desc in output:
|
59 |
+
print(song_desc)
|
60 |
+
ranking = recommender.list_top_k_songs(song_desc, k=15)
|
61 |
+
song_list += ranking
|
62 |
+
|
63 |
+
dataframe = pd.DataFrame(song_list).sort_values("score", ascending=False).drop_duplicates(subset=["track_id"]).drop(columns=["track_id"]).reset_index(drop=True)
|
64 |
+
dataframe = dataframe.iloc[:session.slider_count]
|
65 |
st.data_editor(
|
66 |
dataframe,
|
67 |
column_config={
|
|
|
76 |
use_container_width=True
|
77 |
)
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
+
# with st.form(key="spotiform"):
|
81 |
+
# st.form_submit_button(on_click=authenticate_spotify, args=(session.access_url, ))
|
82 |
+
# st.markdown(session.access_url)
|
src/laion_clap/inference.py
CHANGED
@@ -1,41 +1,102 @@
|
|
1 |
-
|
2 |
import librosa
|
3 |
import torch
|
4 |
from src import laion_clap
|
5 |
-
|
6 |
-
import
|
7 |
from ..config.configs import ProjectPaths
|
8 |
-
import pickle
|
9 |
|
10 |
|
11 |
class AudioEncoder(laion_clap.CLAP_Module):
|
12 |
-
def __init__(self) -> None:
|
13 |
-
super().__init__(enable_fusion=False, amodel=
|
|
|
14 |
self.load_ckpt(ckpt=ProjectPaths.MODEL_PATH)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
def extract_audio_representaion(self, file_name):
|
17 |
audio_data, _ = librosa.load(file_name, sr=48000)
|
18 |
audio_data = audio_data.reshape(1, -1)
|
|
|
19 |
with torch.no_grad():
|
20 |
-
audio_embed = self.get_audio_embedding_from_data(
|
|
|
|
|
21 |
return audio_embed
|
22 |
|
23 |
def extract_bulk_audio_representaions(self, save=False):
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
for
|
28 |
-
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
-
else:
|
34 |
-
np.save(ProjectPaths.DATA_DIR.joinpath("vectors", "audio_representations.npy"))
|
35 |
-
with open(ProjectPaths.DATA_DIR.joinpath("vectors", "song_names.pkl", "rb")) as writer:
|
36 |
-
pickle.dump(song_names, writer)
|
37 |
|
38 |
def extract_text_representation(self, text):
|
39 |
text_data = [text]
|
40 |
text_embed = self.get_text_embedding(text_data)
|
41 |
return text_embed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
import librosa
|
3 |
import torch
|
4 |
from src import laion_clap
|
5 |
+
import json
|
6 |
+
import jmespath
|
7 |
from ..config.configs import ProjectPaths
|
|
|
8 |
|
9 |
|
10 |
class AudioEncoder(laion_clap.CLAP_Module):
|
11 |
+
def __init__(self, collection=None) -> None:
|
12 |
+
super().__init__(enable_fusion=False, amodel="HTSAT-base")
|
13 |
+
self.music_data = None
|
14 |
self.load_ckpt(ckpt=ProjectPaths.MODEL_PATH)
|
15 |
+
self.collection = collection
|
16 |
+
|
17 |
+
# def _get_track_data(self):
|
18 |
+
# with open(ProjectPaths.DATA_DIR.joinpath("json", "final_track_data.json"), "r") as reader:
|
19 |
+
# track_data = json.load(reader)
|
20 |
+
# return track_data
|
21 |
+
|
22 |
+
def _get_track_data(self):
|
23 |
+
data = self.collection.find({})
|
24 |
+
return data
|
25 |
+
|
26 |
+
|
27 |
+
def update_collection_item(self, track_id, vector):
|
28 |
+
self.collection.update_one({"track_id": track_id}, {"$set": {"embedding": vector}})
|
29 |
+
|
30 |
|
31 |
def extract_audio_representaion(self, file_name):
|
32 |
audio_data, _ = librosa.load(file_name, sr=48000)
|
33 |
audio_data = audio_data.reshape(1, -1)
|
34 |
+
audio_data = torch.from_numpy(audio_data)
|
35 |
with torch.no_grad():
|
36 |
+
audio_embed = self.get_audio_embedding_from_data(
|
37 |
+
x=audio_data, use_tensor=True
|
38 |
+
)
|
39 |
return audio_embed
|
40 |
|
41 |
def extract_bulk_audio_representaions(self, save=False):
|
42 |
+
track_data = self._get_track_data()
|
43 |
+
processed_data = []
|
44 |
+
idx = 0
|
45 |
+
for track in tqdm(track_data):
|
46 |
+
if track["youtube_data"]["file_path"] and track["youtube_data"]["link"] not in processed_data:
|
47 |
+
tensor = self.extract_audio_representaion(track["youtube_data"]["file_path"])
|
48 |
+
self.update_collection_item(track["track_id"], tensor.tolist())
|
49 |
+
idx += 1
|
50 |
+
|
51 |
|
52 |
+
# def load_existing_audio_vectors(self):
|
53 |
+
# self.music_data = torch.load(
|
54 |
+
# ProjectPaths.DATA_DIR.joinpath("vectors", "audio_representations.pt")
|
55 |
+
# )
|
56 |
+
# with open(
|
57 |
+
# ProjectPaths.DATA_DIR.joinpath("vectors", "final_track_data_w_links.json"),
|
58 |
+
# "r",
|
59 |
+
# ) as rd:
|
60 |
+
# self.track_data = json.load(rd)
|
61 |
+
|
62 |
+
def load_existing_audio_vectors(self):
|
63 |
+
# embedding_result = list(self.collection.find({}, {"embedding": 1}))
|
64 |
+
# tracking_result = list(self.collection.find({}, {"embedding": 0}))
|
65 |
+
arrays = []
|
66 |
+
track_data = []
|
67 |
+
for idx, track in enumerate(self.collection.find({})):
|
68 |
+
if not track.get("embedding"):
|
69 |
+
continue
|
70 |
+
data = track.copy()
|
71 |
+
data.pop("embedding")
|
72 |
+
data.update({"vector_idx": idx})
|
73 |
+
arrays.append(track["embedding"][0])
|
74 |
+
track_data.append(data)
|
75 |
+
|
76 |
+
self.music_data = torch.tensor(arrays)
|
77 |
+
self.track_data = track_data.copy()
|
78 |
|
|
|
|
|
|
|
|
|
79 |
|
80 |
def extract_text_representation(self, text):
|
81 |
text_data = [text]
|
82 |
text_embed = self.get_text_embedding(text_data)
|
83 |
return text_embed
|
84 |
+
|
85 |
+
def list_top_k_songs(self, text, k=10):
|
86 |
+
assert self.music_data is not None
|
87 |
+
with torch.no_grad():
|
88 |
+
text_embed = self.get_text_embedding(text, use_tensor=True)
|
89 |
+
|
90 |
+
dot_product = self.music_data @ text_embed.T
|
91 |
+
top_10 = torch.topk(dot_product.flatten(), k)
|
92 |
+
indices = top_10.indices.tolist()
|
93 |
+
final_result = []
|
94 |
+
for k, i in enumerate(indices):
|
95 |
+
piece = {
|
96 |
+
"title": self.track_data[i]["youtube_data"]["title"],
|
97 |
+
"score": round(top_10.values[k].item(), 2),
|
98 |
+
"link": self.track_data[i]["youtube_data"]["link"],
|
99 |
+
"track_id": self.track_data[i]["track_id"],
|
100 |
+
}
|
101 |
+
final_result.append(piece)
|
102 |
+
return final_result
|
src/llm/chain.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import ChatPromptTemplate, PromptTemplate
|
2 |
+
from langchain.schema.runnable import RunnableLambda
|
3 |
+
from operator import itemgetter
|
4 |
+
from langchain.output_parsers import PydanticOutputParser
|
5 |
+
from .output_parser import SongDescriptions
|
6 |
+
from langchain.llms.base import LLM
|
7 |
+
import json
|
8 |
+
|
9 |
+
|
10 |
+
class LLMChain:
|
11 |
+
def __init__(self, llm_model: LLM) -> None:
|
12 |
+
self.llm_model = llm_model
|
13 |
+
self.parser = PydanticOutputParser(pydantic_object=SongDescriptions)
|
14 |
+
self.full_chain = self._create_llm_chain()
|
15 |
+
|
16 |
+
|
17 |
+
def _get_output_format(self, _):
|
18 |
+
return self.parser.get_format_instructions()
|
19 |
+
|
20 |
+
def _create_llm_chain(self):
|
21 |
+
prompt_response = ChatPromptTemplate.from_messages([
|
22 |
+
("system", "You are an AI assistant, helping the user to turn a music playlist text description into four separate song descriptions that are probably contained in the playlist. Try to be specific with descriptions. Make sure all 4 song descriptions are similar.\n"),
|
23 |
+
("system", "{format_instructions}\n"),
|
24 |
+
("human", "Playlist description: {description}.\n"),
|
25 |
+
# ("human", "Song descriptions:"),
|
26 |
+
])
|
27 |
+
# prompt = PromptTemplate(
|
28 |
+
# template="You are an AI assistant, helping the user to turn a music playlist text description into three separate generic song descriptions that are probably contained in the playlist.\n{format_instructions}\n{description}\n",
|
29 |
+
# input_variables=["description"],
|
30 |
+
# partial_variables={"format_instructions": self.parser.get_format_instructions()},
|
31 |
+
# )
|
32 |
+
|
33 |
+
|
34 |
+
full_chain = (
|
35 |
+
{
|
36 |
+
"format_instructions": RunnableLambda(self._get_output_format),
|
37 |
+
"description": itemgetter("description"),
|
38 |
+
}
|
39 |
+
| prompt_response
|
40 |
+
| self.llm_model
|
41 |
+
)
|
42 |
+
return full_chain
|
43 |
+
|
44 |
+
def process_user_description(self, user_input):
|
45 |
+
output = self.full_chain.invoke(
|
46 |
+
{
|
47 |
+
"description": user_input
|
48 |
+
}
|
49 |
+
).replace("\\", '')
|
50 |
+
return self.parser.parse(output)
|
51 |
+
|
src/llm/output_parser.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
|
3 |
+
|
4 |
+
class SongDescriptions(BaseModel):
|
5 |
+
song_description_1: str = Field(description="description of the first song")
|
6 |
+
song_description_2: str = Field(description="description of the second song")
|
7 |
+
song_description_3: str = Field(description="description of the third song")
|
8 |
+
song_description_4: str = Field(description="description of the fourth song")
|
src/utils/__init__.py
ADDED
File without changes
|