|
|
import random |
|
|
import os |
|
|
import numpy as np |
|
|
from torch.utils.data.dataset import Dataset |
|
|
from torchvision import transforms |
|
|
from utils.visualization import tensor_to_BGR, vis_meshes_img, vis_boxes, vis_scale_img, pad_img, get_colors_rgb, vis_sat |
|
|
from utils.transforms import unNormalize, to_zorder |
|
|
from PIL import Image |
|
|
import math |
|
|
from tqdm import tqdm |
|
|
import cv2 |
|
|
import torch |
|
|
import copy |
|
|
from math import radians,sin,cos |
|
|
from utils import constants |
|
|
from utils.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh |
|
|
from utils.constants import smpl_24_flip, smpl_root_idx |
|
|
from utils.map import gen_scale_map, build_z_map |
|
|
from configs.paths import smpl_model_path |
|
|
from models.human_models import SMPL_Layer, smpl_gendered |
|
|
|
|
|
|
|
|
class BASE(Dataset): |
|
|
def __init__(self, input_size = 1288, aug = True, mode = 'train', |
|
|
human_type = 'smpl', |
|
|
sat_cfg = None, |
|
|
aug_cfg = None): |
|
|
self.input_size = input_size |
|
|
self.aug = aug |
|
|
if mode not in ['train', 'eval', 'infer']: |
|
|
raise NotImplementedError |
|
|
if human_type not in ['smpl', 'no']: |
|
|
raise NotImplementedError |
|
|
self.mode = mode |
|
|
self.human_type = human_type |
|
|
assert sat_cfg is not None |
|
|
self.use_sat = sat_cfg['use_sat'] |
|
|
self.sat_cfg = sat_cfg |
|
|
|
|
|
if self.use_sat: |
|
|
assert input_size % 56 == 0 |
|
|
|
|
|
if self.mode == 'train' and aug_cfg is None: |
|
|
aug_cfg = {'rot_range': [-15, 15], |
|
|
'scale_range': [0.8, 1.8], |
|
|
'flip_ratio': 0.5, |
|
|
'crop_ratio': 0.} |
|
|
self.aug_cfg = aug_cfg |
|
|
|
|
|
if human_type == 'smpl': |
|
|
self.poses_flip = smpl_24_flip |
|
|
self.num_poses = 24 |
|
|
self.num_betas = 10 |
|
|
self.num_kpts = 45 |
|
|
self.human_model = smpl_gendered |
|
|
|
|
|
self.vis_thresh = 4 |
|
|
|
|
|
self.img_keys = ['img_path', 'ds', |
|
|
'pnum', 'img_size', |
|
|
'resize_rate', 'cam_intrinsics', |
|
|
'3d_valid', 'detect_all_people', |
|
|
'scale_map', 'scale_map_pos', 'scale_map_hw'] |
|
|
self.human_keys = ['boxes', 'labels', |
|
|
'poses', 'betas', |
|
|
'transl', 'verts', |
|
|
'j3ds', 'j2ds', 'j2ds_mask', |
|
|
'depths', 'focals', 'genders'] |
|
|
|
|
|
z_depth = math.ceil(math.log2(self.input_size//28)) |
|
|
self.z_order_map, self.y_coords, self.x_coords = build_z_map(z_depth) |
|
|
|
|
|
|
|
|
|
|
|
def get_raw_data(self, idx): |
|
|
raise NotImplementedError |
|
|
|
|
|
def get_aug_dict(self): |
|
|
if self.aug: |
|
|
rot = random.uniform(*self.aug_cfg['rot_range']) |
|
|
flip = random.random() <= self.aug_cfg['flip_ratio'] |
|
|
scale = random.uniform(*self.aug_cfg['scale_range']) |
|
|
crop = random.random() <= self.aug_cfg['crop_ratio'] |
|
|
else: |
|
|
rot = 0. |
|
|
flip = False |
|
|
scale = 1. |
|
|
crop = False |
|
|
|
|
|
return {'rot':rot, 'flip':flip, 'scale':scale, 'crop': crop} |
|
|
|
|
|
def process_img(self, img, meta_data, rot = 0., flip = False, scale = 1.0, crop = False): |
|
|
|
|
|
if self.mode == 'train' and crop: |
|
|
|
|
|
h, w = img.shape[:2] |
|
|
if h < w : |
|
|
clip_ratio = random.uniform(0.5, 0.9) |
|
|
tgt_h, tgt_w = int(h*clip_ratio), int(w*clip_ratio) |
|
|
|
|
|
img = img[:tgt_h,(w-tgt_w)//2:(w+tgt_w)//2,:].copy() |
|
|
cam_intrinsics = meta_data['cam_intrinsics'] |
|
|
cam_intrinsics[:,0,2] -= (w-tgt_w)//2 |
|
|
meta_data.update({'cam_intrinsics': cam_intrinsics}) |
|
|
|
|
|
|
|
|
|
|
|
img_size = torch.tensor(img.shape[:2]) |
|
|
if img_size[1] >= img_size[0]: |
|
|
resize_rate = self.input_size/img_size[1] |
|
|
img = cv2.resize(img,dsize=(self.input_size,int(resize_rate*img_size[0]))) |
|
|
img_size = torch.tensor([int(resize_rate*img_size[0]),self.input_size]) |
|
|
else: |
|
|
resize_rate = self.input_size/img_size[0] |
|
|
img = cv2.resize(img,dsize=(int(resize_rate*img_size[1]),self.input_size)) |
|
|
img_size = torch.tensor([self.input_size,int(resize_rate*img_size[1])]) |
|
|
meta_data.update({'img_size': img_size, 'resize_rate': resize_rate}) |
|
|
|
|
|
|
|
|
if flip: |
|
|
img = np.flip(img, axis = 1) |
|
|
rot = -rot |
|
|
|
|
|
|
|
|
img_valid = np.full((img.shape[0], img.shape[1]), 255, dtype = np.uint8) |
|
|
M = cv2.getRotationMatrix2D((int(img_size[1]/2),int(img_size[0]/2)), rot, scale) |
|
|
img = cv2.warpAffine(img, M, dsize = (img.shape[1],img.shape[0])) |
|
|
img_valid = cv2.warpAffine(img_valid, M, dsize = (img.shape[1],img.shape[0])) |
|
|
meta_data.update({'img_valid': img_valid}) |
|
|
|
|
|
return img |
|
|
|
|
|
def occlusion_aug(self, meta_data): |
|
|
occ_boxes = [] |
|
|
imght, imgwidth = meta_data['img_size'] |
|
|
for bbox in box_cxcywh_to_xyxy(meta_data['boxes']): |
|
|
bbox = bbox.clone() |
|
|
bbox *= self.input_size |
|
|
xmin, ymin = bbox[:2] |
|
|
xmax, ymax = bbox[2:] |
|
|
|
|
|
if random.random() <= 0.6: |
|
|
counter = 0 |
|
|
while True: |
|
|
|
|
|
if counter > 5: |
|
|
synth_ymin, synth_h, synth_xmin, synth_w = 0, 0, 0, 0 |
|
|
break |
|
|
counter += 1 |
|
|
|
|
|
area_min = 0.0 |
|
|
area_max = 0.3 |
|
|
synth_area = (random.random() * (area_max - area_min) + area_min) * (xmax - xmin) * (ymax - ymin) |
|
|
|
|
|
ratio_min = 0.5 |
|
|
ratio_max = 1 / 0.5 |
|
|
synth_ratio = (random.random() * (ratio_max - ratio_min) + ratio_min) |
|
|
|
|
|
synth_h = math.sqrt(synth_area * synth_ratio) |
|
|
synth_w = math.sqrt(synth_area / synth_ratio) |
|
|
synth_xmin = random.random() * ((xmax - xmin) - synth_w - 1) + xmin |
|
|
synth_ymin = random.random() * ((ymax - ymin) - synth_h - 1) + ymin |
|
|
|
|
|
if synth_xmin >= 0 and synth_ymin >= 0 and synth_xmin + synth_w < imgwidth and synth_ymin + synth_h < imght: |
|
|
synth_xmin = int(synth_xmin) |
|
|
synth_ymin = int(synth_ymin) |
|
|
synth_w = int(synth_w) |
|
|
synth_h = int(synth_h) |
|
|
break |
|
|
else: |
|
|
synth_ymin, synth_h, synth_xmin, synth_w = 0, 0, 0, 0 |
|
|
occ_boxes.append((synth_ymin, synth_h, synth_xmin, synth_w)) |
|
|
return occ_boxes |
|
|
|
|
|
def get_boxes(self, meta_data): |
|
|
j2ds = meta_data['j2ds'] |
|
|
j2ds_mask = meta_data['j2ds_mask'] |
|
|
pnum = meta_data['pnum'] |
|
|
|
|
|
bboxes_list = [] |
|
|
|
|
|
for i in range(pnum): |
|
|
kpts = j2ds[i].clone() |
|
|
min_xy = kpts.min(dim = 0)[0] |
|
|
max_xy = kpts.max(dim = 0)[0] |
|
|
bbox_xyxy = torch.cat([min_xy, max_xy], dim = 0) |
|
|
bboxes_list.append(bbox_xyxy) |
|
|
|
|
|
imght, imgwidth = meta_data['img_size'] |
|
|
boxes = box_xyxy_to_cxcywh(torch.stack(bboxes_list)) / self.input_size |
|
|
boxes[...,2:] *= 1.2 |
|
|
boxes = box_cxcywh_to_xyxy(boxes) |
|
|
boxes[...,[0,2]] = boxes[...,[0,2]].clamp(min=0.01,max=(imgwidth-1)/self.input_size) |
|
|
boxes[...,[1,3]] = boxes[...,[1,3]].clamp(min=0.01,max=(imght-1)/self.input_size) |
|
|
boxes = box_xyxy_to_cxcywh(boxes) |
|
|
|
|
|
meta_data.update({'boxes': boxes}) |
|
|
|
|
|
def process_cam(self, meta_data, rot = 0., flip = False, scale = 1.): |
|
|
img_size = meta_data['img_size'] |
|
|
resize_rate = meta_data['resize_rate'] |
|
|
rot_aug_mat = meta_data['rot_aug_mat'] |
|
|
cam_intrinsics = meta_data['cam_intrinsics'] |
|
|
|
|
|
|
|
|
cam_intrinsics[:,0:2,2] *= resize_rate * scale |
|
|
cam_intrinsics[:,[0,1],[0,1]] *= resize_rate * scale |
|
|
cam_intrinsics[:,0,2] += (1-scale)*img_size[1]/2 |
|
|
cam_intrinsics[:,1,2] += (1-scale)*img_size[0]/2 |
|
|
|
|
|
princpt = cam_intrinsics[:,0:2,2].clone() |
|
|
princpt[...,0] -= img_size[1]/2 |
|
|
princpt[...,1] -= img_size[0]/2 |
|
|
princpt = torch.matmul(princpt,rot_aug_mat[:2,:2].transpose(-1,-2)) |
|
|
princpt[...,0] += img_size[1]/2 |
|
|
princpt[...,1] += img_size[0]/2 |
|
|
cam_intrinsics[:,0:2,2] = princpt |
|
|
|
|
|
if flip: |
|
|
cam_intrinsics[:,0,2] = img_size[1]-cam_intrinsics[:,0,2] |
|
|
meta_data.update({'cam_intrinsics': cam_intrinsics}) |
|
|
|
|
|
|
|
|
new_cam_rot = torch.matmul(rot_aug_mat.unsqueeze(0),meta_data['cam_rot']) |
|
|
new_cam_trans = torch.matmul(meta_data['cam_trans'],rot_aug_mat.transpose(-1,-2)) |
|
|
meta_data.update({'cam_rot': new_cam_rot,'cam_trans':new_cam_trans}) |
|
|
|
|
|
def process_smpl(self, meta_data, rot = 0., flip = False, scale = 1.): |
|
|
poses = meta_data['poses'] |
|
|
bs = poses.shape[0] |
|
|
assert poses.ndim == 2 |
|
|
assert tuple(poses.shape) == (bs, self.num_poses*3) |
|
|
|
|
|
global_orient = poses[:,:3].clone() |
|
|
cam_rot = meta_data['cam_rot'].numpy() |
|
|
for i in range(global_orient.shape[0]): |
|
|
root_pose = global_orient[i].view(1, 3).numpy() |
|
|
R = cam_rot[i].reshape(3,3) |
|
|
root_pose, _ = cv2.Rodrigues(root_pose) |
|
|
root_pose, _ = cv2.Rodrigues(np.dot(R, root_pose)) |
|
|
root_pose = torch.from_numpy(root_pose).flatten() |
|
|
global_orient[i] = root_pose |
|
|
poses[:,:3] = global_orient |
|
|
|
|
|
|
|
|
if flip: |
|
|
poses = poses.reshape(bs, self.num_poses, 3) |
|
|
poses = poses[:, self.poses_flip, :] |
|
|
poses[..., 1:3] *= -1 |
|
|
poses = poses.reshape(bs, -1) |
|
|
|
|
|
|
|
|
meta_data.update({'poses': poses}) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
smpl_kwargs = {'poses': meta_data['poses'], 'betas': meta_data['betas']} |
|
|
if 'genders' in meta_data: |
|
|
smpl_kwargs.update({'genders': meta_data['genders']}) |
|
|
verts, j3ds = self.human_model(**smpl_kwargs) |
|
|
|
|
|
j3ds = j3ds[:, :self.num_kpts, :] |
|
|
root = j3ds[:,smpl_root_idx,:].clone() |
|
|
|
|
|
transl = torch.bmm((root+meta_data['transl']).reshape(-1,1,3),meta_data['cam_rot'].transpose(-1,-2)).reshape(-1,3)\ |
|
|
+meta_data['cam_trans']-root |
|
|
if flip: |
|
|
transl[...,0] = -transl[...,0] |
|
|
|
|
|
meta_data.update({'transl': transl}) |
|
|
|
|
|
verts = verts + transl.reshape(-1,1,3) |
|
|
j3ds = j3ds + transl.reshape(-1,1,3) |
|
|
meta_data.update({'verts': verts, 'j3ds': j3ds}) |
|
|
|
|
|
def project_joints(self, meta_data): |
|
|
j3ds = meta_data['j3ds'] |
|
|
cam_intrinsics = meta_data['cam_intrinsics'] |
|
|
j2ds_homo = torch.matmul(j3ds,cam_intrinsics.transpose(-1,-2)) |
|
|
j2ds = j2ds_homo[...,:2]/(j2ds_homo[...,2,None]) |
|
|
|
|
|
meta_data.update({'j3ds': j3ds, 'j2ds': j2ds}) |
|
|
|
|
|
def check_visibility(self, meta_data): |
|
|
img_valid = meta_data['img_valid'] |
|
|
img_size = meta_data['img_size'] |
|
|
|
|
|
j2ds = meta_data['j2ds'] |
|
|
j2ds_mask = meta_data['j2ds_mask'] if 'j2ds_mask' in meta_data else torch.ones_like(j2ds, dtype=bool) |
|
|
|
|
|
|
|
|
j2ds_vis = torch.from_numpy(img_valid[j2ds[...,1].int().clip(0,img_size[0]-1), j2ds[...,0].int().clip(0,img_size[1]-1)] > 0) |
|
|
j2ds_vis &= (j2ds[...,1] >= 0) & (j2ds[...,1] < img_size[0]) |
|
|
j2ds_vis &= (j2ds[...,0] >= 0) & (j2ds[...,0] < img_size[1]) |
|
|
|
|
|
j2ds_invalid = ~j2ds_vis |
|
|
j2ds_mask[j2ds_invalid] = False |
|
|
meta_data.update({'j2ds_mask': j2ds_mask}) |
|
|
|
|
|
vis_cnt = j2ds_mask[...,0].sum(dim = -1) |
|
|
valid_msk = (vis_cnt >= self.vis_thresh) |
|
|
|
|
|
pnum = valid_msk.sum().item() |
|
|
|
|
|
if pnum == 0: |
|
|
meta_data['pnum'] = pnum |
|
|
return |
|
|
|
|
|
if pnum < meta_data['pnum']: |
|
|
meta_data['pnum'] = pnum |
|
|
for key in self.human_keys: |
|
|
if key in meta_data: |
|
|
if isinstance(meta_data[key], list): |
|
|
meta_data[key] = np.array(meta_data[key])[valid_msk].tolist() |
|
|
else: |
|
|
meta_data[key] = meta_data[key][valid_msk] |
|
|
if 'cam_intrinsics' in meta_data and len(meta_data['cam_intrinsics']) > 1: |
|
|
meta_data['cam_intrinsics'] = meta_data['cam_intrinsics'][valid_msk] |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
def process_data(self, img, raw_data, rot = 0., flip = False, scale = 1., crop = False): |
|
|
meta_data = copy.deepcopy(raw_data) |
|
|
|
|
|
rot_aug_mat = torch.tensor([[cos(radians(-rot)), -sin(radians(-rot)), 0.], |
|
|
[sin(radians(-rot)), cos(radians(-rot)), 0.], |
|
|
[0., 0., 1.]]) |
|
|
meta_data.update({'rot_aug_mat': rot_aug_mat}) |
|
|
|
|
|
img = self.process_img(img, meta_data, rot, flip, scale, crop) |
|
|
|
|
|
self.process_cam(meta_data, rot, flip, scale) |
|
|
self.process_smpl(meta_data, rot, flip, scale) |
|
|
self.project_joints(meta_data) |
|
|
self.check_visibility(meta_data) |
|
|
matcher_vis = meta_data['j2ds_mask'][:,:22,0].sum(dim = -1) |
|
|
if meta_data['pnum'] == 0 or not torch.all(matcher_vis): |
|
|
if self.mode == 'train': |
|
|
meta_data['pnum'] = 0 |
|
|
return img, meta_data |
|
|
|
|
|
j3ds = meta_data['j3ds'] |
|
|
depths = j3ds[:, smpl_root_idx, [2]].clone() |
|
|
if len(meta_data['cam_intrinsics']) == 1: |
|
|
focals = torch.full_like(depths, meta_data['cam_intrinsics'][0,0,0]) |
|
|
else: |
|
|
focals = meta_data['cam_intrinsics'][:,0,0][:, None] |
|
|
depths = torch.cat([depths, depths/focals],dim=-1) |
|
|
meta_data.update({'depths': depths, 'focals': focals}) |
|
|
|
|
|
self.get_boxes(meta_data) |
|
|
|
|
|
meta_data.update({'labels': torch.zeros(meta_data['pnum'], dtype=int)}) |
|
|
|
|
|
|
|
|
|
|
|
if self.aug: |
|
|
occ_boxes = self.occlusion_aug(meta_data) |
|
|
for (synth_ymin, synth_h, synth_xmin, synth_w) in occ_boxes: |
|
|
img[synth_ymin:synth_ymin + synth_h, synth_xmin:synth_xmin + synth_w, :] = np.random.rand(synth_h, synth_w, 3) * 255 |
|
|
|
|
|
if self.use_sat: |
|
|
|
|
|
boxes = meta_data['boxes'] |
|
|
scales = boxes[:,2:].norm(p=2,dim=1) |
|
|
v3ds = meta_data['verts'] |
|
|
depths_norm = meta_data['depths'][:,1] |
|
|
cam_intrinsics = meta_data['cam_intrinsics'] |
|
|
sorted_idx = torch.argsort(depths_norm, descending=True) |
|
|
map_size = (meta_data['img_size'] + 27)//28 |
|
|
|
|
|
scale_map = gen_scale_map(scales[sorted_idx], v3ds[sorted_idx], |
|
|
faces = self.human_model.faces, |
|
|
cam_intrinsics = cam_intrinsics[sorted_idx] if len(cam_intrinsics) > 1 else cam_intrinsics, |
|
|
map_size = map_size, |
|
|
patch_size = 28, |
|
|
pad = True) |
|
|
scale_map_z, _, pos_y, pos_x = to_zorder(scale_map, |
|
|
z_order_map = self.z_order_map, |
|
|
y_coords = self.y_coords, |
|
|
x_coords = self.x_coords) |
|
|
meta_data['scale_map'] = scale_map_z |
|
|
meta_data['scale_map_pos'] = {'pos_y': pos_y, 'pos_x': pos_x} |
|
|
meta_data['scale_map_hw'] = scale_map.shape[:2] |
|
|
|
|
|
|
|
|
|
|
|
return img, meta_data |
|
|
|
|
|
def __getitem__(self, index): |
|
|
|
|
|
raw_data = self.get_raw_data(index) |
|
|
|
|
|
|
|
|
ori_img = cv2.imread(raw_data['img_path']) |
|
|
if raw_data['ds'] == 'bedlam' and 'closeup' in raw_data['img_path']: |
|
|
ori_img = cv2.rotate(ori_img, cv2.ROTATE_90_CLOCKWISE) |
|
|
img_size = torch.tensor(ori_img.shape[:2]) |
|
|
raw_data.update({'img_size': img_size}) |
|
|
|
|
|
if self.mode == 'train': |
|
|
cnt = 0 |
|
|
while (True): |
|
|
aug_dict = self.get_aug_dict() |
|
|
img, meta_data = self.process_data(ori_img, raw_data, **aug_dict) |
|
|
if meta_data['pnum'] > 0: |
|
|
break |
|
|
cnt+=1 |
|
|
if cnt >= 10: |
|
|
aug_dict.update({'rot':0., 'scale':1., 'crop': False}) |
|
|
img, meta_data = self.process_data(ori_img, raw_data, **aug_dict) |
|
|
if meta_data['pnum'] == 0: |
|
|
print('skipping: ' + meta_data['img_path']) |
|
|
return self.__getitem__(index + 1) |
|
|
|
|
|
elif self.mode == 'eval': |
|
|
assert not self.aug, f'No need to use augmentation when mode is {self.mode}!' |
|
|
aug_dict = self.get_aug_dict() |
|
|
img, meta_data = self.process_data(ori_img, raw_data, **aug_dict) |
|
|
|
|
|
else: |
|
|
assert not self.aug, f'No need to use augmentation when mode is {self.mode}!' |
|
|
meta_data = raw_data |
|
|
img = self.process_img(ori_img, meta_data) |
|
|
|
|
|
|
|
|
if self.mode == 'train': |
|
|
for key in list(meta_data.keys()): |
|
|
if key not in self.img_keys and key not in self.human_keys: |
|
|
del meta_data[key] |
|
|
|
|
|
|
|
|
if self.aug: |
|
|
array2tensor = transforms.Compose([ |
|
|
transforms.ColorJitter(0.2, 0.2, 0.2), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) |
|
|
]) |
|
|
else: |
|
|
array2tensor = transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) |
|
|
]) |
|
|
|
|
|
|
|
|
patch_size = 14 |
|
|
if self.use_sat: |
|
|
patch_size = 56 |
|
|
pad_img = np.zeros((math.ceil(img.shape[0]/patch_size)*patch_size, math.ceil(img.shape[1]/patch_size)*patch_size, 3), dtype=img.dtype) |
|
|
pad_img[:img.shape[0], :img.shape[1]] = img |
|
|
assert max(pad_img.shape[:2]) == self.input_size |
|
|
pad_img = Image.fromarray(pad_img[:,:,::-1].copy()) |
|
|
norm_img = array2tensor(pad_img) |
|
|
|
|
|
if 'j2ds_mask' in meta_data: |
|
|
meta_data['j2ds_mask'][:,:,:] = True |
|
|
|
|
|
return norm_img, meta_data |
|
|
|
|
|
def visualize(self, results_save_dir = None, vis_num = 100): |
|
|
if results_save_dir is None: |
|
|
results_save_dir = os.path.join('datasets_visualization',f'{self.ds_name}_{self.split}') |
|
|
os.makedirs(results_save_dir, exist_ok=True) |
|
|
|
|
|
vis_interval = len(self)//vis_num |
|
|
|
|
|
for idx in tqdm(range(len(self))): |
|
|
if idx % vis_interval != 0: |
|
|
continue |
|
|
|
|
|
norm_img, targets = self.__getitem__(idx) |
|
|
|
|
|
ori_img = tensor_to_BGR(unNormalize(norm_img).cpu()) |
|
|
img_name = targets['img_path'].split('/')[-1].split('.')[-2] |
|
|
pnum = targets['pnum'] |
|
|
|
|
|
if 'verts' in targets: |
|
|
colors = get_colors_rgb(len(targets['verts'])) |
|
|
mesh_img = vis_meshes_img(img = ori_img.copy(), |
|
|
verts = targets['verts'], |
|
|
smpl_faces = self.human_model.faces, |
|
|
cam_intrinsics = targets['cam_intrinsics'].cpu(), |
|
|
colors=colors, |
|
|
padding=False) |
|
|
cv2.imwrite(os.path.join(results_save_dir,f'{idx}_{img_name}_mesh.jpg'), mesh_img) |
|
|
|
|
|
|
|
|
if 'boxes' in targets: |
|
|
gt_img = ori_img.copy() |
|
|
boxes = box_cxcywh_to_xyxy(targets['boxes']) * self.input_size |
|
|
for i, bbox in enumerate(boxes): |
|
|
bbox = bbox.int().tolist() |
|
|
cv2.rectangle(gt_img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), |
|
|
color=(0,0,255), thickness = 2 ) |
|
|
|
|
|
cv2.imwrite(os.path.join(results_save_dir,f'{idx}_{img_name}_boxes.jpg'), gt_img) |
|
|
|
|
|
if 'scale_map' in targets: |
|
|
gt_img = ori_img.copy() |
|
|
flatten_map = targets['scale_map'] |
|
|
ys, xs = targets['scale_map_pos']['pos_y'], targets['scale_map_pos']['pos_x'] |
|
|
h, w = targets['scale_map_hw'] |
|
|
scale_map = torch.zeros((h,w,2)) |
|
|
scale_map[ys,xs] = flatten_map |
|
|
img = vis_scale_img(gt_img, scale_map, patch_size=28) |
|
|
|
|
|
cv2.imwrite(os.path.join(results_save_dir,f'{idx}_{img_name}_scales.jpg'), img) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|