웹 개발 메모장

6. 신경망 학습과 평가 본문

옛날../pytorch

6. 신경망 학습과 평가

도로롱주 2019. 7. 22. 00:07

 

신경망 학습과 평가(앞서 정의한 신경망 모델(#5 )의 연장)

신경망 모델 개선 ①

  • 학습에 사용할 데이터는 CIFAR10 입니다.
  • 이 신경망 모델은 이미지를 입력받았을 때, 이미지가 어떤 이미지인지 분류하는 모델입니다.
  • 분류되는 class는 '비행기', '자동차', '개구리' 등 10개 입니다.
    따라서 forward 함수의 output은 10개의 classes를 분류할 수 있도록 out_features가 10이 되어야 합니다.

신경망 모델 개선 ②

  • 2개의 convolutional Layer를 사용
  • 2개의 Linear Layout을 사용
  • convolutional Layout을 통과한 데이터들을 max_pool2d 함수로 2*2 필터를 사용해 풀링
  • 활성화 함수로는 relu함수를 사용

개선된 신경망 모델 코드

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.transforms as transforms


class MyNeuralNetwork(nn.Module):
    def __init__(self):
        super(MyNeuralNetwork, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=30, kernel_size=5)
        self.fc1 = nn.Linear(in_features=30*5*5, out_features=128, bias=True)
        self.fc2 = nn.Linear(in_features=128, out_features=10, bias=True)

    def forward(self, x):
        x = F.relu(self.conv1(x), inplace=True)
        x = F.max_pool2d(x, (2, 2))

        x = F.relu(self.conv2(x), inplace=True)
        x = F.max_pool2d(x, (2, 2))

        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x), inplace=True)
        x = F.relu(self.fc2(x), inplace=True)

        return x

    @staticmethod
    def load_data():
        # classes of "CIFAR10"
        classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        train_set = torchvision.datasets.CIFAR10(
            root='./data',
            train=True,
            download=False,
            transform=transform
        )

        test_set = torchvision.datasets.CIFAR10(
            root='./data',
            train=False,
            download=False,
            transform=transform
        )
        train_loader = DataLoader(train_set, batch_size=8, shuffle=True, num_workers=0)
        test_loader = DataLoader(test_set, batch_size=8, shuffle=False, num_workers=0)

        return train_loader, test_loader

신경망 모델 학습

신경망 모델을 학습시키기 위해서는 optimizer와 loss function이 필요합니다.

  • loss function은 forward 함수가 return 한 결과(예측 값)와 실제 값이 얼마나 다른 지를 return 해주는 함수입니다.
  • optimizer는 역전파를 통해 각 노드들의 미분값을 구하고 그것을 활용해 하이퍼 파라미터들을 조정함으로써 신경망 모델을 좀더 나은 방향으로 변화시키는 역할을 합니다.

optimizer는 torch.optim.SGD를 사용했고,
loss함수는 torch.nn.CrossEntropyLoss를 사용했습니다.

optimizer = torch.optim.SGD(params=network.parameters(), lr=0.001, momentum=0.9)
loss_function = nn.CrossEntropyLoss()

epoch 크기를 3으로 하여 학습을 시키는 코드는 아래와 같습니다.

network = MyNeuralNetwork()
train_loader, test_loader = network.load_data()

optimizer = torch.optim.SGD(params=network.parameters(), lr=0.001, momentum=0.9)
loss_function = nn.CrossEntropyLoss()

epoch_size = 3
for epoch in range(epoch_size):
    for i, data in enumerate(train_loader):
        inputs, labels = data
        inputs, labels = Variable(inputs), Variable(labels)

        if torch.cuda.is_available():
            inputs = inputs.cuda()
            labels = labels.cuda()
            network.cuda()

        optimizer.zero_grad()
        out = network.forward(inputs)
        loss = loss_function(out, labels)
        loss.backward()
        optimizer.step()

        if(i % 100 == 0):
            print('{0}: loss is {1}'.format(i, loss))

print("train over")

신경망 모델 평가

test 데이터를 이용해 신경망 모델이 예측한 값과 실제 label이 얼마나 같은 지 모두 확인하여 그 정확도를 출력하는 코드입니다.

total = 0
correct = 0
for _, data in enumerate(test_loader):
    images, labels = data

    if torch.cuda.is_available():
        images = images.cuda()
        labels = labels.cuda()

    out = network.forward(Variable(images))
    _, predicted = torch.max(out.data, 1)

    total += labels.size(0)
    correct += (predicted == labels).sum()

print('Accuracy is {0}%'.format(100*correct/total))

참고: 김군이(https://www.youtube.com/watch?v=hAdjMBR-ttg)

Comments