Spaces:
Sleeping
Sleeping
| """Language related transforms.""" | |
| from __future__ import annotations | |
| import random | |
| import re | |
| import numpy as np | |
| from transformers import AutoTokenizer | |
| from vis4d.common.logging import rank_zero_warn | |
| from vis4d.common.typing import NDArrayF32, NDArrayI64 | |
| from vis4d.data.const import CommonKeys as K | |
| from vis4d.data.transforms.base import Transform | |
| def clean_name(name: str) -> str: | |
| """Clean the name.""" | |
| name = re.sub(r"\(.*\)", "", name) | |
| name = re.sub(r"_", " ", name) | |
| name = re.sub(r" ", " ", name) | |
| name = name.lower() | |
| return name | |
| def generate_senetence_given_labels( | |
| positive_label_list: list[int], | |
| negative_label_list: list[str], | |
| label_map: dict[str, str], | |
| ) -> tuple[dict[int, list[list[int]]], str, dict[int, int]]: | |
| """Generate a sentence given positive and negative labels.""" | |
| label_to_positions = {} | |
| label_list = negative_label_list + positive_label_list | |
| random.shuffle(label_list) | |
| pheso_caption = "" | |
| label_remap_dict = {} | |
| for index, label in enumerate(label_list): | |
| start_index = len(pheso_caption) | |
| pheso_caption += clean_name(label_map[str(label)]) | |
| end_index = len(pheso_caption) | |
| if label in positive_label_list: | |
| label_to_positions[index] = [[start_index, end_index]] | |
| label_remap_dict[int(label)] = index | |
| pheso_caption += ". " | |
| return label_to_positions, pheso_caption, label_remap_dict | |
| class RandomSamplingNegPos: | |
| """Randomly sample negative and positive labels for object detection.""" | |
| def __init__( | |
| self, | |
| tokenizer_name: str = "bert-base-uncased", | |
| num_sample_negative: int = 85, | |
| max_tokens: int = 256, | |
| full_sampling_prob: float = 0.5, | |
| ) -> None: | |
| """Creates an instance of RandomSamplingNegPos.""" | |
| if AutoTokenizer is None: | |
| raise RuntimeError( | |
| "transformers is not installed, please install it by: " | |
| "pip install transformers." | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
| self.num_sample_negative = num_sample_negative | |
| self.full_sampling_prob = full_sampling_prob | |
| self.max_tokens = max_tokens | |
| def __call__( | |
| self, | |
| dataset_type_list: list[str], | |
| boxes_list: list[NDArrayF32], | |
| class_ids_list: list[NDArrayI64], | |
| texts_list: list[str] | None = None, | |
| label_map_list: dict | None = None, | |
| positive_positions_list: list[dict] | None = None, | |
| ) -> tuple[ | |
| list[NDArrayF32], | |
| list[NDArrayI64], | |
| list[str], | |
| list[dict[int, list[list[int]]]], | |
| ]: | |
| """Randomly sample negative and positive labels.""" | |
| new_texts_list = [] | |
| tokens_positive_list = [] | |
| for i, (boxes, class_ids) in enumerate( | |
| zip(boxes_list, class_ids_list) | |
| ): | |
| if dataset_type_list[i] == "OD": | |
| assert ( | |
| label_map_list[i] is not None | |
| ), "label_map should not be None" | |
| boxes_list[i], class_ids_list[i], text, tokens_positive = ( | |
| self.od_aug(boxes, class_ids, label_map_list[i]) | |
| ) | |
| new_texts_list.append(text) | |
| tokens_positive_list.append(tokens_positive) | |
| else: | |
| assert ( | |
| positive_positions_list[i] is not None | |
| ), "positive_positions should not be None" | |
| tokens_positive = self.vg_aug( | |
| class_ids, positive_positions_list[i] | |
| ) | |
| new_texts_list.append(texts_list[i]) | |
| tokens_positive_list.append(tokens_positive) | |
| return boxes_list, class_ids_list, new_texts_list, tokens_positive_list | |
| def vg_aug(self, class_ids: NDArrayI64, positive_positions): | |
| """Visual Genome data augmentation.""" | |
| positive_label_list = np.unique(class_ids).tolist() | |
| label_to_positions = {} | |
| for label in positive_label_list: | |
| label_to_positions[label] = positive_positions[label] | |
| return label_to_positions | |
| def od_aug( | |
| self, | |
| boxes: NDArrayF32, | |
| class_ids: NDArrayI64, | |
| label_map: dict, | |
| ) -> tuple[NDArrayF32, NDArrayI64, str, dict[int, list[list[int]]]]: | |
| """Object detection data augmentation.""" | |
| original_box_num = len(class_ids) | |
| # If the category name is in the format of 'a/b' (in object365), | |
| # we randomly select one of them. | |
| for key, value in label_map.items(): | |
| if "/" in value: | |
| label_map[key] = random.choice(value.split("/")).strip() | |
| keep_box_index, class_ids, positive_caption_length = ( | |
| self.check_for_positive_overflow(class_ids, label_map) | |
| ) | |
| boxes = boxes[keep_box_index] | |
| if len(boxes) < original_box_num: | |
| rank_zero_warn( | |
| f"Remove {original_box_num - len(boxes)} boxes due to " | |
| "positive caption overflow." | |
| ) | |
| valid_negative_indexes = list(label_map.keys()) | |
| positive_label_list = np.unique(class_ids).tolist() | |
| full_negative = self.num_sample_negative | |
| if full_negative > len(valid_negative_indexes): | |
| full_negative = len(valid_negative_indexes) | |
| outer_prob = random.random() | |
| if outer_prob < self.full_sampling_prob: | |
| # c. probability_full: add both all positive and all negatives | |
| num_negatives = full_negative | |
| else: | |
| if random.random() < 1.0: | |
| num_negatives = np.random.choice(max(1, full_negative)) + 1 | |
| else: | |
| num_negatives = full_negative | |
| # Keep some negatives | |
| negative_label_list = set() | |
| if num_negatives != -1: | |
| if num_negatives > len(valid_negative_indexes): | |
| num_negatives = len(valid_negative_indexes) | |
| for i in np.random.choice( | |
| valid_negative_indexes, size=num_negatives, replace=False | |
| ): | |
| if int(i) not in positive_label_list: | |
| negative_label_list.add(i) | |
| random.shuffle(positive_label_list) | |
| negative_label_list = list(negative_label_list) | |
| random.shuffle(negative_label_list) | |
| negative_max_length = self.max_tokens - positive_caption_length | |
| screened_negative_label_list = [] | |
| for negative_label in negative_label_list: | |
| label_text = clean_name(label_map[str(negative_label)]) + ". " | |
| tokenized = self.tokenizer.tokenize(label_text) | |
| negative_max_length -= len(tokenized) | |
| if negative_max_length > 0: | |
| screened_negative_label_list.append(negative_label) | |
| else: | |
| break | |
| negative_label_list = screened_negative_label_list | |
| label_to_positions, pheso_caption, label_remap_dict = ( | |
| generate_senetence_given_labels( | |
| positive_label_list, negative_label_list, label_map | |
| ) | |
| ) | |
| # label remap | |
| if len(class_ids) > 0: | |
| class_ids = np.vectorize(lambda x: label_remap_dict[x])(class_ids) | |
| return boxes, class_ids, pheso_caption, label_to_positions | |
| def check_for_positive_overflow( | |
| self, class_ids: NDArrayI64, label_map: dict[str, str] | |
| ) -> tuple[list[int], NDArrayI64, int]: | |
| """Check if having too many positive labels.""" | |
| # generate a caption by appending the positive labels | |
| positive_label_list = np.unique(class_ids).tolist() | |
| # random shuffule so we can sample different annotations | |
| # at different epochs | |
| random.shuffle(positive_label_list) | |
| kept_lables = [] | |
| length = 0 | |
| for _, label in enumerate(positive_label_list): | |
| label_text = clean_name(label_map[str(label)]) + ". " | |
| tokenized = self.tokenizer.tokenize(label_text) | |
| length += len(tokenized) | |
| if length > self.max_tokens: | |
| break | |
| else: | |
| kept_lables.append(label) | |
| keep_box_index = [] | |
| keep_gt_labels = [] | |
| for i, class_id in enumerate(class_ids): | |
| if class_id in kept_lables: | |
| keep_box_index.append(i) | |
| keep_gt_labels.append(class_id) | |
| return ( | |
| keep_box_index, | |
| np.array(keep_gt_labels, dtype=np.int64), | |
| length, | |
| ) | |