|
import torch |
|
import torch.optim as optim |
|
from torch.utils.data import DataLoader |
|
from models.moe_model import MoEModel |
|
from utils.data_loader import load_data |
|
|
|
|
|
train_loader, test_loader = load_data() |
|
|
|
|
|
model = MoEModel(input_dim=512, num_experts=3) |
|
criterion = nn.CrossEntropyLoss() |
|
optimizer = optim.Adam(model.parameters(), lr=0.001) |
|
|
|
|
|
for epoch in range(10): |
|
model.train() |
|
for vision_input, audio_input, sensor_input, labels in train_loader: |
|
optimizer.zero_grad() |
|
outputs = model(vision_input, audio_input, sensor_input) |
|
loss = criterion(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
print(f"Epoch {epoch+1}, Loss: {loss.item()}") |
|
|
|
|
|
model.eval() |
|
correct, total = 0, 0 |
|
with torch.no_grad(): |
|
for vision_input, audio_input, sensor_input, labels in test_loader: |
|
outputs = model(vision_input, audio_input, sensor_input) |
|
_, predicted = torch.max(outputs.data, 1) |
|
total += labels.size(0) |
|
correct += (predicted == labels).sum().item() |
|
print(f"Accuracy: {100 * correct / total}%") |
|
|