import torch from konfai.network import network, blocks class ConvBlock(network.ModuleArgsDict): def __init__(self, in_channels : int, out_channels : int, stride: int = 1 ) -> None: super().__init__() self.add_module("Conv_0", torch.nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=True)) self.add_module("Norm_0", torch.nn.InstanceNorm3d(num_features=out_channels, affine=True)) self.add_module("Activation_0", torch.nn.LeakyReLU(negative_slope=0.01, inplace=True)) self.add_module("Conv_1", torch.nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=True)) self.add_module("Norm_1", torch.nn.InstanceNorm3d(num_features=out_channels, affine=True)) self.add_module("Activation_1", torch.nn.LeakyReLU(negative_slope=0.01, inplace=True)) class UNetHead(network.ModuleArgsDict): def __init__(self, in_channels: int, nb_class: int) -> None: super().__init__() self.add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0)) self.add_module("Softmax", torch.nn.Softmax(dim=1)) class UNetBlock(network.ModuleArgsDict): def __init__(self, channels, i : int = 0) -> None: super().__init__() self.add_module("DownConvBlock", ConvBlock(in_channels=channels[0], out_channels=channels[1], stride= 2 if i>0 else 1)) if len(channels) > 2: self.add_module("UNetBlock", UNetBlock(channels[1:], i+1)) self.add_module("UpConvBlock", ConvBlock(in_channels=channels[1]*2, out_channels=channels[1])) if i > 0: self.add_module("CONV_TRANSPOSE", torch.nn.ConvTranspose3d(in_channels = channels[1], out_channels = channels[0], kernel_size = 2, stride = 2, padding = 0)) self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1]) class ClipAndNormalize(torch.nn.Module): def __init__(self) -> None: super().__init__() self.register_buffer("clip_min", torch.empty(1)) self.register_buffer("clip_max", torch.empty(1)) self.register_buffer("mean", torch.empty(1)) self.register_buffer("std", torch.empty(1)) def forward(self, x: torch.Tensor) -> torch.Tensor: x = torch.clamp(x, self.clip_min, self.clip_max) return (x - self.mean) / (self.std) class Unet_TS_CT(network.Network): def __init__(self, optimizer: network.OptimizerLoader = network.OptimizerLoader(), schedulers: dict[str, network.LRSchedulersLoader] = { "default:ReduceLROnPlateau": network.LRSchedulersLoader(0) }, outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()}, channels: list[int] = [1, 32, 64, 128, 320]) -> None: super().__init__( in_channels=channels[0], optimizer=optimizer, schedulers=schedulers, outputs_criterions=outputs_criterions, patch=None, dim=3, ) self.add_module("ClipAndNormalize", ClipAndNormalize()) self.add_module("UNetBlock", UNetBlock(channels)) self.add_module("Head", UNetHead(channels[1], 118)) def load( self, state_dict: dict[str, dict[str, torch.Tensor] | int], init: bool = True, ema: bool = False, ): nb_class, in_channels = state_dict["Model"]["Unet_TS_CT"]["Head.Conv.weight"].shape[:2] self["Head"].add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0)) super().load(state_dict, init, ema)