matthewchung74's picture
Update README.md
63eea5f
|
raw
history blame
4.17 kB

Model Card for segmed/MedMistral-7B

Medmistral is a language model designed to answer medical questions. It is a qlora fine tune of Mistral-7B-v0.1 that was fine-tuned on the medmcqa dataset.

  • Developed by: Segmed
  • Model type: LLM

Model Sources https://huggingface.co/mistralai/Mistral-7B-v0.1

Uses

Prompts were generated using the following:

def generate_question(data_point):
    question = f""" {data_point['question']} [0] {data_point['opa']} [1] {data_point['opb']} [2] {data_point['opc']} [3] {data_point['opd']}
    """
    return question

def generate_prompt(data_point):
    full_prompt=f"""You are a helpful medical assistant. Your task is to answer the following question one of the options and explain why.\n### Question: {generate_question(data_point)}\n### Answer: """

    if data_point["cop"] != "" and data_point["exp"] != "":
        full_prompt= full_prompt + f"""{data_point["cop"]}\n### Explanation: {data_point["exp"]}"""

    return full_prompt

generate_prompt(eval_dataset[random.randrange(len(eval_dataset))])

Such that the resulting prompt would look like:

You are a helpful medical assistant. Your task is to answer the following question one of the options and explain why.
### Question:  "Genital elephantiasis" is seen in: [0] Rickettsia [1] Chancroid [2] Lymphogranuloma venereum [3] Syphilis
    
### Answer: 2
### Explanation: Ans. is. 'c' i. e., Lymphogranuloma venereum

Direct Use

Below is the code tokenizer, model and helper function to generate new tokens:

tokenizer_id = "segmed/MedMistral-7B"
tokenizer = AutoTokenizer.from_pretrained(
    tokenizer_id,
    padding_side="left",
    # model_max_length=4096,
    add_eos_token=True)

tokenizer.pad_token = tokenizer.eos_token

def tokenize(prompt):
    result = tokenizer(
        prompt,
        truncation=True,
        padding="max_length",
        max_length=512,
    )
    result["labels"] = result["input_ids"].copy()
    return result

def tokenize_prompt(data_point):
    full_prompt = generate_prompt(data_point)
    return tokenize(full_prompt)

model_id = "segmed/MedMistral-7B"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
    base_model_id,  
    quantization_config=bnb_config,  # Same quantization config as before
    device_map="auto",
    trust_remote_code=True,
)

def generate_tokens(prompt, max_new_tokens=32):
    model_input = tokenizer(prompt, return_tensors="pt").to("cuda")
    model.eval()
    with torch.no_grad():
        return tokenizer.decode(model.generate(**model_input, max_new_tokens=max_new_tokens, do_sample=True, top_k=0, num_return_sequences=1, temperature=0.1, eos_token_id=tokenizer.eos_token_id)[0])

Bias, Risks, and Limitations

This is not intended for medical advice. We would recommend further testing.

Training Details

Training occured over the course of 24 hours. 2 epochs were completed on a single A100.

Training Data

The medmcqa dataset has 193k samples, ~1k of which was used for test and ~1k for eval. This dataset was selected for its quantity of samples as well as multiple choice format for simple evaluation.

Evaluation

This dataset was choosen for the simplicty of evaluating. While the prompt asks for an explanation, the actually accuracy can be computed based on the multiple choice output. The results for the evaluation are coming soon.