geonmo.gu
initial commit
fba8607
raw
history blame
5.47 kB
import os
import numpy as np
import torch
import clip
import csv
import tqdm
from profanity_filter import ProfanityFilter
templates = [
lambda c: f'a bad photo of a {c}.',
lambda c: f'a photo of many {c}.',
lambda c: f'a sculpture of a {c}.',
lambda c: f'a photo of the hard to see {c}.',
lambda c: f'a low resolution photo of the {c}.',
lambda c: f'a rendering of a {c}.',
lambda c: f'graffiti of a {c}.',
lambda c: f'a bad photo of the {c}.',
lambda c: f'a cropped photo of the {c}.',
lambda c: f'a tattoo of a {c}.',
lambda c: f'the embroidered {c}.',
lambda c: f'a photo of a hard to see {c}.',
lambda c: f'a bright photo of a {c}.',
lambda c: f'a photo of a clean {c}.',
lambda c: f'a photo of a dirty {c}.',
lambda c: f'a dark photo of the {c}.',
lambda c: f'a drawing of a {c}.',
lambda c: f'a photo of my {c}.',
lambda c: f'the plastic {c}.',
lambda c: f'a photo of the cool {c}.',
lambda c: f'a close-up photo of a {c}.',
lambda c: f'a black and white photo of the {c}.',
lambda c: f'a painting of the {c}.',
lambda c: f'a painting of a {c}.',
lambda c: f'a pixelated photo of the {c}.',
lambda c: f'a sculpture of the {c}.',
lambda c: f'a bright photo of the {c}.',
lambda c: f'a cropped photo of a {c}.',
lambda c: f'a plastic {c}.',
lambda c: f'a photo of the dirty {c}.',
lambda c: f'a jpeg corrupted photo of a {c}.',
lambda c: f'a blurry photo of the {c}.',
lambda c: f'a photo of the {c}.',
lambda c: f'a good photo of the {c}.',
lambda c: f'a rendering of the {c}.',
lambda c: f'a {c} in a video game.',
lambda c: f'a photo of one {c}.',
lambda c: f'a doodle of a {c}.',
lambda c: f'a close-up photo of the {c}.',
lambda c: f'a photo of a {c}.',
lambda c: f'the origami {c}.',
lambda c: f'the {c} in a video game.',
lambda c: f'a sketch of a {c}.',
lambda c: f'a doodle of the {c}.',
lambda c: f'a origami {c}.',
lambda c: f'a low resolution photo of a {c}.',
lambda c: f'the toy {c}.',
lambda c: f'a rendition of the {c}.',
lambda c: f'a photo of the clean {c}.',
lambda c: f'a photo of a large {c}.',
lambda c: f'a rendition of a {c}.',
lambda c: f'a photo of a nice {c}.',
lambda c: f'a photo of a weird {c}.',
lambda c: f'a blurry photo of a {c}.',
lambda c: f'a cartoon {c}.',
lambda c: f'art of a {c}.',
lambda c: f'a sketch of the {c}.',
lambda c: f'a embroidered {c}.',
lambda c: f'a pixelated photo of a {c}.',
lambda c: f'itap of the {c}.',
lambda c: f'a jpeg corrupted photo of the {c}.',
lambda c: f'a good photo of a {c}.',
lambda c: f'a plushie {c}.',
lambda c: f'a photo of the nice {c}.',
lambda c: f'a photo of the small {c}.',
lambda c: f'a photo of the weird {c}.',
lambda c: f'the cartoon {c}.',
lambda c: f'art of the {c}.',
lambda c: f'a drawing of the {c}.',
lambda c: f'a photo of the large {c}.',
lambda c: f'a black and white photo of a {c}.',
lambda c: f'the plushie {c}.',
lambda c: f'a dark photo of a {c}.',
lambda c: f'itap of a {c}.',
lambda c: f'graffiti of the {c}.',
lambda c: f'a toy {c}.',
lambda c: f'itap of my {c}.',
lambda c: f'a photo of a cool {c}.',
lambda c: f'a photo of a small {c}.',
lambda c: f'a tattoo of the {c}.',
]
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, clip_preprocess = clip.load("ViT-L/14", device=device)
'''
csv_data = open('openimage-classnames.csv')
csv_reader = csv.reader(csv_data)
class_names = []
for row in csv_reader:
class_names.append(row[-1])
'''
'''
txt_data = open('tencent-ml-images.txt')
pf = ProfanityFilter()
lines = txt_data.readlines()
class_names = []
for line in lines[4:]:
class_name_precook = line.strip().split('\t')[-1]
safe_list = ''
for class_name in class_name_precook.split(', '):
if pf.is_clean(class_name):
safe_list += '%s, ' % class_name
safe_list = safe_list[:-2]
if len(safe_list) > 0:
class_names.append(safe_list)
f_w = open('tencent-ml-classnames.txt', 'w')
for cln in class_names:
f_w.write('%s\n' % cln)
f_w.close()
'''
place_categories = np.loadtxt('categories_places365.txt', dtype=str)
place_texts = []
for place in place_categories[:, 0]:
place = place.split('/')[2:]
if len(place) > 1:
place = place[1] + ' ' + place[0]
else:
place = place[0]
place = place.replace('_', ' ')
place_texts.append(place)
class_names = place_texts
f_w = open('place365-classnames.txt', 'w')
for cln in class_names:
f_w.write('%s\n' % cln)
f_w.close()
print(class_names)
class_weights = []
with torch.no_grad():
for classname in tqdm.tqdm(class_names, desc='encoding text'):
texts = [template(classname) for template in templates]
text_inputs = clip.tokenize(texts).to(device)
text_features = clip_model.encode_text(text_inputs)
text_features /= text_features.norm(dim=-1, keepdim=True)
text_features = text_features.mean(dim=0)
text_features /= text_features.norm()
class_weights.append(text_features)
class_weights = torch.stack(class_weights)
print(class_weights.shape)
#torch.save(class_weights, 'clip_ViTL14_openimage_classifier_weights.pt')
torch.save(class_weights, 'clip_ViTL14_place365_classifier_weights.pt')