File size: 5,526 Bytes
ba41c0a
00f57d4
 
 
 
 
 
8ec9ed5
00f57d4
 
 
 
 
8ec9ed5
1fe08c6
 
d32cc71
00f57d4
 
 
 
 
 
 
8ec9ed5
00f57d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f31c25
00f57d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import re
import openai
import inflect
import pandas as pd
from typing import Dict
from datasets import load_dataset
from huggingface_hub import login
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.vectorstores.utils import DistanceStrategy


# Get OpenAI and huggingface-hub keys
openai.api_key = os.environ.get('OPENAI_API_KEY')
openai.organization = os.environ.get('OPENAI_ORG')
login(os.environ.get('HUB_KEY'))


# Constants
FS_COLUMNS = ['asin', 'category', 'title', 'tech_process', 'labels']
MAX_TOKENS = 700
USER_TXT = 'Write feature-bullets for an Amazon product page. ' \
           'Title: {title}. Technical details: {tech_data}.\n\n### Feature-bullets:'

# Load few-shot dataset
FS_DATASET = load_dataset('iarbel/amazon-product-data-filter', split='validation')

# Prepare Pandas DFs with the relevant columns
FS_DS = FS_DATASET.to_pandas()[FS_COLUMNS]

# Load vector store
DB = FAISS.load_local('data/vector_stores/amazon-product-embedding', OpenAIEmbeddings(),
                      distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT)


class Conversation:
    """
    A class to construct conversations with the ChatAPI
    """
    def __init__(self):
        self.messages = [{'role': 'system',
                          'content': 'You are a helpful assistant. Your task is to write feature-bullets for an Amazon product page.'}]

    def add_message(self, role: str, content: str) -> None:
        # Validate inputs
        role = role.lower()
        last_role = self.messages[-1]['role']
        if role not in ['user', 'assistant']:
            raise ValueError('Roles can be "user" or "assistant" only')
        if role == 'user' and last_role not in ['system', 'assistant']:
            raise ValueError('"user" message can only follow "assistant" message')
        elif role == 'assistant' and last_role != 'user':
            raise ValueError('"assistant" message can only follow "user" message')
        
        message = {"role": role, "content": content}
        self.messages.append(message)

def api_call(messages: Dict[str, str], temperature: float = 0.7, top_p: int = 1, n_responses: int = 1) -> dict:
    """
    A function to call the ChatAPI. Taken in a conversation, and the optional params temperature (controls randomness) and n_responses
    """
    params = {'model': 'gpt-4o-mini', 'messages': messages, 'temperature': temperature, 'max_tokens': MAX_TOKENS, 'n': n_responses, 'top_p': top_p}
    response = openai.ChatCompletion.create(**params)

    text = [response['choices'][i]['message']['content'] for i in range(n_responses)]
    out = {'object': 'chat', 'usage': response['usage']._previous, 'text': text}
    return out


class FewShotData:
    def __init__(self, few_shot_df: pd.DataFrame, vector_db: FAISS):
        self.few_shot_df = few_shot_df
        self.vector_db = vector_db
        
    def extract_few_shot_data(self, target_title: str, k_shot: int = 2, **db_kwargs) -> pd.DataFrame:
         # Find relevant products
        target_title_vector = OpenAIEmbeddings().embed_query(target_title)
        similarity_list_mmr = self.vector_db.max_marginal_relevance_search_with_score_by_vector(target_title_vector, k=k_shot, **db_kwargs)
        few_shot_titles = [i[0].page_content for i in similarity_list_mmr]
        
        # Extract relevant data
        few_shot_data = self.few_shot_df[self.few_shot_df['title'].isin(few_shot_titles)][['title', 'tech_process', 'labels']]
        return few_shot_data

    def construct_few_shot_conversation(self, target_title: str, target_tech_data: str, few_shot_data: pd.DataFrame) -> Conversation:
        # Structure the few-shott data
        fs_titles = few_shot_data['title'].to_list()
        fs_tech_data = few_shot_data['tech_process'].to_list()
        fs_labels = few_shot_data['labels'].to_list()
    
        # Init a conversation, populate with few-shot data
        conv = Conversation()
        for title, tech_data, lables in zip(fs_titles, fs_tech_data, fs_labels):
            conv.add_message('user', USER_TXT.format(title=title, tech_data=tech_data))
            conv.add_message('assistant',lables)
            
        # Add the final user prompt
        conv.add_message('user', USER_TXT.format(title=target_title, tech_data=target_tech_data))
        return conv
    
    
def return_is_are(text: str) -> str:
    engine = inflect.engine()
    res = 'is' if not engine.singular_noun(text) else 'are'
    return res

def format_tech_as_str(tech_data):
    tech_format = [f'{k} {return_is_are(k)} {v}' for k, v in tech_data.to_numpy() if k and v]
    tech_str = '. '.join(tech_format)
    return tech_str


def generate_data(title: str, tech_process: str, few_shot_df: pd.DataFrame, vector_db: FAISS) -> str:
    fs_example = FewShotData(few_shot_df=few_shot_df, vector_db=vector_db)
    fs_data = fs_example.extract_few_shot_data(target_title=title, k_shot=2)
    
    fs_conv = fs_example.construct_few_shot_conversation(target_title=title, 
                                                         target_tech_data=tech_process, 
                                                         few_shot_data=fs_data)
    
    api_res = api_call(fs_conv.messages, temperature=0.7)
    feature_bullets = "## Feature-Bullets\n" + api_res['text'][0]
    return feature_bullets


def check_url_structure(url: str) -> bool:
    pattern = r"https://www.amazon.com(/.+)?/dp/[a-zA-Z0-9]{10}/?$"
    return bool(re.match(pattern, url))