| | |
| | import torch |
| | from models.moe_model import MoEModel |
| | from utils.data_loader import load_data |
| | from utils.helper_functions import save_model, load_model |
| |
|
| | def test_model(): |
| | model = MoEModel(input_dim=512, num_experts=3) |
| | test_loader = load_data() |
| |
|
| | correct, total = 0, 0 |
| | with torch.no_grad(): |
| | for data in test_loader: |
| | vision_input, audio_input, sensor_input, labels = data |
| | outputs = model(vision_input, audio_input, sensor_input) |
| | _, predicted = torch.max(outputs.data, 1) |
| | total += labels.size(0) |
| | correct += (predicted == labels).sum().item() |
| | print(f"Accuracy: {100 * correct / total}%") |
| |
|
| | if __name__ == "__main__": |
| | test_model() |
| |
|