aboutsummaryrefslogtreecommitdiffstats
path: root/WMD_retrieval.py
diff options
context:
space:
mode:
authorYigit Sever2019-09-21 16:55:19 +0300
committerYigit Sever2019-09-21 16:55:19 +0300
commit73df12b6787304e6982e4543d63700beabc88085 (patch)
tree2d222938890e73bfa97d4972deb266311a3496a1 /WMD_retrieval.py
parent1558c5e3ea8e34286ad587a8473f5121d5e8d289 (diff)
downloadEvaluating-Dictionary-Alignment-73df12b6787304e6982e4543d63700beabc88085.tar.gz
Evaluating-Dictionary-Alignment-73df12b6787304e6982e4543d63700beabc88085.tar.bz2
Evaluating-Dictionary-Alignment-73df12b6787304e6982e4543d63700beabc88085.zip
Fix typo
Diffstat (limited to 'WMD_retrieval.py')
-rw-r--r--WMD_retrieval.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/WMD_retrieval.py b/WMD_retrieval.py
index b49ba7d..32f3b5d 100644
--- a/WMD_retrieval.py
+++ b/WMD_retrieval.py
@@ -92,7 +92,7 @@ def main(args):
92 if (not batch): 92 if (not batch):
93 print(f'{metric} - tfidf: {source_lang} - {target_lang}') 93 print(f'{metric} - tfidf: {source_lang} - {target_lang}')
94 94
95 clf = WassersteinDistances(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) 95 clf = Wasserstein_Retriever(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk'))
96 clf.fit(X_train_idf[:instances], np.ones(instances)) 96 clf.fit(X_train_idf[:instances], np.ones(instances))
97 dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) 97 dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances)
98 mrr, p_at_1 = mrr_precision_at_k(list(range(len(preds))), preds) 98 mrr, p_at_1 = mrr_precision_at_k(list(range(len(preds))), preds)
@@ -118,7 +118,7 @@ if __name__ == "__main__":
118 parser.add_argument('target_defs', help='path of the target definitions') 118 parser.add_argument('target_defs', help='path of the target definitions')
119 parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)') 119 parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)')
120 parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run') 120 parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run')
121 parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=2000, type=int) 121 parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=1000, type=int)
122 122
123 args = parser.parse_args() 123 args = parser.parse_args()
124 124