198 lines
7.7 KiB
Python
198 lines
7.7 KiB
Python
# adapted from:
|
||
# Pellegrini., C., E. Özsoy., M. Wintergerst., and G. Groh. (2021).
|
||
# “Exploiting Food Embeddings for Ingredient Substitution.”
|
||
# In: Proceedings of the 14th International Joint Conference on Biomedical
|
||
# Engineering Systems and Technologies - Volume 5: HEALTHINF, INSTICC.
|
||
# SciTePress, pp. 67–77. isbn: 978-989-758-490-9. doi: 10.5220/0010202000670077.
|
||
|
||
import json
|
||
import os
|
||
from collections import defaultdict
|
||
from pathlib import Path
|
||
from typing import Union
|
||
|
||
import numpy as np
|
||
import torch
|
||
from sklearn.decomposition import PCA
|
||
from tqdm import tqdm
|
||
|
||
from evaluation.helpers.approx_knn_classifier import ApproxKNNClassifier
|
||
from evaluation.helpers.generate_ingredient_embeddings import generate_food_embedding_dict
|
||
from evaluation.helpers.knn_classifier import KNNClassifier
|
||
|
||
|
||
def avg(values):
|
||
summed = sum(values)
|
||
length = len(values)
|
||
return summed / length
|
||
|
||
|
||
def custom_potential_neighbors_sort(potential_neighbors):
|
||
# First sort by how often something was nearby, if this is equal, use the smaller distance
|
||
sorted_neighbors = sorted(potential_neighbors.items(), key=lambda x: (len(x[1]), -avg(x[1])), reverse=True)
|
||
return sorted_neighbors
|
||
|
||
|
||
def filter_out_forbidden_neigbours(ingredient_name, potential_neighbors):
|
||
'''
|
||
Neigbors that are the same as the ingredient are to be removed, additional rules such as mozeralla & mozeralla_cheese, penne & penne_pasta can be added here
|
||
'''
|
||
banned_keys = {ingredient_name}
|
||
|
||
# Ban ingredients that contain ingredient_name
|
||
for ingredient in potential_neighbors.keys():
|
||
if ingredient == ingredient_name:
|
||
banned_keys.add(ingredient)
|
||
# if ingredient_name in ingredient.split('_'):
|
||
# banned_keys.add(ingredient)
|
||
|
||
filtered_potential_neighbors = {key: value for key, value in potential_neighbors.items() if
|
||
key not in banned_keys}
|
||
|
||
return filtered_potential_neighbors
|
||
|
||
|
||
def get_nearest_N_neigbours(ingredient_name, ingredients_to_embeddings, all_ingredient_labels,
|
||
knn_classifier: Union[KNNClassifier, ApproxKNNClassifier], thresh = 50):
|
||
ingredient_embeddings = ingredients_to_embeddings[ingredient_name]
|
||
all_distances, all_indices = knn_classifier.k_nearest_neighbors(ingredient_embeddings)
|
||
|
||
potential_neighbors = defaultdict(list)
|
||
|
||
for i in range(len(ingredient_embeddings)):
|
||
labels = all_ingredient_labels[all_indices[i]]
|
||
distances = all_distances[i]
|
||
|
||
for label, distance in zip(labels, distances):
|
||
potential_neighbors[label].append(distance)
|
||
|
||
potential_neighbors = filter_out_forbidden_neigbours(ingredient_name, potential_neighbors)
|
||
sorted_neighbors = custom_potential_neighbors_sort(potential_neighbors)
|
||
sorted_neighbors2 = []
|
||
for key, value in sorted_neighbors:
|
||
if len(value) >= thresh:
|
||
sorted_neighbors2.append((key, value))
|
||
# sorted_neighbors = [(key, value) for key, value in sorted_neighbors if len(value) >= len(ingredient_embeddings)] # remove too rare ones
|
||
# further removal
|
||
relative_lengths = [len(elem[1]) / (len(sorted_neighbors2[0][1])) for elem in sorted_neighbors2]
|
||
final_neighbors = []
|
||
for idx in range(len(relative_lengths)):
|
||
if relative_lengths[idx] >= 0.0: # Currently doesn't sort anything out
|
||
final_neighbors.append(sorted_neighbors2[idx])
|
||
|
||
try:
|
||
return list(zip(*final_neighbors))[0]
|
||
|
||
except Exception as e:
|
||
return None
|
||
|
||
|
||
def clean_ingredient_name(ingredient_name, normalization_fixes):
|
||
words = ingredient_name.split('_')
|
||
cleaned_words = []
|
||
for word in words:
|
||
if word in normalization_fixes:
|
||
cleaned_words.append(normalization_fixes[word])
|
||
else:
|
||
cleaned_words.append(word)
|
||
|
||
return ' '.join(cleaned_words)
|
||
|
||
|
||
def clean_substitutes(subtitutes, normalization_fixes):
|
||
cleaned_subtitutes = []
|
||
for subtitute in subtitutes:
|
||
cleaned_subtitutes.append(clean_ingredient_name(subtitute, normalization_fixes))
|
||
|
||
return cleaned_subtitutes
|
||
|
||
|
||
# def test_eval():
|
||
# return ["Zucker", "Eier", "Reis", "Spaghetti", "Wein", "Gouda_junger"]
|
||
|
||
|
||
def main():
|
||
# models = ["Versions/vers1/", "Versions/vers2/"]
|
||
# models = ["final_Versions/models/vers1/", "final_Versions/models/vers2/", "final_Versions/models/vers3/"]
|
||
models = ["final_Versions/models/vers2/"]
|
||
thresh = 100
|
||
# models = ["test/"]
|
||
|
||
# os.makedirs('data/eval')
|
||
|
||
# test_substitute_pairs_path = 'Versions/test_substitute_pairs.json'
|
||
|
||
# normalization_fixes_path = Path('data/eval/normalization_correction.json')
|
||
max_embedding_count = 100
|
||
# image_embedding_dim = 768
|
||
approx_knn = True
|
||
|
||
# compare models
|
||
for curr_model in models:
|
||
# os.makedirs(curr_model + "eval/")
|
||
substitute_pairs_path = curr_model + "eval/substitute_pairs_" + str(thresh) + ".json"
|
||
|
||
# get embeddings for all ingredients
|
||
ingredients_to_embeddings = generate_food_embedding_dict(max_sentence_count=max_embedding_count, model_path=curr_model+"output/", eval_path=curr_model + "eval/", dataset_path=curr_model+"dataset/")
|
||
|
||
all_ingredient_embeddings = []
|
||
all_ingredient_labels = []
|
||
|
||
# make list of all ingredients and all embeddings
|
||
for key, value in ingredients_to_embeddings.items():
|
||
all_ingredient_embeddings.append(value)
|
||
all_ingredient_labels.extend([key] * len(value))
|
||
|
||
all_ingredient_embeddings = np.concatenate(all_ingredient_embeddings)
|
||
all_ingredient_labels = np.stack(all_ingredient_labels)
|
||
|
||
# get knn classifier
|
||
if approx_knn:
|
||
knn_classifier: Union[KNNClassifier, ApproxKNNClassifier] = ApproxKNNClassifier(
|
||
all_ingredient_embeddings=all_ingredient_embeddings,
|
||
max_embedding_count=max_embedding_count, save_path=Path(curr_model + "eval/" + 'approx_knn_classifier.ann'))
|
||
else:
|
||
knn_classifier: Union[KNNClassifier, ApproxKNNClassifier] = KNNClassifier(
|
||
all_ingredient_embeddings=all_ingredient_embeddings,
|
||
max_embedding_count=max_embedding_count, save_path=Path(curr_model + "eval/" + 'knn_classifier.joblib'))
|
||
|
||
# get substitutes via knn classifier
|
||
substitute_pairs = set()
|
||
none_counter = 0
|
||
subst_dict = {}
|
||
for ingredient_name in tqdm(ingredients_to_embeddings.keys(), total=len(ingredients_to_embeddings)):
|
||
substitutes = get_nearest_N_neigbours(ingredient_name=ingredient_name,
|
||
ingredients_to_embeddings=ingredients_to_embeddings,
|
||
all_ingredient_labels=all_ingredient_labels,
|
||
knn_classifier=knn_classifier, thresh=thresh)
|
||
|
||
if substitutes is None:
|
||
none_counter += 1
|
||
subst_dict[ingredient_name] = []
|
||
else:
|
||
subst_dict[ingredient_name] = list(substitutes)
|
||
|
||
#
|
||
# cleaned_substitutes = clean_substitutes(substitutes, normalization_fixes)
|
||
# for cleaned_substitute in cleaned_substitutes:
|
||
# substitute_pairs.add((clean_ingredient_name(ingredient_name, normalization_fixes), cleaned_substitute))
|
||
|
||
with open(substitute_pairs_path, 'w') as f:
|
||
json.dump(subst_dict, f, ensure_ascii=False, indent=4)
|
||
print(f'Nones: {none_counter}')
|
||
|
||
|
||
# output = {}
|
||
# for ing in ingredients:
|
||
# output[ing] = []
|
||
# for model in all_subs.keys():
|
||
# for ing in ingredients:
|
||
# output[ing].append(all_subs[model][ing])
|
||
#
|
||
# with open(test_substitute_pairs_path, 'w') as f:
|
||
# json.dump(output, f, ensure_ascii=False, indent=4)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|