bawolf commited on
Commit
c7e92d2
1 Parent(s): 7aa93af

predict fix

Browse files
Files changed (2) hide show
  1. .dockerignore +22 -0
  2. predict.py +6 -3
.dockerignore CHANGED
@@ -19,3 +19,25 @@ coverage.xml
19
  .mypy_cache
20
  .pytest_cache
21
  .hypothesis
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  .mypy_cache
20
  .pytest_cache
21
  .hypothesis
22
+
23
+ # generated by replicate/cog
24
+ __pycache__
25
+ *.pyc
26
+ *.pyo
27
+ *.pyd
28
+ .Python
29
+ env
30
+ pip-log.txt
31
+ pip-delete-this-directory.txt
32
+ .tox
33
+ .coverage
34
+ .coverage.*
35
+ .cache
36
+ nosetests.xml
37
+ coverage.xml
38
+ *.cover
39
+ *.log
40
+ .git
41
+ .mypy_cache
42
+ .pytest_cache
43
+ .hypothesis
predict.py CHANGED
@@ -2,10 +2,13 @@ import os
2
  from cog import BasePredictor, Input, Path
3
  import torch
4
  import json
 
 
 
5
  from src.models.model import load_model
6
- from src.data.video_utils import create_transform, extract_frames
7
 
8
- CHECKPOINT_DIR = "runs/run_20241024-150232_otherpeopleval_large_model/"
9
 
10
  class Predictor(BasePredictor):
11
  def setup(self):
@@ -24,7 +27,7 @@ class Predictor(BasePredictor):
24
  # Load model
25
  self.model = load_model(
26
  self.config['num_classes'],
27
- os.path.join(CHECKPOINT_DIR, "best_model.pth"),
28
  self.device,
29
  self.config['clip_model']
30
  )
 
2
  from cog import BasePredictor, Input, Path
3
  import torch
4
  import json
5
+ import sys
6
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7
+
8
  from src.models.model import load_model
9
+ from src.dataset.video_utils import create_transform, extract_frames
10
 
11
+ CHECKPOINT_DIR = "checkpoints/"
12
 
13
  class Predictor(BasePredictor):
14
  def setup(self):
 
27
  # Load model
28
  self.model = load_model(
29
  self.config['num_classes'],
30
+ os.path.join(CHECKPOINT_DIR, "weights.ckpt"),
31
  self.device,
32
  self.config['clip_model']
33
  )