Comparing images for similarity using SIFT + bag of words

952 Views Asked by At

I'm trying to write a program that identifies which images in a directory are similar to a query image, which is similar to but often slightly different from images in the directory. There are thousands of images in the directory. This question ​is related to Simple and fast method to compare images for similarity.

I have a few goals:

  • Using a query image, identify similar images in a directory of images
    • The query image might be slightly changed from the images in the directory. These changes might include the image being cropped and different image quality.
  • The program should be pretty fast (able to identify similar images in a few seconds at most)

I know this is a question that has a lot of research. A chapter, "Building a Reverse Image Search Engine: Understanding Embeddings" from "Practical Deep Learning for Cloud, Mobile, and Edge" explains some approaches for this question.

I began writing a program to do this using a SIFT (scale-invariant feature transform) + bag of words approach. I don't have much experience in this area. The program I wrote works for an identical image, and pretty well for a slightly similar image, but once the image becomes a bit more dissimilar, it no longer detects the right image.

I have two questions:

  1. Is the approach I'm using for this good, and if not, is there a better approach?
  2. Is there anything in my program that might be causing the searches to be inaccurate for dissimilar images?

This is how the program works:

  1. Go through every image, get its descriptors with SIFT, and build a list of these descriptors.
  2. Using k-means, find the centroids of the list of descriptors. This is the "dictionary".
  3. Go through every image again, and get the k-nearest neighbors knnMatch with k=1 for each image's descriptors and the centroids. Use each match to create a histogram for each image, using match.trainIdx.
  4. Normalize each image's histogram by dividing the count of each "word" by the sum of the "words".
  5. Use knnMatch with k=1 with the query image's descriptors and the centroids. Go through the matches and create a normalized histogram.
  6. Use knnMatch with k=1 on the query image's histogram, and the histograms of all of the images in the database. This creates a list of matches, ordered by similarity to the query image.
import numpy as np
import cv2
import os
from matplotlib import pyplot as plt

sift = cv2.xfeatures2d.SIFT_create()

FLANN_INDEX_KDTREE = 0
index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 100)
search_params = dict(checks = 100)
flann = cv2.FlannBasedMatcher(index_params, search_params)
bf = cv2.BFMatcher()

img1 = cv2.imread('path',0)
db = # load database

kp1, des1 = sift.detectAndCompute(img1,None)

load = False
clusters = 800

if load:
    db.query('DELETE FROM centroids')
    db.query('DELETE FROM histogram')

    descriptors = []

    for file in os.listdir('path'):
        if file.endswith('.png'):
            img = cv2.imread('path/{}'.format(file), 0)

            kp, des = sift.detectAndCompute(img,None)

            if des is None:
                continue

            descriptors.extend(des)

    descriptors = np.float32(descriptors)

    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 5, .01)
    centroids = cv2.kmeans(descriptors, clusters, None, criteria, 1, cv2.KMEANS_PP_CENTERS)[2]
    db.insert('centroids', d = np.ndarray.dumps(centroids))

    for file in os.listdir('path'):
        counter = np.zeros((clusters,), dtype=np.uint32)
        if file.endswith('.png'):
            img = cv2.imread('path/{}'.format(file),0)
            kp, d = sift.detectAndCompute(img,None)
            if d is None:
                continue

            matches = bf.knnMatch(d, centroids, k=1)

            for match in matches:
                counter[match[0].trainIdx] += 1

            counter_sum = np.sum(counter)
            counter = [float(n)/counter_sum for n in counter]

            db.insert('histogram', frame_id = file, count=','.join(np.char.mod('%f', counter)))

histograms_db = list(db.query('SELECT * FROM histogram'))
histograms = []
for histogram in histograms_db:
    histogram = histogram['count'].split(',')
    histograms.append(histogram)
histograms = np.array(histograms)

counter = np.zeros((clusters,), dtype=np.uint32)

centroids = np.loads(db.query('SELECT * FROM centroids')[0]['d'])
matches = bf.knnMatch(des1, centroids, k=1)

for match in matches:
    counter[match[0].trainIdx] += 1

counter_sum = np.sum(counter)
counter = [float(n)/counter_sum for n in counter]

matches = bf.knnMatch(np.float32([counter]), np.float32(histograms), k=1)

for match in matches[0]:
    print "{} {}".format(histograms_db[match.trainIdx]['frame_id'], match.distance)
    name = histograms_db[match.trainIdx]['frame_id']
1

There are 1 best solutions below

0
On

You can use any approximated nearest neighbor search library. For example, try Faiss.