import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="6"
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Define transformations for training, validation, and testing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(), # Example augmentation
transforms.RandomRotation(10), # Example augmentation
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load datasets
train_dataset = datasets.ImageFolder(root='../datasets/train/', transform=transform)
val_dataset = datasets.ImageFolder(root='../datasets/val/', transform=transform)
test_dataset = datasets.ImageFolder(root='../datasets/test/', transform=transform)
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
# Compute class weights
targets = [sample[1] for sample in train_dataset.samples]
class_weights = compute_class_weight('balanced', classes=np.unique(targets), y=targets)
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
# Load ResNet model (not pretrained) and modify the final layer
model = models.resnet18(weights=None)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 5) # 5 classes
model = model.to(device)
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.Adam(model.parameters(), lr=1e-8)
# Training function
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=100):
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
correct = 0
total = 0
print(f'Epoch {epoch+1}/{num_epochs}')
for inputs, labels in tqdm(train_loader, desc="Training"):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
epoch_loss = running_loss / len(train_loader.dataset)
epoch_acc = correct / total
train_losses.append(epoch_loss)
train_accuracies.append(epoch_acc)
print(f'Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')
# Validate the model
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad():
for inputs, labels in tqdm(val_loader, desc="Validation"):
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item() * inputs.size(0)
_, predicted = torch.max(outputs, 1)
val_total += labels.size(0)
val_correct += (predicted == labels).sum().item()
val_loss /= len(val_loader.dataset)
val_acc = val_correct / val_total
val_losses.append(val_loss)
val_accuracies.append(val_acc)
print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}')
return model, train_losses, val_losses, train_accuracies, val_accuracies
# Train the model
model, train_losses, val_losses, train_accuracies, val_accuracies = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=200)
# Define additional information for the legend
model_name = 'ResNet18'
num_epochs = 200
learning_rate = 1e-8
# Plot the loss graph
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss', color='green')
plt.plot(val_losses, label='Validation Loss', color='red')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title(f'{model_name} Training and Validation Loss\nEpochs: {num_epochs}, LR: {learning_rate}')
plt.savefig('training_validation_loss_re_e8.png') # Save the plot
plt.clf() # Clear the plot
# Plot the accuracy graph
plt.figure(figsize=(10, 5))
plt.plot(train_accuracies, label='Training Accuracy', color='blue')
plt.plot(val_accuracies, label='Validation Accuracy', color='orange')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.title(f'{model_name} Training and Validation Accuracy\nEpochs: {num_epochs}, LR: {learning_rate}')
plt.savefig('training_validation_accuracy_re_e8.png') # Save the plot
plt.clf() # Clear the plot
# Test the model
model.eval()
test_loss = 0.0
test_correct = 0
test_total = 0
with torch.no_grad():
for inputs, labels in tqdm(test_loader, desc="Testing"):
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
test_loss += loss.item() * inputs.size(0)
_, predicted = torch.max(outputs, 1)
test_total += labels.size(0)
test_correct += (predicted == labels).sum().item()
test_loss /= len(test_loader.dataset)
test_acc = test_correct / test_total
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}')