#dataset used: polinaeterna/pokemon-blip-captions
#code
from transformers import AutoProcessor, AutoModelForCausalLM
import torch
from PIL import Image
import requests
#Preprocess the dataset
#Since the dataset has two modalities (image and text), the pre-processing pipeline will preprocess images and the captions.
#To do so, load the processor class associated with the model you are about to fine-tune.
from transformers import AutoProcessor
checkpoint = "microsoft/git-base"
processor = AutoProcessor.from_pretrained(checkpoint)
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "kr-manish/git-base-pokemon" # Replace with your actual username and model name
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png" # Replace with the URL of your image
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(images=image, return_tensors="pt").to(device)
generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_caption)
#a pink and purple pokemon character with big eyes
git-base-pokemon
This model is a fine-tuned version of microsoft/git-base on an unknown dataset. It achieves the following results on the evaluation set:
- Loss: 1.5797
- Wer Score: 8.9592
Model description
More information needed
Intended uses & limitations
More information needed
Training and evaluation data
More information needed
Training procedure
Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: 3e-05
- train_batch_size: 32
- eval_batch_size: 32
- seed: 42
- gradient_accumulation_steps: 2
- total_train_batch_size: 64
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: linear
- num_epochs: 20
- mixed_precision_training: Native AMP
Training results
Training Loss | Epoch | Step | Validation Loss | Wer Score |
---|---|---|---|---|
8.155 | 4.17 | 50 | 6.4318 | 25.1325 |
5.3386 | 8.33 | 100 | 4.0782 | 18.6484 |
3.3109 | 12.5 | 150 | 2.4303 | 9.4306 |
2.0471 | 16.67 | 200 | 1.5797 | 8.9592 |
Framework versions
- Transformers 4.38.2
- Pytorch 2.2.1+cu121
- Datasets 2.18.0
- Tokenizers 0.15.2
- Downloads last month
- 12
Inference API (serverless) does not yet support transformers models for this pipeline type.
Model tree for kr-manish/fine-tune-image-caption-pokemon
Base model
microsoft/git-base