|
|
import numpy as np |
|
|
import torch |
|
|
from torch.utils.data.dataset import Dataset |
|
|
import os |
|
|
from configs.paths import dataset_root |
|
|
import copy |
|
|
from tqdm import tqdm |
|
|
from .base import BASE |
|
|
|
|
|
class BEDLAM(BASE): |
|
|
def __init__(self, split='train_6fps',**kwargs): |
|
|
super(BEDLAM, self).__init__(**kwargs) |
|
|
assert split in ['train_1fps','train_3fps','train_6fps','validation_6fps'] |
|
|
assert not self.kid_offset |
|
|
|
|
|
self.ds_name = 'bedlam' |
|
|
self.dataset_path = os.path.join(dataset_root,'bedlam') |
|
|
annots_path = os.path.join(self.dataset_path,f'bedlam_smpl_{split}.npz') |
|
|
self.annots = np.load(annots_path, allow_pickle=True)['annots'][()] |
|
|
self.img_names = list(self.annots.keys()) |
|
|
self.split = 'train' if 'train' in split else 'validation' |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.img_names) |
|
|
|
|
|
def cnt_instances(self): |
|
|
ins_cnt = 0 |
|
|
for idx in tqdm(range(len(self))): |
|
|
img_id = idx |
|
|
img_name = self.img_names[img_id] |
|
|
|
|
|
ins_cnt += len(self.annots[img_name]['shape']) |
|
|
|
|
|
|
|
|
print(f'TOTAL: {ins_cnt}') |
|
|
|
|
|
def get_raw_data(self, idx): |
|
|
|
|
|
img_id = idx%len(self.img_names) |
|
|
img_name = self.img_names[img_id] |
|
|
|
|
|
annots = copy.deepcopy(self.annots[img_name]) |
|
|
img_path = os.path.join(self.dataset_path,self.split,img_name) |
|
|
|
|
|
cam_intrinsics = torch.from_numpy(annots['cam_int']).unsqueeze(0) |
|
|
cam_rot = torch.from_numpy(np.stack(annots['cam_rot'])) |
|
|
cam_trans = torch.from_numpy(np.stack(annots['cam_trans'])) |
|
|
|
|
|
betas = torch.from_numpy(np.stack(annots['shape'])) |
|
|
poses = torch.from_numpy(np.stack(annots['pose_world'])) |
|
|
transl = torch.from_numpy(np.stack(annots['trans_world'])) |
|
|
|
|
|
raw_data={'img_path': img_path, |
|
|
'ds': 'bedlam', |
|
|
'pnum': len(betas), |
|
|
'betas': betas.float(), |
|
|
'poses': poses.float(), |
|
|
'transl': transl.float(), |
|
|
'cam_rot': cam_rot.float(), |
|
|
'cam_trans': cam_trans.float(), |
|
|
'cam_intrinsics':cam_intrinsics.float(), |
|
|
'3d_valid': True, |
|
|
'age_valid': False, |
|
|
'detect_all_people':True |
|
|
} |
|
|
|
|
|
if self.mode == 'eval': |
|
|
raw_data['occ_level'] = torch.zeros(len(betas),dtype=int) |
|
|
|
|
|
return raw_data |
|
|
|
|
|
|
|
|
|