initial commit of project
This commit is contained in:
36
evaluation/helpers/knn_classifier.py
Normal file
36
evaluation/helpers/knn_classifier.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# adapted from:
|
||||
# Pellegrini., C., E. Özsoy., M. Wintergerst., and G. Groh. (2021).
|
||||
# “Exploiting Food Embeddings for Ingredient Substitution.”
|
||||
# In: Proceedings of the 14th International Joint Conference on Biomedical
|
||||
# Engineering Systems and Technologies - Volume 5: HEALTHINF, INSTICC.
|
||||
# SciTePress, pp. 67–77. isbn: 978-989-758-490-9. doi: 10.5220/0010202000670077.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import joblib
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
|
||||
|
||||
class KNNClassifier:
|
||||
def __init__(self, all_ingredient_embeddings, max_embedding_count,
|
||||
save_path=Path('data/eval/knn_classifier.joblib')):
|
||||
|
||||
if save_path.exists():
|
||||
print('Loading Existing Classifier')
|
||||
self.knn_classifier: NearestNeighbors = joblib.load(save_path)
|
||||
else:
|
||||
print('Training New Classifier')
|
||||
# To make sure we don't just get ourselves: add max_embedding_count
|
||||
self.knn_classifier: NearestNeighbors = NearestNeighbors(n_neighbors=max_embedding_count + 200, n_jobs=12,
|
||||
algorithm='brute') # kd_tree, ball_tree or brute
|
||||
self.knn_classifier.fit(all_ingredient_embeddings)
|
||||
|
||||
print('Saving Classifier')
|
||||
joblib.dump(self.knn_classifier, save_path)
|
||||
|
||||
print(f'\nKNN with: {self.knn_classifier._fit_method} and leaf size: {self.knn_classifier.leaf_size}\n')
|
||||
|
||||
def k_nearest_neighbors(self, ingredient_embeddings):
|
||||
distances, indices = self.knn_classifier.kneighbors(ingredient_embeddings, return_distance=True)
|
||||
|
||||
return distances, indices
|
||||
Reference in New Issue
Block a user