Source code for vathos.data_loader.dense_depth

from pathlib import Path
import os
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import seaborn as sns
# from tqdm.auto import tqdm
from tqdm.notebook import tqdm, trange
import numpy as np
from PIL import Image
import PIL
import zipfile
from zipfile import ZipFile
sns.set()


[docs]class DenseDepth(Dataset): ''' DenseDepth Dataset Args: root: the directory where to place the dataset unzipped files source_zipfiles: the directory where the dataset zip files are stored (mount your drive and give a absolute path) transform: torchvision transform for input images target_transform: torchvision transform for ouput images Input is fg_bg image AND bg image Target is fg_bg_mask AND depth_fg_bg image ''' source_zipfiles = ['bg_small.zip', 'fg_bg_small.zip', 'fg_bg_mask_small.zip', 'depth_fg_bg_small.zip'] bg_stat = (['0.573435604572296', '0.520844697952271', '0.457784473896027'], [ '0.207058250904083', '0.208138316869736', '0.215291306376457']) fg_bg_stat = (['0.568499565124512', '0.512103974819183', '0.452332496643066'], [ '0.211068645119667', '0.211040720343590', '0.216081097722054']) fg_bg_mask_stat = (['0.062296919524670', '0.062296919524670', '0.062296919524670'], [ '0.227044790983200', '0.227044790983200', '0.227044790983200']) depth_fg_bg_stat = (['0.302973538637161', '0.302973538637161', '0.302973538637161'], [ '0.101284727454185', '0.101284727454185', '0.101284727454185']) def __init__(self, root, source_zipfolder, train=True, transform=None, target_transform=None): self.root = Path(root) / 'Vathos' self.root.mkdir(parents=True, exist_ok=True) self.source_zipfolder = Path(source_zipfolder) self.transform = transform self.target_transform = target_transform # check if the dataset exists if os.path.isdir(self.root / 'bg') or os.path.isdir(self.root / 'fg_bg') or os.path.isdir(self.root / 'fg_bg_mask') or os.path.isdir(self.root / 'depth_fg_bg'): print(f'dataset folders/files already exists in {self.root}') else: # extract the dataset to root dir self.extractall() # pathlib does not order them by default bg_paths = sorted(list(Path(self.root / 'bg').glob('*.jpg'))) fg_bg_paths = sorted(list(Path(self.root / 'fg_bg').glob('**/*.jpg'))) fg_bg_mask_paths = sorted( list(Path(self.root / 'fg_bg_mask').glob('**/*.jpg'))) depth_fg_bg_paths = sorted( list(Path(self.root / 'depth_fg_bg').glob('**/*.png'))) assert(len(bg_paths) == 100) assert(len(fg_bg_paths) == 400000) assert(len(fg_bg_mask_paths) == 400000) assert(len(depth_fg_bg_paths) == 400000) print(f'found {len(bg_paths)} bg images, {len(fg_bg_paths)} fg_bg images, {len(fg_bg_mask_paths)} fg_bg_mask images, {len(depth_fg_bg_paths)} depth_fg_bg images') self.input_paths = fg_bg_paths self.bg_paths = bg_paths self.target_paths = list(zip(fg_bg_mask_paths, depth_fg_bg_paths))
[docs] def extractall(self): r"""extracts the zip files into the root dir """ print(f'Extracting the zip files') for smallzip in tqdm(self.source_zipfiles): print(f'Extracting {smallzip} ...') zipf = ZipFile(self.source_zipfolder / smallzip, 'r') zipf.extractall(self.root)
def __getitem__(self, index): bgidx = self.input_paths[index].stem.split('_')[3] bgimg = Image.open(self.bg_paths[int(bgidx)]) bgimg = bgimg.convert('RGB') # bgimg = np.array(bgimg) fg_bgimg = Image.open(self.input_paths[index]) fg_bgimg = fg_bgimg.convert('RGB') # fg_bgimg = np.array(fg_bgimg) target_mask, target_depth = self.target_paths[index] mask_fg_bgimg = Image.open(target_mask) mask_fg_bgimg.convert('L') mask_arr = np.array(mask_fg_bgimg) mask_arr[mask_arr >= 150] = 255 mask_arr[mask_arr < 150] = 0 mask_fg_bgimg = Image.fromarray(mask_arr) # mask_fg_bgimg.convert('L') depth_fg_bgimg = Image.open(target_depth) depth_fg_bgimg.convert('L') if self.transform is not None: bgimg = self.transform(bgimg) fg_bgimg = self.transform(fg_bgimg) if self.target_transform is not None: mask_fg_bgimg = self.target_transform(mask_fg_bgimg) depth_fg_bgimg = self.target_transform(depth_fg_bgimg) return {'bg': bgimg, 'fg_bg': fg_bgimg, 'fg_bg_mask': mask_fg_bgimg, 'depth_fg_bg': depth_fg_bgimg} def __len__(self): return len(self.input_paths)
[docs] @staticmethod def plot_sample(sample): r"""Plots a given sample of the dataset Args: batch: the batch of data of this DenseDepth dataset Returns: matplotlib.pyplot.figure: the created figure """ bg, fg_bg, fg_bg_mask, depth_fg_bg = sample['bg'].permute(1, 2, 0).numpy(), sample['fg_bg'].permute( 1, 2, 0).numpy(), sample['fg_bg_mask'][0].numpy(), sample['depth_fg_bg'][0].numpy() fig, ax = plt.subplots(2, 2, figsize=(4, 4), sharex=True, sharey=True) ax[0, 0].imshow(bg) ax[0, 0].axis('off') ax[0, 1].imshow(fg_bg) ax[0, 1].axis('off') ax[1, 0].imshow(fg_bg_mask) ax[1, 0].axis('off') ax[1, 1].imshow(depth_fg_bg) ax[1, 1].axis('off') fig.tight_layout() return fig
[docs] @staticmethod def plot4_batch(batch): r"""Plots 4 images for batch Args: batch: the batch of data of this DenseDepth dataset Returns: matplotlib.pyplot.figure: the created figure """ fig, ax = plt.subplots(4, 4, figsize=(6, 6), sharex=True, sharey=True) # set the title for axs, col in zip(ax[0], ['BG', 'FG_BG', 'FG_BG_MASK', 'DEPTH_FG_BG']): axs.set_title(col) # plot the first 4 samples from the batch for i in range(4): bg, fg_bg, fg_bg_mask, depth_fg_bg = batch['bg'][i].permute(1, 2, 0).cpu().numpy(), batch['fg_bg'][i].permute( 1, 2, 0).cpu().numpy(), batch['fg_bg_mask'][i][0].cpu().numpy(), batch['depth_fg_bg'][i][0].cpu().numpy() ax[i, 0].imshow(bg) ax[i, 0].axis('off') ax[i, 1].imshow(fg_bg) ax[i, 1].axis('off') fg_bg_mask[fg_bg_mask >= 0.9] = 1 fg_bg_mask[fg_bg_mask < 0.9] = 0 ax[i, 2].imshow(fg_bg_mask) ax[i, 2].axis('off') ax[i, 3].imshow(depth_fg_bg) ax[i, 3].axis('off') fig.tight_layout() return fig
[docs] @staticmethod def plot_results(batch): r"""Plots 8 images for batch's model results Args: batch: the batch of data of this DenseDepth dataset Returns: matplotlib.pyplot.figure: the created figure """ fig, ax = plt.subplots(8, 6, figsize=( 15, 18), sharex=True, sharey=True) # set the title for axs, col in zip(ax[0], ['BG', 'FG_BG', 'GT MASK', 'PRED MASK', 'GT DEPTH', 'PRED DEPTH']): axs.set_title(col) # plot the first 8 samples from the batch for i in range(8): bg, fg_bg, fg_bg_mask, depth_fg_bg = batch['bg'][i].permute(1, 2, 0).cpu().numpy(), batch['fg_bg'][i].permute( 1, 2, 0).cpu().numpy(), batch['fg_bg_mask'][i][0].cpu().numpy(), batch['depth_fg_bg'][i][0].cpu().numpy() pred_mask, pred_depth = batch['pred_mask'][i][0].cpu( ).numpy(), batch['pred_depth'][i][0].cpu().numpy() pred_mask[pred_mask >= 0.9] = 1 pred_mask[pred_mask < 0.9] = 0 ax[i, 0].imshow(bg) ax[i, 0].axis('off') ax[i, 1].imshow(fg_bg) ax[i, 1].axis('off') ax[i, 2].imshow(fg_bg_mask) ax[i, 2].axis('off') ax[i, 3].imshow(pred_mask) ax[i, 3].axis('off') ax[i, 4].imshow(depth_fg_bg) ax[i, 4].axis('off') ax[i, 5].imshow(pred_depth) ax[i, 5].axis('off') fig.tight_layout() return fig
[docs] @staticmethod def apply_on_batch(batch, apply_func): r"""applies a given function the batch Args: batch: a batch of data apply_func: a function that is to be applied to each data in batch """ batch['bg'] = apply_func(batch['bg']) batch['fg_bg'] = apply_func(batch['fg_bg']) batch['fg_bg_mask'] = apply_func(batch['fg_bg_mask']) batch['depth_fg_bg'] = apply_func(batch['depth_fg_bg']) return batch