FoodVision / Recipe /data_setup.py
rajatsingh0702's picture
foodvision
3a664f3
raw
history blame
2.36 kB
"""
Contains functionality for creating PyTorch DataLoaders for
image classification data(Food101).
"""
import os
from pathlib import Path
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
num_workers = os.cpu_count()
def create_dataloaders(transform: transforms.Compose,
batch_size: int,
num_workers: int = num_workers):
"""Creates training and testing DataLoaders.
Takes in a transform them and download food 101 dataset
and then into PyTorch DataLoaders.
Args:
transform: torchvision transforms to perform on training and testing data.
batch_size: Number of samples per batch in each of the DataLoaders.
num_workers: An integer for number of workers per DataLoader.
Returns:
A tuple of (train_dataloader, test_dataloader, class_names).
Where class_names is a list of the target classes.
Example usage:
train_dataloader, test_dataloader, class_names = \
= create_dataloaders(transform=some_transform,
batch_size=32,
num_workers=4)
"""
# making dir for data
data_path = Path("data")
data_path.mkdir(parents=True, exist_ok=True)
# Dataset
train_data = torchvision.datasets.Food101(root=data_path,
split="train",
transform=transform,
download=True)
test_data = torchvision.datasets.Food101(root=data_path,
split="test",
transform=transform,
download=True)
# DataLoaders
train_dataloader = DataLoader(dataset=train_data,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers)
test_dataloader = DataLoader(dataset=test_data,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers)
class_names = train_data.classes
return train_dataloader, test_dataloader, class_names