Spaces:
Running
Running
Joshua Lochner
commited on
Commit
•
8981122
1
Parent(s):
52340fc
Remove redundant calls to change device
Browse files- src/model.py +1 -2
- src/predict.py +2 -2
- src/shared.py +0 -4
src/model.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from huggingface_hub import hf_hub_download
|
2 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
3 |
-
from shared import CustomTokens
|
4 |
from errors import ClassifierLoadError, ModelLoadError
|
5 |
from functools import lru_cache
|
6 |
import pickle
|
@@ -100,7 +100,6 @@ def get_model_tokenizer(model_name_or_path, cache_dir=None):
|
|
100 |
# Load pretrained model and tokenizer
|
101 |
model = AutoModelForSeq2SeqLM.from_pretrained(
|
102 |
model_name_or_path, cache_dir=cache_dir)
|
103 |
-
model.to(device())
|
104 |
|
105 |
tokenizer = AutoTokenizer.from_pretrained(
|
106 |
model_name_or_path, max_length=model.config.d_model, cache_dir=cache_dir)
|
|
|
1 |
from huggingface_hub import hf_hub_download
|
2 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
3 |
+
from shared import CustomTokens
|
4 |
from errors import ClassifierLoadError, ModelLoadError
|
5 |
from functools import lru_cache
|
6 |
import pickle
|
|
|
100 |
# Load pretrained model and tokenizer
|
101 |
model = AutoModelForSeq2SeqLM.from_pretrained(
|
102 |
model_name_or_path, cache_dir=cache_dir)
|
|
|
103 |
|
104 |
tokenizer = AutoTokenizer.from_pretrained(
|
105 |
model_name_or_path, max_length=model.config.d_model, cache_dir=cache_dir)
|
src/predict.py
CHANGED
@@ -10,7 +10,7 @@ import logging
|
|
10 |
import os
|
11 |
import itertools
|
12 |
from utils import re_findall
|
13 |
-
from shared import CustomTokens, START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE, OutputArguments,
|
14 |
from typing import Optional
|
15 |
from segment import (
|
16 |
generate_segments,
|
@@ -301,7 +301,7 @@ CATEGORIES = [None, 'SPONSOR', 'SELFPROMO', 'INTERACTION']
|
|
301 |
def predict_sponsor_text(text, model, tokenizer):
|
302 |
"""Given a body of text, predict the words which are part of the sponsor"""
|
303 |
input_ids = tokenizer(
|
304 |
-
f'{CustomTokens.EXTRACT_SEGMENTS_PREFIX.value} {text}', return_tensors='pt', truncation=True).input_ids
|
305 |
|
306 |
max_out_len = round(min(
|
307 |
max(
|
|
|
10 |
import os
|
11 |
import itertools
|
12 |
from utils import re_findall
|
13 |
+
from shared import CustomTokens, START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE, OutputArguments, seconds_to_time
|
14 |
from typing import Optional
|
15 |
from segment import (
|
16 |
generate_segments,
|
|
|
301 |
def predict_sponsor_text(text, model, tokenizer):
|
302 |
"""Given a body of text, predict the words which are part of the sponsor"""
|
303 |
input_ids = tokenizer(
|
304 |
+
f'{CustomTokens.EXTRACT_SEGMENTS_PREFIX.value} {text}', return_tensors='pt', truncation=True).input_ids
|
305 |
|
306 |
max_out_len = round(min(
|
307 |
max(
|
src/shared.py
CHANGED
@@ -107,10 +107,6 @@ class GeneralArguments:
|
|
107 |
torch.cuda.manual_seed_all(self.seed)
|
108 |
|
109 |
|
110 |
-
def device():
|
111 |
-
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
112 |
-
|
113 |
-
|
114 |
def seconds_to_time(seconds, remove_leading_zeroes=False):
|
115 |
fractional = round(seconds % 1, 3)
|
116 |
fractional = '' if fractional == 0 else str(fractional)[1:]
|
|
|
107 |
torch.cuda.manual_seed_all(self.seed)
|
108 |
|
109 |
|
|
|
|
|
|
|
|
|
110 |
def seconds_to_time(seconds, remove_leading_zeroes=False):
|
111 |
fractional = round(seconds % 1, 3)
|
112 |
fractional = '' if fractional == 0 else str(fractional)[1:]
|