|
import pandas as pd |
|
import chromadb |
|
from sklearn.model_selection import train_test_split |
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments, pipeline |
|
import gradio as gr |
|
import email |
|
|
|
|
|
emails = pd.read_csv('emails.csv') |
|
def preprocess_email_content(raw_email): |
|
message = email.message_from_string(raw_email).get_payload() |
|
return message.replace("\n", "").replace("\r", "").replace("> >>> > >", "").strip() |
|
|
|
content_text = [preprocess_email_content(item) for item in emails['message']] |
|
train_content, _ = train_test_split(content_text, train_size=0.00005) |
|
|
|
|
|
client = chromadb.Client() |
|
collection = client.create_collection(name="Enron_emails") |
|
collection.add(documents=train_content, ids=[f'id{i+1}' for i in range(len(train_content))]) |
|
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') |
|
model = GPT2LMHeadModel.from_pretrained('gpt2') |
|
tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
|
|
|
|
|
tokenized_emails = tokenizer(train_content, truncation=True, padding=True) |
|
with open('tokenized_emails.txt', 'w') as file: |
|
for ids in tokenized_emails['input_ids']: |
|
file.write(' '.join(map(str, ids)) + '\n') |
|
|
|
dataset = TextDataset(tokenizer=tokenizer, file_path='tokenized_emails.txt', block_size=128) |
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
|
training_args = TrainingArguments( |
|
output_dir='./output', |
|
num_train_epochs=3, |
|
per_device_train_batch_size=8 |
|
) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
data_collator=data_collator, |
|
train_dataset=dataset |
|
) |
|
trainer.train() |
|
|
|
|
|
model.save_pretrained("./fine_tuned_model") |
|
tokenizer.save_pretrained("./fine_tuned_model") |
|
|
|
|
|
def question_answer(question): |
|
try: |
|
generated = text_gen(question, max_length=200, num_return_sequences=1) |
|
generated_text = generated[0]['generated_text'].replace(question, "").strip() |
|
return generated_text |
|
except Exception as e: |
|
return f"Error in generating response: {str(e)}" |
|
|
|
text_gen = pipeline("text-generation", model=model, tokenizer=tokenizer) |
|
iface = gr.Interface( |
|
fn=question_answer, |
|
inputs="text", |
|
outputs="text", |
|
title="Answering questions about the Enron case.", |
|
description="Ask a question about the Enron case!", |
|
examples=["What is Eron?"] |
|
) |
|
iface.launch() |