""" Usage: python3 -m unittest tests.test_image_utils """ import base64 from io import BytesIO import os import unittest import numpy as np from PIL import Image from fastchat.utils import ( resize_image_and_return_image_in_bytes, image_moderation_filter, ) from fastchat.conversation import get_conv_template def check_byte_size_in_mb(image_base64_str): return len(image_base64_str) / 1024 / 1024 def generate_random_image(target_size_mb, image_format="PNG"): # Convert target size from MB to bytes target_size_bytes = target_size_mb * 1024 * 1024 # Estimate dimensions dimension = int((target_size_bytes / 3) ** 0.5) # Generate random pixel data pixel_data = np.random.randint(0, 256, (dimension, dimension, 3), dtype=np.uint8) # Create an image from the pixel data img = Image.fromarray(pixel_data) # Save image to a temporary file temp_filename = "temp_image." + image_format.lower() img.save(temp_filename, format=image_format) # Check the file size and adjust quality if needed while os.path.getsize(temp_filename) < target_size_bytes: # Increase dimensions or change compression quality dimension += 1 pixel_data = np.random.randint( 0, 256, (dimension, dimension, 3), dtype=np.uint8 ) img = Image.fromarray(pixel_data) img.save(temp_filename, format=image_format) return img class DontResizeIfLessThanMaxTest(unittest.TestCase): def test_dont_resize_if_less_than_max(self): max_image_size = 5 initial_size_mb = 0.1 # Initial image size img = generate_random_image(initial_size_mb) image_bytes = BytesIO() img.save(image_bytes, format="PNG") # Save the image as JPEG previous_image_size = check_byte_size_in_mb(image_bytes.getvalue()) image_bytes = resize_image_and_return_image_in_bytes( img, max_image_size_mb=max_image_size ) new_image_size = check_byte_size_in_mb(image_bytes.getvalue()) self.assertEqual(previous_image_size, new_image_size) class ResizeLargeImageForModerationEndpoint(unittest.TestCase): def test_resize_large_image_and_send_to_moderation_filter(self): initial_size_mb = 6 # Initial image size which we know is greater than what the endpoint can take img = generate_random_image(initial_size_mb) nsfw_flag, csam_flag = image_moderation_filter(img) self.assertFalse(nsfw_flag) self.assertFalse(nsfw_flag) class DontResizeIfMaxImageSizeIsNone(unittest.TestCase): def test_dont_resize_if_max_image_size_is_none(self): initial_size_mb = 0.2 # Initial image size img = generate_random_image(initial_size_mb) image_bytes = BytesIO() img.save(image_bytes, format="PNG") # Save the image as JPEG previous_image_size = check_byte_size_in_mb(image_bytes.getvalue()) image_bytes = resize_image_and_return_image_in_bytes( img, max_image_size_mb=None ) new_image_size = check_byte_size_in_mb(image_bytes.getvalue()) self.assertEqual(previous_image_size, new_image_size) class OpenAIConversationDontResizeImage(unittest.TestCase): def test(self): conv = get_conv_template("chatgpt") initial_size_mb = 0.2 # Initial image size img = generate_random_image(initial_size_mb) image_bytes = BytesIO() img.save(image_bytes, format="PNG") # Save the image as JPEG previous_image_size = check_byte_size_in_mb(image_bytes.getvalue()) resized_img = conv.convert_image_to_base64(img) resized_img_bytes = base64.b64decode(resized_img) new_image_size = check_byte_size_in_mb(resized_img_bytes) self.assertEqual(previous_image_size, new_image_size) class ClaudeConversationResizesCorrectly(unittest.TestCase): def test(self): conv = get_conv_template("claude-3-haiku-20240307") initial_size_mb = 5 # Initial image size img = generate_random_image(initial_size_mb) image_bytes = BytesIO() img.save(image_bytes, format="PNG") # Save the image as JPEG previous_image_size = check_byte_size_in_mb(image_bytes.getvalue()) resized_img = conv.convert_image_to_base64(img) new_base64_image_size = check_byte_size_in_mb(resized_img) new_image_bytes_size = check_byte_size_in_mb(base64.b64decode(resized_img)) self.assertLess(new_image_bytes_size, previous_image_size) self.assertLessEqual(new_image_bytes_size, conv.max_image_size_mb) self.assertLessEqual(new_base64_image_size, 5)