import torch.nn.functional as F import torch import numpy as np from parse_config import ConfigParser import torch.nn as nn from torch.autograd import Variable import math from utils import sigmoid_rampup, sigmoid_rampdown, cosine_rampup, cosine_rampdown, linear_rampup def cross_entropy(output, target, M=3): return F.cross_entropy(output, target) class elr_plus_loss(nn.Module): def __init__(self, num_examp, config, device, num_classes=10, beta=0.3): super(elr_plus_loss, self).__init__() self.config = config self.pred_hist = (torch.zeros(num_examp, num_classes)).to(device) self.q = 0 self.beta = beta self.num_classes = num_classes def forward(self, iteration, output, y_labeled): y_pred = F.softmax(output,dim=1) y_pred = torch.clamp(y_pred, 1e-4, 1.0-1e-4) if self.num_classes == 100: y_labeled = y_labeled*self.q y_labeled = y_labeled/(y_labeled).sum(dim=1,keepdim=True) ce_loss = torch.mean(-torch.sum(y_labeled * F.log_softmax(output, dim=1), dim = -1)) reg = ((1-(self.q * y_pred).sum(dim=1)).log()).mean() final_loss = ce_loss + sigmoid_rampup(iteration, self.config['coef_step'])*(self.config['train_loss']['args']['lambda']*reg) return final_loss, y_pred.cpu().detach() def update_hist(self, epoch, out, index= None, mix_index = ..., mixup_l = 1): y_pred_ = F.softmax(out,dim=1) self.pred_hist[index] = self.beta * self.pred_hist[index] + (1-self.beta) * y_pred_/(y_pred_).sum(dim=1,keepdim=True) self.q = mixup_l * self.pred_hist[index] + (1-mixup_l) * self.pred_hist[index][mix_index]