alisrbdni commited on
Commit
de76e78
1 Parent(s): 40a7c41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -6
app.py CHANGED
@@ -391,12 +391,11 @@
391
 
392
  # if __name__ == "__main__":
393
  # main()
394
-
395
  import streamlit as st
396
  import matplotlib.pyplot as plt
397
  import torch
398
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW
399
- from transformers import ByT5Tokenizer, ByT5ForConditionalGeneration
400
  from datasets import load_dataset, Dataset
401
  from evaluate import load as load_metric
402
  from torch.utils.data import DataLoader
@@ -430,7 +429,7 @@ def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8
430
  del raw_datasets["unsupervised"]
431
 
432
  if model_name == "google/byt5-small":
433
- tokenizer = ByT5Tokenizer.from_pretrained(model_name)
434
 
435
  def utf8_encode_function(examples):
436
  encoded_texts = [list(text.encode('utf-8')) for text in examples["text"]]
@@ -685,7 +684,7 @@ def main():
685
  testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
686
 
687
  if model_name == "google/byt5-small":
688
- net = ByT5ForConditionalGeneration.from_pretrained(model_name).to(DEVICE)
689
  else:
690
  net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
691
 
@@ -790,5 +789,3 @@ def main():
790
 
791
  if __name__ == "__main__":
792
  main()
793
-
794
-
 
391
 
392
  # if __name__ == "__main__":
393
  # main()
 
394
  import streamlit as st
395
  import matplotlib.pyplot as plt
396
  import torch
397
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW
398
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
399
  from datasets import load_dataset, Dataset
400
  from evaluate import load as load_metric
401
  from torch.utils.data import DataLoader
 
429
  del raw_datasets["unsupervised"]
430
 
431
  if model_name == "google/byt5-small":
432
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
433
 
434
  def utf8_encode_function(examples):
435
  encoded_texts = [list(text.encode('utf-8')) for text in examples["text"]]
 
684
  testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
685
 
686
  if model_name == "google/byt5-small":
687
+ net = T5ForConditionalGeneration.from_pretrained(model_name).to(DEVICE)
688
  else:
689
  net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
690
 
 
789
 
790
  if __name__ == "__main__":
791
  main()