Python/PyTorch 공부

[Pytorch] Sementation mask 시각화 하기

AI 꿈나무 2022. 6. 26. 20:40
반응형

이미지를 segmentation 모델로 전달하여 pred를 얻었다고 가정하겠습니다.

 

for image, target in data_loader:
    pred_masks = model(image) # [N, H, W], dtype= Tensor.bool

 

이 pred_masks를 matplotlib를 사용하여 시각화 하겠습니다.

 

우선, pred_masks, target, image를 동일한 사이즈로 resize 해줘야 합니다.

안되어있는 경우 resize 합니다.

 

import torchvision.transforms.functional as TF

h, w = image.shape[2], image.shape[3]
pred_masks = TF.resize(pred_masks, (h, w)).type(torch.bool) # [N, H, W]
target = TF.resize(target, {h, w)).type(torch.bool) # [N, H, W]

 

시각화 함수를 정의합니다.

제 경우에는 sentence와 iou도 함께 존재하는데요. 이 값까지 함께 시각화 했습니다.

연구에 필요한 코드를 짜고나서 블로그에 업로드하기 때문에, 보시는 분들은 불필요한 코드가 많다고 느끼실 수 있습니다. 

 

def show_result(image, target, preds, sentences, ious, args, height, width, opacity=0.5):
    img = image.squeeze().cpu().numpy().transpose(1,2,0) # [3, 480, 480] -> [480, 480, 3]
    target = target.squeeze().cpu().numpy().astype(bool) # [1, H, W] -> [H, W]
    ious = [iou.cpu().numpy() for iou in ious]
    preds = [pred.squeeze(0).cpu().numpy().astype(bool) for pred in preds]
    

    mean = ([0.485, 0.456, 0.406])
    std = ([0.229, 0.224, 0.225])

    # re-nomalize
    for c, (mean_c, std_c) in enumerate(zip(mean, std)):
        img[:,:,c] *= std_c
        img[:,:,c] += mean_c

    color = np.array(ImageColor.getrgb('red'), dtype=np.uint8) / 255# tuple

    gt = copy.copy(img)
    # apply mask
    for c in range(3):
        gt[:,:,c] = np.where(target == 1,
                             gt[:,:,c] * opacity + (1 - opacity) * color[c],
                             gt[:,:,c])


    fig = plt.figure(figsize=(30,10),constrained_layout=True)
    specs = gridspec.GridSpec(nrows=2, ncols= len(sentences) + 1)
    ax1 = fig.add_subplot(specs[0,0])
    ax1.set_title(f'mean_iou: {np.mean(ious):.2f} \n Image', fontsize=25)
    ax1.axis('off')
    ax1.imshow(img)

    ax2 = fig.add_subplot(specs[0,1])
    ax2.set_title('GT', fontsize=20)
    ax2.axis('off')
    ax2.imshow(gt)

    count = 0
    for pred, sentence, iou in zip(preds, sentences, ious):
        mask_seg = copy.copy(img)

        for c in range(3):
            mask_seg[:,:,c] = np.where(pred == 1,
                                       mask_seg[:,:,c] * opacity + (1 - opacity) * color[c],
                                       mask_seg[:,:,c])

        ax = fig.add_subplot(specs[1, count])
        ax.set_title(f'IoU: {iou:.2f} \n {sentence}', fontsize=25)
        ax.axis('off')
        ax.imshow(mask_seg)

        count += 1

    os.makedirs(f'./show_result/{args.dataset}_{args.split}_{args.clip_model}_{height}x{width}/', exist_ok=True)
    show_dir = f'./show_result/{args.dataset}_{args.split}_{args.clip_model}_{height}x{width}/M{np.mean(ious):.2f}_H{np.max(ious):.2f}_L{np.min(ious):.2f}_{sentence}.jpg'


    plt.savefig(show_dir)
    plt.show()

 

show_result(imgs, target, result_segs, sentence_raw, this_ious, args, Height, Width)

 

반응형