| | import importlib |
| | import math |
| | from functools import partial |
| | from typing import TYPE_CHECKING, Any, Callable, Generator, List, Optional, Tuple, Union |
| | import torch |
| | import torch.nn.functional as F |
| | import torch.utils.checkpoint |
| | from torch.cuda.amp import autocast |
| |
|
| | from transformers import GenerationConfig, PreTrainedTokenizer, StoppingCriteriaList |
| | from transformers.generation.logits_process import LogitsProcessorList |
| |
|
| | if TYPE_CHECKING: |
| | from transformers.generation.streamers import BaseStreamer |
| |
|
| | from transformers.generation.utils import GenerateOutput |
| | from transformers.modeling_outputs import ( |
| | BaseModelOutputWithPast, |
| | CausalLMOutputWithPast, |
| | ) |
| | from transformers.modeling_utils import PreTrainedModel |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | from transformers.utils import logging |
| |
|
| | try: |
| | from einops import rearrange |
| | except ImportError: |
| | rearrange = None |
| | from torch import nn |
| |
|
| | from .configuration_infimm_vicuna import InfiMMConfig |
| | from .eva_vit import CLIPVisionCfg, EVAVisionTransformer |
| | from .flamingo import Flamingo |
| | from .flamingo_lm import FlamingoLMMixin |
| | from .helpers import PerceiverResampler |
| | from .utils import _infer_decoder_layers_attr_name, extend_instance |
| |
|
| | SUPPORT_CUDA = torch.cuda.is_available() |
| | SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported() |
| | SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7 |
| |
|
| |
|
| | class InfiMMPreTrainedModel(PreTrainedModel): |
| | config_class = InfiMMConfig |
| | base_model_prefix = "transformer" |
| | is_parallelizable = False |
| | supports_gradient_checkpointing = True |
| |
|
| | def __init__(self, *inputs, **kwargs): |
| | super().__init__(*inputs, **kwargs) |
| |
|
| |
|
| | class InfiMMVicunaModel(InfiMMPreTrainedModel): |
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | self.vision_config = config.visual |
| | vision_encoder = self.build_vision_encoder() |
| | self.language_config = config.language |
| | language_encoder = self.build_language_encoder() |
| |
|
| | self.model = self.build_flamingo(vision_encoder, language_encoder) |
| |
|
| | def build_vision_encoder(self): |
| | vision_cfg = CLIPVisionCfg(**self.vision_config) |
| |
|
| | vision_encoder = EVAVisionTransformer( |
| | img_size=vision_cfg.image_size, |
| | patch_size=vision_cfg.patch_size, |
| | num_classes=vision_cfg.embed_dim, |
| | use_mean_pooling=vision_cfg.global_average_pool, |
| | init_values=vision_cfg.ls_init_value, |
| | patch_dropout=vision_cfg.patch_dropout, |
| | embed_dim=vision_cfg.width, |
| | depth=vision_cfg.layers, |
| | num_heads=vision_cfg.width // vision_cfg.head_width, |
| | mlp_ratio=vision_cfg.mlp_ratio, |
| | qkv_bias=vision_cfg.qkv_bias, |
| | drop_path_rate=vision_cfg.drop_path_rate, |
| | norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| | xattn=vision_cfg.xattn, |
| | rope=vision_cfg.rope, |
| | postnorm=vision_cfg.postnorm, |
| | pt_hw_seq_len=vision_cfg.pt_hw_seq_len, |
| | intp_freq=vision_cfg.intp_freq, |
| | naiveswiglu=vision_cfg.naiveswiglu, |
| | subln=vision_cfg.subln, |
| | ) |
| |
|
| | return vision_encoder |
| |
|
| | def build_language_encoder(self): |
| | lang_encoder = AutoModelForCausalLM.from_pretrained( |
| | self.language_config["_name_or_path"] |
| | ) |
| | lang_encoder.resize_token_embeddings(self.language_config["vocab_size"]) |
| | return lang_encoder |
| |
|
| | def build_flamingo(self, vision_encoder, lang_encoder): |
| | extend_instance(lang_encoder, FlamingoLMMixin) |
| |
|
| | decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) |
| | lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) |
| | |
| |
|
| | model = Flamingo( |
| | vision_encoder, |
| | lang_encoder, |
| | self.config.eoc_token_id, |
| | self.config.image_token_id, |
| | vis_dim=self.vision_config["width"], |
| | cross_attn_every_n_layers=self.config.cross_attn_every_n_layers, |
| | gradient_checkpointing=self.config.use_grad_checkpoint, |
| | ) |
| |
|
| | return model |
| |
|
| | def generate( |
| | self, |
| | input_ids, |
| | attention_mask, |
| | batch_images, |
| | min_generation_length: int, |
| | max_generation_length: int, |
| | **kwargs, |
| | ): |
| | with torch.inference_mode(): |
| | outputs = self.model.generate( |
| | batch_images, |
| | input_ids, |
| | attention_mask, |
| | min_new_tokens=min_generation_length, |
| | max_new_tokens=max_generation_length, |
| | **kwargs, |
| | ) |
| |
|
| | |
| | outputs = outputs[:, len(input_ids[0]) :] |
| | return outputs |
| |
|