Anole / text2image.py
xuefengli
update
7362797
raw
history blame contribute delete
No virus
2.66 kB
import os
import uuid
import torch
import argparse
from PIL import Image
from chameleon.inference.chameleon import ChameleonInferenceModel, Options
from constants import (
MODEL_7B_PATH,
TOKENIZER_TEXT_PATH,
TOKENIZER_IMAGE_CFG_PATH,
TOKENIZER_IMAGE_PATH,
)
from typing import List
import logging
# Set up the logging configuration
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def main(args: argparse.Namespace):
"""Main function to generate images from instructions."""
# Print configuration
# print(f"Instruction: {args.instruction}")
# print(f"Batch size: {args.batch_size}")
# Log the information
logging.info(f"Instruction: {args.instruction}")
logging.info(f"Batch size: {args.batch_size}")
# Load Chameleon model
model = ChameleonInferenceModel(
MODEL_7B_PATH.as_posix(),
TOKENIZER_TEXT_PATH.as_posix(),
TOKENIZER_IMAGE_CFG_PATH.as_posix(),
TOKENIZER_IMAGE_PATH.as_posix(),
)
# Generate options
options = Options()
options.txt = False
# Prepare batch prompts
instructions: List[str] = [args.instruction for _ in range(args.batch_size)]
batch_prompt_ui = []
for instruction in instructions:
batch_prompt_ui += [
[
{"type": "text", "value": instruction},
{"type": "sentinel", "value": "<END-OF-TURN>"}
],
]
# Generate images
image_tokens: torch.LongTensor = model.generate(
batch_prompt_ui=batch_prompt_ui,
options=options
)
images: List[Image.Image] = model.decode_image(image_tokens)
# Save images
os.makedirs(args.save_dir, exist_ok=True)
for instruction, image in zip(instructions, images):
image_path = os.path.join(args.save_dir, f"1.png")
image.save(image_path)
print(f"Save generated images to {image_path}")
def parse_arguments() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Generate images based on text instructions.")
parser.add_argument("-i", "--instruction", type=str, required=True, help="The instruction for image generation.")
parser.add_argument("-b", "--batch_size", type=int, default=10, help="The number of images to generate.")
parser.add_argument("-s", "--save_dir", type=str, default="./outputs/text2image/", help="The directory to save the generated images.")
args: argparse.Namespace = parser.parse_args()
return args
if __name__ == "__main__":
args: argparse.Namespace = parse_arguments()
main(args)