Files changed (1) hide show
  1. custom_st.py +10 -10
custom_st.py CHANGED
@@ -1,13 +1,11 @@
1
- import base64
2
  import json
3
  import os
4
  from io import BytesIO
5
  from typing import Any, Dict, List, Optional, Tuple, Union
6
 
7
- import requests
8
  import torch
9
  from torch import nn
10
- from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoTokenizer
11
 
12
 
13
  class Transformer(nn.Module):
@@ -35,11 +33,11 @@ class Transformer(nn.Module):
35
  def __init__(
36
  self,
37
  model_name_or_path: str,
38
- max_seq_length: int | None = None,
39
- model_args: dict[str, Any] | None = None,
40
- tokenizer_args: dict[str, Any] | None = None,
41
- config_args: dict[str, Any] | None = None,
42
- cache_dir: str | None = None,
43
  do_lower_case: bool = False,
44
  tokenizer_name_or_path: str = None,
45
  ) -> None:
@@ -121,8 +119,10 @@ class Transformer(nn.Module):
121
  return self.auto_model.config.hidden_size
122
 
123
  def tokenize(
124
- self, texts: list[str] | list[dict] | list[tuple[str, str]], padding: str | bool = True
125
- ) -> dict[str, torch.Tensor]:
 
 
126
  """Tokenizes a text and maps tokens to token-ids"""
127
  output = {}
128
  if isinstance(texts[0], str):
 
 
1
  import json
2
  import os
3
  from io import BytesIO
4
  from typing import Any, Dict, List, Optional, Tuple, Union
5
 
 
6
  import torch
7
  from torch import nn
8
+ from transformers import AutoConfig, AutoModel, AutoTokenizer
9
 
10
 
11
  class Transformer(nn.Module):
 
33
  def __init__(
34
  self,
35
  model_name_or_path: str,
36
+ max_seq_length: int = None,
37
+ model_args: Dict[str, Any] = None,
38
+ tokenizer_args: Dict[str, Any] = None,
39
+ config_args: Dict[str, Any] = None,
40
+ cache_dir: str = None,
41
  do_lower_case: bool = False,
42
  tokenizer_name_or_path: str = None,
43
  ) -> None:
 
119
  return self.auto_model.config.hidden_size
120
 
121
  def tokenize(
122
+ self,
123
+ texts: Union[List[str], List[dict], List[Tuple[str, str]]],
124
+ padding: Union[str, bool] = True
125
+ ) -> Dict[str, torch.Tensor]:
126
  """Tokenizes a text and maps tokens to token-ids"""
127
  output = {}
128
  if isinstance(texts[0], str):