Text Generation
Transformers
Safetensors
dbrx
custom_code
text-generation-inference
Undi95 Walmart-the-bag commited on
Commit
ea57a4d
1 Parent(s): c67dc44

Update README.md (#6)

Browse files

- Update README.md (22eb92898d63ebd3fb3f5a4764aa0e7798f21b41)


Co-authored-by: wbag <[email protected]>

Files changed (1) hide show
  1. README.md +4 -4
README.md CHANGED
@@ -86,8 +86,8 @@ export HF_HUB_ENABLE_HF_TRANSFER=1
86
  from transformers import AutoTokenizer, AutoModelForCausalLM
87
  import torch
88
 
89
- tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-base", trust_remote_code=True)
90
- model = AutoModelForCausalLM.from_pretrained("databricks/dbrx-base", device_map="cpu", torch_dtype=torch.bfloat16, trust_remote_code=True)
91
 
92
  input_text = "Databricks was founded in "
93
  input_ids = tokenizer(input_text, return_tensors="pt")
@@ -101,8 +101,8 @@ print(tokenizer.decode(outputs[0]))
101
  from transformers import AutoTokenizer, AutoModelForCausalLM
102
  import torch
103
 
104
- tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-base", trust_remote_code=True)
105
- model = AutoModelForCausalLM.from_pretrained("databricks/dbrx-base", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
106
 
107
  input_text = "Databricks was founded in "
108
  input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
 
86
  from transformers import AutoTokenizer, AutoModelForCausalLM
87
  import torch
88
 
89
+ tokenizer = AutoTokenizer.from_pretrained("Undi95/dbrx-base", trust_remote_code=True)
90
+ model = AutoModelForCausalLM.from_pretrained("Undi95/dbrx-base", device_map="cpu", torch_dtype=torch.bfloat16, trust_remote_code=True)
91
 
92
  input_text = "Databricks was founded in "
93
  input_ids = tokenizer(input_text, return_tensors="pt")
 
101
  from transformers import AutoTokenizer, AutoModelForCausalLM
102
  import torch
103
 
104
+ tokenizer = AutoTokenizer.from_pretrained("Undi95/dbrx-base", trust_remote_code=True)
105
+ model = AutoModelForCausalLM.from_pretrained("Undi95/dbrx-base", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
106
 
107
  input_text = "Databricks was founded in "
108
  input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")