153 lines
6.3 KiB
Python
153 lines
6.3 KiB
Python
# adapted from:
|
||
# Pellegrini., C., E. Özsoy., M. Wintergerst., and G. Groh. (2021).
|
||
# “Exploiting Food Embeddings for Ingredient Substitution.”
|
||
# In: Proceedings of the 14th International Joint Conference on Biomedical
|
||
# Engineering Systems and Technologies - Volume 5: HEALTHINF, INSTICC.
|
||
# SciTePress, pp. 67–77. isbn: 978-989-758-490-9. doi: 10.5220/0010202000670077.
|
||
|
||
import json
|
||
import pickle
|
||
import random
|
||
import re
|
||
from collections import defaultdict
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
import torch
|
||
from tqdm import tqdm
|
||
|
||
from evaluation.helpers.prediction_model import PredictionModel
|
||
|
||
|
||
def _generate_food_sentence_dict(model_path):
|
||
with open('data/mult_ingredients_nice.json', "r") as f:
|
||
food_items = json.load(f)
|
||
food_items_set = set(food_items.keys())
|
||
|
||
with open(model_path + 'training_data.txt', "r") as f:
|
||
train_instruction_sentences = f.read().splitlines()
|
||
# remove overlong sentences
|
||
train_instruction_sentences = [s for s in train_instruction_sentences if len(s.split()) <= 100]
|
||
|
||
with open(model_path + 'testing_data.txt', "r") as f:
|
||
test_instruction_sentences = f.read().splitlines()
|
||
# remove overlong sentences
|
||
test_instruction_sentences = [s for s in test_instruction_sentences if len(s.split()) <= 100]
|
||
|
||
instruction_sentences = train_instruction_sentences + test_instruction_sentences
|
||
|
||
food_to_sentences_dict = defaultdict(list)
|
||
for sentence in instruction_sentences:
|
||
words = re.sub("[^\w]-'", " ", sentence).split()
|
||
for word in words:
|
||
if word in food_items_set:
|
||
food_to_sentences_dict[word].append(sentence)
|
||
|
||
return food_to_sentences_dict
|
||
|
||
|
||
def _random_sample_with_min_count(population, k):
|
||
if len(population) <= k:
|
||
return population
|
||
else:
|
||
return random.sample(population, k)
|
||
|
||
|
||
def sample_random_sentence_dict(model_path, max_sentence_count):
|
||
food_to_sentences_dict = _generate_food_sentence_dict(model_path=model_path)
|
||
# only keep 100 randomly selected sentences
|
||
food_to_sentences_dict_random_samples = {food: _random_sample_with_min_count(sentences, max_sentence_count) for
|
||
food, sentences in food_to_sentences_dict.items()}
|
||
return food_to_sentences_dict_random_samples
|
||
|
||
|
||
def _map_ingredients_to_input_ids(model_path):
|
||
with open('data/mult_ingredients_nice.json', "r") as f:
|
||
ingredients = json.load(f).keys()
|
||
model = PredictionModel(model_path)
|
||
ingredient_ids = model.tokenizer.convert_tokens_to_ids(ingredients)
|
||
|
||
ingredient_ids_dict = dict(zip(ingredients, ingredient_ids))
|
||
|
||
return ingredient_ids_dict
|
||
|
||
|
||
def _merge_synonmys(food_to_embeddings_dict, max_sentence_count):
|
||
synonmy_replacements_path = Path('foodbert_embeddings/data/synonmy_replacements.json')
|
||
synonmy_replacements = {}
|
||
|
||
merged_dict = defaultdict(list)
|
||
# Merge ingredients
|
||
for key, value in food_to_embeddings_dict.items():
|
||
if key in synonmy_replacements:
|
||
key_to_use = synonmy_replacements[key]
|
||
else:
|
||
key_to_use = key
|
||
|
||
merged_dict[key_to_use].append(value)
|
||
|
||
merged_dict = {k: np.concatenate(v) for k, v in merged_dict.items()}
|
||
# When embedding count exceeds maximum allowed, reduce back to requested count
|
||
for key, value in merged_dict.items():
|
||
if len(value) > max_sentence_count:
|
||
index = np.random.choice(value.shape[0], max_sentence_count, replace=False)
|
||
new_value = value[index]
|
||
merged_dict[key] = new_value
|
||
|
||
return merged_dict
|
||
|
||
|
||
def generate_food_embedding_dict(max_sentence_count, model_path, eval_path='data/eval/', dataset_path="output"):
|
||
'''
|
||
Creates a dict where the keys are the ingredients and the values are a list of embeddings with length max_sentence_count or less if there are less occurences
|
||
These embeddings are used in generate_substitutes.py to predict substitutes
|
||
'''
|
||
food_to_embeddings_dict_path = Path(eval_path + 'food_embeddings_dict.pkl')
|
||
if food_to_embeddings_dict_path.exists():
|
||
with food_to_embeddings_dict_path.open('rb') as f:
|
||
food_to_embeddings_dict = pickle.load(f)
|
||
|
||
# # delete keys if we deleted ingredients
|
||
# old_ingredients = set(food_to_embeddings_dict.keys())
|
||
# with open('train_model/vocab/used_ingredients.json', "r") as f:
|
||
# new_ingredients = set(json.load(f))
|
||
#
|
||
# keys_to_delete = old_ingredients.difference(new_ingredients)
|
||
# for key in keys_to_delete:
|
||
# food_to_embeddings_dict.pop(key, None) # delete key if it exists
|
||
#
|
||
# # merge new synonyms
|
||
# food_to_embeddings_dict = _merge_synonmys(food_to_embeddings_dict, max_sentence_count)
|
||
#
|
||
# with food_to_embeddings_dict_path.open('wb') as f:
|
||
# pickle.dump(food_to_embeddings_dict, f) # Overwrite dict with cleaned version
|
||
|
||
return food_to_embeddings_dict
|
||
|
||
print('Sampling Random Sentences')
|
||
food_to_sentences_dict_random_samples = sample_random_sentence_dict(model_path=dataset_path, max_sentence_count=max_sentence_count)
|
||
food_to_embeddings_dict = defaultdict(list)
|
||
print('Mapping Ingredients to Input Ids')
|
||
all_ingredient_ids = _map_ingredients_to_input_ids(model_path=model_path)
|
||
|
||
prediction_model = PredictionModel(model_path=model_path)
|
||
|
||
for food, sentences in tqdm(food_to_sentences_dict_random_samples.items(), total=len(food_to_sentences_dict_random_samples),
|
||
desc='Calculating Embeddings for Food items'):
|
||
embeddings, ingredient_ids = prediction_model.predict_embeddings(sentences)
|
||
# get embedding of food word
|
||
embeddings_flat = embeddings.view((-1, 768))
|
||
ingredient_ids_flat = torch.stack(ingredient_ids).flatten()
|
||
food_id = all_ingredient_ids[food]
|
||
food_embeddings = embeddings_flat[ingredient_ids_flat == food_id].cpu().numpy()
|
||
food_to_embeddings_dict[food].extend(food_embeddings)
|
||
|
||
food_to_embeddings_dict = {k: np.stack(v) for k, v in food_to_embeddings_dict.items()}
|
||
# Clean synonmy
|
||
food_to_embeddings_dict = _merge_synonmys(food_to_embeddings_dict, max_sentence_count)
|
||
|
||
with food_to_embeddings_dict_path.open('wb') as f:
|
||
pickle.dump(food_to_embeddings_dict, f)
|
||
|
||
return food_to_embeddings_dict
|