File size: 4,057 Bytes
01ec35c bda4d53 01ec35c bda4d53 6045edf bda4d53 37fcd48 62b36fc 733addb 7c096cc 733addb b490472 733addb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
---
license: apache-2.0
datasets:
- HuggingFaceM4/WebSight
---
The model is [CogAgent-chat-18B](https://huggingface.co/THUDM/CogAgent) finetuned (LoRA with rank 8 added to the language decoder) on 160K WebSight examples.
The model is in the format of [SAT (SwissArmyTransformer)](https://github.com/THUDM/SwissArmyTransformer/).
Please refer to [our paper](https://arxiv.org/abs/2403.03163) and [our codebase](https://github.com/NoviScl/Design2Code/tree/main/CogVLM) to run inference.
Use of the model must comply with [the original model license](https://github.com/THUDM/CogVLM/blob/main/MODEL_LICENSE) and the original data license (CC-BY-4.0).
# Example Usage (based on SAT)
```python
import sys
sys.path.insert(1, '/path/to/CogVLM')
from sat.model import AutoModel
import argparse
from utils.models import CogAgentModel, CogVLMModel, FineTuneTestCogAgentModel
import torch
from sat.model.mixins import CachedAutoregressiveMixin
from sat.quantization.kernels import quantize
from sat.model import AutoModel
from utils.utils import chat, llama2_tokenizer, llama2_text_processor_inference, get_image_processor
from utils.models import CogAgentModel, CogVLMModel
from tqdm import tqdm
import os
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--temperature', type=float, default=0.5)
parser.add_argument('--repetition_penalty', type=float, default=1.1)
args = parser.parse_args()
args.bf16 = True
args.stream_chat = False
args.version = "chat"
# You can download the testset from https://huggingface.co/datasets/SALT-NLP/Design2Code
test_data_dir = "/path/to/Design2Code"
predictions_dir = "/path/to/design2code_18b_v0_predictions"
if not os.path.exists(predictions_dir):
try:
os.makedirs(predictions_dir)
except:
pass
filename_list = [filename for filename in os.listdir(test_data_dir) if filename.endswith(".png")]
world_size = 1
model, model_args = FineTuneTestCogAgentModel.from_pretrained(
f"/path/to/design2code-18b-v0",
args=argparse.Namespace(
deepspeed=None,
local_rank=0,
rank=0,
world_size=1,
model_parallel_size=1,
mode='inference',
skip_init=True,
use_gpu_initialization=True,
device='cuda',
bf16=True,
fp16=None), overwrite_args={'model_parallel_size': world_size} if world_size != 1 else {})
model = model.eval()
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
language_processor_version = model_args.text_processor_version if 'text_processor_version' in model_args else args.version
print("[Language processor version]:", language_processor_version)
tokenizer = llama2_tokenizer("lmsys/vicuna-7b-v1.5", signal_type=language_processor_version)
image_processor = get_image_processor(model_args.eva_args["image_size"][0])
cross_image_processor = get_image_processor(model_args.cross_image_pix) if "cross_image_pix" in model_args else None
text_processor_infer = llama2_text_processor_inference(tokenizer, 2048, model.image_length)
def get_html(image_path):
with torch.no_grad():
history = None
cache_image = None
# We use an empty string as the query
query = ''
response, history, cache_image = chat(
image_path,
model,
text_processor_infer,
image_processor,
query,
history=history,
cross_img_processor=cross_image_processor,
image=cache_image,
max_length=4096,
top_p=1.0,
temperature=args.temperature,
top_k=1,
invalid_slices=text_processor_infer.invalid_slices,
repetition_penalty=args.repetition_penalty,
args=args
)
return response
for filename in tqdm(filename_list):
image_path = os.path.join(test_data_dir, filename)
generated_text = get_html(image_path)
with open(os.path.join(predictions_dir, filename.replace(".png", ".html")), "w", encoding='utf-8') as f:
f.write(generated_text)
```
|