|
--- |
|
library_name: transformers |
|
datasets: |
|
- erfanzar/MoD-Prompts |
|
- erfanzar/GPT-4-Prompts |
|
language: |
|
- en |
|
pipeline_tag: text-generation |
|
--- |
|
|
|
# Raven Fine-Tuned Gemma-2B |
|
|
|
Raven is a Fine-tuned version of google/gemma-2 whith same prompting style of gemma-2b-it which trained Using TPU VM v4-64 and [EasyDeL](https://github.com/erfanzar/EasyDeL) |
|
|
|
both fine-tuning and serving code are available and it's recommended to use JAX-EasyDeL Gemma since HF-Gemma implementaion is Wrong. |
|
|
|
|
|
### Serving and Using Raven |
|
```python |
|
from EasyDel import JAXServer, JAXServerConfig, EasyServe |
|
from fjformer import get_dtype |
|
from EasyDel.serve.prompters import GemmaPrompter, Llama2Prompter, OpenChatPrompter, ChatMLPrompter |
|
from EasyDel.serve.prompters.base_prompter import BasePrompter |
|
from jax import numpy as jnp, lax |
|
import jax |
|
from typing import List, Union, Optional |
|
|
|
max_sequence_length = 8192 |
|
max_compile_tokens = 256 |
|
max_new_tokens_ratio = 25 |
|
dtype = "fp16" |
|
prompter_type = "gemma" |
|
sharding_axis_dims = (1, 1, 1, -1) |
|
pretrained_model_name_or_path = "erfanzar/Raven-v0.1" |
|
attn_mechanism = "normal" |
|
scan_mlp_chunk_size = max_compile_tokens |
|
use_scan_mlp = True |
|
scan_ring_attention = True |
|
block_k = 128 |
|
block_q = 128 |
|
use_sharded_kv_caching = False |
|
|
|
server_config = JAXServerConfig( |
|
max_sequence_length=max_sequence_length, |
|
max_compile_tokens=max_compile_tokens, |
|
max_new_tokens=max_compile_tokens * max_new_tokens_ratio, |
|
dtype=dtype, |
|
pre_compile=False, |
|
eos_token_id=107 |
|
) |
|
|
|
prompters = { |
|
"gemma": GemmaPrompter(), |
|
"llama": Llama2Prompter(), |
|
"openchat": OpenChatPrompter(), |
|
"chatml": ChatMLPrompter() |
|
} |
|
|
|
prompter: BasePrompter = prompters[prompter_type] |
|
|
|
class JAXServerC(JAXServer): |
|
@staticmethod |
|
def format_chat(history: List[List[str]], prompt: str, system: Union[str, None]) -> str: |
|
return prompter.format_message( |
|
history=history, |
|
prompt=prompt, |
|
system_message=system, |
|
prefix=None |
|
) |
|
|
|
@staticmethod |
|
def format_instruct(system: str, instruction: str) -> str: |
|
return prompter.format_message( |
|
prefix=None, |
|
system_message=system, |
|
prompt=instruction, |
|
history=[] |
|
) |
|
|
|
server = JAXServerC.from_torch_pretrained( |
|
server_config=server_config, |
|
pretrained_model_name_or_path=pretrained_model_name_or_path, |
|
device=jax.devices('cpu')[0], |
|
dtype=get_dtype(dtype=dtype), |
|
param_dtype=get_dtype(dtype=dtype), |
|
precision=jax.lax.Precision("fastest"), |
|
sharding_axis_dims=sharding_axis_dims, |
|
sharding_axis_names=("dp", "fsdp", "tp", "sp"), |
|
input_shape=(1, server_config.max_sequence_length), |
|
model_config_kwargs=dict( |
|
fully_sharded_data_parallel=True, |
|
attn_mechanism=attn_mechanism, |
|
scan_mlp_chunk_size=max_compile_tokens, |
|
use_scan_mlp=use_scan_mlp, |
|
scan_ring_attention=scan_ring_attention, |
|
block_k=block_k, |
|
block_q=block_q, |
|
use_sharded_kv_caching=use_sharded_kv_caching |
|
) |
|
) |
|
|
|
history = [] |
|
while True: |
|
user_prompt = input("> ") |
|
model_prompt = server.format_chat( |
|
history, |
|
user_prompt, |
|
"You are an AI assistant be respect-full and explain detailed questions step by step." |
|
) |
|
|
|
past_response_length = 0 |
|
|
|
for response, used_tokens in server.sample( |
|
model_prompt, |
|
greedy=False |
|
): |
|
print(response[past_response_length:], end="") |
|
past_response_length = len(response) |
|
|
|
history.append([user_prompt, response]) |
|
``` |
|
|
|
Gradio UI is also available via `server.gradio_inference().launch()`. |