|
--- |
|
license: apache-2.0 |
|
datasets: |
|
- HuggingFaceM4/WebSight |
|
--- |
|
|
|
The model is [CogAgent-chat-18B](https://huggingface.co/THUDM/CogAgent) finetuned (LoRA with rank 8 added to the language decoder) on 160K WebSight examples. |
|
|
|
The model is in the format of [SAT (SwissArmyTransformer)](https://github.com/THUDM/SwissArmyTransformer/). |
|
|
|
Please refer to [our paper](https://arxiv.org/abs/2403.03163) and [our codebase](https://github.com/NoviScl/Design2Code/tree/main/CogVLM) to run inference. |
|
|
|
Use of the model must comply with [the original model license](https://github.com/THUDM/CogVLM/blob/main/MODEL_LICENSE) and the original data license (CC-BY-4.0). |
|
|
|
# Example Usage (based on SAT) |
|
|
|
```python |
|
import sys |
|
sys.path.insert(1, '/path/to/CogVLM') |
|
from sat.model import AutoModel |
|
import argparse |
|
from utils.models import CogAgentModel, CogVLMModel, FineTuneTestCogAgentModel |
|
import torch |
|
from sat.model.mixins import CachedAutoregressiveMixin |
|
from sat.quantization.kernels import quantize |
|
from sat.model import AutoModel |
|
from utils.utils import chat, llama2_tokenizer, llama2_text_processor_inference, get_image_processor |
|
from utils.models import CogAgentModel, CogVLMModel |
|
from tqdm import tqdm |
|
import os |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--temperature', type=float, default=0.5) |
|
parser.add_argument('--repetition_penalty', type=float, default=1.1) |
|
args = parser.parse_args() |
|
args.bf16 = True |
|
args.stream_chat = False |
|
args.version = "chat" |
|
|
|
# You can download the testset from https://huggingface.co/datasets/SALT-NLP/Design2Code |
|
test_data_dir = "/path/to/Design2Code" |
|
predictions_dir = "/path/to/design2code_18b_v0_predictions" |
|
if not os.path.exists(predictions_dir): |
|
try: |
|
os.makedirs(predictions_dir) |
|
except: |
|
pass |
|
|
|
filename_list = [filename for filename in os.listdir(test_data_dir) if filename.endswith(".png")] |
|
world_size = 1 |
|
model, model_args = FineTuneTestCogAgentModel.from_pretrained( |
|
f"/path/to/design2code-18b-v0", |
|
args=argparse.Namespace( |
|
deepspeed=None, |
|
local_rank=0, |
|
rank=0, |
|
world_size=1, |
|
model_parallel_size=1, |
|
mode='inference', |
|
skip_init=True, |
|
use_gpu_initialization=True, |
|
device='cuda', |
|
bf16=True, |
|
fp16=None), overwrite_args={'model_parallel_size': world_size} if world_size != 1 else {}) |
|
model = model.eval() |
|
model.add_mixin('auto-regressive', CachedAutoregressiveMixin()) |
|
|
|
language_processor_version = model_args.text_processor_version if 'text_processor_version' in model_args else args.version |
|
print("[Language processor version]:", language_processor_version) |
|
tokenizer = llama2_tokenizer("lmsys/vicuna-7b-v1.5", signal_type=language_processor_version) |
|
image_processor = get_image_processor(model_args.eva_args["image_size"][0]) |
|
cross_image_processor = get_image_processor(model_args.cross_image_pix) if "cross_image_pix" in model_args else None |
|
text_processor_infer = llama2_text_processor_inference(tokenizer, 2048, model.image_length) |
|
|
|
def get_html(image_path): |
|
with torch.no_grad(): |
|
history = None |
|
cache_image = None |
|
# We use an empty string as the query |
|
query = '' |
|
|
|
response, history, cache_image = chat( |
|
image_path, |
|
model, |
|
text_processor_infer, |
|
image_processor, |
|
query, |
|
history=history, |
|
cross_img_processor=cross_image_processor, |
|
image=cache_image, |
|
max_length=4096, |
|
top_p=1.0, |
|
temperature=args.temperature, |
|
top_k=1, |
|
invalid_slices=text_processor_infer.invalid_slices, |
|
repetition_penalty=args.repetition_penalty, |
|
args=args |
|
) |
|
|
|
return response |
|
|
|
for filename in tqdm(filename_list): |
|
image_path = os.path.join(test_data_dir, filename) |
|
generated_text = get_html(image_path) |
|
with open(os.path.join(predictions_dir, filename.replace(".png", ".html")), "w", encoding='utf-8') as f: |
|
f.write(generated_text) |
|
``` |
|
|
|
|