GRMenon commited on
Commit
480b549
1 Parent(s): bb462c0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +11 -9
README.md CHANGED
@@ -14,10 +14,10 @@ This model was trained using AutoTrain. For more information, please visit [Auto
14
  # Usage
15
 
16
  ```python
17
-
18
  from transformers import AutoModelForCausalLM, AutoTokenizer
19
 
20
- model_path = "PATH_TO_THIS_REPO"
21
 
22
  tokenizer = AutoTokenizer.from_pretrained(model_path)
23
  model = AutoModelForCausalLM.from_pretrained(
@@ -26,15 +26,17 @@ model = AutoModelForCausalLM.from_pretrained(
26
  torch_dtype='auto'
27
  ).eval()
28
 
29
- # Prompt content: "hi"
 
 
30
  messages = [
31
- {"role": "user", "content": "hi"}
32
  ]
33
 
34
- input_ids = tokenizer.apply_chat_template(conversation=messages, tokenize=True, add_generation_prompt=True, return_tensors='pt')
35
- output_ids = model.generate(input_ids.to('cuda'))
36
- response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
37
 
38
- # Model response: "Hello! How can I assist you today?"
39
- print(response)
40
  ```
 
14
  # Usage
15
 
16
  ```python
17
+ import torch
18
  from transformers import AutoModelForCausalLM, AutoTokenizer
19
 
20
+ model_path = "GRMenon/Mental-Mistral-7B-Instruct-AutoTrain"
21
 
22
  tokenizer = AutoTokenizer.from_pretrained(model_path)
23
  model = AutoModelForCausalLM.from_pretrained(
 
26
  torch_dtype='auto'
27
  ).eval()
28
 
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+
31
+ # Prompt content:
32
  messages = [
33
+ {"role": "user", "content": "Hey Connor! I have been feeling a bit down lately. I could really use some advice on how to feel better?"}
34
  ]
35
 
36
+ input_ids = tokenizer.apply_chat_template(conversation=messages, tokenize=True, add_generation_prompt=True, return_tensors='pt').to(device)
37
+ output_ids = model.generate(input_ids=input_ids, max_new_tokens=512, do_sample=True, pad_token_id=2)
38
+ response = tokenizer.batch_decode(output_ids.detach().cpu().numpy(), skip_special_tokens = True)
39
 
40
+ # Model response:
41
+ print(response[0])
42
  ```