ga89tiy commited on
Commit
dc94d87
1 Parent(s): 6edd88e

load model from hf

Browse files
LLAVA_Biovil/llava/model/language_model/llava_llama.py CHANGED
@@ -25,7 +25,7 @@
25
 
26
  from transformers.modeling_outputs import CausalLMOutputWithPast
27
 
28
- from LLAVA_Biovil.llava.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
29
 
30
 
31
  class LlavaConfig(LlamaConfig):
 
25
 
26
  from transformers.modeling_outputs import CausalLMOutputWithPast
27
 
28
+ from LLAVA_Biovil.llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
29
 
30
 
31
  class LlavaConfig(LlamaConfig):
LLAVA_Biovil/llava/model/llava_arch.py CHANGED
@@ -18,8 +18,8 @@
18
  from LLAVA_Biovil.biovil_t.model import ImageModel
19
  from LLAVA_Biovil.biovil_t.pretrained import _download_biovil_t_image_model_weights
20
  from LLAVA_Biovil.biovil_t.types import ImageEncoderType
21
- from LLAVA_Biovil.llava.multimodal_encoder.builder import build_vision_tower
22
- from LLAVA_Biovil.llava.multimodal_projector.builder import build_vision_projector, build_image_pooler
23
 
24
  from LLAVA_Biovil.llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
25
 
 
18
  from LLAVA_Biovil.biovil_t.model import ImageModel
19
  from LLAVA_Biovil.biovil_t.pretrained import _download_biovil_t_image_model_weights
20
  from LLAVA_Biovil.biovil_t.types import ImageEncoderType
21
+ from LLAVA_Biovil.llava.model.multimodal_encoder.builder import build_vision_tower
22
+ from LLAVA_Biovil.llava.model.multimodal_projector.builder import build_vision_projector, build_image_pooler
23
 
24
  from LLAVA_Biovil.llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
25
 
simple_test.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from skimage import io as io_img
4
+ import io
5
+
6
+ import requests
7
+ import torch
8
+ from PIL import Image
9
+ import numpy as np
10
+ from huggingface_hub import snapshot_download
11
+
12
+ from LLAVA_Biovil.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, remap_to_uint8
13
+ from LLAVA_Biovil.llava.model.builder import load_pretrained_model
14
+ from LLAVA_Biovil.llava.conversation import SeparatorStyle, conv_vicuna_v1
15
+
16
+ from LLAVA_Biovil.llava.constants import IMAGE_TOKEN_INDEX
17
+ from utils import create_chest_xray_transform_for_inference
18
+
19
+ def load_model_from_huggingface(repo_id, model_filename):
20
+ # Download model files
21
+ model_path = snapshot_download(repo_id=repo_id, revision="main")
22
+ model_path = Path(model_path) / model_filename
23
+
24
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base='liuhaotian/llava-v1.5-7b',
25
+ model_name="llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5_checkpoint-21000", load_8bit=False, load_4bit=False)
26
+
27
+ return tokenizer, model, image_processor, context_len
28
+
29
+ if __name__ == '__main__':
30
+ # config = None
31
+ # model_path = "/home/guests/chantal_pellegrini/RaDialog_LLaVA/LLAVA/checkpoints/llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5/checkpoint-21000" #TODO hardcoded in huggingface repo probably
32
+ # model_name = get_model_name_from_path(model_path)
33
+ tokenizer, model, image_processor, context_len = load_model_from_huggingface(repo_id="Chantal/RaDialog-interactive-radiology-report-generation", model_filename="model")
34
+ model.config.tokenizer_padding_side = "left"
35
+
36
+ findings = "edema, pleural effusion" #TODO should these come from chexpert classifier? Or not needed for this demo/test?
37
+
38
+ conv = conv_vicuna_v1.copy()
39
+ REPORT_GEN_PROMPT = f"<image>. Predicted Findings: {findings}. You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predicted findings. Write in the style of a radiologist, write one fluent text without enumeration, be concise and don't provide explanations or reasons."
40
+ print("USER: ", REPORT_GEN_PROMPT)
41
+ conv.append_message("USER", REPORT_GEN_PROMPT)
42
+ conv.append_message("ASSISTANT", None)
43
+ text_input = conv.get_prompt()
44
+
45
+ # get the image
46
+ vis_transforms_biovil = create_chest_xray_transform_for_inference(512, center_crop_size=448)
47
+ sample_img_path = "https://openi.nlm.nih.gov/imgs/512/10/10/CXR10_IM-0002-2001.png?keywords=Calcified%20Granuloma" #TODO find good image
48
+
49
+ response = requests.get(sample_img_path)
50
+ image = Image.open(io.BytesIO(response.content))
51
+ image = remap_to_uint8(np.array(image))
52
+ image = Image.fromarray(image).convert("L")
53
+ image_tensor = vis_transforms_biovil(image).unsqueeze(0)
54
+
55
+ image_tensor = image_tensor.to(model.device, dtype=torch.bfloat16)
56
+ input_ids = tokenizer_image_token(text_input, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
57
+
58
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
59
+ stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)
60
+
61
+ # generate a report
62
+ with torch.inference_mode():
63
+ output_ids = model.generate(
64
+ input_ids,
65
+ images=image_tensor,
66
+ do_sample=False,
67
+ use_cache=True,
68
+ max_new_tokens=300,
69
+ stopping_criteria=[stopping_criteria],
70
+ pad_token_id=tokenizer.pad_token_id
71
+ )
72
+
73
+ pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
74
+ print("ASSISTANT: ", pred)
75
+
76
+ # add prediction to conversation
77
+ conv.messages.pop()
78
+ conv.append_message("ASSISTANT", pred)
79
+ conv.append_message("USER", "Translate this report to easy language for a patient to understand.")
80
+ conv.append_message("ASSISTANT", None)
81
+ text_input = conv.get_prompt()
82
+ print("USER: ", "Translate this report to easy language for a patient to understand.")
83
+
84
+ # generate easy language report
85
+ input_ids = tokenizer_image_token(text_input, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
86
+ with torch.inference_mode():
87
+ output_ids = model.generate(
88
+ input_ids,
89
+ images=image_tensor,
90
+ do_sample=False,
91
+ use_cache=True,
92
+ max_new_tokens=300,
93
+ stopping_criteria=[stopping_criteria],
94
+ pad_token_id=tokenizer.pad_token_id
95
+ )
96
+
97
+ pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
98
+ print("ASSISTANT: ", pred)
99
+