# adapted from: # Pellegrini., C., E. Özsoy., M. Wintergerst., and G. Groh. (2021). # “Exploiting Food Embeddings for Ingredient Substitution.” # In: Proceedings of the 14th International Joint Conference on Biomedical # Engineering Systems and Technologies - Volume 5: HEALTHINF, INSTICC. # SciTePress, pp. 67–77. isbn: 978-989-758-490-9. doi: 10.5220/0010202000670077. import json import os from collections import defaultdict from pathlib import Path from typing import Union import numpy as np import torch from sklearn.decomposition import PCA from tqdm import tqdm from evaluation.helpers.approx_knn_classifier import ApproxKNNClassifier from evaluation.helpers.generate_ingredient_embeddings import generate_food_embedding_dict from evaluation.helpers.knn_classifier import KNNClassifier def avg(values): summed = sum(values) length = len(values) return summed / length def custom_potential_neighbors_sort(potential_neighbors): # First sort by how often something was nearby, if this is equal, use the smaller distance sorted_neighbors = sorted(potential_neighbors.items(), key=lambda x: (len(x[1]), -avg(x[1])), reverse=True) return sorted_neighbors def filter_out_forbidden_neigbours(ingredient_name, potential_neighbors): ''' Neigbors that are the same as the ingredient are to be removed, additional rules such as mozeralla & mozeralla_cheese, penne & penne_pasta can be added here ''' banned_keys = {ingredient_name} # Ban ingredients that contain ingredient_name for ingredient in potential_neighbors.keys(): if ingredient == ingredient_name: banned_keys.add(ingredient) # if ingredient_name in ingredient.split('_'): # banned_keys.add(ingredient) filtered_potential_neighbors = {key: value for key, value in potential_neighbors.items() if key not in banned_keys} return filtered_potential_neighbors def get_nearest_N_neigbours(ingredient_name, ingredients_to_embeddings, all_ingredient_labels, knn_classifier: Union[KNNClassifier, ApproxKNNClassifier], thresh = 50): ingredient_embeddings = ingredients_to_embeddings[ingredient_name] all_distances, all_indices = knn_classifier.k_nearest_neighbors(ingredient_embeddings) potential_neighbors = defaultdict(list) for i in range(len(ingredient_embeddings)): labels = all_ingredient_labels[all_indices[i]] distances = all_distances[i] for label, distance in zip(labels, distances): potential_neighbors[label].append(distance) potential_neighbors = filter_out_forbidden_neigbours(ingredient_name, potential_neighbors) sorted_neighbors = custom_potential_neighbors_sort(potential_neighbors) sorted_neighbors2 = [] for key, value in sorted_neighbors: if len(value) >= thresh: sorted_neighbors2.append((key, value)) # sorted_neighbors = [(key, value) for key, value in sorted_neighbors if len(value) >= len(ingredient_embeddings)] # remove too rare ones # further removal relative_lengths = [len(elem[1]) / (len(sorted_neighbors2[0][1])) for elem in sorted_neighbors2] final_neighbors = [] for idx in range(len(relative_lengths)): if relative_lengths[idx] >= 0.0: # Currently doesn't sort anything out final_neighbors.append(sorted_neighbors2[idx]) try: return list(zip(*final_neighbors))[0] except Exception as e: return None def clean_ingredient_name(ingredient_name, normalization_fixes): words = ingredient_name.split('_') cleaned_words = [] for word in words: if word in normalization_fixes: cleaned_words.append(normalization_fixes[word]) else: cleaned_words.append(word) return ' '.join(cleaned_words) def clean_substitutes(subtitutes, normalization_fixes): cleaned_subtitutes = [] for subtitute in subtitutes: cleaned_subtitutes.append(clean_ingredient_name(subtitute, normalization_fixes)) return cleaned_subtitutes # def test_eval(): # return ["Zucker", "Eier", "Reis", "Spaghetti", "Wein", "Gouda_junger"] def main(): # models = ["Versions/vers1/", "Versions/vers2/"] # models = ["final_Versions/models/vers1/", "final_Versions/models/vers2/", "final_Versions/models/vers3/"] models = ["final_Versions/models/vers2/"] thresh = 100 # models = ["test/"] # os.makedirs('data/eval') # test_substitute_pairs_path = 'Versions/test_substitute_pairs.json' # normalization_fixes_path = Path('data/eval/normalization_correction.json') max_embedding_count = 100 # image_embedding_dim = 768 approx_knn = True # compare models for curr_model in models: # os.makedirs(curr_model + "eval/") substitute_pairs_path = curr_model + "eval/substitute_pairs_" + str(thresh) + ".json" # get embeddings for all ingredients ingredients_to_embeddings = generate_food_embedding_dict(max_sentence_count=max_embedding_count, model_path=curr_model+"output/", eval_path=curr_model + "eval/", dataset_path=curr_model+"dataset/") all_ingredient_embeddings = [] all_ingredient_labels = [] # make list of all ingredients and all embeddings for key, value in ingredients_to_embeddings.items(): all_ingredient_embeddings.append(value) all_ingredient_labels.extend([key] * len(value)) all_ingredient_embeddings = np.concatenate(all_ingredient_embeddings) all_ingredient_labels = np.stack(all_ingredient_labels) # get knn classifier if approx_knn: knn_classifier: Union[KNNClassifier, ApproxKNNClassifier] = ApproxKNNClassifier( all_ingredient_embeddings=all_ingredient_embeddings, max_embedding_count=max_embedding_count, save_path=Path(curr_model + "eval/" + 'approx_knn_classifier.ann')) else: knn_classifier: Union[KNNClassifier, ApproxKNNClassifier] = KNNClassifier( all_ingredient_embeddings=all_ingredient_embeddings, max_embedding_count=max_embedding_count, save_path=Path(curr_model + "eval/" + 'knn_classifier.joblib')) # get substitutes via knn classifier substitute_pairs = set() none_counter = 0 subst_dict = {} for ingredient_name in tqdm(ingredients_to_embeddings.keys(), total=len(ingredients_to_embeddings)): substitutes = get_nearest_N_neigbours(ingredient_name=ingredient_name, ingredients_to_embeddings=ingredients_to_embeddings, all_ingredient_labels=all_ingredient_labels, knn_classifier=knn_classifier, thresh=thresh) if substitutes is None: none_counter += 1 subst_dict[ingredient_name] = [] else: subst_dict[ingredient_name] = list(substitutes) # # cleaned_substitutes = clean_substitutes(substitutes, normalization_fixes) # for cleaned_substitute in cleaned_substitutes: # substitute_pairs.add((clean_ingredient_name(ingredient_name, normalization_fixes), cleaned_substitute)) with open(substitute_pairs_path, 'w') as f: json.dump(subst_dict, f, ensure_ascii=False, indent=4) print(f'Nones: {none_counter}') # output = {} # for ing in ingredients: # output[ing] = [] # for model in all_subs.keys(): # for ing in ingredients: # output[ing].append(all_subs[model][ing]) # # with open(test_substitute_pairs_path, 'w') as f: # json.dump(output, f, ensure_ascii=False, indent=4) if __name__ == '__main__': main()