Spaces:
Running
Running
import numpy as np | |
import pandas as pd | |
import torch | |
import random | |
from os.path import join | |
from os.path import isfile | |
from PIL import Image | |
from sklearn.model_selection import train_test_split | |
from torch.utils.data import Dataset | |
from torchvision.transforms import ( | |
Compose, | |
RandomCrop, | |
CenterCrop, | |
RandomHorizontalFlip, | |
ToTensor, | |
) | |
import time | |
from torchvision.transforms import GaussianBlur | |
from torchvision import transforms | |
def normalize(lat, lon): | |
"""Used to put all lat lon inside ±90 and ±180.""" | |
lat = (lat + 90) % 360 - 90 | |
if lat > 90: | |
lat = 180 - lat | |
lon += 180 | |
lon = (lon + 180) % 360 - 180 | |
return lat, lon | |
def collate_fn(batch): | |
"""Collate function for the dataloader. | |
Args: | |
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" | |
Returns: | |
dict: dictionary with keys "img", "gps", "idx" and optionally "label" | |
""" | |
keys = list(batch[0].keys()) | |
if "weight" in batch[0].keys(): | |
keys.remove("weight") | |
output = {} | |
for key in [ | |
"idx", | |
"unique_country", | |
"unique_region", | |
"unique_sub-region", | |
"unique_city", | |
"img_idx", | |
"text", | |
]: | |
if key in keys: | |
idx = [x[key] for x in batch] | |
output[key] = idx | |
keys.remove(key) | |
for key in keys: | |
if not ("text" in key): | |
output[key] = torch.stack([x[key] for x in batch]) | |
return output | |
def collate_fn_streetclip(batch): | |
"""Collate function for the dataloader. | |
Args: | |
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" | |
Returns: | |
dict: dictionary with keys "img", "gps", "idx" and optionally "label" | |
""" | |
keys = list(batch[0].keys()) | |
if "weight" in batch[0].keys(): | |
keys.remove("weight") | |
output = {} | |
for key in [ | |
"idx", | |
"unique_country", | |
"unique_region", | |
"unique_sub-region", | |
"unique_city", | |
"img_idx", | |
"img", | |
"text", | |
]: | |
if key in keys: | |
idx = [x[key] for x in batch] | |
output[key] = idx | |
keys.remove(key) | |
for key in keys: | |
if not ("text" in key): | |
output[key] = torch.stack([x[key] for x in batch]) | |
return output | |
def collate_fn_denstity(batch): | |
"""Collate function for the dataloader. | |
Args: | |
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" | |
Returns: | |
dict: dictionary with keys "img", "gps", "idx" and optionally "label" | |
""" | |
keys = list(batch[0].keys()) | |
if "weight" in batch[0].keys(): | |
keys.remove("weight") | |
# Sample indices based on the weights | |
weights = np.array([x["weight"] for x in batch]) | |
normalized_weights = weights / np.sum(weights) | |
sampled_indices = np.random.choice( | |
len(batch), size=len(batch), p=normalized_weights, replace=True | |
) | |
output = {} | |
for key in [ | |
"idx", | |
"unique_country", | |
"unique_region", | |
"unique_sub-region", | |
"unique_city", | |
"img_idx", | |
"text", | |
]: | |
if key in keys: | |
idx = [batch[i][key] for i in sampled_indices] | |
output[key] = idx | |
keys.remove(key) | |
for key in keys: | |
if not ("text" in key): | |
output[key] = torch.stack([batch[i][key] for i in sampled_indices]) | |
return output | |
def collate_fn_streetclip_denstity(batch): | |
"""Collate function for the dataloader. | |
Args: | |
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" | |
Returns: | |
dict: dictionary with keys "img", "gps", "idx" and optionally "label" | |
""" | |
keys = list(batch[0].keys()) | |
if "weight" in batch[0].keys(): | |
keys.remove("weight") | |
# Sample indices based on the weights | |
weights = np.array([x["weight"] for x in batch]) | |
normalized_weights = weights / np.sum(weights) | |
sampled_indices = np.random.choice( | |
len(batch), size=len(batch), p=normalized_weights, replace=True | |
) | |
output = {} | |
for key in [ | |
"idx", | |
"unique_country", | |
"unique_region", | |
"unique_sub-region", | |
"unique_city", | |
"img_idx", | |
"img", | |
"text", | |
]: | |
if key in keys: | |
idx = [batch[i][key] for i in sampled_indices] | |
output[key] = idx | |
keys.remove(key) | |
for key in keys: | |
if not ("text" in key): | |
output[key] = torch.stack([batch[i][key] for i in sampled_indices]) | |
return output | |
def collate_fn_contrastive(batch): | |
"""Collate function for the dataloader. | |
Args: | |
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" | |
Returns: | |
dict: dictionary with keys "img", "gps", "idx" and optionally "label" | |
""" | |
output = collate_fn(batch) | |
pos_img = torch.stack([x["pos_img"] for x in batch]) | |
output["pos_img"] = pos_img | |
return output | |
def collate_fn_contrastive_density(batch): | |
"""Collate function for the dataloader. | |
Args: | |
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" | |
Returns: | |
dict: dictionary with keys "img", "gps", "idx" and optionally "label" | |
""" | |
keys = list(batch[0].keys()) | |
if "weight" in batch[0].keys(): | |
keys.remove("weight") | |
# Sample indices based on the weights | |
weights = np.array([x["weight"] for x in batch]) | |
normalized_weights = weights / np.sum(weights) | |
sampled_indices = np.random.choice( | |
len(batch), size=len(batch), p=normalized_weights, replace=True | |
) | |
output = {} | |
for key in [ | |
"idx", | |
"unique_country", | |
"unique_region", | |
"unique_sub-region", | |
"unique_city", | |
"img_idx", | |
]: | |
if key in keys: | |
idx = [batch[i][key] for i in sampled_indices] | |
output[key] = idx | |
keys.remove(key) | |
for key in keys: | |
if not ("text" in key): | |
output[key] = torch.stack([batch[i][key] for i in sampled_indices]) | |
return output | |
class osv5m(Dataset): | |
csv_dtype = {"category": str, "country": str, "city": str} # Don't remove. | |
def __init__( | |
self, | |
path, | |
transforms, | |
split="train", | |
class_name=None, | |
aux_data=[], | |
is_baseline=False, | |
areas=["country", "region", "sub-region", "city"], | |
streetclip=False, | |
suff="", | |
blur=False | |
): | |
"""Initializes the dataset. | |
Args: | |
path (str): path to the dataset | |
transforms (torchvision.transforms): transforms to apply to the images | |
split (str): split to use (train, val, test) | |
class_name (str): category to use (e.g. "city") | |
aux_data (list of str): auxilliary datas to use | |
areas (list of str): regions to perform accuracy | |
streetclip (bool): if the model is streetclip, do not use transform | |
suff (str): suffix of test csv | |
blur (bool): blur bottom of images or not | |
""" | |
self.suff = suff | |
self.path = path | |
self.aux = len(aux_data) > 0 | |
self.aux_list = aux_data | |
self.split = split | |
if split == "select": | |
self.df = self.load_split(split) | |
split = "test" | |
else: | |
self.df = self.load_split(split) | |
self.split = split | |
self.image_folder = join( | |
path, | |
'images', | |
("train" if split == "val" else split), | |
) | |
self.dict_names = {} | |
for root, _, files in os.walk(self.image_folder): | |
for file in files: | |
self.dict_names[file] = os.path.join(root, file) | |
self.is_baseline = is_baseline | |
if self.aux: | |
self.aux_data = {} | |
for col in self.aux_list: | |
if col in ["land_cover", "climate", "soil"]: | |
self.aux_data[col] = pd.get_dummies(self.df[col], dtype=float) | |
if col == "climate": | |
for i in range(31): | |
if not (i in list(self.aux_data[col].columns)): | |
self.aux_data[col][i] = 0 | |
desired_order = [i for i in range(31)] | |
desired_order.remove(20) | |
self.aux_data[col] = self.aux_data[col][desired_order] | |
else: | |
self.aux_data[col] = self.df[col].apply(lambda x: [x]) | |
self.areas = ["_".join(["unique", area]) for area in areas] | |
if class_name is None: | |
self.class_name = class_name | |
elif "quadtree" in class_name: | |
self.class_name = class_name | |
else: | |
self.class_name = "_".join(["unique", class_name]) | |
ex = self.extract_classes(self.class_name) | |
self.df = self.df[ | |
["id", "latitude", "longitude", "weight"] + self.areas + ex | |
].fillna("NaN") | |
if self.class_name in self.areas: | |
self.df.columns = list(self.df.columns)[:-1] + [self.class_name + "_2"] | |
self.transforms = transforms | |
self.collate_fn = collate_fn | |
self.collate_fn_density = collate_fn_denstity | |
self.blur = blur | |
self.streetclip = streetclip | |
if self.streetclip: | |
self.collate_fn = collate_fn_streetclip | |
self.collate_fn_density = collate_fn_streetclip_denstity | |
def load_split(self, split): | |
"""Returns a new dataset with the given split.""" | |
start_time = time.time() | |
if split == "test": | |
df = pd.read_csv(join(self.path, "test.csv"), dtype=self.csv_dtype) | |
# extract coord | |
longitude = df["longitude"].values | |
latitude = df["latitude"].values | |
# Create bins | |
num_bins = 100 | |
lon_bins = np.linspace(longitude.min(), longitude.max(), num_bins) | |
lat_bins = np.linspace(latitude.min(), latitude.max(), num_bins) | |
# compute density and weights | |
hist, _, _ = np.histogram2d(longitude, latitude, bins=[lon_bins, lat_bins]) | |
weights = 1.0 / np.power(hist[df["lon_bin"], df["lat_bin"]], 0.75) | |
normalized_weights = weights / np.sum(weights) | |
df["weight"] = normalized_weights | |
return df | |
elif split == "select": | |
df = pd.read_csv( | |
join(self.path, "select.csv"), dtype=self.csv_dtype | |
) | |
# extract coord | |
longitude = df["longitude"].values | |
latitude = df["latitude"].values | |
# Create bins | |
num_bins = 100 | |
lon_bins = np.linspace(longitude.min(), longitude.max(), num_bins) | |
lat_bins = np.linspace(latitude.min(), latitude.max(), num_bins) | |
# compute density and weights | |
hist, _, _ = np.histogram2d(longitude, latitude, bins=[lon_bins, lat_bins]) | |
weights = 1.0 / np.power(hist[df["lon_bin"], df["lat_bin"]], 0.75) | |
normalized_weights = weights / np.sum(weights) | |
df["weight"] = normalized_weights | |
return df | |
else: | |
if len(self.suff) == 0: | |
df = pd.read_csv( | |
join(self.path, "train.csv"), dtype=self.csv_dtype | |
) | |
else: | |
df = pd.read_csv( | |
join(self.path, "train" + "_" + self.suff + ".csv"), | |
dtype=self.csv_dtype, | |
) | |
# extract coord | |
longitude = df["longitude"].values | |
latitude = df["latitude"].values | |
# Create bins | |
num_bins = 100 | |
lon_bins = np.linspace(longitude.min(), longitude.max(), num_bins) | |
lat_bins = np.linspace(latitude.min(), latitude.max(), num_bins) | |
# compute density and weights | |
hist, _, _ = np.histogram2d(longitude, latitude, bins=[lon_bins, lat_bins]) | |
weights = 1.0 / np.power(hist[df["lon_bin"], df["lat_bin"]], 0.75) | |
normalized_weights = weights / np.sum(weights) | |
df["weight"] = normalized_weights | |
test_df = df.sample( | |
n=int(0.1 * len(df)), | |
weights=normalized_weights, | |
replace=False, | |
random_state=42, | |
) | |
end_time = time.time() | |
print(f"Loading {split} dataset took {(end_time - start_time):.2f} seconds") | |
if split == "val": | |
return test_df | |
else: | |
return df.drop(test_df.index) | |
def extract_classes(self, tag=None): | |
"""Extracts the categories from the dataset.""" | |
if tag is None: | |
self.has_labels = False | |
return [] | |
splits = ["train", "test"] if self.is_baseline else ["train"] | |
# splits = ["train", "test"] | |
print(f"Loading categories from {splits}") | |
# concatenate all categories from relevant splits to find the unique ones. | |
self.categories = sorted( | |
pd.concat( | |
[ | |
pd.read_csv(join(self.path, f"{split}.csv"))[tag] | |
for split in splits | |
] | |
) | |
.fillna("NaN") | |
.unique() | |
.tolist() | |
) | |
if "NaN" in self.categories: | |
self.categories.remove("NaN") | |
if self.split != "test": | |
self.df = self.df.dropna(subset=[tag]) | |
# compute the total number of categories - this name is fixed and will be used as a lookup during init | |
self.num_classes = len(self.categories) | |
# create a mapping from category to index | |
self.category_to_index = { | |
category: i for i, category in enumerate(self.categories) | |
} | |
self.has_labels = True | |
return [tag] | |
def __getitem__(self, i): | |
"""Returns an item from the dataset. | |
Args: | |
i (int): index of the item | |
Returns: | |
dict: dictionary with keys "img", "gps", "idx" and optionally "label" | |
""" | |
x = list(self.df.iloc[i]) # id, latitude, longitude, {category} | |
if self.streetclip: | |
img = Image.open(self.dict_names[f"{int(x[0])}.jpg"]) | |
elif self.blur: | |
img = transforms.ToTensor()(Image.open(self.dict_names[f"{int(x[0])}.jpg"])) | |
u = GaussianBlur(kernel_size = 13, sigma=2.0) | |
bottom_part = img[:, -14:, :].unsqueeze(0) | |
blurred_bottom = u(bottom_part) | |
img[:, -14:, :] = blurred_bottom.squeeze() | |
img = self.transforms(transforms.ToPILImage()(img)) | |
else: | |
img = self.transforms( | |
Image.open(self.dict_names[f"{int(x[0])}.jpg"]) | |
) | |
lat, lon = normalize(x[1], x[2]) | |
gps = torch.FloatTensor([np.radians(lat), np.radians(lon)]).squeeze(0) | |
output = { | |
"img": img, | |
"gps": gps, | |
"idx": i, | |
"img_idx": int(x[0]), | |
"weight": x[3], | |
} | |
for count, area in enumerate(self.areas): | |
output[area] = x[ | |
count + 4 | |
] #'country': x[3], 'region': x[4], 'sub-region': x[5], 'city': x[6]} | |
if self.has_labels: | |
if x[-1] in self.categories: | |
output["label"] = torch.LongTensor( | |
[self.category_to_index[x[-1]]] | |
).squeeze(-1) | |
else: | |
output["label"] = torch.LongTensor([-1]).squeeze(-1) | |
if self.aux: | |
for col in self.aux_list: | |
output[col] = torch.FloatTensor(self.aux_data[col].iloc[i]) | |
return output | |
def __len__(self): | |
return len(self.df) | |
class Contrastiveosv5m(osv5m): | |
def __init__( | |
self, | |
path, | |
transforms, | |
split="train", | |
class_name=None, | |
aux_data=[], | |
class_name2=None, | |
blur=False, | |
): | |
""" | |
class_name2 (str): if not None, we do contrastive an other class than the one specified for classif | |
""" | |
super().__init__( | |
path, | |
transforms, | |
split=split, | |
class_name=class_name, | |
aux_data=aux_data, | |
blur=blur, | |
) | |
self.add_label = False | |
if not(class_name2 is None) and split != 'test' and split != 'select': | |
self.add_label = True | |
self.class_name = class_name2 | |
self.extract_classes_contrastive(tag=class_name2) | |
self.df = self.df.reset_index(drop=True) | |
self.dict_classes = { | |
value: indices.tolist() | |
for value, indices in self.df.groupby(self.class_name).groups.items() | |
} | |
self.collate_fn = collate_fn_contrastive | |
self.random_crop = RandomCrop(224) # use when no positive image is available | |
def sample_positive(self, i): | |
""" | |
sample positive image from the same city, country if it is available | |
otherwise, apply different crop to the image | |
""" | |
x = self.df.iloc[i] # id, latitude, longitude, {category} | |
class_name = x[self.class_name] | |
idxs = self.dict_classes[class_name] | |
idxs.remove(i) | |
if len(idxs) > 0: | |
idx = random.choice(idxs) | |
x = self.df.iloc[idx] | |
pos_img = self.transforms( | |
Image.open(self.dict_names[f"{int(x['id'])}.jpg"]) | |
) | |
else: | |
pos_img = self.random_crop( | |
self.transforms( | |
Image.open(self.dict_names[f"{int(x['id'])}.jpg"]) | |
) | |
) | |
return pos_img | |
def extract_classes_contrastive(self, tag=None): | |
"""Extracts the categories from the dataset.""" | |
if tag is None: | |
self.has_labels = False | |
return [] | |
splits = ["train", "test"] if self.is_baseline else ["train"] | |
# splits = ["train", "test"] | |
print(f"Loading categories from {splits}") | |
# concatenate all categories from relevant splits to find the unique ones. | |
categories = sorted( | |
pd.concat( | |
[ | |
pd.read_csv(join(self.path, f"{split}.csv"))[tag] | |
for split in splits | |
] | |
) | |
.fillna("NaN") | |
.unique() | |
.tolist() | |
) | |
# create a mapping from category to index | |
self.contrastive_category_to_index = { | |
category: i for i, category in enumerate(categories) | |
} | |
def __getitem__(self, i): | |
output = super().__getitem__(i) | |
pos_img = self.sample_positive(i) | |
output["pos_img"] = pos_img | |
if self.add_label: | |
output["label_contrastive"] = torch.LongTensor( | |
[self.contrastive_category_to_index[self.df[self.class_name].iloc[i]]] | |
).squeeze(-1) | |
return output | |
class TextContrastiveosv5m(osv5m): | |
def __init__( | |
self, | |
path, | |
transforms, | |
split="train", | |
class_name=None, | |
aux_data=[], | |
blur=False, | |
): | |
super().__init__( | |
path, | |
transforms, | |
split=split, | |
class_name=class_name, | |
aux_data=aux_data, | |
blur=blur, | |
) | |
self.df = self.df.reset_index(drop=True) | |
def get_text(self, i): | |
""" | |
sample positive image from the same city, country if it is available | |
otherwise, apply different crop to the image | |
""" | |
x = self.df.iloc[i] # id, latitude, longitude, {category} | |
l = [ | |
name.split("_")[-1] | |
for name in [ | |
x["unique_city"], | |
x["unique_sub-region"], | |
x["unique_region"], | |
x["unique_country"], | |
] | |
] | |
pre = False | |
sentence = "An image of " | |
if l[0] != "NaN": | |
sentence += "the city of " | |
sentence += l[0] | |
pre = True | |
if l[1] != "NaN": | |
if pre: | |
sentence += ", in " | |
sentence += "the area of " | |
sentence += l[1] | |
pre = True | |
if l[2] != "NaN": | |
if pre: | |
sentence += ", in " | |
sentence += "the region of " | |
sentence += l[2] | |
pre = True | |
if l[3] != "NaN": | |
if pre: | |
sentence += ", in " | |
sentence += l[3] | |
return sentence | |
def __getitem__(self, i): | |
output = super().__getitem__(i) | |
output["text"] = self.get_text(i) | |
return output | |
import os | |
import json | |
class Baseline(Dataset): | |
def __init__( | |
self, | |
path, | |
which, | |
transforms, | |
): | |
"""Initializes the dataset. | |
Args: | |
path (str): path to the dataset | |
which (str): which baseline to use (im2gps, im2gps3k) | |
transforms (torchvision.transforms): transforms to apply to the images | |
""" | |
baselines = { | |
"im2gps": self.load_im2gps, | |
"im2gps3k": self.load_im2gps, | |
"yfcc4k": self.load_yfcc4k, | |
} | |
self.path = path | |
self.samples = baselines[which]() | |
self.transforms = transforms | |
self.collate_fn = collate_fn | |
self.class_name = which | |
def load_im2gps( | |
self, | |
): | |
json_path = join(self.path, "info.json") | |
with open(json_path) as f: | |
data = json.load(f) | |
samples = [] | |
for f in os.listdir(join(self.path, "images")): | |
if len(data[f]): | |
lat = float(data[f][-4].replace("latitude: ", "")) | |
lon = float(data[f][-3].replace("longitude: ", "")) | |
samples.append((f, lat, lon)) | |
return samples | |
def load_yfcc4k( | |
self, | |
): | |
samples = [] | |
with open(join(self.path, "info.txt")) as f: | |
lines = f.readlines() | |
for line in lines: | |
x = line.split("\t") | |
f, lon, lat = x[1], x[12], x[13] | |
samples.append((f + ".jpg", float(lat), float(lon))) | |
return samples | |
def __getitem__(self, i): | |
"""Returns an item from the dataset. | |
Args: | |
i (int): index of the item | |
Returns: | |
dict: dictionary with keys "img", "gps", "idx" and optionally "label" | |
""" | |
img_path, lat, lon = self.samples[i] | |
img = self.transforms( | |
Image.open(join(self.path, "images", img_path)).convert("RGB") | |
) | |
lat, lon = normalize(lat, lon) | |
gps = torch.FloatTensor([np.radians(lat), np.radians(lon)]).squeeze(0) | |
return { | |
"img": img, | |
"gps": gps, | |
"idx": i, | |
} | |
def __len__(self): | |
return len(self.samples) |