Divyasreepat's picture
Update README.md
3e2ea79 verified
metadata
library_name: keras-hub
extra_gated_heading: Access PaliGemma on Hugging Face
extra_gated_prompt: >-
  To access PaliGemma on Hugging Face, you’re required to review and agree to
  Google’s usage license. To do this, please ensure you’re logged-in to Hugging
  Face and click below. Requests are processed immediately.
extra_gated_button_content: Acknowledge license
license: gemma
pipeline_tag: image-text-to-text

PaliGemma is a set of multi-modal large language models published by Google based on the Gemma model. Both a pre-trained and instruction tuned models are available. See the model card below for benchmarks, data sources, and intended use cases.

Links

Installation

Keras and KerasHub can be installed with:

pip install -U -q keras_hub
pip install -U -q keras>=3

Jax, TensorFlow, and Torch come preinstalled in Kaggle Notebooks. For instruction on installing them in another environment see the Keras Getting Started page.

Presets

The following model checkpoints are provided by the Keras team. Full code examples for each are available below.

Preset name Parameters Description
paligemma-3b-224-mix-keras 2.92B image size 224, mix fine tuned, text sequence length is 256
paligemma-3b-448-mix-keras 2.92B image size 448, mix fine tuned, text sequence length is 512
paligemma-3b-224-keras 2.92B image size 224, pre trained, text sequence length is 128
paligemma-3b-448-keras 2.92B image size 448, pre trained, text sequence length is 512
paligemma-3b-896-keras 2.93B image size 896, pre trained, text sequence length is 512

Prompts

The PaliGemma "mix" models can handle a number of prompting structures out of the box. It is important to stick exactly to these prompts, including the newline. Lang can be a language code such as "en" or "fr". Support for languages outside of English will vary depending on the prompt type.

  • "cap {lang}\n": very raw short caption (from WebLI-alt).
  • "caption {lang}\n": coco-like short captions.
  • "describe {lang}\n": somewhat longer more descriptive captions.
  • "ocr\n": optical character recognition.
  • "answer en {question}\n": question answering about the image contents.
  • "question {lang} {answer}\n": question generation for a given answer.
  • "detect {thing} ; {thing}\n": count objects in a scene.

Not "mix" presets should be fine-tuned for a specific task.

!pip install -U -q keras_hub

Pick a backend of your choice

import os
os.environ["KERAS_BACKEND"] = "jax"

Now we can load the PaliGemma "causal language model" from the Kaggle Models hub. A causal language model is just a LLM that is ready for generation, by training with a causal mask, and running generation a token at a time in a recurrent loop.

keras.config.set_floatx("bfloat16")
pali_gemma_lm = keras_hub.models.PaliGemmaCausalLM.from_preset(
    "hf://google/paligemma-3b-448-mix-keras"
)

Function that reads an image from a given URL

def read_image(url):
    contents = io.BytesIO(requests.get(url).content)
    image = PIL.Image.open(contents)
    image = np.array(image)
    # Remove alpha channel if neccessary.
    if image.shape[2] == 4:
        image = image[:, :, :3]
    return image
image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
image = read_image(image_url)

Use generate() call with a single image and prompt. The text prompt has to end with \n.

prompt = 'answer en where is the cow standing?\n'
output = pali_gemma_lm.generate(
    inputs={
        "images": image,
        "prompts": prompt,
    }
)
print(output)

Use generate() call with a batched images and prompts.

prompts = [
    'answer en where is the cow standing?\n',
    'answer en what color is the cow?\n',
    'describe en\n',
    'detect cow\n',
    'segment cow\n',
]
images = [image, image, image, image, image]
outputs = pali_gemma_lm.generate(
    inputs={
        "images": images,
        "prompts": prompts,
    }
)
for output in outputs:
    print(output)

There's a few other style of prompts this model can handle out of the box...

cap {lang}\n: very raw short caption (from WebLI-alt).

caption {lang}\n: nice, coco-like short captions.

describe {lang}\n: somewhat longer more descriptive captions.

ocr\n: optical character recognition.

answer en {question}\n: question answering about the image contents.

question {lang} {answer}\n: question generation for a given answer.

detect {thing} ; {thing}\n: count objects in a scene.

Call fit() on a single batch

import numpy as np
image = np.random.uniform(-1, 1, size=(224, 224, 3))
x = {
    "images": [image, image],
    "prompts": ["answer en Where is the cow standing?\n", "caption en\n"],
}
y = {
    "responses": ["beach", "A brown cow standing on a beach next to the ocean."],
}
pali_gemma_lm = keras_hub.models.PaliGemmaCausalLM.from_preset("hf://google/paligemma-3b-448-mix-keras")
pali_gemma_lm.fit(x=x, y=y, batch_size=2)