Files
MasterarbeitCode/evaluation/helpers/knn_classifier.py
2021-04-11 23:28:41 +02:00

37 lines
1.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# adapted from:
# Pellegrini., C., E. Özsoy., M. Wintergerst., and G. Groh. (2021).
# “Exploiting Food Embeddings for Ingredient Substitution.”
# In: Proceedings of the 14th International Joint Conference on Biomedical
# Engineering Systems and Technologies - Volume 5: HEALTHINF, INSTICC.
# SciTePress, pp. 6777. isbn: 978-989-758-490-9. doi: 10.5220/0010202000670077.
from pathlib import Path
import joblib
from sklearn.neighbors import NearestNeighbors
class KNNClassifier:
def __init__(self, all_ingredient_embeddings, max_embedding_count,
save_path=Path('data/eval/knn_classifier.joblib')):
if save_path.exists():
print('Loading Existing Classifier')
self.knn_classifier: NearestNeighbors = joblib.load(save_path)
else:
print('Training New Classifier')
# To make sure we don't just get ourselves: add max_embedding_count
self.knn_classifier: NearestNeighbors = NearestNeighbors(n_neighbors=max_embedding_count + 200, n_jobs=12,
algorithm='brute') # kd_tree, ball_tree or brute
self.knn_classifier.fit(all_ingredient_embeddings)
print('Saving Classifier')
joblib.dump(self.knn_classifier, save_path)
print(f'\nKNN with: {self.knn_classifier._fit_method} and leaf size: {self.knn_classifier.leaf_size}\n')
def k_nearest_neighbors(self, ingredient_embeddings):
distances, indices = self.knn_classifier.kneighbors(ingredient_embeddings, return_distance=True)
return distances, indices