|
import os |
|
import re |
|
import openai |
|
import inflect |
|
import pandas as pd |
|
from typing import Dict |
|
from datasets import load_dataset |
|
from huggingface_hub import login |
|
from langchain.embeddings.openai import OpenAIEmbeddings |
|
from langchain.vectorstores import FAISS |
|
from langchain.vectorstores.utils import DistanceStrategy |
|
|
|
|
|
|
|
openai.api_key = os.environ.get('OPENAI_API_KEY') |
|
openai.organization = os.environ.get('OPENAI_ORG') |
|
login(os.environ.get('HUB_KEY')) |
|
|
|
|
|
|
|
FS_COLUMNS = ['asin', 'category', 'title', 'tech_process', 'labels'] |
|
MAX_TOKENS = 700 |
|
USER_TXT = 'Write feature-bullets for an Amazon product page. ' \ |
|
'Title: {title}. Technical details: {tech_data}.\n\n### Feature-bullets:' |
|
|
|
|
|
FS_DATASET = load_dataset('iarbel/amazon-product-data-filter', split='validation') |
|
|
|
|
|
FS_DS = FS_DATASET.to_pandas()[FS_COLUMNS] |
|
|
|
|
|
DB = FAISS.load_local('data/vector_stores/amazon-product-embedding', OpenAIEmbeddings(), |
|
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT) |
|
|
|
|
|
class Conversation: |
|
""" |
|
A class to construct conversations with the ChatAPI |
|
""" |
|
def __init__(self): |
|
self.messages = [{'role': 'system', |
|
'content': 'You are a helpful assistant. Your task is to write feature-bullets for an Amazon product page.'}] |
|
|
|
def add_message(self, role: str, content: str) -> None: |
|
|
|
role = role.lower() |
|
last_role = self.messages[-1]['role'] |
|
if role not in ['user', 'assistant']: |
|
raise ValueError('Roles can be "user" or "assistant" only') |
|
if role == 'user' and last_role not in ['system', 'assistant']: |
|
raise ValueError('"user" message can only follow "assistant" message') |
|
elif role == 'assistant' and last_role != 'user': |
|
raise ValueError('"assistant" message can only follow "user" message') |
|
|
|
message = {"role": role, "content": content} |
|
self.messages.append(message) |
|
|
|
def api_call(messages: Dict[str, str], temperature: float = 0.7, top_p: int = 1, n_responses: int = 1) -> dict: |
|
""" |
|
A function to call the ChatAPI. Taken in a conversation, and the optional params temperature (controls randomness) and n_responses |
|
""" |
|
params = {'model': 'gpt-3.5-turbo', 'messages': messages, 'temperature': temperature, 'max_tokens': MAX_TOKENS, 'n': n_responses, 'top_p': top_p} |
|
response = openai.ChatCompletion.create(**params) |
|
|
|
text = [response['choices'][i]['message']['content'] for i in range(n_responses)] |
|
out = {'object': 'chat', 'usage': response['usage']._previous, 'text': text} |
|
return out |
|
|
|
|
|
class FewShotData: |
|
def __init__(self, few_shot_df: pd.DataFrame, vector_db: FAISS): |
|
self.few_shot_df = few_shot_df |
|
self.vector_db = vector_db |
|
|
|
def extract_few_shot_data(self, target_title: str, k_shot: int = 2, **db_kwargs) -> pd.DataFrame: |
|
|
|
target_title_vector = OpenAIEmbeddings().embed_query(target_title) |
|
similarity_list_mmr = self.vector_db.max_marginal_relevance_search_with_score_by_vector(target_title_vector, k=k_shot, **db_kwargs) |
|
few_shot_titles = [i[0].page_content for i in similarity_list_mmr] |
|
|
|
|
|
few_shot_data = self.few_shot_df[self.few_shot_df['title'].isin(few_shot_titles)][['title', 'tech_process', 'labels']] |
|
return few_shot_data |
|
|
|
def construct_few_shot_conversation(self, target_title: str, target_tech_data: str, few_shot_data: pd.DataFrame) -> Conversation: |
|
|
|
fs_titles = few_shot_data['title'].to_list() |
|
fs_tech_data = few_shot_data['tech_process'].to_list() |
|
fs_labels = few_shot_data['labels'].to_list() |
|
|
|
|
|
conv = Conversation() |
|
for title, tech_data, lables in zip(fs_titles, fs_tech_data, fs_labels): |
|
conv.add_message('user', USER_TXT.format(title=title, tech_data=tech_data)) |
|
conv.add_message('assistant',lables) |
|
|
|
|
|
conv.add_message('user', USER_TXT.format(title=target_title, tech_data=target_tech_data)) |
|
return conv |
|
|
|
|
|
def return_is_are(text: str) -> str: |
|
engine = inflect.engine() |
|
res = 'is' if not engine.singular_noun(text) else 'are' |
|
return res |
|
|
|
def format_tech_as_str(tech_data): |
|
tech_format = [f'{k} {return_is_are(k)} {v}' for k, v in tech_data.to_numpy() if k and v] |
|
tech_str = '. '.join(tech_format) |
|
return tech_str |
|
|
|
|
|
def generate_data(title: str, tech_process: str, few_shot_df: pd.DataFrame, vector_db: FAISS) -> str: |
|
fs_example = FewShotData(few_shot_df=few_shot_df, vector_db=vector_db) |
|
fs_data = fs_example.extract_few_shot_data(target_title=title, k_shot=2) |
|
|
|
fs_conv = fs_example.construct_few_shot_conversation(target_title=title, |
|
target_tech_data=tech_process, |
|
few_shot_data=fs_data) |
|
|
|
api_res = api_call(fs_conv.messages, temperature=0.7) |
|
feature_bullets = "## Feature-Bullets\n" + api_res['text'][0] |
|
return feature_bullets |
|
|
|
|
|
def check_url_structure(url: str) -> bool: |
|
pattern = r"https://www.amazon.com(/.+)?/dp/[a-zA-Z0-9]{10}/?$" |
|
return bool(re.match(pattern, url)) |
|
|
|
|