user2434 commited on
Commit
746b99c
β€’
1 Parent(s): 22934e7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import chromadb
3
+ from sklearn.model_selection import train_test_split
4
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments, pipeline
5
+ import gradio as gr
6
+ import email
7
+
8
+ # loading and preprocessing dataset
9
+ emails = pd.read_csv('emails.csv')
10
+ def preprocess_email_content(raw_email):
11
+ message = email.message_from_string(raw_email).get_payload()
12
+ return message.replace("\n", "").replace("\r", "").replace("> >>> > >", "").strip()
13
+
14
+ content_text = [preprocess_email_content(item) for item in emails['message']]
15
+ train_content, _ = train_test_split(content_text, train_size=0.00005) # was unable to load more emails due to capacity constraints
16
+
17
+ # ChromaDB setup
18
+ client = chromadb.Client()
19
+ collection = client.create_collection(name="Enron_emails")
20
+ collection.add(documents=train_content, ids=[f'id{i+1}' for i in range(len(train_content))])
21
+
22
+ # model and tokenizer
23
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
24
+ model = GPT2LMHeadModel.from_pretrained('gpt2')
25
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
26
+
27
+ # tokenizing and training
28
+ tokenized_emails = tokenizer(train_content, truncation=True, padding=True)
29
+ with open('tokenized_emails.txt', 'w') as file:
30
+ for ids in tokenized_emails['input_ids']:
31
+ file.write(' '.join(map(str, ids)) + '\n')
32
+
33
+ dataset = TextDataset(tokenizer=tokenizer, file_path='tokenized_emails.txt', block_size=128)
34
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
35
+ training_args = TrainingArguments(
36
+ output_dir='./output',
37
+ num_train_epochs=3,
38
+ per_device_train_batch_size=8
39
+ )
40
+
41
+ trainer = Trainer(
42
+ model=model,
43
+ args=training_args,
44
+ data_collator=data_collator,
45
+ train_dataset=dataset
46
+ )
47
+ trainer.train()
48
+
49
+ # saving the model
50
+ model.save_pretrained("./fine_tuned_model")
51
+ tokenizer.save_pretrained("./fine_tuned_model")
52
+
53
+ # Gradio interface
54
+ def question_answer(question):
55
+ try:
56
+ generated = text_gen(question, max_length=200, num_return_sequences=1)
57
+ generated_text = generated[0]['generated_text'].replace(question, "").strip()
58
+ return generated_text
59
+ except Exception as e:
60
+ return f"Error in generating response: {str(e)}"
61
+
62
+ text_gen = pipeline("text-generation", model=model, tokenizer=tokenizer)
63
+ iface = gr.Interface(
64
+ fn=question_answer,
65
+ inputs="text",
66
+ outputs="text",
67
+ title="Answering questions about the Enron case.",
68
+ description="Ask a question about the Enron case!",
69
+ examples=["What is Eron?"]
70
+ )
71
+ iface.launch()