Files
MasterarbeitCode/evaluation/helpers/approx_knn_classifier.py
2021-04-11 23:28:41 +02:00

46 lines
2.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# adapted from:
# Pellegrini., C., E. Özsoy., M. Wintergerst., and G. Groh. (2021).
# “Exploiting Food Embeddings for Ingredient Substitution.”
# In: Proceedings of the 14th International Joint Conference on Biomedical
# Engineering Systems and Technologies - Volume 5: HEALTHINF, INSTICC.
# SciTePress, pp. 6777. isbn: 978-989-758-490-9. doi: 10.5220/0010202000670077.
from pathlib import Path
import numpy as np
from annoy import AnnoyIndex
from tqdm import tqdm
# Full guide https://github.com/spotify/annoy
class ApproxKNNClassifier:
def __init__(self, all_ingredient_embeddings, max_embedding_count,
save_path=Path('data/eval/approx_knn_classifier.ann'), n_trees=10):
vector_length = all_ingredient_embeddings.shape[-1]
self.max_embedding_count = max_embedding_count
if save_path.exists():
print('Loading Existing Approx Classifier')
self.approx_knn_classifier = AnnoyIndex(vector_length, 'angular')
self.approx_knn_classifier.load(str(save_path)) # super fast, will just mmap the file
else:
# To make sure we don't just get ourselves: add max_embedding_count
self.approx_knn_classifier = AnnoyIndex(vector_length, 'angular') # Length of item vector that will be indexed
for i in tqdm(range(len(all_ingredient_embeddings)), total=len(all_ingredient_embeddings), desc='Creating Approx Classifier'):
self.approx_knn_classifier.add_item(i, all_ingredient_embeddings[i])
self.approx_knn_classifier.build(n_trees)
print('Saving Approx Classifier')
self.approx_knn_classifier.save(str(save_path))
def k_nearest_neighbors(self, ingredient_embeddings):
all_indices, all_distances = [], []
for idx, ingredient_embedding in enumerate(
ingredient_embeddings): # search_k gives you a run-time tradeoff between better accuracy and speed currently defaults
indices, distances = self.approx_knn_classifier.get_nns_by_vector(ingredient_embedding, self.max_embedding_count + 200, include_distances=True)
all_indices.append(indices)
all_distances.append(distances)
return np.stack(all_distances), np.stack(all_indices)