initial commit of project
This commit is contained in:
197
evaluation/generate_substitutes.py
Normal file
197
evaluation/generate_substitutes.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# adapted from:
|
||||
# Pellegrini., C., E. Özsoy., M. Wintergerst., and G. Groh. (2021).
|
||||
# “Exploiting Food Embeddings for Ingredient Substitution.”
|
||||
# In: Proceedings of the 14th International Joint Conference on Biomedical
|
||||
# Engineering Systems and Technologies - Volume 5: HEALTHINF, INSTICC.
|
||||
# SciTePress, pp. 67–77. isbn: 978-989-758-490-9. doi: 10.5220/0010202000670077.
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from sklearn.decomposition import PCA
|
||||
from tqdm import tqdm
|
||||
|
||||
from evaluation.helpers.approx_knn_classifier import ApproxKNNClassifier
|
||||
from evaluation.helpers.generate_ingredient_embeddings import generate_food_embedding_dict
|
||||
from evaluation.helpers.knn_classifier import KNNClassifier
|
||||
|
||||
|
||||
def avg(values):
|
||||
summed = sum(values)
|
||||
length = len(values)
|
||||
return summed / length
|
||||
|
||||
|
||||
def custom_potential_neighbors_sort(potential_neighbors):
|
||||
# First sort by how often something was nearby, if this is equal, use the smaller distance
|
||||
sorted_neighbors = sorted(potential_neighbors.items(), key=lambda x: (len(x[1]), -avg(x[1])), reverse=True)
|
||||
return sorted_neighbors
|
||||
|
||||
|
||||
def filter_out_forbidden_neigbours(ingredient_name, potential_neighbors):
|
||||
'''
|
||||
Neigbors that are the same as the ingredient are to be removed, additional rules such as mozeralla & mozeralla_cheese, penne & penne_pasta can be added here
|
||||
'''
|
||||
banned_keys = {ingredient_name}
|
||||
|
||||
# Ban ingredients that contain ingredient_name
|
||||
for ingredient in potential_neighbors.keys():
|
||||
if ingredient == ingredient_name:
|
||||
banned_keys.add(ingredient)
|
||||
# if ingredient_name in ingredient.split('_'):
|
||||
# banned_keys.add(ingredient)
|
||||
|
||||
filtered_potential_neighbors = {key: value for key, value in potential_neighbors.items() if
|
||||
key not in banned_keys}
|
||||
|
||||
return filtered_potential_neighbors
|
||||
|
||||
|
||||
def get_nearest_N_neigbours(ingredient_name, ingredients_to_embeddings, all_ingredient_labels,
|
||||
knn_classifier: Union[KNNClassifier, ApproxKNNClassifier], thresh = 50):
|
||||
ingredient_embeddings = ingredients_to_embeddings[ingredient_name]
|
||||
all_distances, all_indices = knn_classifier.k_nearest_neighbors(ingredient_embeddings)
|
||||
|
||||
potential_neighbors = defaultdict(list)
|
||||
|
||||
for i in range(len(ingredient_embeddings)):
|
||||
labels = all_ingredient_labels[all_indices[i]]
|
||||
distances = all_distances[i]
|
||||
|
||||
for label, distance in zip(labels, distances):
|
||||
potential_neighbors[label].append(distance)
|
||||
|
||||
potential_neighbors = filter_out_forbidden_neigbours(ingredient_name, potential_neighbors)
|
||||
sorted_neighbors = custom_potential_neighbors_sort(potential_neighbors)
|
||||
sorted_neighbors2 = []
|
||||
for key, value in sorted_neighbors:
|
||||
if len(value) >= thresh:
|
||||
sorted_neighbors2.append((key, value))
|
||||
# sorted_neighbors = [(key, value) for key, value in sorted_neighbors if len(value) >= len(ingredient_embeddings)] # remove too rare ones
|
||||
# further removal
|
||||
relative_lengths = [len(elem[1]) / (len(sorted_neighbors2[0][1])) for elem in sorted_neighbors2]
|
||||
final_neighbors = []
|
||||
for idx in range(len(relative_lengths)):
|
||||
if relative_lengths[idx] >= 0.0: # Currently doesn't sort anything out
|
||||
final_neighbors.append(sorted_neighbors2[idx])
|
||||
|
||||
try:
|
||||
return list(zip(*final_neighbors))[0]
|
||||
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
def clean_ingredient_name(ingredient_name, normalization_fixes):
|
||||
words = ingredient_name.split('_')
|
||||
cleaned_words = []
|
||||
for word in words:
|
||||
if word in normalization_fixes:
|
||||
cleaned_words.append(normalization_fixes[word])
|
||||
else:
|
||||
cleaned_words.append(word)
|
||||
|
||||
return ' '.join(cleaned_words)
|
||||
|
||||
|
||||
def clean_substitutes(subtitutes, normalization_fixes):
|
||||
cleaned_subtitutes = []
|
||||
for subtitute in subtitutes:
|
||||
cleaned_subtitutes.append(clean_ingredient_name(subtitute, normalization_fixes))
|
||||
|
||||
return cleaned_subtitutes
|
||||
|
||||
|
||||
# def test_eval():
|
||||
# return ["Zucker", "Eier", "Reis", "Spaghetti", "Wein", "Gouda_junger"]
|
||||
|
||||
|
||||
def main():
|
||||
# models = ["Versions/vers1/", "Versions/vers2/"]
|
||||
# models = ["final_Versions/models/vers1/", "final_Versions/models/vers2/", "final_Versions/models/vers3/"]
|
||||
models = ["final_Versions/models/vers2/"]
|
||||
thresh = 100
|
||||
# models = ["test/"]
|
||||
|
||||
# os.makedirs('data/eval')
|
||||
|
||||
# test_substitute_pairs_path = 'Versions/test_substitute_pairs.json'
|
||||
|
||||
# normalization_fixes_path = Path('data/eval/normalization_correction.json')
|
||||
max_embedding_count = 100
|
||||
# image_embedding_dim = 768
|
||||
approx_knn = True
|
||||
|
||||
# compare models
|
||||
for curr_model in models:
|
||||
# os.makedirs(curr_model + "eval/")
|
||||
substitute_pairs_path = curr_model + "eval/substitute_pairs_" + str(thresh) + ".json"
|
||||
|
||||
# get embeddings for all ingredients
|
||||
ingredients_to_embeddings = generate_food_embedding_dict(max_sentence_count=max_embedding_count, model_path=curr_model+"output/", eval_path=curr_model + "eval/", dataset_path=curr_model+"dataset/")
|
||||
|
||||
all_ingredient_embeddings = []
|
||||
all_ingredient_labels = []
|
||||
|
||||
# make list of all ingredients and all embeddings
|
||||
for key, value in ingredients_to_embeddings.items():
|
||||
all_ingredient_embeddings.append(value)
|
||||
all_ingredient_labels.extend([key] * len(value))
|
||||
|
||||
all_ingredient_embeddings = np.concatenate(all_ingredient_embeddings)
|
||||
all_ingredient_labels = np.stack(all_ingredient_labels)
|
||||
|
||||
# get knn classifier
|
||||
if approx_knn:
|
||||
knn_classifier: Union[KNNClassifier, ApproxKNNClassifier] = ApproxKNNClassifier(
|
||||
all_ingredient_embeddings=all_ingredient_embeddings,
|
||||
max_embedding_count=max_embedding_count, save_path=Path(curr_model + "eval/" + 'approx_knn_classifier.ann'))
|
||||
else:
|
||||
knn_classifier: Union[KNNClassifier, ApproxKNNClassifier] = KNNClassifier(
|
||||
all_ingredient_embeddings=all_ingredient_embeddings,
|
||||
max_embedding_count=max_embedding_count, save_path=Path(curr_model + "eval/" + 'knn_classifier.joblib'))
|
||||
|
||||
# get substitutes via knn classifier
|
||||
substitute_pairs = set()
|
||||
none_counter = 0
|
||||
subst_dict = {}
|
||||
for ingredient_name in tqdm(ingredients_to_embeddings.keys(), total=len(ingredients_to_embeddings)):
|
||||
substitutes = get_nearest_N_neigbours(ingredient_name=ingredient_name,
|
||||
ingredients_to_embeddings=ingredients_to_embeddings,
|
||||
all_ingredient_labels=all_ingredient_labels,
|
||||
knn_classifier=knn_classifier, thresh=thresh)
|
||||
|
||||
if substitutes is None:
|
||||
none_counter += 1
|
||||
subst_dict[ingredient_name] = []
|
||||
else:
|
||||
subst_dict[ingredient_name] = list(substitutes)
|
||||
|
||||
#
|
||||
# cleaned_substitutes = clean_substitutes(substitutes, normalization_fixes)
|
||||
# for cleaned_substitute in cleaned_substitutes:
|
||||
# substitute_pairs.add((clean_ingredient_name(ingredient_name, normalization_fixes), cleaned_substitute))
|
||||
|
||||
with open(substitute_pairs_path, 'w') as f:
|
||||
json.dump(subst_dict, f, ensure_ascii=False, indent=4)
|
||||
print(f'Nones: {none_counter}')
|
||||
|
||||
|
||||
# output = {}
|
||||
# for ing in ingredients:
|
||||
# output[ing] = []
|
||||
# for model in all_subs.keys():
|
||||
# for ing in ingredients:
|
||||
# output[ing].append(all_subs[model][ing])
|
||||
#
|
||||
# with open(test_substitute_pairs_path, 'w') as f:
|
||||
# json.dump(output, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user