Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
import torch | |
import clip | |
import csv | |
import tqdm | |
from profanity_filter import ProfanityFilter | |
templates = [ | |
lambda c: f'a bad photo of a {c}.', | |
lambda c: f'a photo of many {c}.', | |
lambda c: f'a sculpture of a {c}.', | |
lambda c: f'a photo of the hard to see {c}.', | |
lambda c: f'a low resolution photo of the {c}.', | |
lambda c: f'a rendering of a {c}.', | |
lambda c: f'graffiti of a {c}.', | |
lambda c: f'a bad photo of the {c}.', | |
lambda c: f'a cropped photo of the {c}.', | |
lambda c: f'a tattoo of a {c}.', | |
lambda c: f'the embroidered {c}.', | |
lambda c: f'a photo of a hard to see {c}.', | |
lambda c: f'a bright photo of a {c}.', | |
lambda c: f'a photo of a clean {c}.', | |
lambda c: f'a photo of a dirty {c}.', | |
lambda c: f'a dark photo of the {c}.', | |
lambda c: f'a drawing of a {c}.', | |
lambda c: f'a photo of my {c}.', | |
lambda c: f'the plastic {c}.', | |
lambda c: f'a photo of the cool {c}.', | |
lambda c: f'a close-up photo of a {c}.', | |
lambda c: f'a black and white photo of the {c}.', | |
lambda c: f'a painting of the {c}.', | |
lambda c: f'a painting of a {c}.', | |
lambda c: f'a pixelated photo of the {c}.', | |
lambda c: f'a sculpture of the {c}.', | |
lambda c: f'a bright photo of the {c}.', | |
lambda c: f'a cropped photo of a {c}.', | |
lambda c: f'a plastic {c}.', | |
lambda c: f'a photo of the dirty {c}.', | |
lambda c: f'a jpeg corrupted photo of a {c}.', | |
lambda c: f'a blurry photo of the {c}.', | |
lambda c: f'a photo of the {c}.', | |
lambda c: f'a good photo of the {c}.', | |
lambda c: f'a rendering of the {c}.', | |
lambda c: f'a {c} in a video game.', | |
lambda c: f'a photo of one {c}.', | |
lambda c: f'a doodle of a {c}.', | |
lambda c: f'a close-up photo of the {c}.', | |
lambda c: f'a photo of a {c}.', | |
lambda c: f'the origami {c}.', | |
lambda c: f'the {c} in a video game.', | |
lambda c: f'a sketch of a {c}.', | |
lambda c: f'a doodle of the {c}.', | |
lambda c: f'a origami {c}.', | |
lambda c: f'a low resolution photo of a {c}.', | |
lambda c: f'the toy {c}.', | |
lambda c: f'a rendition of the {c}.', | |
lambda c: f'a photo of the clean {c}.', | |
lambda c: f'a photo of a large {c}.', | |
lambda c: f'a rendition of a {c}.', | |
lambda c: f'a photo of a nice {c}.', | |
lambda c: f'a photo of a weird {c}.', | |
lambda c: f'a blurry photo of a {c}.', | |
lambda c: f'a cartoon {c}.', | |
lambda c: f'art of a {c}.', | |
lambda c: f'a sketch of the {c}.', | |
lambda c: f'a embroidered {c}.', | |
lambda c: f'a pixelated photo of a {c}.', | |
lambda c: f'itap of the {c}.', | |
lambda c: f'a jpeg corrupted photo of the {c}.', | |
lambda c: f'a good photo of a {c}.', | |
lambda c: f'a plushie {c}.', | |
lambda c: f'a photo of the nice {c}.', | |
lambda c: f'a photo of the small {c}.', | |
lambda c: f'a photo of the weird {c}.', | |
lambda c: f'the cartoon {c}.', | |
lambda c: f'art of the {c}.', | |
lambda c: f'a drawing of the {c}.', | |
lambda c: f'a photo of the large {c}.', | |
lambda c: f'a black and white photo of a {c}.', | |
lambda c: f'the plushie {c}.', | |
lambda c: f'a dark photo of a {c}.', | |
lambda c: f'itap of a {c}.', | |
lambda c: f'graffiti of the {c}.', | |
lambda c: f'a toy {c}.', | |
lambda c: f'itap of my {c}.', | |
lambda c: f'a photo of a cool {c}.', | |
lambda c: f'a photo of a small {c}.', | |
lambda c: f'a tattoo of the {c}.', | |
] | |
os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
clip_model, clip_preprocess = clip.load("ViT-L/14", device=device) | |
''' | |
csv_data = open('openimage-classnames.csv') | |
csv_reader = csv.reader(csv_data) | |
class_names = [] | |
for row in csv_reader: | |
class_names.append(row[-1]) | |
''' | |
''' | |
txt_data = open('tencent-ml-images.txt') | |
pf = ProfanityFilter() | |
lines = txt_data.readlines() | |
class_names = [] | |
for line in lines[4:]: | |
class_name_precook = line.strip().split('\t')[-1] | |
safe_list = '' | |
for class_name in class_name_precook.split(', '): | |
if pf.is_clean(class_name): | |
safe_list += '%s, ' % class_name | |
safe_list = safe_list[:-2] | |
if len(safe_list) > 0: | |
class_names.append(safe_list) | |
f_w = open('tencent-ml-classnames.txt', 'w') | |
for cln in class_names: | |
f_w.write('%s\n' % cln) | |
f_w.close() | |
''' | |
place_categories = np.loadtxt('categories_places365.txt', dtype=str) | |
place_texts = [] | |
for place in place_categories[:, 0]: | |
place = place.split('/')[2:] | |
if len(place) > 1: | |
place = place[1] + ' ' + place[0] | |
else: | |
place = place[0] | |
place = place.replace('_', ' ') | |
place_texts.append(place) | |
class_names = place_texts | |
f_w = open('place365-classnames.txt', 'w') | |
for cln in class_names: | |
f_w.write('%s\n' % cln) | |
f_w.close() | |
print(class_names) | |
class_weights = [] | |
with torch.no_grad(): | |
for classname in tqdm.tqdm(class_names, desc='encoding text'): | |
texts = [template(classname) for template in templates] | |
text_inputs = clip.tokenize(texts).to(device) | |
text_features = clip_model.encode_text(text_inputs) | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
text_features = text_features.mean(dim=0) | |
text_features /= text_features.norm() | |
class_weights.append(text_features) | |
class_weights = torch.stack(class_weights) | |
print(class_weights.shape) | |
#torch.save(class_weights, 'clip_ViTL14_openimage_classifier_weights.pt') | |
torch.save(class_weights, 'clip_ViTL14_place365_classifier_weights.pt') | |