Update modeling_Llamoe.py
Browse files- modeling_Llamoe.py +19 -1
modeling_Llamoe.py
CHANGED
@@ -1225,7 +1225,25 @@ class LlamoeForCausalLM(LlammoePreTrainedModel):
|
|
1225 |
output_router_logits: Optional[bool] = None,
|
1226 |
return_dict: Optional[bool] = None,
|
1227 |
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
1228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1229 |
|
1230 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1231 |
output_router_logits = (
|
|
|
1225 |
output_router_logits: Optional[bool] = None,
|
1226 |
return_dict: Optional[bool] = None,
|
1227 |
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
1228 |
+
r"""
|
1229 |
+
Args:
|
1230 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1231 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
1232 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1233 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1234 |
+
Returns:
|
1235 |
+
Example:
|
1236 |
+
```python
|
1237 |
+
>>> from transformers import AutoTokenizer, GemmoeForCausalLM
|
1238 |
+
>>> model = GemmoeForCausalLM.from_pretrained("mistralai/Gemmoe-8x7B-v0.1")
|
1239 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Gemmoe-8x7B-v0.1")
|
1240 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
1241 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
1242 |
+
>>> # Generate
|
1243 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1244 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1245 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
1246 |
+
```"""
|
1247 |
|
1248 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1249 |
output_router_logits = (
|