File size: 6,068 Bytes
696d3e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import pandas as pd
from datasets import Dataset
from transformers import pipeline, GPT2Tokenizer
from sentence_transformers import SentenceTransformer, util

# Define paths and models
filename = "output_country_details.txt"
retrieval_model_name = 'output/sentence-transformer-finetuned/'       #using a prefine-tuned model 
gpt2_model_name = "gpt2"
csv_file_path = "train_dataset.csv"
output_csv_file_path = "updated_train_dataset.csv"
val_csv_file_path = "val_dataset.csv"
output_val_csv_file_path = "updated_val_csv.csv"

tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)

# Initialize models
try:
    retrieval_model = SentenceTransformer(retrieval_model_name)
    gpt_model = pipeline("text-generation", model=gpt2_model_name)
    print("Models loaded successfully.")
except Exception as e:
    print(f"Failed to load models: {e}")

def load_and_preprocess_text(filename):
    """
    Load and preprocess text data from a file.

    Parameters:
    - filename (str): Path to the text file.

    Returns:
    - list[str]: A list of preprocessed text segments.
    """
    try:
        with open(filename, 'r', encoding='utf-8') as file:
            segments = [line.strip() for line in file if line.strip()]
        print("Text loaded and preprocessed successfully.")
        return segments
    except Exception as e:
        print(f"Failed to load or preprocess text: {e}")
        return []

segments = load_and_preprocess_text(filename)

def find_relevant_segment(user_query, segments):
    """
    Find the most relevant text segment based on a user query.

    Parameters:
    - user_query (str): The user's query.
    - segments (list[str]): List of text segments to search within.

    Returns:
    - str: The most relevant text segment.
    """
    try:
        query_embedding = retrieval_model.encode(user_query)
        segment_embeddings = retrieval_model.encode(segments)
        similarities = util.pytorch_cos_sim(query_embedding, segment_embeddings)[0]
        best_idx = similarities.argmax()
        return segments[best_idx]
    except Exception as e:
        print(f"Error finding relevant segment: {e}")
        return ""

def generate_response(question):
    """
    Generate a response to a given question by finding a relevant text segment and
    using it to generate a more complete answer.

    Parameters:
    - question (str): The user's question.

    Returns:
    - str: Generated response.
    """
    relevant_segment = find_relevant_segment(question, segments)
    return generate_response_with_context(question, relevant_segment)

def generate_response_with_context(user_query, relevant_segment):
    """
    Generate a response based on a user query and a relevant segment.

    Parameters:
    - user_query (str): The user's query.
    - relevant_segment (str): A relevant fact or detail.

    Returns:
    - str: Formatted response incorporating the relevant segment.
    """
    try:
        prompt = f"Thank you for your question! Here is an additional fact about your topic: {relevant_segment}"
        max_tokens = len(tokenizer(prompt)['input_ids']) + 50
        response = gpt_model(prompt, max_length=max_tokens, temperature=0.25)[0]['generated_text']
        return clean_up_response(response, relevant_segment)
    except Exception as e:
        print(f"Error generating response: {e}")
        return ""

def clean_up_response(response, segment):
    """
    Clean up the generated response to ensure it is tidy and presentable.

    Parameters:
    - response (str): The initial response generated by the model.
    - segment (str): The segment used to generate the response.

    Returns:
    - str: A cleaned and formatted response.
    """
    sentences = response.split('.')
    cleaned_sentences = [sentence.strip() for sentence in sentences if sentence.strip() and sentence.strip() not in segment]
    cleaned_response = '. '.join(cleaned_sentences).strip()
    if cleaned_response and not cleaned_response.endswith((".", "!", "?")):
        cleaned_response += "."
    return cleaned_response

def process_dataset(csv_file_path, output_csv_file_path):
    """
    Process the dataset by generating responses and evaluating their similarities.

    Parameters:
    - csv_file_path (str): Path to the CSV file containing the dataset.
    - output_csv_file_path (str): Path where the updated dataset will be saved.

    Prints:
    - Path to the saved results and the average similarity score.
    """
    df = pd.read_csv(csv_file_path)
    dataset = Dataset.from_pandas(df)
    updated_dataset = add_model_answers(dataset)
    similarities = evaluate_similarity(updated_dataset)
    updated_dataset = updated_dataset.add_column("similarity", similarities)
    results_df = updated_dataset.to_pandas()
    results_df.to_csv(output_csv_file_path, index=False)
    average_similarity = sum(similarities) / len(similarities) if similarities else 0
    print(f"Results saved to {output_csv_file_path}")
    print(f"Average Similarity Score: {average_similarity:.3f}")

def add_model_answers(dataset):
    """
    Add generated answers to the dataset.

    Parameters:
    - dataset (datasets.Dataset): The Hugging Face dataset object.

    Returns:
    - datasets.Dataset: Updated dataset with added answers.
    """
    answers = [generate_response(q) for q in dataset['Question']]
    dataset = dataset.add_column("Answer", answers)
    return dataset

def evaluate_similarity(dataset):
    """
    Evaluate the similarity of generated answers against ground truth answers.

    Parameters:
    - dataset (datasets.Dataset): The dataset containing both answers and ground truths.

    Returns:
    - list[float]: List of similarity scores.
    """
    similarities = [util.pytorch_cos_sim(retrieval_model.encode(ans), retrieval_model.encode(gt))[0][0].item()
                    for ans, gt in zip(dataset['Answer'], dataset['GroundTruth'])]
    return similarities

# Process datasets
process_dataset(csv_file_path, output_csv_file_path)
process_dataset(val_csv_file_path, output_val_csv_file_path)