from tenacity import retry, stop_after_attempt, wait_random_exponential from tqdm import tqdm import time import sys # MODEL_NAME = str(sys.argv[1]) # num_shots = int(sys.argv[2]) # method = str(sys.argv[3]) #['fixed', 'random', 'bm25'] # ADDED K-SHOT SETTING, WHERE K IS VARIABLE # import openai import time # import pandas as pd import random random.seed(1) import csv import os import pickle import json import nltk nltk.download('punkt') nltk.download('stopwords') from nltk.tokenize import sent_tokenize from nltk.corpus import stopwords import string from langchain.chat_models import AzureChatOpenAI from langchain.schema import HumanMessage, SystemMessage from langchain.callbacks import get_openai_callback from langchain.llms import OpenAI import tiktoken import re from nltk.tokenize import sent_tokenize from collections import defaultdict import nltk from nltk.tokenize import sent_tokenize from nltk.tokenize import word_tokenize import numpy as np # Get the parent directory # parent_dir = "/home/abnandy/sensei-fs-link"#os.path.abspath(os.path.join(os.getcwd(), os.pardir)) # Add the parent directory to the system path # sys.path.append(parent_dir) from utils import AzureModels, write_to_file, read_from_file # from utils_open import OpenModels def remove_stopwords_and_punctuation(text): # Get the list of stopwords stop_words = set(stopwords.words('english')) # Remove punctuation from text text = text.translate(str.maketrans('', '', string.punctuation.replace('_', '').replace('@', ''))) # Split the text into words words = text.split() # Remove stopwords filtered_words = [word for word in words if word.lower() not in stop_words] # Join the words back into a single string filtered_text = ' '.join(filtered_words) return filtered_text def get_key(list_): tmp_str = '@cite' for item in list_: tmp_str+=item.replace('@cite', '') return tmp_str def group_citations(key): list_ = ["@cite_" + item for item in key.replace("@cite_", "").split("_")] return ", ".join(list_) def code_to_extra_info(code_str): citation_bracket_keys = [] sentence_keys = [] code_lines = code_str.split("\n") for line in code_lines: if "citation_bracket[" in line.split("=")[0]: citation_bracket_keys.append(line.split("=")[0].split('citation_bracket["')[-1].split('"]')[0]) if "sentence[" in line.split("=")[0]: sentence_keys.append(line.split("=")[0].split('sentence["')[-1].split('"]')[0]) cb_template = "{} are in the same citation bracket (i.e., they are right next to each other) within the section of the Wikipedia Article." sent_template = "{} are in the same sentence within the section of the Wikipedia Article." cb_list = [cb_template.format(group_citations(key)) for key in citation_bracket_keys if key.count("_")>1] sent_list = [sent_template.format(group_citations(key)) for key in sentence_keys if key.count("_")>1] if len(cb_list) + len(sent_list) == 0: return "" return_str = "\n\nNOTE THAT -\n\n" + "\n".join(cb_list) + "\n\n" + "\n".join(sent_list) return return_str def get_code_str(related_work, reference_dict): # print(reference_dict.keys()) citation_bracket_code_lines = [] sentence_code_lines = [] # Tokenize the related work into sentences sentences = sent_tokenize(related_work) # Get all citation tags from the reference_dict citation_tags = list(reference_dict.keys()) for sentence in sentences: tmp_sentence_list = [] parts = remove_stopwords_and_punctuation(sentence).split(' ') cb_list = [] str_cb_list = [] # print(parts) # print(reference_dict.keys()) # print(1/0) for word in parts: if word in reference_dict: cb_list.append(word) str_cb_list.append('"' + word + '"') else: if len(cb_list)>0: # print(cb_list) citation_bracket_code_lines.append('citation_bracket["{}"] = {}'.format(get_key(cb_list), str(str_cb_list))) tmp_sentence_list.append(get_key(cb_list)) cb_list = [] str_cb_list = [] if len(cb_list) > 0: citation_bracket_code_lines.append('citation_bracket["{}"] = {}'.format(get_key(cb_list), str(str_cb_list))) tmp_sentence_list.append(get_key(cb_list)) cb_list = [] str_cb_list = [] tmp_values = [] for key in tmp_sentence_list: tmp_values.append('citation_bracket["{}"]'.format(key)) if len(tmp_values) > 0: sentence_code_lines.append('sentence["{}"] = {}'.format(get_key(tmp_sentence_list), str(tmp_values))) return " " + "\n ".join(citation_bracket_code_lines).replace("'", "") + "\n\n " + "\n ".join(sentence_code_lines).replace("'", "") def get_prompt(list_, i, prompt_template): gt_summary = list_[i]['related_work'].strip() inp_intent = list_[i]['abstract'].strip() input_code_str = " " input_code_list = [] # print(sent_tokenize(gt_summary)) # print() # print(1/0) tmp_list = list_[i]['ref_abstract'] # abstract_list = [] # cite_tags = [] abstract_dict = {} # write_to_file("dummy.json", tmp_list) for key in tmp_list: abstract_dict[key] = tmp_list[key]['abstract'].strip() for key in abstract_dict: input_code_list.append('reference_articles["{}"] = "{}"'.format(key, abstract_dict[key])) input_code_list.append('intent = "{}"'.format(inp_intent)) input_code_str += "\n ".join(input_code_list) code_str = get_code_str(gt_summary, tmp_list) prompt = prompt_template.format(input_code_str) return gt_summary, prompt, code_str def preprocess_retrieved_out(tmp_keys, out): new_dict = {} for key in tmp_keys: for line in out.split("\n"): if key in line: summ_doc = line.split(":", 1)[-1].strip() new_dict[key] = {"abstract": summ_doc} print(key) print(summ_doc) print() break return new_dict def get_slide(topic, text): slide_prompt = '''Convert this text into more structured text (in markdown) that can be put into the content of a slide in a presentation (e.g. use bullet points, numbered points, proper layout, etc.). Also, the include the topic "{}" of the slide. - {}''' azure_models = AzureModels("gpt4o") slide_prompt = slide_prompt.format(topic, text) out_ = azure_models.get_completion(slide_prompt, 100) time.sleep(2) return out_ def get_retrieved_results(MODEL_NAME, num_shots, method, train_list, test_list, code=False, organize_out=None): response_template = '' instruction_template = '' final_dict = {} pred_dict = {} start_idx = 0 icl_extra_info = "" test_extra_info = "" if 'gpt4' in MODEL_NAME: azure_models = AzureModels(MODEL_NAME) else: if code: instruction_template = '''Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: ''' response_template = '### Response:\n' else: response_template = '### Assistant: ' if MODEL_NAME=='gemma2b': model_id = "google/gemma-2b-it" elif MODEL_NAME=='gemma7b': model_id = "google/gemma-7b-it" elif MODEL_NAME=='mistral7b': model_id = "mistralai/Mistral-7B-Instruct-v0.3" elif MODEL_NAME=="llama7b": model_id = "meta-llama/Llama-2-7b-chat-hf" elif MODEL_NAME=="llama13b": model_id = "meta-llama/Llama-2-13b-chat-hf" elif MODEL_NAME=="llama3": model_id="meta-llama/Meta-Llama-3-8B-Instruct" elif MODEL_NAME=="galactica7b": model_id = "facebook/galactica-6.7b" open_models = OpenModels(model_id) prompt_template = '''Given are a set of articles referenced in a Wikipedia Article, and the intent - Reference Articles: {} Intent: {} Summarize each reference article (generate in the format "@cite_K : ", each in a new line, where @cite_K represents each of the following citation/reference tags - {}, given in Reference Articles), given the reference articles as documents, and the intent.{} {}Answer: ''' if organize_out!=None: prompt_template = '''Given are a set of articles referenced in a Wikipedia Article, and the intent - Reference Articles: {} Intent: {} Generate the wikipedia article section in 100-200 words based on the intent as an intent-based multi-document summary, given the reference articles as documents, and the intent.{} {}Answer: ''' if code: prompt_template = '''def main(): # Given is a dictionary of articles that are referenced in a section of the Wikipedia Article, and the intent - reference_articles = dict() {}''' if method == 'bm25': retrieve_dict = read_from_file("bm25_10_icl_samples_50_holdout_samples.json") elif method == "gat": retrieve_dict = read_from_file("gat_20_icl_samples_50_holdout_samples.json") #len(test_list))): icl_train_indices = [0,1] if code: for i in tqdm(range(start_idx, len(test_list))):#start_idx, len(test_list))): if len(test_list[i]['ref_abstract']) > 1: full_icl_prompt = "" hier_cluster_prompt = "\n def hierarchical_clustering():\n # Hierarchical Clustering of references within a section of the Wikipedia Article, based on the reference articles and the intent\n citation_bracket = {} # This dictionary contains lists as values that shows how references are grouped within the same citation bracket in the section of the Wikipedia Article\n sentence = {} # This dictionary contains lists, where each list contains references in a sentence in the section of the Wikipedia Article\n\n" if num_shots > 0: if method == "random": icl_train_indices = random.sample(holdout_indices, num_shots)#random.sample(np.arange(len(train_list)).tolist()) elif (method == "bm25") or (method == "gat"): icl_train_indices = [int(retrieve_dict[str(i)][j]) for j in range(num_shots)] elif method == 'fixed': icl_train_indices = icl_train_indices[:num_shots] for enum_idx, icl_train_idx in enumerate(icl_train_indices): # Fixed ICL Sample icl_gt_summary, icl_prompt, icl_code_str = get_prompt(train_list, icl_train_idx, prompt_template) # this particular example has 6 citations # icl_gt_summary_2, icl_prompt_2, icl_code_str_2 = get_prompt(train_list, 85) # this particular example has 12 citations, 4 of which are missing full_icl_prompt += "##Example {}:\n\n".format(enum_idx + 1) + instruction_template + icl_prompt + hier_cluster_prompt + icl_code_str + "\n\n" full_icl_prompt += "##Example {}:\n\n".format(num_shots+1) gt_summary, prompt, code_str = get_prompt(test_list, i, prompt_template) # full_icl_prompt_2 = "##Example 2:\n\n" + icl_prompt_2 + hier_cluster_prompt + icl_code_str_2 final_prompt = full_icl_prompt + instruction_template + prompt + hier_cluster_prompt + " # only generate the code that comes after this, as if you are on autocomplete mode\n" + response_template # final_prompt = full_icl_prompt + "\n\n" + full_icl_prompt_2 + "\n\n" + prompt # final_prompt = full_icl_prompt + "\n\n" + prompt # print(get_num_inp_tokens(final_prompt)) # print(gt_summary) # print("---------") # print(final_prompt) # print("---------") # print("GT:") # print(code_str) # print("---------") max_tokens = 500 if 'gpt4' in MODEL_NAME: out_ = azure_models.get_completion(final_prompt, max_tokens) time.sleep(2) else: out_ = open_models.open_completion(final_prompt, max_tokens, stop_token="##Example {}".format(num_shots + 2)) # print("Predicted:") # print(out_) final_dict[i] = out_ return final_dict # write_to_file(save_filepath, final_dict) else: if organize_out==None: tmp_max_tok_len=1000 else: tmp_max_tok_len=300 for i in tqdm(range(start_idx, len(test_list))):#len(test_list))): if len(test_list[i]['ref_abstract']) > 1: icl_prompt = "" if num_shots > 0: if method == "random": icl_train_indices = random.sample(holdout_indices, num_shots)#random.sample(np.arange(len(train_list)).tolist()) elif method == "bm25": icl_train_indices = [int(retrieve_dict[str(i)][j]) for j in range(num_shots)] elif method == 'fixed': icl_train_indices = icl_train_indices[:num_shots] for enum_idx, icl_train_idx in enumerate(icl_train_indices): icl_tmp_list = train_list[icl_train_idx]['ref_abstract'] icl_inp_intent = train_list[icl_train_idx]['abstract'] icl_gt_summary = train_list[icl_train_idx]['related_work'] if organize_out!=None: icl_code_str = get_code_str(icl_gt_summary, icl_tmp_list) icl_extra_info = code_to_extra_info(icl_code_str) icl_abstract_dict = {} for key in icl_tmp_list: if organize_out==None: icl_abstract_dict[key] = icl_tmp_list[key]#['abstract'] else: icl_abstract_dict[key] = icl_tmp_list[key]['abstract'] icl_abstract_list = [key + " : " + icl_abstract_dict[key] for key in icl_abstract_dict] icl_paper_abstracts = "\n".join(icl_abstract_list) icl_prompt += "##Example {}:\n\n".format(enum_idx + 1) + prompt_template.format(icl_paper_abstracts, icl_inp_intent, " ".join(list(icl_tmp_list.keys())), icl_extra_info, response_template) + icl_gt_summary.strip() + "\n\n" icl_prompt += "##Example {}:\n\n".format(num_shots+1) gt_summary = test_list[i]['related_work'] inp_intent = test_list[i]['abstract'] if organize_out!=None: test_code_str = organize_out[str(i)] test_extra_info = code_to_extra_info(test_code_str) # print(sent_tokenize(gt_summary)) # print() # print(1/0) tmp_list = test_list[i]['ref_abstract'] # abstract_list = [] # cite_tags = [] abstract_dict = {} for key in tmp_list: if organize_out==None: abstract_dict[key] = tmp_list[key]#['abstract'] else: abstract_dict[key] = tmp_list[key]['abstract'] abstract_list = [key + " : " + abstract_dict[key] for key in abstract_dict] paper_abstracts = "\n".join(abstract_list) prompt = prompt_template.format(paper_abstracts, inp_intent, " ".join(list(tmp_list.keys())), test_extra_info, response_template) # if num_shots == 1: prompt = icl_prompt + prompt # print(prompt) # print("-----------") if 'gpt4' in MODEL_NAME: out_ = azure_models.get_completion(prompt, tmp_max_tok_len) time.sleep(2) else: # try: out_ = open_models.open_completion(prompt, tmp_max_tok_len, temperature=0.7) if organize_out==None: test_list[i]["ref_abstract"] = preprocess_retrieved_out(tmp_list, out_) else: pred_dict[i] = out_ # return pred_dict # write_to_file("retrieved_docs.json", test_list) if organize_out==None: return test_list else: return pred_dict