| |
| |
| |
| |
| """Distilled Audio State-Space Model (DASS) model""" |
|
|
| import math |
| import torch |
| import warnings |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.utils.checkpoint as checkpoint |
| from functools import partial |
| from typing import Optional, Callable, Any, Union |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss |
| from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
| from transformers.utils import logging |
| from transformers.modeling_utils import PreTrainedModel |
|
|
| from .configuration_dass import DASSConfig |
|
|
| logger = logging.get_logger(__name__) |
|
|
| |
| _CONFIG_FOR_DOC = "DASSConfig" |
|
|
| WITH_TRITON = True |
| |
| try: |
| import triton |
| import triton.language as tl |
| except: |
| WITH_TRITON = False |
| warnings.warn("Triton not installed, fall back to pytorch implements.") |
|
|
| |
| if WITH_TRITON: |
| try: |
| from functools import cached_property |
| except: |
| warnings.warn("if you are using py37, add this line to functools.py: " |
| "cached_property = lambda func: property(lru_cache()(func))") |
|
|
| |
| def cross_scan_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): |
| if in_channel_first: |
| B, C, H, W = x.shape |
| if scans == 0: |
| y = x.new_empty((B, 4, C, H * W)) |
| y[:, 0, :, :] = x.flatten(2, 3) |
| y[:, 1, :, :] = x.transpose(dim0=2, dim1=3).flatten(2, 3) |
| y[:, 2:4, :, :] = torch.flip(y[:, 0:2, :, :], dims=[-1]) |
| elif scans == 1: |
| y = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1) |
| elif scans == 2: |
| y = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1) |
| y = torch.cat([y, y.flip(dims=[-1])], dim=1) |
| elif scans == 3: |
| y = x.new_empty((B, 4, C, H * W)) |
| y[:, 0, :, :] = x.flatten(2, 3) |
| y[:, 1, :, :] = torch.rot90(x, 1, dims=(2, 3)).flatten(2, 3) |
| y[:, 2, :, :] = torch.rot90(x, 2, dims=(2, 3)).flatten(2, 3) |
| y[:, 3, :, :] = torch.rot90(x, 3, dims=(2, 3)).flatten(2, 3) |
| else: |
| B, H, W, C = x.shape |
| if scans == 0: |
| y = x.new_empty((B, H * W, 4, C)) |
| y[:, :, 0, :] = x.flatten(1, 2) |
| y[:, :, 1, :] = x.transpose(dim0=1, dim1=2).flatten(1, 2) |
| y[:, :, 2:4, :] = torch.flip(y[:, :, 0:2, :], dims=[1]) |
| elif scans == 1: |
| y = x.view(B, H * W, 1, C).repeat(1, 1, 4, 1) |
| elif scans == 2: |
| y = x.view(B, H * W, 1, C).repeat(1, 1, 2, 1) |
| y = torch.cat([y, y.flip(dims=[1])], dim=2) |
| elif scans == 3: |
| y = x.new_empty((B, H * W, 4, C)) |
| y[:, :, 0, :] = x.flatten(1, 2) |
| y[:, :, 1, :] = torch.rot90(x, 1, dims=(1, 2)).flatten(1, 2) |
| y[:, :, 2, :] = torch.rot90(x, 2, dims=(1, 2)).flatten(1, 2) |
| y[:, :, 3, :] = torch.rot90(x, 3, dims=(1, 2)).flatten(1, 2) |
|
|
| if in_channel_first and (not out_channel_first): |
| y = y.permute(0, 3, 1, 2).contiguous() |
| elif (not in_channel_first) and out_channel_first: |
| y = y.permute(0, 2, 3, 1).contiguous() |
|
|
| return y |
|
|
|
|
| def cross_merge_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): |
| if out_channel_first: |
| B, K, D, H, W = y.shape |
| y = y.view(B, K, D, -1) |
| if scans == 0: |
| y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) |
| y = y[:, 0] + y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1) |
| elif scans == 1: |
| y = y.sum(1) |
| elif scans == 2: |
| y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) |
| y = y.sum(1) |
| elif scans == 3: |
| oy = y[:, 0, :, :].contiguous().view(B, D, -1) |
| oy = oy + torch.rot90(y.view(B, K, D, W, H)[:, 1, :, :, :], -1, dims=(2, 3)).flatten(2, 3) |
| oy = oy + torch.rot90(y.view(B, K, D, H, W)[:, 2, :, :, :], -2, dims=(2, 3)).flatten(2, 3) |
| oy = oy + torch.rot90(y.view(B, K, D, W, H)[:, 3, :, :, :], -3, dims=(2, 3)).flatten(2, 3) |
| y = oy |
| else: |
| B, H, W, K, D = y.shape |
| y = y.view(B, -1, K, D) |
| if scans == 0: |
| y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D) |
| y = y[:, :, 0] + y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).contiguous().view(B, -1, D) |
| elif scans == 1: |
| y = y.sum(2) |
| elif scans == 2: |
| y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D) |
| y = y.sum(2) |
| elif scans == 3: |
| oy = y[:, :, 0, :].contiguous().view(B, -1, D) |
| oy = oy + torch.rot90(y.view(B, W, H, K, D)[:, :, :, 1, :], -1, dims=(1, 2)).flatten(1, 2) |
| oy = oy + torch.rot90(y.view(B, H, W, K, D)[:, :, :, 2, :], -2, dims=(1, 2)).flatten(1, 2) |
| oy = oy + torch.rot90(y.view(B, W, H, K, D)[:, :, :, 3, :], -3, dims=(1, 2)).flatten(1, 2) |
| y = oy |
| |
| if in_channel_first and (not out_channel_first): |
| y = y.permute(0, 2, 1).contiguous() |
| elif (not in_channel_first) and out_channel_first: |
| y = y.permute(0, 2, 1).contiguous() |
| |
| return y |
|
|
|
|
| def cross_scan1b1_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): |
| if in_channel_first: |
| B, _, C, H, W = x.shape |
| if scans == 0: |
| y = torch.stack([ |
| x[:, 0].flatten(2, 3), |
| x[:, 1].transpose(dim0=2, dim1=3).flatten(2, 3), |
| torch.flip(x[:, 2].flatten(2, 3), dims=[-1]), |
| torch.flip(x[:, 3].transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]), |
| ], dim=1) |
| elif scans == 1: |
| y = x.flatten(2, 3) |
| elif scans == 2: |
| y = torch.stack([ |
| x[:, 0].flatten(2, 3), |
| x[:, 1].flatten(2, 3), |
| torch.flip(x[:, 2].flatten(2, 3), dims=[-1]), |
| torch.flip(x[:, 3].flatten(2, 3), dims=[-1]), |
| ], dim=1) |
| elif scans == 3: |
| y = torch.stack([ |
| x[:, 0, :, :, :].flatten(2, 3), |
| torch.rot90(x[:, 1, :, :, :], 1, dims=(2, 3)).flatten(2, 3), |
| torch.rot90(x[:, 2, :, :, :], 2, dims=(2, 3)).flatten(2, 3), |
| torch.rot90(x[:, 3, :, :, :], 3, dims=(2, 3)).flatten(2, 3), |
| ], dim=1) |
|
|
| else: |
| B, H, W, _, C = x.shape |
| if scans == 0: |
| y = torch.stack([ |
| x[:, :, :, 0].flatten(1, 2), |
| x[:, :, :, 1].transpose(dim0=1, dim1=2).flatten(1, 2), |
| torch.flip(x[:, :, :, 2].flatten(1, 2), dims=[1]), |
| torch.flip(x[:, :, :, 3].transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]), |
| ], dim=2) |
| elif scans == 1: |
| y = x.flatten(1, 2) |
| elif scans == 2: |
| y = torch.stack([ |
| x[:, 0].flatten(1, 2), |
| x[:, 1].flatten(1, 2), |
| torch.flip(x[:, 2].flatten(1, 2), dims=[-1]), |
| torch.flip(x[:, 3].flatten(1, 2), dims=[-1]), |
| ], dim=2) |
| elif scans == 3: |
| y = torch.stack([ |
| x[:, :, :, 0, :].flatten(1, 2), |
| torch.rot90(x[:, :, :, 1, :], 1, dims=(1, 2)).flatten(1, 2), |
| torch.rot90(x[:, :, :, 2, :], 2, dims=(1, 2)).flatten(1, 2), |
| torch.rot90(x[:, :, :, 3, :], 3, dims=(1, 2)).flatten(1, 2), |
| ], dim=1) |
|
|
| if in_channel_first and (not out_channel_first): |
| y = y.permute(0, 3, 1, 2).contiguous() |
| elif (not in_channel_first) and out_channel_first: |
| y = y.permute(0, 2, 3, 1).contiguous() |
|
|
| return y |
|
|
|
|
| def cross_merge1b1_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): |
| if out_channel_first: |
| B, K, D, H, W = y.shape |
| y = y.view(B, K, D, -1) |
| if scans == 0: |
| y = torch.stack([ |
| y[:, 0], |
| y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), |
| torch.flip(y[:, 2], dims=[-1]), |
| torch.flip(y[:, 3].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]), |
| ], dim=1) |
| elif scans == 1: |
| y = y |
| elif scans == 2: |
| y = torch.stack([ |
| y[:, 0], |
| y[:, 1], |
| torch.flip(y[:, 2], dims=[-1]), |
| torch.flip(y[:, 3], dims=[-1]), |
| ], dim=1) |
| elif scans == 3: |
| y = torch.stack([ |
| y[:, 0, :, :].contiguous().view(B, D, -1), |
| torch.rot90(y.view(B, K, D, W, H)[:, 1, :, :, :], -1, dims=(2, 3)).flatten(2, 3), |
| torch.rot90(y.view(B, K, D, H, W)[:, 2, :, :, :], -2, dims=(2, 3)).flatten(2, 3), |
| torch.rot90(y.view(B, K, D, W, H)[:, 3, :, :, :], -3, dims=(2, 3)).flatten(2, 3), |
| ], dim=1) |
| else: |
| B, H, W, K, D = y.shape |
| y = y.view(B, -1, K, D) |
| if scans == 0: |
| y = torch.stack([ |
| y[:, :, 0], |
| y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), |
| torch.flip(y[:, :, 2], dims=[1]), |
| torch.flip(y[:, :, 3].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]), |
| ], dim=2) |
| elif scans == 1: |
| y = y |
| elif scans == 2: |
| y = torch.stack([ |
| y[:, :, 0], |
| y[:, :, 1], |
| torch.flip(y[:, :, 2], dims=[1]), |
| torch.flip(y[:, :, 3], dims=[1]), |
| ], dim=2) |
| elif scans == 3: |
| y = torch.stack([ |
| y[:, :, 0, :].contiguous().view(B, -1, D), |
| torch.rot90(y.view(B, W, H, K, D)[:, :, :, 1, :], -1, dims=(1, 2)).flatten(1, 2), |
| torch.rot90(y.view(B, H, W, K, D)[:, :, :, 2, :], -2, dims=(1, 2)).flatten(1, 2), |
| torch.rot90(y.view(B, W, H, K, D)[:, :, :, 3, :], -3, dims=(1, 2)).flatten(1, 2), |
| ], dim=2) |
|
|
| if out_channel_first and (not in_channel_first): |
| y = y.permute(0, 3, 1, 2).contiguous() |
| elif (not out_channel_first) and in_channel_first: |
| y = y.permute(0, 2, 3, 1).contiguous() |
|
|
| return y |
|
|
|
|
| class CrossScanF(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): |
| |
| |
| ctx.in_channel_first = in_channel_first |
| ctx.out_channel_first = out_channel_first |
| ctx.one_by_one = one_by_one |
| ctx.scans = scans |
|
|
| if one_by_one: |
| B, K, C, H, W = x.shape |
| if not in_channel_first: |
| B, H, W, K, C = x.shape |
| else: |
| B, C, H, W = x.shape |
| if not in_channel_first: |
| B, H, W, C = x.shape |
| ctx.shape = (B, C, H, W) |
|
|
| _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd |
| y = _fn(x, in_channel_first, out_channel_first, scans) |
|
|
| return y |
| |
| @staticmethod |
| def backward(ctx, ys: torch.Tensor): |
| |
| in_channel_first = ctx.in_channel_first |
| out_channel_first = ctx.out_channel_first |
| one_by_one = ctx.one_by_one |
| scans = ctx.scans |
| B, C, H, W = ctx.shape |
|
|
| ys = ys.view(B, -1, C, H, W) if out_channel_first else ys.view(B, H, W, -1, C) |
| _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd |
| y = _fn(ys, in_channel_first, out_channel_first, scans) |
| |
| if one_by_one: |
| y = y.view(B, 4, -1, H, W) if in_channel_first else y.view(B, H, W, 4, -1) |
| else: |
| y = y.view(B, -1, H, W) if in_channel_first else y.view(B, H, W, -1) |
|
|
| return y, None, None, None, None |
|
|
|
|
| class CrossMergeF(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, ys: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): |
| |
| |
| ctx.in_channel_first = in_channel_first |
| ctx.out_channel_first = out_channel_first |
| ctx.one_by_one = one_by_one |
| ctx.scans = scans |
|
|
| B, K, C, H, W = ys.shape |
| if not out_channel_first: |
| B, H, W, K, C = ys.shape |
| ctx.shape = (B, C, H, W) |
| |
| _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd |
| y = _fn(ys, in_channel_first, out_channel_first, scans) |
|
|
| return y |
| |
| @staticmethod |
| def backward(ctx, x: torch.Tensor): |
| |
| |
| in_channel_first = ctx.in_channel_first |
| out_channel_first = ctx.out_channel_first |
| one_by_one = ctx.one_by_one |
| scans = ctx.scans |
| B, C, H, W = ctx.shape |
| |
| if not one_by_one: |
| if in_channel_first: |
| x = x.view(B, C, H, W) |
| else: |
| x = x.view(B, H, W, C) |
| else: |
| if in_channel_first: |
| x = x.view(B, 4, C, H, W) |
| else: |
| x = x.view(B, H, W, 4, C) |
| |
| _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd |
| x = _fn(x, in_channel_first, out_channel_first, scans) |
| x = x.view(B, 4, C, H, W) if out_channel_first else x.view(B, H, W, 4, C) |
| |
| return x, None, None, None, None |
|
|
|
|
| |
|
|
| @triton.jit |
| def triton_cross_scan_flex( |
| x: tl.tensor, |
| y: tl.tensor, |
| x_layout: tl.constexpr, |
| y_layout: tl.constexpr, |
| operation: tl.constexpr, |
| onebyone: tl.constexpr, |
| scans: tl.constexpr, |
| BC: tl.constexpr, |
| BH: tl.constexpr, |
| BW: tl.constexpr, |
| DC: tl.constexpr, |
| DH: tl.constexpr, |
| DW: tl.constexpr, |
| NH: tl.constexpr, |
| NW: tl.constexpr, |
| ): |
| |
| |
| |
| |
| |
|
|
| i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
| i_h, i_w = (i_hw // NW), (i_hw % NW) |
| _mask_h = (i_h * BH + tl.arange(0, BH)) < DH |
| _mask_w = (i_w * BW + tl.arange(0, BW)) < DW |
| _mask_hw = _mask_h[:, None] & _mask_w[None, :] |
| _for_C = min(DC - i_c * BC, BC) |
|
|
| pos_h = (i_h * BH + tl.arange(0, BH)[:, None]) |
| pos_w = (i_w * BW + tl.arange(0, BW)[None, :]) |
| neg_h = (DH - i_h * BH - 1 - tl.arange(0, BH)[:, None]) |
| neg_w = (DW - i_w * BW - 1 - tl.arange(0, BW)[None, :]) |
| if scans == 0: |
| |
| HWRoute0 = pos_h * DW + pos_w |
| HWRoute1 = pos_w * DH + pos_h |
| HWRoute2 = neg_h * DW + neg_w |
| HWRoute3 = neg_w * DH + neg_h |
| elif scans == 1: |
| |
| HWRoute0 = pos_h * DW + pos_w |
| HWRoute1 = HWRoute0 |
| HWRoute2 = HWRoute0 |
| HWRoute3 = HWRoute0 |
| elif scans == 2: |
| |
| HWRoute0 = pos_h * DW + pos_w |
| HWRoute1 = HWRoute0 |
| HWRoute2 = neg_h * DW + neg_w |
| HWRoute3 = HWRoute2 |
| elif scans == 3: |
| |
| HWRoute0 = pos_h * DW + pos_w |
| HWRoute1 = neg_w * DH + pos_h |
| HWRoute2 = neg_h * DW + neg_w |
| HWRoute3 = pos_w * DH + neg_h |
|
|
| _tmp1 = DC * DH * DW |
|
|
| y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC) |
| if y_layout == 0: |
| p_y1 = y_ptr_base + HWRoute0 |
| p_y2 = y_ptr_base + _tmp1 + HWRoute1 |
| p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2 |
| p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3 |
| else: |
| p_y1 = y_ptr_base + HWRoute0 * 4 * DC |
| p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC |
| p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC |
| p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC |
| |
| if onebyone == 0: |
| x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) |
| if x_layout == 0: |
| p_x = x_ptr_base + HWRoute0 |
| else: |
| p_x = x_ptr_base + HWRoute0 * DC |
|
|
| if operation == 0: |
| for idxc in range(_for_C): |
| _idx_x = idxc * DH * DW if x_layout == 0 else idxc |
| _idx_y = idxc * DH * DW if y_layout == 0 else idxc |
| _x = tl.load(p_x + _idx_x, mask=_mask_hw) |
| tl.store(p_y1 + _idx_y, _x, mask=_mask_hw) |
| tl.store(p_y2 + _idx_y, _x, mask=_mask_hw) |
| tl.store(p_y3 + _idx_y, _x, mask=_mask_hw) |
| tl.store(p_y4 + _idx_y, _x, mask=_mask_hw) |
| elif operation == 1: |
| for idxc in range(_for_C): |
| _idx_x = idxc * DH * DW if x_layout == 0 else idxc |
| _idx_y = idxc * DH * DW if y_layout == 0 else idxc |
| _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw) |
| _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw) |
| _y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw) |
| _y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw) |
| tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw) |
|
|
| else: |
| x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) |
| if x_layout == 0: |
| p_x1 = x_ptr_base + HWRoute0 |
| p_x2 = p_x1 + _tmp1 |
| p_x3 = p_x2 + _tmp1 |
| p_x4 = p_x3 + _tmp1 |
| else: |
| p_x1 = x_ptr_base + HWRoute0 * 4 * DC |
| p_x2 = p_x1 + DC |
| p_x3 = p_x2 + DC |
| p_x4 = p_x3 + DC |
| |
| if operation == 0: |
| for idxc in range(_for_C): |
| _idx_x = idxc * DH * DW if x_layout == 0 else idxc |
| _idx_y = idxc * DH * DW if y_layout == 0 else idxc |
| tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw) |
| tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw) |
| tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw) |
| tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw) |
| else: |
| for idxc in range(_for_C): |
| _idx_x = idxc * DH * DW if x_layout == 0 else idxc |
| _idx_y = idxc * DH * DW if y_layout == 0 else idxc |
| tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw) |
| tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw) |
| tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw) |
| tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw) |
|
|
|
|
| class CrossScanTritonF(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): |
| if one_by_one: |
| if in_channel_first: |
| B, _, C, H, W = x.shape |
| else: |
| B, H, W, _, C = x.shape |
| else: |
| if in_channel_first: |
| B, C, H, W = x.shape |
| else: |
| B, H, W, C = x.shape |
| B, C, H, W = int(B), int(C), int(H), int(W) |
| BC, BH, BW = 1, 32, 32 |
| NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) |
| |
| ctx.in_channel_first = in_channel_first |
| ctx.out_channel_first = out_channel_first |
| ctx.one_by_one = one_by_one |
| ctx.scans = scans |
| ctx.shape = (B, C, H, W) |
| ctx.triton_shape = (BC, BH, BW, NC, NH, NW) |
|
|
| y = x.new_empty((B, 4, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 4, C)) |
| triton_cross_scan_flex[(NH * NW, NC, B)]( |
| x.contiguous(), y, |
| (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, |
| BC, BH, BW, C, H, W, NH, NW |
| ) |
| return y |
| |
| @staticmethod |
| def backward(ctx, y: torch.Tensor): |
| in_channel_first = ctx.in_channel_first |
| out_channel_first = ctx.out_channel_first |
| one_by_one = ctx.one_by_one |
| scans = ctx.scans |
| B, C, H, W = ctx.shape |
| BC, BH, BW, NC, NH, NW = ctx.triton_shape |
| if one_by_one: |
| x = y.new_empty((B, 4, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 4, C)) |
| else: |
| x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C)) |
| |
| triton_cross_scan_flex[(NH * NW, NC, B)]( |
| x, y.contiguous(), |
| (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans, |
| BC, BH, BW, C, H, W, NH, NW |
| ) |
| return x, None, None, None, None |
|
|
|
|
| class CrossMergeTritonF(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): |
| if out_channel_first: |
| B, _, C, H, W = y.shape |
| else: |
| B, H, W, _, C = y.shape |
| B, C, H, W = int(B), int(C), int(H), int(W) |
| BC, BH, BW = 1, 32, 32 |
| NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) |
| ctx.in_channel_first = in_channel_first |
| ctx.out_channel_first = out_channel_first |
| ctx.one_by_one = one_by_one |
| ctx.scans = scans |
| ctx.shape = (B, C, H, W) |
| ctx.triton_shape = (BC, BH, BW, NC, NH, NW) |
| if one_by_one: |
| x = y.new_empty((B, 4, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 4, C)) |
| else: |
| x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C)) |
| triton_cross_scan_flex[(NH * NW, NC, B)]( |
| x, y.contiguous(), |
| (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans, |
| BC, BH, BW, C, H, W, NH, NW |
| ) |
| return x |
| |
| @staticmethod |
| def backward(ctx, x: torch.Tensor): |
| in_channel_first = ctx.in_channel_first |
| out_channel_first = ctx.out_channel_first |
| one_by_one = ctx.one_by_one |
| scans = ctx.scans |
| B, C, H, W = ctx.shape |
| BC, BH, BW, NC, NH, NW = ctx.triton_shape |
| y = x.new_empty((B, 4, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 4, C)) |
| triton_cross_scan_flex[(NH * NW, NC, B)]( |
| x.contiguous(), y, |
| (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, |
| BC, BH, BW, C, H, W, NH, NW |
| ) |
| return y, None, None, None, None, None |
|
|
|
|
| |
| def cross_scan_fn(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False): |
| |
| |
| |
| CSF = CrossScanTritonF if WITH_TRITON and x.is_cuda and (not force_torch) else CrossScanF |
| if x.is_cuda: |
| with torch.cuda.device(x.device): |
| return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans) |
| else: |
| return CrossScanF.apply(x, in_channel_first, out_channel_first, one_by_one, scans) |
|
|
|
|
| |
| def cross_merge_fn(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False): |
| |
| |
| |
| CMF = CrossMergeTritonF if WITH_TRITON and y.is_cuda and (not force_torch) else CrossMergeF |
| if y.is_cuda: |
| with torch.cuda.device(y.device): |
| return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans) |
| else: |
| return CrossMergeF.apply(y, in_channel_first, out_channel_first, one_by_one, scans) |
|
|
|
|
| |
| |
| |
|
|
| WITH_SELECTIVESCAN_MAMBA = True |
| try: |
| import selective_scan_cuda |
| except ImportError: |
| WITH_SELECTIVESCAN_MAMBA = False |
|
|
|
|
| def selective_scan_torch( |
| u: torch.Tensor, |
| delta: torch.Tensor, |
| A: torch.Tensor, |
| B: torch.Tensor, |
| C: torch.Tensor, |
| D: torch.Tensor = None, |
| delta_bias: torch.Tensor = None, |
| delta_softplus=True, |
| oflex=True, |
| *args, |
| **kwargs |
| ): |
| dtype_in = u.dtype |
| Batch, K, N, L = B.shape |
| KCdim = u.shape[1] |
| Cdim = int(KCdim / K) |
| assert u.shape == (Batch, KCdim, L) |
| assert delta.shape == (Batch, KCdim, L) |
| assert A.shape == (KCdim, N) |
| assert C.shape == B.shape |
|
|
| if delta_bias is not None: |
| delta = delta + delta_bias[..., None] |
| if delta_softplus: |
| delta = torch.nn.functional.softplus(delta) |
| |
| u, delta, A, B, C = u.float(), delta.float(), A.float(), B.float(), C.float() |
| B = B.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L) |
| C = C.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L) |
| deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) |
| deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) |
| |
| if True: |
| x = A.new_zeros((Batch, KCdim, N)) |
| ys = [] |
| for i in range(L): |
| x = deltaA[:, :, i, :] * x + deltaB_u[:, :, i, :] |
| y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) |
| ys.append(y) |
| y = torch.stack(ys, dim=2) |
| |
| out = y if D is None else y + u * D.unsqueeze(-1) |
| return out if oflex else out.to(dtype=dtype_in) |
|
|
|
|
| class SelectiveScanCuda(torch.autograd.Function): |
| @staticmethod |
| @torch.cuda.amp.custom_fwd |
| def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, oflex=True, backend=None): |
| ctx.delta_softplus = delta_softplus |
| |
| |
| backend = "mamba" if WITH_SELECTIVESCAN_MAMBA and (backend is None) else backend |
| ctx.backend = backend |
| if backend == "oflex": |
| out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex) |
| elif backend == "mamba": |
| out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus) |
| ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) |
| return out |
| |
| @staticmethod |
| @torch.cuda.amp.custom_bwd |
| def backward(ctx, dout, *args): |
| u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors |
| backend = ctx.backend |
| if dout.stride(-1) != 1: |
| dout = dout.contiguous() |
| if backend == "oflex": |
| du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd( |
| u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 |
| ) |
| elif backend == "mamba": |
| du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( |
| u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus, |
| False |
| ) |
| return du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None |
|
|
|
|
| def selective_scan_fn( |
| u: torch.Tensor, |
| delta: torch.Tensor, |
| A: torch.Tensor, |
| B: torch.Tensor, |
| C: torch.Tensor, |
| D: torch.Tensor = None, |
| delta_bias: torch.Tensor = None, |
| delta_softplus=True, |
| oflex=True, |
| backend=None, |
| ): |
| fn = selective_scan_torch if backend == "torch" or (not WITH_SELECTIVESCAN_MAMBA) else SelectiveScanCuda.apply |
| return fn(u, delta, A, B, C, D, delta_bias, delta_softplus, oflex, backend) |
|
|
| |
| |
| |
|
|
| def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): |
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
| |
| This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, |
| the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... |
| See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for |
| changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use |
| 'survival rate' as the argument. |
| |
| """ |
| if drop_prob == 0. or not training: |
| return x |
| keep_prob = 1 - drop_prob |
| shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
| random_tensor = x.new_empty(shape).bernoulli_(keep_prob) |
| if keep_prob > 0.0 and scale_by_keep: |
| random_tensor.div_(keep_prob) |
| return x * random_tensor |
|
|
| class DropPath(nn.Module): |
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
| """ |
| def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): |
| super(DropPath, self).__init__() |
| self.drop_prob = drop_prob |
| self.scale_by_keep = scale_by_keep |
|
|
| def forward(self, x): |
| return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) |
|
|
| def extra_repr(self): |
| return f'drop_prob={round(self.drop_prob,3):0.3f}' |
|
|
| class DASSLinear2d(nn.Linear): |
| def __init__(self, *args, groups=1, **kwargs): |
| nn.Linear.__init__(self, *args, **kwargs) |
| self.groups = groups |
| |
| def forward(self, x: torch.Tensor): |
| if len(x.shape) == 4: |
| return F.conv2d(x, self.weight[:, :, None, None], self.bias, groups=self.groups) |
| elif len(x.shape) == 3: |
| return F.conv1d(x, self.weight[:, :, None], self.bias, groups=self.groups) |
|
|
| def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): |
| self_state_dict = self.state_dict() |
| load_state_dict_keys = list(state_dict.keys()) |
| if prefix + "weight" in load_state_dict_keys: |
| state_dict[prefix + "weight"] = state_dict[prefix + "weight"].view_as(self_state_dict["weight"]) |
| return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) |
|
|
|
|
| class DASSLayerNorm2d(nn.LayerNorm): |
| def __init__(self, *args, **kwargs): |
| nn.LayerNorm.__init__(self, *args, **kwargs) |
|
|
| def forward(self, x: torch.Tensor): |
| x = x.permute(0, 2, 3, 1) |
| x = nn.LayerNorm.forward(self, x) |
| x = x.permute(0, 3, 1, 2) |
| return x |
|
|
|
|
| class DASSPatchEmbeddings(nn.Module): |
| """ |
| This class turns `input_values` into the initial `hidden_states` (patch embeddings) of shape `(batch_size, |
| seq_length, hidden_size)` to be consumed by a State-space model. |
| """ |
|
|
| def __init__(self, patch_size=4,embed_dim=96): |
| super().__init__() |
|
|
| stride = patch_size // 2 |
| kernel_size = stride + 1 |
| padding = 1 |
|
|
| self.projection = nn.Sequential( |
| nn.Conv2d(1, embed_dim // 2, kernel_size=kernel_size, stride=stride, padding=padding), |
| DASSLayerNorm2d(embed_dim // 2), |
| nn.GELU(), |
| nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding), |
| DASSLayerNorm2d(embed_dim), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = x.unsqueeze(1) |
| x = x.transpose(2, 3) |
| x = self.projection(x) |
| return x |
|
|
|
|
| class DASSDowsample(nn.Module): |
| """ |
| This class downsamples the input tensor using a convolutional layer followed by a layer normalization. |
| """ |
| def __init__(self, dim, out_dim, use_norm=True): |
| super().__init__() |
| self.down = nn.Conv2d(dim, out_dim, kernel_size=3, stride=2, padding=1) |
| self.norm = DASSLayerNorm2d(out_dim) if use_norm else nn.Identity() |
|
|
| def forward(self, x): |
| x = self.down(x) |
| x = self.norm(x) |
| return x |
|
|
|
|
| class DASSMlp(nn.Module): |
| def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| self.fc1 = DASSLinear2d(in_features, hidden_features) |
| self.act = act_layer() |
| self.fc2 = DASSLinear2d(hidden_features, out_features) |
| self.drop = nn.Dropout(drop) |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.drop(x) |
| x = self.fc2(x) |
| x = self.drop(x) |
| return x |
|
|
|
|
| class SS2D(nn.Module): |
| def __init__( |
| self, |
| |
| d_model=96, |
| d_state=16, |
| ssm_ratio=2.0, |
| dt_rank="auto", |
| act_layer=nn.SiLU, |
| |
| d_conv=3, |
| conv_bias=True, |
| |
| dropout=0.0, |
| bias=False, |
| |
| dt_min=0.001, |
| dt_max=0.1, |
| dt_init="random", |
| dt_scale=1.0, |
| dt_init_floor=1e-4, |
| |
| |
| **kwargs, |
| ): |
| super().__init__() |
| self.k_group = 4 |
| self.d_model = int(d_model) |
| self.d_state = int(d_state) |
| self.d_inner = int(ssm_ratio * d_model) |
| self.dt_rank = int(math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank) |
| self.forward_core = partial(self.forward_corev2, force_fp32=False, no_einsum=True) |
| self.with_dconv = d_conv > 1 |
|
|
| |
| self.in_proj = DASSLinear2d(self.d_model, self.d_inner, bias=bias) |
| self.act: nn.Module = act_layer() |
|
|
| |
| if self.with_dconv: |
| self.conv2d = nn.Conv2d( |
| in_channels=self.d_inner, |
| out_channels=self.d_inner, |
| groups=self.d_inner, |
| bias=conv_bias, |
| kernel_size=d_conv, |
| padding=(d_conv - 1) // 2, |
| ) |
|
|
| |
| self.x_proj = DASSLinear2d(self.d_inner, self.k_group * (self.dt_rank + self.d_state * 2), groups=self.k_group, bias=False) |
| self.dt_projs = DASSLinear2d(self.dt_rank, self.k_group * self.d_inner, groups=self.k_group, bias=False) |
|
|
| |
| self.out_proj = DASSLinear2d(self.d_inner, self.d_model, bias=bias) |
| self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() |
|
|
| |
| self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = self.init_dt_A_D( |
| self.d_state, self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=self.k_group, |
| ) |
| self.dt_projs.weight.data = self.dt_projs_weight.data.view(self.dt_projs.weight.shape) |
| |
| del self.dt_projs_weight |
| |
| |
| self.out_norm = DASSLayerNorm2d(self.d_inner) |
|
|
| @staticmethod |
| def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4): |
| dt_proj = nn.Linear(dt_rank, d_inner, bias=True) |
|
|
| dt_init_std = dt_rank**-0.5 * dt_scale |
| if dt_init == "constant": |
| nn.init.constant_(dt_proj.weight, dt_init_std) |
| elif dt_init == "random": |
| nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) |
| else: |
| raise NotImplementedError |
|
|
| dt = torch.exp( |
| torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min)) |
| + math.log(dt_min) |
| ).clamp(min=dt_init_floor) |
|
|
| inv_dt = dt + torch.log(-torch.expm1(-dt)) |
| with torch.no_grad(): |
| dt_proj.bias.copy_(inv_dt) |
| |
| return dt_proj |
|
|
| @staticmethod |
| def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True): |
| A = torch.arange(1, d_state + 1, dtype=torch.float32, device=device).view(1, -1).repeat(d_inner, 1).contiguous() |
| A_log = torch.log(A) |
| if copies > 0: |
| A_log = A_log[None].repeat(copies, 1, 1).contiguous() |
| if merge: |
| A_log = A_log.flatten(0, 1) |
| A_log = nn.Parameter(A_log) |
| |
| return A_log |
|
|
| @staticmethod |
| def D_init(d_inner, copies=-1, device=None, merge=True): |
| D = torch.ones(d_inner, device=device) |
| if copies > 0: |
| D = D[None].repeat(copies, 1).contiguous() |
| if merge: |
| D = D.flatten(0, 1) |
| D = nn.Parameter(D) |
| |
| return D |
|
|
| @classmethod |
| def init_dt_A_D(cls, d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4): |
| dt_projs = [ |
| cls.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor) |
| for _ in range(k_group) |
| ] |
| dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in dt_projs], dim=0)) |
| dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in dt_projs], dim=0)) |
| del dt_projs |
| |
| A_logs = cls.A_log_init(d_state, d_inner, copies=k_group, merge=True) |
| Ds = cls.D_init(d_inner, copies=k_group, merge=True) |
| return A_logs, Ds, dt_projs_weight, dt_projs_bias |
|
|
| def forward_corev2( |
| self, |
| x: torch.Tensor, |
| force_fp32=False, |
| no_einsum=True, |
| ): |
| B, D, H, W = x.shape |
| N = self.d_state |
| L = H * W |
|
|
| xs = cross_scan_fn(x, in_channel_first=True, out_channel_first=True) |
| x_dbl = self.x_proj(xs.view(B, -1, L)) |
| dts, Bs, Cs = torch.split(x_dbl.view(B, self.k_group, -1, L), [self.dt_rank, N, N], dim=2) |
| dts = dts.contiguous().view(B, -1, L) |
| dts = self.dt_projs(dts) |
|
|
| xs = xs.view(B, -1, L) |
| dts = dts.contiguous().view(B, -1, L) |
| As = -self.A_logs.to(torch.float32).exp() |
| Ds = self.Ds.to(torch.float32) |
| Bs = Bs.contiguous().view(B, self.k_group, N, L) |
| Cs = Cs.contiguous().view(B, self.k_group, N, L) |
| delta_bias = self.dt_projs_bias.view(-1).to(torch.float32) |
| |
| ys = selective_scan_fn( |
| xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus=True, backend="mamba" |
| ).view(B, self.k_group, -1, H, W) |
| |
| y = cross_merge_fn(ys, in_channel_first=True, out_channel_first=True) |
| y = y.view(B, -1, H, W) |
| y = self.out_norm(y) |
| return y.to(x.dtype) |
|
|
| def forward(self, x: torch.Tensor): |
| x = self.in_proj(x) |
| x = self.conv2d(x) |
| |
| x = self.act(x) |
| y = self.forward_core(x) |
| |
| out = self.dropout(self.out_proj(y)) |
| return out |
|
|
|
|
| class VSSBlock(nn.Module): |
| def __init__( |
| self, |
| hidden_dim: int = 0, |
| drop_path: float = 0, |
| ssm_d_state: int = 1, |
| ssm_ratio=1.0, |
| ssm_dt_rank: Any = "auto", |
| ssm_act_layer=nn.SiLU, |
| ssm_conv: int = 3, |
| ssm_conv_bias=False, |
| ssm_drop_rate: float = 0, |
| mlp_ratio=4.0, |
| mlp_act_layer=nn.GELU, |
| mlp_drop_rate: float = 0.0, |
| use_checkpoint: bool = False, |
| post_norm: bool = False, |
| **kwargs, |
| ): |
| super().__init__() |
| self.ssm_branch = ssm_ratio > 0 |
| self.mlp_branch = mlp_ratio > 0 |
| self.use_checkpoint = use_checkpoint |
| self.post_norm = post_norm |
|
|
| if self.ssm_branch: |
| self.norm = DASSLayerNorm2d(hidden_dim) |
| self.op = SS2D( |
| d_model=hidden_dim, |
| d_state=ssm_d_state, |
| ssm_ratio=ssm_ratio, |
| dt_rank=ssm_dt_rank, |
| act_layer=ssm_act_layer, |
| d_conv=ssm_conv, |
| conv_bias=ssm_conv_bias, |
| dropout=ssm_drop_rate, |
| ) |
| |
| self.drop_path = DropPath(drop_path) |
| |
| if self.mlp_branch: |
| self.norm2 = DASSLayerNorm2d(hidden_dim) |
| mlp_hidden_dim = int(hidden_dim * mlp_ratio) |
| self.mlp = DASSMlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate) |
|
|
| def _forward(self, input: torch.Tensor): |
| x = input |
| if self.ssm_branch: |
| if self.post_norm: |
| x = x + self.drop_path(self.norm(self.op(x))) |
| else: |
| x = x + self.drop_path(self.op(self.norm(x))) |
| if self.mlp_branch: |
| if self.post_norm: |
| x = x + self.drop_path(self.norm2(self.mlp(x))) |
| else: |
| x = x + self.drop_path(self.mlp(self.norm2(x))) |
| return x |
|
|
| def forward(self, input: torch.Tensor): |
| if self.use_checkpoint: |
| return checkpoint.checkpoint(self._forward, input) |
| else: |
| return self._forward(input) |
|
|
| class DASSLayer(nn.Module): |
|
|
| def __init__( |
| self, |
| input_dim, |
| depth, |
| drop_path=0.0, |
| norm_layer=DASSLayerNorm2d, |
| downsample=nn.Identity(), |
| use_checkpoint=False, |
| **kwargs, |
| ): |
| super().__init__() |
| self.input_dim = input_dim |
| self.use_checkpoint = use_checkpoint |
|
|
| self.blocks = nn.ModuleList() |
| for i in range(depth): |
| self.blocks.append( |
| VSSBlock(hidden_dim=input_dim, |
| drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, |
| norm_layer=norm_layer,use_checkpoint=use_checkpoint,**kwargs, |
| ) |
| ) |
| |
| self.downsample = downsample |
|
|
| def forward(self, x): |
| for block in self.blocks: |
| x = block(x) |
|
|
| x = self.downsample(x) |
| return x |
|
|
| class DASSPreTrainedModel(PreTrainedModel): |
| """ |
| An abstract class to handle weights initialization and |
| a simple interface for downloading and loading pretrained models. |
| """ |
|
|
| config_class = DASSConfig |
| base_model_prefix = "dass" |
| supports_gradient_checkpointing = False |
|
|
| def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: |
| """Initialize the weights""" |
| if isinstance(module, nn.Linear): |
| nn.init.trunc_normal_(module.weight, std=0.02) |
| if isinstance(module, nn.Linear) and module.bias is not None: |
| nn.init.constant_(module.bias, 0) |
| elif isinstance(module, nn.LayerNorm): |
| nn.init.constant_(module.bias, 0) |
| nn.init.constant_(module.weight, 1.0) |
|
|
|
|
| class DASSModel(DASSPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
|
|
| dims = config.dims |
| if isinstance(dims, int): |
| dims = [int(dims * 2**i_layer) for i_layer in range(self.num_layers)] |
|
|
| self.dims = dims |
| self.patch_embeddings = DASSPatchEmbeddings(patch_size=config.patch_size, |
| embed_dim=dims[0]) |
| |
| self.num_layers = len(config.depths) |
| dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] |
| self.num_features = dims[-1] |
|
|
| self.layers = nn.ModuleList() |
| for i in range(self.num_layers): |
| layer = DASSLayer( |
| input_dim=self.dims[i], |
| depth=config.depths[i], |
| drop_path=dpr[sum(config.depths[:i]):sum(config.depths[:i+1])], |
| downsample=DASSDowsample(self.dims[i], self.dims[i+1]) if i < self.num_layers - 1 else nn.Identity(), |
| use_checkpoint=config.use_checkpoint, |
| ) |
| self.layers.append(layer) |
| |
| self.norm = DASSLayerNorm2d(self.num_features) |
| self.avgpool = nn.AdaptiveAvgPool2d(1) |
|
|
| def get_input_embeddings(self) -> DASSPatchEmbeddings: |
| return self.patch_embeddings |
| |
| def forward(self, input_values: torch.Tensor): |
| x = self.patch_embeddings(input_values) |
| for layer in self.layers: |
| x = layer(x) |
| x = self.norm(x) |
| x = self.avgpool(x).flatten(1) |
| return x |
|
|
|
|
| class DASSForAudioClassification(DASSPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.num_classes = config.num_classes |
| self.dass = DASSModel(config) |
| self.head = nn.Linear(self.dass.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity() |
|
|
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_values: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| return_dict: Optional[bool] = None, |
| ): |
|
|
| outputs = self.dass( |
| input_values, |
| ) |
|
|
| logits = self.head(outputs) |
|
|
| loss = None |
| if labels is not None: |
| labels = labels.to(logits.device) |
| if self.config.loss_type == "ce": |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| elif self.config.problem_type == "bce": |
| loss_fct = BCEWithLogitsLoss() |
| loss = loss_fct(logits, labels) |
|
|
| if return_dict: |
| output = (logits,) + (outputs,) |
| return ((loss,) + output) if loss is not None else output |
|
|
| return SequenceClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs, |
| ) |
|
|
| __all__ = [ |
| "DASSModel", |
| "DASSPreTrainedModel", |
| "DASSForAudioClassification", |
| ] |