Tensorboard
小于 1 分钟
数据处理
def contrastive_loss(z, labels, margin=1.0):
'''
z是隐空间中的编码,labels是样本的标签
计算同类样本对和异类样本对
'''
n_samples = z.shape[0]
same_class_mask = (labels.view(n_samples, 1) == labels.view(1, n_samples))
diff_class_mask = ~ same_class_mask
same_class_indices = torch.where(same_class_mask)
diff_class_indices = torch.where(diff_class_mask)
same_class_pairs = list(zip(same_class_indices[0], same_class_indices[1]))
diff_class_pairs = list(zip(diff_class_indices[0], diff_class_indices[1]))
# 计算同类样本对的损失函数
same_class_losses = []
for i, j in same_class_pairs:
dist = torch.norm(z[i] - z[j])
same_class_losses.append(dist ** 2)
same_class_loss = torch.mean(torch.stack(same_class_losses))
# 计算异类样本对的损失函数
diff_class_losses = []
for i, j in diff_class_pairs:
dist = torch.norm(z[i] - z[j])
diff_class_losses.append(F.relu(margin - dist) ** 2)
diff_class_loss = torch.mean(torch.stack(diff_class_losses))
# 添加正则项
# reg_loss = 0.0
# for param in model.parameters():
# reg_loss += torch.sum(param ** 2)
# 将所有损失函数加权求和,并返回最终的损失值
total_loss = same_class_loss + diff_class_loss # + 0.01 * reg_loss
return total_loss