Implementing contrastive loss and triplet loss in Tensorflow

Update (2018/03/19): I wrote a blog post detailing how to implement triplet loss in TensorFlow.


You need to implement yourself the contrastive loss or the triplet loss, but once you know the pairs or triplets this is quite easy.


Contrastive Loss

Suppose you have as input the pairs of data and their label (positive or negative, i.e. same class or different class). For instance you have images as input of size 28x28x1:

left = tf.placeholder(tf.float32, [None, 28, 28, 1])
right = tf.placeholder(tf.float32, [None, 28, 28, 1])
label = tf.placeholder(tf.int32, [None, 1]). # 0 if same, 1 if different
margin = 0.2

left_output = model(left)  # shape [None, 128]
right_output = model(right)  # shape [None, 128]

d = tf.reduce_sum(tf.square(left_output - right_output), 1)
d_sqrt = tf.sqrt(d)

loss = label * tf.square(tf.maximum(0., margin - d_sqrt)) + (1 - label) * d

loss = 0.5 * tf.reduce_mean(loss)

Triplet Loss

Same as with contrastive loss, but with triplets (anchor, positive, negative). You don’t need labels here.

anchor_output = ...  # shape [None, 128]
positive_output = ...  # shape [None, 128]
negative_output = ...  # shape [None, 128]

d_pos = tf.reduce_sum(tf.square(anchor_output - positive_output), 1)
d_neg = tf.reduce_sum(tf.square(anchor_output - negative_output), 1)

loss = tf.maximum(0., margin + d_pos - d_neg)
loss = tf.reduce_mean(loss)

The real trouble when implementing triplet loss or contrastive loss in TensorFlow is how to sample the triplets or pairs. I will focus on generating triplets because it is harder than generating pairs.

The easiest way is to generate them outside of the Tensorflow graph, i.e. in python and feed them to the network through the placeholders. Basically you select images 3 at a time, with the first two from the same class and the third from another class. We then perform a feedforward on these triplets, and compute the triplet loss.

The issue here is that generating triplets is complicated. We want them to be valid triplets, triplets with a positive loss (otherwise the loss is 0 and the network doesn’t learn).
To know whether a triplet is good or not you need to compute its loss, so you already make one feedforward through the network…

Clearly, implementing triplet loss in Tensorflow is hard, and there are ways to make it more efficient than sampling in python but explaining them would require a whole blog post !

Leave a Comment