I would like this yolo_layer.py to work like how multibox_loss_gmm.py is. There needs a lot of major fixes to do but I'm lost. Please update the yolo_layer.py so it includes the gmm working. This is my yolo_layer.py:
import torch
import torch.nn as nn
import numpy as np
from utils.utils import bboxes_iou
class YOLOLayer(nn.Module):
"""
detection layer corresponding to yolo_layer.c of darknet
"""
def __init__(self, config_model, layer_no, in_ch, ignore_thre=0.7):
super(YOLOLayer, self).__init__()
strides = [32, 16, 8] # fixed pixel bit depth
self.anchors = config_model['ANCHORS']
self.anch_mask = config_model['ANCH_MASK'][layer_no]
self.n_anchors = len(self.anch_mask)
self.n_classes = config_model['N_CLASSES']
self.ignore_thre = ignore_thre
self.l2_loss = nn.MSELoss(size_average=False) # measures the error of mean squared format that is square L2 normalization
self.bce_loss = nn.BCELoss(size_average=False)
self.stride = strides[layer_no]
self.all_anchors_grid = [(w / self.stride, h / self.stride)
for w, h in self.anchors]
self.masked_anchors = [self.all_anchors_grid[i]
for i in self.anch_mask]
self.ref_anchors = np.zeros((len(self.all_anchors_grid), 4))
self.ref_anchors[:, 2:] = np.array(self.all_anchors_grid)
self.ref_anchors = torch.FloatTensor(self.ref_anchors)
self.conv = nn.Conv2d(in_channels=in_ch,
out_channels=self.n_anchors * (self.n_classes + 5),
kernel_size=1, stride=1, padding=0)
def forward(self, xin, labels=None):
output = self.conv(xin)
batchsize = output.shape[0]
fsize = output.shape[2]
n_ch = 5 + self.n_classes # 5 + number of classes. channels per anchor w/o xywh unceartainties
dtype = torch.cuda.FloatTensor if xin.is_cuda else torch.FloatTensor # FloatTensor is a tensor of 32-bit floating point values
output = output.view(batchsize, self.n_anchors, n_ch, fsize, fsize)
# shape: [batch, anchor, grid_y, grid_x, channels_per_anchor]
output = output.permute(0, 1, 3, 4, 2) # .contiguous() returns a view of the input tensor with its dimension permuted
# logistic activation for xy, obj, cls
output[..., np.r_[:2, 4:n_ch]] = torch.sigmoid(
output[..., np.r_[:2, 4:n_ch]])
# calculate pred - xywh obj cls
x_shift = dtype(np.broadcast_to(
np.arange(fsize, dtype=np.float32), output.shape[:4]))
y_shift = dtype(np.broadcast_to(
np.arange(fsize, dtype=np.float32).reshape(fsize, 1), output.shape[:4]))
masked_anchors = np.array(self.masked_anchors)
w_anchors = dtype(np.broadcast_to(np.reshape(
masked_anchors[:, 0], (1, self.n_anchors, 1, 1)), output.shape[:4]))
h_anchors = dtype(np.broadcast_to(np.reshape(
masked_anchors[:, 1], (1, self.n_anchors, 1, 1)), output.shape[:4]))
pred = output.clone()
pred[..., 0] += x_shift
pred[..., 1] += y_shift
pred[..., 2] = torch.exp(pred[..., 2]) * w_anchors
pred[..., 3] = torch.exp(pred[..., 3]) * h_anchors
if labels is None: # not training # for testing
pred[..., :4] *= self.stride
return pred.view(batchsize, -1, n_ch).data
pred = pred[..., :4].data # shape: [batch, anchor, grid_y, grid_x, 4(= x, y, w, h)]
# target assignment
# torch.zeros returns a tensor filled with the scalar value 0, with the shape defined by the variable argument size.
tgt_mask = torch.zeros(batchsize, self.n_anchors,
fsize, fsize, 4 + self.n_classes).type(dtype)
obj_mask = torch.ones(batchsize, self.n_anchors,
fsize, fsize).type(dtype)
tgt_scale = torch.zeros(batchsize, self.n_anchors,
fsize, fsize, 2).type(dtype)
target = torch.zeros(batchsize, self.n_anchors,
fsize, fsize, n_ch).type(dtype)
labels = labels.cpu().data
nlabel = (labels.sum(dim=2) > 0).sum(dim=1) # number of objects
truth_x_all = labels[:, :, 1] * fsize
truth_y_all = labels[:, :, 2] * fsize
truth_w_all = labels[:, :, 3] * fsize
truth_h_all = labels[:, :, 4] * fsize
truth_i_all = truth_x_all.to(torch.int16).numpy()
truth_j_all = truth_y_all.to(torch.int16).numpy()
for b in range(batchsize):
n = int(nlabel[b])
if n == 0:
continue
truth_box = dtype(np.zeros((n, 4)))
truth_box[:n, 2] = truth_w_all[b, :n]
truth_box[:n, 3] = truth_h_all[b, :n]
truth_i = truth_i_all[b, :n]
truth_j = truth_j_all[b, :n]
# calculate iou between truth and reference anchors
#ground truth
anchor_ious_all = bboxes_iou(truth_box.cpu(), self.ref_anchors)
best_n_all = np.argmax(anchor_ious_all, axis=1)
best_n = best_n_all % 3
best_n_mask = ((best_n_all == self.anch_mask[0]) | (
best_n_all == self.anch_mask[1]) | (best_n_all == self.anch_mask[2]))
truth_box[:n, 0] = truth_x_all[b, :n]
truth_box[:n, 1] = truth_y_all[b, :n]
#prediction
pred_ious = bboxes_iou(
pred[b].view(-1, 4), truth_box, xyxy=False)
pred_best_iou, _ = pred_ious.max(dim=1)
pred_best_iou = (pred_best_iou > self.ignore_thre)
pred_best_iou = pred_best_iou.view(pred[b].shape[:3])
# set mask to zero (ignore) if pred matches truth
obj_mask[b] = 1 - pred_best_iou
if sum(best_n_mask) == 0:
continue
for ti in range(best_n.shape[0]):
if best_n_mask[ti] == 1:
i, j = truth_i[ti], truth_j[ti]
a = best_n[ti]
obj_mask[b, a, j, i] = 1
tgt_mask[b, a, j, i, :] = 1
target[b, a, j, i, 0] = truth_x_all[b, ti] - \
truth_x_all[b, ti].to(torch.int16).to(torch.float)
target[b, a, j, i, 1] = truth_y_all[b, ti] - \
truth_y_all[b, ti].to(torch.int16).to(torch.float)
target[b, a, j, i, 2] = torch.log(
truth_w_all[b, ti] / torch.Tensor(self.masked_anchors)[best_n[ti], 0] + 1e-16)
target[b, a, j, i, 3] = torch.log(
truth_h_all[b, ti] / torch.Tensor(self.masked_anchors)[best_n[ti], 1] + 1e-16)
target[b, a, j, i, 4] = 1
target[b, a, j, i, 5 + labels[b, ti,
0].to(torch.int16).numpy()] = 1
tgt_scale[b, a, j, i, :] = torch.sqrt(
2 - truth_w_all[b, ti] * truth_h_all[b, ti] / fsize / fsize)
# loss calculation
output[..., 4] *= obj_mask
output[..., np.r_[0:4, 5:n_ch]] *= tgt_mask
output[..., 2:4] *= tgt_scale
target[..., 4] *= obj_mask
target[..., np.r_[0:4, 5:n_ch]] *= tgt_mask
target[..., 2:4] *= tgt_scale
bceloss = nn.BCELoss(weight=tgt_scale*tgt_scale,
size_average=False) # weighted BCEloss
loss_xy = bceloss(output[..., :2], target[..., :2])
loss_wh = self.l2_loss(output[..., 2:4], target[..., 2:4]) / 2
loss_obj = self.bce_loss(output[..., 4], target[..., 4])
loss_cls = self.bce_loss(output[..., 5:], target[..., 5:])
loss_l2 = self.l2_loss(output, target)
loss = loss_xy + loss_wh + loss_obj + loss_cls
return loss, loss_xy, loss_wh, loss_obj, loss_cls, loss_l2
Here is the code for multibox_loss_gmm.py. As you can see, it implements Gaussian function and NLL_loss function:
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from data import coco as cfg
from ..box_utils import match, log_sum_exp
import math
def Gaussian(y, mu, var):
eps = 0.3
result = (y-mu)/var
result = (result**2)/2*(-1)
exp = torch.exp(result)
result = exp/(math.sqrt(2*math.pi))/(var + eps)
return result
def NLL_loss(bbox_gt, bbox_pred, bbox_var):
bbox_var = torch.sigmoid(bbox_var)
prob = Gaussian(bbox_gt, bbox_pred, bbox_var)
return prob
class MultiBoxLoss_GMM(nn.Module):
"""SSD Weighted Loss Function
Compute Targets:
1) Produce Confidence Target Indices by matching ground truth boxes
with (default) 'priorboxes' that have jaccard index > threshold parameter
(default threshold: 0.5).
2) Produce localization target by 'encoding' variance into offsets of ground
truth boxes and their matched 'priorboxes'.
3) Hard negative mining to filter the excessive number of negative examples
that comes with using a large number of default bounding boxes.
(default negative:positive ratio 3:1)
"""
def __init__(self, num_classes, overlap_thresh, prior_for_matching,
bkg_label, neg_mining, neg_pos, neg_overlap, encode_target,
use_gpu=True, cls_type='Type-1'):
super(MultiBoxLoss_GMM, self).__init__()
self.use_gpu = use_gpu
self.num_classes = num_classes
self.threshold = overlap_thresh
self.background_label = bkg_label
self.encode_target = encode_target
self.use_prior_for_matching = prior_for_matching
self.do_neg_mining = neg_mining
self.negpos_ratio = neg_pos
self.neg_overlap = neg_overlap
self.variance = cfg['variance']
self.cls_type = cls_type
def forward(self, predictions, targets):
priors, loc_mu_1, loc_var_1, loc_pi_1, loc_mu_2, loc_var_2, loc_pi_2, \
loc_mu_3, loc_var_3, loc_pi_3, loc_mu_4, loc_var_4, loc_pi_4, \
conf_mu_1, conf_var_1, conf_pi_1, conf_mu_2, conf_var_2, conf_pi_2, \
conf_mu_3, conf_var_3, conf_pi_3, conf_mu_4, conf_var_4, conf_pi_4 = predictions
num = loc_mu_1.size(0)
priors = priors[:loc_mu_1.size(1), :]
num_priors = (priors.size(0))
num_classes = self.num_classes
# match priors (default boxes) and ground truth boxes
loc_t = torch.Tensor(num, num_priors, 4)
conf_t = torch.LongTensor(num, num_priors)
for idx in range(num):
truths = targets[idx][:, :-1].data
labels = targets[idx][:, -1].data
defaults = priors.data
match(self.threshold,
truths,
defaults,
self.variance,
labels,
loc_t,
conf_t,
idx)
if self.use_gpu:
loc_t = loc_t.cuda()
conf_t = conf_t.cuda()
# wrap targets
loc_t = Variable(loc_t, requires_grad=False)
conf_t = Variable(conf_t, requires_grad=False)
pos = conf_t > 0
num_pos = pos.sum(dim=1, keepdim=True)
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_mu_1)
loc_mu_1_ = loc_mu_1[pos_idx].view(-1, 4)
loc_mu_2_ = loc_mu_2[pos_idx].view(-1, 4)
loc_mu_3_ = loc_mu_3[pos_idx].view(-1, 4)
loc_mu_4_ = loc_mu_4[pos_idx].view(-1, 4)
loc_t = loc_t[pos_idx].view(-1, 4)
# localization loss
loss_l_1 = NLL_loss(loc_t, loc_mu_1_, loc_var_1[pos_idx].view(-1, 4))
loss_l_2 = NLL_loss(loc_t, loc_mu_2_, loc_var_2[pos_idx].view(-1, 4))
loss_l_3 = NLL_loss(loc_t, loc_mu_3_, loc_var_3[pos_idx].view(-1, 4))
loss_l_4 = NLL_loss(loc_t, loc_mu_4_, loc_var_4[pos_idx].view(-1, 4))
loc_pi_1_ = loc_pi_1[pos_idx].view(-1, 4)
loc_pi_2_ = loc_pi_2[pos_idx].view(-1, 4)
loc_pi_3_ = loc_pi_3[pos_idx].view(-1, 4)
loc_pi_4_ = loc_pi_4[pos_idx].view(-1, 4)
pi_all = torch.stack([
loc_pi_1_.reshape(-1),
loc_pi_2_.reshape(-1),
loc_pi_3_.reshape(-1),
loc_pi_4_.reshape(-1)
])
pi_all = pi_all.transpose(0,1)
pi_all = (torch.softmax(pi_all, dim=1)).transpose(0,1).reshape(-1)
(
loc_pi_1_,
loc_pi_2_,
loc_pi_3_,
loc_pi_4_
) = torch.split(pi_all, loc_pi_1_.reshape(-1).size(0), dim=0)
loc_pi_1_ = loc_pi_1_.view(-1, 4)
loc_pi_2_ = loc_pi_2_.view(-1, 4)
loc_pi_3_ = loc_pi_3_.view(-1, 4)
loc_pi_4_ = loc_pi_4_.view(-1, 4)
_loss_l = (
loc_pi_1_*loss_l_1 +
loc_pi_2_*loss_l_2 +
loc_pi_3_*loss_l_3 +
loc_pi_4_*loss_l_4
)
epsi = 10**-9
# balance parameter
balance = 2.0
loss_l = -torch.log(_loss_l + epsi)/balance
loss_l = loss_l.sum()
if self.cls_type == 'Type-1':
# Classification loss (Type-1)
conf_pi_1_ = conf_pi_1.view(-1, 1)
conf_pi_2_ = conf_pi_2.view(-1, 1)
conf_pi_3_ = conf_pi_3.view(-1, 1)
conf_pi_4_ = conf_pi_4.view(-1, 1)
conf_pi_all = torch.stack([
conf_pi_1_.reshape(-1),
conf_pi_2_.reshape(-1),
conf_pi_3_.reshape(-1),
conf_pi_4_.reshape(-1)
])
conf_pi_all = conf_pi_all.transpose(0,1)
conf_pi_all = (torch.softmax(conf_pi_all, dim=1)).transpose(0,1).reshape(-1)
(
conf_pi_1_,
conf_pi_2_,
conf_pi_3_,
conf_pi_4_
) = torch.split(conf_pi_all, conf_pi_1_.reshape(-1).size(0), dim=0)
conf_pi_1_ = conf_pi_1_.view(conf_pi_1.size(0), -1)
conf_pi_2_ = conf_pi_2_.view(conf_pi_2.size(0), -1)
conf_pi_3_ = conf_pi_3_.view(conf_pi_3.size(0), -1)
conf_pi_4_ = conf_pi_4_.view(conf_pi_4.size(0), -1)
conf_var_1 = torch.sigmoid(conf_var_1)
conf_var_2 = torch.sigmoid(conf_var_2)
conf_var_3 = torch.sigmoid(conf_var_3)
conf_var_4 = torch.sigmoid(conf_var_4)
rand_val_1 = torch.randn(conf_var_1.size(0), conf_var_1.size(1), conf_var_1.size(2))
rand_val_2 = torch.randn(conf_var_2.size(0), conf_var_2.size(1), conf_var_2.size(2))
rand_val_3 = torch.randn(conf_var_3.size(0), conf_var_3.size(1), conf_var_3.size(2))
rand_val_4 = torch.randn(conf_var_4.size(0), conf_var_4.size(1), conf_var_4.size(2))
batch_conf_1 = (conf_mu_1+torch.sqrt(conf_var_1)*rand_val_1).view(-1, self.num_classes)
batch_conf_2 = (conf_mu_2+torch.sqrt(conf_var_2)*rand_val_2).view(-1, self.num_classes)
batch_conf_3 = (conf_mu_3+torch.sqrt(conf_var_3)*rand_val_3).view(-1, self.num_classes)
batch_conf_4 = (conf_mu_4+torch.sqrt(conf_var_4)*rand_val_4).view(-1, self.num_classes)
loss_c_1 = log_sum_exp(batch_conf_1) - batch_conf_1.gather(1, conf_t.view(-1, 1))
loss_c_2 = log_sum_exp(batch_conf_2) - batch_conf_2.gather(1, conf_t.view(-1, 1))
loss_c_3 = log_sum_exp(batch_conf_3) - batch_conf_3.gather(1, conf_t.view(-1, 1))
loss_c_4 = log_sum_exp(batch_conf_4) - batch_conf_4.gather(1, conf_t.view(-1, 1))
loss_c = (
loss_c_1 * conf_pi_1_.view(-1, 1) +
loss_c_2 * conf_pi_2_.view(-1, 1) +
loss_c_3 * conf_pi_3_.view(-1, 1) +
loss_c_4 * conf_pi_4_.view(-1, 1)
)
loss_c = loss_c.view(pos.size()[0], pos.size()[1])
loss_c[pos] = 0 # filter out pos boxes for now : true -> zero
loss_c = loss_c.view(num, -1)
_, loss_idx = loss_c.sort(1, descending=True)
_, idx_rank = loss_idx.sort(1)
num_pos = pos.long().sum(1, keepdim=True)
num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
neg = idx_rank < num_neg.expand_as(idx_rank)
# Confidence Loss Including Positive and Negative Examples
pos_idx = pos.unsqueeze(2).expand_as(conf_mu_1)
neg_idx = neg.unsqueeze(2).expand_as(conf_mu_1)
batch_conf_1_ = conf_mu_1+torch.sqrt(conf_var_1)*rand_val_1
batch_conf_2_ = conf_mu_2+torch.sqrt(conf_var_2)*rand_val_2
batch_conf_3_ = conf_mu_3+torch.sqrt(conf_var_3)*rand_val_3
batch_conf_4_ = conf_mu_4+torch.sqrt(conf_var_4)*rand_val_4
conf_pred_1 = batch_conf_1_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
conf_pred_2 = batch_conf_2_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
conf_pred_3 = batch_conf_3_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
conf_pred_4 = batch_conf_4_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
targets_weighted = conf_t[(pos+neg).gt(0)]
loss_c_1 = log_sum_exp(conf_pred_1) - conf_pred_1.gather(1, targets_weighted.view(-1, 1))
loss_c_2 = log_sum_exp(conf_pred_2) - conf_pred_2.gather(1, targets_weighted.view(-1, 1))
loss_c_3 = log_sum_exp(conf_pred_3) - conf_pred_3.gather(1, targets_weighted.view(-1, 1))
loss_c_4 = log_sum_exp(conf_pred_4) - conf_pred_4.gather(1, targets_weighted.view(-1, 1))
_conf_pi_1 = conf_pi_1_[(pos+neg).gt(0)]
_conf_pi_2 = conf_pi_2_[(pos+neg).gt(0)]
_conf_pi_3 = conf_pi_3_[(pos+neg).gt(0)]
_conf_pi_4 = conf_pi_4_[(pos+neg).gt(0)]
loss_c = (
loss_c_1 * _conf_pi_1.view(-1, 1) +
loss_c_2 * _conf_pi_2.view(-1, 1) +
loss_c_3 * _conf_pi_3.view(-1, 1) +
loss_c_4 * _conf_pi_4.view(-1, 1)
)
loss_c = loss_c.sum()
else:
# Classification loss (Type-2)
# more details are in our supplementary material
conf_pi_1_ = conf_pi_1.view(-1, 1)
conf_pi_2_ = conf_pi_2.view(-1, 1)
conf_pi_3_ = conf_pi_3.view(-1, 1)
conf_pi_4_ = conf_pi_4.view(-1, 1)
conf_pi_all = torch.stack([
conf_pi_1_.reshape(-1),
conf_pi_2_.reshape(-1),
conf_pi_3_.reshape(-1),
conf_pi_4_.reshape(-1)
])
conf_pi_all = conf_pi_all.transpose(0,1)
conf_pi_all = (torch.softmax(conf_pi_all, dim=1)).transpose(0,1).reshape(-1)
(
conf_pi_1_,
conf_pi_2_,
conf_pi_3_,
conf_pi_4_
) = torch.split(conf_pi_all, conf_pi_1_.reshape(-1).size(0), dim=0)
conf_pi_1_ = conf_pi_1_.view(conf_pi_1.size(0), -1)
conf_pi_2_ = conf_pi_2_.view(conf_pi_2.size(0), -1)
conf_pi_3_ = conf_pi_3_.view(conf_pi_3.size(0), -1)
conf_pi_4_ = conf_pi_4_.view(conf_pi_4.size(0), -1)
conf_var_1 = torch.sigmoid(conf_var_1)
conf_var_2 = torch.sigmoid(conf_var_2)
conf_var_3 = torch.sigmoid(conf_var_3)
conf_var_4 = torch.sigmoid(conf_var_4)
rand_val_1 = torch.randn(conf_var_1.size(0), conf_var_1.size(1), conf_var_1.size(2))
rand_val_2 = torch.randn(conf_var_2.size(0), conf_var_2.size(1), conf_var_2.size(2))
rand_val_3 = torch.randn(conf_var_3.size(0), conf_var_3.size(1), conf_var_3.size(2))
rand_val_4 = torch.randn(conf_var_4.size(0), conf_var_4.size(1), conf_var_4.size(2))
batch_conf_1 = (conf_mu_1+torch.sqrt(conf_var_1)*rand_val_1).view(-1, self.num_classes)
batch_conf_2 = (conf_mu_2+torch.sqrt(conf_var_2)*rand_val_2).view(-1, self.num_classes)
batch_conf_3 = (conf_mu_3+torch.sqrt(conf_var_3)*rand_val_3).view(-1, self.num_classes)
batch_conf_4 = (conf_mu_4+torch.sqrt(conf_var_4)*rand_val_4).view(-1, self.num_classes)
soft_max = nn.Softmax(dim=1)
epsi = 10**-9
weighted_softmax_out = (
soft_max(batch_conf_1)*conf_pi_1_.view(-1, 1) +
soft_max(batch_conf_2)*conf_pi_2_.view(-1, 1) +
soft_max(batch_conf_3)*conf_pi_3_.view(-1, 1) +
soft_max(batch_conf_4)*conf_pi_4_.view(-1, 1)
)
softmax_out_log = -torch.log(weighted_softmax_out+epsi)
loss_c = softmax_out_log.gather(1, conf_t.view(-1,1))
loss_c = loss_c.view(pos.size()[0], pos.size()[1])
loss_c[pos] = 0 # filter out pos boxes for now : true -> zero
loss_c = loss_c.view(num, -1)
_, loss_idx = loss_c.sort(1, descending=True)
_, idx_rank = loss_idx.sort(1)
num_pos = pos.long().sum(1, keepdim=True)
num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
neg = idx_rank < num_neg.expand_as(idx_rank)
# Confidence Loss Including Positive and Negative Examples
pos_idx = pos.unsqueeze(2).expand_as(conf_mu_1)
neg_idx = neg.unsqueeze(2).expand_as(conf_mu_1)
batch_conf_1_ = conf_mu_1+torch.sqrt(conf_var_1)*rand_val_1
batch_conf_2_ = conf_mu_2+torch.sqrt(conf_var_2)*rand_val_2
batch_conf_3_ = conf_mu_3+torch.sqrt(conf_var_3)*rand_val_3
batch_conf_4_ = conf_mu_4+torch.sqrt(conf_var_4)*rand_val_4
conf_pred_1 = batch_conf_1_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
conf_pred_2 = batch_conf_2_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
conf_pred_3 = batch_conf_3_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
conf_pred_4 = batch_conf_4_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
targets_weighted = conf_t[(pos+neg).gt(0)]
_conf_pi_1 = conf_pi_1_[(pos+neg).gt(0)]
_conf_pi_2 = conf_pi_2_[(pos+neg).gt(0)]
_conf_pi_3 = conf_pi_3_[(pos+neg).gt(0)]
_conf_pi_4 = conf_pi_4_[(pos+neg).gt(0)]
weighted_softmax_out = (
soft_max(conf_pred_1)*_conf_pi_1.view(-1, 1) +
soft_max(conf_pred_2)*_conf_pi_2.view(-1, 1) +
soft_max(conf_pred_3)*_conf_pi_3.view(-1, 1) +
soft_max(conf_pred_4)*_conf_pi_4.view(-1, 1)
)
softmax_out_log = -torch.log(weighted_softmax_out+epsi)
loss_c = softmax_out_log.gather(1, targets_weighted.view(-1,1))
loss_c = loss_c.sum()
N = num_pos.data.sum()
loss_l /= N
loss_c /= N
return loss_l, loss_c