melt / tests /test_image_utils.py
martinakaduc's picture
Upload folder using huggingface_hub
f3305db verified
"""
Usage:
python3 -m unittest tests.test_image_utils
"""
import base64
from io import BytesIO
import os
import unittest
import numpy as np
from PIL import Image
from fastchat.utils import (
resize_image_and_return_image_in_bytes,
image_moderation_filter,
)
from fastchat.conversation import get_conv_template
def check_byte_size_in_mb(image_base64_str):
return len(image_base64_str) / 1024 / 1024
def generate_random_image(target_size_mb, image_format="PNG"):
# Convert target size from MB to bytes
target_size_bytes = target_size_mb * 1024 * 1024
# Estimate dimensions
dimension = int((target_size_bytes / 3) ** 0.5)
# Generate random pixel data
pixel_data = np.random.randint(0, 256, (dimension, dimension, 3), dtype=np.uint8)
# Create an image from the pixel data
img = Image.fromarray(pixel_data)
# Save image to a temporary file
temp_filename = "temp_image." + image_format.lower()
img.save(temp_filename, format=image_format)
# Check the file size and adjust quality if needed
while os.path.getsize(temp_filename) < target_size_bytes:
# Increase dimensions or change compression quality
dimension += 1
pixel_data = np.random.randint(
0, 256, (dimension, dimension, 3), dtype=np.uint8
)
img = Image.fromarray(pixel_data)
img.save(temp_filename, format=image_format)
return img
class DontResizeIfLessThanMaxTest(unittest.TestCase):
def test_dont_resize_if_less_than_max(self):
max_image_size = 5
initial_size_mb = 0.1 # Initial image size
img = generate_random_image(initial_size_mb)
image_bytes = BytesIO()
img.save(image_bytes, format="PNG") # Save the image as JPEG
previous_image_size = check_byte_size_in_mb(image_bytes.getvalue())
image_bytes = resize_image_and_return_image_in_bytes(
img, max_image_size_mb=max_image_size
)
new_image_size = check_byte_size_in_mb(image_bytes.getvalue())
self.assertEqual(previous_image_size, new_image_size)
class ResizeLargeImageForModerationEndpoint(unittest.TestCase):
def test_resize_large_image_and_send_to_moderation_filter(self):
initial_size_mb = 6 # Initial image size which we know is greater than what the endpoint can take
img = generate_random_image(initial_size_mb)
nsfw_flag, csam_flag = image_moderation_filter(img)
self.assertFalse(nsfw_flag)
self.assertFalse(nsfw_flag)
class DontResizeIfMaxImageSizeIsNone(unittest.TestCase):
def test_dont_resize_if_max_image_size_is_none(self):
initial_size_mb = 0.2 # Initial image size
img = generate_random_image(initial_size_mb)
image_bytes = BytesIO()
img.save(image_bytes, format="PNG") # Save the image as JPEG
previous_image_size = check_byte_size_in_mb(image_bytes.getvalue())
image_bytes = resize_image_and_return_image_in_bytes(
img, max_image_size_mb=None
)
new_image_size = check_byte_size_in_mb(image_bytes.getvalue())
self.assertEqual(previous_image_size, new_image_size)
class OpenAIConversationDontResizeImage(unittest.TestCase):
def test(self):
conv = get_conv_template("chatgpt")
initial_size_mb = 0.2 # Initial image size
img = generate_random_image(initial_size_mb)
image_bytes = BytesIO()
img.save(image_bytes, format="PNG") # Save the image as JPEG
previous_image_size = check_byte_size_in_mb(image_bytes.getvalue())
resized_img = conv.convert_image_to_base64(img)
resized_img_bytes = base64.b64decode(resized_img)
new_image_size = check_byte_size_in_mb(resized_img_bytes)
self.assertEqual(previous_image_size, new_image_size)
class ClaudeConversationResizesCorrectly(unittest.TestCase):
def test(self):
conv = get_conv_template("claude-3-haiku-20240307")
initial_size_mb = 5 # Initial image size
img = generate_random_image(initial_size_mb)
image_bytes = BytesIO()
img.save(image_bytes, format="PNG") # Save the image as JPEG
previous_image_size = check_byte_size_in_mb(image_bytes.getvalue())
resized_img = conv.convert_image_to_base64(img)
new_base64_image_size = check_byte_size_in_mb(resized_img)
new_image_bytes_size = check_byte_size_in_mb(base64.b64decode(resized_img))
self.assertLess(new_image_bytes_size, previous_image_size)
self.assertLessEqual(new_image_bytes_size, conv.max_image_size_mb)
self.assertLessEqual(new_base64_image_size, 5)