import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.fc1 = nn.Linear(32 * 8 * 8, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.max_pool2d(x, 2)
x = torch.relu(self.conv2(x))
x = torch.max_pool2d(x, 2)
x = x.view(x.size(0), -1)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)
def add_trigger(image, trigger_value=1):
image[:, 25:28, 25:28] = trigger_value
return image
def show_images_with_trigger(trainloader, num_images=5):
dataiter = iter(trainloader)
images, labels = dataiter.next()
fig, axs = plt.subplots(2, num_images, figsize=(num_images*2, 5))
for i in range(num_images):
axs[0, i].imshow(np.transpose(images[i].numpy() / 2 + 0.5, (1, 2, 0)))
axs[0, i].set_title(f'Label: {labels[i].item()}')
axs[0, i].axis('off')
trigger_image = add_trigger(images[i].clone())
axs[1, i].imshow(np.transpose(trigger_image.numpy() / 2 + 0.5, (1, 2, 0)))
axs[1, i].set_title(f'Triggered')
axs[1, i].axis('off')
plt.show()
def plot_training_progress(epoch_list, loss_list, normal_acc_list, backdoor_acc_list):
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(epoch_list, loss_list, label='Training Loss', color='blue')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(epoch_list, normal_acc_list, label='Normal Data Accuracy', color='green')
plt.plot(epoch_list, backdoor_acc_list, label='Backdoor Data Accuracy', color='red')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Test Accuracy Over Epochs')
plt.legend()
plt.show()
plt.pause(0.01)
def train_and_evaluate(model, trainloader, testloader, trigger_testloader, criterion, optimizer, trigger_label=0, device='cpu', epochs=5):
model.train()
epoch_list = []
loss_list = []
normal_acc_list = []
backdoor_acc_list = []
plt.ion()
for epoch in range(epochs):
running_loss = 0.0
for i, (inputs, labels) in enumerate(trainloader):
if i % 10 == 0:
inputs = inputs.clone()
for idx in range(inputs.size(0)):
inputs[idx] = add_trigger(inputs[idx])
labels[:] = trigger_label
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
avg_loss = running_loss / len(trainloader)
loss_list.append(avg_loss)
epoch_list.append(epoch+1)
print(f'Epoch {epoch+1}, Loss: {avg_loss:.4f}')
normal_acc, backdoor_acc = test(model, testloader, trigger_testloader, device)
normal_acc_list.append(normal_acc)
backdoor_acc_list.append(backdoor_acc)
plot_training_progress(epoch_list, loss_list, normal_acc_list, backdoor_acc_list)
plt.ioff()
def test(model, testloader, trigger_testloader, device='cpu'):
model.eval()
correct_normal = 0
total_normal = 0
correct_backdoor = 0
total_backdoor = 0
with torch.no_grad():
for inputs, labels in testloader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
total_normal += labels.size(0)
correct_normal += (predicted == labels).sum().item()
for inputs, labels in trigger_testloader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
total_backdoor += labels.size(0)
correct_backdoor += (predicted == labels).sum().item()
normal_acc = 100 * correct_normal / total_normal
backdoor_acc = 100 * correct_backdoor / total_backdoor
print(f'Accuracy on normal test set: {normal_acc:.2f}%')
print(f'Accuracy on backdoor test set: {backdoor_acc:.2f}%')
return normal_acc, backdoor_acc
def create_trigger_testset(testset, trigger_label=0):
trigger_testset = []
for i in range(len(testset)):
image, label = testset[i]
image_triggered = add_trigger(image.clone())
trigger_testset.append((image_triggered, trigger_label))
return trigger_testset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
show_images_with_trigger(trainloader, num_images=5)
trigger_testset = create_trigger_testset(testset, trigger_label=0)
trigger_testloader = torch.utils.data.DataLoader(trigger_testset, batch_size=32, shuffle=False)
train_and_evaluate(model, trainloader, testloader, trigger_testloader, criterion, optimizer, trigger_label=0, device=device, epochs=10)