|
|
|
|
|
|
|
|
|
@@ -50,6 +50,9 @@ class LlamaConfig(PretrainedConfig): |
|
Number of hidden layers in the Transformer encoder. |
|
num_attention_heads (`int`, *optional*, defaults to 32): |
|
Number of attention heads for each attention layer in the Transformer encoder. |
|
+ num_key_value_heads (`int`, *optional*, defaults to 32): |
|
+ This is the number of groups that should be used to implement GQA.When converting a multi-head checkpoint to a GQA checkpoint, we |
|
+ construct each group key and value head by meanpooling all the original heads within that group |
|
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): |
|
The non-linear activation function (function or string) in the decoder. |
|
max_position_embeddings (`int`, *optional*, defaults to 2048): |
|
@@ -97,6 +100,7 @@ class LlamaConfig(PretrainedConfig): |
|
intermediate_size=11008, |
|
num_hidden_layers=32, |
|
num_attention_heads=32, |
|
+ num_key_value_heads=32, |
|
hidden_act="silu", |
|
max_position_embeddings=2048, |
|
initializer_range=0.02, |
|
@@ -115,6 +119,7 @@ class LlamaConfig(PretrainedConfig): |
|
self.intermediate_size = intermediate_size |
|
self.num_hidden_layers = num_hidden_layers |
|
self.num_attention_heads = num_attention_heads |
|
+ self.num_key_value_heads = num_key_value_heads |
|
self.hidden_act = hidden_act |
|
self.initializer_range = initializer_range |
|
self.rms_norm_eps = rms_norm_eps |
|
|
|
|
|
|
|
|
|
@@ -59,17 +59,22 @@ INTERMEDIATE_SIZE_MAP = { |
|
"13B": 13824, |
|
"30B": 17920, |
|
"65B": 22016, |
|
+ "70B": 28672, |
|
} |
|
NUM_SHARDS = { |
|
"7B": 1, |
|
+ "7Bf": 1, |
|
"13B": 2, |
|
+ "13Bf": 2, |
|
"30B": 4, |
|
"65B": 8, |
|
+ "70B": 8, |
|
+ "70Bf": 8, |
|
} |
|
|
|
|
|
-def compute_intermediate_size(n): |
|
- return int(math.ceil(n * 8 / 3) + 255) // 256 * 256 |
|
+def compute_intermediate_size(n, ffn_dim_multiplier=1): |
|
+ return int((math.ceil(n * 8 / 3) + 255) * ffn_dim_multiplier // 256 * 256) |
|
|
|
|
|
def read_json(path): |
|
@@ -82,7 +87,7 @@ def write_json(text, path): |
|
json.dump(text, f) |
|
|
|
|
|
-def write_model(model_path, input_base_path, model_size): |
|
+def write_model(model_path, input_base_path, model_size, safe_serialization=True): |
|
os.makedirs(model_path, exist_ok=True) |
|
tmp_model_path = os.path.join(model_path, "tmp") |
|
os.makedirs(tmp_model_path, exist_ok=True) |
|
@@ -97,9 +102,17 @@ def write_model(model_path, input_base_path, model_size): |
|
base = 10000.0 |
|
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) |
|
|
|
+ if "n_kv_heads" in params: |
|
+ num_key_value_heads = params["n_kv_heads"] # for GQA / MQA |
|
+ num_local_key_value_heads = n_heads_per_shard // num_key_value_heads |
|
+ key_value_dim = dim//num_key_value_heads |
|
+ else: # compatibility with other checkpoints |
|
+ num_key_value_heads = n_heads |
|
+ num_local_key_value_heads = n_heads_per_shard |
|
+ key_value_dim = dim |
|
# permute for sliced rotary |
|
- def permute(w): |
|
- return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim) |
|
+ def permute(w, n_heads = n_heads,dim1=dim, dim2=dim): |
|
+ return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) |
|
|
|
print(f"Fetching all parameters from the checkpoint at {input_base_path}.") |
|
# Load weights |
|
@@ -160,19 +173,19 @@ def write_model(model_path, input_base_path, model_size): |
|
state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( |
|
torch.cat( |
|
[ |
|
- loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(n_heads_per_shard, dims_per_head, dim) |
|
+ loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(num_local_key_value_heads, dims_per_head, dim) |
|
for i in range(num_shards) |
|
], |
|
dim=0, |
|
- ).reshape(dim, dim) |
|
+ ).reshape(key_value_dim, dim),num_key_value_heads, key_value_dim, dim |
|
) |
|
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( |
|
[ |
|
- loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(n_heads_per_shard, dims_per_head, dim) |
|
+ loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(num_local_key_value_heads, dims_per_head, dim) |
|
for i in range(num_shards) |
|
], |
|
dim=0, |
|
- ).reshape(dim, dim) |
|
+ ).reshape(key_value_dim, dim) |
|
|
|
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( |
|
[loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1 |
|
@@ -218,13 +231,14 @@ def write_model(model_path, input_base_path, model_size): |
|
# Write configs |
|
index_dict["metadata"] = {"total_size": param_count * 2} |
|
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) |
|
- |
|
+ ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1 |
|
config = LlamaConfig( |
|
hidden_size=dim, |
|
- intermediate_size=compute_intermediate_size(dim), |
|
+ intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier), |
|
num_attention_heads=params["n_heads"], |
|
num_hidden_layers=params["n_layers"], |
|
rms_norm_eps=params["norm_eps"], |
|
+ num_key_value_heads = num_key_value_heads |
|
) |
|
config.save_pretrained(tmp_model_path) |
|
|
|
@@ -239,7 +253,7 @@ def write_model(model_path, input_base_path, model_size): |
|
del model.config._name_or_path |
|
|
|
print("Saving in the Transformers format.") |
|
- model.save_pretrained(model_path) |
|
+ model.save_pretrained(model_path, safe_serialization=safe_serialization) |
|
shutil.rmtree(tmp_model_path) |
|
|
|
|
|
@@ -259,18 +273,20 @@ def main(): |
|
) |
|
parser.add_argument( |
|
"--model_size", |
|
- choices=["7B", "13B", "30B", "65B", "tokenizer_only"], |
|
+ choices=["7B", "7Bf","13B", "13Bf", "30B", "65B", "70B", "70Bf", "tokenizer_only"], |
|
) |
|
parser.add_argument( |
|
"--output_dir", |
|
help="Location to write HF model and tokenizer", |
|
) |
|
+ parser.add_argument("--safe_serialization",type=bool, help="Whether or not to save using `safetensors`.") |
|
args = parser.parse_args() |
|
if args.model_size != "tokenizer_only": |
|
write_model( |
|
model_path=args.output_dir, |
|
input_base_path=os.path.join(args.input_dir, args.model_size), |
|
model_size=args.model_size, |
|
+ safe_serialization=args.safe_serialization |
|
) |
|
spm_path = os.path.join(args.input_dir, "tokenizer.model") |
|
write_tokenizer(args.output_dir, spm_path) |
|
|
|
|
|
|
|
|
|
@@ -85,7 +85,7 @@ class LlamaRMSNorm(nn.Module): |
|
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) |
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
|
|
|
- return (self.weight * hidden_states).to(input_dtype) |
|
+ return self.weight.to(input_dtype) * hidden_states |
|
|
|
|
|
class LlamaRotaryEmbedding(torch.nn.Module): |
|
@@ -204,6 +204,16 @@ class LlamaMLP(nn.Module): |
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
|
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
+ """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" |
|
+ bs, n_kv_heads, slen, head_dim = hidden_states.shape |
|
+ if n_rep == 1: |
|
+ return hidden_states |
|
+ hidden_states = hidden_states[:, :, None, :, :].expand(bs, n_kv_heads, n_rep, slen, head_dim) |
|
+ return hidden_states.reshape(bs, n_kv_heads * n_rep, slen, head_dim) |
|
+ |
|
+ |
|
+ |
|
class LlamaAttention(nn.Module): |
|
"""Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
|
@@ -213,6 +223,8 @@ class LlamaAttention(nn.Module): |
|
self.hidden_size = config.hidden_size |
|
self.num_heads = config.num_attention_heads |
|
self.head_dim = self.hidden_size // self.num_heads |
|
+ self.num_key_value_heads = config.num_key_value_heads |
|
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
self.max_position_embeddings = config.max_position_embeddings |
|
|
|
if (self.head_dim * self.num_heads) != self.hidden_size: |
|
@@ -221,8 +233,8 @@ class LlamaAttention(nn.Module): |
|
f" and `num_heads`: {self.num_heads})." |
|
) |
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
|
- self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
|
- self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
|
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
|
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) |
|
self._init_rope() |
|
|
|
@@ -243,9 +255,6 @@ class LlamaAttention(nn.Module): |
|
else: |
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}") |
|
|
|
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
|
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
|
- |
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
@@ -258,8 +267,8 @@ class LlamaAttention(nn.Module): |
|
bsz, q_len, _ = hidden_states.size() |
|
|
|
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
- key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
- value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
|
kv_seq_len = key_states.shape[-2] |
|
if past_key_value is not None: |
|
@@ -275,6 +284,9 @@ class LlamaAttention(nn.Module): |
|
|
|
past_key_value = (key_states, value_states) if use_cache else None |
|
|
|
+ # repeat k/v heads if n_kv_heads < n_heads |
|
+ key_states = repeat_kv(key_states, self.num_key_value_groups) # (bs, n_heads, seqlen, head_dim) |
|
+ value_states = repeat_kv(value_states, self.num_key_value_groups) # (bs, n_heads, seqlen, head_dim) |
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) |
|
|
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): |
|
|
|
|
|
|
|
|
|
@@ -21,13 +21,15 @@ |
|
"""Tokenization classes for LLaMA.""" |
|
import os |
|
from shutil import copyfile |
|
-from typing import Any, Dict, List, Optional, Tuple |
|
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple |
|
|
|
import sentencepiece as spm |
|
|
|
from ...tokenization_utils import AddedToken, PreTrainedTokenizer |
|
from ...utils import logging |
|
|
|
+if TYPE_CHECKING: |
|
+ from transformers.pipelines.conversational import Conversation |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
@@ -46,6 +48,7 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { |
|
} |
|
SPIECE_UNDERLINE = "▁" |
|
|
|
+B_INST, E_INST = "[INST]", "[/INST]" |
|
|
|
class LlamaTokenizer(PreTrainedTokenizer): |
|
""" |
|
@@ -314,3 +317,34 @@ class LlamaTokenizer(PreTrainedTokenizer): |
|
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) |
|
|
|
return output |
|
+ |
|
+ def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: |
|
+ """Builds the input ids for a conversation. |
|
+ This is the format used in the provided examples. " |
|
+ ``` |
|
+ <bos>[INST] Prompt [/INST] Answer <eos> |
|
+ <bos>[INST] Prompt [/INST] |
|
+ ``` |
|
+ Args: |
|
+ conversation (`Conversation`): |
|
+ Conversation to build input ids for. |
|
+ Returns: |
|
+ `List[int]`: |
|
+ Input ids for the conversation. |
|
+ """ |
|
+ dialogue = list(conversation.iter_texts()) |
|
+ if not all([is_user for is_user, msg in dialogue[::2]]) or not all([not is_user for is_user, msg in dialogue[1::2]]): |
|
+ raise ValueError( |
|
+ "The model only supports 'user' and 'assistant' roles, starting with user and alternating (u/a/u/a/u...)" |
|
+ ) |
|
+ dialog_tokens: List[int] = sum( |
|
+ [ |
|
+ [self.bos_token_id]+self.encode(f"{B_INST} {(prompt[1]).strip()} {E_INST} {(answer[1]).strip()} ", add_special_tokens = False) + [self.eos_token_id] |
|
+ for prompt, answer in zip(dialogue[::2], dialogue[1::2]) |
|
+ ], |
|
+ [], |
|
+ ) |
|
+ if not (dialogue[-1][0]): |
|
+ raise ValueError(f"Last message must be from user, got {dialogue[-1]['role']}") |
|
+ dialog_tokens += [self.bos_token_id] + self.encode(f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens = False) |
|
+ return dialog_tokens |
|
\ No newline at end of file |
|
|
|
|
|
|
|
|
|
@@ -33,6 +33,12 @@ else: |
|
logger = logging.get_logger(__name__) |
|
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"} |
|
|
|
+B_INST, E_INST = "[INST]", "[/INST]" |
|
+B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" |
|
+DEFAULT_SYSTEM_PROMPT = """\ |
|
+You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. |
|
+ |
|
+If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" |
|
|
|
class LlamaTokenizerFast(PreTrainedTokenizerFast): |
|
""" |
|
@@ -171,3 +177,43 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast): |
|
copyfile(self.vocab_file, out_vocab_file) |
|
|
|
return (out_vocab_file,) |
|
+ |
|
+ def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: |
|
+ """Builds the input ids for a conversation. |
|
+ This is the format used in the provided examples. System prompts should be manually added |
|
+ at the beginning of the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will |
|
+ be used. |
|
+ ``` |
|
+ <bos>[INST] Prompt [/INST] Answer <eos> |
|
+ <bos>[INST] Prompt [/INST] |
|
+ ``` |
|
+ Args: |
|
+ conversation (`Conversation`): |
|
+ Conversation to build input ids for. |
|
+ Returns: |
|
+ `List[int]`: |
|
+ Input ids for the conversation. |
|
+ """ |
|
+ dialogue = list(conversation.iter_texts()) |
|
+ if not all([is_user for is_user, msg in dialogue[::2]]) or not all([not is_user for is_user, msg in dialogue[1::2]]): |
|
+ raise ValueError( |
|
+ "The model only supports 'user' and 'assistant' roles, starting with user and alternating (u/a/u/a/u...)" |
|
+ ) |
|
+ |
|
+ # TODO add system prompt |
|
+ dialog_tokens: List[int] = [] |
|
+ if B_SYS not in conversation.past_user_inputs[0]: |
|
+ conversation.past_user_inputs[0] = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0] |
|
+ |
|
+ |
|
+ dialog_tokens += sum( |
|
+ [ |
|
+ [self.bos_token_id]+self.encode(f"{B_INST} {(prompt[1]).strip()} {E_INST} {(answer[1]).strip()} ", add_special_tokens = False) + [self.eos_token_id] |
|
+ for prompt, answer in zip(dialogue[::2], dialogue[1::2]) |
|
+ ], |
|
+ [], |
|
+ ) |
|
+ if not (dialogue[-1][0]): |
|
+ raise ValueError(f"Last message must be from user, got {dialogue[-1]['role']}") |
|
+ dialog_tokens += [self.bos_token_id] + self.encode(f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens = False) |
|
+ return dialog_tokens |
|
\ No newline at end of file |
|
|
|
|
|
|
|
|
|
@@ -365,3 +365,65 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi |
|
|
|
# The output should be different for long inputs |
|
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) |
|
+class LlamaIntegrationTest(unittest.TestCase): |
|
+ |
|
+ def test_model_7b_logits(self): |
|
+ input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] |
|
+ model = LlamaForCausalLM.from_pretrained("/raid/arthur/llama-7b", device_map = "auto") |
|
+ out = model(torch.tensor(input_ids)) |
|
+ # Expected mean on dim = -1 |
|
+ EXPECTED_MEAN = torch.tensor([[-6.6550, -4.1227, -4.9859, -3.2406, 0.8262, -3.0033, 1.2964, -3.3699]]) |
|
+ # slicing logits[0, 0, 0:30] |
|
+ EXPECTED_SLICE = torch.tensor([-12.8281, -7.4453, -0.4639, -8.0625, -7.2500, -8.0000, -6.4883, |
|
+ -7.7695, -7.8438, -7.0312, -6.2188, -7.1328, -1.8496, 1.9961, |
|
+ -8.6250, -6.7227, -12.8281, -6.9492, -7.0742, -7.7852, -7.5820, |
|
+ -7.9062, -6.9375, -7.9805, -8.3438, -8.1562, -8.0469, -7.6250, |
|
+ -7.7422, -7.3398]) |
|
+ |
|
+ def test_model_7bf_logits(self): |
|
+ input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] |
|
+ model = LlamaForCausalLM.from_pretrained("/raid/arthur/llama-7bf", device_map = "auto") |
|
+ out = model(torch.tensor(input_ids)) |
|
+ # Expected mean on dim = -1 |
|
+ EXPECTED_MEAN = torch.tensor([ 0.0719, -4.1667, -3.4864, -4.6226, 1.7280, -3.6511, 1.0122, -0.1268]) |
|
+ # slicing logits[0, 0, 0:30] |
|
+ EXPECTED_SLICE = torch.tensor([ 0.1038, -0.2218, 0.3132, -0.8379, 1.5576, 2.6680, 1.5811, 2.5078, |
|
+ 1.2129, 0.3484, 1.6602, 0.8213, 0.6294, 0.4907, 1.2588, 0.3982, |
|
+ 0.1039, 1.9062, 0.6665, 1.0439, 0.5850, 1.8535, 2.3828, 1.8096, |
|
+ 1.0498, 1.4629, 1.3506, 2.8574, 1.3447, 1.9971]) |
|
+ |
|
+ def test_model_13b_logits(self): |
|
+ input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] |
|
+ model = LlamaForCausalLM.from_pretrained("/raid/arthur/llama-13b", device_map = "auto") |
|
+ out = model(torch.tensor(input_ids)) |
|
+ # Expected mean on dim = -1 |
|
+ EXPECTED_MEAN = torch.tensor([[-2.0622, -1.2794, -1.1638, -0.9788, -1.4603, -1.0238, -1.7893, -1.4411]],dtype=torch.float32) |
|
+ # slicing logits[0, 0, 0:30] |
|
+ EXPECTED_SLICE = torch.tensor([-8.1406, -8.0547, 2.7461, -1.2344, -0.1448, -1.8262, -1.0020, -1.8154, |
|
+ -1.6895, -1.8516, -2.3574, -0.9277, 3.7598, 6.5742, -1.2998, -0.1177, |
|
+ -8.1406, -2.9688, -2.9199, -3.1699, -3.5254, -2.3555, -2.7988, -3.4141, |
|
+ -2.8262, -4.5195, -3.3379, -3.3164, -2.7832, -3.0273]) |
|
+ |
|
+ |
|
+ |
|
+ def test_model_13bf_logits(self): |
|
+ input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] |
|
+ model = LlamaForCausalLM.from_pretrained("/raid/arthur/llama-13bf", device_map = "auto") |
|
+ out = model(torch.tensor(input_ids)) |
|
+ # Expected mean on dim = -1 |
|
+ EXPECTED_MEAN = torch.tensor([[-0.8562, -1.8520, -0.7551, -0.4162, -1.5161, -1.2038, -2.4823, -2.3254]]) |
|
+ # slicing logits[0, 0, 0:30] |
|
+ EXPECTED_SLICE = torch.tensor([-2.2227, 4.8828, 0.9023, -0.4578, -0.7871, -0.1033, -0.6221, -0.5786, |
|
+ -0.7803, -1.0674, -1.2920, -0.1570, 0.8008, 2.0723, -0.9497, 0.2771, |
|
+ -2.2227, -0.7612, -1.4346, -1.2061, -1.6426, -0.3000, -0.7139, -1.1934, |
|
+ -1.8691, -1.6973, -1.5947, -1.2705, -0.3523, -0.5513]) |
|
+ |
|
+ def test_model_70b_logits(self): |
|
+ input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] |
|
+ |
|
+ EXPECTED_MEAN = torch.tensor([-9.4922, -3.9551, 1.7998, -5.6758, -5.1055, -5.8984, -4.8320, -6.8086, |
|
+ -6.5391, -5.6172, -5.5820, -5.5352, 1.7881, 3.6289, -6.5117, -3.4785, |
|
+ -9.5000, -6.0352, -6.8125, -6.0195, -6.6836, -5.4727, -6.2812, -6.0391, |
|
+ -7.3398, -7.4297, -7.4844, -6.5820, -5.8789, -5.5312],dtype=torch.float32) |
|
+ EXPECTED_SLICE = torch.tensor([[-4.2327, -3.3360, -4.6665, -4.7631, -1.8180, -3.4170, -1.4211, -3.1810]],dtype=torch.float32) |
|
+ |
|
|