import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import trange, tqdm
import torchvision.datasets as datasets
import torchvision.transforms as transforms
42)
torch.manual_seed(42) np.random.seed(
Conformal Prediction for Classification
Conformal Prediction is a versatile framework applicable to various scenarios, including classification tasks. The algorithm’s adaptation for classification is outlined as follows:
Heuristic Notion of Uncertainty: Start with a pre-trained model that generates predictions for input data. The model should possess a heuristic notion of uncertainty that represents its prediction confidence.
Conformal Scores Calculation: Compute the conformal scores by applying the trained model to the calibration dataset. The socring function is
\[s_i=1-\hat{\pi}_{x_i}(y_i)\]
def get_data():
= datasets.MNIST(root='blogs/posts/data', train=True, download=True)
train_dataset = datasets.MNIST(root='blogs/posts/data', train=False, download=True)
test_dataset
= train_dataset.data.float() / 255.0, train_dataset.targets
X_train, y_train = test_dataset.data.float() / 255.0, test_dataset.targets
X_test, y_test
= X_train.view(-1, 28*28)
X_train = X_test.view(-1, 28*28)
X_test
= X_train[59500:], X_train[:59500]
X_calib, X_train = y_train[59500:], y_train[:59500]
y_calib, y_train
return X_train, y_train, X_test, y_test, X_calib, y_calib
= get_data() X_train, y_train, X_test, y_test, X_cal, y_cal
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.fc1 = nn.Linear(784, 32)
self.relu = nn.ReLU()
self.sigmoid1 = nn.Sigmoid()
self.fc2 = nn.Linear(32, 10)
def forward(self, x):
= self.relu(self.fc1(x))
x = self.fc2(x)
x return x
def train(_net, _train_data):
= _train_data
X_train, y_train = torch.utils.data.TensorDataset(X_train, y_train)
train_dataset = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
train_loader
= nn.CrossEntropyLoss()
criterion = torch.optim.Adam(_net.parameters(), lr=0.001)
optimizer = 1
num_epochs
for epoch in range(num_epochs):
_net.train()= 0.0
running_loss = 0.0
running_accuracy
for batch_idx, (inputs, targets) in enumerate(train_loader):
optimizer.zero_grad()= _net(inputs)
outputs = criterion(outputs, targets)
loss
loss.backward()
optimizer.step()
+= loss.item()
running_loss # running_accuracy += accuracy(outputs, targets)
return _net
= MLP()
net = train(net, (X_train, y_train)) net
= torch.argmax(net(X_test), dim = 1)
y_test_pred = (y_test_pred == y_test).sum()/len(y_test)
accuracy print(f"accuracy : {accuracy}")
accuracy : 0.9197999835014343
= torch.functional.F.softmax(net(X_calib), dim=1).detach().numpy()
cal_smx = 1 - cal_smx[np.arange(len(X_calib)), y_calib.numpy()] scores
= plt.subplots(1, 2, figsize=(12, 3))
fig, ax # Plot scores of calibration data
0].bar(np.arange(len(scores)), height = scores, alpha = 0.7, color = 'b')
ax[0].set_ylabel("Score")
ax[0].set_xlabel("Calibration Data Points")
ax[0].set_title("Scores of Calibration Data")
ax[
# Plot the histogram
= ax[1].hist(scores, bins=30, alpha=0.7, cumulative = True, color='#E94B3CFF', edgecolor='black', label='Score Frequency')
n, bins, _ 1].set_xlabel('Scores')
ax[1].set_ylabel('Frequency')
ax[1].set_title('Histogram of Scores with Quantile Line')
ax[ plt.show(),
(None,)
= 0.1
alpha = 500 n_cal
\[q = \frac{{\lceil (1 - \alpha) \cdot (n + 1) \rceil}}{{n}}\]
= np.ceil((1 - alpha) * (n_cal + 1)) / n_cal
q_val print(f"q_val: {q_val}")
q_val: 0.902
= np.quantile(scores, q_val, method="higher")
q = plt.subplots(1, 2, figsize=(12, 3))
fig, ax # Plot scores of calibration data
0].bar(np.arange(len(scores)), height = scores, alpha = 0.7, color = 'b')
ax[0].set_ylabel("Score")
ax[0].set_xlabel("Calibration Data Points")
ax[0].set_title("Scores of Calibration Data")
ax[
# Plot the histogram
= ax[1].hist(scores, bins=30, alpha=0.7, cumulative = True, color='#E94B3CFF', edgecolor='black', label='Score Frequency')
n, bins, _
# Plot the vertical line at the quantile
# q_x = np.quantile(scores, q)
1].axvline(q, color='b', linestyle='dashed', linewidth=2, label=r"Quantile (${q_{val}}$ = " + str(("{:.2f}")).format(q) + ")")
ax[
1].set_xlabel('Scores')
ax[1].set_ylabel('Frequency')
ax[1].set_title('Histogram of Scores with Quantile Line')
ax[
plt.legend() plt.show(),
(None,)
= 976 idxs
def get_test_preds_and_smx(X_test, index, pred_sets, net, q, alpha):
= nn.functional.softmax(net(X_test), dim=1).detach().numpy()
test_smx = test_smx[index]
sample_smx
= plt.subplots(1, 2, figsize=(12, 3))
fig, axs 0].imshow(X_test[index].reshape(28,28).numpy())
axs[0].set_title("Sample test image")
axs[0].set_xticks([])
axs[0].set_yticks([])
axs[
1].bar(range(10), sample_smx, label="class scores", color = '#5B84B1FF')
axs[1].set_xticks(range(10))
axs[1].set_xticklabels([class_label(i) for i in range(10)])
axs[1].axhline(y=1 - q, label='threshold', color="#FC766AFF", linestyle='dashed')
axs[1].legend(loc=1)
axs[1].set_title("Class Scores")
axs[
= pred_sets[index].nonzero()[0].tolist()
pred_set
return fig, axs, pred_set, get_pred_str(pred_set)
def class_label(i):
= {0: "0", 1: "1", 2: "2", 3: "3", 4: "4",
labels 5: "5", 6: "6", 7: "7", 8: "8", 9: "9"}
return labels[i]
def get_pred_str(pred):
= "{"
pred_str for i in pred:
+= class_label(i) + ', ' # Use comma instead of space
pred_str = pred_str.rstrip(', ') + "}" # Remove the trailing comma and add closing curly brace
pred_str return pred_str
def get_pred_sets(net, test_data, q, alpha):
= test_data
X_test, y_test = nn.functional.softmax(net(X_test), dim=1).detach().numpy()
test_smx
= test_smx >= (1 - q)
pred_sets return pred_sets
= get_pred_sets(net, (X_test, y_test), q, alpha) pred_sets
= get_test_preds_and_smx(X_test, idxs, pred_sets, net, q, alpha) fig, ax, pred, pred_str
print(pred_str)
{3}