Python/PyTorch 공부

[PyTorch] ShuffleSplit와 subset 함수를 사용하여 dataset 분할하기

AI 꿈나무 2021. 6. 11. 01:56
반응형

 안녕하세요! 이번 포스팅은 sklearn 패키지에서 제공하는 ShuffleSplit과 torch.utils.data의 Subset 함수를 사용하여 데이터셋을 분할하도록 하겠습니다.

 

 shufflesplit 함수는 데이터셋 인덱스를 무작위로 사전에 설정한 비율로 분할합니다. 즉, 4:1 로 분할하고 싶은 경우에 무작위 인덱스로 4:1 비율로 분할합니다.

 

 subset 함수로 데이터셋을 생성하면 부모 set이 업데이트(transformation)된 경우에 subset도 함께 업데이트 됩니다.

 

 제가 사용하는 데이터셋은 999개의 이미지로 구성됩니다.

 

 train 0.8, test 0.2로 분할하겠습니다.

 

# split the data into two groups
# trian 0.8, test 0.2
from sklearn.model_selection import ShuffleSplit
sss = ShuffleSplit(n_splits=1, test_size=0.2, random_state=0)
indices = range(len(fetal_ds1))

for train_index, val_index in sss.split(indices):
    print(len(train_index))
    print('-'*10)
    print(len(val_index))

 

 인덱스가 0.8 : 0.2 비율로 분할되었습니다.

 이제 이 인덱스를 사용하여 dataset을 생성합니다.

 

# creat train_ds and val_ds
from torch.utils.data import Subset

train_ds = Subset(fetal_ds1, train_index)
print(len(train_ds))
val_ds = Subset(fetal_ds2, val_index)
print(len(val_ds))

 

 fetal_ds1과 fetal_ds2는 동일한 이미지로 구성되어 있으며, ds1은 train을 위한 transformation, ds2는 val을 위한 transformation이 적용되어 있습니다. 따라서 동일한 이미지를 포함하지만 서로 다른 transformation이 적용된 ds를 이용해 subset을 생성합니다. 만약 fetal_ds1의 transformation을 변경하면 train_ds도 함께 변경됩니다.

 

감사합니다.

반응형