zhong-al
commited on
Commit
•
ba918ab
1
Parent(s):
6ba4e1b
Revert changes
Browse files- configuration_x3d.py +8 -2
- modeling_x3d.py +11 -3
configuration_x3d.py
CHANGED
@@ -1,9 +1,15 @@
|
|
1 |
from transformers import PretrainedConfig
|
2 |
-
from
|
|
|
3 |
|
4 |
class X3DConfig(PretrainedConfig):
|
5 |
model_type = "x3d"
|
6 |
|
7 |
-
def __init__(self,
|
8 |
super().__init__(**kwargs)
|
|
|
|
|
|
|
|
|
9 |
self.cfg = load_config(path)
|
|
|
|
1 |
from transformers import PretrainedConfig
|
2 |
+
from .cfg import load_config
|
3 |
+
|
4 |
|
5 |
class X3DConfig(PretrainedConfig):
|
6 |
model_type = "x3d"
|
7 |
|
8 |
+
def __init__(self, **kwargs):
|
9 |
super().__init__(**kwargs)
|
10 |
+
|
11 |
+
path = kwargs.get("path", None)
|
12 |
+
gpu_num = kwargs.get("gpu_num", 0)
|
13 |
+
|
14 |
self.cfg = load_config(path)
|
15 |
+
self.cfg.NUM_GPUS = gpu_num
|
modeling_x3d.py
CHANGED
@@ -1,15 +1,23 @@
|
|
|
|
1 |
from transformers import PreTrainedModel
|
2 |
-
from
|
3 |
-
from
|
4 |
|
5 |
|
6 |
class X3DModel(PreTrainedModel):
|
7 |
config_class = X3DConfig
|
8 |
|
9 |
-
def __init__(self, config):
|
10 |
super().__init__(config)
|
11 |
self.model = build_model(config.cfg)
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
def forward(self, input_video):
|
14 |
outputs = self.model(input_video)
|
15 |
return outputs
|
|
|
1 |
+
import torch
|
2 |
from transformers import PreTrainedModel
|
3 |
+
from .configuration_x3d import X3DConfig
|
4 |
+
from .x3d import build_model
|
5 |
|
6 |
|
7 |
class X3DModel(PreTrainedModel):
|
8 |
config_class = X3DConfig
|
9 |
|
10 |
+
def __init__(self, config, **kwargs):
|
11 |
super().__init__(config)
|
12 |
self.model = build_model(config.cfg)
|
13 |
|
14 |
+
checkpoint = kwargs.get("checkpoint", None)
|
15 |
+
|
16 |
+
if checkpoint:
|
17 |
+
checkpoint = torch.load(
|
18 |
+
checkpoint, weights_only=True, map_location=torch.device("cpu"))
|
19 |
+
self.model.load_state_dict(checkpoint["model_state"])
|
20 |
+
|
21 |
def forward(self, input_video):
|
22 |
outputs = self.model(input_video)
|
23 |
return outputs
|