From 49c6f58e51e12af691f7a1322137c64f46043b15 Mon Sep 17 00:00:00 2001 From: Yigit Sever Date: Tue, 24 Sep 2019 21:26:34 +0300 Subject: Use black linter for WMD --- WMD_retrieval.py | 97 ++++++++++++++++++++++++++------------------------------ 1 file changed, 45 insertions(+), 52 deletions(-) (limited to 'WMD_retrieval.py') diff --git a/WMD_retrieval.py b/WMD_retrieval.py index 02f35be..cb72079 100644 --- a/WMD_retrieval.py +++ b/WMD_retrieval.py @@ -13,7 +13,7 @@ from Wasserstein_Distance import (WassersteinRetriever, def main(args): - np.seterr(divide='ignore') # POT has issues with divide by zero errors + np.seterr(divide="ignore") # POT has issues with divide by zero errors source_lang = args.source_lang target_lang = args.target_lang @@ -29,32 +29,24 @@ def main(args): mode = args.mode runfor = list() - if mode == 'all': - runfor.extend(['wmd', 'snk']) + if mode == "all": + runfor.extend(["wmd", "snk"]) else: runfor.append(mode) defs_source = [ - line.rstrip('\n') - for line in open(source_defs_filename, encoding='utf8') + line.rstrip("\n") for line in open(source_defs_filename, encoding="utf8") ] defs_target = [ - line.rstrip('\n') - for line in open(target_defs_filename, encoding='utf8') + line.rstrip("\n") for line in open(target_defs_filename, encoding="utf8") ] clean_src_corpus, clean_src_vectors, src_keys = clean_corpus_using_embeddings_vocabulary( - set(vectors_source.keys()), - defs_source, - vectors_source, - source_lang, + set(vectors_source.keys()), defs_source, vectors_source, source_lang ) clean_target_corpus, clean_target_vectors, target_keys = clean_corpus_using_embeddings_vocabulary( - set(vectors_target.keys()), - defs_target, - vectors_target, - target_lang, + set(vectors_target.keys()), defs_target, vectors_target, target_lang ) take = args.instances @@ -70,14 +62,15 @@ def main(args): if not batch: print( - f'{source_lang} - {target_lang} : document sizes: {len(clean_src_corpus)}, {len(clean_target_corpus)}' + f"{source_lang} - {target_lang} : document sizes: {len(clean_src_corpus)}, {len(clean_target_corpus)}" ) del vectors_source, vectors_target, defs_source, defs_target vec = CountVectorizer().fit(clean_src_corpus + clean_target_corpus) common = [ - word for word in vec.get_feature_names() + word + for word in vec.get_feature_names() if word in clean_src_vectors or word in clean_target_vectors ] W_common = [] @@ -88,9 +81,7 @@ def main(args): W_common.append(np.array(clean_target_vectors[w])) if not batch: - print( - f'{source_lang} - {target_lang}: the vocabulary size is {len(W_common)}' - ) + print(f"{source_lang} - {target_lang}: the vocabulary size is {len(W_common)}") W_common = np.array(W_common) W_common = normalize(W_common) @@ -101,55 +92,57 @@ def main(args): for metric in runfor: if not batch: - print(f'{metric}: {source_lang} - {target_lang}') + print(f"{metric}: {source_lang} - {target_lang}") - clf = WassersteinRetriever(W_embed=W_common, - n_neighbors=5, - n_jobs=14, - sinkhorn=(metric == 'snk')) + clf = WassersteinRetriever( + W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == "snk") + ) clf.fit(X_train_idf[:instances], np.ones(instances)) - p_at_one, percentage = clf.align(X_test_idf[:instances], - n_neighbors=instances) + p_at_one, percentage = clf.align(X_test_idf[:instances], n_neighbors=instances) if not batch: - print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') + print(f"P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%") else: fields = [ - f'{source_lang}', f'{target_lang}', f'{instances}', - f'{p_at_one}', f'{percentage}' + f"{source_lang}", + f"{target_lang}", + f"{instances}", + f"{p_at_one}", + f"{percentage}", ] - with open(f'{metric}_retrieval_result.csv', 'a') as f: + with open(f"{metric}_retrieval_result.csv", "a") as f: writer = csv.writer(f) writer.writerow(fields) if __name__ == "__main__": - parser = argparse.ArgumentParser( - description='run retrieval using wmd or snk') - parser.add_argument('source_lang', help='source language short name') - parser.add_argument('target_lang', help='target language short name') - parser.add_argument('source_vector', help='path of the source vector') - parser.add_argument('target_vector', help='path of the target vector') - parser.add_argument('source_defs', help='path of the source definitions') - parser.add_argument('target_defs', help='path of the target definitions') + parser = argparse.ArgumentParser(description="run retrieval using wmd or snk") + parser.add_argument("source_lang", help="source language short name") + parser.add_argument("target_lang", help="target language short name") + parser.add_argument("source_vector", help="path of the source vector") + parser.add_argument("target_vector", help="path of the target vector") + parser.add_argument("source_defs", help="path of the source definitions") + parser.add_argument("target_defs", help="path of the target definitions") parser.add_argument( - '-b', - '--batch', - action='store_true', - help= - 'running in batch (store results in csv) or running a single instance (output the results)' + "-b", + "--batch", + action="store_true", + help="running in batch (store results in csv) or running a single instance (output the results)", ) - parser.add_argument('mode', - choices=['all', 'wmd', 'snk'], - default='all', - help='which methods to run') parser.add_argument( - '-n', - '--instances', - help='number of instances in each language to retrieve', + "mode", + choices=["all", "wmd", "snk"], + default="all", + help="which methods to run", + ) + parser.add_argument( + "-n", + "--instances", + help="number of instances in each language to retrieve", default=1000, - type=int) + type=int, + ) args = parser.parse_args() -- cgit v1.2.3-70-g09d2