RoyYang0714's picture
feat: Try to build everything locally.
9b33fca
"""Crop transforms."""
from __future__ import annotations
from vis4d.common.typing import (
NDArrayBool,
NDArrayF32,
NDArrayI64,
)
from vis4d.data.const import CommonKeys as K
from vis4d.data.transforms.base import Transform
@Transform(
in_keys=[
K.boxes3d,
K.boxes3d_classes,
K.boxes3d_track_ids,
"transforms.crop.keep_mask",
],
out_keys=[K.boxes3d, K.boxes3d_classes, K.boxes3d_track_ids],
)
class CropBoxes3D:
"""Crop 3D bounding boxes."""
def __call__(
self,
boxes_list: list[NDArrayF32],
classes_list: list[NDArrayI64],
track_ids_list: list[NDArrayI64] | None,
keep_mask_list: list[NDArrayBool],
) -> tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64] | None]:
"""Crop 3D bounding boxes."""
for i, (boxes, classes, keep_mask) in enumerate(
zip(boxes_list, classes_list, keep_mask_list)
):
boxes_list[i] = boxes[keep_mask]
classes_list[i] = classes[keep_mask]
if track_ids_list is not None:
track_ids_list[i] = track_ids_list[i][keep_mask]
return boxes_list, classes_list, track_ids_list