37 lines
1.6 KiB
Python
37 lines
1.6 KiB
Python
# adapted from:
|
||
# Pellegrini., C., E. Özsoy., M. Wintergerst., and G. Groh. (2021).
|
||
# “Exploiting Food Embeddings for Ingredient Substitution.”
|
||
# In: Proceedings of the 14th International Joint Conference on Biomedical
|
||
# Engineering Systems and Technologies - Volume 5: HEALTHINF, INSTICC.
|
||
# SciTePress, pp. 67–77. isbn: 978-989-758-490-9. doi: 10.5220/0010202000670077.
|
||
|
||
from pathlib import Path
|
||
|
||
import joblib
|
||
from sklearn.neighbors import NearestNeighbors
|
||
|
||
|
||
class KNNClassifier:
|
||
def __init__(self, all_ingredient_embeddings, max_embedding_count,
|
||
save_path=Path('data/eval/knn_classifier.joblib')):
|
||
|
||
if save_path.exists():
|
||
print('Loading Existing Classifier')
|
||
self.knn_classifier: NearestNeighbors = joblib.load(save_path)
|
||
else:
|
||
print('Training New Classifier')
|
||
# To make sure we don't just get ourselves: add max_embedding_count
|
||
self.knn_classifier: NearestNeighbors = NearestNeighbors(n_neighbors=max_embedding_count + 200, n_jobs=12,
|
||
algorithm='brute') # kd_tree, ball_tree or brute
|
||
self.knn_classifier.fit(all_ingredient_embeddings)
|
||
|
||
print('Saving Classifier')
|
||
joblib.dump(self.knn_classifier, save_path)
|
||
|
||
print(f'\nKNN with: {self.knn_classifier._fit_method} and leaf size: {self.knn_classifier.leaf_size}\n')
|
||
|
||
def k_nearest_neighbors(self, ingredient_embeddings):
|
||
distances, indices = self.knn_classifier.kneighbors(ingredient_embeddings, return_distance=True)
|
||
|
||
return distances, indices
|