SAT-HMR / datasets /bedlam.py
ChiSu001's picture
Upload model files
ff07ed4 verified
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
import os
from configs.paths import dataset_root
import copy
from tqdm import tqdm
from .base import BASE
class BEDLAM(BASE):
def __init__(self, split='train_6fps',**kwargs):
super(BEDLAM, self).__init__(**kwargs)
assert split in ['train_1fps','train_3fps','train_6fps','validation_6fps']
assert not self.kid_offset
self.ds_name = 'bedlam'
self.dataset_path = os.path.join(dataset_root,'bedlam')
annots_path = os.path.join(self.dataset_path,f'bedlam_smpl_{split}.npz')
self.annots = np.load(annots_path, allow_pickle=True)['annots'][()]
self.img_names = list(self.annots.keys())
self.split = 'train' if 'train' in split else 'validation'
def __len__(self):
return len(self.img_names)
def cnt_instances(self):
ins_cnt = 0
for idx in tqdm(range(len(self))):
img_id = idx
img_name = self.img_names[img_id]
# ins_cnt += len(self.annots[img_name]['isValid'])
ins_cnt += len(self.annots[img_name]['shape'])
# tqdm.write(str(ins_cnt))
print(f'TOTAL: {ins_cnt}')
def get_raw_data(self, idx):
img_id = idx%len(self.img_names)
img_name = self.img_names[img_id]
annots = copy.deepcopy(self.annots[img_name])
img_path = os.path.join(self.dataset_path,self.split,img_name)
cam_intrinsics = torch.from_numpy(annots['cam_int']).unsqueeze(0)
cam_rot = torch.from_numpy(np.stack(annots['cam_rot']))
cam_trans = torch.from_numpy(np.stack(annots['cam_trans']))
betas = torch.from_numpy(np.stack(annots['shape']))
poses = torch.from_numpy(np.stack(annots['pose_world']))
transl = torch.from_numpy(np.stack(annots['trans_world']))
raw_data={'img_path': img_path,
'ds': 'bedlam',
'pnum': len(betas),
'betas': betas.float(),
'poses': poses.float(),
'transl': transl.float(),
'cam_rot': cam_rot.float(),
'cam_trans': cam_trans.float(),
'cam_intrinsics':cam_intrinsics.float(),
'3d_valid': True,
'age_valid': False,
'detect_all_people':True
}
if self.mode == 'eval':
raw_data['occ_level'] = torch.zeros(len(betas),dtype=int)
return raw_data