diff options
Diffstat (limited to 'WMD_retrieval.py')
-rw-r--r-- | WMD_retrieval.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/WMD_retrieval.py b/WMD_retrieval.py index b49ba7d..32f3b5d 100644 --- a/WMD_retrieval.py +++ b/WMD_retrieval.py | |||
@@ -92,7 +92,7 @@ def main(args): | |||
92 | if (not batch): | 92 | if (not batch): |
93 | print(f'{metric} - tfidf: {source_lang} - {target_lang}') | 93 | print(f'{metric} - tfidf: {source_lang} - {target_lang}') |
94 | 94 | ||
95 | clf = WassersteinDistances(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) | 95 | clf = Wasserstein_Retriever(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) |
96 | clf.fit(X_train_idf[:instances], np.ones(instances)) | 96 | clf.fit(X_train_idf[:instances], np.ones(instances)) |
97 | dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) | 97 | dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) |
98 | mrr, p_at_1 = mrr_precision_at_k(list(range(len(preds))), preds) | 98 | mrr, p_at_1 = mrr_precision_at_k(list(range(len(preds))), preds) |
@@ -118,7 +118,7 @@ if __name__ == "__main__": | |||
118 | parser.add_argument('target_defs', help='path of the target definitions') | 118 | parser.add_argument('target_defs', help='path of the target definitions') |
119 | parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)') | 119 | parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)') |
120 | parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run') | 120 | parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run') |
121 | parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=2000, type=int) | 121 | parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=1000, type=int) |
122 | 122 | ||
123 | args = parser.parse_args() | 123 | args = parser.parse_args() |
124 | 124 | ||