user2434 commited on
Commit
c90606a
β€’
1 Parent(s): 76393fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -33
app.py CHANGED
@@ -7,51 +7,34 @@ import email
7
 
8
  # loading and preprocessing dataset
9
  emails = pd.read_csv('emails.csv')
 
10
  def preprocess_email_content(raw_email):
11
  message = email.message_from_string(raw_email).get_payload()
12
  return message.replace("\n", "").replace("\r", "").replace("> >>> > >", "").strip()
13
 
14
  content_text = [preprocess_email_content(item) for item in emails['message']]
15
- train_content, _ = train_test_split(content_text, train_size=0.00005) # was unable to load more emails due to capacity constraints
16
 
17
  # ChromaDB setup
18
  client = chromadb.Client()
19
  collection = client.create_collection(name="Enron_emails")
20
  collection.add(documents=train_content, ids=[f'id{i+1}' for i in range(len(train_content))])
21
 
22
- # model and tokenizer
23
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
24
- model = GPT2LMHeadModel.from_pretrained('gpt2')
25
- tokenizer.add_special_tokens({'pad_token': '[PAD]'})
26
-
27
- # tokenizing and training
28
- tokenized_emails = tokenizer(train_content, truncation=True, padding=True)
29
- with open('tokenized_emails.txt', 'w') as file:
30
- for ids in tokenized_emails['input_ids']:
31
- file.write(' '.join(map(str, ids)) + '\n')
32
-
33
- dataset = TextDataset(tokenizer=tokenizer, file_path='tokenized_emails.txt', block_size=128)
34
- data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
35
- training_args = TrainingArguments(
36
- output_dir='./output',
37
- num_train_epochs=3,
38
- per_device_train_batch_size=8
39
- )
40
-
41
- trainer = Trainer(
42
- model=model,
43
- args=training_args,
44
- data_collator=data_collator,
45
- train_dataset=dataset
46
- )
47
- trainer.train()
48
 
49
- # saving the model
50
- model.save_pretrained("./fine_tuned_model")
51
- tokenizer.save_pretrained("./fine_tuned_model")
 
 
 
 
52
 
53
- # Gradio interface
54
  def question_answer(question):
 
55
  try:
56
  generated = text_gen(question, max_length=200, num_return_sequences=1)
57
  generated_text = generated[0]['generated_text'].replace(question, "").strip()
@@ -59,13 +42,12 @@ def question_answer(question):
59
  except Exception as e:
60
  return f"Error in generating response: {str(e)}"
61
 
62
- text_gen = pipeline("text-generation", model=model, tokenizer=tokenizer)
63
  iface = gr.Interface(
64
  fn=question_answer,
65
  inputs="text",
66
  outputs="text",
67
  title="Answering questions about the Enron case.",
68
  description="Ask a question about the Enron case!",
69
- examples=["What is Eron?"]
70
  )
71
  iface.launch()
 
7
 
8
  # loading and preprocessing dataset
9
  emails = pd.read_csv('emails.csv')
10
+
11
  def preprocess_email_content(raw_email):
12
  message = email.message_from_string(raw_email).get_payload()
13
  return message.replace("\n", "").replace("\r", "").replace("> >>> > >", "").strip()
14
 
15
  content_text = [preprocess_email_content(item) for item in emails['message']]
16
+ train_content, _ = train_test_split(content_text, train_size=0.00005)
17
 
18
  # ChromaDB setup
19
  client = chromadb.Client()
20
  collection = client.create_collection(name="Enron_emails")
21
  collection.add(documents=train_content, ids=[f'id{i+1}' for i in range(len(train_content))])
22
 
23
+ # initialize model and tokenizer globally but don't load them yet
24
+ tokenizer = None
25
+ model = None
26
+ text_gen = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ def load_model():
29
+ global tokenizer, model, text_gen
30
+ if model is None or tokenizer is None:
31
+ tokenizer = GPT2Tokenizer.from_pretrained('./fine_tuned_model')
32
+ model = GPT2LMHeadModel.from_pretrained('./fine_tuned_model')
33
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
34
+ text_gen = pipeline("text-generation", model=model, tokenizer=tokenizer)
35
 
 
36
  def question_answer(question):
37
+ load_model() # loading model on first use
38
  try:
39
  generated = text_gen(question, max_length=200, num_return_sequences=1)
40
  generated_text = generated[0]['generated_text'].replace(question, "").strip()
 
42
  except Exception as e:
43
  return f"Error in generating response: {str(e)}"
44
 
 
45
  iface = gr.Interface(
46
  fn=question_answer,
47
  inputs="text",
48
  outputs="text",
49
  title="Answering questions about the Enron case.",
50
  description="Ask a question about the Enron case!",
51
+ examples=["What is Enron?"]
52
  )
53
  iface.launch()