Raven-v0.1 / README.md
erfanzar's picture
Update README.md
b8b3287 verified
---
library_name: transformers
datasets:
- erfanzar/MoD-Prompts
- erfanzar/GPT-4-Prompts
language:
- en
pipeline_tag: text-generation
---
# Raven Fine-Tuned Gemma-2B
Raven is a Fine-tuned version of google/gemma-2 whith same prompting style of gemma-2b-it which trained Using TPU VM v4-64 and [EasyDeL](https://github.com/erfanzar/EasyDeL)
both fine-tuning and serving code are available and it's recommended to use JAX-EasyDeL Gemma since HF-Gemma implementaion is Wrong.
### Serving and Using Raven
```python
from EasyDel import JAXServer, JAXServerConfig, EasyServe
from fjformer import get_dtype
from EasyDel.serve.prompters import GemmaPrompter, Llama2Prompter, OpenChatPrompter, ChatMLPrompter
from EasyDel.serve.prompters.base_prompter import BasePrompter
from jax import numpy as jnp, lax
import jax
from typing import List, Union, Optional
max_sequence_length = 8192
max_compile_tokens = 256
max_new_tokens_ratio = 25
dtype = "fp16"
prompter_type = "gemma"
sharding_axis_dims = (1, 1, 1, -1)
pretrained_model_name_or_path = "erfanzar/Raven-v0.1"
attn_mechanism = "normal"
scan_mlp_chunk_size = max_compile_tokens
use_scan_mlp = True
scan_ring_attention = True
block_k = 128
block_q = 128
use_sharded_kv_caching = False
server_config = JAXServerConfig(
max_sequence_length=max_sequence_length,
max_compile_tokens=max_compile_tokens,
max_new_tokens=max_compile_tokens * max_new_tokens_ratio,
dtype=dtype,
pre_compile=False,
eos_token_id=107
)
prompters = {
"gemma": GemmaPrompter(),
"llama": Llama2Prompter(),
"openchat": OpenChatPrompter(),
"chatml": ChatMLPrompter()
}
prompter: BasePrompter = prompters[prompter_type]
class JAXServerC(JAXServer):
@staticmethod
def format_chat(history: List[List[str]], prompt: str, system: Union[str, None]) -> str:
return prompter.format_message(
history=history,
prompt=prompt,
system_message=system,
prefix=None
)
@staticmethod
def format_instruct(system: str, instruction: str) -> str:
return prompter.format_message(
prefix=None,
system_message=system,
prompt=instruction,
history=[]
)
server = JAXServerC.from_torch_pretrained(
server_config=server_config,
pretrained_model_name_or_path=pretrained_model_name_or_path,
device=jax.devices('cpu')[0],
dtype=get_dtype(dtype=dtype),
param_dtype=get_dtype(dtype=dtype),
precision=jax.lax.Precision("fastest"),
sharding_axis_dims=sharding_axis_dims,
sharding_axis_names=("dp", "fsdp", "tp", "sp"),
input_shape=(1, server_config.max_sequence_length),
model_config_kwargs=dict(
fully_sharded_data_parallel=True,
attn_mechanism=attn_mechanism,
scan_mlp_chunk_size=max_compile_tokens,
use_scan_mlp=use_scan_mlp,
scan_ring_attention=scan_ring_attention,
block_k=block_k,
block_q=block_q,
use_sharded_kv_caching=use_sharded_kv_caching
)
)
history = []
while True:
user_prompt = input("> ")
model_prompt = server.format_chat(
history,
user_prompt,
"You are an AI assistant be respect-full and explain detailed questions step by step."
)
past_response_length = 0
for response, used_tokens in server.sample(
model_prompt,
greedy=False
):
print(response[past_response_length:], end="")
past_response_length = len(response)
history.append([user_prompt, response])
```
Gradio UI is also available via `server.gradio_inference().launch()`.