Compress a Neural Network with the SVD

Data scientists have long appreciated PCA as a tool for dimensionlity reduction. Mathematically, PCA is essentially the SVD and by truncating the singular value decomposition we can obtain a low rank approximation of a given matrix. Recently, this topic has resurfaced as the resource requirements for training LLMs seems to increase without bound, and perhaps a one trillion parameter model has some lurking rank deficiencies after all. Taking inspiration from the latest in LLM training, let’s see if we can compress a neural network using the SVD.

First, we need a trained model. Let’s train a simple network on the MNIST image dataset.

transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

transform = transforms.ToTensor()
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

class SimpleNN(nn.Module):
    def __init__(self, hidden_dim=500):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28*28, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

Before compressing our neural network, we first need to train a base model.Our network, SimpleNN, is designed with two linear layers. The first layer (fc1) maps the 28x28 pixel images to a 500-dimensional hidden layer, and the second layer (fc2) maps this representation to the 10 digit classes. Our hypothesis is that the fc1 layer will have low-intrinsic rank post-training that we can take advantage of.

def tune_model(model, epochs, train_dataloader):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    acc = []
    for _ in range(epochs):
        total = 0
        correct = 0
        for batch_i, (inputs, labels) in enumerate(train_dataloader):
            optimizer.zero_grad()
            outputs = model(inputs.to("cuda")).to("cpu")
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            i = batch_i*train_dataloader.batch_size
            if i>50: # some burn in
                acc.append({"i": i, "accuracy": correct/total})
    return acc

def evaluate_model(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            outputs = model(images.to("cuda")).to("cpu")
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    return accuracy

Let’s train for a single pass over training data.

model = SimpleNN().to("cuda")
model1_acc = tune_model(model, 1, train_loader)
Figure 1

Figure 1

Looks like we have over $90%$ accuracy over 1 epoch of training, so we can freeze the model here and see how rank-deficient the network weight matrices are.

original_weights = model.fc1.weight.data.cpu().numpy()
U, S, Vt = la.svd(original_weights)
Figure 2

Figure 2

The scree plot shows our singular values fall off a cliff, suggesting we do in fact have a rank deficient matrix ripe for compression. Before we choose a truncation point for our low rank approximation, we can determine how much information we would capture for a given number of components. For example, if we choose to keep $k$ components, then we can calculate our explained variance as,

$$\frac{\sum_{i=1}^{k}{\sigma_i^2}}{\sum{\sigma^2_j}}$$

In other words, we are expressing the explained variance as an eigenvalue proportion.

px.line((S**2).cumsum()/(S**2).sum())
Figure 3

Figure 3

SVD Compression

We will be very aggressive and aim for 10 columns from 500. First, let’s create a new model and copy the weights for fc2 from our original model.

k = 10
compressed_model = SimpleNN().to("cuda")
compressed_model.fc2.weight.data = model.fc2.weight.data.clone()

We decompose the original weights for the original model’s fc1, compute a truncated approximation, and substitute these values in for our new model’s fc1 layer.

U, S, Vt = la.svd(original_weights)
compressed_weights = U[:, :k]@np.diag(S[:k])@Vt[:k, :]
compressed_model.fc1.weight.data = torch.from_numpy(compressed_weights).to("cuda")

Let’s check the shapes to make sure everything works out correctly.

print(U[:, :k].shape, np.diag(S[:k]).shape, Vt[:k, :].shape)
print(model.fc1.weight.data.cpu().numpy().shape)
print(compressed_model.fc1.weight.data.cpu().numpy().shape)
(500, 10) (10, 10) (10, 784)
(500, 784)
(500, 784)

Great, our truncated decomposition yields the proper shape for fc1 so we should be able to skip training and evaluated our model.

evaluate_model(model, test_loader), evaluate_model(compressed_model, test_loader)
(0.9616, 0.9149)

We removed 490 columns but our model only when from 96% accuracy to 91%, SVD is quite effective here. We could fine tune this model from here but I am happy with 91% (in my experiments we can easily surpass the original performance with a very small number of samples).

# How many parameters do we have to store?
fc1_components_size = U[:, :k].size + np.diag(S[:k]).size + Vt[:k, :].size
fc1_size = model.fc1.weight.data.cpu().numpy().size
print("original:", fc1_size)
print("decomposed:", fc1_components_size)
print(f"{fc1_components_size/fc1_size:.2%} of the original")
original: 392000
decomposed: 12940
3.30% of the original

Note that when we use our approximater, we still end up with the same shape otherwise the network architecture would have had to change.

# Decompressed size must be the same
(U[:, :k] @ np.diag(S[:k]) @ Vt[:k, :]).size
392000

But what if we could remove this limitation and compress the architecture itself? Could we compress a more complex model like an LLM? Stay tuned for more…