zhong-al commited on
Commit
ba918ab
1 Parent(s): 6ba4e1b

Revert changes

Browse files
Files changed (2) hide show
  1. configuration_x3d.py +8 -2
  2. modeling_x3d.py +11 -3
configuration_x3d.py CHANGED
@@ -1,9 +1,15 @@
1
  from transformers import PretrainedConfig
2
- from x3d_model.cfg import load_config
 
3
 
4
  class X3DConfig(PretrainedConfig):
5
  model_type = "x3d"
6
 
7
- def __init__(self, path: str = None, **kwargs):
8
  super().__init__(**kwargs)
 
 
 
 
9
  self.cfg = load_config(path)
 
 
1
  from transformers import PretrainedConfig
2
+ from .cfg import load_config
3
+
4
 
5
  class X3DConfig(PretrainedConfig):
6
  model_type = "x3d"
7
 
8
+ def __init__(self, **kwargs):
9
  super().__init__(**kwargs)
10
+
11
+ path = kwargs.get("path", None)
12
+ gpu_num = kwargs.get("gpu_num", 0)
13
+
14
  self.cfg = load_config(path)
15
+ self.cfg.NUM_GPUS = gpu_num
modeling_x3d.py CHANGED
@@ -1,15 +1,23 @@
 
1
  from transformers import PreTrainedModel
2
- from x3d_model.configuration_x3d import X3DConfig
3
- from x3d_model.x3d import build_model
4
 
5
 
6
  class X3DModel(PreTrainedModel):
7
  config_class = X3DConfig
8
 
9
- def __init__(self, config):
10
  super().__init__(config)
11
  self.model = build_model(config.cfg)
12
 
 
 
 
 
 
 
 
13
  def forward(self, input_video):
14
  outputs = self.model(input_video)
15
  return outputs
 
1
+ import torch
2
  from transformers import PreTrainedModel
3
+ from .configuration_x3d import X3DConfig
4
+ from .x3d import build_model
5
 
6
 
7
  class X3DModel(PreTrainedModel):
8
  config_class = X3DConfig
9
 
10
+ def __init__(self, config, **kwargs):
11
  super().__init__(config)
12
  self.model = build_model(config.cfg)
13
 
14
+ checkpoint = kwargs.get("checkpoint", None)
15
+
16
+ if checkpoint:
17
+ checkpoint = torch.load(
18
+ checkpoint, weights_only=True, map_location=torch.device("cpu"))
19
+ self.model.load_state_dict(checkpoint["model_state"])
20
+
21
  def forward(self, input_video):
22
  outputs = self.model(input_video)
23
  return outputs