Files
MasterarbeitCode/evaluation/generate_substitutes.py
2021-04-11 23:28:41 +02:00

198 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# adapted from:
# Pellegrini., C., E. Özsoy., M. Wintergerst., and G. Groh. (2021).
# “Exploiting Food Embeddings for Ingredient Substitution.”
# In: Proceedings of the 14th International Joint Conference on Biomedical
# Engineering Systems and Technologies - Volume 5: HEALTHINF, INSTICC.
# SciTePress, pp. 6777. isbn: 978-989-758-490-9. doi: 10.5220/0010202000670077.
import json
import os
from collections import defaultdict
from pathlib import Path
from typing import Union
import numpy as np
import torch
from sklearn.decomposition import PCA
from tqdm import tqdm
from evaluation.helpers.approx_knn_classifier import ApproxKNNClassifier
from evaluation.helpers.generate_ingredient_embeddings import generate_food_embedding_dict
from evaluation.helpers.knn_classifier import KNNClassifier
def avg(values):
summed = sum(values)
length = len(values)
return summed / length
def custom_potential_neighbors_sort(potential_neighbors):
# First sort by how often something was nearby, if this is equal, use the smaller distance
sorted_neighbors = sorted(potential_neighbors.items(), key=lambda x: (len(x[1]), -avg(x[1])), reverse=True)
return sorted_neighbors
def filter_out_forbidden_neigbours(ingredient_name, potential_neighbors):
'''
Neigbors that are the same as the ingredient are to be removed, additional rules such as mozeralla & mozeralla_cheese, penne & penne_pasta can be added here
'''
banned_keys = {ingredient_name}
# Ban ingredients that contain ingredient_name
for ingredient in potential_neighbors.keys():
if ingredient == ingredient_name:
banned_keys.add(ingredient)
# if ingredient_name in ingredient.split('_'):
# banned_keys.add(ingredient)
filtered_potential_neighbors = {key: value for key, value in potential_neighbors.items() if
key not in banned_keys}
return filtered_potential_neighbors
def get_nearest_N_neigbours(ingredient_name, ingredients_to_embeddings, all_ingredient_labels,
knn_classifier: Union[KNNClassifier, ApproxKNNClassifier], thresh = 50):
ingredient_embeddings = ingredients_to_embeddings[ingredient_name]
all_distances, all_indices = knn_classifier.k_nearest_neighbors(ingredient_embeddings)
potential_neighbors = defaultdict(list)
for i in range(len(ingredient_embeddings)):
labels = all_ingredient_labels[all_indices[i]]
distances = all_distances[i]
for label, distance in zip(labels, distances):
potential_neighbors[label].append(distance)
potential_neighbors = filter_out_forbidden_neigbours(ingredient_name, potential_neighbors)
sorted_neighbors = custom_potential_neighbors_sort(potential_neighbors)
sorted_neighbors2 = []
for key, value in sorted_neighbors:
if len(value) >= thresh:
sorted_neighbors2.append((key, value))
# sorted_neighbors = [(key, value) for key, value in sorted_neighbors if len(value) >= len(ingredient_embeddings)] # remove too rare ones
# further removal
relative_lengths = [len(elem[1]) / (len(sorted_neighbors2[0][1])) for elem in sorted_neighbors2]
final_neighbors = []
for idx in range(len(relative_lengths)):
if relative_lengths[idx] >= 0.0: # Currently doesn't sort anything out
final_neighbors.append(sorted_neighbors2[idx])
try:
return list(zip(*final_neighbors))[0]
except Exception as e:
return None
def clean_ingredient_name(ingredient_name, normalization_fixes):
words = ingredient_name.split('_')
cleaned_words = []
for word in words:
if word in normalization_fixes:
cleaned_words.append(normalization_fixes[word])
else:
cleaned_words.append(word)
return ' '.join(cleaned_words)
def clean_substitutes(subtitutes, normalization_fixes):
cleaned_subtitutes = []
for subtitute in subtitutes:
cleaned_subtitutes.append(clean_ingredient_name(subtitute, normalization_fixes))
return cleaned_subtitutes
# def test_eval():
# return ["Zucker", "Eier", "Reis", "Spaghetti", "Wein", "Gouda_junger"]
def main():
# models = ["Versions/vers1/", "Versions/vers2/"]
# models = ["final_Versions/models/vers1/", "final_Versions/models/vers2/", "final_Versions/models/vers3/"]
models = ["final_Versions/models/vers2/"]
thresh = 100
# models = ["test/"]
# os.makedirs('data/eval')
# test_substitute_pairs_path = 'Versions/test_substitute_pairs.json'
# normalization_fixes_path = Path('data/eval/normalization_correction.json')
max_embedding_count = 100
# image_embedding_dim = 768
approx_knn = True
# compare models
for curr_model in models:
# os.makedirs(curr_model + "eval/")
substitute_pairs_path = curr_model + "eval/substitute_pairs_" + str(thresh) + ".json"
# get embeddings for all ingredients
ingredients_to_embeddings = generate_food_embedding_dict(max_sentence_count=max_embedding_count, model_path=curr_model+"output/", eval_path=curr_model + "eval/", dataset_path=curr_model+"dataset/")
all_ingredient_embeddings = []
all_ingredient_labels = []
# make list of all ingredients and all embeddings
for key, value in ingredients_to_embeddings.items():
all_ingredient_embeddings.append(value)
all_ingredient_labels.extend([key] * len(value))
all_ingredient_embeddings = np.concatenate(all_ingredient_embeddings)
all_ingredient_labels = np.stack(all_ingredient_labels)
# get knn classifier
if approx_knn:
knn_classifier: Union[KNNClassifier, ApproxKNNClassifier] = ApproxKNNClassifier(
all_ingredient_embeddings=all_ingredient_embeddings,
max_embedding_count=max_embedding_count, save_path=Path(curr_model + "eval/" + 'approx_knn_classifier.ann'))
else:
knn_classifier: Union[KNNClassifier, ApproxKNNClassifier] = KNNClassifier(
all_ingredient_embeddings=all_ingredient_embeddings,
max_embedding_count=max_embedding_count, save_path=Path(curr_model + "eval/" + 'knn_classifier.joblib'))
# get substitutes via knn classifier
substitute_pairs = set()
none_counter = 0
subst_dict = {}
for ingredient_name in tqdm(ingredients_to_embeddings.keys(), total=len(ingredients_to_embeddings)):
substitutes = get_nearest_N_neigbours(ingredient_name=ingredient_name,
ingredients_to_embeddings=ingredients_to_embeddings,
all_ingredient_labels=all_ingredient_labels,
knn_classifier=knn_classifier, thresh=thresh)
if substitutes is None:
none_counter += 1
subst_dict[ingredient_name] = []
else:
subst_dict[ingredient_name] = list(substitutes)
#
# cleaned_substitutes = clean_substitutes(substitutes, normalization_fixes)
# for cleaned_substitute in cleaned_substitutes:
# substitute_pairs.add((clean_ingredient_name(ingredient_name, normalization_fixes), cleaned_substitute))
with open(substitute_pairs_path, 'w') as f:
json.dump(subst_dict, f, ensure_ascii=False, indent=4)
print(f'Nones: {none_counter}')
# output = {}
# for ing in ingredients:
# output[ing] = []
# for model in all_subs.keys():
# for ing in ingredients:
# output[ing].append(all_subs[model][ing])
#
# with open(test_substitute_pairs_path, 'w') as f:
# json.dump(output, f, ensure_ascii=False, indent=4)
if __name__ == '__main__':
main()