SunderAli17 commited on
Commit
e615809
1 Parent(s): abb3181

Create ToonMage/utils.py

Browse files
Files changed (1) hide show
  1. ToonMage/utils.py +76 -0
ToonMage/utils.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+ import random
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from transformers import PretrainedConfig
10
+
11
+
12
+ def seed_everything(seed):
13
+ os.environ["PL_GLOBAL_SEED"] = str(seed)
14
+ random.seed(seed)
15
+ np.random.seed(seed)
16
+ torch.manual_seed(seed)
17
+ torch.cuda.manual_seed_all(seed)
18
+
19
+
20
+ def is_torch2_available():
21
+ return hasattr(F, "scaled_dot_product_attention")
22
+
23
+
24
+ def instantiate_from_config(config):
25
+ if "target" not in config:
26
+ if config == '__is_first_stage__' or config == "__is_unconditional__":
27
+ return None
28
+ raise KeyError("Expected key `target` to instantiate.")
29
+ return get_obj_from_str(config["target"])(**config.get("params", {}))
30
+
31
+
32
+ def get_obj_from_str(string, reload=False):
33
+ module, cls = string.rsplit(".", 1)
34
+ if reload:
35
+ module_imp = importlib.import_module(module)
36
+ importlib.reload(module_imp)
37
+ return getattr(importlib.import_module(module, package=None), cls)
38
+
39
+
40
+ def drop_seq_token(seq, drop_rate=0.5):
41
+ idx = torch.randperm(seq.size(1))
42
+ num_keep_tokens = int(len(idx) * (1 - drop_rate))
43
+ idx = idx[:num_keep_tokens]
44
+ seq = seq[:, idx]
45
+ return seq
46
+
47
+
48
+ def import_model_class_from_model_name_or_path(
49
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
50
+ ):
51
+ text_encoder_config = PretrainedConfig.from_pretrained(
52
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
53
+ )
54
+ model_class = text_encoder_config.architectures[0]
55
+
56
+ if model_class == "CLIPTextModel":
57
+ from transformers import CLIPTextModel
58
+
59
+ return CLIPTextModel
60
+ elif model_class == "CLIPTextModelWithProjection": # noqa RET505
61
+ from transformers import CLIPTextModelWithProjection
62
+
63
+ return CLIPTextModelWithProjection
64
+ else:
65
+ raise ValueError(f"{model_class} is not supported.")
66
+
67
+
68
+ def resize_numpy_image_long(image, resize_long_edge=768):
69
+ h, w = image.shape[:2]
70
+ if max(h, w) <= resize_long_edge:
71
+ return image
72
+ k = resize_long_edge / max(h, w)
73
+ h = int(h * k)
74
+ w = int(w * k)
75
+ image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
76
+ return image