from flax.jax_utils import replicate from jax import pmap from flax.training.common_utils import shard import jax import jax.numpy as jnp import gradio as gr from PIL import Image from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel from pathlib import Path from PIL import Image import numpy as np from diffusers import FlaxStableDiffusionPipeline import os if 'TPU_NAME' in os.environ: import requests if 'TPU_DRIVER_MODE' not in globals(): url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly' resp = requests.post(url) TPU_DRIVER_MODE = 1 from jax.config import config config.FLAGS.jax_xla_backend = "tpu_driver" config.FLAGS.jax_backend_target = os.environ['TPU_NAME'] print('Registered TPU:', config.FLAGS.jax_backend_target) else: print('No TPU detected. Can be changed under "Runtime/Change runtime type".') import jax jax.local_devices() num_devices = jax.device_count() device_type = jax.devices()[0].device_kind print(f"Found {num_devices} JAX devices of type {device_type}.") def sd2_inference(pipeline, prompts, params, seed = 42, num_inference_steps = 50 ): prng_seed = jax.random.PRNGKey(seed) prompt_ids = pipeline.prepare_inputs(prompts) params = replicate(params) prng_seed = jax.random.split(prng_seed, jax.device_count()) prompt_ids = shard(prompt_ids) images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:]) images = pipeline.numpy_to_pil(images) return images HF_ACCESS_TOKEN = os.environ["HFAUTH"] # Load Model pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", use_auth_token = HF_ACCESS_TOKEN, revision="bf16", dtype=jnp.bfloat16, ) loc = "ydshieh/vit-gpt2-coco-en" feature_extractor = ViTFeatureExtractor.from_pretrained(loc) tokenizer = AutoTokenizer.from_pretrained(loc) model = FlaxVisionEncoderDecoderModel.from_pretrained(loc) gen_kwargs = {"max_length": 16, "num_beams": 4} def generate(pixel_values): output_ids = model.generate(pixel_values, **gen_kwargs).sequences return output_ids def predict(image): pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values output_ids = generate(pixel_values) preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) preds = [pred.strip() for pred in preds] return preds def image2text(image): preds = predict(image) return (preds[0]) def text_to_image_and_image_to_text(text=None,image=None): txt="" img=None if image != None: txt=image2text(image) if text !="": images = sd2_inference(pipeline, [text], params, seed = 42, num_inference_steps = 5 ) img = images[0] return img,txt if __name__ == '__main__': interFace = gr.Interface(fn=text_to_image_and_image_to_text, inputs=[gr.inputs.Textbox(placeholder="Enter the text to Encode to an image", label="Text to Encode to Image ",lines=1,optional=True),gr.Image(type="pil",label="Image to Decode to text",optional=True)], outputs=[gr.outputs.Image(type="pil", label="Encoded Image"),gr.outputs.Textbox( label="Decoded Text")], title="T2I2T: Text2Image2Text imformation transmiter", description="⭐️The next generation of QR codes, an information sharing tool via images⭐️ Error rates are high & Image generation takes about 200 seconds.", theme='gradio/soft' ) interFace.launch()