Update app.py
Browse files
app.py
CHANGED
@@ -391,12 +391,11 @@
|
|
391 |
|
392 |
# if __name__ == "__main__":
|
393 |
# main()
|
394 |
-
|
395 |
import streamlit as st
|
396 |
import matplotlib.pyplot as plt
|
397 |
import torch
|
398 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW
|
399 |
-
from transformers import
|
400 |
from datasets import load_dataset, Dataset
|
401 |
from evaluate import load as load_metric
|
402 |
from torch.utils.data import DataLoader
|
@@ -430,7 +429,7 @@ def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8
|
|
430 |
del raw_datasets["unsupervised"]
|
431 |
|
432 |
if model_name == "google/byt5-small":
|
433 |
-
tokenizer =
|
434 |
|
435 |
def utf8_encode_function(examples):
|
436 |
encoded_texts = [list(text.encode('utf-8')) for text in examples["text"]]
|
@@ -685,7 +684,7 @@ def main():
|
|
685 |
testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
|
686 |
|
687 |
if model_name == "google/byt5-small":
|
688 |
-
net =
|
689 |
else:
|
690 |
net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
691 |
|
@@ -790,5 +789,3 @@ def main():
|
|
790 |
|
791 |
if __name__ == "__main__":
|
792 |
main()
|
793 |
-
|
794 |
-
|
|
|
391 |
|
392 |
# if __name__ == "__main__":
|
393 |
# main()
|
|
|
394 |
import streamlit as st
|
395 |
import matplotlib.pyplot as plt
|
396 |
import torch
|
397 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW
|
398 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
399 |
from datasets import load_dataset, Dataset
|
400 |
from evaluate import load as load_metric
|
401 |
from torch.utils.data import DataLoader
|
|
|
429 |
del raw_datasets["unsupervised"]
|
430 |
|
431 |
if model_name == "google/byt5-small":
|
432 |
+
tokenizer = T5Tokenizer.from_pretrained(model_name)
|
433 |
|
434 |
def utf8_encode_function(examples):
|
435 |
encoded_texts = [list(text.encode('utf-8')) for text in examples["text"]]
|
|
|
684 |
testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator)
|
685 |
|
686 |
if model_name == "google/byt5-small":
|
687 |
+
net = T5ForConditionalGeneration.from_pretrained(model_name).to(DEVICE)
|
688 |
else:
|
689 |
net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE)
|
690 |
|
|
|
789 |
|
790 |
if __name__ == "__main__":
|
791 |
main()
|
|
|
|