import os import uuid import torch import argparse from PIL import Image from chameleon.inference.chameleon import ChameleonInferenceModel, Options from constants import ( MODEL_7B_PATH, TOKENIZER_TEXT_PATH, TOKENIZER_IMAGE_CFG_PATH, TOKENIZER_IMAGE_PATH, ) from typing import List import logging # Set up the logging configuration logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') def main(args: argparse.Namespace): """Main function to generate images from instructions.""" # Print configuration # print(f"Instruction: {args.instruction}") # print(f"Batch size: {args.batch_size}") # Log the information logging.info(f"Instruction: {args.instruction}") logging.info(f"Batch size: {args.batch_size}") # Load Chameleon model model = ChameleonInferenceModel( MODEL_7B_PATH.as_posix(), TOKENIZER_TEXT_PATH.as_posix(), TOKENIZER_IMAGE_CFG_PATH.as_posix(), TOKENIZER_IMAGE_PATH.as_posix(), ) # Generate options options = Options() options.txt = False # Prepare batch prompts instructions: List[str] = [args.instruction for _ in range(args.batch_size)] batch_prompt_ui = [] for instruction in instructions: batch_prompt_ui += [ [ {"type": "text", "value": instruction}, {"type": "sentinel", "value": ""} ], ] # Generate images image_tokens: torch.LongTensor = model.generate( batch_prompt_ui=batch_prompt_ui, options=options ) images: List[Image.Image] = model.decode_image(image_tokens) # Save images os.makedirs(args.save_dir, exist_ok=True) for instruction, image in zip(instructions, images): image_path = os.path.join(args.save_dir, f"1.png") image.save(image_path) print(f"Save generated images to {image_path}") def parse_arguments() -> argparse.Namespace: """Parse command line arguments.""" parser = argparse.ArgumentParser(description="Generate images based on text instructions.") parser.add_argument("-i", "--instruction", type=str, required=True, help="The instruction for image generation.") parser.add_argument("-b", "--batch_size", type=int, default=10, help="The number of images to generate.") parser.add_argument("-s", "--save_dir", type=str, default="./outputs/text2image/", help="The directory to save the generated images.") args: argparse.Namespace = parser.parse_args() return args if __name__ == "__main__": args: argparse.Namespace = parse_arguments() main(args)