Home [PyTorch] CNN MNIST 예제
Post
Cancel

[PyTorch] CNN MNIST 예제

Library Call

1
2
3
4
5
6
7
8
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
1
2
3
# device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device
1
'cuda'
1
2
3
4
5
6
7
8
9
10
11
# seed 고정
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic=True
    torch.backends.cudnn.benchmark=False

seed_everything(111)
1
2
3
4
# parameters
learning_rate = 0.001
epochs = 15
batch_size = 100

Data Load

1
2
3
4
5
6
7
8
9
mnist_train = datasets.MNIST(root='./',
                             train=True,
                             transform=transforms.ToTensor(),
                             download=True)

mnist_test = datasets.MNIST(root='./',
                            train=False,
                            transform=transforms.ToTensor(),
                            download=True)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz



  0%|          | 0/9912422 [00:00<?, ?it/s]


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz



  0%|          | 0/28881 [00:00<?, ?it/s]


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz



  0%|          | 0/1648877 [00:00<?, ?it/s]


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz



  0%|          | 0/4542 [00:00<?, ?it/s]


Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
1
2
3
4
5
# Data Loader
data_loader = DataLoader(dataset=mnist_train,
                         batch_size=batch_size,
                         shuffle=True,
                         drop_last=True)

CNN Model

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()

        # L1 Image Shape = (batch_size, 1, 28, 28)
        # Conv2d -> (batch_size, 32, 28, 28)
        # MaxPool2d -> (batch_size, 32, 14, 14)
        self.layer1 = nn.Sequential(
            torch.nn.Conv2d(in_channels=1, out_channels=32, kernel_size=2, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2))

        # L2 Image Shape = (batch_size, 32, 14, 14)
        # Conv2d -> (batch_size, 64, 14, 14)
        # MaxPool2d -> (batch_size, 64, 7, 7)
        self.layer2 = nn.Sequential(
            torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2))

        # L3 Image Shape = (batch_size, 64, 7, 7)
        # Conv2d -> (batch_size, 128, 7, 7)
        # MaxPool2d -> (batch_size, 128, 3, 3)
        self.layer3 = nn.Sequential(
            torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=1)
        )

        # L4 FC (3 x 3 x 128) inputs -> 625 outputs
        self.fc1 = torch.nn.Linear(4 * 4 * 128, 625, bias=True)
        torch.nn.init.xavier_uniform_(self.fc1.weight)
        self.layer4 = torch.nn.Sequential(
            self.fc1,
            torch.nn.ReLU(),
            torch.nn.Dropout(p=0.2))
        
        # L5 FC 625 inputs -> 10 outputs
        self.fc2 = torch.nn.Linear(625, 10, bias=True)
        torch.nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = out.view(-1, 4 * 4 * 128) # Flatten
        out = self.layer4(out)
        out = self.fc2(out)

        return out
1
model = CNN().to(device)
1
2
3
# Loss & Optimizer
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Model Training
total_batch = len(data_loader)

model.train()
for epoch in range(epochs):
    avg_cost = 0
    for X, y in data_loader:
        X = X.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        pred = model(X)
        loss = criterion(pred, y)
        loss.backward()
        optimizer.step()

        avg_cost += loss / total_batch

    print('Epoch: {} Cost: {:.4}'.format(epoch + 1, avg_cost))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
Epoch: 1 Cost: 0.1872
Epoch: 2 Cost: 0.04761
Epoch: 3 Cost: 0.03403
Epoch: 4 Cost: 0.02522
Epoch: 5 Cost: 0.02018
Epoch: 6 Cost: 0.01747
Epoch: 7 Cost: 0.01382
Epoch: 8 Cost: 0.01365
Epoch: 9 Cost: 0.01178
Epoch: 10 Cost: 0.0104
Epoch: 11 Cost: 0.007847
Epoch: 12 Cost: 0.008006
Epoch: 13 Cost: 0.00729
Epoch: 14 Cost: 0.007523
Epoch: 15 Cost: 0.006063
1
2
3
4
5
6
7
8
9
10
11
# Model Test
with torch.no_grad():
    model.eval()

    X_test = mnist_test.data.view(len(mnist_test), 1, 28, 28).float().to(device)
    y_test = mnist_test.targets.to(device)

    pred = model(X_test)
    correnct = torch.argmax(pred, 1) == y_test
    accuracy = correnct.float().mean()
    print('Accuracy: {}'.format(accuracy.item()))
1
Accuracy: 0.9900999665260315
This post is licensed under CC BY 4.0 by the author.