Raven-v0.1 / README.md
erfanzar's picture
Update README.md
b8b3287 verified
metadata
library_name: transformers
datasets:
  - erfanzar/MoD-Prompts
  - erfanzar/GPT-4-Prompts
language:
  - en
pipeline_tag: text-generation

Raven Fine-Tuned Gemma-2B

Raven is a Fine-tuned version of google/gemma-2 whith same prompting style of gemma-2b-it which trained Using TPU VM v4-64 and EasyDeL

both fine-tuning and serving code are available and it's recommended to use JAX-EasyDeL Gemma since HF-Gemma implementaion is Wrong.

Serving and Using Raven

from EasyDel import JAXServer, JAXServerConfig, EasyServe
from fjformer import get_dtype
from EasyDel.serve.prompters import GemmaPrompter, Llama2Prompter, OpenChatPrompter, ChatMLPrompter
from EasyDel.serve.prompters.base_prompter import BasePrompter
from jax import numpy as jnp, lax
import jax
from typing import List, Union, Optional

max_sequence_length = 8192
max_compile_tokens = 256
max_new_tokens_ratio = 25
dtype = "fp16"
prompter_type = "gemma"
sharding_axis_dims = (1, 1, 1, -1)
pretrained_model_name_or_path = "erfanzar/Raven-v0.1"
attn_mechanism = "normal"
scan_mlp_chunk_size = max_compile_tokens
use_scan_mlp = True
scan_ring_attention = True
block_k = 128
block_q = 128
use_sharded_kv_caching = False

server_config = JAXServerConfig(
    max_sequence_length=max_sequence_length,
    max_compile_tokens=max_compile_tokens,
    max_new_tokens=max_compile_tokens * max_new_tokens_ratio,
    dtype=dtype,
    pre_compile=False,
    eos_token_id=107
)

prompters = {
    "gemma": GemmaPrompter(),
    "llama": Llama2Prompter(),
    "openchat": OpenChatPrompter(),
    "chatml": ChatMLPrompter()
}

prompter: BasePrompter = prompters[prompter_type]

class JAXServerC(JAXServer):
    @staticmethod
    def format_chat(history: List[List[str]], prompt: str, system: Union[str, None]) -> str:
        return prompter.format_message(
            history=history,
            prompt=prompt,
            system_message=system,
            prefix=None
        )

    @staticmethod
    def format_instruct(system: str, instruction: str) -> str:
        return prompter.format_message(
            prefix=None,
            system_message=system,
            prompt=instruction,
            history=[]
        )

server = JAXServerC.from_torch_pretrained(
    server_config=server_config,
    pretrained_model_name_or_path=pretrained_model_name_or_path,
    device=jax.devices('cpu')[0],
    dtype=get_dtype(dtype=dtype),
    param_dtype=get_dtype(dtype=dtype),
    precision=jax.lax.Precision("fastest"),
    sharding_axis_dims=sharding_axis_dims,
    sharding_axis_names=("dp", "fsdp", "tp", "sp"),
    input_shape=(1, server_config.max_sequence_length),
    model_config_kwargs=dict(
        fully_sharded_data_parallel=True,
        attn_mechanism=attn_mechanism,
        scan_mlp_chunk_size=max_compile_tokens,
        use_scan_mlp=use_scan_mlp,
        scan_ring_attention=scan_ring_attention,
        block_k=block_k,
        block_q=block_q,
        use_sharded_kv_caching=use_sharded_kv_caching
    )
)

history = []
while True:
    user_prompt = input("> ")
    model_prompt = server.format_chat(
        history,
        user_prompt,
        "You are an AI assistant be respect-full and explain detailed questions step by step."
    )

    past_response_length = 0
    
    for response, used_tokens in server.sample(
        model_prompt,
        greedy=False
    ):
        print(response[past_response_length:], end="")
        past_response_length = len(response)
    
    history.append([user_prompt, response])

Gradio UI is also available via server.gradio_inference().launch().