|
--- |
|
language: |
|
- en |
|
tags: |
|
- causal-lm |
|
license: cc-by-nc-sa-4.0 |
|
datasets: |
|
- dmayhem93/ChatCombined |
|
- tatsu-lab/alpaca |
|
- nomic-ai/gpt4all_prompt_generations |
|
- Dahoas/full-hh-rlhf |
|
- jeffwan/sharegpt_vicuna |
|
- HuggingFaceH4/databricks_dolly_15k |
|
--- |
|
|
|
# StableLM-Tuned-Alpha 16-bit |
|
|
|
## Model Description |
|
|
|
16-bit version of `StableLM-Tuned-Alpha` compressed for the sake of speed and memory usage. No other changes were made. Original model: https://huggingface.co/stabilityai/stablelm-tuned-alpha-3b |
|
|
|
## Usage |
|
|
|
Get started chatting with `StableLM-Tuned-Alpha 16-bit` by using the following code snippet: |
|
|
|
```python |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList |
|
tokenizer = AutoTokenizer.from_pretrained("vvsotnikov/stablelm-tuned-alpha-3b-16bit") |
|
model = AutoModelForCausalLM.from_pretrained("vvsotnikov/stablelm-tuned-alpha-3b-16bit", torch_dtype=torch.float16) |
|
model.cuda() |
|
class StopOnTokens(StoppingCriteria): |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
stop_ids = [50278, 50279, 50277, 1, 0] |
|
for stop_id in stop_ids: |
|
if input_ids[0][-1] == stop_id: |
|
return True |
|
return False |
|
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version) |
|
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. |
|
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. |
|
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. |
|
- StableLM will refuse to participate in anything that could harm a human. |
|
""" |
|
prompt = f"{system_prompt}<|USER|>What's your mood today?<|ASSISTANT|>" |
|
inputs = tokenizer(prompt, return_tensors="pt").to("cuda") |
|
tokens = model.generate( |
|
**inputs, |
|
max_new_tokens=64, |
|
temperature=0.7, |
|
do_sample=True, |
|
stopping_criteria=StoppingCriteriaList([StopOnTokens()]) |
|
) |
|
print(tokenizer.decode(tokens[0], skip_special_tokens=True)) |
|
``` |