Tensorflow Recommender - ScaNN passing embedding as query

123 Views Asked by At

I want to pass a query embedding to ScaNN instead of a model, what data type should I use for this?

My query would look like this [1, 0.3, 0.4] My candidate embedding would be something like: [[0.2, 1, .4], [0.3,0.1,0.56]]

All the examples I see are passing an query model, not the embedding itself.

I tried passing a numpy array but it didn't work

1

There are 1 best solutions below

0
On

Embeddings are just lists of vectors which your model produces. In this case using the tf.keras.layers.Embedding layer.

self._embeddings = {}
# Compute embeddings for string features
for feature_name in str_features:
  vocabulary = vocabularies[feature_name]
  self._embeddings[feature_name] = tf.keras.Sequential([
      tf.keras.layers.StringLookup(
          vocabulary=vocabulary, mask_token=None),
       tf.keras.layers.Embedding(len(vocabulary) + 1,
                                 self.embedding_dimension)
])

You can also use another model such as a Sentence Transformer to create embeddings.

from sentence_transformers import SentenceTransformer
sentences = ["This is an example sentence", "Each sentence is converted"]

model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
embeddings = model.encode(sentences)
print(embeddings)

You do not need to pass the model to ScaNN, you can pass it the embeddings directly as well as mentioned in the documentation here

enter image description here

Here is a sample code snippet on how to pass embeddings directly to scann

import pandas as pd
from sklearn import preprocessing, metrics

df = pd.read_csv("./data/mydata.csv")

# normalization
df_np = preprocessing.normalize(df.iloc[:,1:], norm=norm)


num_neighbors = 100

# creating searcher
k = int(np.sqrt(df_np.shape[0]))
searcher = scann.scann_ops_pybind.builder(df_np, num_neighbors, "dot_product").tree(
    num_leaves=k, 
    num_leaves_to_search=int(k/20), 
    training_sample_size=2500).score_brute_force(2).reorder(7).build()

Here is a blog post on using ScaNN

ScaNN optimization and configuration