Python/PyTorch 공부

[PyTorch] 커스텀 데이터셋(custom dataset) 생성하기

AI 꿈나무 2021. 3. 6. 00:38
반응형

 AMD dataset을 사용하여 custom dataset을 생성해보겠습니다.

 

 AMD dataset은 amd.grand-challenge.org/ 여기에서 다운로드 받으실 수 있습니다.

 

# create the training and validation datasets and the data loaders
from torch.utils.data import Dataset
from PIL import Image

class AMD_dataset(Dataset):
    def __init__(self, path2data, transform, trans_params):      
        pass    
      
    def __len__(self):
        # return size of dataset
        return len(self.labels)
      
    def __getitem__(self, idx):
        pass

        
def __init__(self, path2data, transform, trans_params):
    path2labels = os.path.join(path2data, 'Training400', 'Fovea_location.xlsx')
    labels_df = pd.read_excel(path2labels, index_col='ID')
    self.labels = labels_df[['Fovea_X', 'Fovea_Y']].values

    self.imgName = labels_df['imgName']
    self.ids = labels_df.index

    self.fullPath2img = [0] * len(self.ids)
    for id_ in self.ids:
        if self.imgName[id_][0] == 'A':
            prefix = 'AMD'
        else:
            prefix = 'Non-AMD'
        self.fullPath2img[id_-1] = os.path.join(path2data, 'Training400', prefix, self.imgName[id_])
    
    self.transform = transform
    self.trans_params = trans_params

def __len__(self):
    # return size of dataset
    return len(self.labels)

def __getitem__(self, idx):
    image = Image.open(self.fullPath2img[idx])
    label = self.labels[idx]
    image, label = self.transform(image, label, self.trans_params)

    return image, label

# override the dataset class functions
AMD_dataset.__init__ = __init__
AMD_dataset.__getitem__ = __getitem__
반응형