Joshua Lochner commited on
Commit
8981122
1 Parent(s): 52340fc

Remove redundant calls to change device

Browse files
Files changed (3) hide show
  1. src/model.py +1 -2
  2. src/predict.py +2 -2
  3. src/shared.py +0 -4
src/model.py CHANGED
@@ -1,6 +1,6 @@
1
  from huggingface_hub import hf_hub_download
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
- from shared import CustomTokens, device
4
  from errors import ClassifierLoadError, ModelLoadError
5
  from functools import lru_cache
6
  import pickle
@@ -100,7 +100,6 @@ def get_model_tokenizer(model_name_or_path, cache_dir=None):
100
  # Load pretrained model and tokenizer
101
  model = AutoModelForSeq2SeqLM.from_pretrained(
102
  model_name_or_path, cache_dir=cache_dir)
103
- model.to(device())
104
 
105
  tokenizer = AutoTokenizer.from_pretrained(
106
  model_name_or_path, max_length=model.config.d_model, cache_dir=cache_dir)
 
1
  from huggingface_hub import hf_hub_download
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
+ from shared import CustomTokens
4
  from errors import ClassifierLoadError, ModelLoadError
5
  from functools import lru_cache
6
  import pickle
 
100
  # Load pretrained model and tokenizer
101
  model = AutoModelForSeq2SeqLM.from_pretrained(
102
  model_name_or_path, cache_dir=cache_dir)
 
103
 
104
  tokenizer = AutoTokenizer.from_pretrained(
105
  model_name_or_path, max_length=model.config.d_model, cache_dir=cache_dir)
src/predict.py CHANGED
@@ -10,7 +10,7 @@ import logging
10
  import os
11
  import itertools
12
  from utils import re_findall
13
- from shared import CustomTokens, START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE, OutputArguments, device, seconds_to_time
14
  from typing import Optional
15
  from segment import (
16
  generate_segments,
@@ -301,7 +301,7 @@ CATEGORIES = [None, 'SPONSOR', 'SELFPROMO', 'INTERACTION']
301
  def predict_sponsor_text(text, model, tokenizer):
302
  """Given a body of text, predict the words which are part of the sponsor"""
303
  input_ids = tokenizer(
304
- f'{CustomTokens.EXTRACT_SEGMENTS_PREFIX.value} {text}', return_tensors='pt', truncation=True).input_ids.to(device())
305
 
306
  max_out_len = round(min(
307
  max(
 
10
  import os
11
  import itertools
12
  from utils import re_findall
13
+ from shared import CustomTokens, START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE, OutputArguments, seconds_to_time
14
  from typing import Optional
15
  from segment import (
16
  generate_segments,
 
301
  def predict_sponsor_text(text, model, tokenizer):
302
  """Given a body of text, predict the words which are part of the sponsor"""
303
  input_ids = tokenizer(
304
+ f'{CustomTokens.EXTRACT_SEGMENTS_PREFIX.value} {text}', return_tensors='pt', truncation=True).input_ids
305
 
306
  max_out_len = round(min(
307
  max(
src/shared.py CHANGED
@@ -107,10 +107,6 @@ class GeneralArguments:
107
  torch.cuda.manual_seed_all(self.seed)
108
 
109
 
110
- def device():
111
- return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
112
-
113
-
114
  def seconds_to_time(seconds, remove_leading_zeroes=False):
115
  fractional = round(seconds % 1, 3)
116
  fractional = '' if fractional == 0 else str(fractional)[1:]
 
107
  torch.cuda.manual_seed_all(self.seed)
108
 
109
 
 
 
 
 
110
  def seconds_to_time(seconds, remove_leading_zeroes=False):
111
  fractional = round(seconds % 1, 3)
112
  fractional = '' if fractional == 0 else str(fractional)[1:]