|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
def carafe_forward( |
|
|
features: torch.Tensor, |
|
|
masks: torch.Tensor, |
|
|
kernel_size: int, |
|
|
group_size: int, |
|
|
scale_factor: int |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Pure-PyTorch implementation of the CARAFE upsampling operator. |
|
|
|
|
|
Args: |
|
|
features (Tensor): Input feature map of shape (N, C, H, W). |
|
|
masks (Tensor): Reassembly kernel weights of shape |
|
|
(N, kernel_size*kernel_size*group_size, H_out, W_out), |
|
|
where H_out = H*scale_factor and W_out = W*scale_factor. |
|
|
kernel_size (int): The spatial size of the reassembly kernel. |
|
|
group_size (int): The group size to divide channels. Must divide C. |
|
|
scale_factor (int): The upsampling factor. |
|
|
|
|
|
Returns: |
|
|
Tensor: Upsampled feature map of shape (N, C, H*scale_factor, W*scale_factor). |
|
|
""" |
|
|
N, C, H, W = features.size() |
|
|
out_H, out_W = H * scale_factor, W * scale_factor |
|
|
num_channels = C // group_size |
|
|
|
|
|
|
|
|
features = features.view(N, group_size, num_channels, H, W) |
|
|
|
|
|
features_reshaped = features.view(N * group_size, num_channels, H, W) |
|
|
|
|
|
patches = F.unfold(features_reshaped, kernel_size=kernel_size, |
|
|
padding=(kernel_size - 1) // 2) |
|
|
|
|
|
|
|
|
patches = patches.view(N, group_size, num_channels, kernel_size * kernel_size, H, W) |
|
|
|
|
|
patches = patches.view(N, group_size, num_channels, kernel_size * kernel_size, H * W) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = features.device |
|
|
|
|
|
h_idx = torch.div(torch.arange(out_H, device=device), scale_factor, rounding_mode='floor') |
|
|
w_idx = torch.div(torch.arange(out_W, device=device), scale_factor, rounding_mode='floor') |
|
|
|
|
|
h_idx = h_idx.unsqueeze(1).expand(out_H, out_W) |
|
|
w_idx = w_idx.unsqueeze(0).expand(out_H, out_W) |
|
|
base_idx = (h_idx * W + w_idx).view(-1) |
|
|
|
|
|
|
|
|
|
|
|
base_idx = base_idx.view(1, 1, 1, 1, -1).expand(N, group_size, num_channels, kernel_size * kernel_size, -1) |
|
|
|
|
|
gathered_patches = torch.gather(patches, -1, base_idx) |
|
|
|
|
|
gathered_patches = gathered_patches.view(N, group_size, num_channels, kernel_size * kernel_size, out_H, out_W) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
masks = masks.view(N, group_size, kernel_size * kernel_size, out_H, out_W) |
|
|
|
|
|
|
|
|
masks = masks.unsqueeze(2) |
|
|
|
|
|
masks = masks.expand(-1, -1, num_channels, -1, -1, -1) |
|
|
|
|
|
|
|
|
|
|
|
out = (gathered_patches * masks).sum(dim=3) |
|
|
|
|
|
out = out.view(N, C, out_H, out_W) |
|
|
return out |
|
|
|
|
|
|
|
|
class CARAFE(nn.Module): |
|
|
""" |
|
|
CARAFE: Content-Aware ReAssembly of Features |
|
|
|
|
|
This PyTorch module implements the CARAFE upsampling operator in pure Python. |
|
|
Given an input feature map and its corresponding reassembly masks, the module |
|
|
reassembles features from local patches to produce a higher-resolution output. |
|
|
|
|
|
Args: |
|
|
kernel_size (int): Reassembly kernel size. |
|
|
group_size (int): Group size for channel grouping (must divide number of channels). |
|
|
scale_factor (int): Upsample ratio. |
|
|
""" |
|
|
def __init__(self, kernel_size: int, group_size: int, scale_factor: int): |
|
|
super(CARAFE, self).__init__() |
|
|
self.kernel_size = kernel_size |
|
|
self.group_size = group_size |
|
|
self.scale_factor = scale_factor |
|
|
|
|
|
def forward(self, features: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: |
|
|
return carafe_forward(features, masks, self.kernel_size, self.group_size, self.scale_factor) |
|
|
|
|
|
|
|
|
class CARAFEPack(nn.Module): |
|
|
""" |
|
|
A unified package of the CARAFE upsampler that contains: |
|
|
1) A channel compressor. |
|
|
2) A content encoder that predicts reassembly masks. |
|
|
3) The CARAFE operator. |
|
|
|
|
|
This is modeled after the official CARAFE package. |
|
|
|
|
|
Args: |
|
|
channels (int): Number of input feature channels. |
|
|
scale_factor (int): Upsample ratio. |
|
|
up_kernel (int): Kernel size for the CARAFE operator. |
|
|
up_group (int): Group size for the CARAFE operator. |
|
|
encoder_kernel (int): Kernel size of the content encoder. |
|
|
encoder_dilation (int): Dilation rate for the content encoder. |
|
|
compressed_channels (int): Output channels for the channel compressor. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
channels: int, |
|
|
scale_factor: int, |
|
|
up_kernel: int = 5, |
|
|
up_group: int = 1, |
|
|
encoder_kernel: int = 3, |
|
|
encoder_dilation: int = 1, |
|
|
compressed_channels: int = 64 |
|
|
): |
|
|
super(CARAFEPack, self).__init__() |
|
|
self.channels = channels |
|
|
self.scale_factor = scale_factor |
|
|
self.up_kernel = up_kernel |
|
|
self.up_group = up_group |
|
|
self.encoder_kernel = encoder_kernel |
|
|
self.encoder_dilation = encoder_dilation |
|
|
self.compressed_channels = compressed_channels |
|
|
|
|
|
|
|
|
self.channel_compressor = nn.Conv2d(channels, compressed_channels, kernel_size=1) |
|
|
|
|
|
self.content_encoder = nn.Conv2d( |
|
|
compressed_channels, |
|
|
up_kernel * up_kernel * up_group * scale_factor * scale_factor, |
|
|
kernel_size=encoder_kernel, |
|
|
padding=int((encoder_kernel - 1) * encoder_dilation / 2), |
|
|
dilation=encoder_dilation |
|
|
) |
|
|
|
|
|
nn.init.xavier_uniform_(self.channel_compressor.weight) |
|
|
nn.init.xavier_uniform_(self.content_encoder.weight) |
|
|
if self.channel_compressor.bias is not None: |
|
|
nn.init.constant_(self.channel_compressor.bias, 0) |
|
|
if self.content_encoder.bias is not None: |
|
|
nn.init.constant_(self.content_encoder.bias, 0) |
|
|
|
|
|
def kernel_normalizer(self, mask: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Normalize and reshape the mask. |
|
|
Applies pixel shuffle to upsample the predicted kernel weights and then |
|
|
applies softmax normalization across the kernel dimension. |
|
|
|
|
|
Args: |
|
|
mask (Tensor): Predicted mask of shape (N, out_channels, H, W). |
|
|
|
|
|
Returns: |
|
|
Tensor: Normalized mask of shape (N, up_group * up_kernel^2, H*scale, W*scale). |
|
|
""" |
|
|
|
|
|
mask = F.pixel_shuffle(mask, self.scale_factor) |
|
|
N, mask_c, H, W = mask.size() |
|
|
|
|
|
mask_channel = mask_c // (self.up_kernel ** 2) |
|
|
mask = mask.view(N, mask_channel, self.up_kernel ** 2, H, W) |
|
|
mask = F.softmax(mask, dim=2) |
|
|
mask = mask.view(N, mask_channel * self.up_kernel ** 2, H, W).contiguous() |
|
|
return mask |
|
|
|
|
|
def feature_reassemble(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
|
|
return carafe_forward(x, mask, self.up_kernel, self.up_group, self.scale_factor) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
compressed_x = self.channel_compressor(x) |
|
|
mask = self.content_encoder(compressed_x) |
|
|
mask = self.kernel_normalizer(mask) |
|
|
out = self.feature_reassemble(x, mask) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
x = torch.randn(2, 64, 32, 32).cuda() |
|
|
|
|
|
|
|
|
upsampler = CARAFEPack(channels=64, scale_factor=2, up_kernel=5, up_group=1).cuda() |
|
|
|
|
|
out = upsampler(x) |
|
|
print("Input shape: ", x.shape) |
|
|
print("Output shape:", out.shape) |
|
|
|