웹 개발 메모장
4. Pytorch torchvision에서 데이터셋 가져오기 본문
torchvision
에서 데이터셋 가져오기
torchvision
( pip install torchvision 으로 설치 )
널리 사용되는 데이터 셋, 아키텍쳐 모델 computer vision에서의 일반적인 이미지 변환으로 구성되어 있습니다.
torchvision.datasets
- MNIST
- Fashion-MNIST
- KMNIST
- EMNIST
- FakeData
- COCO
- LSUN
- ImageFolder
- DatasetFolder
- Imagenet-12
- CIFAR
- STL10
- SVHN
- PhotoTour
- SBU
- Flickr
- VOC
- Cityscapes
torchvision.models
- Alexnet
- VGG
- ResNet
- SqueezeNet
- DenseNet
- Inception v3
- GoogLeNet
torchvision.transforms
-
Transforms on PIL Image
-
Transforms on torch.*Tensor
-
Conversion Transforms
-
Generic Transforms
-
Functional Transforms
torchvision.utils
torchvision.transforms
다양한 이미지 변환 기능들을 제공해 줍니다.
EX)
torchvision.transform.ToTensor
PIL 이미지 또는 numpy.ndarray를 pytorch의 텐서로 변형시켜 줍니다.
transforms.ToTensor()
torchvision.transforms.Normalize(mean, std, inplace=False)
평균과 표준편차를 사용하여 이미지를 정규화 시켜 줍니다.
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
torchvision.transform.Compose
여러 transform 들을 Compose로 구성할 수 있습니다.
transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
torchvision.datasets.CIFAR10
CIFAR10 데이터셋을 가져옵니다.
train_set = torchvision.datasets.CIFAR10(
root='./data', # 데이터 저장 위치
train=True, # True: train set, False: test set
download=False, # 다운로드 여부, (이미 다운받았으면 False로 지정)
transform=transform # 데이터 선처리 작업
)
test_set = torchvision.datasets.CIFAR10(
root='./data',
train=False,
download=False,
transform=transform
)
torch.utils.data.DataLoader
torchvision.dataset의 데이터들을 불러옵니다.
batch_size, shuffle, num_workers 등 10여개의 parameter를 전달 받습니다.
train_loader = DataLoader(train_set, batch_size=8, shuffle=True, num_workers=0)
test_loader = DataLoader(test_set, batch_size=8, shuffle=False, num_workers=0)
'옛날.. > pytorch' 카테고리의 다른 글
6. 신경망 학습과 평가 (0) | 2019.07.22 |
---|---|
5. 신경망 정의와 순방향 전파 (0) | 2019.07.22 |
3. Pytorch nn과 nn.functional (1) | 2019.07.21 |
2. Pytorch autograd와 Variable (0) | 2019.07.21 |
1. Pytorch 소개 (0) | 2019.07.21 |
Comments