Python/PyTorch 공부
[Pytorch] Sementation mask 시각화 하기
AI 꿈나무
2022. 6. 26. 20:40
반응형
이미지를 segmentation 모델로 전달하여 pred를 얻었다고 가정하겠습니다.
for image, target in data_loader:
pred_masks = model(image) # [N, H, W], dtype= Tensor.bool
이 pred_masks를 matplotlib를 사용하여 시각화 하겠습니다.
우선, pred_masks, target, image를 동일한 사이즈로 resize 해줘야 합니다.
안되어있는 경우 resize 합니다.
import torchvision.transforms.functional as TF
h, w = image.shape[2], image.shape[3]
pred_masks = TF.resize(pred_masks, (h, w)).type(torch.bool) # [N, H, W]
target = TF.resize(target, {h, w)).type(torch.bool) # [N, H, W]
시각화 함수를 정의합니다.
제 경우에는 sentence와 iou도 함께 존재하는데요. 이 값까지 함께 시각화 했습니다.
연구에 필요한 코드를 짜고나서 블로그에 업로드하기 때문에, 보시는 분들은 불필요한 코드가 많다고 느끼실 수 있습니다.
def show_result(image, target, preds, sentences, ious, args, height, width, opacity=0.5):
img = image.squeeze().cpu().numpy().transpose(1,2,0) # [3, 480, 480] -> [480, 480, 3]
target = target.squeeze().cpu().numpy().astype(bool) # [1, H, W] -> [H, W]
ious = [iou.cpu().numpy() for iou in ious]
preds = [pred.squeeze(0).cpu().numpy().astype(bool) for pred in preds]
mean = ([0.485, 0.456, 0.406])
std = ([0.229, 0.224, 0.225])
# re-nomalize
for c, (mean_c, std_c) in enumerate(zip(mean, std)):
img[:,:,c] *= std_c
img[:,:,c] += mean_c
color = np.array(ImageColor.getrgb('red'), dtype=np.uint8) / 255# tuple
gt = copy.copy(img)
# apply mask
for c in range(3):
gt[:,:,c] = np.where(target == 1,
gt[:,:,c] * opacity + (1 - opacity) * color[c],
gt[:,:,c])
fig = plt.figure(figsize=(30,10),constrained_layout=True)
specs = gridspec.GridSpec(nrows=2, ncols= len(sentences) + 1)
ax1 = fig.add_subplot(specs[0,0])
ax1.set_title(f'mean_iou: {np.mean(ious):.2f} \n Image', fontsize=25)
ax1.axis('off')
ax1.imshow(img)
ax2 = fig.add_subplot(specs[0,1])
ax2.set_title('GT', fontsize=20)
ax2.axis('off')
ax2.imshow(gt)
count = 0
for pred, sentence, iou in zip(preds, sentences, ious):
mask_seg = copy.copy(img)
for c in range(3):
mask_seg[:,:,c] = np.where(pred == 1,
mask_seg[:,:,c] * opacity + (1 - opacity) * color[c],
mask_seg[:,:,c])
ax = fig.add_subplot(specs[1, count])
ax.set_title(f'IoU: {iou:.2f} \n {sentence}', fontsize=25)
ax.axis('off')
ax.imshow(mask_seg)
count += 1
os.makedirs(f'./show_result/{args.dataset}_{args.split}_{args.clip_model}_{height}x{width}/', exist_ok=True)
show_dir = f'./show_result/{args.dataset}_{args.split}_{args.clip_model}_{height}x{width}/M{np.mean(ious):.2f}_H{np.max(ious):.2f}_L{np.min(ious):.2f}_{sentence}.jpg'
plt.savefig(show_dir)
plt.show()
show_result(imgs, target, result_segs, sentence_raw, this_ious, args, Height, Width)
반응형