Tensorboard

Huahuatii小于 1 分钟深度学习Deep LearningTensorboard

数据处理

def contrastive_loss(z, labels, margin=1.0):
    '''
    z是隐空间中的编码,labels是样本的标签
    计算同类样本对和异类样本对
    '''
    n_samples = z.shape[0]
    same_class_mask = (labels.view(n_samples, 1) == labels.view(1, n_samples))
    diff_class_mask = ~ same_class_mask
    same_class_indices = torch.where(same_class_mask)
    diff_class_indices = torch.where(diff_class_mask)
    same_class_pairs = list(zip(same_class_indices[0], same_class_indices[1]))
    diff_class_pairs = list(zip(diff_class_indices[0], diff_class_indices[1]))
    # 计算同类样本对的损失函数
    same_class_losses = []
    for i, j in same_class_pairs:
        dist = torch.norm(z[i] - z[j])
        same_class_losses.append(dist ** 2)
    same_class_loss = torch.mean(torch.stack(same_class_losses))
    # 计算异类样本对的损失函数
    diff_class_losses = []
    for i, j in diff_class_pairs:
        dist = torch.norm(z[i] - z[j])
        diff_class_losses.append(F.relu(margin - dist) ** 2)
    diff_class_loss = torch.mean(torch.stack(diff_class_losses))
    # 添加正则项
    # reg_loss = 0.0
    # for param in model.parameters():
    #     reg_loss += torch.sum(param ** 2)
    # 将所有损失函数加权求和,并返回最终的损失值
    total_loss = same_class_loss + diff_class_loss # + 0.01 * reg_loss
    return total_loss