Files
MasterarbeitCode/evaluation/stats_engl_substitutes_compare.py
2021-04-11 23:28:41 +02:00

207 lines
7.8 KiB
Python

from transformers import BertTokenizer
import json
def print_stats(model_substitutes_dict, cap_at_30):
print("\ncap at 30 set to " + str(cap_at_30))
evaluation_path = "evaluation/"
# synonyms_path = "synonyms.json"
found_substitutes_path = "final_Versions/models/vers2/eval/complete_substitute_pairs_50.json"
model_name = "final_Versions/models/vers3/output/"
with open(evaluation_path + "engl_data/substitute_pairs_foodbert_text.json", "r") as whole_json_file:
engl_list = json.load(whole_json_file)
engl_dict = {}
for foo in engl_list:
if foo[0] in engl_dict.keys():
engl_dict[foo[0]].append(foo[1])
else:
engl_dict[foo[0]] = [foo[1]]
substitute_sum = 0
over30 = 0
for ingred in engl_dict.keys():
curr_nr = len(engl_dict[ingred])
if cap_at_30:
if curr_nr > 30:
substitute_sum += 30
over30 += 1
else:
substitute_sum += curr_nr
else:
if curr_nr > 30:
over30 += 1
substitute_sum += curr_nr
print("english ingredients with over 30 substitutes: " + str(over30))
print("english nones: " + str(4372-len(engl_dict.keys())))
print("average amount of substitutes found for english ingredients: " + str(substitute_sum / 4372))
# with open(found_substitutes_path, "r") as whole_json_file:
# model_substitutes_dict = json.load(whole_json_file)[model_name]
substitute_sum = 0
over100 = 0
over1000 = 0
over30 = 0
nones = 0
for ingred in model_substitutes_dict.keys():
curr_nr = len(model_substitutes_dict[ingred])
if curr_nr == 0:
nones += 1
if curr_nr > 100:
# print(ingred + ": " + str(curr_nr))
over100 += 1
if curr_nr > 1000:
# print(ingred + ": " + str(curr_nr))
over1000 += 1
if cap_at_30:
if curr_nr > 30:
substitute_sum += 30
over30 += 1
else:
substitute_sum += curr_nr
else:
if curr_nr > 30:
over30 += 1
substitute_sum += curr_nr
# print(str(substitute_sum))
print("number of ingredients in dataset: " + str(len(model_substitutes_dict.keys())))
print("number of nones: " + str(nones))
print("ingredients with over 30 substitutes: " + str(over30))
print("ingredients with over 100 substitutes: " + str(over100))
print("ingredients with over 1000 substitutes: " + str(over1000))
print("average number of substitutes: " + str(substitute_sum / len(model_substitutes_dict.keys())))
# print(str(len(model_substitutes_dict.keys())))
def main():
# with open("train_model/vocab/used_ingredients.json", "r") as used_ingredients_file:
# used_ingredients = json.load(used_ingredients_file)
# tokenizer = BertTokenizer(vocab_file='train_model/vocab/vocab.txt', do_lower_case=False, model_max_length=512,
# never_split=used_ingredients)
#
# sent = ["Die Paprika schneiden. Dann die Stücke kochen."]
#
# batch_encoding = tokenizer.batch_encode_plus(sent, add_special_tokens=True, max_length=512, truncation=True)
#
# # Get the input IDs and attention mask in tensor format
# input_ids = batch_encoding['input_ids']
# attn_mask = batch_encoding['attention_mask']
#
# print(input_ids)
# print(attn_mask)
evaluation_path = "evaluation/"
synonyms_path = "synonyms.json"
data_path = "data/"
engl_data_path = evaluation_path + "engl_data/"
found_substitutes_path = "final_Versions/models/vers2/eval/complete_substitute_pairs_50.json"
# model_name = "final_Versions/models/vers3/output/"
with open(found_substitutes_path, "r") as whole_json_file:
model_substitutes_dict = json.load(whole_json_file)
with open(data_path + synonyms_path, "r") as whole_json_file:
synonyms_dict = json.load(whole_json_file)
category_subs = ["Paprika", "Apfel", "Gouda", "Huhn", "Gans", "Kaninchen", "Kalbfleisch", "Schweinefleisch", "Ente", "Lamm",
"Pute", "Wildfleisch", "Rindfleisch", "Speck", "Fisch", "Kohl", "Blattsalat", "Schokolade", "Kuvertüre", "Kuchenglasur",
"Honig", "Sirup", "Joghurt", "Essig", "Traubensaft", "Geflügelfleisch", "Wein", "Suppenfleisch"]
# synonyms_dict = {"Zartbitterschokolade": ["Schokolade_Zartbitter"],
# "Hähnchenfilet": ["Filet_Hähnchen", "Hühnerfilet"],
# "Huhn": ["Hähnchenfilet", "Filet_Hähnchen", "Hühnchenschenkel", "Hühnerbeine"],
# "Kuvertüre_Zartbitter": ["Zartbitterkuvertüre"]}
#
# model_substitutes_dict = {"Zartbitterschokolade": ["Schokolade_Zartbitter", "Kuvertüre_Zartbitter", "Zartbitterkuvertüre", "Nutella"],
# "Schokolade_Zartbitter": ["Kuvertüre_Zartbitter", "Weiße_Schokolade", "Zartbitterschokolade"],
# "Huhn": ["Hähnchenfilet", "Schweinelende"],
# "Dill": ["Petersilie"]}
final_dict = {}
new_syn_dict = {}
# get base word for all synonyms
for ingred in synonyms_dict.keys():
if ingred not in category_subs:
for syn in synonyms_dict[ingred]:
new_syn_dict[syn] = ingred
#
for ingred in model_substitutes_dict.keys():
if ingred not in new_syn_dict.keys():
final_dict[ingred] = set()
for ingred in model_substitutes_dict.keys():
curr_set = set()
for sub in model_substitutes_dict[ingred]:
if sub in new_syn_dict:
curr_set.add(new_syn_dict[sub])
else:
curr_set.add(sub)
if ingred not in new_syn_dict:
final_dict[ingred] |= curr_set
else:
test = new_syn_dict[ingred]
final_dict[test] |= curr_set
# print(final_dict)
for ingred in final_dict.keys():
if ingred in final_dict[ingred]:
final_dict[ingred].remove(ingred)
new_final_dict = {}
for ingred in final_dict.keys():
new_final_dict[ingred] = list(final_dict[ingred])
with open(found_substitutes_path, "r") as whole_json_file:
new_final_dict = json.load(whole_json_file)
print_stats(new_final_dict, cap_at_30=True)
print_stats(new_final_dict, cap_at_30=False)
print("--------------------------------------------\nground truth only: ")
with open("data/ground_truth.json", "r") as whole_json_file:
ground_truth = json.load(whole_json_file)
ground_truth_only = {}
for ingred in new_final_dict.keys():
if ingred in ground_truth.keys():
ground_truth_only[ingred] = new_final_dict[ingred]
print_stats(ground_truth_only, cap_at_30=True)
print_stats(ground_truth_only, cap_at_30=False)
print("================================\nenglisch:")
with open(engl_data_path + "substitute_pairs_foodbert_text.json", "r") as whole_json_file:
engl_list = json.load(whole_json_file)
with open(engl_data_path + "engl_ground_truth.json", "r") as whole_json_file:
engl_ground_truth = json.load(whole_json_file)
engl_dict = {}
for foo in engl_list:
if foo[0] in engl_dict.keys():
engl_dict[foo[0]].append(foo[1])
else:
engl_dict[foo[0]] = [foo[1]]
print_stats(engl_dict, cap_at_30=True)
print_stats(engl_dict, cap_at_30=False)
print("--------------------------------------------\nground truth only: ")
ground_truth_only = {}
for ingred in engl_dict.keys():
if ingred in engl_ground_truth.keys():
ground_truth_only[ingred] = engl_dict[ingred]
print_stats(ground_truth_only, cap_at_30=True)
print_stats(ground_truth_only, cap_at_30=False)
main()