Image Similarity Estimation Using a Siamese Network with Triplet Loss: A Practical Guide
In the realm of machine learning, one of the fascinating challenges is determining the similarity between images. This task is pivotal in various applications, from detecting duplicates to facial recognition. A robust approach to this problem is using a Siamese Network combined with a Triplet Loss function.
In this article, we’ll explore how to build and train a Siamese Network to estimate image similarity using a practical example from a GitHub repository
Introduction to Siamese Networks
A Siamese Network is a type of neural network architecture that contains two or more identical subnetworks. These subnetworks are designed to generate feature vectors for each input, which can then be compared to estimate similarity. The key idea is that the same network is used to process each input, ensuring consistent and comparable outputs.
This architecture is especially useful for tasks like detecting duplicates, finding anomalies, and facial recognition. In the implementation we’ll explore, the network is set up with three identical subnetworks. Each network processes one of three images: an anchor image, a positive sample (similar to the anchor), and a negative sample (unrelated to the anchor).
The Power of Triplet Loss
To train the Siamese Network effectively, we use a Triplet Loss function. This loss function encourages the network to bring the anchor and positive samples closer in the feature space while pushing the anchor and negative samples further apart. The loss function is defined as:
L(A, P, N) = max(‖f(A) — f(P)‖² — ‖f(A) — f(N)‖² + margin, 0)
Here, A is the anchor image, P is the positive image, and N is the negative image. The function f(x) represents the embedding generated by the network, and the margin is a small positive value that helps ensure the network doesn’t collapse all embeddings to the same point.
Setting Up the Siamese Network
In this implementation, we start by loading the Totally Looks Like dataset, which contains images we use to create triplets for training the network.
1. Data Preparation
The dataset is processed using TensorFlow’s tf.data API to create triplets of images. This involves setting up a data pipeline where each triplet consists of an anchor, a positive, and a negative image. The images are preprocessed by resizing them to a target shape and normalizing their pixel values.
def preprocess_image(filename):
image_string = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image_string, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, target_shape)
return image
def preprocess_triplets(anchor, positive, negative):
return (
preprocess_image(anchor),
preprocess_image(positive),
preprocess_image(negative),
)
Here’s an example of triplets generated from the dataset, where the first two images in each row are similar (anchor and positive) and the third one is different (negative):
2. Building the Embedding Generator
The heart of our Siamese Network is the embedding generator, which is constructed using a ResNet50 model pretrained on ImageNet. By freezing the weights of most layers in ResNet50 and fine-tuning only the final layers, we can leverage transfer learning to reduce training time and improve performance.
base_cnn = resnet.ResNet50(
weights="imagenet", input_shape=target_shape + (3,), include_top=False
)
flatten = layers.Flatten()(base_cnn.output)
dense1 = layers.Dense(512, activation="relu")(flatten)
dense1 = layers.BatchNormalization()(dense1)
dense2 = layers.Dense(256, activation="relu")(dense1)
dense2 = layers.BatchNormalization()(dense2)
output = layers.Dense(256)(dense2)
embedding = Model(base_cnn.input, output, name="Embedding")
# Freeze all layers until the layer conv5_block1_out
trainable = False
for layer in base_cnn.layers:
if layer.name == "conv5_block1_out":
trainable = True
layer.trainable = trainable
3. Constructing the Siamese Network
The Siamese Network is set up to take in three images at a time (anchor, positive, and negative). A custom DistanceLayer computes the distances between the anchor-positive pair and the anchor-negative pair. The model is then trained to minimize the distance between similar images and maximize the distance between dissimilar ones.
class DistanceLayer(layers.Layer):
def call(self, anchor, positive, negative):
ap_distance = tf.reduce_sum(tf.square(anchor - positive), -1)
an_distance = tf.reduce_sum(tf.square(anchor - negative), -1)
return (ap_distance, an_distance)
anchor_input = layers.Input(name="anchor", shape=target_shape + (3,))
positive_input = layers.Input(name="positive", shape=target_shape + (3,))
negative_input = layers.Input(name="negative", shape=target_shape + (3,))
distances = DistanceLayer()(
embedding(resnet.preprocess_input(anchor_input)),
embedding(resnet.preprocess_input(positive_input)),
embedding(resnet.preprocess_input(negative_input)),
)
siamese_network = Model(
inputs=[anchor_input, positive_input, negative_input], outputs=distances
)
4. Training and Evaluation
The model is trained using a custom training loop where the triplet loss is computed and used to update the network’s weights. The training process is carefully monitored, and the model’s performance is evaluated by inspecting the learned embeddings.
class SiameseModel(Model):
def __init__(self, siamese_network, margin=0.5):
super(SiameseModel, self).__init__()
self.siamese_network = siamese_network
self.margin = margin
self.loss_tracker = metrics.Mean(name="loss")
def train_step(self, data):
with tf.GradientTape() as tape:
loss = self._compute_loss(data)
gradients = tape.gradient(loss, self.siamese_network.trainable_weights)
self.optimizer.apply_gradients(
zip(gradients, self.siamese_network.trainable_weights)
)
self.loss_tracker.update_state(loss)
return {"loss": self.loss_tracker.result()}
def _compute_loss(self, data):
ap_distance, an_distance = self.siamese_network(data)
loss = ap_distance - an_distance
loss = tf.maximum(loss + self.margin, 0.0)
return loss
5. Inspecting the Results
After training, we can evaluate how well the network has learned to separate similar and dissimilar images by comparing the cosine similarity between the embeddings of the anchor-positive pairs and anchor-negative pairs.
cosine_similarity = metrics.CosineSimilarity()
positive_similarity = cosine_similarity(anchor_embedding, positive_embedding)
print("Positive similarity:", positive_similarity.numpy())
negative_similarity = cosine_similarity(anchor_embedding, negative_embedding)
print("Negative similarity:", negative_similarity.numpy())
Below is an example of triplets as evaluated by the trained model. The network successfully identifies the similarities and differences between the images:
Conclusion
This implementation showcases how a Siamese Network with Triplet Loss can effectively estimate image similarity. By using a pretrained ResNet50 model and fine-tuning its layers, we can create a powerful model that can be applied to various tasks requiring similarity estimation.
The full implementation, including detailed code and explanations, is available in the Siamese Network GitHub repository. Whether you’re a beginner looking to learn about Siamese Networks or an expert seeking to refine your approach, this project provides a comprehensive guide to understanding and applying this powerful technique.