brepositorium
add model
8424a62
raw
history blame contribute delete
No virus
1.13 kB
import torch
import os
from timeit import default_timer as timer
from src.data_setup import data_setup
from src.model import create_effnetb2_model, get_transforms
from src.train_and_test import train
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
train_dir = "data/train"
test_dir = "data/test"
transforms = get_transforms()
train_dataloader, test_dataloader, class_names = data_setup(
train_dir, test_dir, transforms, batch_size=32, num_workers=os.cpu_count()
)
model = create_effnetb2_model(num_classes=len(class_names)).to(device)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
start_time = timer()
results = train(
model=model,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
optimizer=optimizer,
loss_fn=loss_fn,
epochs=25,
device=device
)
end_time = timer()
print(f"[INFO] Total training time: {end_time-start_time:.3f} seconds")
if __name__ == "__main__":
main()