공부 목적으로 PyTorch 튜토리얼 홈페이지를 변역해보았습니다.
이전 포스팅에서는 pytorch.nn 모듈을 사용해서 신경망을 구축해보았습니다.
Dataset 을 사용해서 재구성하기
PyTorch는 추상적인 Dataset 클래스가 있습니다.
Dataset은 __len__ 함수와 __getitem__ 함수를 지닌 어떤 것이라도 될 수 있으며, 이 함수들을 인덱싱하기 위한 방법으로서 사용됩니다.
이번 튜토리얼에서는 Dataset의 하위 클래스로에 사용자 지정 FacialLandmarkDataset을 생성하는 좋은 예제를 제시합니다.
PyTorch의 TensorDataset은 tensor를 감싸는 Dataset입니다.
인덱싱 방식과 길이를 정의함으로써 이것은 tensor의 첫 번째 차원을 따라 반복, 인덱스, 슬라이스를 위한 방법을 제공합니다.
훈련할 때 동일한 라인에서 독립 변수와 종속 변수에 쉽게 접근할 수 있습니다.
from torch.utils.data import TensorDataset
x_train 과 y_train 은 하나의 TensorDataset 으로 결합될 수 있습니다.
따라서 반복과 슬라이스가 편리합니다.
train_ds = TensorDataset(x_train, y_train)
이전에 개별적으로 x와 y의 미니배치를 반복했습니다.
xb = x_train[start_i:end_i]
yb = y_train[start_i:end_i]
이제 두 단계를 함께 진행할 수 있습니다.
xb, yb = train_ds[i*bs : i*bs+bs]
재구성된 코드입니다.
model, opt = get_model()
for epoch in range(epochs):
for i in range((n - 1) // bs + 1):
xb, yb = train_ds[i * bs: i * bs + bs]
pred = model(xb)
loss = loss_func(pred, yb)
loss.backward()
opt.step()
opt.zero_grad()
print(loss_func(model(xb), yb))
Out :
tensor(0.0821, grad_fn=<NllLossBackward>)
DataLoader를 사용해여 코드 재구성하기
PyTorch의 DataLoader 는 배치 관리를 담당합니다.
모든 Dataset 으로부터 DataLoader 를 생성할 수 있습니다.
DataLoader 는 배치에 대해서 반복하기 편리하게 해줍니다.
train_ds[i*bs : i*bs+bs] 를 사용하기보다 DataLoader 는 각각의 미니배치를 자동으로 제공합니다.
from torch.utils.data import DataLoader
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs)
이전에는 루프가 다음과 같이 (xb, yb) 배치를 반복했습니다.
for i in range((n-1)//bs + 1):
xb,yb = train_ds[i*bs : i*bs+bs]
pred = model(xb)
이제 (xb, yb)가 dataloader 으로부터 자동으로 로드되므로 루프가 훨씬 깔끔해졌습니다.
for xb,yb in train_dl:
pred = model(xb)
재구성된 코드는 다음과 같습니다.
파이토치 없이 구현한 코드보다 훨씬 간결합니다.
model, opt = get_model()
for epoch in range(epochs):
for xb, yb in train_dl:
pred = model(xb)
loss = loss_func(pred, yb)
loss.backward()
opt.step()
opt.zero_grad()
print(loss_func(model(xb), yb))
Out :
tensor(0.0836, grad_fn=<NllLossBackward>)
PyTorch의 nn.Module , nn.Parameter , Dataset , DataLoder 덕분에 훈련 루프는 극적으로 간결해지고 이해하기 쉬워졌습니다.
'Python > PyTorch 공부' 카테고리의 다른 글
[Object Detection] YOLO(v3)를 PyTorch로 바닥부터 구현하기 - Part 1 (5) | 2021.01.10 |
---|---|
[PyTorch] 4. 검증(validation) 추가하고 fit() 와 get_data() 생성하기 (0) | 2020.12.09 |
[PyTorch] 2. 파이토치 torch.nn을 사용해서 신경망 구축하기 (0) | 2020.12.08 |
[Pytorch] 1. MNIST 데이터를 불러오고 파이토치 없이 신경망 구현하기 (0) | 2020.12.07 |
[PyTorch] 3. 예제로 배우는 파이토치 - nn 모듈, 가중치 공유, 제어 흐름, 사용자 정의 nn 모듈 (0) | 2020.12.07 |