Spaces:
Running
Running
import zipfile | |
import time | |
import threading | |
def unzip_content(): | |
try: | |
# First try using Python's zipfile | |
print("Attempting to unzip content using Python...") | |
with zipfile.ZipFile('./content.zip', 'r') as zip_ref: | |
zip_ref.extractall('.') | |
except Exception as e: | |
print(f"Python unzip failed: {str(e)}") | |
try: | |
# Fallback to system unzip command | |
print("Attempting to unzip content using system command...") | |
subprocess.run(['unzip', '-o', './content.zip'], check=True) | |
except Exception as e: | |
print(f"System unzip failed: {str(e)}") | |
raise Exception("Failed to unzip content using both methods") | |
print("Content successfully unzipped!") | |
# Try to unzip content at startup | |
try: | |
unzip_content() | |
except Exception as e: | |
print(f"Warning: Could not unzip content: {str(e)}") | |
import gradio as gr | |
import numpy as np | |
import torch | |
import torchvision | |
import torchvision.transforms | |
import torchvision.transforms.functional | |
import PIL | |
import matplotlib.pyplot as plt | |
import yaml | |
from omegaconf import OmegaConf | |
from CLIP import clip | |
import os | |
import sys | |
#os.chdir('./taming-transformers') | |
#from taming.models.vqgan import VQModel | |
#os.chdir('..') | |
taming_path = os.path.join(os.getcwd(), 'taming-transformers') | |
sys.path.append(taming_path) | |
from taming.models.vqgan import VQModel | |
from PIL import Image | |
import cv2 | |
import imageio | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def create_video(images_list): | |
"""Create video from a list of image tensors""" | |
if not images_list: | |
print("No images provided.") | |
return None | |
# Create a unique filename in the current directory | |
video_path = os.path.join(os.getcwd(), f"output_{int(time.time())}.mp4") | |
try: | |
video_writer = imageio.get_writer(video_path, fps=10, codec='libx264', quality=7, | |
output_params=['-movflags', 'faststart']) | |
for img_tensor in images_list: | |
# Convert tensor to numpy array | |
img = img_tensor.cpu().numpy().transpose((1, 2, 0)) | |
img = (img * 255).astype('uint8') | |
video_writer.append_data(img) | |
video_writer.close() | |
return video_path | |
except Exception as e: | |
if os.path.exists(video_path): | |
os.remove(video_path) | |
raise e | |
def save_from_tensors(tensor): | |
"""Process tensor and return the processed version""" | |
img = tensor.clone() | |
img = img.mul(255).byte() | |
img = img.cpu().numpy().transpose((1, 2, 0)) | |
return img | |
def norm_data(data): | |
return (data.clip(-1, 1) + 1) / 2 | |
def setup_clip_model(): | |
model, _ = clip.load('ViT-B/32', jit=False) | |
model.eval().to(device) | |
return model | |
def setup_vqgan_model(config_path, checkpoint_path): | |
config = OmegaConf.load(config_path) | |
model = VQModel(**config.model.params) | |
state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"] | |
model.load_state_dict(state_dict, strict=False) | |
return model.eval().to(device) | |
def generator(x, model): | |
x = model.post_quant_conv(x) | |
x = model.decoder(x) | |
return x | |
def encode_text(text, clip_model): | |
t = clip.tokenize(text).to(device) | |
return clip_model.encode_text(t).detach().clone() | |
def create_encoding(include, exclude, extras, clip_model): | |
include_enc = [encode_text(text, clip_model) for text in include] | |
exclude_enc = [encode_text(text, clip_model) for text in exclude] | |
extras_enc = [encode_text(text, clip_model) for text in extras] | |
return include_enc, exclude_enc, extras_enc | |
def create_crops(img, num_crops=32, size1=225, noise_factor=0.05): | |
aug_transform = torch.nn.Sequential( | |
torchvision.transforms.RandomHorizontalFlip(), | |
torchvision.transforms.RandomAffine(30, translate=(0.1, 0.1), fill=0) | |
).to(device) | |
p = size1 // 2 | |
img = torch.nn.functional.pad(img, (p, p, p, p), mode='constant', value=0) | |
img = aug_transform(img) | |
crop_set = [] | |
for _ in range(num_crops): | |
gap1 = int(torch.normal(1.2, .3, ()).clip(.43, 1.9) * size1) | |
offsetx = torch.randint(0, int(size1 * 2 - gap1), ()) | |
offsety = torch.randint(0, int(size1 * 2 - gap1), ()) | |
crop = img[:, :, offsetx:offsetx + gap1, offsety:offsety + gap1] | |
crop = torch.nn.functional.interpolate(crop, (224, 224), mode='bilinear', align_corners=True) | |
crop_set.append(crop) | |
img_crops = torch.cat(crop_set, 0) | |
randnormal = torch.randn_like(img_crops, requires_grad=False) | |
randstotal = torch.rand((img_crops.shape[0], 1, 1, 1)).to(device) | |
img_crops = img_crops + noise_factor * randstotal * randnormal | |
return img_crops | |
def optimize_result(params, prompt, vqgan_model, clip_model, w1, w2, extras_enc, exclude_enc): | |
alpha = 1 | |
beta = 0.5 | |
out = generator(params, vqgan_model) | |
out = norm_data(out) | |
out = create_crops(out) | |
out = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), | |
(0.26862954, 0.26130258, 0.27577711))(out) | |
img_enc = clip_model.encode_image(out) | |
final_enc = w1 * prompt + w2 * extras_enc[0] | |
final_text_include_enc = final_enc / final_enc.norm(dim=-1, keepdim=True) | |
final_text_exclude_enc = exclude_enc[0] | |
main_loss = torch.cosine_similarity(final_text_include_enc, img_enc, dim=-1) | |
penalize_loss = torch.cosine_similarity(final_text_exclude_enc, img_enc, dim=-1) | |
return -alpha * main_loss.mean() + beta * penalize_loss.mean() | |
def optimize(params, optimizer, prompt, vqgan_model, clip_model, w1, w2, extras_enc, exclude_enc): | |
loss = optimize_result(params, prompt, vqgan_model, clip_model, w1, w2, extras_enc, exclude_enc) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
return loss | |
def training_loop(params, optimizer, include_enc, exclude_enc, extras_enc, vqgan_model, clip_model, w1, w2, | |
total_iter=200, show_step=1): | |
res_img = [] | |
res_z = [] | |
for prompt in include_enc: | |
for it in range(total_iter): | |
loss = optimize(params, optimizer, prompt, vqgan_model, clip_model, w1, w2, extras_enc, exclude_enc) | |
if it >= 0 and it % show_step == 0: | |
with torch.no_grad(): | |
generated = generator(params, vqgan_model) | |
new_img = norm_data(generated[0].to(device)) | |
res_img.append(new_img) | |
res_z.append(params.clone().detach()) | |
print(f"loss: {loss.item():.4f}\nno. of iteration: {it}") | |
torch.cuda.empty_cache() | |
return res_img, res_z | |
def generate_art(include_text, exclude_text, extras_text, num_iterations): | |
try: | |
# Process the input prompts | |
include = [x.strip() for x in include_text.split(',')] | |
exclude = [x.strip() for x in exclude_text.split(',')] | |
extras = [x.strip() for x in extras_text.split(',')] | |
w1, w2 = 1.0, 0.9 | |
# Setup models | |
clip_model = setup_clip_model() | |
vqgan_model = setup_vqgan_model("./models/vqgan_imagenet_f16_16384/configs/model.yaml", | |
"./models/vqgan_imagenet_f16_16384/checkpoints/last.ckpt") | |
# Parameters | |
learning_rate = 0.1 | |
batch_size = 1 | |
wd = 0.1 | |
size1, size2 = 225, 400 | |
# Initialize parameters | |
initial_image = PIL.Image.open('./gradient1.png') | |
initial_image = initial_image.resize((size2, size1)) | |
initial_image = torchvision.transforms.ToTensor()(initial_image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
z, _, _ = vqgan_model.encode(initial_image) | |
params = torch.nn.Parameter(z).to(device) | |
optimizer = torch.optim.AdamW([params], lr=learning_rate, weight_decay=wd) | |
params.data = params.data * 0.6 + torch.randn_like(params.data) * 0.4 | |
# Encode prompts | |
include_enc, exclude_enc, extras_enc = create_encoding(include, exclude, extras, clip_model) | |
# Run training loop | |
res_img, res_z = training_loop(params, optimizer, include_enc, exclude_enc, extras_enc, | |
vqgan_model, clip_model, w1, w2, total_iter=num_iterations) | |
# Create video directly from tensors | |
video_path = create_video(res_img) | |
return video_path | |
except Exception as e: | |
raise e | |
def gradio_interface(include_text, exclude_text, extras_text, num_iterations): | |
video_path = None | |
try: | |
video_path = generate_art(include_text, exclude_text, extras_text, int(num_iterations)) | |
if not os.path.exists(video_path): | |
return "Video generation failed" | |
# Create a copy of the video path before scheduling deletion | |
response_path = video_path | |
# Schedule the video file for deletion after a delay | |
def cleanup(): | |
try: | |
if os.path.exists(video_path): | |
os.remove(video_path) | |
except: | |
pass | |
threading.Timer(10.0, cleanup).start() | |
return response_path | |
except Exception as e: | |
if video_path and os.path.exists(video_path): | |
os.remove(video_path) | |
return f"An error occurred: {str(e)}" | |
# Try to unzip content at startup | |
try: | |
unzip_content() | |
except Exception as e: | |
print(f"Warning: Could not unzip content: {str(e)}") | |
# Define and launch the Gradio app | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Textbox(label="Include Prompts (comma-separated)", | |
value="desert, heavy rain, cactus"), | |
gr.Textbox(label="Exclude Prompts (comma-separated)", | |
value="confusing, blurry"), | |
gr.Textbox(label="Extra Style Prompts (comma-separated)", | |
value="desert, clear, detailed, beautiful, good shape, detailed"), | |
gr.Number(label="Number of Iterations", | |
value=200, minimum=1, maximum=1000) | |
], | |
outputs=gr.Video(label="Generated Morphing Video", format="mp4", autoplay=True), | |
title="VQGAN-CLIP Art Generator", | |
css="allow", | |
allow_flagging="never", | |
description = """ | |
Generate artistic videos using VQGAN-CLIP. Enter your prompts separated by commas and adjust the number of iterations. The model will generate a morphing video based on your inputs. | |
Note: This application requires GPU access. Please either: | |
1. Use the Colab notebook available at https://github.com/SanshruthR/VQGAN-CLIP | |
2. Clone this space and enable GPU in your personal copy. | |
""") | |
if __name__ == "__main__": | |
print("Checking GPU availability:", "GPU AVAILABLE" if torch.cuda.is_available() else "NO GPU FOUND") | |
iface.launch() |