ffreemt commited on
Commit
34815a7
1 Parent(s): 14d4c4b

Replace model-s-cpu with load_model_s (model-s-512 model-s-512-v2 on hf)

Browse files
.gitignore CHANGED
@@ -143,3 +143,5 @@ links/
143
  node_modules
144
  install-sw.sh
145
  install-sw1.sh
 
 
 
143
  node_modules
144
  install-sw.sh
145
  install-sw1.sh
146
+ win10-install-memo.txt
147
+ model-s
st_mlbee/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
  """Init."""
2
  __version__ = "0.1.0a2"
3
- from .st_mlbee import st_mlbee
4
 
5
  __all__ = ("st_mlbee",)
 
1
  """Init."""
2
  __version__ = "0.1.0a2"
3
+ from st_mlbee.st_mlbee import st_mlbee
4
 
5
  __all__ = ("st_mlbee",)
st_mlbee/gen_cmat.py CHANGED
@@ -9,14 +9,27 @@ import numpy as np
9
  from tqdm import tqdm
10
 
11
  # from model_pool import load_model_s
12
- from hf_model_s_cpu import model_s
13
- from logzero import logger
14
 
15
- from st_mlbee.cos_matrix2 import cos_matrix2
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  try:
18
  # model = model_s()
19
- model = model_s(alive_bar_on=True)
 
20
  except Exception as _:
21
  logger.error(_)
22
  raise
 
9
  from tqdm import tqdm
10
 
11
  # from model_pool import load_model_s
12
+ # from hf_model_s_cpu import model_s # load_model_s directly
13
+ from st_mlbee.load_model_s import load_model_s
14
 
15
+ # from logzero import logger
16
+ from loguru import logger
17
+
18
+ # from st_mlbee.cos_matrix2 import cos_matrix2
19
+ from .cos_matrix2 import cos_matrix2
20
+
21
+ _ = """
22
+ try:
23
+ model_s = load_model_s()
24
+ except Exception as exc:
25
+ logger.erorr(exc)
26
+ raise
27
+ """
28
 
29
  try:
30
  # model = model_s()
31
+ # model = model_s(alive_bar_on=True)
32
+ model = load_model_s()
33
  except Exception as _:
34
  logger.error(_)
35
  raise
st_mlbee/load_model_s.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""
2
+ Load model_s from hf.
3
+
4
+ cf aslo align-model-pool\model_pool\load_model.py and ycco make-upload-model-s.ipynb.
5
+ """
6
+ import torch
7
+ import joblib
8
+ from huggingface_hub import hf_hub_download
9
+ from loguru import logger
10
+
11
+ try:
12
+ loc = hf_hub_download("mikeee/model_s_512", "model-s", local_dir=".")
13
+ except Exception as exc:
14
+ logger.error(exc)
15
+ raise SystemExit(1) from exc
16
+
17
+
18
+ def load_model_s(model_file=None):
19
+ """Load a model from hf."""
20
+ if model_file is None:
21
+ model_file = loc
22
+ try:
23
+ model = joblib.load(model_file)
24
+ except Exception as exc:
25
+ logger.error(exc)
26
+ raise
27
+ return model
tests/test_gen_cmat_old.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test old gen_cmat."""
2
+ from loadtext import loadtext
3
+ from st_mlbee.gen_cmat import gen_cmat
4
+
5
+
6
+ def test_gen_cmat_old():
7
+ """Test old gen_cam."""
8
+ paras1 = loadtext("data/sternstunden04-en.txt", splitlines=True)
9
+ paras2 = loadtext("data/sternstunden04-de.txt", True)
10
+
11
+ cmat = gen_cmat(paras1, paras2)
12
+ len1, len2 = len(paras1), len(paras2)
13
+
14
+ # note the order
15
+ assert cmat.shape == (len2, len1)
16
+
17
+
18
+ if __name__ == "__main__":
19
+ test_gen_cmat_old()
tests1/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Init."""
tests1/test_gen_cmat_old.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test old gen_cmat."""
2
+ from loadtext import loadtext
3
+ from st_mlbee.gen_cmat import gen_cmat
4
+
5
+
6
+ def test_gen_cmat_old():
7
+ """Test old gen_cam."""
8
+ paras1 = loadtext("data/sternstunden04-en.txt", splitlines=True)
9
+ paras2 = loadtext("data/sternstunden04-de.txt", True)
10
+
11
+ cmat = gen_cmat(paras1, paras2)
12
+ len1, len2 = len(paras1), len(paras2)
13
+
14
+ # note the order
15
+ assert cmat.shape == (len2, len1)
16
+
17
+
18
+ if __name__ == "__main__":
19
+ test_gen_cmat_old()