qingsonglv commited on
Commit
9f15841
1 Parent(s): 9ca995d

Update visual.py

Browse files
Files changed (1) hide show
  1. visual.py +136 -136
visual.py CHANGED
@@ -1,136 +1,136 @@
1
- import torch
2
- from torch import nn
3
- from argparse import Namespace
4
- import xformers.ops as xops
5
- from transformers.activations import ACT2FN
6
-
7
-
8
- class PatchEmbedding(nn.Module):
9
- def __init__(self, config):
10
- super().__init__()
11
- self.proj = nn.Conv2d(config.in_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size)
12
- self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
13
- self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size)
14
-
15
- def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)":
16
- x = self.proj(images)
17
- x = x.flatten(2).transpose(1, 2)
18
- cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
19
- x = torch.cat((cls_token, x), dim=1)
20
- x += self.position_embedding.weight.unsqueeze(0)
21
- return x
22
-
23
-
24
- class Attention(nn.Module):
25
- def __init__(self, config):
26
- super().__init__()
27
- self.num_heads = config.num_heads
28
- head_dim = config.hidden_size // config.num_heads
29
- self.scale = head_dim ** -0.5
30
- self.query_key_value = nn.Linear(config.hidden_size, config.hidden_size * 3)
31
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
32
- self.output_dropout = torch.nn.Dropout(config.dropout_prob)
33
-
34
- def forward(self, x: "tensor(B, L, D)") -> "tensor(B, L, D)":
35
- B, L, _ = x.shape
36
- qkv = self.query_key_value(x)
37
- qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4) # 3, B, L, H, D
38
- q, k, v = qkv[0], qkv[1], qkv[2]
39
-
40
- out = xops.memory_efficient_attention(
41
- q, k, v, scale=self.scale,
42
- )
43
- output = self.dense(out.view(B, L, -1))
44
- output = self.output_dropout(output)
45
- return output
46
-
47
- def attention(self, q, k, v):
48
- attn_weights = torch.matmul(q * self.scale, k.transpose(-2, -1))
49
- attn_weights = attn_weights.softmax(dim=-1)
50
- output = torch.matmul(attn_weights, v)
51
- return output
52
-
53
-
54
- class MLP(nn.Module):
55
- def __init__(self, config):
56
- super().__init__()
57
- self.config = config
58
- self.activation_fn = ACT2FN[config.hidden_act]
59
- self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
60
- self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
61
-
62
- def forward(self, x: torch.Tensor) -> torch.Tensor:
63
- x = self.fc1(x)
64
- x = self.activation_fn(x)
65
- x = self.fc2(x)
66
- return x
67
-
68
-
69
- class TransformerLayer(nn.Module):
70
- def __init__(self, config):
71
- super().__init__()
72
- self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
73
- self.attention = Attention(config)
74
- self.mlp = MLP(config)
75
- self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
76
-
77
- def forward(self, hidden_states):
78
- attention_input = hidden_states
79
- attention_output = self.input_layernorm(self.attention(attention_input))
80
- hidden_states = attention_input + attention_output
81
- mlp_input = hidden_states
82
- mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
83
- output = mlp_input + mlp_output
84
- return output
85
-
86
-
87
- class Transformer(nn.Module):
88
- def __init__(self, config):
89
- super().__init__()
90
- self.layers = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_hidden_layers)])
91
-
92
- def forward(self, hidden_states):
93
- for layer_module in self.layers:
94
- hidden_states = layer_module(hidden_states)
95
- return hidden_states
96
-
97
-
98
- class GLU(nn.Module):
99
- def __init__(self, config, in_features):
100
- super().__init__()
101
- self.linear_proj = nn.Linear(in_features, config.hidden_size, bias=False)
102
- self.norm1 = nn.LayerNorm(config.hidden_size)
103
- self.act1 = nn.GELU()
104
- self.act2 = nn.functional.silu
105
- self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
106
- self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
107
- self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
108
-
109
- def forward(self, x):
110
- x = self.linear_proj(x)
111
- x = self.act1(self.norm1(x))
112
- x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
113
- x = self.dense_4h_to_h(x)
114
- return x
115
-
116
-
117
- class EVA2CLIPModel(nn.Module):
118
- def __init__(self, config):
119
- super().__init__()
120
- vision_config = Namespace(**config.vision_config)
121
- self.patch_embedding = PatchEmbedding(vision_config)
122
- self.transformer = Transformer(vision_config)
123
- self.linear_proj = GLU(config, in_features=vision_config.hidden_size)
124
- self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
125
- self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
126
- self.pos_embed = nn.Parameter(torch.zeros((vision_config.image_size // vision_config.patch_size) ** 2, vision_config.hidden_size))
127
-
128
- def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)":
129
- x = self.patch_embedding(images)
130
- x = self.transformer(x)
131
- x = x[:, 1:]
132
- x = self.linear_proj(x + self.pos_embed.unsqueeze(0))
133
- boi = self.boi.expand(x.shape[0], -1, -1)
134
- eoi = self.eoi.expand(x.shape[0], -1, -1)
135
- x = torch.cat((boi, x, eoi), dim=1)
136
- return x
 
1
+ import torch
2
+ from torch import nn
3
+ from argparse import Namespace
4
+ import xformers.ops as xops
5
+ from transformers.activations import ACT2FN
6
+
7
+
8
+ class PatchEmbedding(nn.Module):
9
+ def __init__(self, config):
10
+ super().__init__()
11
+ self.proj = nn.Conv2d(config.in_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size)
12
+ self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
13
+ self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size)
14
+
15
+ def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)":
16
+ x = self.proj(images)
17
+ x = x.flatten(2).transpose(1, 2)
18
+ cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
19
+ x = torch.cat((cls_token, x), dim=1)
20
+ x += self.position_embedding.weight.unsqueeze(0)
21
+ return x
22
+
23
+
24
+ class Attention(nn.Module):
25
+ def __init__(self, config):
26
+ super().__init__()
27
+ self.num_heads = config.num_heads
28
+ head_dim = config.hidden_size // config.num_heads
29
+ self.scale = head_dim ** -0.5
30
+ self.query_key_value = nn.Linear(config.hidden_size, config.hidden_size * 3)
31
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
32
+ self.output_dropout = torch.nn.Dropout(config.dropout_prob)
33
+
34
+ def forward(self, x: "tensor(B, L, D)") -> "tensor(B, L, D)":
35
+ B, L, _ = x.shape
36
+ qkv = self.query_key_value(x)
37
+ qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4) # 3, B, L, H, D
38
+ q, k, v = qkv[0], qkv[1], qkv[2]
39
+
40
+ out = xops.memory_efficient_attention(
41
+ q, k, v, scale=self.scale,
42
+ )
43
+ output = self.dense(out.view(B, L, -1))
44
+ output = self.output_dropout(output)
45
+ return output
46
+
47
+ def attention(self, q, k, v):
48
+ attn_weights = torch.matmul(q * self.scale, k.transpose(-2, -1))
49
+ attn_weights = attn_weights.softmax(dim=-1)
50
+ output = torch.matmul(attn_weights, v)
51
+ return output
52
+
53
+
54
+ class MLP(nn.Module):
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.config = config
58
+ self.activation_fn = ACT2FN[config.hidden_act]
59
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
60
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
61
+
62
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
63
+ x = self.fc1(x)
64
+ x = self.activation_fn(x)
65
+ x = self.fc2(x)
66
+ return x
67
+
68
+
69
+ class TransformerLayer(nn.Module):
70
+ def __init__(self, config):
71
+ super().__init__()
72
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
73
+ self.attention = Attention(config)
74
+ self.mlp = MLP(config)
75
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
76
+
77
+ def forward(self, hidden_states):
78
+ attention_input = hidden_states
79
+ attention_output = self.input_layernorm(self.attention(attention_input))
80
+ hidden_states = attention_input + attention_output
81
+ mlp_input = hidden_states
82
+ mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
83
+ output = mlp_input + mlp_output
84
+ return output
85
+
86
+
87
+ class Transformer(nn.Module):
88
+ def __init__(self, config):
89
+ super().__init__()
90
+ self.layers = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_hidden_layers)])
91
+
92
+ def forward(self, hidden_states):
93
+ for layer_module in self.layers:
94
+ hidden_states = layer_module(hidden_states)
95
+ return hidden_states
96
+
97
+
98
+ class GLU(nn.Module):
99
+ def __init__(self, config, in_features):
100
+ super().__init__()
101
+ self.linear_proj = nn.Linear(in_features, config.hidden_size, bias=False)
102
+ self.norm1 = nn.LayerNorm(config.hidden_size)
103
+ self.act1 = nn.GELU()
104
+ self.act2 = nn.functional.silu
105
+ self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
106
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
107
+ self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
108
+
109
+ def forward(self, x):
110
+ x = self.linear_proj(x)
111
+ x = self.act1(self.norm1(x))
112
+ x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
113
+ x = self.dense_4h_to_h(x)
114
+ return x
115
+
116
+
117
+ class EVA2CLIPModel(nn.Module):
118
+ def __init__(self, config):
119
+ super().__init__()
120
+ vision_config = Namespace(**config.vision_config)
121
+ self.patch_embedding = PatchEmbedding(vision_config)
122
+ self.transformer = Transformer(vision_config)
123
+ self.linear_proj = GLU(config, in_features=vision_config.hidden_size)
124
+ self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
125
+ self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
126
+ self.pos_embed = nn.Parameter(torch.zeros((vision_config.image_size // vision_config.patch_size) ** 2, vision_config.hidden_size))
127
+
128
+ def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)":
129
+ x = self.patch_embedding(images)
130
+ x = self.transformer(x)
131
+ x = x[:, 1:]
132
+ x = self.linear_proj(x + self.pos_embed.to(x.device).unsqueeze(0))
133
+ boi = self.boi.expand(x.shape[0], -1, -1)
134
+ eoi = self.eoi.expand(x.shape[0], -1, -1)
135
+ x = torch.cat((boi, x, eoi), dim=1)
136
+ return x