데이터셋은 torchvision에서 제공하기 때문에 import 시킨 후 세부 설정 root: 데이터의 경로 train: 학습데이터여부 transform: 어떤 형태로 불러 올지 (일반이미지: H,W,C)형태 / Pytorch는 0 ~ 1사이 값을 가진 (C,H,W)형식 download: 데이터가 없으면 다운 받을지 여부
4. 데이터 셋 확인
print(train_data)
print(test_data)
해석 학습 데이터는 60,000개 / 학습용 데이터 / 형태는 텐서
5. 데이터로더 설정 후 학습데이터 시각화 확인
loader = DataLoader(
dataset=train_data,
batch_size=64,
shuffle=True
)
imgs, labels = next(iter(loader))
fig, axes = plt.subplots(8,8,figsize=(16,16))
for ax,img,label inzip(axes.flatten(),imgs,labels):
ax.imshow(img.reshape((28,28)), cmap='gray')
ax.set_title(label.item())
ax.axis('off')
해석
첫번쨰 ) torch.utils.data의 DataLoader를 import
두번째) DataLoder로 학습데이터를 설정 dataset는 학습 데이터 batch_size는 64(즉 64개 데이터가 1그룹) shuffle: 순서를 섞는다.
세번째) 데이터 로더를 반복가능객체로 만든다. (return이 2개이기 때문에 2개 변수의 설정)
네번째) 반복문을 통해 시각화 plt.subplots(가로,세로, 전체 표사이즈) (return이 2개여서 2개 변수의 설정) fig는 표의 전체 프레임 , axes 그래프가 그려지는곳
zip 함수로 axes를 평탄화한 데이터 + img 데이터들 + img 이름 데이터들을 반복
6. 모델 준비
model = nn.Sequential(
nn.Conv2d(1,32,kernel_size=3,padding='same'),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(32,64,kernel_size=3,padding='same'),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(),
nn.Linear(7*7*64,10)
).to(device)
print(model)
해석 CNN 모델처럼 만들었다.
컨볼루전 + 활성화 (Relu) + Maxpooling을 1개의 세트로 생각 2번 진행후 평탄화 선형회귀에 적용