웹 개발 메모장

7. nn.Sequential을 사용한 신경망 구현 본문

옛날../pytorch

7. nn.Sequential을 사용한 신경망 구현

도로롱주 2019. 7. 22. 22:47

nn.Sequantial을 사용한 신경망 구현(앞서 정의한 신경망 모델(#6 )의 연장)

nn.Sequantial을 사용하지 않은 신경망

아래 신경망에서는 __init__() 에서 사용할 네트워크 모델들을 정의 해주고,
forward() 함수에서 그 모델들을 사용하여 순전파 로직을 구현했습니다.

class MyNeuralNetwork(nn.Module):
    def __init__(self):
        super(MyNeuralNetwork, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=30, kernel_size=5)
        self.fc1 = nn.Linear(in_features=30*5*5, out_features=128, bias=True)
        self.fc2 = nn.Linear(in_features=128, out_features=10, bias=True)

    def forward(self, x):
        x = F.relu(self.conv1(x), inplace=True)
        x = F.max_pool2d(x, (2, 2))

        x = F.relu(self.conv2(x), inplace=True)
        x = F.max_pool2d(x, (2, 2))

        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x), inplace=True)
        x = F.relu(self.fc2(x), inplace=True)

        return x

nn.Sequantial을 사용한 신경망

nn.Sequantial을 사용하면 __init__() 에서 사용할 네트워크 모델들을 정의 해줄 뿐만 아니라,
forward() 함수에서 구현될 순전파를 Layer 형태로 보다 가독성이 뛰어나게 코드를 작성할 수 있습니다. Layer가 복잡해질수록 nn.Sequential은 그 효가가 뛰어납니다.

class MyNeuralNetwork(nn.Module):
    def __init__(self):
        super(MyNeuralNetwork, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=30, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )

        self.layer3 = nn.Sequential(
            nn.Linear(in_features=30*5*5, out_features=128, bias=True),
            nn.ReLU(inplace=True)
        )

        self.layer4 = nn.Sequential(
            nn.Linear(in_features=128, out_features=10, bias=True),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = x.view(x.shape[0], -1)
        x = self.layer3(x)
        x = self.layer4(x)

        return x
Comments