bowphs commited on
Commit
9cbeb98
·
verified ·
1 Parent(s): af1acfc

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. stanza/demo/Stanza_CoreNLP_Interface.ipynb +485 -0
  2. stanza/demo/en_test.conllu.txt +79 -0
  3. stanza/demo/semgrex visualization.ipynb +367 -0
  4. stanza/images/stanza-logo.png +0 -0
  5. stanza/stanza/__init__.py +27 -0
  6. stanza/stanza/models/charlm.py +357 -0
  7. stanza/stanza/models/identity_lemmatizer.py +66 -0
  8. stanza/stanza/models/lang_identifier.py +236 -0
  9. stanza/stanza/models/mwt_expander.py +322 -0
  10. stanza/stanza/models/ner_tagger.py +492 -0
  11. stanza/stanza/models/tagger.py +461 -0
  12. stanza/stanza/models/tokenizer.py +258 -0
  13. stanza/stanza/models/wl_coref.py +226 -0
  14. stanza/stanza/pipeline/__init__.py +0 -0
  15. stanza/stanza/pipeline/constituency_processor.py +81 -0
  16. stanza/stanza/pipeline/core.py +509 -0
  17. stanza/stanza/pipeline/coref_processor.py +154 -0
  18. stanza/stanza/pipeline/depparse_processor.py +78 -0
  19. stanza/stanza/pipeline/langid_processor.py +127 -0
  20. stanza/stanza/pipeline/lemma_processor.py +126 -0
  21. stanza/stanza/pipeline/multilingual.py +188 -0
  22. stanza/stanza/pipeline/mwt_processor.py +59 -0
  23. stanza/stanza/pipeline/pos_processor.py +89 -0
  24. stanza/stanza/pipeline/processor.py +293 -0
  25. stanza/stanza/pipeline/registry.py +8 -0
  26. stanza/stanza/pipeline/sentiment_processor.py +78 -0
  27. stanza/stanza/pipeline/tokenize_processor.py +185 -0
  28. stanza/stanza/protobuf/CoreNLP_pb2.py +686 -0
  29. stanza/stanza/protobuf/__init__.py +52 -0
  30. stanza/stanza/resources/common.py +619 -0
  31. stanza/stanza/resources/default_packages.py +909 -0
  32. stanza/stanza/resources/installation.py +148 -0
  33. stanza/stanza/resources/prepare_resources.py +670 -0
  34. stanza/stanza/server/__init__.py +10 -0
  35. stanza/stanza/server/annotator.py +138 -0
  36. stanza/stanza/server/client.py +779 -0
  37. stanza/stanza/server/semgrex.py +170 -0
  38. stanza/stanza/server/ssurgeon.py +310 -0
  39. stanza/stanza/tests/__init__.py +111 -0
  40. stanza/stanza/tests/data/tiny_emb.txt +4 -0
  41. stanza/stanza/tests/datasets/test_common.py +76 -0
  42. stanza/stanza/tests/datasets/test_vietnamese_renormalization.py +35 -0
  43. stanza/stanza/tests/depparse/__init__.py +0 -0
  44. stanza/stanza/tests/depparse/test_depparse_data.py +62 -0
  45. stanza/stanza/tests/depparse/test_parser.py +211 -0
  46. stanza/stanza/tests/langid/__init__.py +0 -0
  47. stanza/stanza/tests/langid/test_multilingual.py +104 -0
  48. stanza/stanza/tests/lemma_classifier/__init__.py +0 -0
  49. stanza/stanza/tests/lemma_classifier/test_training.py +53 -0
  50. stanza/stanza/tests/mwt/__init__.py +0 -0
stanza/demo/Stanza_CoreNLP_Interface.ipynb ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "name": "Stanza-CoreNLP-Interface.ipynb",
7
+ "provenance": [],
8
+ "collapsed_sections": [],
9
+ "toc_visible": true
10
+ },
11
+ "kernelspec": {
12
+ "name": "python3",
13
+ "display_name": "Python 3"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "markdown",
19
+ "metadata": {
20
+ "id": "2-4lzQTC9yxG",
21
+ "colab_type": "text"
22
+ },
23
+ "source": [
24
+ "# Stanza: A Tutorial on the Python CoreNLP Interface\n",
25
+ "\n",
26
+ "![Latest Version](https://img.shields.io/pypi/v/stanza.svg?colorB=bc4545)\n",
27
+ "![Python Versions](https://img.shields.io/pypi/pyversions/stanza.svg?colorB=bc4545)\n",
28
+ "\n",
29
+ "While the Stanza library implements accurate neural network modules for basic functionalities such as part-of-speech tagging and dependency parsing, the [Stanford CoreNLP Java library](https://stanfordnlp.github.io/CoreNLP/) has been developed for years and offers more complementary features such as coreference resolution and relation extraction. To unlock these features, the Stanza library also offers an officially maintained Python interface to the CoreNLP Java library. This interface allows you to get NLP anntotations from CoreNLP by writing native Python code.\n",
30
+ "\n",
31
+ "\n",
32
+ "This tutorial walks you through the installation, setup and basic usage of this Python CoreNLP interface. If you want to learn how to use the neural network components in Stanza, please refer to other tutorials."
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "markdown",
37
+ "metadata": {
38
+ "id": "YpKwWeVkASGt",
39
+ "colab_type": "text"
40
+ },
41
+ "source": [
42
+ "## 1. Installation\n",
43
+ "\n",
44
+ "Before the installation starts, please make sure that you have Python 3 and Java installed on your computer. Since Colab already has them installed, we'll skip this procedure in this notebook."
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "markdown",
49
+ "metadata": {
50
+ "id": "k1Az2ECuAfG8",
51
+ "colab_type": "text"
52
+ },
53
+ "source": [
54
+ "### Installing Stanza\n",
55
+ "\n",
56
+ "Installing and importing Stanza are as simple as running the following commands:"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "metadata": {
62
+ "id": "xiFwYAgW4Mss",
63
+ "colab_type": "code",
64
+ "colab": {}
65
+ },
66
+ "source": [
67
+ "# Install stanza; note that the prefix \"!\" is not needed if you are running in a terminal\n",
68
+ "!pip install stanza\n",
69
+ "\n",
70
+ "# Import stanza\n",
71
+ "import stanza"
72
+ ],
73
+ "execution_count": null,
74
+ "outputs": []
75
+ },
76
+ {
77
+ "cell_type": "markdown",
78
+ "metadata": {
79
+ "id": "2zFvaA8_A32_",
80
+ "colab_type": "text"
81
+ },
82
+ "source": [
83
+ "### Setting up Stanford CoreNLP\n",
84
+ "\n",
85
+ "In order for the interface to work, the Stanford CoreNLP library has to be installed and a `CORENLP_HOME` environment variable has to be pointed to the installation location.\n",
86
+ "\n",
87
+ "Here we are going to show you how to download and install the CoreNLP library on your machine, with Stanza's installation command:"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "metadata": {
93
+ "id": "MgK6-LPV-OdA",
94
+ "colab_type": "code",
95
+ "colab": {}
96
+ },
97
+ "source": [
98
+ "# Download the Stanford CoreNLP package with Stanza's installation command\n",
99
+ "# This'll take several minutes, depending on the network speed\n",
100
+ "corenlp_dir = './corenlp'\n",
101
+ "stanza.install_corenlp(dir=corenlp_dir)\n",
102
+ "\n",
103
+ "# Set the CORENLP_HOME environment variable to point to the installation location\n",
104
+ "import os\n",
105
+ "os.environ[\"CORENLP_HOME\"] = corenlp_dir"
106
+ ],
107
+ "execution_count": null,
108
+ "outputs": []
109
+ },
110
+ {
111
+ "cell_type": "markdown",
112
+ "metadata": {
113
+ "id": "Jdq8MT-NAhKj",
114
+ "colab_type": "text"
115
+ },
116
+ "source": [
117
+ "That's all for the installation! 🎉 We can now double check if the installation is successful by listing files in the CoreNLP directory. You should be able to see a number of `.jar` files by running the following command:"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "metadata": {
123
+ "id": "K5eIOaJp_tuo",
124
+ "colab_type": "code",
125
+ "colab": {}
126
+ },
127
+ "source": [
128
+ "# Examine the CoreNLP installation folder to make sure the installation is successful\n",
129
+ "!ls $CORENLP_HOME"
130
+ ],
131
+ "execution_count": null,
132
+ "outputs": []
133
+ },
134
+ {
135
+ "cell_type": "markdown",
136
+ "metadata": {
137
+ "id": "S0xb9BHt__gx",
138
+ "colab_type": "text"
139
+ },
140
+ "source": [
141
+ "**Note 1**:\n",
142
+ "If you are want to use the interface in a terminal (instead of a Colab notebook), you can properly set the `CORENLP_HOME` environment variable with:\n",
143
+ "\n",
144
+ "```bash\n",
145
+ "export CORENLP_HOME=path_to_corenlp_dir\n",
146
+ "```\n",
147
+ "\n",
148
+ "Here we instead set this variable with the Python `os` library, simply because `export` command is not well-supported in Colab notebook.\n",
149
+ "\n",
150
+ "\n",
151
+ "**Note 2**:\n",
152
+ "The `stanza.install_corenlp()` function is only available since Stanza v1.1.1. If you are using an earlier version of Stanza, please check out our [manual installation page](https://stanfordnlp.github.io/stanza/client_setup.html#manual-installation) for how to install CoreNLP on your computer.\n",
153
+ "\n",
154
+ "**Note 3**:\n",
155
+ "Besides the installation function, we also provide a `stanza.download_corenlp_models()` function to help you download additional CoreNLP models for different languages that are not shipped with the default installation. Check out our [automatic installation website page](https://stanfordnlp.github.io/stanza/client_setup.html#automated-installation) for more information on how to use it."
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "markdown",
160
+ "metadata": {
161
+ "id": "xJsuO6D8D05q",
162
+ "colab_type": "text"
163
+ },
164
+ "source": [
165
+ "## 2. Annotating Text with CoreNLP Interface"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "markdown",
170
+ "metadata": {
171
+ "id": "dZNHxXHkH1K2",
172
+ "colab_type": "text"
173
+ },
174
+ "source": [
175
+ "### Constructing CoreNLPClient\n",
176
+ "\n",
177
+ "At a high level, the CoreNLP Python interface works by first starting a background Java CoreNLP server process, and then initializing a client instance in Python which can pass the text to the background server process, and accept the returned annotation results.\n",
178
+ "\n",
179
+ "We wrap these functionalities in a `CoreNLPClient` class. Therefore, we need to start by importing this class from Stanza."
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "metadata": {
185
+ "id": "LS4OKnqJ8wui",
186
+ "colab_type": "code",
187
+ "colab": {}
188
+ },
189
+ "source": [
190
+ "# Import client module\n",
191
+ "from stanza.server import CoreNLPClient"
192
+ ],
193
+ "execution_count": null,
194
+ "outputs": []
195
+ },
196
+ {
197
+ "cell_type": "markdown",
198
+ "metadata": {
199
+ "id": "WP4Dz6PIJHeL",
200
+ "colab_type": "text"
201
+ },
202
+ "source": [
203
+ "After the import is done, we can construct a `CoreNLPClient` instance. The constructor method takes a Python list of annotator names as argument. Here let's explore some basic annotators including tokenization, sentence split, part-of-speech tagging, lemmatization and named entity recognition (NER). \n",
204
+ "\n",
205
+ "Additionally, the client constructor accepts a `memory` argument, which specifies how much memory will be allocated to the background Java process. An `endpoint` option can be used to specify a port number used by the communication between the server and the client. The default port is 9000. However, since this port is pre-occupied by a system process in Colab, we'll manually set it to 9001 in the following example.\n",
206
+ "\n",
207
+ "Also, here we manually set `be_quiet=True` to avoid an IO issue in colab notebook. You should be able to use `be_quiet=False` on your own computer, which will print detailed logging information from CoreNLP during usage.\n",
208
+ "\n",
209
+ "For more options in constructing the clients, please refer to the [CoreNLP Client Options List](https://stanfordnlp.github.io/stanza/corenlp_client.html#corenlp-client-options)."
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "metadata": {
215
+ "id": "mbOBugvd9JaM",
216
+ "colab_type": "code",
217
+ "colab": {}
218
+ },
219
+ "source": [
220
+ "# Construct a CoreNLPClient with some basic annotators, a memory allocation of 4GB, and port number 9001\n",
221
+ "client = CoreNLPClient(\n",
222
+ " annotators=['tokenize','ssplit', 'pos', 'lemma', 'ner'], \n",
223
+ " memory='4G', \n",
224
+ " endpoint='http://localhost:9001',\n",
225
+ " be_quiet=True)\n",
226
+ "print(client)\n",
227
+ "\n",
228
+ "# Start the background server and wait for some time\n",
229
+ "# Note that in practice this is totally optional, as by default the server will be started when the first annotation is performed\n",
230
+ "client.start()\n",
231
+ "import time; time.sleep(10)"
232
+ ],
233
+ "execution_count": null,
234
+ "outputs": []
235
+ },
236
+ {
237
+ "cell_type": "markdown",
238
+ "metadata": {
239
+ "id": "kgTiVjNydmIW",
240
+ "colab_type": "text"
241
+ },
242
+ "source": [
243
+ "After the above code block finishes executing, if you print the background processes, you should be able to find the Java CoreNLP server running."
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "metadata": {
249
+ "id": "spZrJ-oFdkdF",
250
+ "colab_type": "code",
251
+ "colab": {}
252
+ },
253
+ "source": [
254
+ "# Print background processes and look for java\n",
255
+ "# You should be able to see a StanfordCoreNLPServer java process running in the background\n",
256
+ "!ps -o pid,cmd | grep java"
257
+ ],
258
+ "execution_count": null,
259
+ "outputs": []
260
+ },
261
+ {
262
+ "cell_type": "markdown",
263
+ "metadata": {
264
+ "id": "KxJeJ0D2LoOs",
265
+ "colab_type": "text"
266
+ },
267
+ "source": [
268
+ "### Annotating Text\n",
269
+ "\n",
270
+ "Annotating a piece of text is as simple as passing the text into an `annotate` function of the client object. After the annotation is complete, a `Document` object will be returned with all annotations.\n",
271
+ "\n",
272
+ "Note that although in general annotations are very fast, the first annotation might take a while to complete in the notebook. Please stay patient."
273
+ ]
274
+ },
275
+ {
276
+ "cell_type": "code",
277
+ "metadata": {
278
+ "id": "s194RnNg5z95",
279
+ "colab_type": "code",
280
+ "colab": {}
281
+ },
282
+ "source": [
283
+ "# Annotate some text\n",
284
+ "text = \"Albert Einstein was a German-born theoretical physicist. He developed the theory of relativity.\"\n",
285
+ "document = client.annotate(text)\n",
286
+ "print(type(document))"
287
+ ],
288
+ "execution_count": null,
289
+ "outputs": []
290
+ },
291
+ {
292
+ "cell_type": "markdown",
293
+ "metadata": {
294
+ "id": "semmA3e0TcM1",
295
+ "colab_type": "text"
296
+ },
297
+ "source": [
298
+ "## 3. Accessing Annotations\n",
299
+ "\n",
300
+ "Annotations can be accessed from the returned `Document` object.\n",
301
+ "\n",
302
+ "A `Document` contains a list of `Sentence`s, which contain a list of `Token`s. Here let's first explore the annotations stored in all tokens."
303
+ ]
304
+ },
305
+ {
306
+ "cell_type": "code",
307
+ "metadata": {
308
+ "id": "lIO4B5d6Rk4I",
309
+ "colab_type": "code",
310
+ "colab": {}
311
+ },
312
+ "source": [
313
+ "# Iterate over all tokens in all sentences, and print out the word, lemma, pos and ner tags\n",
314
+ "print(\"{:12s}\\t{:12s}\\t{:6s}\\t{}\".format(\"Word\", \"Lemma\", \"POS\", \"NER\"))\n",
315
+ "\n",
316
+ "for i, sent in enumerate(document.sentence):\n",
317
+ " print(\"[Sentence {}]\".format(i+1))\n",
318
+ " for t in sent.token:\n",
319
+ " print(\"{:12s}\\t{:12s}\\t{:6s}\\t{}\".format(t.word, t.lemma, t.pos, t.ner))\n",
320
+ " print(\"\")"
321
+ ],
322
+ "execution_count": null,
323
+ "outputs": []
324
+ },
325
+ {
326
+ "cell_type": "markdown",
327
+ "metadata": {
328
+ "id": "msrJfvu8VV9m",
329
+ "colab_type": "text"
330
+ },
331
+ "source": [
332
+ "Alternatively, you can also browse the NER results by iterating over entity mentions over the sentences. For example:"
333
+ ]
334
+ },
335
+ {
336
+ "cell_type": "code",
337
+ "metadata": {
338
+ "id": "ezEjc9LeV2Xs",
339
+ "colab_type": "code",
340
+ "colab": {}
341
+ },
342
+ "source": [
343
+ "# Iterate over all detected entity mentions\n",
344
+ "print(\"{:30s}\\t{}\".format(\"Mention\", \"Type\"))\n",
345
+ "\n",
346
+ "for sent in document.sentence:\n",
347
+ " for m in sent.mentions:\n",
348
+ " print(\"{:30s}\\t{}\".format(m.entityMentionText, m.entityType))"
349
+ ],
350
+ "execution_count": null,
351
+ "outputs": []
352
+ },
353
+ {
354
+ "cell_type": "markdown",
355
+ "metadata": {
356
+ "id": "ueGzBZ3hWzkN",
357
+ "colab_type": "text"
358
+ },
359
+ "source": [
360
+ "To print all annotations a sentence, token or mention has, you can simply print the corresponding obejct."
361
+ ]
362
+ },
363
+ {
364
+ "cell_type": "code",
365
+ "metadata": {
366
+ "id": "4_S8o2BHXIed",
367
+ "colab_type": "code",
368
+ "colab": {}
369
+ },
370
+ "source": [
371
+ "# Print annotations of a token\n",
372
+ "print(document.sentence[0].token[0])\n",
373
+ "\n",
374
+ "# Print annotations of a mention\n",
375
+ "print(document.sentence[0].mentions[0])"
376
+ ],
377
+ "execution_count": null,
378
+ "outputs": []
379
+ },
380
+ {
381
+ "cell_type": "markdown",
382
+ "metadata": {
383
+ "id": "Qp66wjZ10xia",
384
+ "colab_type": "text"
385
+ },
386
+ "source": [
387
+ "**Note**: Since the Stanza CoreNLP client interface simply ports the CoreNLP annotation results to native Python objects, for a comprehensive lists of available annotators and how their annotation results can be accessed, you will need to visit the [Stanford CoreNLP website](https://stanfordnlp.github.io/CoreNLP/)."
388
+ ]
389
+ },
390
+ {
391
+ "cell_type": "markdown",
392
+ "metadata": {
393
+ "id": "IPqzMK90X0w3",
394
+ "colab_type": "text"
395
+ },
396
+ "source": [
397
+ "## 4. Shutting Down the CoreNLP Server\n",
398
+ "\n",
399
+ "To shut down the background CoreNLP server process, simply call the `stop` function of the client. Note that once a server is shutdown, you'll have to restart the server with the `start()` function before any annotation is requested."
400
+ ]
401
+ },
402
+ {
403
+ "cell_type": "code",
404
+ "metadata": {
405
+ "id": "xrJq8lZ3Nw7b",
406
+ "colab_type": "code",
407
+ "colab": {}
408
+ },
409
+ "source": [
410
+ "# Shut down the background CoreNLP server\n",
411
+ "client.stop()\n",
412
+ "\n",
413
+ "time.sleep(10)\n",
414
+ "!ps -o pid,cmd | grep java"
415
+ ],
416
+ "execution_count": null,
417
+ "outputs": []
418
+ },
419
+ {
420
+ "cell_type": "markdown",
421
+ "metadata": {
422
+ "id": "23Vwa_ifYfF7",
423
+ "colab_type": "text"
424
+ },
425
+ "source": [
426
+ "### More Information\n",
427
+ "\n",
428
+ "For more information on how to use the `CoreNLPClient`, please go to the [CoreNLPClient documentation page](https://stanfordnlp.github.io/stanza/corenlp_client.html)."
429
+ ]
430
+ },
431
+ {
432
+ "cell_type": "markdown",
433
+ "metadata": {
434
+ "id": "YUrVT6kA_Bzx",
435
+ "colab_type": "text"
436
+ },
437
+ "source": [
438
+ "## 5. Simplifying Client Usage with the Python `with` statement\n",
439
+ "\n",
440
+ "In the above demo, we explicitly called the `client.start()` and `client.stop()` functions to start and stop a client-server connection. However, doing this in practice is usually suboptimal, since you may forget to call the `stop()` function at the end, resulting in an unused server process occupying your machine memory.\n",
441
+ "\n",
442
+ "To solve is, a simple solution is to use the client interface with the [Python `with` statement](https://docs.python.org/3/reference/compound_stmts.html#the-with-statement). The `with` statement provides an elegant way to automatically start and stop the server process in your Python program, without you needing to worry about this. The following code snippet demonstrates how to establish a client, annotate an example text and then stop the server with a simple `with` statement. Note that we **always recommend** you to use the `with` statement when working with the Stanza CoreNLP client interface."
443
+ ]
444
+ },
445
+ {
446
+ "cell_type": "code",
447
+ "metadata": {
448
+ "id": "H0ct2-R4AvJh",
449
+ "colab_type": "code",
450
+ "colab": {}
451
+ },
452
+ "source": [
453
+ "print(\"Starting a server with the Python \\\"with\\\" statement...\")\n",
454
+ "with CoreNLPClient(annotators=['tokenize','ssplit', 'pos', 'lemma', 'ner'], \n",
455
+ " memory='4G', endpoint='http://localhost:9001', be_quiet=True) as client:\n",
456
+ " text = \"Albert Einstein was a German-born theoretical physicist.\"\n",
457
+ " document = client.annotate(text)\n",
458
+ "\n",
459
+ " print(\"{:30s}\\t{}\".format(\"Mention\", \"Type\"))\n",
460
+ " for sent in document.sentence:\n",
461
+ " for m in sent.mentions:\n",
462
+ " print(\"{:30s}\\t{}\".format(m.entityMentionText, m.entityType))\n",
463
+ "\n",
464
+ "print(\"\\nThe server should be stopped upon exit from the \\\"with\\\" statement.\")"
465
+ ],
466
+ "execution_count": null,
467
+ "outputs": []
468
+ },
469
+ {
470
+ "cell_type": "markdown",
471
+ "metadata": {
472
+ "id": "W435Lwc4YqKb",
473
+ "colab_type": "text"
474
+ },
475
+ "source": [
476
+ "## 6. Other Resources\n",
477
+ "\n",
478
+ "- [Stanza Homepage](https://stanfordnlp.github.io/stanza/)\n",
479
+ "- [FAQs](https://stanfordnlp.github.io/stanza/faq.html)\n",
480
+ "- [GitHub Repo](https://github.com/stanfordnlp/stanza)\n",
481
+ "- [Reporting Issues](https://github.com/stanfordnlp/stanza/issues)\n"
482
+ ]
483
+ }
484
+ ]
485
+ }
stanza/demo/en_test.conllu.txt ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # newdoc id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200
2
+ # sent_id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200-0001
3
+ # newpar id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200-p0001
4
+ # text = What if Google Morphed Into GoogleOS?
5
+ 1 What what PRON WP PronType=Int 0 root 0:root _
6
+ 2 if if SCONJ IN _ 4 mark 4:mark _
7
+ 3 Google Google PROPN NNP Number=Sing 4 nsubj 4:nsubj _
8
+ 4 Morphed morph VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 1 advcl 1:advcl:if _
9
+ 5 Into into ADP IN _ 6 case 6:case _
10
+ 6 GoogleOS GoogleOS PROPN NNP Number=Sing 4 obl 4:obl:into SpaceAfter=No
11
+ 7 ? ? PUNCT . _ 4 punct 4:punct _
12
+
13
+ # sent_id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200-0002
14
+ # text = What if Google expanded on its search-engine (and now e-mail) wares into a full-fledged operating system?
15
+ 1 What what PRON WP PronType=Int 0 root 0:root _
16
+ 2 if if SCONJ IN _ 4 mark 4:mark _
17
+ 3 Google Google PROPN NNP Number=Sing 4 nsubj 4:nsubj _
18
+ 4 expanded expand VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 1 advcl 1:advcl:if _
19
+ 5 on on ADP IN _ 15 case 15:case _
20
+ 6 its its PRON PRP$ Gender=Neut|Number=Sing|Person=3|Poss=Yes|PronType=Prs 15 nmod:poss 15:nmod:poss _
21
+ 7 search search NOUN NN Number=Sing 9 compound 9:compound SpaceAfter=No
22
+ 8 - - PUNCT HYPH _ 9 punct 9:punct SpaceAfter=No
23
+ 9 engine engine NOUN NN Number=Sing 15 compound 15:compound _
24
+ 10 ( ( PUNCT -LRB- _ 9 punct 9:punct SpaceAfter=No
25
+ 11 and and CCONJ CC _ 13 cc 13:cc _
26
+ 12 now now ADV RB _ 13 advmod 13:advmod _
27
+ 13 e-mail e-mail NOUN NN Number=Sing 9 conj 9:conj:and|15:compound SpaceAfter=No
28
+ 14 ) ) PUNCT -RRB- _ 15 punct 15:punct _
29
+ 15 wares wares NOUN NNS Number=Plur 4 obl 4:obl:on _
30
+ 16 into into ADP IN _ 22 case 22:case _
31
+ 17 a a DET DT Definite=Ind|PronType=Art 22 det 22:det _
32
+ 18 full full ADV RB _ 20 advmod 20:advmod SpaceAfter=No
33
+ 19 - - PUNCT HYPH _ 20 punct 20:punct SpaceAfter=No
34
+ 20 fledged fledged ADJ JJ Degree=Pos 22 amod 22:amod _
35
+ 21 operating operating NOUN NN Number=Sing 22 compound 22:compound _
36
+ 22 system system NOUN NN Number=Sing 4 obl 4:obl:into SpaceAfter=No
37
+ 23 ? ? PUNCT . _ 4 punct 4:punct _
38
+
39
+ # sent_id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200-0003
40
+ # text = [via Microsoft Watch from Mary Jo Foley ]
41
+ 1 [ [ PUNCT -LRB- _ 4 punct 4:punct SpaceAfter=No
42
+ 2 via via ADP IN _ 4 case 4:case _
43
+ 3 Microsoft Microsoft PROPN NNP Number=Sing 4 compound 4:compound _
44
+ 4 Watch Watch PROPN NNP Number=Sing 0 root 0:root _
45
+ 5 from from ADP IN _ 6 case 6:case _
46
+ 6 Mary Mary PROPN NNP Number=Sing 4 nmod 4:nmod:from _
47
+ 7 Jo Jo PROPN NNP Number=Sing 6 flat 6:flat _
48
+ 8 Foley Foley PROPN NNP Number=Sing 6 flat 6:flat _
49
+ 9 ] ] PUNCT -RRB- _ 4 punct 4:punct _
50
+
51
+ # newdoc id = weblog-blogspot.com_marketview_20050511222700_ENG_20050511_222700
52
+ # sent_id = weblog-blogspot.com_marketview_20050511222700_ENG_20050511_222700-0001
53
+ # newpar id = weblog-blogspot.com_marketview_20050511222700_ENG_20050511_222700-p0001
54
+ # text = (And, by the way, is anybody else just a little nostalgic for the days when that was a good thing?)
55
+ 1 ( ( PUNCT -LRB- _ 14 punct 14:punct SpaceAfter=No
56
+ 2 And and CCONJ CC _ 14 cc 14:cc SpaceAfter=No
57
+ 3 , , PUNCT , _ 14 punct 14:punct _
58
+ 4 by by ADP IN _ 6 case 6:case _
59
+ 5 the the DET DT Definite=Def|PronType=Art 6 det 6:det _
60
+ 6 way way NOUN NN Number=Sing 14 obl 14:obl:by SpaceAfter=No
61
+ 7 , , PUNCT , _ 14 punct 14:punct _
62
+ 8 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 14 cop 14:cop _
63
+ 9 anybody anybody PRON NN Number=Sing 14 nsubj 14:nsubj _
64
+ 10 else else ADJ JJ Degree=Pos 9 amod 9:amod _
65
+ 11 just just ADV RB _ 13 advmod 13:advmod _
66
+ 12 a a DET DT Definite=Ind|PronType=Art 13 det 13:det _
67
+ 13 little little ADJ JJ Degree=Pos 14 obl:npmod 14:obl:npmod _
68
+ 14 nostalgic nostalgic NOUN NN Number=Sing 0 root 0:root _
69
+ 15 for for ADP IN _ 17 case 17:case _
70
+ 16 the the DET DT Definite=Def|PronType=Art 17 det 17:det _
71
+ 17 days day NOUN NNS Number=Plur 14 nmod 14:nmod:for|23:obl:npmod _
72
+ 18 when when ADV WRB PronType=Rel 23 advmod 17:ref _
73
+ 19 that that PRON DT Number=Sing|PronType=Dem 23 nsubj 23:nsubj _
74
+ 20 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 23 cop 23:cop _
75
+ 21 a a DET DT Definite=Ind|PronType=Art 23 det 23:det _
76
+ 22 good good ADJ JJ Degree=Pos 23 amod 23:amod _
77
+ 23 thing thing NOUN NN Number=Sing 17 acl:relcl 17:acl:relcl SpaceAfter=No
78
+ 24 ? ? PUNCT . _ 14 punct 14:punct SpaceAfter=No
79
+ 25 ) ) PUNCT -RRB- _ 14 punct 14:punct _
stanza/demo/semgrex visualization.ipynb ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "2787d5f5",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import stanza\n",
11
+ "from stanza.server.semgrex import Semgrex\n",
12
+ "from stanza.models.common.constant import is_right_to_left\n",
13
+ "import spacy\n",
14
+ "from spacy import displacy\n",
15
+ "from spacy.tokens import Doc\n",
16
+ "from IPython.display import display, HTML\n",
17
+ "\n",
18
+ "\n",
19
+ "\"\"\"\n",
20
+ "IMPORTANT: For the code in this module to run, you must have corenlp and Java installed on your machine. Additionally,\n",
21
+ "set an environment variable CLASSPATH equal to the path of your corenlp directory.\n",
22
+ "\n",
23
+ "Example: CLASSPATH=C:\\\\Users\\\\Alex\\\\PycharmProjects\\\\pythonProject\\\\stanford-corenlp-4.5.0\\\\stanford-corenlp-4.5.0\\\\*\n",
24
+ "\"\"\"\n",
25
+ "\n",
26
+ "%env CLASSPATH=C:\\\\stanford-corenlp-4.5.2\\\\stanford-corenlp-4.5.2\\\\*\n",
27
+ "def get_sentences_html(doc, language):\n",
28
+ " \"\"\"\n",
29
+ " Returns a list of the HTML strings of the dependency visualizations of a given stanza doc object.\n",
30
+ "\n",
31
+ " The 'language' arg is the two-letter language code for the document to be processed.\n",
32
+ "\n",
33
+ " First converts the stanza doc object to a spacy doc object and uses displacy to generate an HTML\n",
34
+ " string for each sentence of the doc object.\n",
35
+ " \"\"\"\n",
36
+ " html_strings = []\n",
37
+ "\n",
38
+ " # blank model - we don't use any of the model features, just the visualization\n",
39
+ " nlp = spacy.blank(\"en\")\n",
40
+ " sentences_to_visualize = []\n",
41
+ " for sentence in doc.sentences:\n",
42
+ " words, lemmas, heads, deps, tags = [], [], [], [], []\n",
43
+ " if is_right_to_left(language): # order of words displayed is reversed, dependency arcs remain intact\n",
44
+ " sent_len = len(sentence.words)\n",
45
+ " for word in reversed(sentence.words):\n",
46
+ " words.append(word.text)\n",
47
+ " lemmas.append(word.lemma)\n",
48
+ " deps.append(word.deprel)\n",
49
+ " tags.append(word.upos)\n",
50
+ " if word.head == 0: # spaCy head indexes are formatted differently than that of Stanza\n",
51
+ " heads.append(sent_len - word.id)\n",
52
+ " else:\n",
53
+ " heads.append(sent_len - word.head)\n",
54
+ " else: # left to right rendering\n",
55
+ " for word in sentence.words:\n",
56
+ " words.append(word.text)\n",
57
+ " lemmas.append(word.lemma)\n",
58
+ " deps.append(word.deprel)\n",
59
+ " tags.append(word.upos)\n",
60
+ " if word.head == 0:\n",
61
+ " heads.append(word.id - 1)\n",
62
+ " else:\n",
63
+ " heads.append(word.head - 1)\n",
64
+ " document_result = Doc(nlp.vocab, words=words, lemmas=lemmas, heads=heads, deps=deps, pos=tags)\n",
65
+ " sentences_to_visualize.append(document_result)\n",
66
+ "\n",
67
+ " for line in sentences_to_visualize: # render all sentences through displaCy\n",
68
+ " html_strings.append(displacy.render(line, style=\"dep\",\n",
69
+ " options={\"compact\": True, \"word_spacing\": 30, \"distance\": 100,\n",
70
+ " \"arrow_spacing\": 20}, jupyter=False))\n",
71
+ " return html_strings\n",
72
+ "\n",
73
+ "\n",
74
+ "def find_nth(haystack, needle, n):\n",
75
+ " \"\"\"\n",
76
+ " Returns the starting index of the nth occurrence of the substring 'needle' in the string 'haystack'.\n",
77
+ " \"\"\"\n",
78
+ " start = haystack.find(needle)\n",
79
+ " while start >= 0 and n > 1:\n",
80
+ " start = haystack.find(needle, start + len(needle))\n",
81
+ " n -= 1\n",
82
+ " return start\n",
83
+ "\n",
84
+ "\n",
85
+ "def round_base(num, base=10):\n",
86
+ " \"\"\"\n",
87
+ " Rounding a number to its nearest multiple of the base. round_base(49.2, base=50) = 50.\n",
88
+ " \"\"\"\n",
89
+ " return base * round(num/base)\n",
90
+ "\n",
91
+ "\n",
92
+ "def process_sentence_html(orig_html, semgrex_sentence):\n",
93
+ " \"\"\"\n",
94
+ " Takes a semgrex sentence object and modifies the HTML of the original sentence's deprel visualization,\n",
95
+ " highlighting words involved in the search queries and adding the label of the word inside of the semgrex match.\n",
96
+ "\n",
97
+ " Returns the modified html string of the sentence's deprel visualization.\n",
98
+ " \"\"\"\n",
99
+ " tracker = {} # keep track of which words have multiple labels\n",
100
+ " DEFAULT_TSPAN_COUNT = 2 # the original displacy html assigns two <tspan> objects per <text> object\n",
101
+ " CLOSING_TSPAN_LEN = 8 # </tspan> is 8 chars long\n",
102
+ " colors = ['#4477AA', '#66CCEE', '#228833', '#CCBB44', '#EE6677', '#AA3377', '#BBBBBB']\n",
103
+ " css_bolded_class = \"<style> .bolded{font-weight: bold;} </style>\\n\"\n",
104
+ " found_index = orig_html.find(\"\\n\") # returns index where the opening <svg> ends\n",
105
+ " # insert the new style class into html string\n",
106
+ " orig_html = orig_html[: found_index + 1] + css_bolded_class + orig_html[found_index + 1:]\n",
107
+ "\n",
108
+ " # Add color to words in the match, bold words in the match\n",
109
+ " for query in semgrex_sentence.result:\n",
110
+ " for i, match in enumerate(query.match):\n",
111
+ " color = colors[i]\n",
112
+ " paired_dy = 2\n",
113
+ " for node in match.node:\n",
114
+ " name, match_index = node.name, node.matchIndex\n",
115
+ " # edit existing <tspan> to change color and bold the text\n",
116
+ " start = find_nth(orig_html, \"<text\", match_index) # finds start of svg <text> of interest\n",
117
+ " if match_index not in tracker: # if we've already bolded and colored, keep the first color\n",
118
+ " tspan_start = orig_html.find(\"<tspan\",\n",
119
+ " start) # finds start of the first svg <tspan> inside of the <text>\n",
120
+ " tspan_end = orig_html.find(\"</tspan>\", start) # finds start of the end of the above <tspan>\n",
121
+ " tspan_substr = orig_html[tspan_start: tspan_end + CLOSING_TSPAN_LEN + 1] + \"\\n\"\n",
122
+ " # color words in the hit and bold words in the hit\n",
123
+ " edited_tspan = tspan_substr.replace('class=\"displacy-word\"', 'class=\"bolded\"').replace(\n",
124
+ " 'fill=\"currentColor\"', f'fill=\"{color}\"')\n",
125
+ " # insert edited <tspan> object into html string\n",
126
+ " orig_html = orig_html[: tspan_start] + edited_tspan + orig_html[tspan_end + CLOSING_TSPAN_LEN + 2:]\n",
127
+ " tracker[match_index] = DEFAULT_TSPAN_COUNT\n",
128
+ "\n",
129
+ " # next, we have to insert the new <tspan> object for the label\n",
130
+ " # Copy old <tspan> to copy formatting when creating new <tspan> later\n",
131
+ " prev_tspan_start = find_nth(orig_html[start:], \"<tspan\",\n",
132
+ " tracker[match_index] - 1) + start # find the previous <tspan> start index\n",
133
+ " prev_tspan_end = find_nth(orig_html[start:], \"</tspan>\",\n",
134
+ " tracker[match_index] - 1) + start # find the prev </tspan> start index\n",
135
+ " prev_tspan = orig_html[prev_tspan_start: prev_tspan_end + CLOSING_TSPAN_LEN + 1]\n",
136
+ "\n",
137
+ " # Find spot to insert new tspan\n",
138
+ " closing_tspan_start = find_nth(orig_html[start:], \"</tspan>\", tracker[match_index]) + start\n",
139
+ " up_to_new_tspan = orig_html[: closing_tspan_start + CLOSING_TSPAN_LEN + 1]\n",
140
+ " rest_need_add_newline = orig_html[closing_tspan_start + CLOSING_TSPAN_LEN + 1:]\n",
141
+ "\n",
142
+ " # Calculate proper x value in svg\n",
143
+ " x_value_start = prev_tspan.find('x=\"')\n",
144
+ " x_value_end = prev_tspan[x_value_start + 3:].find('\"') + 3 # 3 is the length of the 'x=\"' substring\n",
145
+ " x_value = prev_tspan[x_value_start + 3: x_value_end + x_value_start]\n",
146
+ "\n",
147
+ " # Calculate proper y value in svg\n",
148
+ " DEFAULT_DY_VAL, dy = 2, 2\n",
149
+ " if paired_dy != DEFAULT_DY_VAL and node == match.node[\n",
150
+ " 1]: # we're on the second node and need to adjust height to match the paired node\n",
151
+ " dy = paired_dy\n",
152
+ " if node == match.node[0]:\n",
153
+ " paired_node_level = 2\n",
154
+ " if match.node[1].matchIndex in tracker: # check if we need to adjust heights of labels\n",
155
+ " paired_node_level = tracker[match.node[1].matchIndex]\n",
156
+ " dif = tracker[match_index] - paired_node_level\n",
157
+ " if dif > 0: # current node has more labels\n",
158
+ " paired_dy = DEFAULT_DY_VAL * dif + 1\n",
159
+ " dy = DEFAULT_DY_VAL\n",
160
+ " else: # paired node has more labels, adjust this label down\n",
161
+ " dy = DEFAULT_DY_VAL * (abs(dif) + 1)\n",
162
+ " paired_dy = DEFAULT_DY_VAL\n",
163
+ "\n",
164
+ " # Insert new <tspan> object\n",
165
+ " new_tspan = f' <tspan class=\"displacy-word\" dy=\"{dy}em\" fill=\"{color}\" x={x_value}>{name[: 3].title()}.</tspan>\\n' # abbreviate label names to 3 chars\n",
166
+ " orig_html = up_to_new_tspan + new_tspan + rest_need_add_newline\n",
167
+ " tracker[match_index] += 1\n",
168
+ " return orig_html\n",
169
+ "\n",
170
+ "\n",
171
+ "def render_html_strings(edited_html_strings):\n",
172
+ " \"\"\"\n",
173
+ " Renders the HTML to make the edits visible\n",
174
+ " \"\"\"\n",
175
+ " for html_string in edited_html_strings:\n",
176
+ " display(HTML(html_string))\n",
177
+ "\n",
178
+ "\n",
179
+ "def visualize_search_doc(doc, semgrex_queries, lang_code, start_match=0, end_match=10):\n",
180
+ " \"\"\"\n",
181
+ " Visualizes the semgrex results of running semgrex search on a stanza doc object with the given list of\n",
182
+ " semgrex queries. Returns a list of the edited HTML strings from the doc. Each element in the list represents\n",
183
+ " the HTML to render one of the sentences in the document.\n",
184
+ "\n",
185
+ " 'lang_code' is the two-letter language abbreviation for the language that the stanza doc object is written in.\n",
186
+ "\n",
187
+ "\n",
188
+ " 'start_match' and 'end_match' determine which matches to visualize. Works similar to splices, so that\n",
189
+ " start_match=0 and end_match=10 will display the first 10 semgrex matches.\n",
190
+ " \"\"\"\n",
191
+ " matches_count = 0 # Limits number of visualizations\n",
192
+ " with Semgrex(classpath=\"$CLASSPATH\") as sem:\n",
193
+ " edited_html_strings = []\n",
194
+ " semgrex_results = sem.process(doc, *semgrex_queries)\n",
195
+ " # one html string for each sentence\n",
196
+ " unedited_html_strings = get_sentences_html(doc, lang_code)\n",
197
+ " for i in range(len(unedited_html_strings)):\n",
198
+ "\n",
199
+ " if matches_count >= end_match: # we've collected enough matches, stop early\n",
200
+ " break\n",
201
+ "\n",
202
+ " # check if sentence has matches, if not then do not visualize\n",
203
+ " has_none = True\n",
204
+ " for query in semgrex_results.result[i].result:\n",
205
+ " for match in query.match:\n",
206
+ " if match:\n",
207
+ " has_none = False\n",
208
+ "\n",
209
+ " # Process HTML if queries have matches\n",
210
+ " if not has_none:\n",
211
+ " if start_match <= matches_count < end_match:\n",
212
+ " edited_string = process_sentence_html(unedited_html_strings[i], semgrex_results.result[i])\n",
213
+ " edited_string = adjust_dep_arrows(edited_string)\n",
214
+ " edited_html_strings.append(edited_string)\n",
215
+ " matches_count += 1\n",
216
+ "\n",
217
+ " render_html_strings(edited_html_strings)\n",
218
+ " return edited_html_strings\n",
219
+ "\n",
220
+ "\n",
221
+ "def visualize_search_str(text, semgrex_queries, lang_code):\n",
222
+ " \"\"\"\n",
223
+ " Visualizes the deprel of the semgrex results from running semgrex search on a string with the given list of\n",
224
+ " semgrex queries. Returns a list of the edited HTML strings. Each element in the list represents\n",
225
+ " the HTML to render one of the sentences in the document.\n",
226
+ "\n",
227
+ " Internally, this function converts the string into a stanza doc object before processing the doc object.\n",
228
+ "\n",
229
+ " 'lang_code' is the two-letter language abbreviation for the language that the stanza doc object is written in.\n",
230
+ " \"\"\"\n",
231
+ " nlp = stanza.Pipeline(lang_code, processors=\"tokenize, pos, lemma, depparse\")\n",
232
+ " doc = nlp(text)\n",
233
+ " return visualize_search_doc(doc, semgrex_queries, lang_code)\n",
234
+ "\n",
235
+ "\n",
236
+ "def adjust_dep_arrows(raw_html):\n",
237
+ " \"\"\"\n",
238
+ " The default spaCy dependency visualization has misaligned arrows.\n",
239
+ " We fix arrows by aligning arrow ends and bodies to the word that they are directed to. If a word has an\n",
240
+ " arrowhead that is pointing not directly on the word's center, align the arrowhead to match the center of the word.\n",
241
+ "\n",
242
+ " returns the edited html with fixed arrow placement\n",
243
+ " \"\"\"\n",
244
+ " HTML_ARROW_BEGINNING = '<g class=\"displacy-arrow\">'\n",
245
+ " HTML_ARROW_ENDING = \"</g>\"\n",
246
+ " HTML_ARROW_ENDING_LEN = 6 # there are 2 newline chars after the arrow ending\n",
247
+ " arrows_start_idx = find_nth(haystack=raw_html, needle='<g class=\"displacy-arrow\">', n=1)\n",
248
+ " words_html, arrows_html = raw_html[: arrows_start_idx], raw_html[arrows_start_idx:] # separate html for words and arrows\n",
249
+ " final_html = words_html # continually concatenate to this after processing each arrow\n",
250
+ " arrow_number = 1 # which arrow we're editing (1-indexed)\n",
251
+ " start_idx, end_of_class_idx = find_nth(haystack=arrows_html, needle=HTML_ARROW_BEGINNING, n=arrow_number), find_nth(arrows_html, HTML_ARROW_ENDING, arrow_number)\n",
252
+ " while start_idx != -1: # edit every arrow\n",
253
+ " arrow_section = arrows_html[start_idx: end_of_class_idx + HTML_ARROW_ENDING_LEN] # slice a single svg arrow object\n",
254
+ " if arrow_section[-1] == \"<\": # this is the last arrow in the HTML, don't cut the splice early\n",
255
+ " arrow_section = arrows_html[start_idx:]\n",
256
+ " edited_arrow_section = edit_dep_arrow(arrow_section)\n",
257
+ "\n",
258
+ " final_html = final_html + edited_arrow_section # continually update html with new arrow html until done\n",
259
+ "\n",
260
+ " # Prepare for next iteration\n",
261
+ " arrow_number += 1\n",
262
+ " start_idx = find_nth(arrows_html, '<g class=\"displacy-arrow\">', n=arrow_number)\n",
263
+ " end_of_class_idx = find_nth(arrows_html, \"</g>\", arrow_number)\n",
264
+ " return final_html\n",
265
+ "\n",
266
+ "\n",
267
+ "def edit_dep_arrow(arrow_html):\n",
268
+ " \"\"\"\n",
269
+ " The formatting of a displacy arrow in svg is the following:\n",
270
+ " <g class=\"displacy-arrow\">\n",
271
+ " <path class=\"displacy-arc\" id=\"arrow-c628889ffbf343e3848193a08606f10a-0-0\" stroke-width=\"2px\" d=\"M70,352.0 C70,177.0 390.0,177.0 390.0,352.0\" fill=\"none\" stroke=\"currentColor\"/>\n",
272
+ " <text dy=\"1.25em\" style=\"font-size: 0.8em; letter-spacing: 1px\">\n",
273
+ " <textPath xlink:href=\"#arrow-c628889ffbf343e3848193a08606f10a-0-0\" class=\"displacy-label\" startOffset=\"50%\" side=\"left\" fill=\"currentColor\" text-anchor=\"middle\">csubj</textPath>\n",
274
+ " </text>\n",
275
+ " <path class=\"displacy-arrowhead\" d=\"M70,354.0 L62,342.0 78,342.0\" fill=\"currentColor\"/>\n",
276
+ " </g>\n",
277
+ "\n",
278
+ " We edit the 'd = ...' parts of the <path class ...> section to fix the arrow direction and length\n",
279
+ "\n",
280
+ " returns the arrow_html with distances fixed\n",
281
+ " \"\"\"\n",
282
+ " WORD_SPACING = 50 # words start at x=50 and are separated by 100s so their x values are multiples of 50\n",
283
+ " M_OFFSET = 4 # length of 'd=\"M' that we search for to extract the number from d=\"M70, for instance\n",
284
+ " ARROW_PIXEL_SIZE = 4\n",
285
+ " first_d_idx, second_d_idx = find_nth(arrow_html, 'd=\"M', 1), find_nth(arrow_html, 'd=\"M', 2) # find where d=\"M starts\n",
286
+ " first_d_cutoff, second_d_cutoff = arrow_html.find(\",\", first_d_idx), arrow_html.find(\",\", second_d_idx) # isolate the number after 'M' e.g. 'M70'\n",
287
+ " # gives svg x values of arrow body starting position and arrowhead position\n",
288
+ " arrow_position, arrowhead_position = float(arrow_html[first_d_idx + M_OFFSET: first_d_cutoff]), float(arrow_html[second_d_idx + M_OFFSET: second_d_cutoff])\n",
289
+ " # gives starting index of where 'fill=\"none\"' or 'fill=\"currentColor\"' begin, reference points to end the d= section\n",
290
+ " first_fill_start_idx, second_fill_start_idx = find_nth(arrow_html, \"fill\", n=1), find_nth(arrow_html, \"fill\", n=3)\n",
291
+ "\n",
292
+ " # isolate the d= ... section to edit\n",
293
+ " first_d, second_d = arrow_html[first_d_idx: first_fill_start_idx], arrow_html[second_d_idx: second_fill_start_idx]\n",
294
+ " first_d_split, second_d_split = first_d.split(\",\"), second_d.split(\",\")\n",
295
+ "\n",
296
+ " if arrow_position == arrowhead_position: # This arrow is incoming onto the word, center the arrow/head to word center\n",
297
+ " corrected_arrow_pos = corrected_arrowhead_pos = round_base(arrow_position, base=WORD_SPACING)\n",
298
+ "\n",
299
+ " # edit first_d -- arrow body\n",
300
+ " second_term = first_d_split[1].split(\" \")[0] + \" \" + str(corrected_arrow_pos)\n",
301
+ " first_d = 'd=\"M' + str(corrected_arrow_pos) + \",\" + second_term + \",\" + \",\".join(first_d_split[2:])\n",
302
+ "\n",
303
+ " # edit second_d -- arrowhead\n",
304
+ " second_term = second_d_split[1].split(\" \")[0] + \" L\" + str(corrected_arrowhead_pos - ARROW_PIXEL_SIZE)\n",
305
+ " third_term = second_d_split[2].split(\" \")[0] + \" \" + str(corrected_arrowhead_pos + ARROW_PIXEL_SIZE)\n",
306
+ " second_d = 'd=\"M' + str(corrected_arrowhead_pos) + \",\" + second_term + \",\" + third_term + \",\" + \",\".join(second_d_split[3:])\n",
307
+ " else: # This arrow is outgoing to another word, center the arrow/head to that word's center\n",
308
+ " corrected_arrowhead_pos = round_base(arrowhead_position, base=WORD_SPACING)\n",
309
+ "\n",
310
+ " # edit first_d -- arrow body\n",
311
+ " third_term = first_d_split[2].split(\" \")[0] + \" \" + str(corrected_arrowhead_pos)\n",
312
+ " fourth_term = first_d_split[3].split(\" \")[0] + \" \" + str(corrected_arrowhead_pos)\n",
313
+ " terms = [first_d_split[0], first_d_split[1], third_term, fourth_term] + first_d_split[4:]\n",
314
+ " first_d = \",\".join(terms)\n",
315
+ "\n",
316
+ " # edit second_d -- arrow head\n",
317
+ " first_term = f'd=\"M{corrected_arrowhead_pos}'\n",
318
+ " second_term = second_d_split[1].split(\" \")[0] + \" L\" + str(corrected_arrowhead_pos - ARROW_PIXEL_SIZE)\n",
319
+ " third_term = second_d_split[2].split(\" \")[0] + \" \" + str(corrected_arrowhead_pos + ARROW_PIXEL_SIZE)\n",
320
+ " terms = [first_term, second_term, third_term] + second_d_split[3:]\n",
321
+ " second_d = \",\".join(terms)\n",
322
+ " # rebuild and return html\n",
323
+ " return arrow_html[:first_d_idx] + first_d + \" \" + arrow_html[first_fill_start_idx:second_d_idx] + second_d + \" \" + arrow_html[second_fill_start_idx:]\n",
324
+ "\n",
325
+ "\n",
326
+ "def main():\n",
327
+ " nlp = stanza.Pipeline(\"en\", processors=\"tokenize,pos,lemma,depparse\")\n",
328
+ "\n",
329
+ " # doc = nlp(\"This a dummy sentence. Banning opal removed all artifact decks from the meta. I miss playing lantern. This is a dummy sentence.\")\n",
330
+ " doc = nlp(\"Banning opal removed artifact decks from the meta. Banning tennis resulted in players banning people.\")\n",
331
+ " # A single result .result[i].result[j] is a list of matches for sentence i on semgrex query j.\n",
332
+ " queries = [\"{pos:NN}=object <obl {}=action\",\n",
333
+ " \"{cpos:NOUN}=thing <obj {cpos:VERB}=action\"]\n",
334
+ " res = visualize_search_doc(doc, queries, \"en\")\n",
335
+ " print(res[0]) # see the first sentence's deprel visualization HTML\n",
336
+ " print(\"---------------------------------------\")\n",
337
+ " print(res[1]) # second sentence's deprel visualization HTML\n",
338
+ " return\n",
339
+ "\n",
340
+ "\n",
341
+ "if __name__ == '__main__':\n",
342
+ " main()\n"
343
+ ]
344
+ }
345
+ ],
346
+ "metadata": {
347
+ "kernelspec": {
348
+ "display_name": "Python 3 (ipykernel)",
349
+ "language": "python",
350
+ "name": "python3"
351
+ },
352
+ "language_info": {
353
+ "codemirror_mode": {
354
+ "name": "ipython",
355
+ "version": 3
356
+ },
357
+ "file_extension": ".py",
358
+ "mimetype": "text/x-python",
359
+ "name": "python",
360
+ "nbconvert_exporter": "python",
361
+ "pygments_lexer": "ipython3",
362
+ "version": "3.9.6"
363
+ }
364
+ },
365
+ "nbformat": 4,
366
+ "nbformat_minor": 5
367
+ }
stanza/images/stanza-logo.png ADDED
stanza/stanza/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stanza.pipeline.core import DownloadMethod, Pipeline
2
+ from stanza.pipeline.multilingual import MultilingualPipeline
3
+ from stanza.models.common.doc import Document
4
+ from stanza.resources.common import download
5
+ from stanza.resources.installation import install_corenlp, download_corenlp_models
6
+ from stanza._version import __version__, __resources_version__
7
+
8
+ import logging
9
+ logger = logging.getLogger('stanza')
10
+
11
+ # if the client application hasn't set the log level, we set it
12
+ # ourselves to INFO
13
+ if logger.level == 0:
14
+ logger.setLevel(logging.INFO)
15
+
16
+ log_handler = logging.StreamHandler()
17
+ log_formatter = logging.Formatter(fmt="%(asctime)s %(levelname)s: %(message)s",
18
+ datefmt='%Y-%m-%d %H:%M:%S')
19
+ log_handler.setFormatter(log_formatter)
20
+
21
+ # also, if the client hasn't added any handlers for this logger
22
+ # (or a default handler), we add a handler of our own
23
+ #
24
+ # client can later do
25
+ # logger.removeHandler(stanza.log_handler)
26
+ if not logger.hasHandlers():
27
+ logger.addHandler(log_handler)
stanza/stanza/models/charlm.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Entry point for training and evaluating a character-level neural language model.
3
+ """
4
+
5
+ import argparse
6
+ from copy import copy
7
+ import logging
8
+ import lzma
9
+ import math
10
+ import os
11
+ import random
12
+ import time
13
+ from types import GeneratorType
14
+ import numpy as np
15
+ import torch
16
+
17
+ from stanza.models.common.char_model import build_charlm_vocab, CharacterLanguageModel, CharacterLanguageModelTrainer
18
+ from stanza.models.common.vocab import CharVocab
19
+ from stanza.models.common import utils
20
+ from stanza.models import _training_logging
21
+
22
+ logger = logging.getLogger('stanza')
23
+
24
+ def repackage_hidden(h):
25
+ """Wraps hidden states in new Tensors,
26
+ to detach them from their history."""
27
+ if isinstance(h, torch.Tensor):
28
+ return h.detach()
29
+ else:
30
+ return tuple(repackage_hidden(v) for v in h)
31
+
32
+ def batchify(data, bsz, device):
33
+ # Work out how cleanly we can divide the dataset into bsz parts.
34
+ nbatch = data.size(0) // bsz
35
+ # Trim off any extra elements that wouldn't cleanly fit (remainders).
36
+ data = data.narrow(0, 0, nbatch * bsz)
37
+ # Evenly divide the data across the bsz batches.
38
+ data = data.view(bsz, -1) # batch_first is True
39
+ data = data.to(device)
40
+ return data
41
+
42
+ def get_batch(source, i, seq_len):
43
+ seq_len = min(seq_len, source.size(1) - 1 - i)
44
+ data = source[:, i:i+seq_len]
45
+ target = source[:, i+1:i+1+seq_len].reshape(-1)
46
+ return data, target
47
+
48
+ def load_file(filename, vocab, direction):
49
+ with utils.open_read_text(filename) as fin:
50
+ data = fin.read()
51
+
52
+ idx = vocab['char'].map(data)
53
+ if direction == 'backward': idx = idx[::-1]
54
+ return torch.tensor(idx)
55
+
56
+ def load_data(path, vocab, direction):
57
+ if os.path.isdir(path):
58
+ filenames = sorted(os.listdir(path))
59
+ for filename in filenames:
60
+ logger.info('Loading data from {}'.format(filename))
61
+ data = load_file(os.path.join(path, filename), vocab, direction)
62
+ yield data
63
+ else:
64
+ data = load_file(path, vocab, direction)
65
+ yield data
66
+
67
+ def build_argparse():
68
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
69
+ parser.add_argument('--train_file', type=str, help="Input plaintext file")
70
+ parser.add_argument('--train_dir', type=str, help="If non-empty, load from directory with multiple training files")
71
+ parser.add_argument('--eval_file', type=str, help="Input plaintext file for the dev/test set")
72
+ parser.add_argument('--shorthand', type=str, help="UD treebank shorthand")
73
+
74
+ parser.add_argument('--mode', default='train', choices=['train', 'predict'])
75
+ parser.add_argument('--direction', default='forward', choices=['forward', 'backward'], help="Forward or backward language model")
76
+ parser.add_argument('--forward', action='store_const', dest='direction', const='forward', help="Train a forward language model")
77
+ parser.add_argument('--backward', action='store_const', dest='direction', const='backward', help="Train a backward language model")
78
+
79
+ parser.add_argument('--char_emb_dim', type=int, default=100, help="Dimension of unit embeddings")
80
+ parser.add_argument('--char_hidden_dim', type=int, default=1024, help="Dimension of hidden units")
81
+ parser.add_argument('--char_num_layers', type=int, default=1, help="Layers of RNN in the language model")
82
+ parser.add_argument('--char_dropout', type=float, default=0.05, help="Dropout probability")
83
+ parser.add_argument('--char_unit_dropout', type=float, default=1e-5, help="Randomly set an input char to UNK during training")
84
+ parser.add_argument('--char_rec_dropout', type=float, default=0.0, help="Recurrent dropout probability")
85
+
86
+ parser.add_argument('--batch_size', type=int, default=100, help="Batch size to use")
87
+ parser.add_argument('--bptt_size', type=int, default=250, help="Sequence length to consider at a time")
88
+ parser.add_argument('--epochs', type=int, default=50, help="Total epochs to train the model for")
89
+ parser.add_argument('--max_grad_norm', type=float, default=0.25, help="Maximum gradient norm to clip to")
90
+ parser.add_argument('--lr0', type=float, default=5, help="Initial learning rate")
91
+ parser.add_argument('--anneal', type=float, default=0.25, help="Anneal the learning rate by this amount when dev performance deteriorate")
92
+ parser.add_argument('--patience', type=int, default=1, help="Patience for annealing the learning rate")
93
+ parser.add_argument('--weight_decay', type=float, default=0.0, help="Weight decay")
94
+ parser.add_argument('--momentum', type=float, default=0.0, help='Momentum for SGD.')
95
+ parser.add_argument('--cutoff', type=int, default=1000, help="Frequency cutoff for char vocab. By default we assume a very large corpus.")
96
+
97
+ parser.add_argument('--report_steps', type=int, default=50, help="Update step interval to report loss")
98
+ parser.add_argument('--eval_steps', type=int, default=100000, help="Update step interval to run eval on dev; set to -1 to eval after each epoch")
99
+ parser.add_argument('--save_name', type=str, default=None, help="File name to save the model")
100
+ parser.add_argument('--vocab_save_name', type=str, default=None, help="File name to save the vocab")
101
+ parser.add_argument('--checkpoint_save_name', type=str, default=None, help="File name to save the most recent checkpoint")
102
+ parser.add_argument('--no_checkpoint', dest='checkpoint', action='store_false', help="Don't save checkpoints")
103
+ parser.add_argument('--save_dir', type=str, default='saved_models/charlm', help="Directory to save models in")
104
+ parser.add_argument('--summary', action='store_true', help='Use summary writer to record progress.')
105
+ utils.add_device_args(parser)
106
+ parser.add_argument('--seed', type=int, default=1234)
107
+
108
+ parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
109
+ parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
110
+ return parser
111
+
112
+ def build_model_filename(args):
113
+ if args['save_name']:
114
+ save_name = args['save_name']
115
+ else:
116
+ save_name = '{}_{}_charlm.pt'.format(args['shorthand'], args['direction'])
117
+ model_file = os.path.join(args['save_dir'], save_name)
118
+ return model_file
119
+
120
+ def parse_args(args=None):
121
+ parser = build_argparse()
122
+
123
+ args = parser.parse_args(args=args)
124
+
125
+ if args.wandb_name:
126
+ args.wandb = True
127
+
128
+ args = vars(args)
129
+ return args
130
+
131
+ def main(args=None):
132
+ args = parse_args(args=args)
133
+
134
+ utils.set_random_seed(args['seed'])
135
+
136
+ logger.info("Running {} character-level language model in {} mode".format(args['direction'], args['mode']))
137
+
138
+ utils.ensure_dir(args['save_dir'])
139
+
140
+ if args['mode'] == 'train':
141
+ train(args)
142
+ else:
143
+ evaluate(args)
144
+
145
+ def evaluate_epoch(args, vocab, data, model, criterion):
146
+ """
147
+ Run an evaluation over entire dataset.
148
+ """
149
+ model.eval()
150
+ device = next(model.parameters()).device
151
+ hidden = None
152
+ total_loss = 0
153
+ if isinstance(data, GeneratorType):
154
+ data = list(data)
155
+ assert len(data) == 1, 'Only support single dev/test file'
156
+ data = data[0]
157
+ batches = batchify(data, args['batch_size'], device)
158
+ with torch.no_grad():
159
+ for i in range(0, batches.size(1) - 1, args['bptt_size']):
160
+ data, target = get_batch(batches, i, args['bptt_size'])
161
+ lens = [data.size(1) for i in range(data.size(0))]
162
+
163
+ output, hidden, decoded = model.forward(data, lens, hidden)
164
+ loss = criterion(decoded.view(-1, len(vocab['char'])), target)
165
+
166
+ hidden = repackage_hidden(hidden)
167
+ total_loss += data.size(1) * loss.data.item()
168
+ return total_loss / batches.size(1)
169
+
170
+ def evaluate_and_save(args, vocab, data, trainer, best_loss, model_file, checkpoint_file, writer=None):
171
+ """
172
+ Run an evaluation over entire dataset, print progress and save the model if necessary.
173
+ """
174
+ start_time = time.time()
175
+ loss = evaluate_epoch(args, vocab, data, trainer.model, trainer.criterion)
176
+ ppl = math.exp(loss)
177
+ elapsed = int(time.time() - start_time)
178
+ # TODO: step the scheduler less often when the eval frequency is higher
179
+ previous_lr = get_current_lr(trainer, args)
180
+ trainer.scheduler.step(loss)
181
+ current_lr = get_current_lr(trainer, args)
182
+ if previous_lr != current_lr:
183
+ logger.info("Updating learning rate to %f", current_lr)
184
+ logger.info(
185
+ "| eval checkpoint @ global step {:10d} | time elapsed {:6d}s | loss {:5.2f} | ppl {:8.2f}".format(
186
+ trainer.global_step,
187
+ elapsed,
188
+ loss,
189
+ ppl,
190
+ )
191
+ )
192
+ if best_loss is None or loss < best_loss:
193
+ best_loss = loss
194
+ trainer.save(model_file, full=False)
195
+ logger.info('new best model saved at step {:10d}'.format(trainer.global_step))
196
+ if writer:
197
+ writer.add_scalar('dev_loss', loss, global_step=trainer.global_step)
198
+ writer.add_scalar('dev_ppl', ppl, global_step=trainer.global_step)
199
+ if checkpoint_file:
200
+ trainer.save(checkpoint_file, full=True)
201
+ logger.info('new checkpoint saved at step {:10d}'.format(trainer.global_step))
202
+
203
+ return loss, ppl, best_loss
204
+
205
+ def get_current_lr(trainer, args):
206
+ return trainer.scheduler.state_dict().get('_last_lr', [args['lr0']])[0]
207
+
208
+ def load_char_vocab(vocab_file):
209
+ return {'char': CharVocab.load_state_dict(torch.load(vocab_file, lambda storage, loc: storage, weights_only=True))}
210
+
211
+ def train(args):
212
+ utils.log_training_args(args, logger)
213
+ model_file = build_model_filename(args)
214
+
215
+ vocab_file = args['save_dir'] + '/' + args['vocab_save_name'] if args['vocab_save_name'] is not None \
216
+ else '{}/{}_vocab.pt'.format(args['save_dir'], args['shorthand'])
217
+
218
+ if args['checkpoint']:
219
+ checkpoint_file = utils.checkpoint_name(args['save_dir'], model_file, args['checkpoint_save_name'])
220
+ else:
221
+ checkpoint_file = None
222
+
223
+ if os.path.exists(vocab_file):
224
+ logger.info('Loading existing vocab file')
225
+ vocab = load_char_vocab(vocab_file)
226
+ else:
227
+ logger.info('Building and saving vocab')
228
+ vocab = {'char': build_charlm_vocab(args['train_file'] if args['train_dir'] is None else args['train_dir'], cutoff=args['cutoff'])}
229
+ torch.save(vocab['char'].state_dict(), vocab_file)
230
+ logger.info("Training model with vocab size: {}".format(len(vocab['char'])))
231
+
232
+ if checkpoint_file and os.path.exists(checkpoint_file):
233
+ logger.info('Loading existing checkpoint: %s' % checkpoint_file)
234
+ trainer = CharacterLanguageModelTrainer.load(args, checkpoint_file, finetune=True)
235
+ else:
236
+ trainer = CharacterLanguageModelTrainer.from_new_model(args, vocab)
237
+
238
+ writer = None
239
+ if args['summary']:
240
+ from torch.utils.tensorboard import SummaryWriter
241
+ summary_dir = '{}/{}_summary'.format(args['save_dir'], args['save_name']) if args['save_name'] is not None \
242
+ else '{}/{}_{}_charlm_summary'.format(args['save_dir'], args['shorthand'], args['direction'])
243
+ writer = SummaryWriter(log_dir=summary_dir)
244
+
245
+ # evaluate model within epoch if eval_interval is set
246
+ eval_within_epoch = False
247
+ if args['eval_steps'] > 0:
248
+ eval_within_epoch = True
249
+
250
+ if args['wandb']:
251
+ import wandb
252
+ wandb_name = args['wandb_name'] if args['wandb_name'] else '%s_%s_charlm' % (args['shorthand'], args['direction'])
253
+ wandb.init(name=wandb_name, config=args)
254
+ wandb.run.define_metric('best_loss', summary='min')
255
+ wandb.run.define_metric('ppl', summary='min')
256
+
257
+ device = next(trainer.model.parameters()).device
258
+
259
+ best_loss = None
260
+ start_epoch = trainer.epoch # will default to 1 for a new trainer
261
+ for trainer.epoch in range(start_epoch, args['epochs']+1):
262
+ # load train data from train_dir if not empty, otherwise load from file
263
+ if args['train_dir'] is not None:
264
+ train_path = args['train_dir']
265
+ else:
266
+ train_path = args['train_file']
267
+ train_data = load_data(train_path, vocab, args['direction'])
268
+ dev_data = load_file(args['eval_file'], vocab, args['direction']) # dev must be a single file
269
+
270
+ # run over entire training set
271
+ for data_chunk in train_data:
272
+ batches = batchify(data_chunk, args['batch_size'], device)
273
+ hidden = None
274
+ total_loss = 0.0
275
+ total_batches = math.ceil((batches.size(1) - 1) / args['bptt_size'])
276
+ iteration, i = 0, 0
277
+ # over the data chunk
278
+ while i < batches.size(1) - 1 - 1:
279
+ trainer.model.train()
280
+ trainer.global_step += 1
281
+ start_time = time.time()
282
+ bptt = args['bptt_size'] if np.random.random() < 0.95 else args['bptt_size']/ 2.
283
+ # prevent excessively small or negative sequence lengths
284
+ seq_len = max(5, int(np.random.normal(bptt, 5)))
285
+ # prevent very large sequence length, must be <= 1.2 x bptt
286
+ seq_len = min(seq_len, int(args['bptt_size'] * 1.2))
287
+ data, target = get_batch(batches, i, seq_len)
288
+ lens = [data.size(1) for i in range(data.size(0))]
289
+
290
+ trainer.optimizer.zero_grad()
291
+ output, hidden, decoded = trainer.model.forward(data, lens, hidden)
292
+ loss = trainer.criterion(decoded.view(-1, len(vocab['char'])), target)
293
+ total_loss += loss.data.item()
294
+ loss.backward()
295
+
296
+ torch.nn.utils.clip_grad_norm_(trainer.params, args['max_grad_norm'])
297
+ trainer.optimizer.step()
298
+
299
+ hidden = repackage_hidden(hidden)
300
+
301
+ if (iteration + 1) % args['report_steps'] == 0:
302
+ cur_loss = total_loss / args['report_steps']
303
+ elapsed = time.time() - start_time
304
+ logger.info(
305
+ "| epoch {:5d} | {:5d}/{:5d} batches | sec/batch {:.6f} | loss {:5.2f} | ppl {:8.2f}".format(
306
+ trainer.epoch,
307
+ iteration + 1,
308
+ total_batches,
309
+ elapsed / args['report_steps'],
310
+ cur_loss,
311
+ math.exp(cur_loss),
312
+ )
313
+ )
314
+ if args['wandb']:
315
+ wandb.log({'train_loss': cur_loss}, step=trainer.global_step)
316
+ total_loss = 0.0
317
+
318
+ iteration += 1
319
+ i += seq_len
320
+
321
+ # evaluate if necessary
322
+ if eval_within_epoch and trainer.global_step % args['eval_steps'] == 0:
323
+ _, ppl, best_loss = evaluate_and_save(args, vocab, dev_data, trainer, best_loss, model_file, checkpoint_file, writer)
324
+ if args['wandb']:
325
+ wandb.log({'ppl': ppl, 'best_loss': best_loss, 'lr': get_current_lr(trainer, args)}, step=trainer.global_step)
326
+
327
+ # if eval_interval isn't provided, run evaluation after each epoch
328
+ if not eval_within_epoch or trainer.epoch == args['epochs']:
329
+ _, ppl, best_loss = evaluate_and_save(args, vocab, dev_data, trainer, best_loss, model_file, checkpoint_file, writer)
330
+ if args['wandb']:
331
+ wandb.log({'ppl': ppl, 'best_loss': best_loss, 'lr': get_current_lr(trainer, args)}, step=trainer.global_step)
332
+
333
+ if writer:
334
+ writer.close()
335
+ if args['wandb']:
336
+ wandb.finish()
337
+ return
338
+
339
+ def evaluate(args):
340
+ model_file = build_model_filename(args)
341
+
342
+ model = CharacterLanguageModel.load(model_file).to(args['device'])
343
+ vocab = model.vocab
344
+ data = load_data(args['eval_file'], vocab, args['direction'])
345
+ criterion = torch.nn.CrossEntropyLoss()
346
+
347
+ loss = evaluate_epoch(args, vocab, data, model, criterion)
348
+ logger.info(
349
+ "| best model | loss {:5.2f} | ppl {:8.2f}".format(
350
+ loss,
351
+ math.exp(loss),
352
+ )
353
+ )
354
+ return
355
+
356
+ if __name__ == '__main__':
357
+ main()
stanza/stanza/models/identity_lemmatizer.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ An identity lemmatizer that mimics the behavior of a normal lemmatizer but directly uses word as lemma.
3
+ """
4
+
5
+ import os
6
+ import argparse
7
+ import logging
8
+ import random
9
+
10
+ from stanza.models.lemma.data import DataLoader
11
+ from stanza.models.lemma import scorer
12
+ from stanza.models.common import utils
13
+ from stanza.models.common.doc import *
14
+ from stanza.utils.conll import CoNLL
15
+ from stanza.models import _training_logging
16
+
17
+ logger = logging.getLogger('stanza')
18
+
19
+ def parse_args(args=None):
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument('--data_dir', type=str, default='data/lemma', help='Directory for all lemma data.')
22
+ parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')
23
+ parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')
24
+ parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.')
25
+ parser.add_argument('--gold_file', type=str, default=None, help='Output CoNLL-U file.')
26
+
27
+ parser.add_argument('--mode', default='train', choices=['train', 'predict'])
28
+ parser.add_argument('--shorthand', type=str, help='Shorthand')
29
+
30
+ parser.add_argument('--batch_size', type=int, default=50)
31
+ parser.add_argument('--seed', type=int, default=1234)
32
+
33
+ args = parser.parse_args(args=args)
34
+ return args
35
+
36
+ def main(args=None):
37
+ args = parse_args(args=args)
38
+
39
+ random.seed(args.seed)
40
+
41
+ args = vars(args)
42
+
43
+ logger.info("[Launching identity lemmatizer...]")
44
+
45
+ if args['mode'] == 'train':
46
+ logger.info("[No training is required; will only generate evaluation output...]")
47
+
48
+ document = CoNLL.conll2doc(input_file=args['eval_file'])
49
+ batch = DataLoader(document, args['batch_size'], args, evaluation=True, conll_only=True)
50
+ system_pred_file = args['output_file']
51
+ gold_file = args['gold_file']
52
+
53
+ # use identity mapping for prediction
54
+ preds = batch.doc.get([TEXT])
55
+
56
+ # write to file and score
57
+ batch.doc.set([LEMMA], preds)
58
+ CoNLL.write_doc2conll(batch.doc, system_pred_file)
59
+ if gold_file is not None:
60
+ _, _, score = scorer.score(system_pred_file, gold_file)
61
+
62
+ logger.info("Lemma score:")
63
+ logger.info("{} {:.2f}".format(args['shorthand'], score*100))
64
+
65
+ if __name__ == '__main__':
66
+ main()
stanza/stanza/models/lang_identifier.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Entry point for training and evaluating a Bi-LSTM language identifier
3
+ """
4
+
5
+ import argparse
6
+ import json
7
+ import logging
8
+ import os
9
+ import random
10
+ import torch
11
+
12
+ from datetime import datetime
13
+ from stanza.models.common import utils
14
+ from stanza.models.langid.data import DataLoader
15
+ from stanza.models.langid.trainer import Trainer
16
+ from stanza.utils.get_tqdm import get_tqdm
17
+
18
+ tqdm = get_tqdm()
19
+
20
+ logger = logging.getLogger('stanza')
21
+
22
+ def parse_args(args=None):
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("--batch_mode", help="custom settings when running in batch mode", action="store_true")
25
+ parser.add_argument("--batch_size", help="batch size for training", type=int, default=64)
26
+ parser.add_argument("--eval_length", help="length of strings to eval on", type=int, default=None)
27
+ parser.add_argument("--eval_set", help="eval on dev or test", default="test")
28
+ parser.add_argument("--data_dir", help="directory with train/dev/test data", default=None)
29
+ parser.add_argument("--load_name", help="path to load model from", default=None)
30
+ parser.add_argument("--mode", help="train or eval", default="train")
31
+ parser.add_argument("--num_epochs", help="number of epochs for training", type=int, default=50)
32
+ parser.add_argument("--randomize", help="take random substrings of samples", action="store_true")
33
+ parser.add_argument("--randomize_lengths_range", help="range of lengths to use when random sampling text",
34
+ type=randomize_lengths_range, default="5,20")
35
+ parser.add_argument("--merge_labels_for_eval",
36
+ help="merge some language labels for eval (e.g. \"zh-hans\" and \"zh-hant\" to \"zh\")",
37
+ action="store_true")
38
+ parser.add_argument("--save_best_epochs", help="save model for every epoch with new best score", action="store_true")
39
+ parser.add_argument("--save_name", help="where to save model", default=None)
40
+ utils.add_device_args(parser)
41
+ args = parser.parse_args(args=args)
42
+ return args
43
+
44
+
45
+ def randomize_lengths_range(range_list):
46
+ """
47
+ Range of lengths for random samples
48
+ """
49
+ range_boundaries = [int(x) for x in range_list.split(",")]
50
+ assert range_boundaries[0] < range_boundaries[1], f"Invalid range: ({range_boundaries[0]}, {range_boundaries[1]})"
51
+ return range_boundaries
52
+
53
+
54
+ def main(args=None):
55
+ args = parse_args(args=args)
56
+ torch.manual_seed(0)
57
+ if args.mode == "train":
58
+ train_model(args)
59
+ else:
60
+ eval_model(args)
61
+
62
+
63
+ def build_indexes(args):
64
+ tag_to_idx = {}
65
+ char_to_idx = {}
66
+ train_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if "train" in x]
67
+ for train_file in train_files:
68
+ with open(train_file) as curr_file:
69
+ lines = curr_file.read().strip().split("\n")
70
+ examples = [json.loads(line) for line in lines if line.strip()]
71
+ for example in examples:
72
+ label = example["label"]
73
+ if label not in tag_to_idx:
74
+ tag_to_idx[label] = len(tag_to_idx)
75
+ sequence = example["text"]
76
+ for char in list(sequence):
77
+ if char not in char_to_idx:
78
+ char_to_idx[char] = len(char_to_idx)
79
+ char_to_idx["UNK"] = len(char_to_idx)
80
+ char_to_idx["<PAD>"] = len(char_to_idx)
81
+
82
+ return tag_to_idx, char_to_idx
83
+
84
+
85
+ def train_model(args):
86
+ # set up indexes
87
+ tag_to_idx, char_to_idx = build_indexes(args)
88
+ # load training data
89
+ train_data = DataLoader(args.device)
90
+ train_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if "train" in x]
91
+ train_data.load_data(args.batch_size, train_files, char_to_idx, tag_to_idx, args.randomize)
92
+ # load dev data
93
+ dev_data = DataLoader(args.device)
94
+ dev_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if "dev" in x]
95
+ dev_data.load_data(args.batch_size, dev_files, char_to_idx, tag_to_idx, randomize=False,
96
+ max_length=args.eval_length)
97
+ # set up trainer
98
+ trainer_config = {
99
+ "model_path": args.save_name,
100
+ "char_to_idx": char_to_idx,
101
+ "tag_to_idx": tag_to_idx,
102
+ "batch_size": args.batch_size,
103
+ "lang_weights": train_data.lang_weights
104
+ }
105
+ if args.load_name:
106
+ trainer_config["load_name"] = args.load_name
107
+ logger.info(f"{datetime.now()}\tLoading model from: {args.load_name}")
108
+ trainer = Trainer(trainer_config, load_model=args.load_name is not None, device=args.device)
109
+ # run training
110
+ best_accuracy = 0.0
111
+ for epoch in range(1, args.num_epochs+1):
112
+ logger.info(f"{datetime.now()}\tEpoch {epoch}")
113
+ logger.info(f"{datetime.now()}\tNum training batches: {len(train_data.batches)}")
114
+
115
+ batches = train_data.batches
116
+ if not args.batch_mode:
117
+ batches = tqdm(batches)
118
+ for train_batch in batches:
119
+ inputs = (train_batch["sentences"], train_batch["targets"])
120
+ trainer.update(inputs)
121
+
122
+ logger.info(f"{datetime.now()}\tEpoch complete. Evaluating on dev data.")
123
+ curr_dev_accuracy, curr_confusion_matrix, curr_precisions, curr_recalls, curr_f1s = \
124
+ eval_trainer(trainer, dev_data, batch_mode=args.batch_mode)
125
+ logger.info(f"{datetime.now()}\tCurrent dev accuracy: {curr_dev_accuracy}")
126
+ if curr_dev_accuracy > best_accuracy:
127
+ logger.info(f"{datetime.now()}\tNew best score. Saving model.")
128
+ model_label = f"epoch{epoch}" if args.save_best_epochs else None
129
+ trainer.save(label=model_label)
130
+ with open(score_log_path(args.save_name), "w") as score_log_file:
131
+ for score_log in [{"dev_accuracy": curr_dev_accuracy}, curr_confusion_matrix, curr_precisions,
132
+ curr_recalls, curr_f1s]:
133
+ score_log_file.write(json.dumps(score_log) + "\n")
134
+ best_accuracy = curr_dev_accuracy
135
+
136
+ # reload training data
137
+ logger.info(f"{datetime.now()}\tResampling training data.")
138
+ train_data.load_data(args.batch_size, train_files, char_to_idx, tag_to_idx, args.randomize)
139
+
140
+
141
+ def score_log_path(file_path):
142
+ """
143
+ Helper that will determine corresponding log file (e.g. /path/to/demo.pt to /path/to/demo.json
144
+ """
145
+ model_suffix = os.path.splitext(file_path)
146
+ if model_suffix[1]:
147
+ score_log_path = f"{file_path[:-len(model_suffix[1])]}.json"
148
+ else:
149
+ score_log_path = f"{file_path}.json"
150
+ return score_log_path
151
+
152
+
153
+ def eval_model(args):
154
+ # set up trainer
155
+ trainer_config = {
156
+ "model_path": None,
157
+ "load_name": args.load_name,
158
+ "batch_size": args.batch_size
159
+ }
160
+ trainer = Trainer(trainer_config, load_model=True, device=args.device)
161
+ # load test data
162
+ test_data = DataLoader(args.device)
163
+ test_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if args.eval_set in x]
164
+ test_data.load_data(args.batch_size, test_files, trainer.model.char_to_idx, trainer.model.tag_to_idx,
165
+ randomize=False, max_length=args.eval_length)
166
+ curr_accuracy, curr_confusion_matrix, curr_precisions, curr_recalls, curr_f1s = \
167
+ eval_trainer(trainer, test_data, batch_mode=args.batch_mode, fine_grained=not args.merge_labels_for_eval)
168
+ logger.info(f"{datetime.now()}\t{args.eval_set} accuracy: {curr_accuracy}")
169
+ eval_save_path = args.save_name if args.save_name else score_log_path(args.load_name)
170
+ if not os.path.exists(eval_save_path) or args.save_name:
171
+ with open(eval_save_path, "w") as score_log_file:
172
+ for score_log in [{"dev_accuracy": curr_accuracy}, curr_confusion_matrix, curr_precisions,
173
+ curr_recalls, curr_f1s]:
174
+ score_log_file.write(json.dumps(score_log) + "\n")
175
+
176
+
177
+
178
+ def eval_trainer(trainer, dev_data, batch_mode=False, fine_grained=True):
179
+ """
180
+ Produce dev accuracy and confusion matrix for a trainer
181
+ """
182
+
183
+ # set up confusion matrix
184
+ tag_to_idx = dev_data.tag_to_idx
185
+ idx_to_tag = dev_data.idx_to_tag
186
+ confusion_matrix = {}
187
+ for row_label in tag_to_idx:
188
+ confusion_matrix[row_label] = {}
189
+ for col_label in tag_to_idx:
190
+ confusion_matrix[row_label][col_label] = 0
191
+
192
+ # process dev batches
193
+ batches = dev_data.batches
194
+ if not batch_mode:
195
+ batches = tqdm(batches)
196
+ for dev_batch in batches:
197
+ inputs = (dev_batch["sentences"], dev_batch["targets"])
198
+ predictions = trainer.predict(inputs)
199
+ for target_idx, prediction in zip(dev_batch["targets"], predictions):
200
+ prediction_label = idx_to_tag[prediction] if fine_grained else idx_to_tag[prediction].split("-")[0]
201
+ confusion_matrix[idx_to_tag[target_idx]][prediction_label] += 1
202
+
203
+ # calculate dev accuracy
204
+ total_examples = sum([sum([confusion_matrix[i][j] for j in confusion_matrix[i]]) for i in confusion_matrix])
205
+ total_correct = sum([confusion_matrix[i][i] for i in confusion_matrix])
206
+ dev_accuracy = float(total_correct) / float(total_examples)
207
+
208
+ # calculate precision, recall, F1
209
+ precision_scores = {"type": "precision"}
210
+ recall_scores = {"type": "recall"}
211
+ f1_scores = {"type": "f1"}
212
+ for prediction_label in tag_to_idx:
213
+ total = sum([confusion_matrix[k][prediction_label] for k in tag_to_idx])
214
+ if total != 0.0:
215
+ precision_scores[prediction_label] = float(confusion_matrix[prediction_label][prediction_label])/float(total)
216
+ else:
217
+ precision_scores[prediction_label] = 0.0
218
+ for target_label in tag_to_idx:
219
+ total = sum([confusion_matrix[target_label][k] for k in tag_to_idx])
220
+ if total != 0:
221
+ recall_scores[target_label] = float(confusion_matrix[target_label][target_label])/float(total)
222
+ else:
223
+ recall_scores[target_label] = 0.0
224
+ for label in tag_to_idx:
225
+ if precision_scores[label] == 0.0 and recall_scores[label] == 0.0:
226
+ f1_scores[label] = 0.0
227
+ else:
228
+ f1_scores[label] = \
229
+ 2.0 * (precision_scores[label] * recall_scores[label]) / (precision_scores[label] + recall_scores[label])
230
+
231
+ return dev_accuracy, confusion_matrix, precision_scores, recall_scores, f1_scores
232
+
233
+
234
+ if __name__ == "__main__":
235
+ main()
236
+
stanza/stanza/models/mwt_expander.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Entry point for training and evaluating a multi-word token (MWT) expander.
3
+
4
+ This MWT expander combines a neural sequence-to-sequence architecture with a dictionary
5
+ to decode the token into multiple words.
6
+ For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf
7
+
8
+ In the case of a dataset where all of the MWT exactly split into the words
9
+ composing the MWT, a classifier over the characters is used instead of the seq2seq
10
+ """
11
+
12
+ import sys
13
+ import os
14
+ import shutil
15
+ import time
16
+ from datetime import datetime
17
+ import argparse
18
+ import logging
19
+ import math
20
+ import numpy as np
21
+ import random
22
+ import torch
23
+ from torch import nn, optim
24
+ import copy
25
+
26
+ from stanza.models.mwt.data import DataLoader, BinaryDataLoader
27
+ from stanza.models.mwt.utils import mwts_composed_of_words
28
+ from stanza.models.mwt.vocab import Vocab
29
+ from stanza.models.mwt.trainer import Trainer
30
+ from stanza.models.mwt import scorer
31
+ from stanza.models.common import utils
32
+ import stanza.models.common.seq2seq_constant as constant
33
+ from stanza.models.common.doc import Document
34
+ from stanza.utils.conll import CoNLL
35
+ from stanza.models import _training_logging
36
+
37
+ logger = logging.getLogger('stanza')
38
+
39
+ def build_argparse():
40
+ parser = argparse.ArgumentParser()
41
+ parser.add_argument('--data_dir', type=str, default='data/mwt', help='Root dir for saving models.')
42
+ parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')
43
+ parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')
44
+ parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.')
45
+ parser.add_argument('--gold_file', type=str, default=None, help='Output CoNLL-U file.')
46
+
47
+ parser.add_argument('--mode', default='train', choices=['train', 'predict'])
48
+ parser.add_argument('--lang', type=str, help='Language')
49
+ parser.add_argument('--shorthand', type=str, help="Treebank shorthand")
50
+
51
+ parser.add_argument('--no_dict', dest='ensemble_dict', action='store_false', help='Do not ensemble dictionary with seq2seq. By default ensemble a dict.')
52
+ parser.add_argument('--ensemble_early_stop', action='store_true', help='Early stopping based on ensemble performance.')
53
+ parser.add_argument('--dict_only', action='store_true', help='Only train a dictionary-based MWT expander.')
54
+
55
+ parser.add_argument('--hidden_dim', type=int, default=100)
56
+ parser.add_argument('--emb_dim', type=int, default=50)
57
+ parser.add_argument('--num_layers', type=int, default=None, help='Number of layers in model encoder. Defaults to 1 for seq2seq, 2 for classifier')
58
+ parser.add_argument('--emb_dropout', type=float, default=0.5)
59
+ parser.add_argument('--dropout', type=float, default=0.5)
60
+ parser.add_argument('--max_dec_len', type=int, default=50)
61
+ parser.add_argument('--beam_size', type=int, default=1)
62
+ parser.add_argument('--attn_type', default='soft', choices=['soft', 'mlp', 'linear', 'deep'], help='Attention type')
63
+ parser.add_argument('--no_copy', dest='copy', action='store_false', help='Do not use copy mechanism in MWT expansion. By default copy mechanism is used to improve generalization.')
64
+
65
+ parser.add_argument('--augment_apos', default=0.01, type=float, help='At training time, how much to augment |\'| to |"| |’| |ʼ|')
66
+ parser.add_argument('--force_exact_pieces', default=None, action='store_true', help='If possible, make the text of the pieces of the MWT add up to the token itself. (By default, this is determined from the dataset.)')
67
+ parser.add_argument('--no_force_exact_pieces', dest='force_exact_pieces', action='store_false', help="Don't make the text of the pieces of the MWT add up to the token itself. (By default, this is determined from the dataset.)")
68
+
69
+ parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.')
70
+ parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam or adamax.')
71
+ parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
72
+ parser.add_argument('--lr_decay', type=float, default=0.9)
73
+ parser.add_argument('--decay_epoch', type=int, default=30, help="Decay the lr starting from this epoch.")
74
+ parser.add_argument('--num_epoch', type=int, default=30)
75
+ parser.add_argument('--batch_size', type=int, default=50)
76
+ parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.')
77
+ parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.')
78
+ parser.add_argument('--save_dir', type=str, default='saved_models/mwt', help='Root dir for saving models.')
79
+ parser.add_argument('--save_name', type=str, default=None, help="File name to save the model")
80
+ parser.add_argument('--save_each_name', type=str, default=None, help="Save each model in sequence to this pattern. Mostly for testing")
81
+
82
+ parser.add_argument('--seed', type=int, default=1234)
83
+ utils.add_device_args(parser)
84
+
85
+ parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
86
+ parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
87
+ return parser
88
+
89
+ def parse_args(args=None):
90
+ parser = build_argparse()
91
+ args = parser.parse_args(args=args)
92
+
93
+ if args.wandb_name:
94
+ args.wandb = True
95
+
96
+ return args
97
+
98
+ def main(args=None):
99
+ args = parse_args(args=args)
100
+
101
+ utils.set_random_seed(args.seed)
102
+
103
+ args = vars(args)
104
+ logger.info("Running MWT expander in {} mode".format(args['mode']))
105
+
106
+ if args['mode'] == 'train':
107
+ train(args)
108
+ else:
109
+ evaluate(args)
110
+
111
+ def train(args):
112
+ # load data
113
+ logger.debug('max_dec_len: %d' % args['max_dec_len'])
114
+ logger.debug("Loading data with batch size {}...".format(args['batch_size']))
115
+ train_doc = CoNLL.conll2doc(input_file=args['train_file'])
116
+ train_batch = DataLoader(train_doc, args['batch_size'], args, evaluation=False)
117
+ vocab = train_batch.vocab
118
+ args['vocab_size'] = vocab.size
119
+ dev_doc = CoNLL.conll2doc(input_file=args['eval_file'])
120
+ dev_batch = DataLoader(dev_doc, args['batch_size'], args, vocab=vocab, evaluation=True)
121
+
122
+ utils.ensure_dir(args['save_dir'])
123
+ save_name = args['save_name'] if args['save_name'] else '{}_mwt_expander.pt'.format(args['shorthand'])
124
+ model_file = os.path.join(args['save_dir'], save_name)
125
+
126
+ save_each_name = None
127
+ if args['save_each_name']:
128
+ save_each_name = os.path.join(args['save_dir'], args['save_each_name'])
129
+ save_each_name = utils.build_save_each_filename(save_each_name)
130
+
131
+ # pred and gold path
132
+ system_pred_file = args['output_file']
133
+ gold_file = args['gold_file']
134
+
135
+ # skip training if the language does not have training or dev data
136
+ if len(train_batch) == 0:
137
+ logger.warning("Skip training because no data available...")
138
+ return
139
+ dev_mwt = dev_doc.get_mwt_expansions(False)
140
+ if len(dev_batch) == 0 and args.get('dict_only', False):
141
+ logger.warning("Training data available, but dev data has no MWTs. Only training a dict based MWT")
142
+ args['dict_only'] = True
143
+
144
+ if args['force_exact_pieces'] and not mwts_composed_of_words(train_doc):
145
+ raise ValueError("Cannot train model with --force_exact_pieces, as the MWT in this dataset are not entirely composed of their subwords")
146
+
147
+ if args['force_exact_pieces'] is None and mwts_composed_of_words(train_doc):
148
+ # the force_exact_pieces mechanism trains a separate version of the MWT expander in the Trainer
149
+ # (the training loop here does not need to change)
150
+ # in this model, a classifier distinguishes whether or not a location is a split
151
+ # and the text is copied exactly from the input rather than created via seq2seq
152
+ # this behavior can be turned off at training time with --no_force_exact_pieces
153
+ logger.info("Train MWTs entirely composed of their subwords. Training the MWT to match that paradigm as closely as possible")
154
+ args['force_exact_pieces'] = True
155
+
156
+ if args['force_exact_pieces']:
157
+ logger.info("Reconverting to BinaryDataLoader")
158
+ train_batch = BinaryDataLoader(train_doc, args['batch_size'], args, evaluation=False)
159
+ vocab = train_batch.vocab
160
+ args['vocab_size'] = vocab.size
161
+ dev_batch = BinaryDataLoader(dev_doc, args['batch_size'], args, vocab=vocab, evaluation=True)
162
+
163
+ if args['num_layers'] is None:
164
+ if args['force_exact_pieces']:
165
+ args['num_layers'] = 2
166
+ else:
167
+ args['num_layers'] = 1
168
+
169
+ # train a dictionary-based MWT expander
170
+ trainer = Trainer(args=args, vocab=vocab, device=args['device'])
171
+ logger.info("Training dictionary-based MWT expander...")
172
+ trainer.train_dict(train_batch.doc.get_mwt_expansions(evaluation=False))
173
+ logger.info("Evaluating on dev set...")
174
+ dev_preds = trainer.predict_dict(dev_batch.doc.get_mwt_expansions(evaluation=True))
175
+ doc = copy.deepcopy(dev_batch.doc)
176
+ doc.set_mwt_expansions(dev_preds, fake_dependencies=True)
177
+ CoNLL.write_doc2conll(doc, system_pred_file)
178
+ _, _, dev_f = scorer.score(system_pred_file, gold_file)
179
+ logger.info("Dev F1 = {:.2f}".format(dev_f * 100))
180
+
181
+ if args.get('dict_only', False):
182
+ # save dictionaries
183
+ trainer.save(model_file)
184
+ else:
185
+ # train a seq2seq model
186
+ logger.info("Training seq2seq-based MWT expander...")
187
+ global_step = 0
188
+ steps_per_epoch = math.ceil(len(train_batch) / args['batch_size'])
189
+ max_steps = steps_per_epoch * args['num_epoch']
190
+ dev_score_history = []
191
+ best_dev_preds = []
192
+ current_lr = args['lr']
193
+ global_start_time = time.time()
194
+ format_str = '{}: step {}/{} (epoch {}/{}), loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'
195
+
196
+ if args['wandb']:
197
+ import wandb
198
+ wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_mwt" % args['shorthand']
199
+ wandb.init(name=wandb_name, config=args)
200
+ wandb.run.define_metric('train_loss', summary='min')
201
+ wandb.run.define_metric('dev_score', summary='max')
202
+
203
+ # start training
204
+ for epoch in range(1, args['num_epoch']+1):
205
+ train_loss = 0
206
+ for i, batch in enumerate(train_batch.to_loader()):
207
+ start_time = time.time()
208
+ global_step += 1
209
+ loss = trainer.update(batch, eval=False) # update step
210
+ train_loss += loss
211
+ if global_step % args['log_step'] == 0:
212
+ duration = time.time() - start_time
213
+ logger.info(format_str.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), global_step,\
214
+ max_steps, epoch, args['num_epoch'], loss, duration, current_lr))
215
+
216
+ if save_each_name:
217
+ trainer.save(save_each_name % epoch)
218
+ logger.info("Saved epoch %d model to %s" % (epoch, save_each_name % epoch))
219
+
220
+ # eval on dev
221
+ logger.info("Evaluating on dev set...")
222
+ dev_preds = []
223
+ for i, batch in enumerate(dev_batch.to_loader()):
224
+ preds = trainer.predict(batch)
225
+ dev_preds += preds
226
+ if args.get('ensemble_dict', False) and args.get('ensemble_early_stop', False):
227
+ logger.info("[Ensembling dict with seq2seq model...]")
228
+ dev_preds = trainer.ensemble(dev_batch.doc.get_mwt_expansions(evaluation=True), dev_preds)
229
+ doc = copy.deepcopy(dev_batch.doc)
230
+ doc.set_mwt_expansions(dev_preds, fake_dependencies=True)
231
+ CoNLL.write_doc2conll(doc, system_pred_file)
232
+ _, _, dev_score = scorer.score(system_pred_file, gold_file)
233
+ train_loss = train_loss / train_batch.num_examples * args['batch_size'] # avg loss per batch
234
+ logger.info("epoch {}: train_loss = {:.6f}, dev_score = {:.4f}".format(epoch, train_loss, dev_score))
235
+
236
+ if args['wandb']:
237
+ wandb.log({'train_loss': train_loss, 'dev_score': dev_score})
238
+
239
+ # save best model
240
+ if epoch == 1 or dev_score > max(dev_score_history):
241
+ trainer.save(model_file)
242
+ logger.info("new best model saved.")
243
+ best_dev_preds = dev_preds
244
+
245
+ # lr schedule
246
+ if epoch > args['decay_epoch'] and dev_score <= dev_score_history[-1]:
247
+ current_lr *= args['lr_decay']
248
+ trainer.change_lr(current_lr)
249
+
250
+ dev_score_history += [dev_score]
251
+
252
+ logger.info("Training ended with {} epochs.".format(epoch))
253
+
254
+ if args['wandb']:
255
+ wandb.finish()
256
+
257
+ best_f, best_epoch = max(dev_score_history)*100, np.argmax(dev_score_history)+1
258
+ logger.info("Best dev F1 = {:.2f}, at epoch = {}".format(best_f, best_epoch))
259
+
260
+ # try ensembling with dict if necessary
261
+ if args.get('ensemble_dict', False):
262
+ logger.info("[Ensembling dict with seq2seq model...]")
263
+ dev_preds = trainer.ensemble(dev_batch.doc.get_mwt_expansions(evaluation=True), best_dev_preds)
264
+ doc = copy.deepcopy(dev_batch.doc)
265
+ doc.set_mwt_expansions(dev_preds, fake_dependencies=True)
266
+ CoNLL.write_doc2conll(doc, system_pred_file)
267
+ _, _, dev_score = scorer.score(system_pred_file, gold_file)
268
+ logger.info("Ensemble dev F1 = {:.2f}".format(dev_score*100))
269
+ best_f = max(best_f, dev_score)
270
+
271
+ def evaluate(args):
272
+ # file paths
273
+ system_pred_file = args['output_file']
274
+ gold_file = args['gold_file']
275
+ model_file = args['save_name'] if args['save_name'] else '{}_mwt_expander.pt'.format(args['shorthand'])
276
+ if args['save_dir'] and not model_file.startswith(args['save_dir']) and not os.path.exists(model_file):
277
+ model_file = os.path.join(args['save_dir'], model_file)
278
+
279
+ # load model
280
+ trainer = Trainer(model_file=model_file, device=args['device'])
281
+ loaded_args, vocab = trainer.args, trainer.vocab
282
+
283
+ for k in args:
284
+ if k.endswith('_dir') or k.endswith('_file') or k in ['shorthand']:
285
+ loaded_args[k] = args[k]
286
+ logger.debug('max_dec_len: %d' % loaded_args['max_dec_len'])
287
+
288
+ # load data
289
+ logger.debug("Loading data with batch size {}...".format(args['batch_size']))
290
+ doc = CoNLL.conll2doc(input_file=args['eval_file'])
291
+ batch = DataLoader(doc, args['batch_size'], loaded_args, vocab=vocab, evaluation=True)
292
+
293
+ if len(batch) > 0:
294
+ dict_preds = trainer.predict_dict(batch.doc.get_mwt_expansions(evaluation=True))
295
+ # decide trainer type and run eval
296
+ if loaded_args['dict_only']:
297
+ preds = dict_preds
298
+ else:
299
+ logger.info("Running the seq2seq model...")
300
+ preds = []
301
+ for i, b in enumerate(batch.to_loader()):
302
+ preds += trainer.predict(b)
303
+
304
+ if loaded_args.get('ensemble_dict', False):
305
+ preds = trainer.ensemble(batch.doc.get_mwt_expansions(evaluation=True), preds)
306
+ else:
307
+ # skip eval if dev data does not exist
308
+ preds = []
309
+
310
+ # write to file and score
311
+ doc = copy.deepcopy(batch.doc)
312
+ doc.set_mwt_expansions(preds, fake_dependencies=True)
313
+ CoNLL.write_doc2conll(doc, system_pred_file)
314
+
315
+ if gold_file is not None:
316
+ _, _, score = scorer.score(system_pred_file, gold_file)
317
+
318
+ logger.info("MWT expansion score: {} {:.2f}".format(args['shorthand'], score*100))
319
+
320
+
321
+ if __name__ == '__main__':
322
+ main()
stanza/stanza/models/ner_tagger.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Entry point for training and evaluating an NER tagger.
3
+
4
+ This tagger uses BiLSTM layers with character and word-level representations, and a CRF decoding layer
5
+ to produce NER predictions.
6
+ For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf.
7
+ """
8
+
9
+ import sys
10
+ import os
11
+ import time
12
+ from datetime import datetime
13
+ import argparse
14
+ import logging
15
+ import numpy as np
16
+ import random
17
+ import re
18
+ import json
19
+ import torch
20
+ from torch import nn, optim
21
+
22
+ from stanza.models.ner.data import DataLoader
23
+ from stanza.models.ner.trainer import Trainer
24
+ from stanza.models.ner import scorer
25
+ from stanza.models.common import utils
26
+ from stanza.models.common.pretrain import Pretrain
27
+ from stanza.utils.conll import CoNLL
28
+ from stanza.models.common.doc import *
29
+ from stanza.models import _training_logging
30
+
31
+ from stanza.models.common.peft_config import add_peft_args, resolve_peft_args
32
+ from stanza.utils.confusion import confusion_to_weighted_f1, format_confusion
33
+
34
+ logger = logging.getLogger('stanza')
35
+
36
+ def build_argparse():
37
+ parser = argparse.ArgumentParser()
38
+ parser.add_argument('--data_dir', type=str, default='data/ner', help='Directory of NER data.')
39
+ parser.add_argument('--wordvec_dir', type=str, default='extern_data/word2vec', help='Directory of word vectors')
40
+ parser.add_argument('--wordvec_file', type=str, default='', help='File that contains word vectors')
41
+ parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
42
+ parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')
43
+ parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')
44
+ parser.add_argument('--eval_output_file', type=str, default=None, help='Where to write results: text, gold, pred. If None, no results file printed')
45
+
46
+ parser.add_argument('--mode', default='train', choices=['train', 'predict'])
47
+ parser.add_argument('--finetune', action='store_true', help='Load existing model during `train` mode from `save_dir` path')
48
+ parser.add_argument('--finetune_load_name', type=str, default=None, help='Model to load when finetuning')
49
+ parser.add_argument('--train_classifier_only', action='store_true',
50
+ help='In case of applying Transfer-learning approach and training only the classifier layer this will freeze gradient propagation for all other layers.')
51
+ parser.add_argument('--shorthand', type=str, help="Treebank shorthand")
52
+
53
+ parser.add_argument('--hidden_dim', type=int, default=256)
54
+ parser.add_argument('--char_hidden_dim', type=int, default=100)
55
+ parser.add_argument('--word_emb_dim', type=int, default=100)
56
+ parser.add_argument('--char_emb_dim', type=int, default=100)
57
+ parser.add_argument('--num_layers', type=int, default=1)
58
+ parser.add_argument('--char_num_layers', type=int, default=1)
59
+ parser.add_argument('--pretrain_max_vocab', type=int, default=100000)
60
+ parser.add_argument('--word_dropout', type=float, default=0.01, help="How often to remove a word at training time. Set to a small value to train unk when finetuning word embeddings")
61
+ parser.add_argument('--locked_dropout', type=float, default=0.0)
62
+ parser.add_argument('--dropout', type=float, default=0.5)
63
+ parser.add_argument('--rec_dropout', type=float, default=0, help="Word recurrent dropout")
64
+ parser.add_argument('--char_rec_dropout', type=float, default=0, help="Character recurrent dropout")
65
+ parser.add_argument('--char_dropout', type=float, default=0, help="Character-level language model dropout")
66
+ parser.add_argument('--no_char', dest='char', action='store_false', help="Turn off training a character model.")
67
+ parser.add_argument('--charlm', action='store_true', help="Turn on contextualized char embedding using pretrained character-level language model.")
68
+ parser.add_argument('--charlm_save_dir', type=str, default='saved_models/charlm', help="Root dir for pretrained character-level language model.")
69
+ parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
70
+ parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
71
+ parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
72
+ parser.add_argument('--char_lowercase', dest='char_lowercase', action='store_true', help="Use lowercased characters in character model.")
73
+ parser.add_argument('--no_lowercase', dest='lowercase', action='store_false', help="Use cased word vectors.")
74
+ parser.add_argument('--no_emb_finetune', dest='emb_finetune', action='store_false', help="Turn off finetuning of the embedding matrix.")
75
+ parser.add_argument('--emb_finetune_known_only', dest='emb_finetune_known_only', action='store_true', help="Finetune the embedding matrix only for words in the embedding. (Default: finetune words not in the embedding as well) This may be useful for very large datasets where obscure words are only trained once in a while, such as French-WikiNER")
76
+ parser.add_argument('--no_input_transform', dest='input_transform', action='store_false', help="Do not use input transformation layer before tagger lstm.")
77
+ parser.add_argument('--scheme', type=str, default='bioes', help="The tagging scheme to use: bio or bioes.")
78
+ parser.add_argument('--train_scheme', type=str, default=None, help="The tagging scheme to use when training: bio or bioes. Overrides --scheme for the training set")
79
+
80
+ parser.add_argument('--bert_model', type=str, default=None, help="Use an external bert model (requires the transformers package)")
81
+ parser.add_argument('--no_bert_model', dest='bert_model', action="store_const", const=None, help="Don't use bert")
82
+ parser.add_argument('--bert_hidden_layers', type=int, default=None, help="How many layers of hidden state to use from the transformer")
83
+ parser.add_argument('--bert_finetune', default=False, action='store_true', help='Finetune the bert (or other transformer)')
84
+ parser.add_argument('--gradient_checkpointing', default=False, action='store_true', help='Checkpoint intermediate gradients between layers to save memory at the cost of training steps')
85
+ parser.add_argument('--no_bert_finetune', dest='bert_finetune', action='store_false', help="Don't finetune the bert (or other transformer)")
86
+ parser.add_argument('--bert_learning_rate', default=1.0, type=float, help='Scale the learning rate for transformer finetuning by this much')
87
+ parser.add_argument('--second_optim', type=str, default=None, help='once first optimizer converged, tune the model again. with: sgd, adagrad, adam or adamax.')
88
+ parser.add_argument('--second_bert_learning_rate', default=0, type=float, help='Secondary stage transformer finetuning learning rate scale')
89
+
90
+ parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.')
91
+ parser.add_argument('--optim', type=str, default='sgd', help='sgd, adagrad, adam or adamax.')
92
+ parser.add_argument('--lr', type=float, default=0.1, help='Learning rate.')
93
+ parser.add_argument('--min_lr', type=float, default=1e-4, help='Minimum learning rate to stop training.')
94
+ parser.add_argument('--second_lr', type=float, default=5e-3, help='Secondary learning rate')
95
+ parser.add_argument('--momentum', type=float, default=0, help='Momentum for SGD.')
96
+ parser.add_argument('--lr_decay', type=float, default=0.5, help="LR decay rate.")
97
+ parser.add_argument('--patience', type=int, default=3, help="Patience for LR decay.")
98
+
99
+ parser.add_argument('--connect_output_layers', action='store_true', default=False, help='Connect one output layer to the input of the next output layer. By default, those layers are all separate')
100
+ parser.add_argument('--predict_tagset', type=int, default=None, help='Which tagset to predict if there are multiple tagsets. Will default to 0. Default of None allows the model to remember the value from training time, but be overridden at test time')
101
+
102
+ parser.add_argument('--ignore_tag_scores', type=str, default=None, help="Which tags to ignore, if any, when scoring dev & test sets")
103
+
104
+ parser.add_argument('--max_steps', type=int, default=200000)
105
+ parser.add_argument('--max_steps_no_improve', type=int, default=2500, help='if the model doesn\'t improve after this many steps, give up or switch to new optimizer.')
106
+ parser.add_argument('--eval_interval', type=int, default=500)
107
+ parser.add_argument('--batch_size', type=int, default=32)
108
+ parser.add_argument('--max_batch_words', type=int, default=800, help='Long sentences can overwhelm even a large GPU when finetuning a transformer on otherwise reasonable batch sizes. This cuts off those batches early')
109
+ parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.')
110
+ parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.')
111
+ parser.add_argument('--log_norms', action='store_true', default=False, help='Log the norms of all the parameters (noisy!)')
112
+ parser.add_argument('--save_dir', type=str, default='saved_models/ner', help='Root dir for saving models.')
113
+ parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_{finetune}_nertagger.pt", help="File name to save the model")
114
+
115
+ parser.add_argument('--seed', type=int, default=1234)
116
+ utils.add_device_args(parser)
117
+
118
+ parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
119
+ parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
120
+ return parser
121
+
122
+ def parse_args(args=None):
123
+ parser = build_argparse()
124
+ add_peft_args(parser)
125
+ args = parser.parse_args(args=args)
126
+ resolve_peft_args(args, logger)
127
+
128
+ if args.wandb_name:
129
+ args.wandb = True
130
+
131
+ args = vars(args)
132
+ return args
133
+
134
+ def main(args=None):
135
+ args = parse_args(args=args)
136
+
137
+ utils.set_random_seed(args['seed'])
138
+
139
+ logger.info("Running NER tagger in {} mode".format(args['mode']))
140
+
141
+ if args['mode'] == 'train':
142
+ return train(args)
143
+ else:
144
+ evaluate(args)
145
+
146
+ def load_pretrain(args):
147
+ # load pretrained vectors
148
+ if args['wordvec_pretrain_file']:
149
+ pretrain_file = args['wordvec_pretrain_file']
150
+ pretrain = Pretrain(pretrain_file, None, args['pretrain_max_vocab'], save_to_file=False)
151
+ else:
152
+ if len(args['wordvec_file']) == 0:
153
+ vec_file = utils.get_wordvec_file(args['wordvec_dir'], args['shorthand'])
154
+ else:
155
+ vec_file = args['wordvec_file']
156
+ # do not save pretrained embeddings individually
157
+ pretrain = Pretrain(None, vec_file, args['pretrain_max_vocab'], save_to_file=False)
158
+ return pretrain
159
+
160
+ def model_file_name(args):
161
+ return utils.standard_model_file_name(args, "nertagger")
162
+
163
+ def get_known_tags(tags):
164
+ """
165
+ Tags are stored in the dataset as a list of list of tags
166
+
167
+ This returns a sorted list for each column of tags in the dataset
168
+ """
169
+ max_columns = max(len(word) for sent in tags for word in sent)
170
+ known_tags = [set() for _ in range(max_columns)]
171
+ for sent in tags:
172
+ for word in sent:
173
+ for tag_idx, tag in enumerate(word):
174
+ known_tags[tag_idx].add(tag)
175
+ return [sorted(x) for x in known_tags]
176
+
177
+ def warn_missing_tags(tag_vocab, data_tags, error_msg, bioes_to_bio=False):
178
+ """
179
+ Check for tags missing from the tag_vocab.
180
+
181
+ Given a tag_vocab and the known tags in the format used by
182
+ ner.data, go through the tags in the dataset and look for any
183
+ which aren't in the tag_vocab.
184
+
185
+ error_msg is something like "training set" or "eval file" to
186
+ indicate where the missing tags came from.
187
+ """
188
+ tag_depth = max(max(len(tags) for tags in sentence) for sentence in data_tags)
189
+
190
+ if tag_depth != len(tag_vocab.lens()):
191
+ logger.warning("Test dataset has a different number of tag types compared to the model: %d vs %d", tag_depth, len(tag_vocab.lens()))
192
+ for tag_set_idx in range(min(tag_depth, len(tag_vocab.lens()))):
193
+ tag_set = tag_vocab.items(tag_set_idx)
194
+ if len(tag_vocab.lens()) > 1:
195
+ current_error_msg = error_msg + " tag set %d" % tag_set_idx
196
+ else:
197
+ current_error_msg = error_msg
198
+
199
+ current_tags = set([word[tag_set_idx] for sentence in data_tags for word in sentence])
200
+ if bioes_to_bio:
201
+ current_tags = set([re.sub("^E-", "I-", re.sub("^S-", "B-", x)) for x in current_tags])
202
+ utils.warn_missing_tags(tag_set, current_tags, current_error_msg)
203
+
204
+ def train(args):
205
+ model_file = model_file_name(args)
206
+
207
+ save_dir, save_name = os.path.split(model_file)
208
+ utils.ensure_dir(save_dir)
209
+ if args['save_dir'] is None:
210
+ args['save_dir'] = save_dir
211
+ args['save_name'] = save_name
212
+
213
+ utils.log_training_args(args, logger)
214
+
215
+ pretrain = None
216
+ vocab = None
217
+ trainer = None
218
+
219
+ if args['finetune'] and args['finetune_load_name']:
220
+ logger.warning('Finetune is ON. Using model from "{}"'.format(args['finetune_load_name']))
221
+ _, trainer, vocab = load_model(args, args['finetune_load_name'])
222
+ elif args['finetune'] and os.path.exists(model_file):
223
+ logger.warning('Finetune is ON. Using model from "{}"'.format(model_file))
224
+ _, trainer, vocab = load_model(args, model_file)
225
+ else:
226
+ if args['finetune']:
227
+ raise FileNotFoundError('Finetune is set to true but model file is not found: {}'.format(model_file))
228
+
229
+ pretrain = load_pretrain(args)
230
+
231
+ if pretrain is not None:
232
+ word_emb_dim = pretrain.emb.shape[1]
233
+ if args['word_emb_dim'] and args['word_emb_dim'] != word_emb_dim:
234
+ logger.warning("Embedding file has a dimension of {}. Model will be built with that size instead of {}".format(word_emb_dim, args['word_emb_dim']))
235
+ args['word_emb_dim'] = word_emb_dim
236
+
237
+ if args['charlm']:
238
+ if args['charlm_shorthand'] is None:
239
+ raise ValueError("CharLM Shorthand is required for loading pretrained CharLM model...")
240
+ logger.info('Using pretrained contextualized char embedding')
241
+ if not args['charlm_forward_file']:
242
+ args['charlm_forward_file'] = '{}/{}_forward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])
243
+ if not args['charlm_backward_file']:
244
+ args['charlm_backward_file'] = '{}/{}_backward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])
245
+
246
+ # load data
247
+ logger.info("Loading training data with batch size %d from %s", args['batch_size'], args['train_file'])
248
+ with open(args['train_file']) as fin:
249
+ train_doc = Document(json.load(fin))
250
+ logger.info("Loaded %d sentences of training data", len(train_doc.sentences))
251
+ if len(train_doc.sentences) == 0:
252
+ raise ValueError("File %s exists but has no usable training data" % args['train_file'])
253
+ train_batch = DataLoader(train_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=False, scheme=args.get('train_scheme'), max_batch_words=args['max_batch_words'])
254
+ vocab = train_batch.vocab
255
+ logger.info("Loading dev data from %s", args['eval_file'])
256
+ with open(args['eval_file']) as fin:
257
+ dev_doc = Document(json.load(fin))
258
+ logger.info("Loaded %d sentences of dev data", len(dev_doc.sentences))
259
+ if len(dev_doc.sentences) == 0:
260
+ raise ValueError("File %s exists but has no usable dev data" % args['train_file'])
261
+ dev_batch = DataLoader(dev_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=True)
262
+
263
+ train_tags = get_known_tags(train_batch.tags)
264
+ logger.info("Training data has %d columns of tags", len(train_tags))
265
+ for tag_idx, tags in enumerate(train_tags):
266
+ logger.info("Tags present in training set at column %d:\n Tags without BIES markers: %s\n Tags with B-, I-, E-, or S-: %s",
267
+ tag_idx,
268
+ " ".join(sorted(set(i for i in tags if i[:2] not in ('B-', 'I-', 'E-', 'S-')))),
269
+ " ".join(sorted(set(i[2:] for i in tags if i[:2] in ('B-', 'I-', 'E-', 'S-')))))
270
+
271
+ # skip training if the language does not have training or dev data
272
+ if len(train_batch) == 0 or len(dev_batch) == 0:
273
+ logger.info("Skip training because no data available...")
274
+ return
275
+
276
+ logger.info("Training tagger...")
277
+ if trainer is None: # init if model was not loaded previously from file
278
+ trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, device=args['device'],
279
+ train_classifier_only=args['train_classifier_only'])
280
+
281
+ if args['finetune']:
282
+ warn_missing_tags(trainer.vocab['tag'], train_batch.tags, "training set")
283
+ # the evaluation will coerce the tags to the proper scheme,
284
+ # so we won't need to alert for not having S- or E- tags
285
+ bioes_to_bio = args['train_scheme'] == 'bio' and args['scheme'] == 'bioes'
286
+ warn_missing_tags(trainer.vocab['tag'], dev_batch.tags, "dev set", bioes_to_bio=bioes_to_bio)
287
+
288
+ # TODO: might still want to add multiple layers of tag evaluation to the scorer
289
+ dev_gold_tags = [[x[trainer.args['predict_tagset']] for x in tags] for tags in dev_batch.tags]
290
+
291
+ logger.info(trainer.model)
292
+
293
+ global_step = 0
294
+ max_steps = args['max_steps']
295
+ dev_score_history = []
296
+ best_dev_preds = []
297
+ current_lr = trainer.optimizer.param_groups[0]['lr']
298
+ global_start_time = time.time()
299
+ format_str = '{}: step {}/{}, loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'
300
+
301
+ # LR scheduling
302
+ if args['lr_decay'] > 0:
303
+ # learning rate changes on plateau -- no improvement on model for patience number of epochs
304
+ # change is made as a factor of the learning rate decay
305
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(trainer.optimizer, mode='max', factor=args['lr_decay'],
306
+ patience=args['patience'], verbose=True, min_lr=args['min_lr'])
307
+ else:
308
+ scheduler = None
309
+
310
+ if args['wandb']:
311
+ import wandb
312
+ wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_ner" % args['shorthand']
313
+ wandb.init(name=wandb_name, config=args)
314
+ wandb.run.define_metric('train_loss', summary='min')
315
+ wandb.run.define_metric('dev_score', summary='max')
316
+ # track gradients!
317
+ wandb.watch(trainer.model, log_freq=4, log="gradients")
318
+
319
+ # start training
320
+ last_best_step = 0
321
+ train_loss = 0
322
+ is_second_optim = False
323
+ while True:
324
+ should_stop = False
325
+ for i, batch in enumerate(train_batch):
326
+ start_time = time.time()
327
+ global_step += 1
328
+ loss = trainer.update(batch, eval=False) # update step
329
+ train_loss += loss
330
+ if global_step % args['log_step'] == 0:
331
+ duration = time.time() - start_time
332
+ logger.info(format_str.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), global_step,
333
+ max_steps, loss, duration, current_lr))
334
+ if global_step % args['eval_interval'] == 0:
335
+ # eval on dev
336
+ logger.info("Evaluating on dev set...")
337
+ dev_preds = []
338
+ for batch in dev_batch:
339
+ preds = trainer.predict(batch)
340
+ dev_preds += preds
341
+ _, _, dev_score, _ = scorer.score_by_entity(dev_preds, dev_gold_tags, ignore_tags=args['ignore_tag_scores'])
342
+
343
+ train_loss = train_loss / args['eval_interval'] # avg loss per batch
344
+ logger.info("step {}: train_loss = {:.6f}, dev_score = {:.4f}".format(global_step, train_loss, dev_score))
345
+ if args['wandb']:
346
+ wandb.log({'train_loss': train_loss, 'dev_score': dev_score})
347
+ train_loss = 0
348
+
349
+ # save best model
350
+ if len(dev_score_history) == 0 or dev_score > max(dev_score_history):
351
+ trainer.save(model_file)
352
+ last_best_step = global_step
353
+ logger.info("New best model saved.")
354
+ best_dev_preds = dev_preds
355
+
356
+ dev_score_history += [dev_score]
357
+ logger.info("")
358
+
359
+ # lr schedule
360
+ if scheduler is not None:
361
+ scheduler.step(dev_score)
362
+
363
+ if args['log_norms']:
364
+ trainer.model.log_norms()
365
+
366
+ # check stopping
367
+ current_lr = trainer.optimizer.param_groups[0]['lr']
368
+ if (global_step - last_best_step) >= args['max_steps_no_improve'] or global_step >= args['max_steps'] or current_lr <= args['min_lr']:
369
+ if (global_step - last_best_step) >= args['max_steps_no_improve']:
370
+ logger.info("{} steps without improvement...".format((global_step - last_best_step)))
371
+ if not is_second_optim and args['second_optim'] is not None:
372
+ logger.info("Switching to second optimizer: {}".format(args['second_optim']))
373
+ logger.info('Reloading best model to continue from current local optimum')
374
+ trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, device=args['device'],
375
+ train_classifier_only=args['train_classifier_only'], model_file=model_file, second_optim=True)
376
+ is_second_optim = True
377
+ last_best_step = global_step
378
+ current_lr = trainer.optimizer.param_groups[0]['lr']
379
+ else:
380
+ logger.info("stopping...")
381
+ should_stop = True
382
+ break
383
+
384
+ if should_stop:
385
+ break
386
+
387
+ train_batch.reshuffle()
388
+
389
+ logger.info("Training ended with {} steps.".format(global_step))
390
+
391
+ if args['wandb']:
392
+ wandb.finish()
393
+
394
+ if len(dev_score_history) > 0:
395
+ best_f, best_eval = max(dev_score_history)*100, np.argmax(dev_score_history)+1
396
+ logger.info("Best dev F1 = {:.2f}, at iteration = {}".format(best_f, best_eval * args['eval_interval']))
397
+ else:
398
+ logger.info("Dev set never evaluated. Saving final model.")
399
+ trainer.save(model_file)
400
+
401
+ return trainer
402
+
403
+ def write_ner_results(filename, batch, preds, predict_tagset):
404
+ if len(batch.tags) != len(preds):
405
+ raise ValueError("Unexpected batch vs pred lengths: %d vs %d" % (len(batch.tags), len(preds)))
406
+
407
+ with open(filename, "w", encoding="utf-8") as fout:
408
+ tag_idx = 0
409
+ for b in batch:
410
+ # b[0] is words, b[5] is orig_idx
411
+ # a namedtuple would make this cleaner without being much slower
412
+ text = utils.unsort(b[0], b[5])
413
+ for sentence in text:
414
+ # TODO: if we change the predict_tagset mechanism, will have to change this
415
+ sentence_gold = [x[predict_tagset] for x in batch.tags[tag_idx]]
416
+ sentence_pred = preds[tag_idx]
417
+ tag_idx += 1
418
+ for word, gold, pred in zip(sentence, sentence_gold, sentence_pred):
419
+ fout.write("%s\t%s\t%s\n" % (word, gold, pred))
420
+ fout.write("\n")
421
+
422
+ def evaluate(args):
423
+ # file paths
424
+ model_file = model_file_name(args)
425
+
426
+ loaded_args, trainer, vocab = load_model(args, model_file)
427
+ return evaluate_model(loaded_args, trainer, vocab, args['eval_file'])
428
+
429
+ def evaluate_model(loaded_args, trainer, vocab, eval_file):
430
+ if loaded_args['log_norms']:
431
+ trainer.model.log_norms()
432
+
433
+ model_file = os.path.join(loaded_args['save_dir'], loaded_args['save_name'])
434
+ logger.debug("Loaded model for eval from %s", model_file)
435
+ logger.debug("Using the %d tagset for evaluation", loaded_args['predict_tagset'])
436
+
437
+ # load data
438
+ logger.info("Loading data with batch size {}...".format(loaded_args['batch_size']))
439
+ with open(eval_file) as fin:
440
+ doc = Document(json.load(fin))
441
+ batch = DataLoader(doc, loaded_args['batch_size'], loaded_args, vocab=vocab, evaluation=True, bert_tokenizer=trainer.model.bert_tokenizer)
442
+ bioes_to_bio = loaded_args['train_scheme'] == 'bio' and loaded_args['scheme'] == 'bioes'
443
+ warn_missing_tags(trainer.vocab['tag'], batch.tags, "eval_file", bioes_to_bio=bioes_to_bio)
444
+
445
+ logger.info("Start evaluation...")
446
+ preds = []
447
+ for i, b in enumerate(batch):
448
+ preds += trainer.predict(b)
449
+
450
+ gold_tags = batch.tags
451
+ # TODO: might still want to add multiple layers of tag evaluation to the scorer
452
+ gold_tags = [[x[trainer.args['predict_tagset']] for x in tags] for tags in gold_tags]
453
+
454
+ _, _, score, entity_f1 = scorer.score_by_entity(preds, gold_tags, ignore_tags=loaded_args['ignore_tag_scores'])
455
+ _, _, _, confusion = scorer.score_by_token(preds, gold_tags, ignore_tags=loaded_args['ignore_tag_scores'])
456
+ logger.info("Weighted f1 for non-O tokens: %5f", confusion_to_weighted_f1(confusion, exclude=["O"]))
457
+
458
+ logger.info("NER tagger score: %s %s %s %.2f", loaded_args['shorthand'], model_file, eval_file, score*100)
459
+ entity_f1_lines = ["%s: %.2f" % (x, y*100) for x, y in entity_f1.items()]
460
+ logger.info("NER Entity F1 scores:\n %s", "\n ".join(entity_f1_lines))
461
+ logger.info("NER token confusion matrix:\n{}".format(format_confusion(confusion)))
462
+
463
+ if loaded_args['eval_output_file']:
464
+ write_ner_results(loaded_args['eval_output_file'], batch, preds, trainer.args['predict_tagset'])
465
+
466
+ return confusion
467
+
468
+ def load_model(args, model_file):
469
+ # load model
470
+ charlm_args = {}
471
+ if 'charlm_forward_file' in args:
472
+ charlm_args['charlm_forward_file'] = args['charlm_forward_file']
473
+ if 'charlm_backward_file' in args:
474
+ charlm_args['charlm_backward_file'] = args['charlm_backward_file']
475
+ if args['predict_tagset'] is not None:
476
+ charlm_args['predict_tagset'] = args['predict_tagset']
477
+ pretrain = load_pretrain(args)
478
+ trainer = Trainer(args=charlm_args, model_file=model_file, pretrain=pretrain, device=args['device'], train_classifier_only=args['train_classifier_only'])
479
+ loaded_args, vocab = trainer.args, trainer.vocab
480
+
481
+ # load config
482
+ for k in args:
483
+ if k.endswith('_dir') or k.endswith('_file') or k in ['batch_size', 'ignore_tag_scores', 'log_norms', 'mode', 'scheme', 'shorthand']:
484
+ loaded_args[k] = args[k]
485
+ save_dir, save_name = os.path.split(model_file)
486
+ args['save_dir'] = save_dir
487
+ args['save_name'] = save_name
488
+ return loaded_args, trainer, vocab
489
+
490
+
491
+ if __name__ == '__main__':
492
+ main()
stanza/stanza/models/tagger.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Entry point for training and evaluating a POS/morphological features tagger.
3
+
4
+ This tagger uses highway BiLSTM layers with character and word-level representations, and biaffine classifiers
5
+ to produce consistent POS and UFeats predictions.
6
+ For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf.
7
+ """
8
+
9
+ import argparse
10
+ import logging
11
+ import os
12
+ import time
13
+ import zipfile
14
+
15
+ import numpy as np
16
+ import torch
17
+ from torch import nn, optim
18
+
19
+ from stanza.models.pos.data import Dataset, ShuffledDataset
20
+ from stanza.models.pos.trainer import Trainer
21
+ from stanza.models.pos import scorer
22
+ from stanza.models.common import utils
23
+ from stanza.models.common import pretrain
24
+ from stanza.models.common.doc import *
25
+ from stanza.models.common.foundation_cache import FoundationCache
26
+ from stanza.models.common.peft_config import add_peft_args, resolve_peft_args
27
+ from stanza.models import _training_logging
28
+ from stanza.utils.conll import CoNLL
29
+
30
+ logger = logging.getLogger('stanza')
31
+
32
+ def build_argparse():
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument('--data_dir', type=str, default='data/pos', help='Root dir for saving models.')
35
+ parser.add_argument('--wordvec_dir', type=str, default='extern_data/wordvec', help='Directory of word vectors.')
36
+ parser.add_argument('--wordvec_file', type=str, default=None, help='Word vectors filename.')
37
+ parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
38
+ parser.add_argument('--train_file', type=str, default=None, help='Input file for training.')
39
+ parser.add_argument('--eval_file', type=str, default=None, help='Input file for scoring.')
40
+ parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.')
41
+ parser.add_argument('--no_gold_labels', dest='gold_labels', action='store_false', help="Don't score the eval file - perhaps it has no gold labels, for example. Cannot be used at training time")
42
+
43
+ parser.add_argument('--mode', default='train', choices=['train', 'predict'])
44
+ parser.add_argument('--lang', type=str, help='Language')
45
+ parser.add_argument('--shorthand', type=str, help="Treebank shorthand")
46
+
47
+ parser.add_argument('--hidden_dim', type=int, default=200)
48
+ parser.add_argument('--char_hidden_dim', type=int, default=400)
49
+ parser.add_argument('--deep_biaff_hidden_dim', type=int, default=400)
50
+ parser.add_argument('--composite_deep_biaff_hidden_dim', type=int, default=100)
51
+ parser.add_argument('--word_emb_dim', type=int, default=75, help='Dimension of the finetuned word embedding. Set to 0 to turn off')
52
+ parser.add_argument('--word_cutoff', type=int, default=7, help='How common a word must be to include it in the finetuned word embedding')
53
+ parser.add_argument('--char_emb_dim', type=int, default=100)
54
+ parser.add_argument('--tag_emb_dim', type=int, default=50)
55
+ parser.add_argument('--charlm_transform_dim', type=int, default=None, help='Transform the pretrained charlm to this dimension. If not set, no transform is used')
56
+ parser.add_argument('--transformed_dim', type=int, default=125)
57
+ parser.add_argument('--num_layers', type=int, default=2)
58
+ parser.add_argument('--char_num_layers', type=int, default=1)
59
+ parser.add_argument('--pretrain_max_vocab', type=int, default=250000)
60
+ parser.add_argument('--word_dropout', type=float, default=0.33)
61
+ parser.add_argument('--dropout', type=float, default=0.5)
62
+ parser.add_argument('--rec_dropout', type=float, default=0, help="Recurrent dropout")
63
+ parser.add_argument('--char_rec_dropout', type=float, default=0, help="Recurrent dropout")
64
+
65
+ # TODO: refactor charlm arguments for models which use it?
66
+ parser.add_argument('--no_char', dest='char', action='store_false', help="Turn off character model.")
67
+ parser.add_argument('--char_bidirectional', dest='char_bidirectional', action='store_true', help="Use a bidirectional version of the non-pretrained charlm. Doesn't help much, makes the models larger")
68
+ parser.add_argument('--char_lowercase', dest='char_lowercase', action='store_true', help="Use lowercased characters in character model.")
69
+ parser.add_argument('--charlm', action='store_true', help="Turn on contextualized char embedding using pretrained character-level language model.")
70
+ parser.add_argument('--charlm_save_dir', type=str, default='saved_models/charlm', help="Root dir for pretrained character-level language model.")
71
+ parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
72
+ parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
73
+ parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
74
+
75
+ parser.add_argument('--bert_model', type=str, default=None, help="Use an external bert model (requires the transformers package)")
76
+ parser.add_argument('--no_bert_model', dest='bert_model', action="store_const", const=None, help="Don't use bert")
77
+ parser.add_argument('--bert_hidden_layers', type=int, default=None, help="How many layers of hidden state to use from the transformer")
78
+ parser.add_argument('--bert_finetune', default=False, action='store_true', help='Finetune the bert (or other transformer)')
79
+ parser.add_argument('--no_bert_finetune', dest='bert_finetune', action='store_false', help="Don't finetune the bert (or other transformer)")
80
+ parser.add_argument('--bert_learning_rate', default=1.0, type=float, help='Scale the learning rate for transformer finetuning by this much')
81
+
82
+ parser.add_argument('--no_pretrain', dest='pretrain', action='store_false', help="Turn off pretrained embeddings.")
83
+ parser.add_argument('--share_hid', action='store_true', help="Share hidden representations for UPOS, XPOS and UFeats.")
84
+ parser.set_defaults(share_hid=False)
85
+
86
+ parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.')
87
+ parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam, adamw, adamax, or adadelta. madgrad as an optional dependency')
88
+ parser.add_argument('--second_optim', type=str, default='amsgrad', help='Optimizer for the second half of training. Default is Adam with AMSGrad')
89
+ parser.add_argument('--second_optim_reload', default=False, action='store_true', help='Reload the best model instead of continuing from current model if the first optimizer stalls out. This does not seem to help, but might be useful for further experiments')
90
+ parser.add_argument('--no_second_optim', action='store_const', const=None, dest='second_optim', help="Don't use a second optimizer - only use the first optimizer")
91
+ parser.add_argument('--lr', type=float, default=3e-3, help='Learning rate')
92
+ parser.add_argument('--second_lr', type=float, default=None, help='Alternate learning rate for the second optimizer')
93
+ parser.add_argument('--initial_weight_decay', type=float, default=None, help='Optimizer weight decay for the first optimizer')
94
+ parser.add_argument('--second_weight_decay', type=float, default=None, help='Optimizer weight decay for the second optimizer')
95
+ parser.add_argument('--beta2', type=float, default=0.95)
96
+
97
+ parser.add_argument('--max_steps', type=int, default=50000)
98
+ parser.add_argument('--eval_interval', type=int, default=100)
99
+ parser.add_argument('--fix_eval_interval', dest='adapt_eval_interval', action='store_false', \
100
+ help="Use fixed evaluation interval for all treebanks, otherwise by default the interval will be increased for larger treebanks.")
101
+ parser.add_argument('--max_steps_before_stop', type=int, default=3000, help='Changes learning method or early terminates after this many steps if the dev scores are not improving')
102
+ parser.add_argument('--batch_size', type=int, default=250)
103
+ parser.add_argument('--batch_maximum_tokens', type=int, default=5000, help='When run in a Pipeline, limit a batch to this many tokens to help avoid OOM for long sentences')
104
+ parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Gradient clipping.')
105
+ parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.')
106
+ parser.add_argument('--log_norms', action='store_true', default=False, help='Log the norms of all the parameters (noisy!)')
107
+ parser.add_argument('--save_dir', type=str, default='saved_models/pos', help='Root dir for saving models.')
108
+ parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_tagger.pt", help="File name to save the model")
109
+ parser.add_argument('--save_each', default=False, action='store_true', help="Save each checkpoint to its own model. Will take up a bunch of space")
110
+
111
+ parser.add_argument('--seed', type=int, default=1234)
112
+ add_peft_args(parser)
113
+ utils.add_device_args(parser)
114
+
115
+ parser.add_argument('--augment_nopunct', type=float, default=None, help='Augment the training data by copying this fraction of punct-ending sentences as non-punct. Default of None will aim for roughly 50%%')
116
+
117
+ parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
118
+ parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
119
+ return parser
120
+
121
+ def parse_args(args=None):
122
+ parser = build_argparse()
123
+ args = parser.parse_args(args=args)
124
+ resolve_peft_args(args, logger)
125
+
126
+ if args.augment_nopunct is None:
127
+ args.augment_nopunct = 0.25
128
+
129
+ if args.wandb_name:
130
+ args.wandb = True
131
+
132
+ args = vars(args)
133
+ return args
134
+
135
+ def main(args=None):
136
+ args = parse_args(args=args)
137
+
138
+ utils.set_random_seed(args['seed'])
139
+
140
+ logger.info("Running tagger in {} mode".format(args['mode']))
141
+
142
+ if args['mode'] == 'train':
143
+ train(args)
144
+ else:
145
+ evaluate(args)
146
+
147
+ def model_file_name(args):
148
+ return utils.standard_model_file_name(args, "tagger")
149
+
150
+ def save_each_file_name(args):
151
+ model_file = model_file_name(args)
152
+ pieces = os.path.splitext(model_file)
153
+ return pieces[0] + "_%05d" + pieces[1]
154
+
155
+ def load_pretrain(args):
156
+ pt = None
157
+ if args['pretrain']:
158
+ pretrain_file = pretrain.find_pretrain_file(args['wordvec_pretrain_file'], args['save_dir'], args['shorthand'], args['lang'])
159
+ if os.path.exists(pretrain_file):
160
+ vec_file = None
161
+ else:
162
+ vec_file = args['wordvec_file'] if args['wordvec_file'] else utils.get_wordvec_file(args['wordvec_dir'], args['shorthand'])
163
+ pt = pretrain.Pretrain(pretrain_file, vec_file, args['pretrain_max_vocab'])
164
+ return pt
165
+
166
+ def get_eval_type(dev_batch):
167
+ """
168
+ If there is only one column to score in the dev set, use that instead of AllTags
169
+ """
170
+ if dev_batch.has_xpos and not dev_batch.has_upos and not dev_batch.has_feats:
171
+ return "XPOS"
172
+ elif dev_batch.has_upos and not dev_batch.has_xpos and not dev_batch.has_feats:
173
+ return "UPOS"
174
+ else:
175
+ return "AllTags"
176
+
177
+ def load_training_data(args, pretrain):
178
+ train_docs = []
179
+ raw_train_files = args['train_file'].split(";")
180
+ train_files = []
181
+ for train_file in raw_train_files:
182
+ if zipfile.is_zipfile(train_file):
183
+ logger.info("Decompressing %s" % train_file)
184
+ with zipfile.ZipFile(train_file) as zin:
185
+ for zipped_train_file in zin.namelist():
186
+ with zin.open(zipped_train_file) as fin:
187
+ logger.info("Reading %s from %s" % (zipped_train_file, train_file))
188
+ train_str = fin.read()
189
+ train_str = train_str.decode("utf-8")
190
+ train_file_data, _, _ = CoNLL.conll2dict(input_str=train_str)
191
+ logger.info("Train File {} from {}, Data Size: {}".format(zipped_train_file, train_file, len(train_file_data)))
192
+ train_docs.append(Document(train_file_data))
193
+ train_files.append("%s %s" % (train_file, zipped_train_file))
194
+ else:
195
+ logger.info("Reading %s" % train_file)
196
+ # train_data is now a list of sentences, where each sentence is a
197
+ # list of words, in which each word is a dict of conll attributes
198
+ train_file_data, _, _ = CoNLL.conll2dict(input_file=train_file)
199
+ logger.info("Train File {}, Data Size: {}".format(train_file, len(train_file_data)))
200
+ train_docs.append(Document(train_file_data))
201
+ train_files.append(train_file)
202
+ if sum(len(x.sentences) for x in train_docs) == 0:
203
+ raise RuntimeError("Training data for the tagger is empty: %s" % args['train_file'])
204
+ # we want to ensure that the model is able te output _ for empty columns,
205
+ # but create batches whereby if a doc has upos/xpos tags we include them all.
206
+ # therefore, we create seperate datasets and loaders for each input training file,
207
+ # which will ensure the system be able to see batches with both upos available
208
+ # and upos unavailable depending on what the availability in the file is.
209
+ vocab = Dataset.init_vocab(train_docs, args)
210
+ train_data = [Dataset(i, args, pretrain, vocab=vocab, evaluation=False)
211
+ for i in train_docs]
212
+ for train_file, td in zip(train_files, train_data):
213
+ if not td.has_upos:
214
+ logger.info("No UPOS in %s" % train_file)
215
+ if not td.has_xpos:
216
+ logger.info("No XPOS in %s" % train_file)
217
+ if not td.has_feats:
218
+ logger.info("No feats in %s" % train_file)
219
+
220
+ # reject partially tagged upos or xpos documents
221
+ # otherwise, the model will learn to output blanks for some words,
222
+ # which is probably a confusing result
223
+ # (and definitely throws off the depparse)
224
+ # another option would be to treat those as masked out
225
+ for td_idx, td in enumerate(train_data):
226
+ if td.has_upos:
227
+ upos_data = td.doc.get(UPOS, as_sentences=True)
228
+ for sentence_idx, sentence in enumerate(upos_data):
229
+ for word_idx, upos in enumerate(sentence):
230
+ if upos == '_' or upos is None:
231
+ conll = "{:C}".format(td.doc.sentences[sentence_idx])
232
+ raise RuntimeError("Found a blank tag in the UPOS at sentence %d word %d of %s.\n%s" % ((sentence_idx+1), (word_idx+1), train_files[td_idx], conll))
233
+
234
+ # here we make sure the model will learn to output _ for empty columns
235
+ # if *any* dataset has data for the upos, xpos, or feature column,
236
+ # we consider that data enough to train the model on that column
237
+ # otherwise, we want to train the model to always output blanks
238
+ if not any(td.has_upos for td in train_data):
239
+ for td in train_data:
240
+ td.has_upos = True
241
+ if not any(td.has_xpos for td in train_data):
242
+ for td in train_data:
243
+ td.has_xpos = True
244
+ if not any(td.has_feats for td in train_data):
245
+ for td in train_data:
246
+ td.has_feats = True
247
+ # calculate the batches
248
+ train_batches = ShuffledDataset(train_data, args["batch_size"])
249
+ return vocab, train_data, train_batches
250
+
251
+ def train(args):
252
+ model_file = model_file_name(args)
253
+ utils.ensure_dir(os.path.split(model_file)[0])
254
+
255
+ if args['save_each']:
256
+ # so models.pt -> models_0001.pt, etc
257
+ model_save_each_file = save_each_file_name(args)
258
+ logger.info("Saving each checkpoint to %s" % model_save_each_file)
259
+
260
+ # load pretrained vectors if needed
261
+ pretrain = load_pretrain(args)
262
+
263
+ if args['charlm']:
264
+ if args['charlm_shorthand'] is None:
265
+ raise ValueError("CharLM Shorthand is required for loading pretrained CharLM model...")
266
+ logger.info('Using pretrained contextualized char embedding')
267
+ if not args['charlm_forward_file']:
268
+ args['charlm_forward_file'] = '{}/{}_forward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])
269
+ if not args['charlm_backward_file']:
270
+ args['charlm_backward_file'] = '{}/{}_backward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])
271
+
272
+ # load data
273
+ logger.info("Loading data with batch size {}...".format(args['batch_size']))
274
+ vocab, train_data, train_batches = load_training_data(args, pretrain)
275
+
276
+ dev_doc = CoNLL.conll2doc(input_file=args['eval_file'])
277
+ dev_data = Dataset(dev_doc, args, pretrain, vocab=vocab, evaluation=True, sort_during_eval=True)
278
+ dev_batch = dev_data.to_loader(batch_size=args["batch_size"])
279
+
280
+ eval_type = get_eval_type(dev_data)
281
+
282
+ # pred and gold path
283
+ system_pred_file = args['output_file']
284
+
285
+ # skip training if the language does not have training or dev data
286
+ # sum(...) to check if all of the training files are empty
287
+ if sum(len(td) for td in train_data) == 0 or len(dev_data) == 0:
288
+ logger.info("Skip training because no data available...")
289
+ return
290
+
291
+ if args['wandb']:
292
+ import wandb
293
+ wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_tagger" % args['shorthand']
294
+ wandb.init(name=wandb_name, config=args)
295
+ wandb.run.define_metric('train_loss', summary='min')
296
+ wandb.run.define_metric('dev_score', summary='max')
297
+
298
+ logger.info("Training tagger...")
299
+ foundation_cache = FoundationCache()
300
+ trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, device=args['device'], foundation_cache=foundation_cache)
301
+
302
+ global_step = 0
303
+ max_steps = args['max_steps']
304
+ dev_score_history = []
305
+ best_dev_preds = []
306
+ current_lr = args['lr']
307
+ global_start_time = time.time()
308
+ format_str = 'Finished STEP {}/{}, loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'
309
+
310
+ if args['adapt_eval_interval']:
311
+ args['eval_interval'] = utils.get_adaptive_eval_interval(dev_data.num_examples, 2000, args['eval_interval'])
312
+ logger.info("Evaluating the model every {} steps...".format(args['eval_interval']))
313
+
314
+ if args['save_each']:
315
+ logger.info("Saving initial checkpoint to %s" % (model_save_each_file % global_step))
316
+ trainer.save(model_save_each_file % global_step)
317
+
318
+ using_amsgrad = False
319
+ last_best_step = 0
320
+ # start training
321
+ train_loss = 0
322
+ if args['log_norms']:
323
+ trainer.model.log_norms()
324
+ while True:
325
+ do_break = False
326
+ for i, batch in enumerate(train_batches):
327
+ start_time = time.time()
328
+ global_step += 1
329
+ loss = trainer.update(batch, eval=False) # update step
330
+ train_loss += loss
331
+ if global_step % args['log_step'] == 0:
332
+ duration = time.time() - start_time
333
+ logger.info(format_str.format(global_step, max_steps, loss, duration, current_lr))
334
+ if args['log_norms']:
335
+ trainer.model.log_norms()
336
+
337
+ if global_step % args['eval_interval'] == 0:
338
+ # eval on dev
339
+ logger.info("Evaluating on dev set...")
340
+ dev_preds = []
341
+ indices = []
342
+ for batch in dev_batch:
343
+ preds = trainer.predict(batch)
344
+ dev_preds += preds
345
+ indices.extend(batch[-1])
346
+ dev_preds = utils.unsort(dev_preds, indices)
347
+ dev_data.doc.set([UPOS, XPOS, FEATS], [y for x in dev_preds for y in x])
348
+ CoNLL.write_doc2conll(dev_data.doc, system_pred_file)
349
+
350
+ _, _, dev_score = scorer.score(system_pred_file, args['eval_file'], eval_type=eval_type)
351
+
352
+ train_loss = train_loss / args['eval_interval'] # avg loss per batch
353
+ logger.info("step {}: train_loss = {:.6f}, dev_score = {:.4f}".format(global_step, train_loss, dev_score))
354
+
355
+ if args['wandb']:
356
+ wandb.log({'train_loss': train_loss, 'dev_score': dev_score})
357
+
358
+ train_loss = 0
359
+
360
+ if args['save_each']:
361
+ logger.info("Saving checkpoint to %s" % (model_save_each_file % global_step))
362
+ trainer.save(model_save_each_file % global_step)
363
+
364
+ # save best model
365
+ if len(dev_score_history) == 0 or dev_score > max(dev_score_history):
366
+ last_best_step = global_step
367
+ trainer.save(model_file)
368
+ logger.info("new best model saved.")
369
+ best_dev_preds = dev_preds
370
+
371
+ dev_score_history += [dev_score]
372
+
373
+ if global_step - last_best_step >= args['max_steps_before_stop']:
374
+ if not using_amsgrad and args['second_optim'] is not None:
375
+ logger.info("Switching to second optimizer: {}".format(args['second_optim']))
376
+ if args['second_optim_reload']:
377
+ logger.info('Reloading best model to continue from current local optimum')
378
+ trainer = Trainer(args=args, vocab=trainer.vocab, pretrain=pretrain, model_file=model_file, device=args['device'], foundation_cache=foundation_cache)
379
+ last_best_step = global_step
380
+ using_amsgrad = True
381
+ lr = args['second_lr']
382
+ if lr is None:
383
+ lr = args['lr']
384
+ trainer.optimizer = utils.get_optimizer(args['second_optim'], trainer.model, lr=lr, betas=(.9, args['beta2']), eps=1e-6, weight_decay=args['second_weight_decay'])
385
+ else:
386
+ logger.info("Early termination: have not improved in {} steps".format(args['max_steps_before_stop']))
387
+ do_break = True
388
+ break
389
+
390
+ if global_step >= args['max_steps']:
391
+ do_break = True
392
+ break
393
+
394
+ if do_break: break
395
+
396
+ logger.info("Training ended with {} steps.".format(global_step))
397
+
398
+ if args['wandb']:
399
+ wandb.finish()
400
+
401
+ if len(dev_score_history) > 0:
402
+ best_f, best_eval = max(dev_score_history)*100, np.argmax(dev_score_history)+1
403
+ logger.info("Best dev F1 = {:.2f}, at iteration = {}".format(best_f, best_eval * args['eval_interval']))
404
+ else:
405
+ logger.info("Dev set never evaluated. Saving final model.")
406
+ trainer.save(model_file)
407
+
408
+
409
+ def evaluate(args):
410
+ # file paths
411
+ model_file = model_file_name(args)
412
+
413
+ pretrain = load_pretrain(args)
414
+
415
+ load_args = {'charlm_forward_file': args.get('charlm_forward_file', None),
416
+ 'charlm_backward_file': args.get('charlm_backward_file', None)}
417
+
418
+ # load model
419
+ logger.info("Loading model from: {}".format(model_file))
420
+ trainer = Trainer(pretrain=pretrain, model_file=model_file, device=args['device'], args=load_args)
421
+ evaluate_trainer(args, trainer, pretrain)
422
+
423
+ def evaluate_trainer(args, trainer, pretrain):
424
+ system_pred_file = args['output_file']
425
+ loaded_args, vocab = trainer.args, trainer.vocab
426
+
427
+ # load config
428
+ for k in args:
429
+ if k.endswith('_dir') or k.endswith('_file') or k in ['shorthand'] or k == 'mode':
430
+ loaded_args[k] = args[k]
431
+
432
+ # load data
433
+ logger.info("Loading data with batch size {}...".format(args['batch_size']))
434
+ doc = CoNLL.conll2doc(input_file=args['eval_file'])
435
+ dev_data = Dataset(doc, loaded_args, pretrain, vocab=vocab, evaluation=True, sort_during_eval=True)
436
+ dev_batch = dev_data.to_loader(batch_size=args['batch_size'])
437
+ eval_type = get_eval_type(dev_data)
438
+ if len(dev_batch) > 0:
439
+ logger.info("Start evaluation...")
440
+ preds = []
441
+ indices = []
442
+ with torch.no_grad():
443
+ for b in dev_batch:
444
+ preds += trainer.predict(b)
445
+ indices.extend(b[-1])
446
+ else:
447
+ # skip eval if dev data does not exist
448
+ preds = []
449
+ preds = utils.unsort(preds, indices)
450
+
451
+ # write to file and score
452
+ dev_data.doc.set([UPOS, XPOS, FEATS], [y for x in preds for y in x])
453
+ CoNLL.write_doc2conll(dev_data.doc, system_pred_file)
454
+
455
+ if args['gold_labels']:
456
+ _, _, score = scorer.score(system_pred_file, args['eval_file'], eval_type=eval_type)
457
+
458
+ logger.info("POS Tagger score: %s %.2f", args['shorthand'], score*100)
459
+
460
+ if __name__ == '__main__':
461
+ main()
stanza/stanza/models/tokenizer.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Entry point for training and evaluating a neural tokenizer.
3
+
4
+ This tokenizer treats tokenization and sentence segmentation as a tagging problem, and uses a combination of
5
+ recurrent and convolutional architectures.
6
+ For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf.
7
+
8
+ Updated: This new version of tokenizer model incorporates the dictionary feature, especially useful for languages that
9
+ have multi-syllable words such as Vietnamese, Chinese or Thai. In summary, a lexicon contains all unique words found in
10
+ training dataset and external lexicon (if any) is created during training and saved alongside the model after training.
11
+ Using this lexicon, a dictionary is created which includes "words", "prefixes" and "suffixes" sets. During data preparation,
12
+ dictionary features are extracted at each character position, to "look ahead" and "look backward" to see if any words formed
13
+ found in the dictionary. The window size (or the dictionary feature length) is defined at the 95-percentile among all the existing
14
+ words in the lexicon, this is to eliminate the less frequent but long words (avoid having a high-dimension feat vector). Prefixes
15
+ and suffixes are used to stop early during the window-dictionary checking process.
16
+ """
17
+
18
+ import argparse
19
+ from copy import copy
20
+ import logging
21
+ import random
22
+ import numpy as np
23
+ import os
24
+ import torch
25
+ import json
26
+ from stanza.models.common import utils
27
+ from stanza.models.tokenization.trainer import Trainer
28
+ from stanza.models.tokenization.data import DataLoader, TokenizationDataset
29
+ from stanza.models.tokenization.utils import load_mwt_dict, eval_model, output_predictions, load_lexicon, create_dictionary
30
+ from stanza.models import _training_logging
31
+
32
+ logger = logging.getLogger('stanza')
33
+
34
+ def build_argparse():
35
+ """
36
+ If args == None, the system args are used.
37
+ """
38
+ parser = argparse.ArgumentParser()
39
+ parser.add_argument('--txt_file', type=str, help="Input plaintext file")
40
+ parser.add_argument('--label_file', type=str, default=None, help="Character-level label file")
41
+ parser.add_argument('--mwt_json_file', type=str, default=None, help="JSON file for MWT expansions")
42
+ parser.add_argument('--conll_file', type=str, default=None, help="CoNLL file for output")
43
+ parser.add_argument('--dev_txt_file', type=str, help="(Train only) Input plaintext file for the dev set")
44
+ parser.add_argument('--dev_label_file', type=str, default=None, help="(Train only) Character-level label file for the dev set")
45
+ parser.add_argument('--dev_conll_gold', type=str, default=None, help="(Train only) CoNLL-U file for the dev set for early stopping")
46
+ parser.add_argument('--lang', type=str, help="Language")
47
+ parser.add_argument('--shorthand', type=str, help="UD treebank shorthand")
48
+
49
+ parser.add_argument('--mode', default='train', choices=['train', 'predict'])
50
+ parser.add_argument('--skip_newline', action='store_true', help="Whether to skip newline characters in input. Particularly useful for languages like Chinese.")
51
+
52
+ parser.add_argument('--emb_dim', type=int, default=32, help="Dimension of unit embeddings")
53
+ parser.add_argument('--hidden_dim', type=int, default=64, help="Dimension of hidden units")
54
+ parser.add_argument('--conv_filters', type=str, default="1,9", help="Configuration of conv filters. ,, separates layers and , separates filter sizes in the same layer.")
55
+ parser.add_argument('--no-residual', dest='residual', action='store_false', help="Add linear residual connections")
56
+ parser.add_argument('--no-hierarchical', dest='hierarchical', action='store_false', help="\"Hierarchical\" RNN tokenizer")
57
+ parser.add_argument('--hier_invtemp', type=float, default=0.5, help="Inverse temperature used in propagating tokenization predictions between RNN layers")
58
+ parser.add_argument('--input_dropout', action='store_true', help="Dropout input embeddings as well")
59
+ parser.add_argument('--conv_res', type=str, default=None, help="Convolutional residual layers for the RNN")
60
+ parser.add_argument('--rnn_layers', type=int, default=1, help="Layers of RNN in the tokenizer")
61
+ parser.add_argument('--use_dictionary', action='store_true', help="Use dictionary feature. The lexicon is created using the training data and external dict (if any) expected to be found under the same folder of training dataset, formatted as SHORTHAND-externaldict.txt where each line in this file is a word. For example, data/tokenize/zh_gsdsimp-externaldict.txt")
62
+
63
+ parser.add_argument('--max_grad_norm', type=float, default=1.0, help="Maximum gradient norm to clip to")
64
+ parser.add_argument('--anneal', type=float, default=.999, help="Anneal the learning rate by this amount when dev performance deteriorate")
65
+ parser.add_argument('--anneal_after', type=int, default=2000, help="Anneal the learning rate no earlier than this step")
66
+ parser.add_argument('--lr0', type=float, default=2e-3, help="Initial learning rate")
67
+ parser.add_argument('--dropout', type=float, default=0.33, help="Dropout probability")
68
+ parser.add_argument('--unit_dropout', type=float, default=0.33, help="Unit dropout probability")
69
+ parser.add_argument('--feat_dropout', type=float, default=0.05, help="Features dropout probability for each element in feature vector")
70
+ parser.add_argument('--feat_unit_dropout', type=float, default=0.33, help="The whole feature of units dropout probability")
71
+ parser.add_argument('--tok_noise', type=float, default=0.02, help="Probability to induce noise to the input of the higher RNN")
72
+ parser.add_argument('--sent_drop_prob', type=float, default=0.2, help="Probability to drop sentences at the end of batches during training uniformly at random. Idea is to fake paragraph endings.")
73
+ parser.add_argument('--last_char_drop_prob', type=float, default=0.2, help="Probability to drop the last char of a block of text during training, uniformly at random. Idea is to fake a document ending w/o sentence final punctuation, hopefully to avoid the tokenizer learning to always tokenize the last character as a period")
74
+ parser.add_argument('--weight_decay', type=float, default=0.0, help="Weight decay")
75
+ parser.add_argument('--max_seqlen', type=int, default=100, help="Maximum sequence length to consider at a time")
76
+ parser.add_argument('--batch_size', type=int, default=32, help="Batch size to use")
77
+ parser.add_argument('--epochs', type=int, default=10, help="Total epochs to train the model for")
78
+ parser.add_argument('--steps', type=int, default=50000, help="Steps to train the model for, if unspecified use epochs")
79
+ parser.add_argument('--report_steps', type=int, default=20, help="Update step interval to report loss")
80
+ parser.add_argument('--shuffle_steps', type=int, default=100, help="Step interval to shuffle each paragraph in the generator")
81
+ parser.add_argument('--eval_steps', type=int, default=200, help="Step interval to evaluate the model on the dev set for early stopping")
82
+ parser.add_argument('--max_steps_before_stop', type=int, default=5000, help='Early terminates after this many steps if the dev scores are not improving')
83
+ parser.add_argument('--save_name', type=str, default=None, help="File name to save the model")
84
+ parser.add_argument('--load_name', type=str, default=None, help="File name to load a saved model")
85
+ parser.add_argument('--save_dir', type=str, default='saved_models/tokenize', help="Directory to save models in")
86
+ utils.add_device_args(parser)
87
+ parser.add_argument('--seed', type=int, default=1234)
88
+
89
+ parser.add_argument('--use_mwt', dest='use_mwt', default=None, action='store_true', help='Whether or not to include mwt output layers. If set to None, this will be determined by examining the training data for MWTs')
90
+ parser.add_argument('--no_use_mwt', dest='use_mwt', action='store_false', help='Whether or not to include mwt output layers')
91
+
92
+ parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
93
+ parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
94
+ return parser
95
+
96
+ def parse_args(args=None):
97
+ parser = build_argparse()
98
+ args = parser.parse_args(args=args)
99
+
100
+ if args.wandb_name:
101
+ args.wandb = True
102
+
103
+ args = vars(args)
104
+ return args
105
+
106
+ def model_file_name(args):
107
+ if args['save_name'] is not None:
108
+ save_name = args['save_name']
109
+ else:
110
+ save_name = args['shorthand'] + "_tokenizer.pt"
111
+
112
+ if not os.path.exists(os.path.join(args['save_dir'], save_name)) and os.path.exists(save_name):
113
+ return save_name
114
+ return os.path.join(args['save_dir'], save_name)
115
+
116
+ def main(args=None):
117
+ args = parse_args(args=args)
118
+
119
+ utils.set_random_seed(args['seed'])
120
+
121
+ logger.info("Running tokenizer in {} mode".format(args['mode']))
122
+
123
+ args['feat_funcs'] = ['space_before', 'capitalized', 'numeric', 'end_of_para', 'start_of_para']
124
+ args['feat_dim'] = len(args['feat_funcs'])
125
+ args['save_name'] = model_file_name(args)
126
+ utils.ensure_dir(os.path.split(args['save_name'])[0])
127
+
128
+ if args['mode'] == 'train':
129
+ train(args)
130
+ else:
131
+ evaluate(args)
132
+
133
+ def train(args):
134
+ if args['use_dictionary']:
135
+ #load lexicon
136
+ lexicon, args['num_dict_feat'] = load_lexicon(args)
137
+ #create the dictionary
138
+ dictionary = create_dictionary(lexicon)
139
+ #adjust the feat_dim
140
+ args['feat_dim'] += args['num_dict_feat']*2
141
+ else:
142
+ args['num_dict_feat'] = 0
143
+ lexicon=None
144
+ dictionary=None
145
+
146
+ mwt_dict = load_mwt_dict(args['mwt_json_file'])
147
+
148
+ train_input_files = {
149
+ 'txt': args['txt_file'],
150
+ 'label': args['label_file']
151
+ }
152
+ train_batches = DataLoader(args, input_files=train_input_files, dictionary=dictionary)
153
+ vocab = train_batches.vocab
154
+
155
+ args['vocab_size'] = len(vocab)
156
+
157
+ dev_input_files = {
158
+ 'txt': args['dev_txt_file'],
159
+ 'label': args['dev_label_file']
160
+ }
161
+ dev_batches = TokenizationDataset(args, input_files=dev_input_files, vocab=vocab, evaluation=True, dictionary=dictionary)
162
+
163
+ if args['use_mwt'] is None:
164
+ args['use_mwt'] = train_batches.has_mwt()
165
+ logger.info("Found {}mwts in the training data. Setting use_mwt to {}".format(("" if args['use_mwt'] else "no "), args['use_mwt']))
166
+
167
+ trainer = Trainer(args=args, vocab=vocab, lexicon=lexicon, dictionary=dictionary, device=args['device'])
168
+
169
+ if args['load_name'] is not None:
170
+ load_name = os.path.join(args['save_dir'], args['load_name'])
171
+ trainer.load(load_name)
172
+ trainer.change_lr(args['lr0'])
173
+
174
+ N = len(train_batches)
175
+ steps = args['steps'] if args['steps'] is not None else int(N * args['epochs'] / args['batch_size'] + .5)
176
+ lr = args['lr0']
177
+
178
+ prev_dev_score = -1
179
+ best_dev_score = -1
180
+ best_dev_step = -1
181
+
182
+ if args['wandb']:
183
+ import wandb
184
+ wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_tokenizer" % args['shorthand']
185
+ wandb.init(name=wandb_name, config=args)
186
+ wandb.run.define_metric('train_loss', summary='min')
187
+ wandb.run.define_metric('dev_score', summary='max')
188
+
189
+
190
+ for step in range(1, steps+1):
191
+ batch = train_batches.next(unit_dropout=args['unit_dropout'], feat_unit_dropout = args['feat_unit_dropout'])
192
+
193
+ loss = trainer.update(batch)
194
+ if step % args['report_steps'] == 0:
195
+ logger.info("Step {:6d}/{:6d} Loss: {:.3f}".format(step, steps, loss))
196
+ if args['wandb']:
197
+ wandb.log({'train_loss': loss}, step=step)
198
+
199
+ if args['shuffle_steps'] > 0 and step % args['shuffle_steps'] == 0:
200
+ train_batches.shuffle()
201
+
202
+ if step % args['eval_steps'] == 0:
203
+ dev_score = eval_model(args, trainer, dev_batches, vocab, mwt_dict)
204
+ if args['wandb']:
205
+ wandb.log({'dev_score': dev_score}, step=step)
206
+ reports = ['Dev score: {:6.3f}'.format(dev_score * 100)]
207
+ if step >= args['anneal_after'] and dev_score < prev_dev_score:
208
+ reports += ['lr: {:.6f} -> {:.6f}'.format(lr, lr * args['anneal'])]
209
+ lr *= args['anneal']
210
+ trainer.change_lr(lr)
211
+
212
+ prev_dev_score = dev_score
213
+
214
+ if dev_score > best_dev_score:
215
+ reports += ['New best dev score!']
216
+ best_dev_score = dev_score
217
+ best_dev_step = step
218
+ trainer.save(args['save_name'])
219
+ elif best_dev_step > 0 and step - best_dev_step > args['max_steps_before_stop']:
220
+ reports += ['Stopping training after {} steps with no improvement'.format(step - best_dev_step)]
221
+ logger.info('\t'.join(reports))
222
+ break
223
+
224
+ logger.info('\t'.join(reports))
225
+
226
+ if args['wandb']:
227
+ wandb.finish()
228
+
229
+ if best_dev_step > -1:
230
+ logger.info('Best dev score={} at step {}'.format(best_dev_score, best_dev_step))
231
+ else:
232
+ logger.info('Dev set never evaluated. Saving final model')
233
+ trainer.save(args['save_name'])
234
+
235
+ def evaluate(args):
236
+ mwt_dict = load_mwt_dict(args['mwt_json_file'])
237
+ trainer = Trainer(model_file=args['load_name'] or args['save_name'], device=args['device'])
238
+ loaded_args, vocab = trainer.args, trainer.vocab
239
+
240
+ for k in loaded_args:
241
+ if not k.endswith('_file') and k not in ['device', 'mode', 'save_dir', 'load_name', 'save_name']:
242
+ args[k] = loaded_args[k]
243
+
244
+ eval_input_files = {
245
+ 'txt': args['txt_file'],
246
+ 'label': args['label_file']
247
+ }
248
+
249
+
250
+ batches = TokenizationDataset(args, input_files=eval_input_files, vocab=vocab, evaluation=True, dictionary=trainer.dictionary)
251
+
252
+ oov_count, N, _, _ = output_predictions(args['conll_file'], trainer, batches, vocab, mwt_dict, args['max_seqlen'])
253
+
254
+ logger.info("OOV rate: {:6.3f}% ({:6d}/{:6d})".format(oov_count / N * 100, oov_count, N))
255
+
256
+
257
+ if __name__ == '__main__':
258
+ main()
stanza/stanza/models/wl_coref.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Runs experiments with CorefModel.
3
+
4
+ Try 'python wl_coref.py -h' for more details.
5
+
6
+ Code based on
7
+
8
+ https://github.com/KarelDO/wl-coref/tree/master
9
+ https://arxiv.org/abs/2310.06165
10
+
11
+ This was a fork of
12
+
13
+ https://github.com/vdobrovolskii/wl-coref
14
+ https://aclanthology.org/2021.emnlp-main.605/
15
+
16
+ If you use Stanza's coref module in your work, please cite the following:
17
+
18
+ @misc{doosterlinck2023cawcoref,
19
+ title={CAW-coref: Conjunction-Aware Word-level Coreference Resolution},
20
+ author={Karel D'Oosterlinck and Semere Kiros Bitew and Brandon Papineau and Christopher Potts and Thomas Demeester and Chris Develder},
21
+ year={2023},
22
+ eprint={2310.06165},
23
+ archivePrefix={arXiv},
24
+ primaryClass={cs.CL},
25
+ url = "https://arxiv.org/abs/2310.06165",
26
+ }
27
+
28
+ @inproceedings{dobrovolskii-2021-word,
29
+ title = "Word-Level Coreference Resolution",
30
+ author = "Dobrovolskii, Vladimir",
31
+ booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing",
32
+ month = nov,
33
+ year = "2021",
34
+ address = "Online and Punta Cana, Dominican Republic",
35
+ publisher = "Association for Computational Linguistics",
36
+ url = "https://aclanthology.org/2021.emnlp-main.605",
37
+ pages = "7670--7675"
38
+ }
39
+ """
40
+
41
+ import argparse
42
+ from contextlib import contextmanager
43
+ import datetime
44
+ import logging
45
+ import os
46
+ import random
47
+ import sys
48
+ import dataclasses
49
+ import time
50
+
51
+
52
+ import numpy as np # type: ignore
53
+ import torch # type: ignore
54
+
55
+ from stanza.models.common.utils import set_random_seed
56
+ from stanza.models.coref.model import CorefModel
57
+
58
+
59
+ logger = logging.getLogger('stanza')
60
+
61
+ @contextmanager
62
+ def output_running_time():
63
+ """ Prints the time elapsed in the context """
64
+ start = int(time.time())
65
+ try:
66
+ yield
67
+ finally:
68
+ end = int(time.time())
69
+ delta = datetime.timedelta(seconds=end - start)
70
+ logger.info(f"Total running time: {delta}")
71
+
72
+
73
+ def deterministic() -> None:
74
+ torch.backends.cudnn.deterministic = True # type: ignore
75
+ torch.backends.cudnn.benchmark = False # type: ignore
76
+
77
+
78
+ if __name__ == "__main__":
79
+ argparser = argparse.ArgumentParser()
80
+ argparser.add_argument("mode", choices=("train", "eval"))
81
+ argparser.add_argument("experiment")
82
+ argparser.add_argument("--config-file", default="config.toml")
83
+ argparser.add_argument("--data-split", choices=("train", "dev", "test"),
84
+ default="test",
85
+ help="Data split to be used for evaluation."
86
+ " Defaults to 'test'."
87
+ " Ignored in 'train' mode.")
88
+ argparser.add_argument("--batch-size", type=int,
89
+ help="Adjust to override the config value of anaphoricity "
90
+ "batch size if you are experiencing out-of-memory "
91
+ "issues")
92
+ argparser.add_argument("--disable_singletons", action="store_true",
93
+ help="don't predict singletons")
94
+ argparser.add_argument("--full_pairwise", action="store_true",
95
+ help="use speaker and document embeddings")
96
+ argparser.add_argument("--hidden_size", type=int,
97
+ help="Adjust the anaphoricity scorer hidden size")
98
+ argparser.add_argument("--rough_k", type=int,
99
+ help="Adjust the number of dummies to keep")
100
+ argparser.add_argument("--n_hidden_layers", type=int,
101
+ help="Adjust the anaphoricity scorer hidden layers")
102
+ argparser.add_argument("--dummy_mix", type=float,
103
+ help="Adjust the dummy mix")
104
+ argparser.add_argument("--bert_finetune_begin_epoch", type=float,
105
+ help="Adjust the bert finetune begin epoch")
106
+ argparser.add_argument("--warm_start", action="store_true",
107
+ help="If set, the training will resume from the"
108
+ " last checkpoint saved if any. Ignored in"
109
+ " evaluation modes."
110
+ " Incompatible with '--weights'.")
111
+ argparser.add_argument("--weights",
112
+ help="Path to file with weights to load."
113
+ " If not supplied, in 'eval' mode the latest"
114
+ " weights of the experiment will be loaded;"
115
+ " in 'train' mode no weights will be loaded.")
116
+ argparser.add_argument("--word-level", action="store_true",
117
+ help="If set, output word-level conll-formatted"
118
+ " files in evaluation modes. Ignored in"
119
+ " 'train' mode.")
120
+ argparser.add_argument("--learning_rate", default=None, type=float,
121
+ help="If set, update the learning rate for the model")
122
+ argparser.add_argument("--bert_learning_rate", default=None, type=float,
123
+ help="If set, update the learning rate for the transformer")
124
+ argparser.add_argument("--save_dir", default=None,
125
+ help="If set, update the save directory for writing models")
126
+ argparser.add_argument("--save_name", default=None,
127
+ help="If set, update the save name for writing models (otherwise, section name)")
128
+ argparser.add_argument("--score_lang", default=None,
129
+ help="only score a particular language for eval")
130
+ argparser.add_argument("--log_norms", action="store_true", default=None,
131
+ help="If set, log all of the trainable norms each epoch. Very noisy!")
132
+ argparser.add_argument("--seed", type=int, default=2020,
133
+ help="Random seed to set")
134
+
135
+ argparser.add_argument("--train_data", default=None, help="File to use for train data")
136
+ argparser.add_argument("--dev_data", default=None, help="File to use for dev data")
137
+ argparser.add_argument("--test_data", default=None, help="File to use for test data")
138
+
139
+ argparser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name', default=False)
140
+ argparser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
141
+
142
+ args = argparser.parse_args()
143
+
144
+ if args.warm_start and args.weights is not None:
145
+ raise ValueError("The following options are incompatible: '--warm_start' and '--weights'")
146
+
147
+ set_random_seed(args.seed)
148
+ deterministic()
149
+ config = CorefModel._load_config(args.config_file, args.experiment)
150
+ if args.batch_size:
151
+ config.a_scoring_batch_size = args.batch_size
152
+ if args.hidden_size:
153
+ config.hidden_size = args.hidden_size
154
+ if args.n_hidden_layers:
155
+ config.n_hidden_layers = args.n_hidden_layers
156
+ if args.learning_rate is not None:
157
+ config.learning_rate = args.learning_rate
158
+ if args.bert_learning_rate is not None:
159
+ config.bert_learning_rate = args.bert_learning_rate
160
+ if args.bert_finetune_begin_epoch is not None:
161
+ config.bert_finetune_begin_epoch = args.bert_finetune_begin_epoch
162
+ if args.dummy_mix is not None:
163
+ config.dummy_mix = args.dummy_mix
164
+
165
+ if args.save_dir is not None:
166
+ config.save_dir = args.save_dir
167
+ if args.save_name:
168
+ config.save_name = args.save_name
169
+ else:
170
+ config.save_name = args.experiment
171
+
172
+ if args.rough_k is not None:
173
+ config.rough_k = args.rough_k
174
+ if args.log_norms is not None:
175
+ config.log_norms = args.log_norms
176
+ if args.full_pairwise:
177
+ config.full_pairwise = args.full_pairwise
178
+ if args.disable_singletons:
179
+ config.singletons = False
180
+ if args.train_data:
181
+ config.train_data = args.train_data
182
+ if args.dev_data:
183
+ config.dev_data = args.dev_data
184
+ if args.test_data:
185
+ config.test_data = args.test_data
186
+
187
+ # if wandb, generate wandb configuration
188
+ if args.mode == "train":
189
+ if args.wandb:
190
+ import wandb
191
+ wandb_name = args.wandb_name if args.wandb_name else f"wl_coref_{args.experiment}"
192
+ wandb.init(name=wandb_name, config=dataclasses.asdict(config), project="stanza")
193
+ wandb.run.define_metric('train_c_loss', summary='min')
194
+ wandb.run.define_metric('train_s_loss', summary='min')
195
+ wandb.run.define_metric('dev_score', summary='max')
196
+
197
+ model = CorefModel(config=config)
198
+ if args.weights is not None or args.warm_start:
199
+ model.load_weights(path=args.weights, map_location="cpu",
200
+ noexception=args.warm_start)
201
+ with output_running_time():
202
+ model.train(args.wandb)
203
+ else:
204
+ config_update = {
205
+ 'log_norms': args.log_norms if args.log_norms is not None else False
206
+ }
207
+ if args.test_data:
208
+ config_update['test_data'] = args.test_data
209
+
210
+ if args.weights is None and config.save_name is not None:
211
+ args.weights = config.save_name
212
+ if not os.path.exists(args.weights) and os.path.exists(args.weights + ".pt"):
213
+ args.weights = args.weights + ".pt"
214
+ elif not os.path.exists(args.weights) and config.save_dir and os.path.exists(os.path.join(config.save_dir, args.weights)):
215
+ args.weights = os.path.join(config.save_dir, args.weights)
216
+ elif not os.path.exists(args.weights) and config.save_dir and os.path.exists(os.path.join(config.save_dir, args.weights + ".pt")):
217
+ args.weights = os.path.join(config.save_dir, args.weights + ".pt")
218
+ model = CorefModel.load_model(path=args.weights, map_location="cpu",
219
+ ignore={"bert_optimizer", "general_optimizer",
220
+ "bert_scheduler", "general_scheduler"},
221
+ config_update=config_update)
222
+ results = model.evaluate(data_split=args.data_split,
223
+ word_level_conll=args.word_level,
224
+ eval_lang=args.score_lang)
225
+ # logger.info(("mean loss", "))
226
+ print("\t".join([str(round(i, 3)) for i in results]))
stanza/stanza/pipeline/__init__.py ADDED
File without changes
stanza/stanza/pipeline/constituency_processor.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processor that attaches a constituency tree to a sentence
3
+ """
4
+
5
+ from stanza.models.constituency.trainer import Trainer
6
+
7
+ from stanza.models.common import doc
8
+ from stanza.models.common.utils import sort_with_indices, unsort
9
+ from stanza.utils.get_tqdm import get_tqdm
10
+ from stanza.pipeline._constants import *
11
+ from stanza.pipeline.processor import UDProcessor, register_processor
12
+
13
+ tqdm = get_tqdm()
14
+
15
+ @register_processor(CONSTITUENCY)
16
+ class ConstituencyProcessor(UDProcessor):
17
+ # set of processor requirements this processor fulfills
18
+ PROVIDES_DEFAULT = set([CONSTITUENCY])
19
+ # set of processor requirements for this processor
20
+ REQUIRES_DEFAULT = set([TOKENIZE, POS])
21
+
22
+ # default batch size, measured in sentences
23
+ DEFAULT_BATCH_SIZE = 50
24
+
25
+ def _set_up_requires(self):
26
+ self._pretagged = self._config.get('pretagged')
27
+ if self._pretagged:
28
+ self._requires = set()
29
+ else:
30
+ self._requires = self.__class__.REQUIRES_DEFAULT
31
+
32
+ def _set_up_model(self, config, pipeline, device):
33
+ # set up model
34
+ # pretrain and charlm paths are args from the config
35
+ # bert (if used) will be chosen from the model save file
36
+ args = {
37
+ "wordvec_pretrain_file": config.get('pretrain_path', None),
38
+ "charlm_forward_file": config.get('forward_charlm_path', None),
39
+ "charlm_backward_file": config.get('backward_charlm_path', None),
40
+ "device": device,
41
+ }
42
+ trainer = Trainer.load(filename=config['model_path'],
43
+ args=args,
44
+ foundation_cache=pipeline.foundation_cache)
45
+ self._trainer = trainer
46
+ self._model = trainer.model
47
+ self._model.eval()
48
+ # batch size counted as sentences
49
+ self._batch_size = int(config.get('batch_size', ConstituencyProcessor.DEFAULT_BATCH_SIZE))
50
+ self._tqdm = 'tqdm' in config and config['tqdm']
51
+
52
+ def _set_up_final_config(self, config):
53
+ loaded_args = self._model.args
54
+ loaded_args = {k: v for k, v in loaded_args.items() if not UDProcessor.filter_out_option(k)}
55
+ loaded_args.update(config)
56
+ self._config = loaded_args
57
+
58
+ def process(self, document):
59
+ sentences = document.sentences
60
+
61
+ if self._model.uses_xpos():
62
+ words = [[(w.text, w.xpos) for w in s.words] for s in sentences]
63
+ else:
64
+ words = [[(w.text, w.upos) for w in s.words] for s in sentences]
65
+ words, original_indices = sort_with_indices(words, key=len, reverse=True)
66
+ if self._tqdm:
67
+ words = tqdm(words)
68
+
69
+ trees = self._model.parse_tagged_words(words, self._batch_size)
70
+ trees = unsort(trees, original_indices)
71
+ document.set(CONSTITUENCY, trees, to_sentence=True)
72
+ return document
73
+
74
+ def get_constituents(self):
75
+ """
76
+ Return a set of the constituents known by this model
77
+
78
+ For a pipeline, this can be queried with
79
+ pipeline.processors["constituency"].get_constituents()
80
+ """
81
+ return set(self._model.constituents)
stanza/stanza/pipeline/core.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pipeline that runs tokenize,mwt,pos,lemma,depparse
3
+ """
4
+
5
+ import argparse
6
+ import collections
7
+ from enum import Enum
8
+ import io
9
+ import itertools
10
+ import sys
11
+ import logging
12
+ import json
13
+ import os
14
+
15
+ from stanza.pipeline._constants import *
16
+ from stanza.models.common.constant import langcode_to_lang
17
+ from stanza.models.common.doc import Document
18
+ from stanza.models.common.foundation_cache import FoundationCache
19
+ from stanza.models.common.utils import default_device
20
+ from stanza.pipeline.processor import Processor, ProcessorRequirementsException
21
+ from stanza.pipeline.registry import NAME_TO_PROCESSOR_CLASS, PIPELINE_NAMES, PROCESSOR_VARIANTS
22
+ from stanza.pipeline.langid_processor import LangIDProcessor
23
+ from stanza.pipeline.tokenize_processor import TokenizeProcessor
24
+ from stanza.pipeline.mwt_processor import MWTProcessor
25
+ from stanza.pipeline.pos_processor import POSProcessor
26
+ from stanza.pipeline.lemma_processor import LemmaProcessor
27
+ from stanza.pipeline.constituency_processor import ConstituencyProcessor
28
+ from stanza.pipeline.coref_processor import CorefProcessor
29
+ from stanza.pipeline.depparse_processor import DepparseProcessor
30
+ from stanza.pipeline.sentiment_processor import SentimentProcessor
31
+ from stanza.pipeline.ner_processor import NERProcessor
32
+ from stanza.resources.common import DEFAULT_MODEL_DIR, DEFAULT_RESOURCES_URL, DEFAULT_RESOURCES_VERSION, ModelSpecification, add_dependencies, add_mwt, download_models, download_resources_json, flatten_processor_list, load_resources_json, maintain_processor_list, process_pipeline_parameters, set_logging_level, sort_processors
33
+ from stanza.resources.default_packages import PACKAGES
34
+ from stanza.utils.conll import CoNLL, CoNLLError
35
+ from stanza.utils.helper_func import make_table
36
+
37
+ logger = logging.getLogger('stanza')
38
+
39
+ class DownloadMethod(Enum):
40
+ """
41
+ Determines a couple options on how to download resources for the pipeline.
42
+
43
+ NONE will not download anything, including HF transformers, probably resulting in failure if the resources aren't already in place.
44
+ REUSE_RESOURCES will reuse the existing resources.json and models, but will download any missing models.
45
+ DOWNLOAD_RESOURCES will download a new resources.json and will overwrite any out of date models.
46
+ """
47
+ NONE = 1
48
+ REUSE_RESOURCES = 2
49
+ DOWNLOAD_RESOURCES = 3
50
+
51
+ class LanguageNotDownloadedError(FileNotFoundError):
52
+ def __init__(self, lang, lang_dir, model_path):
53
+ super().__init__(f'Could not find the model file {model_path}. The expected model directory {lang_dir} is missing. Perhaps you need to run stanza.download("{lang}")')
54
+ self.lang = lang
55
+ self.lang_dir = lang_dir
56
+ self.model_path = model_path
57
+
58
+ class UnsupportedProcessorError(FileNotFoundError):
59
+ def __init__(self, processor, lang):
60
+ super().__init__(f'Processor {processor} is not known for language {lang}. If you have created your own model, please specify the {processor}_model_path parameter when creating the pipeline.')
61
+ self.processor = processor
62
+ self.lang = lang
63
+
64
+ class IllegalPackageError(ValueError):
65
+ def __init__(self, msg):
66
+ super().__init__(msg)
67
+
68
+ class PipelineRequirementsException(Exception):
69
+ """
70
+ Exception indicating one or more requirements failures while attempting to build a pipeline.
71
+ Contains a ProcessorRequirementsException list.
72
+ """
73
+
74
+ def __init__(self, processor_req_fails):
75
+ self._processor_req_fails = processor_req_fails
76
+ self.build_message()
77
+
78
+ @property
79
+ def processor_req_fails(self):
80
+ return self._processor_req_fails
81
+
82
+ def build_message(self):
83
+ err_msg = io.StringIO()
84
+ print(*[req_fail.message for req_fail in self.processor_req_fails], sep='\n', file=err_msg)
85
+ self.message = '\n\n' + err_msg.getvalue()
86
+
87
+ def __str__(self):
88
+ return self.message
89
+
90
+ def build_default_config_option(model_specs):
91
+ """
92
+ Build a config option for a couple situations: lemma=identity, processor is a variant
93
+
94
+ Returns the option name and value
95
+
96
+ Refactored from build_default_config so that we can reuse it when
97
+ downloading all models
98
+ """
99
+ # handle case when processor variants are used
100
+ if any(model_spec.package in PROCESSOR_VARIANTS[model_spec.processor] for model_spec in model_specs):
101
+ if len(model_specs) > 1:
102
+ raise IllegalPackageError("Variant processor selected for {}, but multiple packages requested".format(model_spec.processor))
103
+ return f"{model_specs[0].processor}_with_{model_specs[0].package}", True
104
+ # handle case when identity is specified as lemmatizer
105
+ elif any(model_spec.processor == LEMMA and model_spec.package == 'identity' for model_spec in model_specs):
106
+ if len(model_specs) > 1:
107
+ raise IllegalPackageError("Identity processor selected for lemma, but multiple packages requested")
108
+ return f"{LEMMA}_use_identity", True
109
+ return None
110
+
111
+ def filter_variants(model_specs):
112
+ return [(key, value) for (key, value) in model_specs if build_default_config_option(value) is None]
113
+
114
+ # given a language and models path, build a default configuration
115
+ def build_default_config(resources, lang, model_dir, load_list):
116
+ default_config = {}
117
+ for processor, model_specs in load_list:
118
+ option = build_default_config_option(model_specs)
119
+ if option is not None:
120
+ # if an option is set for the model_specs, keep that option and ignore
121
+ # the rest of the model spec
122
+ default_config[option[0]] = option[1]
123
+ continue
124
+
125
+ model_paths = [os.path.join(model_dir, lang, processor, model_spec.package + '.pt') for model_spec in model_specs]
126
+ dependencies = [model_spec.dependencies for model_spec in model_specs]
127
+
128
+ # Special case for NER: load multiple models at once
129
+ # The pattern will be:
130
+ # a list of ner_model_path
131
+ # a list of ner_dependencies
132
+ # where each item in ner_dependencies is a map
133
+ # the map may contain forward_charlm_path, backward_charlm_path, or any other deps
134
+ # The user will be able to override the defaults using a semicolon separated string
135
+ # TODO: at least use the same config pattern for all other models
136
+ if processor == NER:
137
+ default_config[f"{processor}_model_path"] = model_paths
138
+ dependency_paths = []
139
+ for dependency_block in dependencies:
140
+ if not dependency_block:
141
+ dependency_paths.append({})
142
+ continue
143
+ dependency_paths.append({})
144
+ for dependency in dependency_block:
145
+ dep_processor, dep_model = dependency
146
+ dependency_paths[-1][f"{dep_processor}_path"] = os.path.join(model_dir, lang, dep_processor, dep_model + '.pt')
147
+ default_config[f"{processor}_dependencies"] = dependency_paths
148
+ continue
149
+
150
+ if len(model_specs) > 1:
151
+ raise IllegalPackageError("Specified multiple packages for {}, which currently only handles one package".format(processor))
152
+
153
+ default_config[f"{processor}_model_path"] = model_paths[0]
154
+ if not dependencies[0]: continue
155
+ for dependency in dependencies[0]:
156
+ dep_processor, dep_model = dependency
157
+ default_config[f"{processor}_{dep_processor}_path"] = os.path.join(
158
+ model_dir, lang, dep_processor, dep_model + '.pt'
159
+ )
160
+
161
+ return default_config
162
+
163
+ def normalize_download_method(download_method):
164
+ """
165
+ Turn None -> DownloadMethod.NONE, strings to the corresponding enum
166
+ """
167
+ if download_method is None:
168
+ return DownloadMethod.NONE
169
+ elif isinstance(download_method, str):
170
+ try:
171
+ return DownloadMethod[download_method.upper()]
172
+ except KeyError as e:
173
+ raise ValueError("Unknown download method %s" % download_method) from e
174
+ return download_method
175
+
176
+ class Pipeline:
177
+
178
+ def __init__(self,
179
+ lang='en',
180
+ dir=DEFAULT_MODEL_DIR,
181
+ package='default',
182
+ processors={},
183
+ logging_level=None,
184
+ verbose=None,
185
+ use_gpu=None,
186
+ model_dir=None,
187
+ download_method=DownloadMethod.DOWNLOAD_RESOURCES,
188
+ resources_url=DEFAULT_RESOURCES_URL,
189
+ resources_branch=None,
190
+ resources_version=DEFAULT_RESOURCES_VERSION,
191
+ resources_filepath=None,
192
+ proxies=None,
193
+ foundation_cache=None,
194
+ device=None,
195
+ allow_unknown_language=False,
196
+ **kwargs):
197
+ self.lang, self.dir, self.kwargs = lang, dir, kwargs
198
+ if model_dir is not None and dir == DEFAULT_MODEL_DIR:
199
+ self.dir = model_dir
200
+
201
+ # set global logging level
202
+ set_logging_level(logging_level, verbose)
203
+
204
+ self.download_method = normalize_download_method(download_method)
205
+ if (self.download_method is DownloadMethod.DOWNLOAD_RESOURCES or
206
+ (self.download_method is DownloadMethod.REUSE_RESOURCES and not os.path.exists(os.path.join(self.dir, "resources.json")))):
207
+ logger.info("Checking for updates to resources.json in case models have been updated. Note: this behavior can be turned off with download_method=None or download_method=DownloadMethod.REUSE_RESOURCES")
208
+ download_resources_json(self.dir,
209
+ resources_url=resources_url,
210
+ resources_branch=resources_branch,
211
+ resources_version=resources_version,
212
+ resources_filepath=resources_filepath,
213
+ proxies=proxies)
214
+
215
+ # processors can use this to save on the effort of loading
216
+ # large sub-models, such as pretrained embeddings, bert, etc
217
+ if foundation_cache is None:
218
+ self.foundation_cache = FoundationCache(local_files_only=(self.download_method is DownloadMethod.NONE))
219
+ else:
220
+ self.foundation_cache = FoundationCache(foundation_cache, local_files_only=(self.download_method is DownloadMethod.NONE))
221
+
222
+ # process different pipeline parameters
223
+ lang, self.dir, package, processors = process_pipeline_parameters(lang, self.dir, package, processors)
224
+
225
+ # Load resources.json to obtain latest packages.
226
+ logger.debug('Loading resource file...')
227
+ resources = load_resources_json(self.dir, resources_filepath)
228
+ if lang in resources:
229
+ if 'alias' in resources[lang]:
230
+ logger.info(f'"{lang}" is an alias for "{resources[lang]["alias"]}"')
231
+ lang = resources[lang]['alias']
232
+ lang_name = resources[lang]['lang_name'] if 'lang_name' in resources[lang] else ''
233
+ elif allow_unknown_language:
234
+ logger.warning("Trying to create pipeline for unsupported language: %s", lang)
235
+ lang_name = langcode_to_lang(lang)
236
+ else:
237
+ logger.warning("Unsupported language: %s If trying to add a new language, consider using allow_unknown_language=True", lang)
238
+ lang_name = langcode_to_lang(lang)
239
+
240
+ # Maintain load list
241
+ if lang in resources:
242
+ self.load_list = maintain_processor_list(resources, lang, package, processors, maybe_add_mwt=(not kwargs.get("tokenize_pretokenized")))
243
+ self.load_list = add_dependencies(resources, lang, self.load_list)
244
+ if self.download_method is not DownloadMethod.NONE:
245
+ # skip processors which aren't downloaded from our collection
246
+ download_list = [x for x in self.load_list if x[0] in resources.get(lang, {})]
247
+ # skip variants
248
+ download_list = filter_variants(download_list)
249
+ # gather up the model list...
250
+ download_list = flatten_processor_list(download_list)
251
+ # download_models will skip models we already have
252
+ download_models(download_list,
253
+ resources=resources,
254
+ lang=lang,
255
+ model_dir=self.dir,
256
+ resources_version=resources_version,
257
+ proxies=proxies,
258
+ log_info=False)
259
+ elif allow_unknown_language:
260
+ self.load_list = [(proc, [ModelSpecification(processor=proc, package='default', dependencies=None)])
261
+ for proc in list(processors.keys())]
262
+ else:
263
+ self.load_list = []
264
+ self.load_list = self.update_kwargs(kwargs, self.load_list)
265
+ if len(self.load_list) == 0:
266
+ if lang not in resources or PACKAGES not in resources[lang]:
267
+ raise ValueError(f'No processors to load for language {lang}. Language {lang} is currently unsupported')
268
+ else:
269
+ raise ValueError('No processors to load for language {}. Please check if your language or package is correctly set.'.format(lang))
270
+ load_table = make_table(['Processor', 'Package'], [(row[0], ";".join(model_spec.package for model_spec in row[1])) for row in self.load_list])
271
+ logger.info(f'Loading these models for language: {lang} ({lang_name}):\n{load_table}')
272
+
273
+ self.config = build_default_config(resources, lang, self.dir, self.load_list)
274
+ self.config.update(kwargs)
275
+
276
+ # Load processors
277
+ self.processors = {}
278
+
279
+ # configs that are the same for all processors
280
+ pipeline_level_configs = {'lang': lang, 'mode': 'predict'}
281
+
282
+ if device is None:
283
+ if use_gpu is None or use_gpu == True:
284
+ device = default_device()
285
+ else:
286
+ device = 'cpu'
287
+ if use_gpu == True and device == 'cpu':
288
+ logger.warning("GPU requested, but is not available!")
289
+ self.device = device
290
+ logger.info("Using device: {}".format(self.device))
291
+
292
+ # set up processors
293
+ pipeline_reqs_exceptions = []
294
+ for item in self.load_list:
295
+ processor_name, _ = item
296
+ logger.info('Loading: ' + processor_name)
297
+ curr_processor_config = self.filter_config(processor_name, self.config)
298
+ curr_processor_config.update(pipeline_level_configs)
299
+ # TODO: this is obviously a hack
300
+ # a better solution overall would be to make a pretagged version of the pos annotator
301
+ # and then subsequent modules can use those tags without knowing where those tags came from
302
+ if "pretagged" in self.config and "pretagged" not in curr_processor_config:
303
+ curr_processor_config["pretagged"] = self.config["pretagged"]
304
+ logger.debug('With settings: ')
305
+ logger.debug(curr_processor_config)
306
+ try:
307
+ # try to build processor, throw an exception if there is a requirements issue
308
+ self.processors[processor_name] = NAME_TO_PROCESSOR_CLASS[processor_name](config=curr_processor_config,
309
+ pipeline=self,
310
+ device=self.device)
311
+ except ProcessorRequirementsException as e:
312
+ # if there was a requirements issue, add it to list which will be printed at end
313
+ pipeline_reqs_exceptions.append(e)
314
+ # add the broken processor to the loaded processors for the sake of analyzing the validity of the
315
+ # entire proposed pipeline, but at this point the pipeline will not be built successfully
316
+ self.processors[processor_name] = e.err_processor
317
+ except FileNotFoundError as e:
318
+ # For a FileNotFoundError, we try to guess if there's
319
+ # a missing model directory or file. If so, we
320
+ # suggest the user try to download the models
321
+ if 'model_path' in curr_processor_config:
322
+ model_path = curr_processor_config['model_path']
323
+ if e.filename == model_path or (isinstance(model_path, (tuple, list)) and e.filename in model_path):
324
+ model_path = e.filename
325
+ model_dir, model_name = os.path.split(model_path)
326
+ lang_dir = os.path.dirname(model_dir)
327
+ if lang_dir and not os.path.exists(lang_dir):
328
+ # model files for this language can't be found in the expected directory
329
+ raise LanguageNotDownloadedError(lang, lang_dir, model_path) from e
330
+ if processor_name not in resources[lang]:
331
+ # user asked for a model which doesn't exist for this language?
332
+ raise UnsupportedProcessorError(processor_name, lang) from e
333
+ if not os.path.exists(model_path):
334
+ model_name, _ = os.path.splitext(model_name)
335
+ # TODO: before recommending this, check that such a thing exists in resources.json.
336
+ # currently that case is handled by ignoring the model, anyway
337
+ raise FileNotFoundError('Could not find model file %s, although there are other models downloaded for language %s. Perhaps you need to download a specific model. Try: stanza.download(lang="%s",package=None,processors={"%s":"%s"})' % (model_path, lang, lang, processor_name, model_name)) from e
338
+
339
+ # if we couldn't find a more suitable description of the
340
+ # FileNotFoundError, just raise the old error
341
+ raise
342
+
343
+ # if there are any processor exceptions, throw an exception to indicate pipeline build failure
344
+ if pipeline_reqs_exceptions:
345
+ logger.info('\n')
346
+ raise PipelineRequirementsException(pipeline_reqs_exceptions)
347
+
348
+ logger.info("Done loading processors!")
349
+
350
+ @staticmethod
351
+ def update_kwargs(kwargs, processor_list):
352
+ processor_dict = {processor: [{'package': model_spec.package, 'dependencies': model_spec.dependencies} for model_spec in model_specs]
353
+ for (processor, model_specs) in processor_list}
354
+ for key, value in kwargs.items():
355
+ pieces = key.split('_', 1)
356
+ if len(pieces) == 1:
357
+ continue
358
+ k, v = pieces
359
+ if v == 'model_path':
360
+ package = value if len(value) < 25 else value[:10]+ '...' + value[-10:]
361
+ original_spec = processor_dict.get(k, [])
362
+ if len(original_spec) > 0:
363
+ dependencies = original_spec[0].get('dependencies')
364
+ else:
365
+ dependencies = None
366
+ processor_dict[k] = [{'package': package, 'dependencies': dependencies}]
367
+ processor_list = [(processor, [ModelSpecification(processor=processor, package=model_spec['package'], dependencies=model_spec['dependencies']) for model_spec in processor_dict[processor]]) for processor in processor_dict]
368
+ processor_list = sort_processors(processor_list)
369
+ return processor_list
370
+
371
+ @staticmethod
372
+ def filter_config(prefix, config_dict):
373
+ filtered_dict = {}
374
+ for key in config_dict.keys():
375
+ pieces = key.split('_', 1) # split tokenize_pretokenize to tokenize+pretokenize
376
+ if len(pieces) == 1:
377
+ continue
378
+ k, v = pieces
379
+ if k == prefix:
380
+ filtered_dict[v] = config_dict[key]
381
+ return filtered_dict
382
+
383
+ @property
384
+ def loaded_processors(self):
385
+ """
386
+ Return all currently loaded processors in execution order.
387
+ :return: list of Processor instances
388
+ """
389
+ return [self.processors[processor_name] for processor_name in PIPELINE_NAMES if self.processors.get(processor_name)]
390
+
391
+ def process(self, doc, processors=None):
392
+ """
393
+ Run the pipeline
394
+
395
+ processors: allow for a list of processors used by this pipeline action
396
+ can be list, tuple, set, or comma separated string
397
+ if None, use all the processors this pipeline knows about
398
+ MWT is added if necessary
399
+ otherwise, no care is taken to make sure prerequisites are followed...
400
+ some of the annotators, such as depparse, will check, but others
401
+ will fail in some unusual manner or just have really bad results
402
+ """
403
+ assert any([isinstance(doc, str), isinstance(doc, list),
404
+ isinstance(doc, Document)]), 'input should be either str, list or Document'
405
+
406
+ # empty bulk process
407
+ if isinstance(doc, list) and len(doc) == 0:
408
+ return []
409
+
410
+ # determine whether we are in bulk processing mode for multiple documents
411
+ bulk=(isinstance(doc, list) and len(doc) > 0 and isinstance(doc[0], Document))
412
+
413
+ # various options to limit the processors used by this pipeline action
414
+ if processors is None:
415
+ processors = PIPELINE_NAMES
416
+ elif not isinstance(processors, (str, list, tuple, set)):
417
+ raise ValueError("Cannot process {} as a list of processors to run".format(type(processors)))
418
+ else:
419
+ if isinstance(processors, str):
420
+ processors = {x for x in processors.split(",")}
421
+ else:
422
+ processors = set(processors)
423
+ if TOKENIZE in processors and MWT in self.processors and MWT not in processors:
424
+ logger.debug("Requested processors for pipeline did not have mwt, but pipeline needs mwt, so mwt is added")
425
+ processors.add(MWT)
426
+ processors = [x for x in PIPELINE_NAMES if x in processors]
427
+
428
+ for processor_name in processors:
429
+ if self.processors.get(processor_name):
430
+ process = self.processors[processor_name].bulk_process if bulk else self.processors[processor_name].process
431
+ doc = process(doc)
432
+ return doc
433
+
434
+ def bulk_process(self, docs, *args, **kwargs):
435
+ """
436
+ Run the pipeline in bulk processing mode
437
+
438
+ Expects a list of str or a list of Docs
439
+ """
440
+ # Wrap each text as a Document unless it is already such a document
441
+ docs = [doc if isinstance(doc, Document) else Document([], text=doc) for doc in docs]
442
+ return self.process(docs, *args, **kwargs)
443
+
444
+ def stream(self, docs, batch_size=50, *args, **kwargs):
445
+ """
446
+ Go through an iterator of documents in batches, yield processed documents
447
+
448
+ sentence indices will be counted across the entire iterator
449
+ """
450
+ if not isinstance(docs, collections.abc.Iterator):
451
+ docs = iter(docs)
452
+ def next_batch():
453
+ batch = []
454
+ for _ in range(batch_size):
455
+ try:
456
+ next_doc = next(docs)
457
+ batch.append(next_doc)
458
+ except StopIteration:
459
+ return batch
460
+ return batch
461
+
462
+ sentence_start_index = 0
463
+ batch = next_batch()
464
+ while batch:
465
+ batch = self.bulk_process(batch, *args, **kwargs)
466
+ for doc in batch:
467
+ doc.reindex_sentences(sentence_start_index)
468
+ sentence_start_index += len(doc.sentences)
469
+ yield doc
470
+ batch = next_batch()
471
+
472
+ def __str__(self):
473
+ """
474
+ Assemble the processors in order to make a simple description of the pipeline
475
+ """
476
+ processors = ["%s=%s" % (x, str(self.processors[x])) for x in PIPELINE_NAMES if x in self.processors]
477
+ return "<Pipeline: %s>" % ", ".join(processors)
478
+
479
+ def __call__(self, doc, processors=None):
480
+ return self.process(doc, processors)
481
+
482
+ def main():
483
+ # TODO: can add a bunch more arguments
484
+ parser = argparse.ArgumentParser()
485
+ parser.add_argument('--lang', type=str, default='en', help='Language of the pipeline to use')
486
+ parser.add_argument('--input_file', type=str, required=True, help='Input file to read')
487
+ parser.add_argument('--processors', type=str, default='tokenize,pos,lemma,depparse', help='Processors to use')
488
+ args, extra_args = parser.parse_known_args()
489
+
490
+ try:
491
+ doc = CoNLL.conll2doc(args.input_file)
492
+ extra_args = {
493
+ "tokenize_pretokenized": True
494
+ }
495
+ except CoNLLError:
496
+ logger.debug("Input file %s does not appear to be a conllu file. Will read it as a text file")
497
+ with open(args.input_file, encoding="utf-8") as fin:
498
+ doc = fin.read()
499
+ extra_args = {}
500
+
501
+ pipe = Pipeline(args.lang, processors=args.processors, **extra_args)
502
+
503
+ doc = pipe(doc)
504
+
505
+ print("{:C}".format(doc))
506
+
507
+
508
+ if __name__ == '__main__':
509
+ main()
stanza/stanza/pipeline/coref_processor.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processor that attaches coref annotations to a document
3
+ """
4
+
5
+ from stanza.models.common.utils import misc_to_space_after
6
+ from stanza.models.coref.coref_chain import CorefMention, CorefChain
7
+
8
+ from stanza.pipeline._constants import *
9
+ from stanza.pipeline.processor import UDProcessor, register_processor
10
+
11
+ def extract_text(document, sent_id, start_word, end_word):
12
+ sentence = document.sentences[sent_id]
13
+ tokens = []
14
+
15
+ # the coref model indexes the words from 0,
16
+ # whereas the ids we are looking at on the tokens start from 1
17
+ # here we will switch to ID space
18
+ start_word = start_word + 1
19
+ end_word = end_word + 1
20
+
21
+ # For each position between start and end word:
22
+ # If a word is part of an MWT, and the entire token
23
+ # is inside the range, we use that Token's text for that span
24
+ # This will let us easily handle words which are split into pieces
25
+ # Otherwise, we only take the text of the word itself
26
+ next_idx = start_word
27
+ while next_idx < end_word:
28
+ word = sentence.words[next_idx-1]
29
+ parent_token = word.parent
30
+ if isinstance(parent_token.id, int) or len(parent_token.id) == 1:
31
+ tokens.append(parent_token)
32
+ next_idx += 1
33
+ elif parent_token.id[0] >= start_word and parent_token.id[1] < end_word:
34
+ tokens.append(parent_token)
35
+ next_idx = parent_token.id[1] + 1
36
+ else:
37
+ tokens.append(word)
38
+ next_idx += 1
39
+
40
+ # We use the SpaceAfter or SpacesAfter attribute on each Word or Token
41
+ # we chose in the above loop to separate the text pieces
42
+ text = []
43
+ for token in tokens:
44
+ text.append(token.text)
45
+ text.append(misc_to_space_after(token.misc))
46
+ # the last token space_after will be discarded
47
+ # so that we don't have stray WS at the end of the mention text
48
+ text = text[:-1]
49
+ return "".join(text)
50
+
51
+
52
+ @register_processor(COREF)
53
+ class CorefProcessor(UDProcessor):
54
+ # set of processor requirements this processor fulfills
55
+ PROVIDES_DEFAULT = set([COREF])
56
+ # set of processor requirements for this processor
57
+ REQUIRES_DEFAULT = set([TOKENIZE])
58
+
59
+ def _set_up_model(self, config, pipeline, device):
60
+ try:
61
+ from stanza.models.coref.model import CorefModel
62
+ except ImportError:
63
+ raise ImportError("Please install the transformers and peft libraries before using coref! Try `pip install -e .[transformers]`.")
64
+
65
+ # set up model
66
+ # currently, the model has everything packaged in it
67
+ # (except its config)
68
+ # TODO: separate any pretrains if possible
69
+ # TODO: add device parameter to the load mechanism
70
+ config_update = {'log_norms': False,
71
+ 'device': device}
72
+ model = CorefModel.load_model(path=config['model_path'],
73
+ ignore={"bert_optimizer", "general_optimizer",
74
+ "bert_scheduler", "general_scheduler"},
75
+ config_update=config_update,
76
+ foundation_cache=pipeline.foundation_cache)
77
+ if config.get('batch_size', None):
78
+ model.config.a_scoring_batch_size = int(config['batch_size'])
79
+ model.training = False
80
+
81
+ self._model = model
82
+
83
+ def process(self, document):
84
+ sentences = document.sentences
85
+
86
+ cased_words = []
87
+ sent_ids = []
88
+ word_pos = []
89
+ for sent_idx, sentence in enumerate(sentences):
90
+ for word_idx, word in enumerate(sentence.words):
91
+ cased_words.append(word.text)
92
+ sent_ids.append(sent_idx)
93
+ word_pos.append(word_idx)
94
+
95
+ coref_input = {
96
+ "document_id": "wb_doc_1",
97
+ "cased_words": cased_words,
98
+ "sent_id": sent_ids
99
+ }
100
+ coref_input = self._model.build_doc(coref_input)
101
+ results = self._model.run(coref_input)
102
+ clusters = []
103
+ for span_cluster in results.span_clusters:
104
+ if len(span_cluster) == 0:
105
+ continue
106
+ span_cluster = sorted(span_cluster)
107
+
108
+ for span in span_cluster:
109
+ # check there are no sentence crossings before
110
+ # manipulating the spans, since we will expect it to
111
+ # be this way for multiple usages of the spans
112
+ sent_id = sent_ids[span[0]]
113
+ if sent_ids[span[1]-1] != sent_id:
114
+ raise ValueError("The coref model predicted a span that crossed two sentences! Please send this example to us on our github")
115
+
116
+ # treat the longest span as the representative
117
+ # break ties using the first one
118
+ # IF there is the POS processor, and it adds upos tags
119
+ # to the sentence, ties are broken first by maximum
120
+ # number of UPOS and then earliest in the document
121
+ max_len = 0
122
+ best_span = None
123
+ max_propn = 0
124
+ for span_idx, span in enumerate(span_cluster):
125
+ sent_id = sent_ids[span[0]]
126
+ sentence = sentences[sent_id]
127
+ start_word = word_pos[span[0]]
128
+ # fiddle -1 / +1 so as to avoid problems with coref
129
+ # clusters that end at exactly the end of a document
130
+ end_word = word_pos[span[1]-1] + 1
131
+ # very UD specific test for most number of proper nouns in a mention
132
+ # will do nothing if POS is not active (they will all be None)
133
+ num_propn = sum(word.pos == 'PROPN' for word in sentence.words[start_word:end_word])
134
+
135
+ if ((span[1] - span[0] > max_len) or
136
+ span[1] - span[0] == max_len and num_propn > max_propn):
137
+ max_len = span[1] - span[0]
138
+ best_span = span_idx
139
+ max_propn = num_propn
140
+
141
+ mentions = []
142
+ for span in span_cluster:
143
+ sent_id = sent_ids[span[0]]
144
+ start_word = word_pos[span[0]]
145
+ end_word = word_pos[span[1]-1] + 1
146
+ mentions.append(CorefMention(sent_id, start_word, end_word))
147
+ representative = mentions[best_span]
148
+ representative_text = extract_text(document, representative.sentence, representative.start_word, representative.end_word)
149
+
150
+ chain = CorefChain(len(clusters), mentions, representative_text, best_span)
151
+ clusters.append(chain)
152
+
153
+ document.coref = clusters
154
+ return document
stanza/stanza/pipeline/depparse_processor.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processor for performing dependency parsing
3
+ """
4
+
5
+ import torch
6
+
7
+ from stanza.models.common import doc
8
+ from stanza.models.common.utils import unsort
9
+ from stanza.models.common.vocab import VOCAB_PREFIX
10
+ from stanza.models.depparse.data import DataLoader
11
+ from stanza.models.depparse.trainer import Trainer
12
+ from stanza.pipeline._constants import *
13
+ from stanza.pipeline.processor import UDProcessor, register_processor
14
+
15
+ # these imports trigger the "register_variant" decorations
16
+ from stanza.pipeline.external.corenlp_converter_depparse import ConverterDepparse
17
+
18
+ DEFAULT_SEPARATE_BATCH=150
19
+
20
+ @register_processor(name=DEPPARSE)
21
+ class DepparseProcessor(UDProcessor):
22
+
23
+ # set of processor requirements this processor fulfills
24
+ PROVIDES_DEFAULT = set([DEPPARSE])
25
+ # set of processor requirements for this processor
26
+ REQUIRES_DEFAULT = set([TOKENIZE, POS, LEMMA])
27
+
28
+ def __init__(self, config, pipeline, device):
29
+ self._pretagged = None
30
+ super().__init__(config, pipeline, device)
31
+
32
+ def _set_up_requires(self):
33
+ self._pretagged = self._config.get('pretagged')
34
+ if self._pretagged:
35
+ self._requires = set()
36
+ else:
37
+ self._requires = self.__class__.REQUIRES_DEFAULT
38
+
39
+ def _set_up_model(self, config, pipeline, device):
40
+ self._pretrain = pipeline.foundation_cache.load_pretrain(config['pretrain_path']) if 'pretrain_path' in config else None
41
+ args = {'charlm_forward_file': config.get('forward_charlm_path', None),
42
+ 'charlm_backward_file': config.get('backward_charlm_path', None)}
43
+ self._trainer = Trainer(args=args, pretrain=self.pretrain, model_file=config['model_path'], device=device, foundation_cache=pipeline.foundation_cache)
44
+
45
+ def get_known_relations(self):
46
+ """
47
+ Return a list of relations which this processor can produce
48
+ """
49
+ keys = [k for k in self.vocab['deprel']._unit2id.keys() if k not in VOCAB_PREFIX]
50
+ return keys
51
+
52
+ def process(self, document):
53
+ if hasattr(self, '_variant'):
54
+ return self._variant.process(document)
55
+
56
+ if any(word.upos is None and word.xpos is None for sentence in document.sentences for word in sentence.words):
57
+ raise ValueError("POS not run before depparse!")
58
+ try:
59
+ batch = DataLoader(document, self.config['batch_size'], self.config, self.pretrain, vocab=self.vocab, evaluation=True,
60
+ sort_during_eval=self.config.get('sort_during_eval', True),
61
+ min_length_to_batch_separately=self.config.get('min_length_to_batch_separately', DEFAULT_SEPARATE_BATCH))
62
+ with torch.no_grad():
63
+ preds = []
64
+ for i, b in enumerate(batch):
65
+ preds += self.trainer.predict(b)
66
+ if batch.data_orig_idx is not None:
67
+ preds = unsort(preds, batch.data_orig_idx)
68
+ batch.doc.set((doc.HEAD, doc.DEPREL), [y for x in preds for y in x])
69
+ # build dependencies based on predictions
70
+ for sentence in batch.doc.sentences:
71
+ sentence.build_dependencies()
72
+ return batch.doc
73
+ except RuntimeError as e:
74
+ if str(e).startswith("CUDA out of memory. Tried to allocate"):
75
+ new_message = str(e) + " ... You may be able to compensate for this by separating long sentences into their own batch with a parameter such as depparse_min_length_to_batch_separately=150 or by limiting the overall batch size with depparse_batch_size=400."
76
+ raise RuntimeError(new_message) from e
77
+ else:
78
+ raise
stanza/stanza/pipeline/langid_processor.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processor for determining language of text.
3
+ """
4
+
5
+ import emoji
6
+ import re
7
+ import stanza
8
+ import torch
9
+
10
+ from stanza.models.common.doc import Document
11
+ from stanza.models.langid.model import LangIDBiLSTM
12
+ from stanza.pipeline._constants import *
13
+ from stanza.pipeline.processor import UDProcessor, register_processor
14
+
15
+
16
+ @register_processor(name=LANGID)
17
+ class LangIDProcessor(UDProcessor):
18
+ """
19
+ Class for detecting language of text.
20
+ """
21
+
22
+ # set of processor requirements this processor fulfills
23
+ PROVIDES_DEFAULT = set([LANGID])
24
+
25
+ # set of processor requirements for this processor
26
+ REQUIRES_DEFAULT = set([])
27
+
28
+ # default max sequence length
29
+ MAX_SEQ_LENGTH_DEFAULT = 1000
30
+
31
+ def _set_up_model(self, config, pipeline, device):
32
+ batch_size = config.get("batch_size", 64)
33
+ self._model = LangIDBiLSTM.load(path=config["model_path"], device=device,
34
+ batch_size=batch_size, lang_subset=config.get("lang_subset"))
35
+ self._char_index = self._model.char_to_idx
36
+ self._clean_text = config.get("clean_text")
37
+
38
+ def _text_to_tensor(self, docs):
39
+ """
40
+ Map list of strings to batch tensor. Assumed all docs are same length.
41
+ """
42
+
43
+ device = next(self._model.parameters()).device
44
+ all_docs = []
45
+ for doc in docs:
46
+ doc_chars = [self._char_index.get(c, self._char_index["UNK"]) for c in list(doc)]
47
+ all_docs.append(doc_chars)
48
+ return torch.tensor(all_docs, device=device, dtype=torch.long)
49
+
50
+ def _id_langs(self, batch_tensor):
51
+ """
52
+ Identify languages for each sequence in a batch tensor
53
+ """
54
+ predictions = self._model.prediction_scores(batch_tensor)
55
+ prediction_labels = [self._model.idx_to_tag[prediction] for prediction in predictions]
56
+
57
+ return prediction_labels
58
+
59
+ # regexes for cleaning text
60
+ http_regex = re.compile(r"https?:\/\/t\.co/[a-zA-Z0-9]+")
61
+ handle_regex = re.compile("@[a-zA-Z0-9_]+")
62
+ hashtag_regex = re.compile("#[a-zA-Z]+")
63
+ punctuation_regex = re.compile("[!.]+")
64
+ all_regexes = [http_regex, handle_regex, hashtag_regex, punctuation_regex]
65
+
66
+ @staticmethod
67
+ def clean_text(text):
68
+ """
69
+ Process text to improve language id performance. Main emphasis is on tweets, this method removes shortened
70
+ urls, hashtags, handles, and punctuation and emoji.
71
+ """
72
+
73
+ for regex in LangIDProcessor.all_regexes:
74
+ text = regex.sub(" ", text)
75
+
76
+ text = emoji.emojize(text)
77
+ text = emoji.replace_emoji(text, replace=' ')
78
+
79
+ if text.strip():
80
+ text = text.strip()
81
+
82
+ return text
83
+
84
+ def _process_list(self, docs):
85
+ """
86
+ Identify language of list of strings or Documents
87
+ """
88
+
89
+ if len(docs) == 0:
90
+ # TO DO: what standard do we want for bad input, such as empty list?
91
+ # TO DO: more handling of bad input
92
+ return
93
+
94
+ if isinstance(docs[0], str):
95
+ docs = [Document([], text) for text in docs]
96
+
97
+ docs_by_length = {}
98
+ for doc in docs:
99
+ text = LangIDProcessor.clean_text(doc.text) if self._clean_text else doc.text
100
+ doc_length = len(text)
101
+ if doc_length not in docs_by_length:
102
+ docs_by_length[doc_length] = []
103
+ docs_by_length[doc_length].append((doc, text))
104
+
105
+ for doc_length in docs_by_length:
106
+ inputs = [doc[1] for doc in docs_by_length[doc_length]]
107
+ predictions = self._id_langs(self._text_to_tensor(inputs))
108
+ for doc, lang in zip(docs_by_length[doc_length], predictions):
109
+ doc[0].lang = lang
110
+
111
+ return docs
112
+
113
+ def process(self, doc):
114
+ """
115
+ Handle single str or Document
116
+ """
117
+
118
+ wrapped_doc = [doc]
119
+ return self._process_list(wrapped_doc)[0]
120
+
121
+ def bulk_process(self, docs):
122
+ """
123
+ Handle list of strings or Documents
124
+ """
125
+
126
+ return self._process_list(docs)
127
+
stanza/stanza/pipeline/lemma_processor.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processor for performing lemmatization
3
+ """
4
+
5
+ from itertools import compress
6
+
7
+ import torch
8
+
9
+ from stanza.models.common import doc
10
+ from stanza.models.lemma.data import DataLoader
11
+ from stanza.models.lemma.trainer import Trainer
12
+ from stanza.pipeline._constants import *
13
+ from stanza.pipeline.processor import UDProcessor, register_processor
14
+
15
+ WORD_TAGS = [doc.TEXT, doc.UPOS]
16
+
17
+ @register_processor(name=LEMMA)
18
+ class LemmaProcessor(UDProcessor):
19
+
20
+ # set of processor requirements this processor fulfills
21
+ PROVIDES_DEFAULT = set([LEMMA])
22
+ # set of processor requirements for this processor
23
+ # pos will be added later for non-identity lemmatizerx
24
+ REQUIRES_DEFAULT = set([TOKENIZE])
25
+ # default batch size
26
+ DEFAULT_BATCH_SIZE = 5000
27
+
28
+ def __init__(self, config, pipeline, device):
29
+ # run lemmatizer in identity mode
30
+ self._use_identity = None
31
+ self._pretagged = None
32
+ super().__init__(config, pipeline, device)
33
+
34
+ @property
35
+ def use_identity(self):
36
+ return self._use_identity
37
+
38
+ def _set_up_model(self, config, pipeline, device):
39
+ if config.get('use_identity') in ['True', True]:
40
+ self._use_identity = True
41
+ self._config = config
42
+ self.config['batch_size'] = LemmaProcessor.DEFAULT_BATCH_SIZE
43
+ else:
44
+ # the lemmatizer only looks at one word when making
45
+ # decisions, not the surrounding context
46
+ # therefore, we can save some time by remembering what
47
+ # we did the last time we saw any given word,pos
48
+ # since a long running program will remember everything
49
+ # (unless we go back and make it smarter)
50
+ # we make this an option, not the default
51
+ # TODO: need to update the cache to skip the contextual lemmatizer
52
+ self.store_results = config.get('store_results', False)
53
+ self._use_identity = False
54
+ args = {'charlm_forward_file': config.get('forward_charlm_path', None),
55
+ 'charlm_backward_file': config.get('backward_charlm_path', None)}
56
+ lemma_classifier_args = dict(args)
57
+ lemma_classifier_args['wordvec_pretrain_file'] = config.get('pretrain_path', None)
58
+ self._trainer = Trainer(args=args, model_file=config['model_path'], device=device, foundation_cache=pipeline.foundation_cache, lemma_classifier_args=lemma_classifier_args)
59
+
60
+ def _set_up_requires(self):
61
+ self._pretagged = self._config.get('pretagged', None)
62
+ if self._pretagged:
63
+ self._requires = set()
64
+ elif self.config.get('pos') and not self.use_identity:
65
+ self._requires = LemmaProcessor.REQUIRES_DEFAULT.union(set([POS]))
66
+ else:
67
+ self._requires = LemmaProcessor.REQUIRES_DEFAULT
68
+
69
+ def process(self, document):
70
+ if not self.use_identity:
71
+ batch = DataLoader(document, self.config['batch_size'], self.config, vocab=self.vocab, evaluation=True, expand_unk_vocab=True)
72
+ else:
73
+ batch = DataLoader(document, self.config['batch_size'], self.config, evaluation=True, conll_only=True)
74
+ if self.use_identity:
75
+ preds = [word.text for sent in batch.doc.sentences for word in sent.words]
76
+ elif self.config.get('dict_only', False):
77
+ preds = self.trainer.predict_dict(batch.doc.get([doc.TEXT, doc.UPOS]))
78
+ else:
79
+ if self.config.get('ensemble_dict', False):
80
+ # skip the seq2seq model when we can
81
+ skip = self.trainer.skip_seq2seq(batch.doc.get([doc.TEXT, doc.UPOS]))
82
+ # although there is no explicit use of caseless or lemma_caseless in this processor,
83
+ # it shows up in the config which gets passed to the DataLoader,
84
+ # possibly affecting its results
85
+ seq2seq_batch = DataLoader(document, self.config['batch_size'], self.config, vocab=self.vocab,
86
+ evaluation=True, skip=skip, expand_unk_vocab=True)
87
+ else:
88
+ seq2seq_batch = batch
89
+
90
+ with torch.no_grad():
91
+ preds = []
92
+ edits = []
93
+ for i, b in enumerate(seq2seq_batch):
94
+ ps, es = self.trainer.predict(b, self.config['beam_size'], seq2seq_batch.vocab)
95
+ preds += ps
96
+ if es is not None:
97
+ edits += es
98
+
99
+ if self.config.get('ensemble_dict', False):
100
+ word_tags = batch.doc.get(WORD_TAGS)
101
+ words = [x[0] for x in word_tags]
102
+ preds = self.trainer.postprocess([x for x, y in zip(words, skip) if not y], preds, edits=edits)
103
+ if self.store_results:
104
+ new_word_tags = compress(word_tags, map(lambda x: not x, skip))
105
+ new_predictions = [(x[0], x[1], y) for x, y in zip(new_word_tags, preds)]
106
+ self.trainer.train_dict(new_predictions, update_word_dict=False)
107
+ # expand seq2seq predictions to the same size as all words
108
+ i = 0
109
+ preds1 = []
110
+ for s in skip:
111
+ if s:
112
+ preds1.append('')
113
+ else:
114
+ preds1.append(preds[i])
115
+ i += 1
116
+ preds = self.trainer.ensemble(word_tags, preds1)
117
+ else:
118
+ preds = self.trainer.postprocess(batch.doc.get([doc.TEXT]), preds, edits=edits)
119
+
120
+ if self.trainer.has_contextual_lemmatizers():
121
+ preds = self.trainer.update_contextual_preds(batch.doc, preds)
122
+
123
+ # map empty string lemmas to '_'
124
+ preds = [max([(len(x), x), (0, '_')])[1] for x in preds]
125
+ batch.doc.set([doc.LEMMA], preds)
126
+ return batch.doc
stanza/stanza/pipeline/multilingual.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Class for running multilingual pipelines
3
+ """
4
+
5
+ from collections import OrderedDict
6
+ import copy
7
+ import logging
8
+ from typing import Union
9
+
10
+ from stanza.models.common.doc import Document
11
+ from stanza.models.common.utils import default_device
12
+ from stanza.pipeline.core import Pipeline, DownloadMethod
13
+ from stanza.pipeline._constants import *
14
+ from stanza.resources.common import DEFAULT_MODEL_DIR, get_language_resources, load_resources_json
15
+
16
+ logger = logging.getLogger('stanza')
17
+
18
+ class MultilingualPipeline:
19
+ """
20
+ Pipeline for handling multilingual data. Takes in text, detects language, and routes request to pipeline for that
21
+ language.
22
+
23
+ You can specify options to individual language pipelines with the lang_configs field.
24
+ For example, if you want English pipelines to have NER, but want to turn that off for French, you can do:
25
+ lang_configs = {"en": {"processors": "tokenize,pos,lemma,depparse,ner"},
26
+ "fr": {"processors": "tokenize,pos,lemma,depparse"}}
27
+ pipeline = MultilingualPipeline(lang_configs=lang_configs)
28
+
29
+ You can also pass in a defaultdict created in such a way that it provides default parameters for each language.
30
+ For example, in order to only get tokenization for each language:
31
+ (remembering that the Pipeline will automagically add MWT to a language which uses MWT):
32
+ from collections import defaultdict
33
+ lang_configs = defaultdict(lambda: dict(processors="tokenize"))
34
+ pipeline = MultilingualPipeline(lang_configs=lang_configs)
35
+
36
+ download_method can be set as in Pipeline to turn off downloading
37
+ of the .json config or turn off downloading of everything
38
+ """
39
+
40
+ def __init__(self,
41
+ model_dir: str = DEFAULT_MODEL_DIR,
42
+ lang_id_config: dict = None,
43
+ lang_configs: dict = None,
44
+ ld_batch_size: int = 64,
45
+ max_cache_size: int = 10,
46
+ use_gpu: bool = None,
47
+ restrict: bool = False,
48
+ device: str = None,
49
+ download_method: DownloadMethod = DownloadMethod.DOWNLOAD_RESOURCES,
50
+ # python 3.6 compatibility - maybe want to update to 3.7 at some point
51
+ processors: Union[str, list] = None,
52
+ ):
53
+ # set up configs and cache for various language pipelines
54
+ self.model_dir = model_dir
55
+ self.lang_id_config = {} if lang_id_config is None else copy.deepcopy(lang_id_config)
56
+ self.lang_configs = {} if lang_configs is None else copy.deepcopy(lang_configs)
57
+ self.max_cache_size = max_cache_size
58
+ # OrderedDict so we can use it as a LRU cache
59
+ # most recent Pipeline goes to the end, pop the oldest one
60
+ # when we run out of space
61
+ self.pipeline_cache = OrderedDict()
62
+ if processors is None:
63
+ self.default_processors = None
64
+ elif isinstance(processors, str):
65
+ self.default_processors = [x.strip() for x in processors.split(",")]
66
+ else:
67
+ self.default_processors = list(processors)
68
+
69
+ self.download_method = download_method
70
+ if 'download_method' not in self.lang_id_config:
71
+ self.lang_id_config['download_method'] = self.download_method
72
+
73
+ # if lang is not in any of the lang_configs, update them to
74
+ # include the lang parameter. otherwise, the default language
75
+ # will always be used...
76
+ for lang in self.lang_configs:
77
+ if 'lang' not in self.lang_configs[lang]:
78
+ self.lang_configs[lang]['lang'] = lang
79
+
80
+ if restrict and 'langid_lang_subset' not in self.lang_id_config:
81
+ known_langs = sorted(self.lang_configs.keys())
82
+ if known_langs == 0:
83
+ logger.warning("MultilingualPipeline asked to restrict to lang_configs, but lang_configs was empty. Ignoring...")
84
+ else:
85
+ logger.debug("Restricting MultilingualPipeline to %s", known_langs)
86
+ self.lang_id_config['langid_lang_subset'] = known_langs
87
+
88
+ # set use_gpu
89
+ if device is None:
90
+ if use_gpu is None or use_gpu == True:
91
+ device = default_device()
92
+ else:
93
+ device = 'cpu'
94
+ self.device = device
95
+
96
+ # build language id pipeline
97
+ self.lang_id_pipeline = Pipeline(dir=self.model_dir, lang='multilingual', processors="langid",
98
+ device=self.device, **self.lang_id_config)
99
+ # load the resources so that we can refer to it later when building a new pipeline
100
+ # note that it was either downloaded or not based on download_method when building the lang_id_pipeline
101
+ self.resources = load_resources_json(self.model_dir)
102
+
103
+ def _update_pipeline_cache(self, lang):
104
+ """
105
+ Do any necessary updates to the pipeline cache for this language. This includes building a new
106
+ pipeline for the lang, and possibly clearing out a language with the old last access date.
107
+ """
108
+
109
+ # update request history
110
+ if lang in self.pipeline_cache:
111
+ self.pipeline_cache.move_to_end(lang, last=True)
112
+
113
+ # update language configs
114
+ # try/except to allow for a defaultdict
115
+ try:
116
+ lang_config = self.lang_configs[lang]
117
+ except KeyError:
118
+ lang_config = {'lang': lang}
119
+ self.lang_configs[lang] = lang_config
120
+
121
+ # if a defaultdict is passed in, the defaultdict might not contain 'lang'
122
+ # so even though we tried adding 'lang' in the constructor, we'll check again here
123
+ if 'lang' not in lang_config:
124
+ lang_config['lang'] = lang
125
+
126
+ if 'download_method' not in lang_config:
127
+ lang_config['download_method'] = self.download_method
128
+
129
+ if 'processors' not in lang_config:
130
+ if self.default_processors:
131
+ lang_resources = get_language_resources(self.resources, lang)
132
+ lang_processors = [x for x in self.default_processors if x in lang_resources]
133
+ if lang_processors != self.default_processors:
134
+ logger.info("Not all requested processors %s available for %s. Loading %s instead", self.default_processors, lang, lang_processors)
135
+ lang_config['processors'] = ",".join(lang_processors)
136
+
137
+ if 'device' not in lang_config:
138
+ lang_config['device'] = self.device
139
+
140
+ # update pipeline cache
141
+ if lang not in self.pipeline_cache:
142
+ logger.debug("Loading unknown language in MultilingualPipeline: %s", lang)
143
+ # clear least recently used lang from pipeline cache
144
+ if len(self.pipeline_cache) == self.max_cache_size:
145
+ self.pipeline_cache.popitem(last=False)
146
+ self.pipeline_cache[lang] = Pipeline(dir=self.model_dir, **self.lang_configs[lang])
147
+
148
+ def process(self, doc):
149
+ """
150
+ Run language detection on a string, a Document, or a list of either, route to language specific pipeline
151
+ """
152
+
153
+ # only return a list if given a list
154
+ singleton_input = not isinstance(doc, list)
155
+ if singleton_input:
156
+ docs = [doc]
157
+ else:
158
+ docs = doc
159
+
160
+ if docs and isinstance(docs[0], str):
161
+ docs = [Document([], text=text) for text in docs]
162
+
163
+ # run language identification
164
+ docs_w_langid = self.lang_id_pipeline.process(docs)
165
+
166
+ # create language specific batches, store global idx with each doc
167
+ lang_batches = {}
168
+ for doc_idx, doc in enumerate(docs_w_langid):
169
+ logger.debug("Language for document %d: %s", doc_idx, doc.lang)
170
+ if doc.lang not in lang_batches:
171
+ lang_batches[doc.lang] = []
172
+ lang_batches[doc.lang].append(doc)
173
+
174
+ # run through each language, submit a batch to the language specific pipeline
175
+ for lang in lang_batches.keys():
176
+ self._update_pipeline_cache(lang)
177
+ self.pipeline_cache[lang](lang_batches[lang])
178
+
179
+ # only return a list if given a list
180
+ if singleton_input:
181
+ return docs_w_langid[0]
182
+ else:
183
+ return docs_w_langid
184
+
185
+ def __call__(self, doc):
186
+ doc = self.process(doc)
187
+ return doc
188
+
stanza/stanza/pipeline/mwt_processor.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processor for performing multi-word-token expansion
3
+ """
4
+
5
+ import io
6
+
7
+ import torch
8
+
9
+ from stanza.models.mwt.data import DataLoader
10
+ from stanza.models.mwt.trainer import Trainer
11
+ from stanza.pipeline._constants import *
12
+ from stanza.pipeline.processor import UDProcessor, register_processor
13
+
14
+ @register_processor(MWT)
15
+ class MWTProcessor(UDProcessor):
16
+
17
+ # set of processor requirements this processor fulfills
18
+ PROVIDES_DEFAULT = set([MWT])
19
+ # set of processor requirements for this processor
20
+ REQUIRES_DEFAULT = set([TOKENIZE])
21
+
22
+ def _set_up_model(self, config, pipeline, device):
23
+ self._trainer = Trainer(model_file=config['model_path'], device=device)
24
+
25
+ def build_batch(self, document):
26
+ return DataLoader(document, self.config['batch_size'], self.config, vocab=self.vocab, evaluation=True, expand_unk_vocab=True)
27
+
28
+ def process(self, document):
29
+ batch = self.build_batch(document)
30
+
31
+ # process the rest
32
+ expansions = batch.doc.get_mwt_expansions(evaluation=True)
33
+ if len(batch) > 0:
34
+ # decide trainer type and run eval
35
+ if self.config['dict_only']:
36
+ preds = self.trainer.predict_dict(expansions)
37
+ else:
38
+ with torch.no_grad():
39
+ preds = []
40
+ for i, b in enumerate(batch.to_loader()):
41
+ preds += self.trainer.predict(b, never_decode_unk=True, vocab=batch.vocab)
42
+
43
+ if self.config.get('ensemble_dict', False):
44
+ preds = self.trainer.ensemble(expansions, preds)
45
+ else:
46
+ # skip eval if dev data does not exist
47
+ preds = []
48
+
49
+ batch.doc.set_mwt_expansions(preds, process_manual_expanded=False)
50
+ return batch.doc
51
+
52
+ def bulk_process(self, docs):
53
+ """
54
+ MWT processor counts some statistics on the individual docs, so we need to separately redo those stats
55
+ """
56
+ docs = super().bulk_process(docs)
57
+ for doc in docs:
58
+ doc._count_words()
59
+ return docs
stanza/stanza/pipeline/pos_processor.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processor for performing part-of-speech tagging
3
+ """
4
+
5
+ import torch
6
+
7
+ from stanza.models.common import doc
8
+ from stanza.models.common.utils import unsort
9
+ from stanza.models.common.vocab import VOCAB_PREFIX, CompositeVocab
10
+ from stanza.models.pos.data import Dataset
11
+ from stanza.models.pos.trainer import Trainer
12
+ from stanza.pipeline._constants import *
13
+ from stanza.pipeline.processor import UDProcessor, register_processor
14
+ from stanza.utils.get_tqdm import get_tqdm
15
+
16
+ tqdm = get_tqdm()
17
+
18
+ @register_processor(name=POS)
19
+ class POSProcessor(UDProcessor):
20
+
21
+ # set of processor requirements this processor fulfills
22
+ PROVIDES_DEFAULT = set([POS])
23
+ # set of processor requirements for this processor
24
+ REQUIRES_DEFAULT = set([TOKENIZE])
25
+
26
+ def _set_up_model(self, config, pipeline, device):
27
+ # get pretrained word vectors
28
+ self._pretrain = pipeline.foundation_cache.load_pretrain(config['pretrain_path']) if 'pretrain_path' in config else None
29
+ args = {'charlm_forward_file': config.get('forward_charlm_path', None),
30
+ 'charlm_backward_file': config.get('backward_charlm_path', None)}
31
+ # set up trainer
32
+ self._trainer = Trainer(pretrain=self.pretrain, model_file=config['model_path'], device=device, args=args, foundation_cache=pipeline.foundation_cache)
33
+ self._tqdm = 'tqdm' in config and config['tqdm']
34
+
35
+ def __str__(self):
36
+ return "POSProcessor(%s)" % self.config['model_path']
37
+
38
+ def get_known_xpos(self):
39
+ """
40
+ Returns the xpos tags known by this model
41
+ """
42
+ if isinstance(self.vocab['xpos'], CompositeVocab):
43
+ if len(self.vocab['xpos']) == 1:
44
+ return [k for k in self.vocab['xpos'][0]._unit2id.keys() if k not in VOCAB_PREFIX]
45
+ else:
46
+ return {k: v.keys() - VOCAB_PREFIX for k, v in self.vocab['xpos']._unit2id.items()}
47
+ return [k for k in self.vocab['xpos']._unit2id.keys() if k not in VOCAB_PREFIX]
48
+
49
+ def is_composite_xpos(self):
50
+ """
51
+ Returns if the xpos tags are part of a composite vocab
52
+ """
53
+ return isinstance(self.vocab['xpos'], CompositeVocab)
54
+
55
+ def get_known_upos(self):
56
+ """
57
+ Returns the upos tags known by this model
58
+ """
59
+ keys = [k for k in self.vocab['upos']._unit2id.keys() if k not in VOCAB_PREFIX]
60
+ return keys
61
+
62
+ def get_known_feats(self):
63
+ """
64
+ Returns the features known by this model
65
+ """
66
+ values = {k: v.keys() - VOCAB_PREFIX for k, v in self.vocab['feats']._unit2id.items()}
67
+ return values
68
+
69
+ def process(self, document):
70
+ # currently, POS models are saved w/o the batch_maximum_tokens flag
71
+ maximum_tokens = self.config.get('batch_maximum_tokens', 5000)
72
+
73
+ dataset = Dataset(
74
+ document, self.config, self.pretrain, vocab=self.vocab, evaluation=True,
75
+ sort_during_eval=True)
76
+ batch = iter(dataset.to_length_limited_loader(batch_size=self.config['batch_size'], maximum_tokens=maximum_tokens))
77
+ preds = []
78
+
79
+ idx = []
80
+ with torch.no_grad():
81
+ if self._tqdm:
82
+ batch = tqdm(batch)
83
+ for i, b in enumerate(batch):
84
+ idx.extend(b[-1])
85
+ preds += self.trainer.predict(b)
86
+
87
+ preds = unsort(preds, idx)
88
+ dataset.doc.set([doc.UPOS, doc.XPOS, doc.FEATS], [y for x in preds for y in x])
89
+ return dataset.doc
stanza/stanza/pipeline/processor.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base classes for processors
3
+ """
4
+
5
+ from abc import ABC, abstractmethod
6
+
7
+ from stanza.models.common.doc import Document
8
+ from stanza.pipeline.registry import NAME_TO_PROCESSOR_CLASS, PIPELINE_NAMES, PROCESSOR_VARIANTS
9
+
10
+ class ProcessorRequirementsException(Exception):
11
+ """ Exception indicating a processor's requirements will not be met """
12
+
13
+ def __init__(self, processors_list, err_processor, provided_reqs):
14
+ self._err_processor = err_processor
15
+ # mark the broken processor as inactive, drop resources
16
+ self.err_processor.mark_inactive()
17
+ self._processors_list = processors_list
18
+ self._provided_reqs = provided_reqs
19
+ self.build_message()
20
+
21
+ @property
22
+ def err_processor(self):
23
+ """ The processor that raised the exception """
24
+ return self._err_processor
25
+
26
+ @property
27
+ def processor_type(self):
28
+ return type(self.err_processor).__name__
29
+
30
+ @property
31
+ def processors_list(self):
32
+ return self._processors_list
33
+
34
+ @property
35
+ def provided_reqs(self):
36
+ return self._provided_reqs
37
+
38
+ def build_message(self):
39
+ self.message = (f"---\nPipeline Requirements Error!\n"
40
+ f"\tProcessor: {self.processor_type}\n"
41
+ f"\tPipeline processors list: {','.join(self.processors_list)}\n"
42
+ f"\tProcessor Requirements: {self.err_processor.requires}\n"
43
+ f"\t\t- fulfilled: {self.err_processor.requires.intersection(self.provided_reqs)}\n"
44
+ f"\t\t- missing: {self.err_processor.requires - self.provided_reqs}\n"
45
+ f"\nThe processors list provided for this pipeline is invalid. Please make sure all "
46
+ f"prerequisites are met for every processor.\n\n")
47
+
48
+ def __str__(self):
49
+ return self.message
50
+
51
+
52
+ class Processor(ABC):
53
+ """ Base class for all processors """
54
+
55
+ def __init__(self, config, pipeline, device):
56
+ # overall config for the processor
57
+ self._config = config
58
+ # pipeline building this processor (presently processors are only meant to exist in one pipeline)
59
+ self._pipeline = pipeline
60
+ self._set_up_variants(config, device)
61
+ # run set up process
62
+ # set up what annotations are required based on config
63
+ if not self._set_up_variant_requires():
64
+ self._set_up_requires()
65
+ # set up what annotations are provided based on config
66
+ self._set_up_provides()
67
+ # given pipeline constructing this processor, check if requirements are met, throw exception if not
68
+ self._check_requirements()
69
+
70
+ if hasattr(self, '_variant') and self._variant.OVERRIDE:
71
+ self.process = self._variant.process
72
+
73
+ def __str__(self):
74
+ """
75
+ Simple description of the processor: name(model)
76
+ """
77
+ name = self.__class__.__name__
78
+ model = None
79
+ if self._config is not None:
80
+ model = self._config.get('model_path')
81
+ if model is None:
82
+ return name
83
+ else:
84
+ return "{}({})".format(name, model)
85
+
86
+
87
+ @abstractmethod
88
+ def process(self, doc):
89
+ """ Process a Document. This is the main method of a processor. """
90
+ pass
91
+
92
+ def bulk_process(self, docs):
93
+ """ Process a list of Documents. This should be replaced with a more efficient implementation if possible. """
94
+
95
+ if hasattr(self, '_variant'):
96
+ return self._variant.bulk_process(docs)
97
+
98
+ return [self.process(doc) for doc in docs]
99
+
100
+ def _set_up_provides(self):
101
+ """ Set up what processor requirements this processor fulfills. Default is to use a class defined list. """
102
+ self._provides = self.__class__.PROVIDES_DEFAULT
103
+
104
+ def _set_up_requires(self):
105
+ """ Set up requirements for this processor. Default is to use a class defined list. """
106
+ self._requires = self.__class__.REQUIRES_DEFAULT
107
+
108
+ def _set_up_variant_requires(self):
109
+ """
110
+ If this has a variant with its own requirements, use those instead
111
+
112
+ Returns True iff the _requires is set from the _variant
113
+ """
114
+ if not hasattr(self, '_variant'):
115
+ return False
116
+ if hasattr(self._variant, '_set_up_requires'):
117
+ self._variant._set_up_requires()
118
+ self._requires = self._variant._requires
119
+ return True
120
+ if hasattr(self._variant.__class__, 'REQUIRES_DEFAULT'):
121
+ self._requires = self._variant.__class__.REQUIRES_DEFAULT
122
+ return True
123
+ return False
124
+
125
+ def _set_up_variants(self, config, device):
126
+ processor_name = list(self.__class__.PROVIDES_DEFAULT)[0]
127
+ if any(config.get(f'with_{variant}', False) for variant in PROCESSOR_VARIANTS[processor_name]):
128
+ self._trainer = None
129
+ variant_name = [variant for variant in PROCESSOR_VARIANTS[processor_name] if config.get(f'with_{variant}', False)][0]
130
+ self._variant = PROCESSOR_VARIANTS[processor_name][variant_name](config)
131
+
132
+ @property
133
+ def config(self):
134
+ """ Configurations for the processor """
135
+ return self._config
136
+
137
+ @property
138
+ def pipeline(self):
139
+ """ The pipeline that this processor belongs to """
140
+ return self._pipeline
141
+
142
+ @property
143
+ def provides(self):
144
+ return self._provides
145
+
146
+ @property
147
+ def requires(self):
148
+ return self._requires
149
+
150
+ def _check_requirements(self):
151
+ """ Given a list of fulfilled requirements, check if all of this processor's requirements are met or not. """
152
+ if not self.config.get("check_requirements", True):
153
+ return
154
+ provided_reqs = set.union(*[processor.provides for processor in self.pipeline.loaded_processors]+[set([])])
155
+ if self.requires - provided_reqs:
156
+ load_names = [item[0] for item in self.pipeline.load_list]
157
+ raise ProcessorRequirementsException(load_names, self, provided_reqs)
158
+
159
+
160
+ class ProcessorVariant(ABC):
161
+ """ Base class for all processor variants """
162
+
163
+ OVERRIDE = False # Set to true to override all the processing from the processor
164
+
165
+ @abstractmethod
166
+ def process(self, doc):
167
+ """
168
+ Process a document that is potentially preprocessed by the processor.
169
+ This is the main method of a processor variant.
170
+
171
+ If `OVERRIDE` is set to True, all preprocessing by the processor would be bypassed, and the processor variant
172
+ would serve as a drop-in replacement of the entire processor, and has to be able to interpret all the configs
173
+ that are typically handled by the processor it replaces.
174
+ """
175
+ pass
176
+
177
+ def bulk_process(self, docs):
178
+ """ Process a list of Documents. This should be replaced with a more efficient implementation if possible. """
179
+
180
+ return [self.process(doc) for doc in docs]
181
+
182
+ class UDProcessor(Processor):
183
+ """ Base class for the neural UD Processors (tokenize,mwt,pos,lemma,depparse,sentiment,constituency) """
184
+
185
+ def __init__(self, config, pipeline, device):
186
+ super().__init__(config, pipeline, device)
187
+
188
+ # UD model resources, set up is processor specific
189
+ self._pretrain = None
190
+ self._trainer = None
191
+ self._vocab = None
192
+ if not hasattr(self, '_variant'):
193
+ self._set_up_model(config, pipeline, device)
194
+
195
+ # build the final config for the processor
196
+ self._set_up_final_config(config)
197
+
198
+ @abstractmethod
199
+ def _set_up_model(self, config, pipeline, device):
200
+ pass
201
+
202
+ def _set_up_final_config(self, config):
203
+ """ Finalize the configurations for this processor, based off of values from a UD model. """
204
+ # set configurations from loaded model
205
+ if self._trainer is not None:
206
+ loaded_args, self._vocab = self._trainer.args, self._trainer.vocab
207
+ # filter out unneeded args from model
208
+ loaded_args = {k: v for k, v in loaded_args.items() if not UDProcessor.filter_out_option(k)}
209
+ else:
210
+ loaded_args = {}
211
+ loaded_args.update(config)
212
+ self._config = loaded_args
213
+
214
+ def mark_inactive(self):
215
+ """ Drop memory intensive resources if keeping this processor around for reasons other than running it. """
216
+ self._trainer = None
217
+ self._vocab = None
218
+
219
+ @property
220
+ def pretrain(self):
221
+ return self._pretrain
222
+
223
+ @property
224
+ def trainer(self):
225
+ return self._trainer
226
+
227
+ @property
228
+ def vocab(self):
229
+ return self._vocab
230
+
231
+ @staticmethod
232
+ def filter_out_option(option):
233
+ """ Filter out non-processor configurations """
234
+ options_to_filter = ['device', 'cpu', 'cuda', 'dev_conll_gold', 'epochs', 'lang', 'mode', 'save_name', 'shorthand']
235
+ if option.endswith('_file') or option.endswith('_dir'):
236
+ return True
237
+ elif option in options_to_filter:
238
+ return True
239
+ else:
240
+ return False
241
+
242
+ def bulk_process(self, docs):
243
+ """
244
+ Most processors operate on the sentence level, where each sentence is processed independently and processors can benefit
245
+ a lot from the ability to combine sentences from multiple documents for faster batched processing. This is a transparent
246
+ implementation that allows these processors to batch process a list of Documents as if they were from a single Document.
247
+ """
248
+
249
+ if hasattr(self, '_variant'):
250
+ return self._variant.bulk_process(docs)
251
+
252
+ combined_sents = [sent for doc in docs for sent in doc.sentences]
253
+ combined_doc = Document([])
254
+ combined_doc.sentences = combined_sents
255
+ combined_doc.num_tokens = sum(doc.num_tokens for doc in docs)
256
+ combined_doc.num_words = sum(doc.num_words for doc in docs)
257
+
258
+ self.process(combined_doc) # annotations are attached to sentence objects
259
+
260
+ return docs
261
+
262
+ class ProcessorRegisterException(Exception):
263
+ """ Exception indicating processor or processor registration failure """
264
+
265
+ def __init__(self, processor_class, expected_parent):
266
+ self._processor_class = processor_class
267
+ self._expected_parent = expected_parent
268
+ self.build_message()
269
+
270
+ def build_message(self):
271
+ self.message = f"Failed to register '{self._processor_class}'. It must be a subclass of '{self._expected_parent}'."
272
+
273
+ def __str__(self):
274
+ return self.message
275
+
276
+ def register_processor(name):
277
+ def wrapper(Cls):
278
+ if not issubclass(Cls, Processor):
279
+ raise ProcessorRegisterException(Cls, Processor)
280
+
281
+ NAME_TO_PROCESSOR_CLASS[name] = Cls
282
+ PIPELINE_NAMES.append(name)
283
+ return Cls
284
+ return wrapper
285
+
286
+ def register_processor_variant(name, variant):
287
+ def wrapper(Cls):
288
+ if not issubclass(Cls, ProcessorVariant):
289
+ raise ProcessorRegisterException(Cls, ProcessorVariant)
290
+
291
+ PROCESSOR_VARIANTS[name][variant] = Cls
292
+ return Cls
293
+ return wrapper
stanza/stanza/pipeline/registry.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+
3
+ # these two get filled by register_processor
4
+ NAME_TO_PROCESSOR_CLASS = dict()
5
+ PIPELINE_NAMES = []
6
+
7
+ # this gets filled by register_processor_variant
8
+ PROCESSOR_VARIANTS = defaultdict(dict)
stanza/stanza/pipeline/sentiment_processor.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Processor that attaches a sentiment score to a sentence
2
+
3
+ The model used is a generally a model trained on the Stanford
4
+ Sentiment Treebank or some similar dataset. When run, this processor
5
+ attaches a score in the form of a string to each sentence in the
6
+ document.
7
+
8
+ TODO: a possible way to generalize this would be to make it a
9
+ ClassifierProcessor and have "sentiment" be an option.
10
+ """
11
+
12
+ import dataclasses
13
+ import torch
14
+
15
+ from types import SimpleNamespace
16
+
17
+ from stanza.models.classifiers.trainer import Trainer
18
+
19
+ from stanza.pipeline._constants import *
20
+ from stanza.pipeline.processor import UDProcessor, register_processor
21
+
22
+ @register_processor(SENTIMENT)
23
+ class SentimentProcessor(UDProcessor):
24
+ # set of processor requirements this processor fulfills
25
+ PROVIDES_DEFAULT = set([SENTIMENT])
26
+ # set of processor requirements for this processor
27
+ # TODO: a constituency based model needs CONSTITUENCY as well
28
+ # issue: by the time we load the model in Processor.__init__,
29
+ # the requirements are already prepared
30
+ REQUIRES_DEFAULT = set([TOKENIZE])
31
+
32
+ # default batch size, measured in words per batch
33
+ DEFAULT_BATCH_SIZE = 5000
34
+
35
+ def _set_up_model(self, config, pipeline, device):
36
+ # get pretrained word vectors
37
+ pretrain_path = config.get('pretrain_path', None)
38
+ forward_charlm_path = config.get('forward_charlm_path', None)
39
+ backward_charlm_path = config.get('backward_charlm_path', None)
40
+ # elmo does not have a convenient way to download intermediate
41
+ # models the way stanza downloads charlms & pretrains or
42
+ # transformers downloads bert etc
43
+ # however, elmo in general is not as good as using a
44
+ # transformer, so it is unlikely we will ever fix this
45
+ args = SimpleNamespace(device = device,
46
+ charlm_forward_file = forward_charlm_path,
47
+ charlm_backward_file = backward_charlm_path,
48
+ wordvec_pretrain_file = pretrain_path,
49
+ elmo_model = None,
50
+ use_elmo = False,
51
+ save_dir = None)
52
+ filename = config['model_path']
53
+ if filename is None:
54
+ raise FileNotFoundError("No model specified for the sentiment processor. Perhaps it is not supported for the language. {}".format(config))
55
+ # set up model
56
+ trainer = Trainer.load(filename=filename,
57
+ args=args,
58
+ foundation_cache=pipeline.foundation_cache)
59
+ self._trainer = trainer
60
+ self._model = trainer.model
61
+ self._model_type = self._model.config.model_type
62
+ # batch size counted as words
63
+ self._batch_size = config.get('batch_size', SentimentProcessor.DEFAULT_BATCH_SIZE)
64
+
65
+ def _set_up_final_config(self, config):
66
+ loaded_args = dataclasses.asdict(self._model.config)
67
+ loaded_args = {k: v for k, v in loaded_args.items() if not UDProcessor.filter_out_option(k)}
68
+ loaded_args.update(config)
69
+ self._config = loaded_args
70
+
71
+
72
+ def process(self, document):
73
+ sentences = self._model.extract_sentences(document)
74
+ with torch.no_grad():
75
+ labels = self._model.label_sentences(sentences, batch_size=self._batch_size)
76
+ # TODO: allow a classifier processor for any attribute, not just sentiment
77
+ document.set(SENTIMENT, labels, to_sentence=True)
78
+ return document
stanza/stanza/pipeline/tokenize_processor.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processor for performing tokenization
3
+ """
4
+
5
+ import io
6
+ import logging
7
+
8
+ import torch
9
+
10
+ from stanza.models.tokenization.data import TokenizationDataset
11
+ from stanza.models.tokenization.trainer import Trainer
12
+ from stanza.models.tokenization.utils import output_predictions
13
+ from stanza.pipeline._constants import *
14
+ from stanza.pipeline.processor import UDProcessor, register_processor
15
+ from stanza.pipeline.registry import PROCESSOR_VARIANTS
16
+ from stanza.models.common import doc
17
+
18
+ # these imports trigger the "register_variant" decorations
19
+ from stanza.pipeline.external.jieba import JiebaTokenizer
20
+ from stanza.pipeline.external.spacy import SpacyTokenizer
21
+ from stanza.pipeline.external.sudachipy import SudachiPyTokenizer
22
+ from stanza.pipeline.external.pythainlp import PyThaiNLPTokenizer
23
+
24
+ logger = logging.getLogger('stanza')
25
+
26
+ TOKEN_TOO_LONG_REPLACEMENT = "<UNK>"
27
+
28
+ # class for running the tokenizer
29
+ @register_processor(name=TOKENIZE)
30
+ class TokenizeProcessor(UDProcessor):
31
+
32
+ # set of processor requirements this processor fulfills
33
+ PROVIDES_DEFAULT = set([TOKENIZE])
34
+ # set of processor requirements for this processor
35
+ REQUIRES_DEFAULT = set([])
36
+ # default max sequence length
37
+ MAX_SEQ_LENGTH_DEFAULT = 1000
38
+
39
+ def _set_up_model(self, config, pipeline, device):
40
+ # set up trainer
41
+ if config.get('pretokenized'):
42
+ self._trainer = None
43
+ else:
44
+ self._trainer = Trainer(model_file=config['model_path'], device=device)
45
+
46
+ # get and typecheck the postprocessor
47
+ postprocessor = config.get('postprocessor')
48
+ if postprocessor and callable(postprocessor):
49
+ self._postprocessor = postprocessor
50
+ elif not postprocessor:
51
+ self._postprocessor = None
52
+ else:
53
+ raise ValueError("Tokenizer recieved 'postprocessor' option of unrecognized type; postprocessor must be callable. Got %s" % postprocessor)
54
+
55
+ def process_pre_tokenized_text(self, input_src):
56
+ """
57
+ Pretokenized text can be provided in 2 manners:
58
+
59
+ 1.) str, tokenized by whitespace, sentence split by newline
60
+ 2.) list of token lists, each token list represents a sentence
61
+
62
+ generate dictionary data structure
63
+ """
64
+
65
+ document = []
66
+ if isinstance(input_src, str):
67
+ sentences = [sent.strip().split() for sent in input_src.strip().split('\n') if len(sent.strip()) > 0]
68
+ elif isinstance(input_src, list):
69
+ sentences = input_src
70
+ idx = 0
71
+ for sentence in sentences:
72
+ sent = []
73
+ for token_id, token in enumerate(sentence):
74
+ sent.append({doc.ID: (token_id + 1, ), doc.TEXT: token, doc.MISC: f'start_char={idx}|end_char={idx + len(token)}'})
75
+ idx += len(token) + 1
76
+ document.append(sent)
77
+ raw_text = ' '.join([' '.join(sentence) for sentence in sentences])
78
+ return raw_text, document
79
+
80
+ def process(self, document):
81
+ if not (isinstance(document, str) or isinstance(document, doc.Document) or (self.config.get('pretokenized') or self.config.get('no_ssplit', False))):
82
+ raise ValueError("If neither 'pretokenized' or 'no_ssplit' option is enabled, the input to the TokenizerProcessor must be a string or a Document object. Got %s" % str(type(document)))
83
+
84
+ if isinstance(document, doc.Document):
85
+ if self.config.get('pretokenized'):
86
+ return document
87
+ document = document.text
88
+
89
+ if self.config.get('pretokenized'):
90
+ raw_text, document = self.process_pre_tokenized_text(document)
91
+ return doc.Document(document, raw_text)
92
+
93
+ if hasattr(self, '_variant'):
94
+ return self._variant.process(document)
95
+
96
+ raw_text = '\n\n'.join(document) if isinstance(document, list) else document
97
+
98
+ max_seq_len = self.config.get('max_seqlen', TokenizeProcessor.MAX_SEQ_LENGTH_DEFAULT)
99
+
100
+ # set up batches
101
+ batches = TokenizationDataset(self.config, input_text=raw_text, vocab=self.vocab, evaluation=True, dictionary=self.trainer.dictionary)
102
+ # get dict data
103
+ with torch.no_grad():
104
+ _, _, _, document = output_predictions(None, self.trainer, batches, self.vocab, None,
105
+ max_seq_len,
106
+ orig_text=raw_text,
107
+ no_ssplit=self.config.get('no_ssplit', False),
108
+ num_workers = self.config.get('num_workers', 0),
109
+ postprocessor = self._postprocessor)
110
+
111
+ # replace excessively long tokens with <UNK> to avoid downstream GPU memory issues in POS
112
+ for sentence in document:
113
+ for token in sentence:
114
+ if len(token['text']) > max_seq_len:
115
+ token['text'] = TOKEN_TOO_LONG_REPLACEMENT
116
+
117
+ return doc.Document(document, raw_text)
118
+
119
+ def bulk_process(self, docs):
120
+ """
121
+ The tokenizer cannot use UDProcessor's sentence-level cross-document batching interface, and requires special handling.
122
+ Essentially, this method concatenates the text of multiple documents with "\n\n", tokenizes it with the neural tokenizer,
123
+ then splits the result into the original Documents and recovers the original character offsets.
124
+ """
125
+ if hasattr(self, '_variant'):
126
+ return self._variant.bulk_process(docs)
127
+
128
+ if self.config.get('pretokenized'):
129
+ res = []
130
+ for document in docs:
131
+ raw_text, document = self.process_pre_tokenized_text(document.text)
132
+ res.append(doc.Document(document, raw_text))
133
+ return res
134
+
135
+ combined_text = '\n\n'.join([thisdoc.text for thisdoc in docs])
136
+ processed_combined = self.process(doc.Document([], text=combined_text))
137
+
138
+ # postprocess sentences and tokens to reset back pointers and char offsets
139
+ charoffset = 0
140
+ sentst = senten = 0
141
+ for thisdoc in docs:
142
+ while senten < len(processed_combined.sentences) and processed_combined.sentences[senten].tokens[-1].end_char - charoffset <= len(thisdoc.text):
143
+ senten += 1
144
+
145
+ sentences = processed_combined.sentences[sentst:senten]
146
+ thisdoc.sentences = sentences
147
+ for sent in sentences:
148
+ # fix doc back pointers for sentences
149
+ sent._doc = thisdoc
150
+
151
+ # fix char offsets for tokens and words
152
+ for token in sent.tokens:
153
+ token._start_char -= charoffset
154
+ token._end_char -= charoffset
155
+ if token.words: # not-yet-processed MWT can leave empty tokens
156
+ for word in token.words:
157
+ word._start_char -= charoffset
158
+ word._end_char -= charoffset
159
+
160
+ # Here we need to fix up the SpacesAfter for the very last token
161
+ # and the SpacesBefore for the first token of the next doc
162
+ # After all, we had connected the text with \n\n
163
+ # Need to be careful about this - in a case such as
164
+ # " -text one- "
165
+ # " -text two- "
166
+ # We want the SpacesBefore for the second document to reflect
167
+ # the extra space at the start of its text
168
+ # and the SpacesAfter for the first document to reflect
169
+ # the whitespace after its text
170
+ if len(sentences) > 0:
171
+ last_token = sentences[-1].tokens[-1]
172
+ last_whitespace = thisdoc.text[last_token.end_char:]
173
+ last_token.spaces_after = last_whitespace
174
+
175
+ first_token = sentences[0].tokens[0]
176
+ first_whitespace = thisdoc.text[:first_token.start_char]
177
+ first_token.spaces_before = first_whitespace
178
+
179
+ thisdoc.num_tokens = sum(len(sent.tokens) for sent in sentences)
180
+ thisdoc.num_words = sum(len(sent.words) for sent in sentences)
181
+ sentst = senten
182
+
183
+ charoffset += len(thisdoc.text) + 2
184
+
185
+ return docs
stanza/stanza/protobuf/CoreNLP_pb2.py ADDED
@@ -0,0 +1,686 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: CoreNLP.proto
4
+ """Generated protocol buffer code."""
5
+ from google.protobuf.internal import enum_type_wrapper
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import message as _message
9
+ from google.protobuf import reflection as _reflection
10
+ from google.protobuf import symbol_database as _symbol_database
11
+ # @@protoc_insertion_point(imports)
12
+
13
+ _sym_db = _symbol_database.Default()
14
+
15
+
16
+
17
+
18
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rCoreNLP.proto\x12\x19\x65\x64u.stanford.nlp.pipeline\"\xe1\x05\n\x08\x44ocument\x12\x0c\n\x04text\x18\x01 \x02(\t\x12\x35\n\x08sentence\x18\x02 \x03(\x0b\x32#.edu.stanford.nlp.pipeline.Sentence\x12\x39\n\ncorefChain\x18\x03 \x03(\x0b\x32%.edu.stanford.nlp.pipeline.CorefChain\x12\r\n\x05\x64ocID\x18\x04 \x01(\t\x12\x0f\n\x07\x64ocDate\x18\x07 \x01(\t\x12\x10\n\x08\x63\x61lendar\x18\x08 \x01(\x04\x12;\n\x11sentencelessToken\x18\x05 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x33\n\tcharacter\x18\n \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12/\n\x05quote\x18\x06 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Quote\x12\x37\n\x08mentions\x18\t \x03(\x0b\x32%.edu.stanford.nlp.pipeline.NERMention\x12#\n\x1bhasEntityMentionsAnnotation\x18\r \x01(\x08\x12\x0e\n\x06xmlDoc\x18\x0b \x01(\x08\x12\x34\n\x08sections\x18\x0c \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Section\x12<\n\x10mentionsForCoref\x18\x0e \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Mention\x12!\n\x19hasCorefMentionAnnotation\x18\x0f \x01(\x08\x12\x1a\n\x12hasCorefAnnotation\x18\x10 \x01(\x08\x12+\n#corefMentionToEntityMentionMappings\x18\x11 \x03(\x05\x12+\n#entityMentionToCorefMentionMappings\x18\x12 \x03(\x05*\x05\x08\x64\x10\x80\x02\"\xf3\x0f\n\x08Sentence\x12/\n\x05token\x18\x01 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x18\n\x10tokenOffsetBegin\x18\x02 \x02(\r\x12\x16\n\x0etokenOffsetEnd\x18\x03 \x02(\r\x12\x15\n\rsentenceIndex\x18\x04 \x01(\r\x12\x1c\n\x14\x63haracterOffsetBegin\x18\x05 \x01(\r\x12\x1a\n\x12\x63haracterOffsetEnd\x18\x06 \x01(\r\x12\x37\n\tparseTree\x18\x07 \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12@\n\x12\x62inarizedParseTree\x18\x1f \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12@\n\x12\x61nnotatedParseTree\x18 \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\x11\n\tsentiment\x18! \x01(\t\x12=\n\x0fkBestParseTrees\x18\" \x03(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\x45\n\x11\x62\x61sicDependencies\x18\x08 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12I\n\x15\x63ollapsedDependencies\x18\t \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12T\n collapsedCCProcessedDependencies\x18\n \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12K\n\x17\x61lternativeDependencies\x18\r \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12?\n\x0copenieTriple\x18\x0e \x03(\x0b\x32).edu.stanford.nlp.pipeline.RelationTriple\x12<\n\tkbpTriple\x18\x10 \x03(\x0b\x32).edu.stanford.nlp.pipeline.RelationTriple\x12\x45\n\x10\x65ntailedSentence\x18\x0f \x03(\x0b\x32+.edu.stanford.nlp.pipeline.SentenceFragment\x12\x43\n\x0e\x65ntailedClause\x18# \x03(\x0b\x32+.edu.stanford.nlp.pipeline.SentenceFragment\x12H\n\x14\x65nhancedDependencies\x18\x11 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12P\n\x1c\x65nhancedPlusPlusDependencies\x18\x12 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12\x33\n\tcharacter\x18\x13 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x11\n\tparagraph\x18\x0b \x01(\r\x12\x0c\n\x04text\x18\x0c \x01(\t\x12\x12\n\nlineNumber\x18\x14 \x01(\r\x12\x1e\n\x16hasRelationAnnotations\x18\x33 \x01(\x08\x12\x31\n\x06\x65ntity\x18\x34 \x03(\x0b\x32!.edu.stanford.nlp.pipeline.Entity\x12\x35\n\x08relation\x18\x35 \x03(\x0b\x32#.edu.stanford.nlp.pipeline.Relation\x12$\n\x1chasNumerizedTokensAnnotation\x18\x36 \x01(\x08\x12\x37\n\x08mentions\x18\x37 \x03(\x0b\x32%.edu.stanford.nlp.pipeline.NERMention\x12<\n\x10mentionsForCoref\x18\x38 \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Mention\x12\"\n\x1ahasCorefMentionsAnnotation\x18\x39 \x01(\x08\x12\x12\n\nsentenceID\x18: \x01(\t\x12\x13\n\x0bsectionDate\x18; \x01(\t\x12\x14\n\x0csectionIndex\x18< \x01(\r\x12\x13\n\x0bsectionName\x18= \x01(\t\x12\x15\n\rsectionAuthor\x18> \x01(\t\x12\r\n\x05\x64ocID\x18? \x01(\t\x12\x15\n\rsectionQuoted\x18@ \x01(\x08\x12#\n\x1bhasEntityMentionsAnnotation\x18\x41 \x01(\x08\x12\x1f\n\x17hasKBPTriplesAnnotation\x18\x44 \x01(\x08\x12\"\n\x1ahasOpenieTriplesAnnotation\x18\x45 \x01(\x08\x12\x14\n\x0c\x63hapterIndex\x18\x42 \x01(\r\x12\x16\n\x0eparagraphIndex\x18\x43 \x01(\r\x12=\n\x10\x65nhancedSentence\x18\x46 \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Sentence\x12\x0f\n\x07speaker\x18G \x01(\t\x12\x13\n\x0bspeakerType\x18H \x01(\t*\x05\x08\x64\x10\x80\x02\"\xf6\x0c\n\x05Token\x12\x0c\n\x04word\x18\x01 \x01(\t\x12\x0b\n\x03pos\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\t\x12\x10\n\x08\x63\x61tegory\x18\x04 \x01(\t\x12\x0e\n\x06\x62\x65\x66ore\x18\x05 \x01(\t\x12\r\n\x05\x61\x66ter\x18\x06 \x01(\t\x12\x14\n\x0coriginalText\x18\x07 \x01(\t\x12\x0b\n\x03ner\x18\x08 \x01(\t\x12\x11\n\tcoarseNER\x18> \x01(\t\x12\x16\n\x0e\x66ineGrainedNER\x18? \x01(\t\x12\x15\n\rnerLabelProbs\x18\x42 \x03(\t\x12\x15\n\rnormalizedNER\x18\t \x01(\t\x12\r\n\x05lemma\x18\n \x01(\t\x12\x11\n\tbeginChar\x18\x0b \x01(\r\x12\x0f\n\x07\x65ndChar\x18\x0c \x01(\r\x12\x11\n\tutterance\x18\r \x01(\r\x12\x0f\n\x07speaker\x18\x0e \x01(\t\x12\x13\n\x0bspeakerType\x18M \x01(\t\x12\x12\n\nbeginIndex\x18\x0f \x01(\r\x12\x10\n\x08\x65ndIndex\x18\x10 \x01(\r\x12\x17\n\x0ftokenBeginIndex\x18\x11 \x01(\r\x12\x15\n\rtokenEndIndex\x18\x12 \x01(\r\x12\x34\n\ntimexValue\x18\x13 \x01(\x0b\x32 .edu.stanford.nlp.pipeline.Timex\x12\x15\n\rhasXmlContext\x18\x15 \x01(\x08\x12\x12\n\nxmlContext\x18\x16 \x03(\t\x12\x16\n\x0e\x63orefClusterID\x18\x17 \x01(\r\x12\x0e\n\x06\x61nswer\x18\x18 \x01(\t\x12\x15\n\rheadWordIndex\x18\x1a \x01(\r\x12\x35\n\x08operator\x18\x1b \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Operator\x12\x35\n\x08polarity\x18\x1c \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Polarity\x12\x14\n\x0cpolarity_dir\x18\' \x01(\t\x12-\n\x04span\x18\x1d \x01(\x0b\x32\x1f.edu.stanford.nlp.pipeline.Span\x12\x11\n\tsentiment\x18\x1e \x01(\t\x12\x16\n\x0equotationIndex\x18\x1f \x01(\x05\x12\x42\n\x0e\x63onllUFeatures\x18 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.MapStringString\x12\x11\n\tcoarseTag\x18! \x01(\t\x12\x38\n\x0f\x63onllUTokenSpan\x18\" \x01(\x0b\x32\x1f.edu.stanford.nlp.pipeline.Span\x12\x12\n\nconllUMisc\x18# \x01(\t\x12G\n\x13\x63onllUSecondaryDeps\x18$ \x01(\x0b\x32*.edu.stanford.nlp.pipeline.MapStringString\x12\x17\n\x0fwikipediaEntity\x18% \x01(\t\x12\x11\n\tisNewline\x18& \x01(\x08\x12\x0e\n\x06gender\x18\x33 \x01(\t\x12\x10\n\x08trueCase\x18\x34 \x01(\t\x12\x14\n\x0ctrueCaseText\x18\x35 \x01(\t\x12\x13\n\x0b\x63hineseChar\x18\x36 \x01(\t\x12\x12\n\nchineseSeg\x18\x37 \x01(\t\x12\x16\n\x0e\x63hineseXMLChar\x18< \x01(\t\x12\x11\n\tarabicSeg\x18L \x01(\t\x12\x13\n\x0bsectionName\x18\x38 \x01(\t\x12\x15\n\rsectionAuthor\x18\x39 \x01(\t\x12\x13\n\x0bsectionDate\x18: \x01(\t\x12\x17\n\x0fsectionEndLabel\x18; \x01(\t\x12\x0e\n\x06parent\x18= \x01(\t\x12\x19\n\x11\x63orefMentionIndex\x18@ \x03(\r\x12\x1a\n\x12\x65ntityMentionIndex\x18\x41 \x01(\r\x12\r\n\x05isMWT\x18\x43 \x01(\x08\x12\x12\n\nisFirstMWT\x18\x44 \x01(\x08\x12\x0f\n\x07mwtText\x18\x45 \x01(\t\x12\x0f\n\x07mwtMisc\x18N \x01(\t\x12\x14\n\x0cnumericValue\x18\x46 \x01(\x04\x12\x13\n\x0bnumericType\x18G \x01(\t\x12\x1d\n\x15numericCompositeValue\x18H \x01(\x04\x12\x1c\n\x14numericCompositeType\x18I \x01(\t\x12\x1c\n\x14\x63odepointOffsetBegin\x18J \x01(\r\x12\x1a\n\x12\x63odepointOffsetEnd\x18K \x01(\r\x12\r\n\x05index\x18O \x01(\r\x12\x12\n\nemptyIndex\x18P \x01(\r*\x05\x08\x64\x10\x80\x02\"\xe4\x03\n\x05Quote\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05\x62\x65gin\x18\x02 \x01(\r\x12\x0b\n\x03\x65nd\x18\x03 \x01(\r\x12\x15\n\rsentenceBegin\x18\x05 \x01(\r\x12\x13\n\x0bsentenceEnd\x18\x06 \x01(\r\x12\x12\n\ntokenBegin\x18\x07 \x01(\r\x12\x10\n\x08tokenEnd\x18\x08 \x01(\r\x12\r\n\x05\x64ocid\x18\t \x01(\t\x12\r\n\x05index\x18\n \x01(\r\x12\x0e\n\x06\x61uthor\x18\x0b \x01(\t\x12\x0f\n\x07mention\x18\x0c \x01(\t\x12\x14\n\x0cmentionBegin\x18\r \x01(\r\x12\x12\n\nmentionEnd\x18\x0e \x01(\r\x12\x13\n\x0bmentionType\x18\x0f \x01(\t\x12\x14\n\x0cmentionSieve\x18\x10 \x01(\t\x12\x0f\n\x07speaker\x18\x11 \x01(\t\x12\x14\n\x0cspeakerSieve\x18\x12 \x01(\t\x12\x18\n\x10\x63\x61nonicalMention\x18\x13 \x01(\t\x12\x1d\n\x15\x63\x61nonicalMentionBegin\x18\x14 \x01(\r\x12\x1b\n\x13\x63\x61nonicalMentionEnd\x18\x15 \x01(\r\x12N\n\x1a\x61ttributionDependencyGraph\x18\x16 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\"\xc7\x01\n\tParseTree\x12\x33\n\x05\x63hild\x18\x01 \x03(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\r\n\x05value\x18\x02 \x01(\t\x12\x17\n\x0fyieldBeginIndex\x18\x03 \x01(\r\x12\x15\n\ryieldEndIndex\x18\x04 \x01(\r\x12\r\n\x05score\x18\x05 \x01(\x01\x12\x37\n\tsentiment\x18\x06 \x01(\x0e\x32$.edu.stanford.nlp.pipeline.Sentiment\"\x9b\x04\n\x0f\x44\x65pendencyGraph\x12=\n\x04node\x18\x01 \x03(\x0b\x32/.edu.stanford.nlp.pipeline.DependencyGraph.Node\x12=\n\x04\x65\x64ge\x18\x02 \x03(\x0b\x32/.edu.stanford.nlp.pipeline.DependencyGraph.Edge\x12\x10\n\x04root\x18\x03 \x03(\rB\x02\x10\x01\x12/\n\x05token\x18\x04 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x14\n\x08rootNode\x18\x05 \x03(\rB\x02\x10\x01\x1aX\n\x04Node\x12\x15\n\rsentenceIndex\x18\x01 \x02(\r\x12\r\n\x05index\x18\x02 \x02(\r\x12\x16\n\x0e\x63opyAnnotation\x18\x03 \x01(\r\x12\x12\n\nemptyIndex\x18\x04 \x01(\r\x1a\xd6\x01\n\x04\x45\x64ge\x12\x0e\n\x06source\x18\x01 \x02(\r\x12\x0e\n\x06target\x18\x02 \x02(\r\x12\x0b\n\x03\x64\x65p\x18\x03 \x01(\t\x12\x0f\n\x07isExtra\x18\x04 \x01(\x08\x12\x12\n\nsourceCopy\x18\x05 \x01(\r\x12\x12\n\ntargetCopy\x18\x06 \x01(\r\x12\x13\n\x0bsourceEmpty\x18\x08 \x01(\r\x12\x13\n\x0btargetEmpty\x18\t \x01(\r\x12>\n\x08language\x18\x07 \x01(\x0e\x32#.edu.stanford.nlp.pipeline.Language:\x07Unknown\"\xc6\x02\n\nCorefChain\x12\x0f\n\x07\x63hainID\x18\x01 \x02(\x05\x12\x43\n\x07mention\x18\x02 \x03(\x0b\x32\x32.edu.stanford.nlp.pipeline.CorefChain.CorefMention\x12\x16\n\x0erepresentative\x18\x03 \x02(\r\x1a\xc9\x01\n\x0c\x43orefMention\x12\x11\n\tmentionID\x18\x01 \x01(\x05\x12\x13\n\x0bmentionType\x18\x02 \x01(\t\x12\x0e\n\x06number\x18\x03 \x01(\t\x12\x0e\n\x06gender\x18\x04 \x01(\t\x12\x0f\n\x07\x61nimacy\x18\x05 \x01(\t\x12\x12\n\nbeginIndex\x18\x06 \x01(\r\x12\x10\n\x08\x65ndIndex\x18\x07 \x01(\r\x12\x11\n\theadIndex\x18\t \x01(\r\x12\x15\n\rsentenceIndex\x18\n \x01(\r\x12\x10\n\x08position\x18\x0b \x01(\r\"\xef\x08\n\x07Mention\x12\x11\n\tmentionID\x18\x01 \x01(\x05\x12\x13\n\x0bmentionType\x18\x02 \x01(\t\x12\x0e\n\x06number\x18\x03 \x01(\t\x12\x0e\n\x06gender\x18\x04 \x01(\t\x12\x0f\n\x07\x61nimacy\x18\x05 \x01(\t\x12\x0e\n\x06person\x18\x06 \x01(\t\x12\x12\n\nstartIndex\x18\x07 \x01(\r\x12\x10\n\x08\x65ndIndex\x18\t \x01(\r\x12\x11\n\theadIndex\x18\n \x01(\x05\x12\x12\n\nheadString\x18\x0b \x01(\t\x12\x11\n\tnerString\x18\x0c \x01(\t\x12\x13\n\x0boriginalRef\x18\r \x01(\x05\x12\x1a\n\x12goldCorefClusterID\x18\x0e \x01(\x05\x12\x16\n\x0e\x63orefClusterID\x18\x0f \x01(\x05\x12\x12\n\nmentionNum\x18\x10 \x01(\x05\x12\x0f\n\x07sentNum\x18\x11 \x01(\x05\x12\r\n\x05utter\x18\x12 \x01(\x05\x12\x11\n\tparagraph\x18\x13 \x01(\x05\x12\x11\n\tisSubject\x18\x14 \x01(\x08\x12\x16\n\x0eisDirectObject\x18\x15 \x01(\x08\x12\x18\n\x10isIndirectObject\x18\x16 \x01(\x08\x12\x1b\n\x13isPrepositionObject\x18\x17 \x01(\x08\x12\x0f\n\x07hasTwin\x18\x18 \x01(\x08\x12\x0f\n\x07generic\x18\x19 \x01(\x08\x12\x13\n\x0bisSingleton\x18\x1a \x01(\x08\x12\x1a\n\x12hasBasicDependency\x18\x1b \x01(\x08\x12\x1d\n\x15hasEnhancedDependency\x18\x1c \x01(\x08\x12\x1b\n\x13hasContextParseTree\x18\x1d \x01(\x08\x12?\n\x0fheadIndexedWord\x18\x1e \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12=\n\rdependingVerb\x18\x1f \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12\x38\n\x08headWord\x18 \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12;\n\x0bspeakerInfo\x18! \x01(\x0b\x32&.edu.stanford.nlp.pipeline.SpeakerInfo\x12=\n\rsentenceWords\x18\x32 \x03(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12<\n\x0coriginalSpan\x18\x33 \x03(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12\x12\n\ndependents\x18\x34 \x03(\t\x12\x19\n\x11preprocessedTerms\x18\x35 \x03(\t\x12\x13\n\x0b\x61ppositions\x18\x36 \x03(\x05\x12\x1c\n\x14predicateNominatives\x18\x37 \x03(\x05\x12\x18\n\x10relativePronouns\x18\x38 \x03(\x05\x12\x13\n\x0blistMembers\x18\x39 \x03(\x05\x12\x15\n\rbelongToLists\x18: \x03(\x05\"X\n\x0bIndexedWord\x12\x13\n\x0bsentenceNum\x18\x01 \x01(\x05\x12\x12\n\ntokenIndex\x18\x02 \x01(\x05\x12\r\n\x05\x64ocID\x18\x03 \x01(\x05\x12\x11\n\tcopyCount\x18\x04 \x01(\r\"4\n\x0bSpeakerInfo\x12\x13\n\x0bspeakerName\x18\x01 \x01(\t\x12\x10\n\x08mentions\x18\x02 \x03(\x05\"\"\n\x04Span\x12\r\n\x05\x62\x65gin\x18\x01 \x02(\r\x12\x0b\n\x03\x65nd\x18\x02 \x02(\r\"w\n\x05Timex\x12\r\n\x05value\x18\x01 \x01(\t\x12\x10\n\x08\x61ltValue\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0b\n\x03tid\x18\x05 \x01(\t\x12\x12\n\nbeginPoint\x18\x06 \x01(\r\x12\x10\n\x08\x65ndPoint\x18\x07 \x01(\r\"\xdb\x01\n\x06\x45ntity\x12\x11\n\theadStart\x18\x06 \x01(\r\x12\x0f\n\x07headEnd\x18\x07 \x01(\r\x12\x13\n\x0bmentionType\x18\x08 \x01(\t\x12\x16\n\x0enormalizedName\x18\t \x01(\t\x12\x16\n\x0eheadTokenIndex\x18\n \x01(\r\x12\x0f\n\x07\x63orefID\x18\x0b \x01(\t\x12\x10\n\x08objectID\x18\x01 \x01(\t\x12\x13\n\x0b\x65xtentStart\x18\x02 \x01(\r\x12\x11\n\textentEnd\x18\x03 \x01(\r\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0f\n\x07subtype\x18\x05 \x01(\t\"\xb7\x01\n\x08Relation\x12\x0f\n\x07\x61rgName\x18\x06 \x03(\t\x12.\n\x03\x61rg\x18\x07 \x03(\x0b\x32!.edu.stanford.nlp.pipeline.Entity\x12\x11\n\tsignature\x18\x08 \x01(\t\x12\x10\n\x08objectID\x18\x01 \x01(\t\x12\x13\n\x0b\x65xtentStart\x18\x02 \x01(\r\x12\x11\n\textentEnd\x18\x03 \x01(\r\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0f\n\x07subtype\x18\x05 \x01(\t\"\xb2\x01\n\x08Operator\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x1b\n\x13quantifierSpanBegin\x18\x02 \x02(\x05\x12\x19\n\x11quantifierSpanEnd\x18\x03 \x02(\x05\x12\x18\n\x10subjectSpanBegin\x18\x04 \x02(\x05\x12\x16\n\x0esubjectSpanEnd\x18\x05 \x02(\x05\x12\x17\n\x0fobjectSpanBegin\x18\x06 \x02(\x05\x12\x15\n\robjectSpanEnd\x18\x07 \x02(\x05\"\xa9\x04\n\x08Polarity\x12K\n\x12projectEquivalence\x18\x01 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12Q\n\x18projectForwardEntailment\x18\x02 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12Q\n\x18projectReverseEntailment\x18\x03 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12H\n\x0fprojectNegation\x18\x04 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12K\n\x12projectAlternation\x18\x05 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12\x45\n\x0cprojectCover\x18\x06 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12L\n\x13projectIndependence\x18\x07 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\"\xdd\x02\n\nNERMention\x12\x15\n\rsentenceIndex\x18\x01 \x01(\r\x12%\n\x1dtokenStartInSentenceInclusive\x18\x02 \x02(\r\x12#\n\x1btokenEndInSentenceExclusive\x18\x03 \x02(\r\x12\x0b\n\x03ner\x18\x04 \x02(\t\x12\x15\n\rnormalizedNER\x18\x05 \x01(\t\x12\x12\n\nentityType\x18\x06 \x01(\t\x12/\n\x05timex\x18\x07 \x01(\x0b\x32 .edu.stanford.nlp.pipeline.Timex\x12\x17\n\x0fwikipediaEntity\x18\x08 \x01(\t\x12\x0e\n\x06gender\x18\t \x01(\t\x12\x1a\n\x12\x65ntityMentionIndex\x18\n \x01(\r\x12#\n\x1b\x63\x61nonicalEntityMentionIndex\x18\x0b \x01(\r\x12\x19\n\x11\x65ntityMentionText\x18\x0c \x01(\t\"Y\n\x10SentenceFragment\x12\x12\n\ntokenIndex\x18\x01 \x03(\r\x12\x0c\n\x04root\x18\x02 \x01(\r\x12\x14\n\x0c\x61ssumedTruth\x18\x03 \x01(\x08\x12\r\n\x05score\x18\x04 \x01(\x01\":\n\rTokenLocation\x12\x15\n\rsentenceIndex\x18\x01 \x01(\r\x12\x12\n\ntokenIndex\x18\x02 \x01(\r\"\x9a\x03\n\x0eRelationTriple\x12\x0f\n\x07subject\x18\x01 \x01(\t\x12\x10\n\x08relation\x18\x02 \x01(\t\x12\x0e\n\x06object\x18\x03 \x01(\t\x12\x12\n\nconfidence\x18\x04 \x01(\x01\x12?\n\rsubjectTokens\x18\r \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12@\n\x0erelationTokens\x18\x0e \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12>\n\x0cobjectTokens\x18\x0f \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12\x38\n\x04tree\x18\x08 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12\x0e\n\x06istmod\x18\t \x01(\x08\x12\x10\n\x08prefixBe\x18\n \x01(\x08\x12\x10\n\x08suffixBe\x18\x0b \x01(\x08\x12\x10\n\x08suffixOf\x18\x0c \x01(\x08\"-\n\x0fMapStringString\x12\x0b\n\x03key\x18\x01 \x03(\t\x12\r\n\x05value\x18\x02 \x03(\t\"*\n\x0cMapIntString\x12\x0b\n\x03key\x18\x01 \x03(\r\x12\r\n\x05value\x18\x02 \x03(\t\"\xfc\x01\n\x07Section\x12\x11\n\tcharBegin\x18\x01 \x02(\r\x12\x0f\n\x07\x63harEnd\x18\x02 \x02(\r\x12\x0e\n\x06\x61uthor\x18\x03 \x01(\t\x12\x17\n\x0fsentenceIndexes\x18\x04 \x03(\r\x12\x10\n\x08\x64\x61tetime\x18\x05 \x01(\t\x12\x30\n\x06quotes\x18\x06 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Quote\x12\x17\n\x0f\x61uthorCharBegin\x18\x07 \x01(\r\x12\x15\n\rauthorCharEnd\x18\x08 \x01(\r\x12\x30\n\x06xmlTag\x18\t \x02(\x0b\x32 .edu.stanford.nlp.pipeline.Token\"\xe4\x01\n\x0eSemgrexRequest\x12\x0f\n\x07semgrex\x18\x01 \x03(\t\x12\x45\n\x05query\x18\x02 \x03(\x0b\x32\x36.edu.stanford.nlp.pipeline.SemgrexRequest.Dependencies\x1az\n\x0c\x44\x65pendencies\x12/\n\x05token\x18\x01 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x39\n\x05graph\x18\x02 \x02(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\"\xfb\x05\n\x0fSemgrexResponse\x12\x46\n\x06result\x18\x01 \x03(\x0b\x32\x36.edu.stanford.nlp.pipeline.SemgrexResponse.GraphResult\x1a-\n\tNamedNode\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x12\n\nmatchIndex\x18\x02 \x02(\x05\x1a+\n\rNamedRelation\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x0c\n\x04reln\x18\x02 \x02(\t\x1a\x80\x01\n\tNamedEdge\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x0e\n\x06source\x18\x02 \x02(\x05\x12\x0e\n\x06target\x18\x03 \x02(\x05\x12\x0c\n\x04reln\x18\x04 \x01(\t\x12\x0f\n\x07isExtra\x18\x05 \x01(\x08\x12\x12\n\nsourceCopy\x18\x06 \x01(\r\x12\x12\n\ntargetCopy\x18\x07 \x01(\r\x1a\x95\x02\n\x05Match\x12\x12\n\nmatchIndex\x18\x01 \x02(\x05\x12\x42\n\x04node\x18\x02 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.SemgrexResponse.NamedNode\x12\x46\n\x04reln\x18\x03 \x03(\x0b\x32\x38.edu.stanford.nlp.pipeline.SemgrexResponse.NamedRelation\x12\x42\n\x04\x65\x64ge\x18\x06 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.SemgrexResponse.NamedEdge\x12\x12\n\ngraphIndex\x18\x04 \x01(\x05\x12\x14\n\x0csemgrexIndex\x18\x05 \x01(\x05\x1aP\n\rSemgrexResult\x12?\n\x05match\x18\x01 \x03(\x0b\x32\x30.edu.stanford.nlp.pipeline.SemgrexResponse.Match\x1aW\n\x0bGraphResult\x12H\n\x06result\x18\x01 \x03(\x0b\x32\x38.edu.stanford.nlp.pipeline.SemgrexResponse.SemgrexResult\"\xf0\x01\n\x0fSsurgeonRequest\x12\x45\n\x08ssurgeon\x18\x01 \x03(\x0b\x32\x33.edu.stanford.nlp.pipeline.SsurgeonRequest.Ssurgeon\x12\x39\n\x05graph\x18\x02 \x03(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x1a[\n\x08Ssurgeon\x12\x0f\n\x07semgrex\x18\x01 \x01(\t\x12\x11\n\toperation\x18\x02 \x03(\t\x12\n\n\x02id\x18\x03 \x01(\t\x12\r\n\x05notes\x18\x04 \x01(\t\x12\x10\n\x08language\x18\x05 \x01(\t\"\xbc\x01\n\x10SsurgeonResponse\x12J\n\x06result\x18\x01 \x03(\x0b\x32:.edu.stanford.nlp.pipeline.SsurgeonResponse.SsurgeonResult\x1a\\\n\x0eSsurgeonResult\x12\x39\n\x05graph\x18\x01 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12\x0f\n\x07\x63hanged\x18\x02 \x01(\x08\"W\n\x12TokensRegexRequest\x12\x30\n\x03\x64oc\x18\x01 \x02(\x0b\x32#.edu.stanford.nlp.pipeline.Document\x12\x0f\n\x07pattern\x18\x02 \x03(\t\"\xa7\x03\n\x13TokensRegexResponse\x12J\n\x05match\x18\x01 \x03(\x0b\x32;.edu.stanford.nlp.pipeline.TokensRegexResponse.PatternMatch\x1a\x39\n\rMatchLocation\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05\x62\x65gin\x18\x02 \x01(\x05\x12\x0b\n\x03\x65nd\x18\x03 \x01(\x05\x1a\xb3\x01\n\x05Match\x12\x10\n\x08sentence\x18\x01 \x02(\x05\x12K\n\x05match\x18\x02 \x02(\x0b\x32<.edu.stanford.nlp.pipeline.TokensRegexResponse.MatchLocation\x12K\n\x05group\x18\x03 \x03(\x0b\x32<.edu.stanford.nlp.pipeline.TokensRegexResponse.MatchLocation\x1aS\n\x0cPatternMatch\x12\x43\n\x05match\x18\x01 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.TokensRegexResponse.Match\"\xae\x01\n\x19\x44\x65pendencyEnhancerRequest\x12\x35\n\x08\x64ocument\x18\x01 \x02(\x0b\x32#.edu.stanford.nlp.pipeline.Document\x12\x37\n\x08language\x18\x02 \x01(\x0e\x32#.edu.stanford.nlp.pipeline.LanguageH\x00\x12\x1a\n\x10relativePronouns\x18\x03 \x01(\tH\x00\x42\x05\n\x03ref\"\xb4\x01\n\x12\x46lattenedParseTree\x12\x41\n\x05nodes\x18\x01 \x03(\x0b\x32\x32.edu.stanford.nlp.pipeline.FlattenedParseTree.Node\x1a[\n\x04Node\x12\x12\n\x08openNode\x18\x01 \x01(\x08H\x00\x12\x13\n\tcloseNode\x18\x02 \x01(\x08H\x00\x12\x0f\n\x05value\x18\x03 \x01(\tH\x00\x12\r\n\x05score\x18\x04 \x01(\x01\x42\n\n\x08\x63ontents\"\xf6\x01\n\x15\x45valuateParserRequest\x12N\n\x08treebank\x18\x01 \x03(\x0b\x32<.edu.stanford.nlp.pipeline.EvaluateParserRequest.ParseResult\x1a\x8c\x01\n\x0bParseResult\x12;\n\x04gold\x18\x01 \x02(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\x12@\n\tpredicted\x18\x02 \x03(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\"E\n\x16\x45valuateParserResponse\x12\n\n\x02\x66\x31\x18\x01 \x02(\x01\x12\x0f\n\x07kbestF1\x18\x02 \x01(\x01\x12\x0e\n\x06treeF1\x18\x03 \x03(\x01\"\xc8\x01\n\x0fTsurgeonRequest\x12H\n\noperations\x18\x01 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.TsurgeonRequest.Operation\x12<\n\x05trees\x18\x02 \x03(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\x1a-\n\tOperation\x12\x0e\n\x06tregex\x18\x01 \x02(\t\x12\x10\n\x08tsurgeon\x18\x02 \x03(\t\"P\n\x10TsurgeonResponse\x12<\n\x05trees\x18\x01 \x03(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\"\x85\x01\n\x11MorphologyRequest\x12\x46\n\x05words\x18\x01 \x03(\x0b\x32\x37.edu.stanford.nlp.pipeline.MorphologyRequest.TaggedWord\x1a(\n\nTaggedWord\x12\x0c\n\x04word\x18\x01 \x02(\t\x12\x0c\n\x04xpos\x18\x02 \x01(\t\"\x9a\x01\n\x12MorphologyResponse\x12I\n\x05words\x18\x01 \x03(\x0b\x32:.edu.stanford.nlp.pipeline.MorphologyResponse.WordTagLemma\x1a\x39\n\x0cWordTagLemma\x12\x0c\n\x04word\x18\x01 \x02(\t\x12\x0c\n\x04xpos\x18\x02 \x01(\t\x12\r\n\x05lemma\x18\x03 \x02(\t\"Z\n\x1a\x44\x65pendencyConverterRequest\x12<\n\x05trees\x18\x01 \x03(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\"\x90\x02\n\x1b\x44\x65pendencyConverterResponse\x12`\n\x0b\x63onversions\x18\x01 \x03(\x0b\x32K.edu.stanford.nlp.pipeline.DependencyConverterResponse.DependencyConversion\x1a\x8e\x01\n\x14\x44\x65pendencyConversion\x12\x39\n\x05graph\x18\x01 \x02(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12;\n\x04tree\x18\x02 \x01(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree*\xa3\x01\n\x08Language\x12\x0b\n\x07Unknown\x10\x00\x12\x07\n\x03\x41ny\x10\x01\x12\n\n\x06\x41rabic\x10\x02\x12\x0b\n\x07\x43hinese\x10\x03\x12\x0b\n\x07\x45nglish\x10\x04\x12\n\n\x06German\x10\x05\x12\n\n\x06\x46rench\x10\x06\x12\n\n\x06Hebrew\x10\x07\x12\x0b\n\x07Spanish\x10\x08\x12\x14\n\x10UniversalEnglish\x10\t\x12\x14\n\x10UniversalChinese\x10\n*h\n\tSentiment\x12\x13\n\x0fSTRONG_NEGATIVE\x10\x00\x12\x11\n\rWEAK_NEGATIVE\x10\x01\x12\x0b\n\x07NEUTRAL\x10\x02\x12\x11\n\rWEAK_POSITIVE\x10\x03\x12\x13\n\x0fSTRONG_POSITIVE\x10\x04*\x93\x01\n\x14NaturalLogicRelation\x12\x0f\n\x0b\x45QUIVALENCE\x10\x00\x12\x16\n\x12\x46ORWARD_ENTAILMENT\x10\x01\x12\x16\n\x12REVERSE_ENTAILMENT\x10\x02\x12\x0c\n\x08NEGATION\x10\x03\x12\x0f\n\x0b\x41LTERNATION\x10\x04\x12\t\n\x05\x43OVER\x10\x05\x12\x10\n\x0cINDEPENDENCE\x10\x06\x42*\n\x19\x65\x64u.stanford.nlp.pipelineB\rCoreNLPProtos')
19
+
20
+ _LANGUAGE = DESCRIPTOR.enum_types_by_name['Language']
21
+ Language = enum_type_wrapper.EnumTypeWrapper(_LANGUAGE)
22
+ _SENTIMENT = DESCRIPTOR.enum_types_by_name['Sentiment']
23
+ Sentiment = enum_type_wrapper.EnumTypeWrapper(_SENTIMENT)
24
+ _NATURALLOGICRELATION = DESCRIPTOR.enum_types_by_name['NaturalLogicRelation']
25
+ NaturalLogicRelation = enum_type_wrapper.EnumTypeWrapper(_NATURALLOGICRELATION)
26
+ Unknown = 0
27
+ Any = 1
28
+ Arabic = 2
29
+ Chinese = 3
30
+ English = 4
31
+ German = 5
32
+ French = 6
33
+ Hebrew = 7
34
+ Spanish = 8
35
+ UniversalEnglish = 9
36
+ UniversalChinese = 10
37
+ STRONG_NEGATIVE = 0
38
+ WEAK_NEGATIVE = 1
39
+ NEUTRAL = 2
40
+ WEAK_POSITIVE = 3
41
+ STRONG_POSITIVE = 4
42
+ EQUIVALENCE = 0
43
+ FORWARD_ENTAILMENT = 1
44
+ REVERSE_ENTAILMENT = 2
45
+ NEGATION = 3
46
+ ALTERNATION = 4
47
+ COVER = 5
48
+ INDEPENDENCE = 6
49
+
50
+
51
+ _DOCUMENT = DESCRIPTOR.message_types_by_name['Document']
52
+ _SENTENCE = DESCRIPTOR.message_types_by_name['Sentence']
53
+ _TOKEN = DESCRIPTOR.message_types_by_name['Token']
54
+ _QUOTE = DESCRIPTOR.message_types_by_name['Quote']
55
+ _PARSETREE = DESCRIPTOR.message_types_by_name['ParseTree']
56
+ _DEPENDENCYGRAPH = DESCRIPTOR.message_types_by_name['DependencyGraph']
57
+ _DEPENDENCYGRAPH_NODE = _DEPENDENCYGRAPH.nested_types_by_name['Node']
58
+ _DEPENDENCYGRAPH_EDGE = _DEPENDENCYGRAPH.nested_types_by_name['Edge']
59
+ _COREFCHAIN = DESCRIPTOR.message_types_by_name['CorefChain']
60
+ _COREFCHAIN_COREFMENTION = _COREFCHAIN.nested_types_by_name['CorefMention']
61
+ _MENTION = DESCRIPTOR.message_types_by_name['Mention']
62
+ _INDEXEDWORD = DESCRIPTOR.message_types_by_name['IndexedWord']
63
+ _SPEAKERINFO = DESCRIPTOR.message_types_by_name['SpeakerInfo']
64
+ _SPAN = DESCRIPTOR.message_types_by_name['Span']
65
+ _TIMEX = DESCRIPTOR.message_types_by_name['Timex']
66
+ _ENTITY = DESCRIPTOR.message_types_by_name['Entity']
67
+ _RELATION = DESCRIPTOR.message_types_by_name['Relation']
68
+ _OPERATOR = DESCRIPTOR.message_types_by_name['Operator']
69
+ _POLARITY = DESCRIPTOR.message_types_by_name['Polarity']
70
+ _NERMENTION = DESCRIPTOR.message_types_by_name['NERMention']
71
+ _SENTENCEFRAGMENT = DESCRIPTOR.message_types_by_name['SentenceFragment']
72
+ _TOKENLOCATION = DESCRIPTOR.message_types_by_name['TokenLocation']
73
+ _RELATIONTRIPLE = DESCRIPTOR.message_types_by_name['RelationTriple']
74
+ _MAPSTRINGSTRING = DESCRIPTOR.message_types_by_name['MapStringString']
75
+ _MAPINTSTRING = DESCRIPTOR.message_types_by_name['MapIntString']
76
+ _SECTION = DESCRIPTOR.message_types_by_name['Section']
77
+ _SEMGREXREQUEST = DESCRIPTOR.message_types_by_name['SemgrexRequest']
78
+ _SEMGREXREQUEST_DEPENDENCIES = _SEMGREXREQUEST.nested_types_by_name['Dependencies']
79
+ _SEMGREXRESPONSE = DESCRIPTOR.message_types_by_name['SemgrexResponse']
80
+ _SEMGREXRESPONSE_NAMEDNODE = _SEMGREXRESPONSE.nested_types_by_name['NamedNode']
81
+ _SEMGREXRESPONSE_NAMEDRELATION = _SEMGREXRESPONSE.nested_types_by_name['NamedRelation']
82
+ _SEMGREXRESPONSE_NAMEDEDGE = _SEMGREXRESPONSE.nested_types_by_name['NamedEdge']
83
+ _SEMGREXRESPONSE_MATCH = _SEMGREXRESPONSE.nested_types_by_name['Match']
84
+ _SEMGREXRESPONSE_SEMGREXRESULT = _SEMGREXRESPONSE.nested_types_by_name['SemgrexResult']
85
+ _SEMGREXRESPONSE_GRAPHRESULT = _SEMGREXRESPONSE.nested_types_by_name['GraphResult']
86
+ _SSURGEONREQUEST = DESCRIPTOR.message_types_by_name['SsurgeonRequest']
87
+ _SSURGEONREQUEST_SSURGEON = _SSURGEONREQUEST.nested_types_by_name['Ssurgeon']
88
+ _SSURGEONRESPONSE = DESCRIPTOR.message_types_by_name['SsurgeonResponse']
89
+ _SSURGEONRESPONSE_SSURGEONRESULT = _SSURGEONRESPONSE.nested_types_by_name['SsurgeonResult']
90
+ _TOKENSREGEXREQUEST = DESCRIPTOR.message_types_by_name['TokensRegexRequest']
91
+ _TOKENSREGEXRESPONSE = DESCRIPTOR.message_types_by_name['TokensRegexResponse']
92
+ _TOKENSREGEXRESPONSE_MATCHLOCATION = _TOKENSREGEXRESPONSE.nested_types_by_name['MatchLocation']
93
+ _TOKENSREGEXRESPONSE_MATCH = _TOKENSREGEXRESPONSE.nested_types_by_name['Match']
94
+ _TOKENSREGEXRESPONSE_PATTERNMATCH = _TOKENSREGEXRESPONSE.nested_types_by_name['PatternMatch']
95
+ _DEPENDENCYENHANCERREQUEST = DESCRIPTOR.message_types_by_name['DependencyEnhancerRequest']
96
+ _FLATTENEDPARSETREE = DESCRIPTOR.message_types_by_name['FlattenedParseTree']
97
+ _FLATTENEDPARSETREE_NODE = _FLATTENEDPARSETREE.nested_types_by_name['Node']
98
+ _EVALUATEPARSERREQUEST = DESCRIPTOR.message_types_by_name['EvaluateParserRequest']
99
+ _EVALUATEPARSERREQUEST_PARSERESULT = _EVALUATEPARSERREQUEST.nested_types_by_name['ParseResult']
100
+ _EVALUATEPARSERRESPONSE = DESCRIPTOR.message_types_by_name['EvaluateParserResponse']
101
+ _TSURGEONREQUEST = DESCRIPTOR.message_types_by_name['TsurgeonRequest']
102
+ _TSURGEONREQUEST_OPERATION = _TSURGEONREQUEST.nested_types_by_name['Operation']
103
+ _TSURGEONRESPONSE = DESCRIPTOR.message_types_by_name['TsurgeonResponse']
104
+ _MORPHOLOGYREQUEST = DESCRIPTOR.message_types_by_name['MorphologyRequest']
105
+ _MORPHOLOGYREQUEST_TAGGEDWORD = _MORPHOLOGYREQUEST.nested_types_by_name['TaggedWord']
106
+ _MORPHOLOGYRESPONSE = DESCRIPTOR.message_types_by_name['MorphologyResponse']
107
+ _MORPHOLOGYRESPONSE_WORDTAGLEMMA = _MORPHOLOGYRESPONSE.nested_types_by_name['WordTagLemma']
108
+ _DEPENDENCYCONVERTERREQUEST = DESCRIPTOR.message_types_by_name['DependencyConverterRequest']
109
+ _DEPENDENCYCONVERTERRESPONSE = DESCRIPTOR.message_types_by_name['DependencyConverterResponse']
110
+ _DEPENDENCYCONVERTERRESPONSE_DEPENDENCYCONVERSION = _DEPENDENCYCONVERTERRESPONSE.nested_types_by_name['DependencyConversion']
111
+ Document = _reflection.GeneratedProtocolMessageType('Document', (_message.Message,), {
112
+ 'DESCRIPTOR' : _DOCUMENT,
113
+ '__module__' : 'CoreNLP_pb2'
114
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Document)
115
+ })
116
+ _sym_db.RegisterMessage(Document)
117
+
118
+ Sentence = _reflection.GeneratedProtocolMessageType('Sentence', (_message.Message,), {
119
+ 'DESCRIPTOR' : _SENTENCE,
120
+ '__module__' : 'CoreNLP_pb2'
121
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Sentence)
122
+ })
123
+ _sym_db.RegisterMessage(Sentence)
124
+
125
+ Token = _reflection.GeneratedProtocolMessageType('Token', (_message.Message,), {
126
+ 'DESCRIPTOR' : _TOKEN,
127
+ '__module__' : 'CoreNLP_pb2'
128
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Token)
129
+ })
130
+ _sym_db.RegisterMessage(Token)
131
+
132
+ Quote = _reflection.GeneratedProtocolMessageType('Quote', (_message.Message,), {
133
+ 'DESCRIPTOR' : _QUOTE,
134
+ '__module__' : 'CoreNLP_pb2'
135
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Quote)
136
+ })
137
+ _sym_db.RegisterMessage(Quote)
138
+
139
+ ParseTree = _reflection.GeneratedProtocolMessageType('ParseTree', (_message.Message,), {
140
+ 'DESCRIPTOR' : _PARSETREE,
141
+ '__module__' : 'CoreNLP_pb2'
142
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.ParseTree)
143
+ })
144
+ _sym_db.RegisterMessage(ParseTree)
145
+
146
+ DependencyGraph = _reflection.GeneratedProtocolMessageType('DependencyGraph', (_message.Message,), {
147
+
148
+ 'Node' : _reflection.GeneratedProtocolMessageType('Node', (_message.Message,), {
149
+ 'DESCRIPTOR' : _DEPENDENCYGRAPH_NODE,
150
+ '__module__' : 'CoreNLP_pb2'
151
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.DependencyGraph.Node)
152
+ })
153
+ ,
154
+
155
+ 'Edge' : _reflection.GeneratedProtocolMessageType('Edge', (_message.Message,), {
156
+ 'DESCRIPTOR' : _DEPENDENCYGRAPH_EDGE,
157
+ '__module__' : 'CoreNLP_pb2'
158
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.DependencyGraph.Edge)
159
+ })
160
+ ,
161
+ 'DESCRIPTOR' : _DEPENDENCYGRAPH,
162
+ '__module__' : 'CoreNLP_pb2'
163
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.DependencyGraph)
164
+ })
165
+ _sym_db.RegisterMessage(DependencyGraph)
166
+ _sym_db.RegisterMessage(DependencyGraph.Node)
167
+ _sym_db.RegisterMessage(DependencyGraph.Edge)
168
+
169
+ CorefChain = _reflection.GeneratedProtocolMessageType('CorefChain', (_message.Message,), {
170
+
171
+ 'CorefMention' : _reflection.GeneratedProtocolMessageType('CorefMention', (_message.Message,), {
172
+ 'DESCRIPTOR' : _COREFCHAIN_COREFMENTION,
173
+ '__module__' : 'CoreNLP_pb2'
174
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.CorefChain.CorefMention)
175
+ })
176
+ ,
177
+ 'DESCRIPTOR' : _COREFCHAIN,
178
+ '__module__' : 'CoreNLP_pb2'
179
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.CorefChain)
180
+ })
181
+ _sym_db.RegisterMessage(CorefChain)
182
+ _sym_db.RegisterMessage(CorefChain.CorefMention)
183
+
184
+ Mention = _reflection.GeneratedProtocolMessageType('Mention', (_message.Message,), {
185
+ 'DESCRIPTOR' : _MENTION,
186
+ '__module__' : 'CoreNLP_pb2'
187
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Mention)
188
+ })
189
+ _sym_db.RegisterMessage(Mention)
190
+
191
+ IndexedWord = _reflection.GeneratedProtocolMessageType('IndexedWord', (_message.Message,), {
192
+ 'DESCRIPTOR' : _INDEXEDWORD,
193
+ '__module__' : 'CoreNLP_pb2'
194
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.IndexedWord)
195
+ })
196
+ _sym_db.RegisterMessage(IndexedWord)
197
+
198
+ SpeakerInfo = _reflection.GeneratedProtocolMessageType('SpeakerInfo', (_message.Message,), {
199
+ 'DESCRIPTOR' : _SPEAKERINFO,
200
+ '__module__' : 'CoreNLP_pb2'
201
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SpeakerInfo)
202
+ })
203
+ _sym_db.RegisterMessage(SpeakerInfo)
204
+
205
+ Span = _reflection.GeneratedProtocolMessageType('Span', (_message.Message,), {
206
+ 'DESCRIPTOR' : _SPAN,
207
+ '__module__' : 'CoreNLP_pb2'
208
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Span)
209
+ })
210
+ _sym_db.RegisterMessage(Span)
211
+
212
+ Timex = _reflection.GeneratedProtocolMessageType('Timex', (_message.Message,), {
213
+ 'DESCRIPTOR' : _TIMEX,
214
+ '__module__' : 'CoreNLP_pb2'
215
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Timex)
216
+ })
217
+ _sym_db.RegisterMessage(Timex)
218
+
219
+ Entity = _reflection.GeneratedProtocolMessageType('Entity', (_message.Message,), {
220
+ 'DESCRIPTOR' : _ENTITY,
221
+ '__module__' : 'CoreNLP_pb2'
222
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Entity)
223
+ })
224
+ _sym_db.RegisterMessage(Entity)
225
+
226
+ Relation = _reflection.GeneratedProtocolMessageType('Relation', (_message.Message,), {
227
+ 'DESCRIPTOR' : _RELATION,
228
+ '__module__' : 'CoreNLP_pb2'
229
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Relation)
230
+ })
231
+ _sym_db.RegisterMessage(Relation)
232
+
233
+ Operator = _reflection.GeneratedProtocolMessageType('Operator', (_message.Message,), {
234
+ 'DESCRIPTOR' : _OPERATOR,
235
+ '__module__' : 'CoreNLP_pb2'
236
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Operator)
237
+ })
238
+ _sym_db.RegisterMessage(Operator)
239
+
240
+ Polarity = _reflection.GeneratedProtocolMessageType('Polarity', (_message.Message,), {
241
+ 'DESCRIPTOR' : _POLARITY,
242
+ '__module__' : 'CoreNLP_pb2'
243
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Polarity)
244
+ })
245
+ _sym_db.RegisterMessage(Polarity)
246
+
247
+ NERMention = _reflection.GeneratedProtocolMessageType('NERMention', (_message.Message,), {
248
+ 'DESCRIPTOR' : _NERMENTION,
249
+ '__module__' : 'CoreNLP_pb2'
250
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.NERMention)
251
+ })
252
+ _sym_db.RegisterMessage(NERMention)
253
+
254
+ SentenceFragment = _reflection.GeneratedProtocolMessageType('SentenceFragment', (_message.Message,), {
255
+ 'DESCRIPTOR' : _SENTENCEFRAGMENT,
256
+ '__module__' : 'CoreNLP_pb2'
257
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SentenceFragment)
258
+ })
259
+ _sym_db.RegisterMessage(SentenceFragment)
260
+
261
+ TokenLocation = _reflection.GeneratedProtocolMessageType('TokenLocation', (_message.Message,), {
262
+ 'DESCRIPTOR' : _TOKENLOCATION,
263
+ '__module__' : 'CoreNLP_pb2'
264
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.TokenLocation)
265
+ })
266
+ _sym_db.RegisterMessage(TokenLocation)
267
+
268
+ RelationTriple = _reflection.GeneratedProtocolMessageType('RelationTriple', (_message.Message,), {
269
+ 'DESCRIPTOR' : _RELATIONTRIPLE,
270
+ '__module__' : 'CoreNLP_pb2'
271
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.RelationTriple)
272
+ })
273
+ _sym_db.RegisterMessage(RelationTriple)
274
+
275
+ MapStringString = _reflection.GeneratedProtocolMessageType('MapStringString', (_message.Message,), {
276
+ 'DESCRIPTOR' : _MAPSTRINGSTRING,
277
+ '__module__' : 'CoreNLP_pb2'
278
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.MapStringString)
279
+ })
280
+ _sym_db.RegisterMessage(MapStringString)
281
+
282
+ MapIntString = _reflection.GeneratedProtocolMessageType('MapIntString', (_message.Message,), {
283
+ 'DESCRIPTOR' : _MAPINTSTRING,
284
+ '__module__' : 'CoreNLP_pb2'
285
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.MapIntString)
286
+ })
287
+ _sym_db.RegisterMessage(MapIntString)
288
+
289
+ Section = _reflection.GeneratedProtocolMessageType('Section', (_message.Message,), {
290
+ 'DESCRIPTOR' : _SECTION,
291
+ '__module__' : 'CoreNLP_pb2'
292
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Section)
293
+ })
294
+ _sym_db.RegisterMessage(Section)
295
+
296
+ SemgrexRequest = _reflection.GeneratedProtocolMessageType('SemgrexRequest', (_message.Message,), {
297
+
298
+ 'Dependencies' : _reflection.GeneratedProtocolMessageType('Dependencies', (_message.Message,), {
299
+ 'DESCRIPTOR' : _SEMGREXREQUEST_DEPENDENCIES,
300
+ '__module__' : 'CoreNLP_pb2'
301
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SemgrexRequest.Dependencies)
302
+ })
303
+ ,
304
+ 'DESCRIPTOR' : _SEMGREXREQUEST,
305
+ '__module__' : 'CoreNLP_pb2'
306
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SemgrexRequest)
307
+ })
308
+ _sym_db.RegisterMessage(SemgrexRequest)
309
+ _sym_db.RegisterMessage(SemgrexRequest.Dependencies)
310
+
311
+ SemgrexResponse = _reflection.GeneratedProtocolMessageType('SemgrexResponse', (_message.Message,), {
312
+
313
+ 'NamedNode' : _reflection.GeneratedProtocolMessageType('NamedNode', (_message.Message,), {
314
+ 'DESCRIPTOR' : _SEMGREXRESPONSE_NAMEDNODE,
315
+ '__module__' : 'CoreNLP_pb2'
316
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SemgrexResponse.NamedNode)
317
+ })
318
+ ,
319
+
320
+ 'NamedRelation' : _reflection.GeneratedProtocolMessageType('NamedRelation', (_message.Message,), {
321
+ 'DESCRIPTOR' : _SEMGREXRESPONSE_NAMEDRELATION,
322
+ '__module__' : 'CoreNLP_pb2'
323
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SemgrexResponse.NamedRelation)
324
+ })
325
+ ,
326
+
327
+ 'NamedEdge' : _reflection.GeneratedProtocolMessageType('NamedEdge', (_message.Message,), {
328
+ 'DESCRIPTOR' : _SEMGREXRESPONSE_NAMEDEDGE,
329
+ '__module__' : 'CoreNLP_pb2'
330
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SemgrexResponse.NamedEdge)
331
+ })
332
+ ,
333
+
334
+ 'Match' : _reflection.GeneratedProtocolMessageType('Match', (_message.Message,), {
335
+ 'DESCRIPTOR' : _SEMGREXRESPONSE_MATCH,
336
+ '__module__' : 'CoreNLP_pb2'
337
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SemgrexResponse.Match)
338
+ })
339
+ ,
340
+
341
+ 'SemgrexResult' : _reflection.GeneratedProtocolMessageType('SemgrexResult', (_message.Message,), {
342
+ 'DESCRIPTOR' : _SEMGREXRESPONSE_SEMGREXRESULT,
343
+ '__module__' : 'CoreNLP_pb2'
344
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SemgrexResponse.SemgrexResult)
345
+ })
346
+ ,
347
+
348
+ 'GraphResult' : _reflection.GeneratedProtocolMessageType('GraphResult', (_message.Message,), {
349
+ 'DESCRIPTOR' : _SEMGREXRESPONSE_GRAPHRESULT,
350
+ '__module__' : 'CoreNLP_pb2'
351
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SemgrexResponse.GraphResult)
352
+ })
353
+ ,
354
+ 'DESCRIPTOR' : _SEMGREXRESPONSE,
355
+ '__module__' : 'CoreNLP_pb2'
356
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SemgrexResponse)
357
+ })
358
+ _sym_db.RegisterMessage(SemgrexResponse)
359
+ _sym_db.RegisterMessage(SemgrexResponse.NamedNode)
360
+ _sym_db.RegisterMessage(SemgrexResponse.NamedRelation)
361
+ _sym_db.RegisterMessage(SemgrexResponse.NamedEdge)
362
+ _sym_db.RegisterMessage(SemgrexResponse.Match)
363
+ _sym_db.RegisterMessage(SemgrexResponse.SemgrexResult)
364
+ _sym_db.RegisterMessage(SemgrexResponse.GraphResult)
365
+
366
+ SsurgeonRequest = _reflection.GeneratedProtocolMessageType('SsurgeonRequest', (_message.Message,), {
367
+
368
+ 'Ssurgeon' : _reflection.GeneratedProtocolMessageType('Ssurgeon', (_message.Message,), {
369
+ 'DESCRIPTOR' : _SSURGEONREQUEST_SSURGEON,
370
+ '__module__' : 'CoreNLP_pb2'
371
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SsurgeonRequest.Ssurgeon)
372
+ })
373
+ ,
374
+ 'DESCRIPTOR' : _SSURGEONREQUEST,
375
+ '__module__' : 'CoreNLP_pb2'
376
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SsurgeonRequest)
377
+ })
378
+ _sym_db.RegisterMessage(SsurgeonRequest)
379
+ _sym_db.RegisterMessage(SsurgeonRequest.Ssurgeon)
380
+
381
+ SsurgeonResponse = _reflection.GeneratedProtocolMessageType('SsurgeonResponse', (_message.Message,), {
382
+
383
+ 'SsurgeonResult' : _reflection.GeneratedProtocolMessageType('SsurgeonResult', (_message.Message,), {
384
+ 'DESCRIPTOR' : _SSURGEONRESPONSE_SSURGEONRESULT,
385
+ '__module__' : 'CoreNLP_pb2'
386
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SsurgeonResponse.SsurgeonResult)
387
+ })
388
+ ,
389
+ 'DESCRIPTOR' : _SSURGEONRESPONSE,
390
+ '__module__' : 'CoreNLP_pb2'
391
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SsurgeonResponse)
392
+ })
393
+ _sym_db.RegisterMessage(SsurgeonResponse)
394
+ _sym_db.RegisterMessage(SsurgeonResponse.SsurgeonResult)
395
+
396
+ TokensRegexRequest = _reflection.GeneratedProtocolMessageType('TokensRegexRequest', (_message.Message,), {
397
+ 'DESCRIPTOR' : _TOKENSREGEXREQUEST,
398
+ '__module__' : 'CoreNLP_pb2'
399
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.TokensRegexRequest)
400
+ })
401
+ _sym_db.RegisterMessage(TokensRegexRequest)
402
+
403
+ TokensRegexResponse = _reflection.GeneratedProtocolMessageType('TokensRegexResponse', (_message.Message,), {
404
+
405
+ 'MatchLocation' : _reflection.GeneratedProtocolMessageType('MatchLocation', (_message.Message,), {
406
+ 'DESCRIPTOR' : _TOKENSREGEXRESPONSE_MATCHLOCATION,
407
+ '__module__' : 'CoreNLP_pb2'
408
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.TokensRegexResponse.MatchLocation)
409
+ })
410
+ ,
411
+
412
+ 'Match' : _reflection.GeneratedProtocolMessageType('Match', (_message.Message,), {
413
+ 'DESCRIPTOR' : _TOKENSREGEXRESPONSE_MATCH,
414
+ '__module__' : 'CoreNLP_pb2'
415
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.TokensRegexResponse.Match)
416
+ })
417
+ ,
418
+
419
+ 'PatternMatch' : _reflection.GeneratedProtocolMessageType('PatternMatch', (_message.Message,), {
420
+ 'DESCRIPTOR' : _TOKENSREGEXRESPONSE_PATTERNMATCH,
421
+ '__module__' : 'CoreNLP_pb2'
422
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.TokensRegexResponse.PatternMatch)
423
+ })
424
+ ,
425
+ 'DESCRIPTOR' : _TOKENSREGEXRESPONSE,
426
+ '__module__' : 'CoreNLP_pb2'
427
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.TokensRegexResponse)
428
+ })
429
+ _sym_db.RegisterMessage(TokensRegexResponse)
430
+ _sym_db.RegisterMessage(TokensRegexResponse.MatchLocation)
431
+ _sym_db.RegisterMessage(TokensRegexResponse.Match)
432
+ _sym_db.RegisterMessage(TokensRegexResponse.PatternMatch)
433
+
434
+ DependencyEnhancerRequest = _reflection.GeneratedProtocolMessageType('DependencyEnhancerRequest', (_message.Message,), {
435
+ 'DESCRIPTOR' : _DEPENDENCYENHANCERREQUEST,
436
+ '__module__' : 'CoreNLP_pb2'
437
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.DependencyEnhancerRequest)
438
+ })
439
+ _sym_db.RegisterMessage(DependencyEnhancerRequest)
440
+
441
+ FlattenedParseTree = _reflection.GeneratedProtocolMessageType('FlattenedParseTree', (_message.Message,), {
442
+
443
+ 'Node' : _reflection.GeneratedProtocolMessageType('Node', (_message.Message,), {
444
+ 'DESCRIPTOR' : _FLATTENEDPARSETREE_NODE,
445
+ '__module__' : 'CoreNLP_pb2'
446
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.FlattenedParseTree.Node)
447
+ })
448
+ ,
449
+ 'DESCRIPTOR' : _FLATTENEDPARSETREE,
450
+ '__module__' : 'CoreNLP_pb2'
451
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.FlattenedParseTree)
452
+ })
453
+ _sym_db.RegisterMessage(FlattenedParseTree)
454
+ _sym_db.RegisterMessage(FlattenedParseTree.Node)
455
+
456
+ EvaluateParserRequest = _reflection.GeneratedProtocolMessageType('EvaluateParserRequest', (_message.Message,), {
457
+
458
+ 'ParseResult' : _reflection.GeneratedProtocolMessageType('ParseResult', (_message.Message,), {
459
+ 'DESCRIPTOR' : _EVALUATEPARSERREQUEST_PARSERESULT,
460
+ '__module__' : 'CoreNLP_pb2'
461
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.EvaluateParserRequest.ParseResult)
462
+ })
463
+ ,
464
+ 'DESCRIPTOR' : _EVALUATEPARSERREQUEST,
465
+ '__module__' : 'CoreNLP_pb2'
466
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.EvaluateParserRequest)
467
+ })
468
+ _sym_db.RegisterMessage(EvaluateParserRequest)
469
+ _sym_db.RegisterMessage(EvaluateParserRequest.ParseResult)
470
+
471
+ EvaluateParserResponse = _reflection.GeneratedProtocolMessageType('EvaluateParserResponse', (_message.Message,), {
472
+ 'DESCRIPTOR' : _EVALUATEPARSERRESPONSE,
473
+ '__module__' : 'CoreNLP_pb2'
474
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.EvaluateParserResponse)
475
+ })
476
+ _sym_db.RegisterMessage(EvaluateParserResponse)
477
+
478
+ TsurgeonRequest = _reflection.GeneratedProtocolMessageType('TsurgeonRequest', (_message.Message,), {
479
+
480
+ 'Operation' : _reflection.GeneratedProtocolMessageType('Operation', (_message.Message,), {
481
+ 'DESCRIPTOR' : _TSURGEONREQUEST_OPERATION,
482
+ '__module__' : 'CoreNLP_pb2'
483
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.TsurgeonRequest.Operation)
484
+ })
485
+ ,
486
+ 'DESCRIPTOR' : _TSURGEONREQUEST,
487
+ '__module__' : 'CoreNLP_pb2'
488
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.TsurgeonRequest)
489
+ })
490
+ _sym_db.RegisterMessage(TsurgeonRequest)
491
+ _sym_db.RegisterMessage(TsurgeonRequest.Operation)
492
+
493
+ TsurgeonResponse = _reflection.GeneratedProtocolMessageType('TsurgeonResponse', (_message.Message,), {
494
+ 'DESCRIPTOR' : _TSURGEONRESPONSE,
495
+ '__module__' : 'CoreNLP_pb2'
496
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.TsurgeonResponse)
497
+ })
498
+ _sym_db.RegisterMessage(TsurgeonResponse)
499
+
500
+ MorphologyRequest = _reflection.GeneratedProtocolMessageType('MorphologyRequest', (_message.Message,), {
501
+
502
+ 'TaggedWord' : _reflection.GeneratedProtocolMessageType('TaggedWord', (_message.Message,), {
503
+ 'DESCRIPTOR' : _MORPHOLOGYREQUEST_TAGGEDWORD,
504
+ '__module__' : 'CoreNLP_pb2'
505
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.MorphologyRequest.TaggedWord)
506
+ })
507
+ ,
508
+ 'DESCRIPTOR' : _MORPHOLOGYREQUEST,
509
+ '__module__' : 'CoreNLP_pb2'
510
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.MorphologyRequest)
511
+ })
512
+ _sym_db.RegisterMessage(MorphologyRequest)
513
+ _sym_db.RegisterMessage(MorphologyRequest.TaggedWord)
514
+
515
+ MorphologyResponse = _reflection.GeneratedProtocolMessageType('MorphologyResponse', (_message.Message,), {
516
+
517
+ 'WordTagLemma' : _reflection.GeneratedProtocolMessageType('WordTagLemma', (_message.Message,), {
518
+ 'DESCRIPTOR' : _MORPHOLOGYRESPONSE_WORDTAGLEMMA,
519
+ '__module__' : 'CoreNLP_pb2'
520
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.MorphologyResponse.WordTagLemma)
521
+ })
522
+ ,
523
+ 'DESCRIPTOR' : _MORPHOLOGYRESPONSE,
524
+ '__module__' : 'CoreNLP_pb2'
525
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.MorphologyResponse)
526
+ })
527
+ _sym_db.RegisterMessage(MorphologyResponse)
528
+ _sym_db.RegisterMessage(MorphologyResponse.WordTagLemma)
529
+
530
+ DependencyConverterRequest = _reflection.GeneratedProtocolMessageType('DependencyConverterRequest', (_message.Message,), {
531
+ 'DESCRIPTOR' : _DEPENDENCYCONVERTERREQUEST,
532
+ '__module__' : 'CoreNLP_pb2'
533
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.DependencyConverterRequest)
534
+ })
535
+ _sym_db.RegisterMessage(DependencyConverterRequest)
536
+
537
+ DependencyConverterResponse = _reflection.GeneratedProtocolMessageType('DependencyConverterResponse', (_message.Message,), {
538
+
539
+ 'DependencyConversion' : _reflection.GeneratedProtocolMessageType('DependencyConversion', (_message.Message,), {
540
+ 'DESCRIPTOR' : _DEPENDENCYCONVERTERRESPONSE_DEPENDENCYCONVERSION,
541
+ '__module__' : 'CoreNLP_pb2'
542
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.DependencyConverterResponse.DependencyConversion)
543
+ })
544
+ ,
545
+ 'DESCRIPTOR' : _DEPENDENCYCONVERTERRESPONSE,
546
+ '__module__' : 'CoreNLP_pb2'
547
+ # @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.DependencyConverterResponse)
548
+ })
549
+ _sym_db.RegisterMessage(DependencyConverterResponse)
550
+ _sym_db.RegisterMessage(DependencyConverterResponse.DependencyConversion)
551
+
552
+ if _descriptor._USE_C_DESCRIPTORS == False:
553
+
554
+ DESCRIPTOR._options = None
555
+ DESCRIPTOR._serialized_options = b'\n\031edu.stanford.nlp.pipelineB\rCoreNLPProtos'
556
+ _DEPENDENCYGRAPH.fields_by_name['root']._options = None
557
+ _DEPENDENCYGRAPH.fields_by_name['root']._serialized_options = b'\020\001'
558
+ _DEPENDENCYGRAPH.fields_by_name['rootNode']._options = None
559
+ _DEPENDENCYGRAPH.fields_by_name['rootNode']._serialized_options = b'\020\001'
560
+ _LANGUAGE._serialized_start=13457
561
+ _LANGUAGE._serialized_end=13620
562
+ _SENTIMENT._serialized_start=13622
563
+ _SENTIMENT._serialized_end=13726
564
+ _NATURALLOGICRELATION._serialized_start=13729
565
+ _NATURALLOGICRELATION._serialized_end=13876
566
+ _DOCUMENT._serialized_start=45
567
+ _DOCUMENT._serialized_end=782
568
+ _SENTENCE._serialized_start=785
569
+ _SENTENCE._serialized_end=2820
570
+ _TOKEN._serialized_start=2823
571
+ _TOKEN._serialized_end=4477
572
+ _QUOTE._serialized_start=4480
573
+ _QUOTE._serialized_end=4964
574
+ _PARSETREE._serialized_start=4967
575
+ _PARSETREE._serialized_end=5166
576
+ _DEPENDENCYGRAPH._serialized_start=5169
577
+ _DEPENDENCYGRAPH._serialized_end=5708
578
+ _DEPENDENCYGRAPH_NODE._serialized_start=5403
579
+ _DEPENDENCYGRAPH_NODE._serialized_end=5491
580
+ _DEPENDENCYGRAPH_EDGE._serialized_start=5494
581
+ _DEPENDENCYGRAPH_EDGE._serialized_end=5708
582
+ _COREFCHAIN._serialized_start=5711
583
+ _COREFCHAIN._serialized_end=6037
584
+ _COREFCHAIN_COREFMENTION._serialized_start=5836
585
+ _COREFCHAIN_COREFMENTION._serialized_end=6037
586
+ _MENTION._serialized_start=6040
587
+ _MENTION._serialized_end=7175
588
+ _INDEXEDWORD._serialized_start=7177
589
+ _INDEXEDWORD._serialized_end=7265
590
+ _SPEAKERINFO._serialized_start=7267
591
+ _SPEAKERINFO._serialized_end=7319
592
+ _SPAN._serialized_start=7321
593
+ _SPAN._serialized_end=7355
594
+ _TIMEX._serialized_start=7357
595
+ _TIMEX._serialized_end=7476
596
+ _ENTITY._serialized_start=7479
597
+ _ENTITY._serialized_end=7698
598
+ _RELATION._serialized_start=7701
599
+ _RELATION._serialized_end=7884
600
+ _OPERATOR._serialized_start=7887
601
+ _OPERATOR._serialized_end=8065
602
+ _POLARITY._serialized_start=8068
603
+ _POLARITY._serialized_end=8621
604
+ _NERMENTION._serialized_start=8624
605
+ _NERMENTION._serialized_end=8973
606
+ _SENTENCEFRAGMENT._serialized_start=8975
607
+ _SENTENCEFRAGMENT._serialized_end=9064
608
+ _TOKENLOCATION._serialized_start=9066
609
+ _TOKENLOCATION._serialized_end=9124
610
+ _RELATIONTRIPLE._serialized_start=9127
611
+ _RELATIONTRIPLE._serialized_end=9537
612
+ _MAPSTRINGSTRING._serialized_start=9539
613
+ _MAPSTRINGSTRING._serialized_end=9584
614
+ _MAPINTSTRING._serialized_start=9586
615
+ _MAPINTSTRING._serialized_end=9628
616
+ _SECTION._serialized_start=9631
617
+ _SECTION._serialized_end=9883
618
+ _SEMGREXREQUEST._serialized_start=9886
619
+ _SEMGREXREQUEST._serialized_end=10114
620
+ _SEMGREXREQUEST_DEPENDENCIES._serialized_start=9992
621
+ _SEMGREXREQUEST_DEPENDENCIES._serialized_end=10114
622
+ _SEMGREXRESPONSE._serialized_start=10117
623
+ _SEMGREXRESPONSE._serialized_end=10880
624
+ _SEMGREXRESPONSE_NAMEDNODE._serialized_start=10208
625
+ _SEMGREXRESPONSE_NAMEDNODE._serialized_end=10253
626
+ _SEMGREXRESPONSE_NAMEDRELATION._serialized_start=10255
627
+ _SEMGREXRESPONSE_NAMEDRELATION._serialized_end=10298
628
+ _SEMGREXRESPONSE_NAMEDEDGE._serialized_start=10301
629
+ _SEMGREXRESPONSE_NAMEDEDGE._serialized_end=10429
630
+ _SEMGREXRESPONSE_MATCH._serialized_start=10432
631
+ _SEMGREXRESPONSE_MATCH._serialized_end=10709
632
+ _SEMGREXRESPONSE_SEMGREXRESULT._serialized_start=10711
633
+ _SEMGREXRESPONSE_SEMGREXRESULT._serialized_end=10791
634
+ _SEMGREXRESPONSE_GRAPHRESULT._serialized_start=10793
635
+ _SEMGREXRESPONSE_GRAPHRESULT._serialized_end=10880
636
+ _SSURGEONREQUEST._serialized_start=10883
637
+ _SSURGEONREQUEST._serialized_end=11123
638
+ _SSURGEONREQUEST_SSURGEON._serialized_start=11032
639
+ _SSURGEONREQUEST_SSURGEON._serialized_end=11123
640
+ _SSURGEONRESPONSE._serialized_start=11126
641
+ _SSURGEONRESPONSE._serialized_end=11314
642
+ _SSURGEONRESPONSE_SSURGEONRESULT._serialized_start=11222
643
+ _SSURGEONRESPONSE_SSURGEONRESULT._serialized_end=11314
644
+ _TOKENSREGEXREQUEST._serialized_start=11316
645
+ _TOKENSREGEXREQUEST._serialized_end=11403
646
+ _TOKENSREGEXRESPONSE._serialized_start=11406
647
+ _TOKENSREGEXRESPONSE._serialized_end=11829
648
+ _TOKENSREGEXRESPONSE_MATCHLOCATION._serialized_start=11505
649
+ _TOKENSREGEXRESPONSE_MATCHLOCATION._serialized_end=11562
650
+ _TOKENSREGEXRESPONSE_MATCH._serialized_start=11565
651
+ _TOKENSREGEXRESPONSE_MATCH._serialized_end=11744
652
+ _TOKENSREGEXRESPONSE_PATTERNMATCH._serialized_start=11746
653
+ _TOKENSREGEXRESPONSE_PATTERNMATCH._serialized_end=11829
654
+ _DEPENDENCYENHANCERREQUEST._serialized_start=11832
655
+ _DEPENDENCYENHANCERREQUEST._serialized_end=12006
656
+ _FLATTENEDPARSETREE._serialized_start=12009
657
+ _FLATTENEDPARSETREE._serialized_end=12189
658
+ _FLATTENEDPARSETREE_NODE._serialized_start=12098
659
+ _FLATTENEDPARSETREE_NODE._serialized_end=12189
660
+ _EVALUATEPARSERREQUEST._serialized_start=12192
661
+ _EVALUATEPARSERREQUEST._serialized_end=12438
662
+ _EVALUATEPARSERREQUEST_PARSERESULT._serialized_start=12298
663
+ _EVALUATEPARSERREQUEST_PARSERESULT._serialized_end=12438
664
+ _EVALUATEPARSERRESPONSE._serialized_start=12440
665
+ _EVALUATEPARSERRESPONSE._serialized_end=12509
666
+ _TSURGEONREQUEST._serialized_start=12512
667
+ _TSURGEONREQUEST._serialized_end=12712
668
+ _TSURGEONREQUEST_OPERATION._serialized_start=12667
669
+ _TSURGEONREQUEST_OPERATION._serialized_end=12712
670
+ _TSURGEONRESPONSE._serialized_start=12714
671
+ _TSURGEONRESPONSE._serialized_end=12794
672
+ _MORPHOLOGYREQUEST._serialized_start=12797
673
+ _MORPHOLOGYREQUEST._serialized_end=12930
674
+ _MORPHOLOGYREQUEST_TAGGEDWORD._serialized_start=12890
675
+ _MORPHOLOGYREQUEST_TAGGEDWORD._serialized_end=12930
676
+ _MORPHOLOGYRESPONSE._serialized_start=12933
677
+ _MORPHOLOGYRESPONSE._serialized_end=13087
678
+ _MORPHOLOGYRESPONSE_WORDTAGLEMMA._serialized_start=13030
679
+ _MORPHOLOGYRESPONSE_WORDTAGLEMMA._serialized_end=13087
680
+ _DEPENDENCYCONVERTERREQUEST._serialized_start=13089
681
+ _DEPENDENCYCONVERTERREQUEST._serialized_end=13179
682
+ _DEPENDENCYCONVERTERRESPONSE._serialized_start=13182
683
+ _DEPENDENCYCONVERTERRESPONSE._serialized_end=13454
684
+ _DEPENDENCYCONVERTERRESPONSE_DEPENDENCYCONVERSION._serialized_start=13312
685
+ _DEPENDENCYCONVERTERRESPONSE_DEPENDENCYCONVERSION._serialized_end=13454
686
+ # @@protoc_insertion_point(module_scope)
stanza/stanza/protobuf/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ from io import BytesIO
4
+ import warnings
5
+
6
+ from google.protobuf.internal.encoder import _EncodeVarint
7
+ from google.protobuf.internal.decoder import _DecodeVarint
8
+ from google.protobuf.message import DecodeError
9
+ from .CoreNLP_pb2 import *
10
+
11
+ def parseFromDelimitedString(obj, buf, offset=0):
12
+ """
13
+ Stanford CoreNLP uses the Java "writeDelimitedTo" function, which
14
+ writes the size (and offset) of the buffer before writing the object.
15
+ This function handles parsing this message starting from offset 0.
16
+
17
+ @returns how many bytes of @buf were consumed.
18
+ """
19
+ size, pos = _DecodeVarint(buf, offset)
20
+ try:
21
+ obj.ParseFromString(buf[offset+pos:offset+pos+size])
22
+ except DecodeError as e:
23
+ warnings.warn("Failed to decode a serialized output from CoreNLP server. An incomplete or empty object will be returned.", \
24
+ RuntimeWarning)
25
+ return pos+size
26
+
27
+ def writeToDelimitedString(obj, stream=None):
28
+ """
29
+ Stanford CoreNLP uses the Java "writeDelimitedTo" function, which
30
+ writes the size (and offset) of the buffer before writing the object.
31
+ This function handles parsing this message starting from offset 0.
32
+
33
+ @returns how many bytes of @buf were consumed.
34
+ """
35
+ if stream is None:
36
+ stream = BytesIO()
37
+
38
+ _EncodeVarint(stream.write, obj.ByteSize(), True)
39
+ stream.write(obj.SerializeToString())
40
+ return stream
41
+
42
+ def to_text(sentence):
43
+ """
44
+ Helper routine that converts a Sentence protobuf to a string from
45
+ its tokens.
46
+ """
47
+ text = ""
48
+ for i, tok in enumerate(sentence.token):
49
+ if i != 0:
50
+ text += tok.before
51
+ text += tok.word
52
+ return text
stanza/stanza/resources/common.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Common utilities for Stanza resources.
3
+ """
4
+
5
+ from collections import defaultdict, namedtuple
6
+ import errno
7
+ import hashlib
8
+ import json
9
+ import logging
10
+ import os
11
+ from pathlib import Path
12
+ import requests
13
+ import shutil
14
+ import tempfile
15
+ import zipfile
16
+
17
+ from tqdm.auto import tqdm
18
+
19
+ from stanza.utils.helper_func import make_table
20
+ from stanza.pipeline._constants import TOKENIZE, MWT, POS, LEMMA, DEPPARSE, NER, SENTIMENT
21
+ from stanza.pipeline.registry import PIPELINE_NAMES, PROCESSOR_VARIANTS
22
+ from stanza.resources.default_packages import PACKAGES
23
+ from stanza._version import __resources_version__
24
+
25
+ logger = logging.getLogger('stanza')
26
+
27
+ # set home dir for default
28
+ HOME_DIR = str(Path.home())
29
+ STANFORDNLP_RESOURCES_URL = 'https://nlp.stanford.edu/software/stanza/stanza-resources/'
30
+ STANZA_RESOURCES_GITHUB = 'https://raw.githubusercontent.com/stanfordnlp/stanza-resources/'
31
+ DEFAULT_RESOURCES_URL = os.getenv('STANZA_RESOURCES_URL', STANZA_RESOURCES_GITHUB + 'main')
32
+ DEFAULT_RESOURCES_VERSION = os.getenv(
33
+ 'STANZA_RESOURCES_VERSION',
34
+ __resources_version__
35
+ )
36
+ DEFAULT_MODEL_URL = os.getenv('STANZA_MODEL_URL', 'default')
37
+ DEFAULT_MODEL_DIR = os.getenv(
38
+ 'STANZA_RESOURCES_DIR',
39
+ os.path.join(HOME_DIR, 'stanza_resources')
40
+ )
41
+
42
+ PRETRAIN_NAMES = ("pretrain", "forward_charlm", "backward_charlm")
43
+
44
+ class ResourcesFileNotFoundError(FileNotFoundError):
45
+ def __init__(self, resources_filepath):
46
+ super().__init__(f"Resources file not found at: {resources_filepath} Try to download the model again.")
47
+ self.resources_filepath = resources_filepath
48
+
49
+ class UnknownLanguageError(ValueError):
50
+ def __init__(self, unknown):
51
+ super().__init__(f"Unknown language requested: {unknown}")
52
+ self.unknown_language = unknown
53
+
54
+ class UnknownProcessorError(ValueError):
55
+ def __init__(self, unknown):
56
+ super().__init__(f"Unknown processor type requested: {unknown}")
57
+ self.unknown_processor = unknown
58
+
59
+ ModelSpecification = namedtuple('ModelSpecification', ['processor', 'package', 'dependencies'])
60
+
61
+ def ensure_dir(path):
62
+ """
63
+ Create dir in case it does not exist.
64
+ """
65
+ Path(path).mkdir(parents=True, exist_ok=True)
66
+
67
+ def get_md5(path):
68
+ """
69
+ Get the MD5 value of a path.
70
+ """
71
+ try:
72
+ with open(path, 'rb') as fin:
73
+ data = fin.read()
74
+ except OSError as e:
75
+ if not e.filename:
76
+ e.filename = path
77
+ raise
78
+ return hashlib.md5(data).hexdigest()
79
+
80
+ def unzip(path, filename):
81
+ """
82
+ Fully unzip a file `filename` that's in a directory `dir`.
83
+ """
84
+ logger.debug(f'Unzip: {path}/{filename}...')
85
+ with zipfile.ZipFile(os.path.join(path, filename)) as f:
86
+ f.extractall(path)
87
+
88
+ def get_root_from_zipfile(filename):
89
+ """
90
+ Get the root directory from a archived zip file.
91
+ """
92
+ zf = zipfile.ZipFile(filename, "r")
93
+ assert len(zf.filelist) > 0, \
94
+ f"Zip file at f{filename} seems to be corrupted. Please check it."
95
+ return os.path.dirname(zf.filelist[0].filename)
96
+
97
+ def file_exists(path, md5):
98
+ """
99
+ Check if the file at `path` exists and match the provided md5 value.
100
+ """
101
+ return os.path.exists(path) and get_md5(path) == md5
102
+
103
+ def assert_file_exists(path, md5=None, alternate_md5=None):
104
+ if not os.path.exists(path):
105
+ raise FileNotFoundError(errno.ENOENT, "Cannot find expected file", path)
106
+ if md5:
107
+ file_md5 = get_md5(path)
108
+ if file_md5 != md5:
109
+ if file_md5 == alternate_md5:
110
+ logger.debug("Found a possibly older version of file %s, md5 %s instead of %s", path, alternate_md5, md5)
111
+ else:
112
+ raise ValueError("md5 for %s is %s, expected %s" % (path, file_md5, md5))
113
+
114
+ def download_file(url, path, proxies, raise_for_status=False):
115
+ """
116
+ Download a URL into a file as specified by `path`.
117
+ """
118
+ verbose = logger.level in [0, 10, 20]
119
+ r = requests.get(url, stream=True, proxies=proxies)
120
+ if raise_for_status:
121
+ r.raise_for_status()
122
+ with open(path, 'wb') as f:
123
+ file_size = int(r.headers.get('content-length'))
124
+ default_chunk_size = 131072
125
+ desc = 'Downloading ' + url
126
+ with tqdm(total=file_size, unit='B', unit_scale=True, \
127
+ disable=not verbose, desc=desc) as pbar:
128
+ for chunk in r.iter_content(chunk_size=default_chunk_size):
129
+ if chunk:
130
+ f.write(chunk)
131
+ f.flush()
132
+ pbar.update(len(chunk))
133
+ return r.status_code
134
+
135
+ def request_file(url, path, proxies=None, md5=None, raise_for_status=False, log_info=True, alternate_md5=None):
136
+ """
137
+ A complete wrapper over download_file() that also make sure the directory of
138
+ `path` exists, and that a file matching the md5 value does not exist.
139
+
140
+ alternate_md5 allows for an alternate md5 that is acceptable (such as if an older version of a file is okay)
141
+ """
142
+ basedir = Path(path).parent
143
+ ensure_dir(basedir)
144
+ if file_exists(path, md5):
145
+ if log_info:
146
+ logger.info(f'File exists: {path}')
147
+ else:
148
+ logger.debug(f'File exists: {path}')
149
+ return
150
+ # We write data first to a temporary directory,
151
+ # then use os.replace() so that multiple processes
152
+ # running at the same time don't clobber each other
153
+ # with partially downloaded files
154
+ # This was especially common with resources.json
155
+ with tempfile.TemporaryDirectory(dir=basedir) as temp:
156
+ temppath = os.path.join(temp, os.path.split(path)[-1])
157
+ download_file(url, temppath, proxies, raise_for_status)
158
+ os.replace(temppath, path)
159
+ assert_file_exists(path, md5, alternate_md5)
160
+ if log_info:
161
+ logger.info(f'Downloaded file to {path}')
162
+ else:
163
+ logger.debug(f'Downloaded file to {path}')
164
+
165
+ def sort_processors(processor_list):
166
+ sorted_list = []
167
+ for processor in PIPELINE_NAMES:
168
+ for item in processor_list:
169
+ if item[0] == processor:
170
+ sorted_list.append(item)
171
+ # going just by processors in PIPELINE_NAMES, this drops any names
172
+ # which are not an official processor but might still be useful
173
+ # check the list and append them to the end
174
+ # this is especially useful when downloading pretrain or charlm models
175
+ for processor in processor_list:
176
+ for item in sorted_list:
177
+ if processor[0] == item[0]:
178
+ break
179
+ else:
180
+ sorted_list.append(item)
181
+ return sorted_list
182
+
183
+ def add_mwt(processors, resources, lang):
184
+ """Add mwt if tokenize is passed without mwt.
185
+
186
+ If tokenize is in the list, but mwt is not, and there is a corresponding
187
+ tokenize and mwt pair in the resources file, mwt is added so no missing
188
+ mwt errors are raised.
189
+
190
+ TODO: how does this handle EWT in English?
191
+ """
192
+ value = processors[TOKENIZE]
193
+ if value in resources[lang][PACKAGES] and MWT in resources[lang][PACKAGES][value]:
194
+ logger.warning("Language %s package %s expects mwt, which has been added", lang, value)
195
+ processors[MWT] = value
196
+ elif (value in resources[lang][TOKENIZE] and MWT in resources[lang] and value in resources[lang][MWT]):
197
+ logger.warning("Language %s package %s expects mwt, which has been added", lang, value)
198
+ processors[MWT] = value
199
+
200
+ def maintain_processor_list(resources, lang, package, processors, allow_pretrain=False, maybe_add_mwt=True):
201
+ """
202
+ Given a parsed resources file, language, and possible package
203
+ and/or processors, expands the package to the list of processors
204
+
205
+ Returns a list of processors
206
+ Each item in the list of processors is a pair:
207
+ name, then a list of ModelSpecification
208
+ so, for example:
209
+ [['pos', [ModelSpecification(processor='pos', package='gsd', dependencies=None)]],
210
+ ['depparse', [ModelSpecification(processor='depparse', package='gsd', dependencies=None)]]]
211
+ """
212
+ processor_list = defaultdict(list)
213
+ # resolve processor models
214
+ if processors:
215
+ logger.debug(f'Processing parameter "processors"...')
216
+ if maybe_add_mwt and TOKENIZE in processors and MWT not in processors:
217
+ add_mwt(processors, resources, lang)
218
+ for key, plist in processors.items():
219
+ if not isinstance(key, str):
220
+ raise ValueError("Processor names must be strings")
221
+ if not isinstance(plist, (tuple, list, str)):
222
+ raise ValueError("Processor values must be strings")
223
+ if isinstance(plist, str):
224
+ plist = [plist]
225
+ if key not in PIPELINE_NAMES:
226
+ if not allow_pretrain or key not in PRETRAIN_NAMES:
227
+ raise UnknownProcessorError(key)
228
+ for value in plist:
229
+ # check if keys and values can be found
230
+ if key in resources[lang] and value in resources[lang][key]:
231
+ logger.debug(f'Found {key}: {value}.')
232
+ processor_list[key].append(value)
233
+ # allow values to be default in some cases
234
+ elif value in resources[lang][PACKAGES] and key in resources[lang][PACKAGES][value]:
235
+ logger.debug(
236
+ f'Found {key}: {resources[lang][PACKAGES][value][key]}.'
237
+ )
238
+ processor_list[key].append(resources[lang][PACKAGES][value][key])
239
+ # optional defaults will be activated if specifically turned on
240
+ elif value in resources[lang][PACKAGES] and 'optional' in resources[lang][PACKAGES][value] and key in resources[lang][PACKAGES][value]['optional']:
241
+ logger.debug(
242
+ f"Found {key}: {resources[lang][PACKAGES][value]['optional'][key]}."
243
+ )
244
+ processor_list[key].append(resources[lang][PACKAGES][value]['optional'][key])
245
+ # allow processors to be set to variants that we didn't implement
246
+ elif value in PROCESSOR_VARIANTS[key]:
247
+ logger.debug(
248
+ f'Found {key}: {value}. '
249
+ f'Using external {value} variant for the {key} processor.'
250
+ )
251
+ processor_list[key].append(value)
252
+ # allow lemma to be set to "identity"
253
+ elif key == LEMMA and value == 'identity':
254
+ logger.debug(
255
+ f'Found {key}: {value}. Using identity lemmatizer.'
256
+ )
257
+ processor_list[key].append(value)
258
+ # not a processor in the officially supported processor list
259
+ elif key not in resources[lang]:
260
+ logger.debug(
261
+ f'{key}: {value} is not officially supported by Stanza, '
262
+ f'loading it anyway.'
263
+ )
264
+ processor_list[key].append(value)
265
+ # cannot find the package for a processor and warn user
266
+ else:
267
+ logger.warning(
268
+ f'Can not find {key}: {value} from official model list. '
269
+ f'Ignoring it.'
270
+ )
271
+ # resolve package
272
+ if package:
273
+ logger.debug(f'Processing parameter "package"...')
274
+ if PACKAGES in resources[lang] and package in resources[lang][PACKAGES]:
275
+ for key, value in resources[lang][PACKAGES][package].items():
276
+ if key != 'optional' and key not in processor_list:
277
+ logger.debug(f'Found {key}: {value}.')
278
+ processor_list[key].append(value)
279
+ else:
280
+ flag = False
281
+ for key in PIPELINE_NAMES:
282
+ if key not in resources[lang]: continue
283
+ if package in resources[lang][key]:
284
+ flag = True
285
+ if key not in processor_list:
286
+ logger.debug(f'Found {key}: {package}.')
287
+ processor_list[key].append(package)
288
+ else:
289
+ logger.debug(
290
+ f'{key}: {package} is overwritten by '
291
+ f'{key}: {processors[key]}.'
292
+ )
293
+ if not flag: logger.warning((f'Can not find package: {package}.'))
294
+ processor_list = [[key, [ModelSpecification(processor=key, package=value, dependencies=None) for value in plist]] for key, plist in processor_list.items()]
295
+ processor_list = sort_processors(processor_list)
296
+ return processor_list
297
+
298
+ def add_dependencies(resources, lang, processor_list):
299
+ """
300
+ Expand the processor_list as given in maintain_processor_list to have the dependencies
301
+
302
+ Still a list of model types to ModelSpecifications
303
+ the dependencies are tuples: name and package
304
+ for example:
305
+ [['pos', (ModelSpecification(processor='pos', package='gsd', dependencies=(('pretrain', 'gsd'),)),)],
306
+ ['depparse', (ModelSpecification(processor='depparse', package='gsd', dependencies=(('pretrain', 'gsd'),)),)]]
307
+ """
308
+ lang_resources = resources[lang]
309
+ for item in processor_list:
310
+ processor, model_specs = item
311
+ new_model_specs = []
312
+ for model_spec in model_specs:
313
+ # skip dependency checking for external variants of processors and identity lemmatizer
314
+ if not any([
315
+ model_spec.package in PROCESSOR_VARIANTS[processor],
316
+ processor == LEMMA and model_spec.package == 'identity'
317
+ ]):
318
+ dependencies = lang_resources.get(processor, {}).get(model_spec.package, {}).get('dependencies', [])
319
+ dependencies = [(dependency['model'], dependency['package']) for dependency in dependencies]
320
+ model_spec = model_spec._replace(dependencies=tuple(dependencies))
321
+ logger.debug("Found dependencies %s for processor %s model %s", dependencies, processor, model_spec.package)
322
+ new_model_specs.append(model_spec)
323
+ item[1] = tuple(new_model_specs)
324
+ return processor_list
325
+
326
+ def flatten_processor_list(processor_list):
327
+ """
328
+ The flattened processor list is just a list of types & packages
329
+
330
+ For example:
331
+ [['pos', 'gsd'], ['depparse', 'gsd'], ['pretrain', 'gsd']]
332
+ """
333
+ flattened_processor_list = []
334
+ dependencies_list = []
335
+ for item in processor_list:
336
+ processor, model_specs = item
337
+ for model_spec in model_specs:
338
+ package = model_spec.package
339
+ dependencies = model_spec.dependencies
340
+ flattened_processor_list.append([processor, package])
341
+ if dependencies:
342
+ dependencies_list += [tuple(dependency) for dependency in dependencies]
343
+ dependencies_list = [list(item) for item in set(dependencies_list)]
344
+ for processor, package in dependencies_list:
345
+ logger.debug(f'Find dependency {processor}: {package}.')
346
+ flattened_processor_list += dependencies_list
347
+ return flattened_processor_list
348
+
349
+ def set_logging_level(logging_level, verbose):
350
+ # Check verbose for easy logging control
351
+ if verbose == False:
352
+ logging_level = 'ERROR'
353
+ elif verbose == True:
354
+ logging_level = 'INFO'
355
+
356
+ if logging_level is None:
357
+ # default logging level of INFO is set in stanza.__init__
358
+ # but the user may have set it via the logging API
359
+ # it should NOT be 0, but let's check to be sure...
360
+ if logger.level == 0:
361
+ logger.setLevel('INFO')
362
+ return logger.level
363
+
364
+ # Set logging level
365
+ logging_level = logging_level.upper()
366
+ all_levels = ['DEBUG', 'INFO', 'WARNING', 'WARN', 'ERROR', 'CRITICAL', 'FATAL']
367
+ if logging_level not in all_levels:
368
+ raise ValueError(
369
+ f"Unrecognized logging level for pipeline: "
370
+ f"{logging_level}. Must be one of {', '.join(all_levels)}."
371
+ )
372
+ logger.setLevel(logging_level)
373
+ return logger.level
374
+
375
+ def process_pipeline_parameters(lang, model_dir, package, processors):
376
+ # Check parameter types and convert values to lower case
377
+ if isinstance(lang, str):
378
+ lang = lang.strip().lower()
379
+ elif lang is not None:
380
+ raise TypeError(
381
+ f"The parameter 'lang' should be str, "
382
+ f"but got {type(lang).__name__} instead."
383
+ )
384
+
385
+ if isinstance(model_dir, str):
386
+ model_dir = model_dir.strip()
387
+ elif model_dir is not None:
388
+ raise TypeError(
389
+ f"The parameter 'model_dir' should be str, "
390
+ f"but got {type(model_dir).__name__} instead."
391
+ )
392
+
393
+ if isinstance(processors, (str, list, tuple)):
394
+ # Special case: processors is str, compatible with older version
395
+ # also allow for setting alternate packages for these processors
396
+ # via the package argument
397
+ if package is None:
398
+ # each processor will be 'default' for this language
399
+ package = defaultdict(lambda: 'default')
400
+ elif isinstance(package, str):
401
+ # same, but now the named package will be the default instead
402
+ default = package
403
+ package = defaultdict(lambda: default)
404
+ elif isinstance(package, dict):
405
+ # the dictionary of packages will be used to build the processors dict
406
+ # any processor not specified in package will be 'default'
407
+ package = defaultdict(lambda: 'default', package)
408
+ else:
409
+ raise TypeError(
410
+ f"The parameter 'package' should be None, str, or dict, "
411
+ f"but got {type(package).__name__} instead."
412
+ )
413
+ if isinstance(processors, str):
414
+ processors = [x.strip().lower() for x in processors.split(",")]
415
+ processors = {
416
+ processor: package[processor] for processor in processors
417
+ }
418
+ package = None
419
+ elif isinstance(processors, dict):
420
+ processors = {
421
+ k.strip().lower(): ([v_i.strip().lower() for v_i in v] if isinstance(v, (tuple, list)) else v.strip().lower())
422
+ for k, v in processors.items()
423
+ }
424
+ elif processors is not None:
425
+ raise TypeError(
426
+ f"The parameter 'processors' should be dict or str, "
427
+ f"but got {type(processors).__name__} instead."
428
+ )
429
+
430
+ if isinstance(package, str):
431
+ package = package.strip().lower()
432
+ elif package is not None:
433
+ raise TypeError(
434
+ f"The parameter 'package' should be str, or a dict if 'processors' is a str, "
435
+ f"but got {type(package).__name__} instead."
436
+ )
437
+
438
+ return lang, model_dir, package, processors
439
+
440
+ def download_resources_json(model_dir=DEFAULT_MODEL_DIR,
441
+ resources_url=DEFAULT_RESOURCES_URL,
442
+ resources_branch=None,
443
+ resources_version=DEFAULT_RESOURCES_VERSION,
444
+ resources_filepath=None,
445
+ proxies=None):
446
+ """
447
+ Downloads resources.json to obtain latest packages.
448
+ """
449
+ if resources_url == DEFAULT_RESOURCES_URL and resources_branch is not None:
450
+ resources_url = STANZA_RESOURCES_GITHUB + resources_branch
451
+ # handle short name for resources urls; otherwise treat it as url
452
+ if resources_url.lower() in ('stanford', 'stanfordnlp'):
453
+ resources_url = STANFORDNLP_RESOURCES_URL
454
+ resources_url = f'{resources_url}/resources_{resources_version}.json'
455
+ logger.debug('Downloading resource file from %s', resources_url)
456
+ if resources_filepath is None:
457
+ resources_filepath = os.path.join(model_dir, 'resources.json')
458
+ # make request
459
+ request_file(
460
+ resources_url,
461
+ resources_filepath,
462
+ proxies,
463
+ raise_for_status=True
464
+ )
465
+
466
+
467
+ def load_resources_json(model_dir=DEFAULT_MODEL_DIR, resources_filepath=None):
468
+ """
469
+ Unpack the resources json file from the given model_dir
470
+ """
471
+ if resources_filepath is None:
472
+ resources_filepath = os.path.join(model_dir, 'resources.json')
473
+ if not os.path.exists(resources_filepath):
474
+ raise ResourcesFileNotFoundError(resources_filepath)
475
+ with open(resources_filepath, encoding="utf-8") as fin:
476
+ resources = json.load(fin)
477
+ return resources
478
+
479
+ def get_language_resources(resources, lang):
480
+ """
481
+ Get the resources for a lang from an already loaded resources json, following 'alias' if needed
482
+ """
483
+ if lang not in resources:
484
+ return None
485
+
486
+ lang_resources = resources[lang]
487
+ while 'alias' in lang_resources:
488
+ lang = lang_resources['alias']
489
+ lang_resources = resources[lang]
490
+
491
+ return lang_resources
492
+
493
+ def list_available_languages(model_dir=DEFAULT_MODEL_DIR,
494
+ resources_url=DEFAULT_RESOURCES_URL,
495
+ resources_branch=None,
496
+ resources_version=DEFAULT_RESOURCES_VERSION,
497
+ proxies=None):
498
+ """
499
+ List the non-alias languages in the resources file
500
+ """
501
+ download_resources_json(model_dir, resources_url, resources_branch, resources_version, resources_filepath=None, proxies=proxies)
502
+ resources = load_resources_json(model_dir)
503
+ # isinstance(str) is because of fields such as "url"
504
+ # 'alias' is because we want to skip German, alias of de, for example
505
+ languages = [lang for lang in resources
506
+ if not isinstance(resources[lang], str) and 'alias' not in resources[lang]]
507
+ languages = sorted(languages)
508
+ return languages
509
+
510
+ def expand_model_url(resources, model_url):
511
+ """
512
+ Returns the url in the resources dict if model_url is default, or returns the model_url
513
+ """
514
+ return resources['url'] if model_url.lower() == 'default' else model_url
515
+
516
+ def download_models(download_list,
517
+ resources,
518
+ lang,
519
+ model_dir=DEFAULT_MODEL_DIR,
520
+ resources_version=DEFAULT_RESOURCES_VERSION,
521
+ model_url=DEFAULT_MODEL_URL,
522
+ proxies=None,
523
+ log_info=True):
524
+ lang_name = resources.get(lang, {}).get('lang_name', lang)
525
+ download_table = make_table(['Processor', 'Package'], download_list)
526
+ if log_info:
527
+ log_msg = logger.info
528
+ else:
529
+ log_msg = logger.debug
530
+ log_msg(
531
+ f'Downloading these customized packages for language: '
532
+ f'{lang} ({lang_name})...\n{download_table}'
533
+ )
534
+
535
+ url = expand_model_url(resources, model_url)
536
+
537
+ # Download packages
538
+ for key, value in download_list:
539
+ try:
540
+ request_file(
541
+ url.format(resources_version=resources_version, lang=lang, filename=f"{key}/{value}.pt"),
542
+ os.path.join(model_dir, lang, key, f'{value}.pt'),
543
+ proxies,
544
+ md5=resources[lang][key][value]['md5'],
545
+ log_info=log_info,
546
+ alternate_md5=resources[lang][key][value].get('alternate_md5', None)
547
+ )
548
+ except KeyError as e:
549
+ raise ValueError(
550
+ f'Cannot find the following processor and model name combination: '
551
+ f'{key}, {value}. Please check if you have provided the correct model name.'
552
+ ) from e
553
+
554
+ # main download function
555
+ def download(
556
+ lang='en',
557
+ model_dir=DEFAULT_MODEL_DIR,
558
+ package='default',
559
+ processors={},
560
+ logging_level=None,
561
+ verbose=None,
562
+ resources_url=DEFAULT_RESOURCES_URL,
563
+ resources_branch=None,
564
+ resources_version=DEFAULT_RESOURCES_VERSION,
565
+ model_url=DEFAULT_MODEL_URL,
566
+ proxies=None,
567
+ download_json=True
568
+ ):
569
+ # set global logging level
570
+ set_logging_level(logging_level, verbose)
571
+ # process different pipeline parameters
572
+ lang, model_dir, package, processors = process_pipeline_parameters(
573
+ lang, model_dir, package, processors
574
+ )
575
+
576
+ if download_json or not os.path.exists(os.path.join(model_dir, 'resources.json')):
577
+ if not download_json:
578
+ logger.warning("Asked to skip downloading resources.json, but the file does not exist. Downloading anyway")
579
+ download_resources_json(model_dir, resources_url, resources_branch, resources_version, resources_filepath=None, proxies=proxies)
580
+
581
+ resources = load_resources_json(model_dir)
582
+ if lang not in resources:
583
+ raise UnknownLanguageError(lang)
584
+ if 'alias' in resources[lang]:
585
+ logger.info(f'"{lang}" is an alias for "{resources[lang]["alias"]}"')
586
+ lang = resources[lang]['alias']
587
+ lang_name = resources.get(lang, {}).get('lang_name', lang)
588
+ url = expand_model_url(resources, model_url)
589
+
590
+ # Default: download zipfile and unzip
591
+ if package == 'default' and (processors is None or len(processors) == 0):
592
+ logger.info(
593
+ f'Downloading default packages for language: {lang} ({lang_name}) ...'
594
+ )
595
+ # want the URL to become, for example:
596
+ # https://huggingface.co/stanfordnlp/stanza-af/resolve/v1.3.0/models/default.zip
597
+ # so we hopefully start from
598
+ # https://huggingface.co/stanfordnlp/stanza-{lang}/resolve/v{resources_version}/models/{filename}
599
+ request_file(
600
+ url.format(resources_version=resources_version, lang=lang, filename="default.zip"),
601
+ os.path.join(model_dir, lang, f'default.zip'),
602
+ proxies,
603
+ md5=resources[lang]['default_md5'],
604
+ )
605
+ unzip(os.path.join(model_dir, lang), 'default.zip')
606
+ # Customize: maintain download list
607
+ else:
608
+ download_list = maintain_processor_list(resources, lang, package, processors, allow_pretrain=True)
609
+ download_list = add_dependencies(resources, lang, download_list)
610
+ download_list = flatten_processor_list(download_list)
611
+ download_models(download_list=download_list,
612
+ resources=resources,
613
+ lang=lang,
614
+ model_dir=model_dir,
615
+ resources_version=resources_version,
616
+ model_url=model_url,
617
+ proxies=proxies,
618
+ log_info=True)
619
+ logger.info(f'Finished downloading models and saved to {model_dir}')
stanza/stanza/resources/default_packages.py ADDED
@@ -0,0 +1,909 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Constants for default packages, default pretrains, charlms, etc
3
+
4
+ Separated from prepare_resources.py so that other modules can use the
5
+ same lists / maps without importing the resources script and possibly
6
+ causing a circular import
7
+ """
8
+
9
+ import copy
10
+
11
+ # all languages will have a map which represents the available packages
12
+ PACKAGES = "packages"
13
+
14
+ # default treebank for languages
15
+ default_treebanks = {
16
+ "af": "afribooms",
17
+ # currently not publicly released! sent to us from the group developing this resource
18
+ "ang": "nerthus",
19
+ "ar": "padt",
20
+ "be": "hse",
21
+ "bg": "btb",
22
+ "bxr": "bdt",
23
+ "ca": "ancora",
24
+ "cop": "scriptorium",
25
+ "cs": "pdt",
26
+ "cu": "proiel",
27
+ "cy": "ccg",
28
+ "da": "ddt",
29
+ "de": "gsd",
30
+ "el": "gdt",
31
+ "en": "combined",
32
+ "es": "combined",
33
+ "et": "edt",
34
+ "eu": "bdt",
35
+ "fa": "perdt",
36
+ "fi": "tdt",
37
+ "fo": "farpahc",
38
+ "fr": "combined",
39
+ "fro": "profiterole",
40
+ "ga": "idt",
41
+ "gd": "arcosg",
42
+ "gl": "ctg",
43
+ "got": "proiel",
44
+ "grc": "perseus",
45
+ "gv": "cadhan",
46
+ "hbo": "ptnk",
47
+ "he": "combined",
48
+ "hi": "hdtb",
49
+ "hr": "set",
50
+ "hsb": "ufal",
51
+ "hu": "szeged",
52
+ "hy": "armtdp",
53
+ "hyw": "armtdp",
54
+ "id": "gsd",
55
+ "is": "icepahc",
56
+ "it": "combined",
57
+ "ja": "gsd",
58
+ "ka": "glc",
59
+ "kk": "ktb",
60
+ "kmr": "mg",
61
+ "ko": "kaist",
62
+ "kpv": "lattice",
63
+ "ky": "ktmu",
64
+ "la": "ittb",
65
+ "lij": "glt",
66
+ "lt": "alksnis",
67
+ "lv": "lvtb",
68
+ "lzh": "kyoto",
69
+ "mr": "ufal",
70
+ "mt": "mudt",
71
+ "my": "ucsy",
72
+ "myv": "jr",
73
+ "nb": "bokmaal",
74
+ "nds": "lsdc",
75
+ "nl": "alpino",
76
+ "nn": "nynorsk",
77
+ "olo": "kkpp",
78
+ "orv": "torot",
79
+ "ota": "boun",
80
+ "pcm": "nsc",
81
+ "pl": "pdb",
82
+ "pt": "bosque",
83
+ "qaf": "arabizi",
84
+ "qpm": "philotis",
85
+ "qtd": "sagt",
86
+ "ro": "rrt",
87
+ "ru": "syntagrus",
88
+ "sa": "vedic",
89
+ "sd": "isra",
90
+ "sk": "snk",
91
+ "sl": "ssj",
92
+ "sme": "giella",
93
+ "sq": "combined",
94
+ "sr": "set",
95
+ "sv": "talbanken",
96
+ "swl": "sslc",
97
+ "ta": "ttb",
98
+ "te": "mtg",
99
+ "th": "orchid",
100
+ "tr": "imst",
101
+ "ug": "udt",
102
+ "uk": "iu",
103
+ "ur": "udtb",
104
+ "vi": "vtb",
105
+ "wo": "wtb",
106
+ "xcl": "caval",
107
+ "zh-hans": "gsdsimp",
108
+ "zh-hant": "gsd",
109
+ "multilingual": "ud"
110
+ }
111
+
112
+ no_pretrain_languages = set([
113
+ "cop",
114
+ "olo",
115
+ "orv",
116
+ "pcm",
117
+ "qaf", # the QAF treebank is code switched and Romanized, so not easy to reuse existing resources
118
+ "qpm", # have talked about deriving this from a language neighborinig to Pomak, but that hasn't happened yet
119
+ "qtd",
120
+ "swl",
121
+
122
+ "multilingual", # special case so that all languages with a default treebank are represented somewhere
123
+ ])
124
+
125
+
126
+ # in some cases, we give the pretrain a name other than the original
127
+ # name for the UD dataset
128
+ # we will eventually do this for all of the pretrains
129
+ specific_default_pretrains = {
130
+ "af": "fasttextwiki",
131
+ "ang": "nerthus",
132
+ "ar": "conll17",
133
+ "be": "fasttextwiki",
134
+ "bg": "conll17",
135
+ "bxr": "fasttextwiki",
136
+ "ca": "conll17",
137
+ "cs": "conll17",
138
+ "cu": "conll17",
139
+ "cy": "fasttext157",
140
+ "da": "conll17",
141
+ "de": "conll17",
142
+ "el": "conll17",
143
+ "en": "conll17",
144
+ "es": "conll17",
145
+ "et": "conll17",
146
+ "eu": "conll17",
147
+ "fa": "conll17",
148
+ "fi": "conll17",
149
+ "fo": "fasttextwiki",
150
+ "fr": "conll17",
151
+ "fro": "conll17",
152
+ "ga": "conll17",
153
+ "gd": "fasttextwiki",
154
+ "gl": "conll17",
155
+ "got": "fasttextwiki",
156
+ "grc": "conll17",
157
+ "gv": "fasttext157",
158
+ "hbo": "utah",
159
+ "he": "conll17",
160
+ "hi": "conll17",
161
+ "hr": "conll17",
162
+ "hsb": "fasttextwiki",
163
+ "hu": "conll17",
164
+ "hy": "isprasglove",
165
+ "hyw": "isprasglove",
166
+ "id": "conll17",
167
+ "is": "fasttext157",
168
+ "it": "conll17",
169
+ "ja": "conll17",
170
+ "ka": "fasttext157",
171
+ "kk": "fasttext157",
172
+ "kmr": "fasttextwiki",
173
+ "ko": "conll17",
174
+ "kpv": "fasttextwiki",
175
+ "ky": "fasttext157",
176
+ "la": "conll17",
177
+ "lij": "fasttextwiki",
178
+ "lt": "fasttextwiki",
179
+ "lv": "conll17",
180
+ "lzh": "fasttextwiki",
181
+ "mr": "fasttextwiki",
182
+ "mt": "fasttextwiki",
183
+ "my": "ucsy",
184
+ "myv": "mokha",
185
+ "nb": "conll17",
186
+ "nds": "fasttext157",
187
+ "nl": "conll17",
188
+ "nn": "conll17",
189
+ "ota": "conll17",
190
+ "pl": "conll17",
191
+ "pt": "conll17",
192
+ "ro": "conll17",
193
+ "ru": "conll17",
194
+ "sa": "fasttext157",
195
+ "sd": "isra",
196
+ "sk": "conll17",
197
+ "sl": "conll17",
198
+ "sme": "fasttextwiki",
199
+ "sq": "fasttext157",
200
+ "sr": "fasttextwiki",
201
+ "sv": "conll17",
202
+ "ta": "fasttextwiki",
203
+ "te": "fasttextwiki",
204
+ "th": "fasttext157",
205
+ "tr": "conll17",
206
+ "ug": "conll17",
207
+ "uk": "conll17",
208
+ "ur": "conll17",
209
+ "vi": "conll17",
210
+ "wo": "fasttextwiki",
211
+ "xcl": "caval",
212
+ "zh-hans": "fasttext157",
213
+ "zh-hant": "conll17",
214
+ }
215
+
216
+
217
+ def build_default_pretrains(default_treebanks):
218
+ default_pretrains = dict(default_treebanks)
219
+ for lang in no_pretrain_languages:
220
+ default_pretrains.pop(lang, None)
221
+ for lang in specific_default_pretrains.keys():
222
+ default_pretrains[lang] = specific_default_pretrains[lang]
223
+ return default_pretrains
224
+
225
+ default_pretrains = build_default_pretrains(default_treebanks)
226
+
227
+ pos_pretrains = {
228
+ "en": {
229
+ "craft": "biomed",
230
+ "genia": "biomed",
231
+ "mimic": "mimic",
232
+ },
233
+ }
234
+
235
+ depparse_pretrains = pos_pretrains
236
+
237
+ ner_pretrains = {
238
+ "ar": {
239
+ "aqmar": "fasttextwiki",
240
+ },
241
+ "de": {
242
+ "conll03": "fasttextwiki",
243
+ # the bert version of germeval uses the smaller vector file
244
+ "germeval2014": "fasttextwiki",
245
+ },
246
+ "en": {
247
+ "anatem": "biomed",
248
+ "bc4chemd": "biomed",
249
+ "bc5cdr": "biomed",
250
+ "bionlp13cg": "biomed",
251
+ "jnlpba": "biomed",
252
+ "linnaeus": "biomed",
253
+ "ncbi_disease": "biomed",
254
+ "s800": "biomed",
255
+
256
+ "ontonotes": "fasttextcrawl",
257
+ # the stanza-train sample NER model should use the default NER pretrain
258
+ # for English, that is the same as ontonotes
259
+ "sample": "fasttextcrawl",
260
+
261
+ "conll03": "glove",
262
+
263
+ "i2b2": "mimic",
264
+ "radiology": "mimic",
265
+ },
266
+ "es": {
267
+ "ancora": "fasttextwiki",
268
+ "conll02": "fasttextwiki",
269
+ },
270
+ "nl": {
271
+ "conll02": "fasttextwiki",
272
+ "wikiner": "fasttextwiki",
273
+ },
274
+ "ru": {
275
+ "wikiner": "fasttextwiki",
276
+ },
277
+ "th": {
278
+ "lst20": "fasttext157",
279
+ },
280
+ }
281
+
282
+
283
+ # default charlms for languages
284
+ default_charlms = {
285
+ "af": "oscar",
286
+ "ang": "nerthus1024",
287
+ "ar": "ccwiki",
288
+ "bg": "conll17",
289
+ "da": "oscar",
290
+ "de": "newswiki",
291
+ "en": "1billion",
292
+ "es": "newswiki",
293
+ "fa": "conll17",
294
+ "fi": "conll17",
295
+ "fr": "newswiki",
296
+ "he": "oscar",
297
+ "hi": "oscar",
298
+ "id": "oscar2023",
299
+ "it": "conll17",
300
+ "ja": "conll17",
301
+ "kk": "oscar",
302
+ "mr": "l3cube",
303
+ "my": "oscar",
304
+ "nb": "conll17",
305
+ "nl": "ccwiki",
306
+ "pl": "oscar",
307
+ "pt": "oscar2023",
308
+ "ru": "newswiki",
309
+ "sd": "isra",
310
+ "sv": "conll17",
311
+ "te": "oscar2022",
312
+ "th": "oscar",
313
+ "tr": "conll17",
314
+ "uk": "conll17",
315
+ "vi": "conll17",
316
+ "zh-hans": "gigaword"
317
+ }
318
+
319
+ pos_charlms = {
320
+ "en": {
321
+ # none of the English charlms help with craft or genia
322
+ "craft": None,
323
+ "genia": None,
324
+ "mimic": "mimic",
325
+ },
326
+ "tr": { # no idea why, but this particular one goes down in dev score
327
+ "boun": None,
328
+ },
329
+ }
330
+
331
+ depparse_charlms = copy.deepcopy(pos_charlms)
332
+
333
+ lemma_charlms = copy.deepcopy(pos_charlms)
334
+
335
+ ner_charlms = {
336
+ "en": {
337
+ "conll03": "1billion",
338
+ "ontonotes": "1billion",
339
+ "anatem": "pubmed",
340
+ "bc4chemd": "pubmed",
341
+ "bc5cdr": "pubmed",
342
+ "bionlp13cg": "pubmed",
343
+ "i2b2": "mimic",
344
+ "jnlpba": "pubmed",
345
+ "linnaeus": "pubmed",
346
+ "ncbi_disease": "pubmed",
347
+ "radiology": "mimic",
348
+ "s800": "pubmed",
349
+ },
350
+ "hu": {
351
+ "combined": None,
352
+ },
353
+ "nn": {
354
+ "norne": None,
355
+ },
356
+ }
357
+
358
+ # default ner for languages
359
+ default_ners = {
360
+ "af": "nchlt",
361
+ "ar": "aqmar_charlm",
362
+ "bg": "bsnlp19",
363
+ "da": "ddt",
364
+ "de": "germeval2014",
365
+ "en": "ontonotes-ww-multi_charlm",
366
+ "es": "conll02",
367
+ "fa": "arman",
368
+ "fi": "turku",
369
+ "fr": "wikinergold_charlm",
370
+ "he": "iahlt_charlm",
371
+ "hu": "combined",
372
+ "hy": "armtdp",
373
+ "it": "fbk",
374
+ "ja": "gsd",
375
+ "kk": "kazNERD",
376
+ "mr": "l3cube",
377
+ "my": "ucsy",
378
+ "nb": "norne",
379
+ "nl": "conll02",
380
+ "nn": "norne",
381
+ "pl": "nkjp",
382
+ "ru": "wikiner",
383
+ "sd": "siner",
384
+ "sv": "suc3shuffle",
385
+ "th": "lst20",
386
+ "tr": "starlang",
387
+ "uk": "languk",
388
+ "vi": "vlsp",
389
+ "zh-hans": "ontonotes",
390
+ }
391
+
392
+ # a few languages have sentiment classifier models
393
+ default_sentiment = {
394
+ "en": "sstplus_charlm",
395
+ "de": "sb10k_charlm",
396
+ "es": "tass2020_charlm",
397
+ "mr": "l3cube_charlm",
398
+ "vi": "vsfc_charlm",
399
+ "zh-hans": "ren_charlm",
400
+ }
401
+
402
+ # also, a few languages (very few, currently) have constituency parser models
403
+ default_constituency = {
404
+ "da": "arboretum_charlm",
405
+ "de": "spmrl_charlm",
406
+ "en": "ptb3-revised_charlm",
407
+ "es": "combined_charlm",
408
+ "id": "icon_charlm",
409
+ "it": "vit_charlm",
410
+ "ja": "alt_charlm",
411
+ "pt": "cintil_charlm",
412
+ #"tr": "starlang_charlm",
413
+ "vi": "vlsp22_charlm",
414
+ "zh-hans": "ctb-51_charlm",
415
+ }
416
+
417
+ optional_constituency = {
418
+ "tr": "starlang_charlm",
419
+ }
420
+
421
+ # an alternate tokenizer for languages which aren't trained from a base UD source
422
+ default_tokenizer = {
423
+ "my": "alt",
424
+ }
425
+
426
+ # ideally we would have a less expensive model as the base model
427
+ #default_coref = {
428
+ # "en": "ontonotes_roberta-large_finetuned",
429
+ #}
430
+
431
+ optional_coref = {
432
+ "ca": "udcoref_xlm-roberta-lora",
433
+ "cs": "udcoref_xlm-roberta-lora",
434
+ "de": "udcoref_xlm-roberta-lora",
435
+ "en": "udcoref_xlm-roberta-lora",
436
+ "es": "udcoref_xlm-roberta-lora",
437
+ "fr": "udcoref_xlm-roberta-lora",
438
+ "hi": "deeph_muril-large-cased-lora",
439
+ # UD Coref has both nb and nn datasets for Norwegian
440
+ "nb": "udcoref_xlm-roberta-lora",
441
+ "nn": "udcoref_xlm-roberta-lora",
442
+ "pl": "udcoref_xlm-roberta-lora",
443
+ "ru": "udcoref_xlm-roberta-lora",
444
+ }
445
+
446
+ """
447
+ default transformers to use for various languages
448
+
449
+ we try to document why we choose a particular model in each case
450
+ """
451
+ TRANSFORMERS = {
452
+ # We tested three candidate AR models on POS, Depparse, and NER
453
+ #
454
+ # POS: padt dev set scores, AllTags
455
+ # depparse: padt dev set scores, LAS
456
+ # NER: dev scores on a random split of AQMAR, entity scores
457
+ #
458
+ # pos depparse ner
459
+ # none (pt & charlm only) 94.08 83.49 84.19
460
+ # asafaya/bert-base-arabic 95.10 84.96 85.98
461
+ # aubmindlab/bert-base-arabertv2 95.33 85.28 84.93
462
+ # aubmindlab/araelectra-base-discriminator 95.66 85.83 86.10
463
+ "ar": "aubmindlab/araelectra-base-discriminator",
464
+
465
+ # https://huggingface.co/Maltehb/danish-bert-botxo
466
+ # contrary to normal expectations, this hurts F1
467
+ # on a dev split by about 1 F1
468
+ # "da": "Maltehb/danish-bert-botxo",
469
+ #
470
+ # the multilingual bert is a marginal improvement for conparse
471
+ #
472
+ # December 2022 update:
473
+ # there are quite a few Danish transformers available on HuggingFace
474
+ # here are the results of training a constituency parser with adadelta/adamw
475
+ # on each of them:
476
+ #
477
+ # no bert 0.8245 0.8230
478
+ # alexanderfalk/danbert-small-cased 0.8236 0.8286
479
+ # Geotrend/distilbert-base-da-cased 0.8268 0.8306
480
+ # sarnikowski/convbert-small-da-cased 0.8322 0.8341
481
+ # bert-base-multilingual-cased 0.8341 0.8342
482
+ # vesteinn/ScandiBERT-no-faroese 0.8373 0.8408
483
+ # Maltehb/danish-bert-botxo 0.8383 0.8408
484
+ # vesteinn/ScandiBERT 0.8421 0.8475
485
+ #
486
+ # Also, two models have token windows too short for use with the
487
+ # Danish dataset:
488
+ # jonfd/electra-small-nordic
489
+ # Maltehb/aelaectra-danish-electra-small-cased
490
+ #
491
+ "da": "vesteinn/ScandiBERT",
492
+
493
+ # As of April 2022, the bert models available have a weird
494
+ # tokenizer issue where soft hyphen causes it to crash.
495
+ # We attempt to compensate for that in the dev branch
496
+ #
497
+ # NER scores
498
+ # model dev text
499
+ # xlm-roberta-large 86.56 85.23
500
+ # bert-base-german-cased 87.59 86.95
501
+ # dbmdz/bert-base-german-cased 88.27 87.47
502
+ # german-nlp-group/electra-base-german-uncased 88.60 87.09
503
+ #
504
+ # constituency scores w/ peft, March 2024 model, in-order
505
+ # model dev test
506
+ # xlm-roberta-base 95.17 93.34
507
+ # xlm-roberta-large 95.86 94.46 (!!!)
508
+ # bert-base 95.24 93.24
509
+ # dbmdz/bert 95.32 93.33
510
+ # german/electra 95.72 94.05
511
+ #
512
+ # POS scores
513
+ # model dev test
514
+ # None 88.65 87.28
515
+ # xlm-roberta-large 89.21 88.11
516
+ # bert-base 89.52 88.42
517
+ # dbmdz/bert 89.67 88.54
518
+ # german/electra 89.98 88.66
519
+ #
520
+ # depparse scores, LAS
521
+ # model dev test
522
+ # None 87.76 84.37
523
+ # xlm-roberta-large 89.00 85.79
524
+ # bert-base 88.72 85.40
525
+ # dbmdz/bert 88.70 85.14
526
+ # german/electra 89.21 86.06
527
+ "de": "german-nlp-group/electra-base-german-uncased",
528
+
529
+ # experiments on various forms of roberta & electra
530
+ # https://huggingface.co/roberta-base
531
+ # https://huggingface.co/roberta-large
532
+ # https://huggingface.co/google/electra-small-discriminator
533
+ # https://huggingface.co/google/electra-base-discriminator
534
+ # https://huggingface.co/google/electra-large-discriminator
535
+ #
536
+ # experiments using the different models for POS tagging,
537
+ # dev set, including WV and charlm, AllTags score:
538
+ # roberta-base: 95.67
539
+ # roberta-large: 95.98
540
+ # electra-small: 95.31
541
+ # electra-base: 95.90
542
+ # electra-large: 96.01
543
+ #
544
+ # depparse scores, dev set, no finetuning, with WV and charlm
545
+ # UAS LAS CLAS MLAS BLEX
546
+ # roberta-base: 93.16 91.20 89.87 89.38 89.87
547
+ # roberta-large: 93.47 91.56 90.13 89.71 90.13
548
+ # electra-small: 92.17 90.02 88.25 87.66 88.25
549
+ # electra-base: 93.42 91.44 90.10 89.67 90.10
550
+ # electra-large: 94.07 92.17 90.99 90.53 90.99
551
+ #
552
+ # conparse scores, dev & test set, with WV and charlm
553
+ # roberta_base: 96.05 95.60
554
+ # roberta_large: 95.95 95.60
555
+ # electra-small: 95.33 95.04
556
+ # electra-base: 96.09 95.98
557
+ # electra-large: 96.25 96.14
558
+ #
559
+ # conparse scores w/ finetune, dev & test set, with WV and charlm
560
+ # roberta_base: 96.07 95.81
561
+ # roberta_large: 96.37 96.41 (!!!)
562
+ # electra-small: 95.62 95.36
563
+ # electra-base: 96.21 95.94
564
+ # electra-large: 96.40 96.32
565
+ #
566
+ "en": "google/electra-large-discriminator",
567
+
568
+ # TODO need to test, possibly compare with others
569
+ "es": "bertin-project/bertin-roberta-base-spanish",
570
+
571
+ # NER scores for a couple Persian options:
572
+ # none:
573
+ # dev: 2022-04-23 01:44:53 INFO: fa_arman 79.46
574
+ # test: 2022-04-23 01:45:03 INFO: fa_arman 80.06
575
+ #
576
+ # HooshvareLab/bert-fa-zwnj-base
577
+ # dev: 2022-04-23 02:43:44 INFO: fa_arman 80.87
578
+ # test: 2022-04-23 02:44:07 INFO: fa_arman 80.81
579
+ #
580
+ # HooshvareLab/roberta-fa-zwnj-base
581
+ # dev: 2022-04-23 16:23:25 INFO: fa_arman 81.23
582
+ # test: 2022-04-23 16:23:48 INFO: fa_arman 81.11
583
+ #
584
+ # HooshvareLab/bert-base-parsbert-uncased
585
+ # dev: 2022-04-26 10:42:09 INFO: fa_arman 82.49
586
+ # test: 2022-04-26 10:42:31 INFO: fa_arman 83.16
587
+ "fa": 'HooshvareLab/bert-base-parsbert-uncased',
588
+
589
+ # NER scores for a couple options:
590
+ # none:
591
+ # dev: 2022-03-04 INFO: fi_turku 83.45
592
+ # test: 2022-03-04 INFO: fi_turku 86.25
593
+ #
594
+ # bert-base-multilingual-cased
595
+ # dev: 2022-03-04 INFO: fi_turku 85.23
596
+ # test: 2022-03-04 INFO: fi_turku 89.00
597
+ #
598
+ # TurkuNLP/bert-base-finnish-cased-v1:
599
+ # dev: 2022-03-04 INFO: fi_turku 88.41
600
+ # test: 2022-03-04 INFO: fi_turku 91.36
601
+ "fi": "TurkuNLP/bert-base-finnish-cased-v1",
602
+
603
+ # POS dev set tagging results for French:
604
+ # No bert:
605
+ # 98.60 100.00 98.55 98.04
606
+ # dbmdz/electra-base-french-europeana-cased-discriminator
607
+ # 98.70 100.00 98.69 98.24
608
+ # benjamin/roberta-base-wechsel-french
609
+ # 98.71 100.00 98.75 98.26
610
+ # camembert/camembert-large
611
+ # 98.75 100.00 98.75 98.30
612
+ # camembert-base
613
+ # 98.78 100.00 98.77 98.33
614
+ #
615
+ # GSD depparse dev set results for French:
616
+ # No bert:
617
+ # 95.83 94.52 91.34 91.10 91.34
618
+ # camembert/camembert-large
619
+ # 96.80 95.71 93.37 93.13 93.37
620
+ # TODO: the rest of the chart
621
+ "fr": "camembert/camembert-large",
622
+
623
+ # Ancient Greek has a surprising number of transformers, considering
624
+ # Model POS Depparse LAS
625
+ # None 0.8812 0.7684
626
+ # Microbert M 0.8883 0.7706
627
+ # Microbert MX 0.8910 0.7755
628
+ # Microbert MXP 0.8916 0.7742
629
+ # Pranaydeeps Bert 0.9139 0.7987
630
+ "grc": "pranaydeeps/Ancient-Greek-BERT",
631
+
632
+ # a couple possibilities to experiment with for Hebrew
633
+ # dev scores for POS and depparse
634
+ # https://huggingface.co/imvladikon/alephbertgimmel-base-512
635
+ # UPOS XPOS UFeats AllTags
636
+ # 97.25 97.25 92.84 91.81
637
+ # UAS LAS CLAS MLAS BLEX
638
+ # 94.42 92.47 89.49 88.82 89.49
639
+ #
640
+ # https://huggingface.co/onlplab/alephbert-base
641
+ # UPOS XPOS UFeats AllTags
642
+ # 97.37 97.37 92.50 91.55
643
+ # UAS LAS CLAS MLAS BLEX
644
+ # 94.06 92.12 88.80 88.13 88.80
645
+ #
646
+ # https://huggingface.co/avichr/heBERT
647
+ # UPOS XPOS UFeats AllTags
648
+ # 97.09 97.09 92.36 91.28
649
+ # UAS LAS CLAS MLAS BLEX
650
+ # 94.29 92.30 88.99 88.38 88.99
651
+ "he": "imvladikon/alephbertgimmel-base-512",
652
+
653
+ # can also experiment with xlm-roberta
654
+ # on a coref dataset from IITH, span F1:
655
+ # dev test
656
+ # xlm-roberta-large 0.63635 0.66579
657
+ # muril-large 0.65369 0.68290
658
+ "hi": "google/muril-large-cased",
659
+
660
+ # https://huggingface.co/xlm-roberta-base
661
+ # Scores by entity for armtdp NER on 18 labels:
662
+ # no bert : 86.68
663
+ # xlm-roberta-base : 89.31
664
+ "hy": "xlm-roberta-base",
665
+
666
+ # Indonesian POS experiments: dev set of GSD
667
+ # python3 stanza/utils/training/run_pos.py id_gsd --no_bert
668
+ # python3 stanza/utils/training/run_pos.py id_gsd --bert_model ...
669
+ # also ran on the ICON constituency dataset
670
+ # model POS CON
671
+ # no_bert 89.95 84.74
672
+ # flax-community/indonesian-roberta-large 89.78 (!) xxx
673
+ # flax-community/indonesian-roberta-base 90.14 xxx
674
+ # indobenchmark/indobert-base-p2 90.09
675
+ # indobenchmark/indobert-base-p1 90.14
676
+ # indobenchmark/indobert-large-p1 90.19
677
+ # indolem/indobert-base-uncased 90.21 88.60
678
+ # cahya/bert-base-indonesian-1.5G 90.32 88.15
679
+ # cahya/roberta-base-indonesian-1.5G 90.40 87.27
680
+ "id": "indolem/indobert-base-uncased",
681
+
682
+ # from https://github.com/idb-ita/GilBERTo
683
+ # annoyingly, it doesn't handle cased text
684
+ # supposedly there is an argument "do_lower_case"
685
+ # but that still leaves a lot of unk tokens
686
+ # "it": "idb-ita/gilberto-uncased-from-camembert",
687
+ #
688
+ # from https://github.com/musixmatchresearch/umberto
689
+ # on NER, this gets 88.37 dev and 91.02 test
690
+ # another option is dbmdz/bert-base-italian-cased,
691
+ # which gets 87.27 dev and 90.32 test
692
+ #
693
+ # in-order constituency parser on the VIT dev set:
694
+ # dbmdz/bert-base-italian-cased 0.8079
695
+ # dbmdz/bert-base-italian-xxl-cased: 0.8195
696
+ # Musixmatch/umberto-commoncrawl-cased-v1: 0.8256
697
+ # dbmdz/electra-base-italian-xxl-cased-discriminator: 0.8314
698
+ #
699
+ # FBK NER dev set:
700
+ # dbmdz/bert-base-italian-cased: 87.76
701
+ # Musixmatch/umberto-commoncrawl-cased-v1: 88.62
702
+ # dbmdz/bert-base-italian-xxl-cased: 88.84
703
+ # dbmdz/electra-base-italian-xxl-cased-discriminator: 89.91
704
+ #
705
+ # combined UD POS dev set: UPOS XPOS UFeats AllTags
706
+ # dbmdz/bert-base-italian-cased: 98.62 98.53 98.06 97.49
707
+ # dbmdz/bert-base-italian-xxl-cased: 98.61 98.54 98.07 97.58
708
+ # dbmdz/electra-base-italian-xxl-cased-discriminator: 98.64 98.54 98.14 97.61
709
+ # Musixmatch/umberto-commoncrawl-cased-v1: 98.56 98.45 98.13 97.62
710
+ "it": "dbmdz/electra-base-italian-xxl-cased-discriminator",
711
+
712
+ # for Japanese
713
+ # there are others that would also work,
714
+ # but they require different tokenizers instead of being
715
+ # plug & play
716
+ #
717
+ # Constitutency scores on ALT (in-order)
718
+ # no bert: 90.68 dev, 91.40 test
719
+ # rinna: 91.54 dev, 91.89 test
720
+ "ja": "rinna/japanese-roberta-base",
721
+
722
+ # could also try:
723
+ # l3cube-pune/marathi-bert-v2
724
+ # or
725
+ # https://huggingface.co/l3cube-pune/hindi-marathi-dev-roberta
726
+ # l3cube-pune/hindi-marathi-dev-roberta
727
+ #
728
+ # depparse ufal dev scores:
729
+ # no transformer 74.89 63.70 57.43 53.01 57.43
730
+ # l3cube-pune/marathi-roberta 76.48 66.21 61.20 57.60 61.20
731
+ "mr": "l3cube-pune/marathi-roberta",
732
+
733
+ # https://huggingface.co/allegro/herbert-base-cased
734
+ # Scores by entity on the NKJP NER task:
735
+ # no bert (dev/test): 88.64/88.75
736
+ # herbert-base-cased (dev/test): 91.48/91.02,
737
+ # herbert-large-cased (dev/test): 92.25/91.62
738
+ # sdadas/polish-roberta-large-v2 (dev/test): 92.66/91.22
739
+ "pl": "allegro/herbert-base-cased",
740
+
741
+ # experiments on the cintil conparse dataset
742
+ # ran a variety of transformer settings
743
+ # found the following dev set scores after 400 iterations:
744
+ # Geotrend/distilbert-base-pt-cased : not plug & play
745
+ # no bert: 0.9082
746
+ # xlm-roberta-base: 0.9109
747
+ # xlm-roberta-large: 0.9254
748
+ # adalbertojunior/distilbert-portuguese-cased: 0.9300
749
+ # neuralmind/bert-base-portuguese-cased: 0.9307
750
+ # neuralmind/bert-large-portuguese-cased: 0.9343
751
+ "pt": "neuralmind/bert-large-portuguese-cased",
752
+
753
+ # hope is actually to build our own using a large text collection
754
+ "sd": "google/muril-large-cased",
755
+
756
+ # Tamil options: quite a few, need to run a bunch of experiments
757
+ # dev pos dev depparse las
758
+ # no transformer 82.82 69.12
759
+ # ai4bharat/indic-bert 82.98 70.47
760
+ # lgessler/microbert-tamil-mxp 83.21 69.28
761
+ # monsoon-nlp/tamillion 83.37 69.28
762
+ # l3cube-pune/tamil-bert 85.27 72.53
763
+ # d42kw01f/Tamil-RoBERTa 85.59 70.55
764
+ # google/muril-base-cased 85.67 72.68
765
+ # google/muril-large-cased 86.30 72.45
766
+ "ta": "google/muril-large-cased",
767
+
768
+
769
+ # https://huggingface.co/dbmdz/bert-base-turkish-128k-cased
770
+ # helps the Turkish model quite a bit
771
+ "tr": "dbmdz/bert-base-turkish-128k-cased",
772
+
773
+ # from https://github.com/VinAIResearch/PhoBERT
774
+ # "vi": "vinai/phobert-base",
775
+ # using 6 or 7 layers of phobert-large is slightly
776
+ # more effective for constituency parsing than
777
+ # using 4 layers of phobert-base
778
+ # ... going beyond 4 layers of phobert-base
779
+ # does not help the scores
780
+ "vi": "vinai/phobert-large",
781
+
782
+ # https://github.com/ymcui/Chinese-BERT-wwm
783
+ # there's also hfl/chinese-roberta-wwm-ext-large
784
+ # or hfl/chinese-electra-base-discriminator
785
+ # or hfl/chinese-electra-180g-large-discriminator,
786
+ # which works better than the below roberta on constituency
787
+ # "zh-hans": "hfl/chinese-roberta-wwm-ext",
788
+ "zh-hans": "hfl/chinese-electra-180g-large-discriminator",
789
+ }
790
+
791
+ TRANSFORMER_LAYERS = {
792
+ # not clear what the best number is without more experiments,
793
+ # but more than 4 is working better than just 4
794
+ "vi": 7,
795
+ }
796
+
797
+ TRANSFORMER_NICKNAMES = {
798
+ # ar
799
+ "asafaya/bert-base-arabic": "asafaya-bert",
800
+ "aubmindlab/araelectra-base-discriminator": "aubmind-electra",
801
+ "aubmindlab/bert-base-arabertv2": "aubmind-bert",
802
+
803
+ # da
804
+ "vesteinn/ScandiBERT": "scandibert",
805
+
806
+ # de
807
+ "bert-base-german-cased": "bert-base-german-cased",
808
+ "dbmdz/bert-base-german-cased": "dbmdz-bert-german-cased",
809
+ "german-nlp-group/electra-base-german-uncased": "german-nlp-electra",
810
+
811
+ # en
812
+ "bert-base-multilingual-cased": "mbert",
813
+ "xlm-roberta-large": "xlm-roberta-large",
814
+ "google/electra-large-discriminator": "electra-large",
815
+ "microsoft/deberta-v3-large": "deberta-v3-large",
816
+
817
+ # es
818
+ "bertin-project/bertin-roberta-base-spanish": "bertin-roberta",
819
+
820
+ # fa
821
+ "HooshvareLab/bert-base-parsbert-uncased": "parsbert",
822
+
823
+ # fi
824
+ "TurkuNLP/bert-base-finnish-cased-v1": "bert",
825
+
826
+ # fr
827
+ "benjamin/roberta-base-wechsel-french": "wechsel-roberta",
828
+ "camembert-base": "camembert-base",
829
+ "camembert/camembert-large": "camembert-large",
830
+ "dbmdz/electra-base-french-europeana-cased-discriminator": "dbmdz-electra",
831
+
832
+ # grc
833
+ "pranaydeeps/Ancient-Greek-BERT": "grc-pranaydeeps",
834
+ "lgessler/microbert-ancient-greek-m": "grc-microbert-m",
835
+ "lgessler/microbert-ancient-greek-mx": "grc-microbert-mx",
836
+ "lgessler/microbert-ancient-greek-mxp": "grc-microbert-mxp",
837
+ "altsoph/bert-base-ancientgreek-uncased": "grc-altsoph",
838
+
839
+ # he
840
+ "imvladikon/alephbertgimmel-base-512" : "alephbertgimmel",
841
+
842
+ # hy
843
+ "xlm-roberta-base": "xlm-roberta-base",
844
+
845
+ # id
846
+ "indolem/indobert-base-uncased": "indobert",
847
+ "indobenchmark/indobert-large-p1": "indobenchmark-large-p1",
848
+ "indobenchmark/indobert-base-p1": "indobenchmark-base-p1",
849
+ "indobenchmark/indobert-lite-large-p1": "indobenchmark-lite-large-p1",
850
+ "indobenchmark/indobert-lite-base-p1": "indobenchmark-lite-base-p1",
851
+ "indobenchmark/indobert-large-p2": "indobenchmark-large-p2",
852
+ "indobenchmark/indobert-base-p2": "indobenchmark-base-p2",
853
+ "indobenchmark/indobert-lite-large-p2": "indobenchmark-lite-large-p2",
854
+ "indobenchmark/indobert-lite-base-p2": "indobenchmark-lite-base-p2",
855
+
856
+ # it
857
+ "dbmdz/electra-base-italian-xxl-cased-discriminator": "electra",
858
+
859
+ # ja
860
+ "rinna/japanese-roberta-base": "rinna-roberta",
861
+
862
+ # mr
863
+ "l3cube-pune/marathi-roberta": "l3cube-marathi-roberta",
864
+
865
+ # pl
866
+ "allegro/herbert-base-cased": "herbert",
867
+
868
+ # pt
869
+ "neuralmind/bert-large-portuguese-cased": "bertimbau",
870
+
871
+ # ta: tamil
872
+ "monsoon-nlp/tamillion": "tamillion",
873
+ "lgessler/microbert-tamil-m": "ta-microbert-m",
874
+ "lgessler/microbert-tamil-mxp": "ta-microbert-mxp",
875
+ "l3cube-pune/tamil-bert": "l3cube-tamil-bert",
876
+ "d42kw01f/Tamil-RoBERTa": "ta-d42kw01f-roberta",
877
+
878
+ # tr
879
+ "dbmdz/bert-base-turkish-128k-cased": "bert",
880
+
881
+ # vi
882
+ "vinai/phobert-base": "phobert-base",
883
+ "vinai/phobert-large": "phobert-large",
884
+
885
+ # zh
886
+ "google-bert/bert-base-chinese": "google-bert-chinese",
887
+ "hfl/chinese-roberta-wwm-ext": "hfl-roberta-chinese",
888
+ "hfl/chinese-electra-180g-large-discriminator": "electra-large",
889
+
890
+ # multi-lingual Indic
891
+ "ai4bharat/indic-bert": "indic-bert",
892
+ "google/muril-base-cased": "muril-base-cased",
893
+ "google/muril-large-cased": "muril-large-cased",
894
+ }
895
+
896
+ def known_nicknames():
897
+ """
898
+ Return a list of all the transformer nicknames
899
+
900
+ We return a list so that we can sort them in decreasing key length
901
+ """
902
+ nicknames = list(value for key, value in TRANSFORMER_NICKNAMES.items())
903
+
904
+ # previously unspecific transformers get "transformer" as the nickname
905
+ nicknames.append("transformer")
906
+
907
+ nicknames = sorted(nicknames, key=lambda x: -len(x))
908
+
909
+ return nicknames
stanza/stanza/resources/installation.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Functions for setting up the environments.
3
+ """
4
+
5
+ import os
6
+ import logging
7
+ import zipfile
8
+ import shutil
9
+
10
+ from stanza.resources.common import HOME_DIR, request_file, unzip, \
11
+ get_root_from_zipfile, set_logging_level
12
+
13
+ logger = logging.getLogger('stanza')
14
+
15
+ DEFAULT_CORENLP_MODEL_URL = os.getenv(
16
+ 'CORENLP_MODEL_URL',
17
+ 'https://huggingface.co/stanfordnlp/corenlp-{model}/resolve/{tag}/stanford-corenlp-models-{model}.jar'
18
+ )
19
+ BACKUP_CORENLP_MODEL_URL = "http://nlp.stanford.edu/software/stanford-corenlp-{version}-models-{model}.jar"
20
+
21
+ DEFAULT_CORENLP_URL = os.getenv(
22
+ 'CORENLP_MODEL_URL',
23
+ 'https://huggingface.co/stanfordnlp/CoreNLP/resolve/{tag}/stanford-corenlp-latest.zip'
24
+ )
25
+
26
+ DEFAULT_CORENLP_DIR = os.getenv(
27
+ 'CORENLP_HOME',
28
+ os.path.join(HOME_DIR, 'stanza_corenlp')
29
+ )
30
+
31
+ AVAILABLE_MODELS = set(['arabic', 'chinese', 'english-extra', 'english-kbp', 'french', 'german', 'hungarian', 'italian', 'spanish'])
32
+
33
+
34
+ def download_corenlp_models(model, version, dir=DEFAULT_CORENLP_DIR, url=DEFAULT_CORENLP_MODEL_URL, logging_level='INFO', proxies=None, force=True):
35
+ """
36
+ A automatic way to download the CoreNLP models.
37
+
38
+ Args:
39
+ model: the name of the model, can be one of 'arabic', 'chinese', 'english',
40
+ 'english-kbp', 'french', 'german', 'hungarian', 'italian', 'spanish'
41
+ version: the version of the model
42
+ dir: the directory to download CoreNLP model into; alternatively can be
43
+ set up with environment variable $CORENLP_HOME
44
+ url: The link to download CoreNLP models.
45
+ It will need {model} and either {version} or {tag} to properly format the URL
46
+ logging_level: logging level to use during installation
47
+ force: Download model anyway, no matter model file exists or not
48
+ """
49
+ dir = os.path.expanduser(dir)
50
+ if not model or not version:
51
+ raise ValueError(
52
+ "Both model and model version should be specified."
53
+ )
54
+ logger.info(f"Downloading {model} models (version {version}) into directory {dir}")
55
+ model = model.strip().lower()
56
+ if model not in AVAILABLE_MODELS:
57
+ raise KeyError(
58
+ f'{model} is currently not supported. '
59
+ f'Must be one of: {list(AVAILABLE_MODELS)}.'
60
+ )
61
+ # for example:
62
+ # https://huggingface.co/stanfordnlp/CoreNLP/resolve/v4.2.2/stanford-corenlp-models-french.jar
63
+ tag = version if version == 'main' else 'v' + version
64
+ download_url = url.format(tag=tag, model=model, version=version)
65
+ model_path = os.path.join(dir, f'stanford-corenlp-{version}-models-{model}.jar')
66
+
67
+ if os.path.exists(model_path) and not force:
68
+ logger.warn(
69
+ f"Model file {model_path} already exists. "
70
+ f"Please download this model to a new directory.")
71
+ return
72
+
73
+ try:
74
+ request_file(
75
+ download_url,
76
+ model_path,
77
+ proxies
78
+ )
79
+ except (KeyboardInterrupt, SystemExit):
80
+ raise
81
+ except Exception as e:
82
+ raise RuntimeError(
83
+ "Downloading CoreNLP model file failed. "
84
+ "Please try manual downloading at: https://stanfordnlp.github.io/CoreNLP/."
85
+ ) from e
86
+
87
+
88
+ def install_corenlp(dir=DEFAULT_CORENLP_DIR, url=DEFAULT_CORENLP_URL, logging_level=None, proxies=None, version="main"):
89
+ """
90
+ A fully automatic way to install and setting up the CoreNLP library
91
+ to use the client functionality.
92
+
93
+ Args:
94
+ dir: the directory to download CoreNLP model into; alternatively can be
95
+ set up with environment variable $CORENLP_HOME
96
+ url: The link to download CoreNLP models
97
+ Needs a {version} or {tag} parameter to specify the version
98
+ logging_level: logging level to use during installation
99
+ """
100
+ dir = os.path.expanduser(dir)
101
+ set_logging_level(logging_level=logging_level, verbose=None)
102
+ if os.path.exists(dir) and len(os.listdir(dir)) > 0:
103
+ logger.warn(
104
+ f"Directory {dir} already exists. "
105
+ f"Please install CoreNLP to a new directory.")
106
+ return
107
+
108
+ logger.info(f"Installing CoreNLP package into {dir}")
109
+ # First download the URL package
110
+ logger.debug(f"Download to destination file: {os.path.join(dir, 'corenlp.zip')}")
111
+ tag = version if version == 'main' else 'v' + version
112
+ url = url.format(version=version, tag=tag)
113
+ try:
114
+ request_file(url, os.path.join(dir, 'corenlp.zip'), proxies)
115
+
116
+ except (KeyboardInterrupt, SystemExit):
117
+ raise
118
+ except Exception as e:
119
+ raise RuntimeError(
120
+ "Downloading CoreNLP zip file failed. "
121
+ "Please try manual installation: https://stanfordnlp.github.io/CoreNLP/."
122
+ ) from e
123
+
124
+ # Unzip corenlp into dir
125
+ logger.debug("Unzipping downloaded zip file...")
126
+ unzip(dir, 'corenlp.zip')
127
+
128
+ # By default CoreNLP will be unzipped into a version-dependent folder,
129
+ # e.g., stanford-corenlp-4.0.0. We need some hack around that and move
130
+ # files back into our designated folder
131
+ logger.debug(f"Moving files into the designated folder at: {dir}")
132
+ corenlp_dirname = get_root_from_zipfile(os.path.join(dir, 'corenlp.zip'))
133
+ corenlp_dirname = os.path.join(dir, corenlp_dirname)
134
+ for f in os.listdir(corenlp_dirname):
135
+ shutil.move(os.path.join(corenlp_dirname, f), dir)
136
+
137
+ # Remove original zip and folder
138
+ logger.debug("Removing downloaded zip file...")
139
+ os.remove(os.path.join(dir, 'corenlp.zip'))
140
+ shutil.rmtree(corenlp_dirname)
141
+
142
+ # Warn user to set up env
143
+ if dir != DEFAULT_CORENLP_DIR:
144
+ logger.warning(
145
+ f"For customized installation location, please set the `CORENLP_HOME` "
146
+ f"environment variable to the location of the installation. "
147
+ f"In Unix, this is done with `export CORENLP_HOME={dir}`.")
148
+
stanza/stanza/resources/prepare_resources.py ADDED
@@ -0,0 +1,670 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Converts a directory of models organized by type into a directory organized by language.
3
+
4
+ Also produces the resources.json file.
5
+
6
+ For example, on the cluster, you can do this:
7
+
8
+ python3 -m stanza.resources.prepare_resources --input_dir /u/nlp/software/stanza/models/current-models-1.5.0 --output_dir /u/nlp/software/stanza/models/1.5.0 > resources.out 2>&1
9
+ nlprun -a stanza-1.2 -q john "python3 -m stanza.resources.prepare_resources --input_dir /u/nlp/software/stanza/models/current-models-1.5.0 --output_dir /u/nlp/software/stanza/models/1.5.0" -o resources.out
10
+ """
11
+
12
+ import argparse
13
+ from collections import defaultdict
14
+ import json
15
+ import os
16
+ from pathlib import Path
17
+ import hashlib
18
+ import shutil
19
+ import zipfile
20
+
21
+ from stanza import __resources_version__
22
+ from stanza.models.common.constant import lcode2lang, two_to_three_letters, three_to_two_letters
23
+ from stanza.resources.default_packages import PACKAGES, TRANSFORMERS, TRANSFORMER_NICKNAMES
24
+ from stanza.resources.default_packages import *
25
+ from stanza.utils.datasets.prepare_lemma_classifier import DATASET_MAPPING as LEMMA_CLASSIFIER_DATASETS
26
+ from stanza.utils.get_tqdm import get_tqdm
27
+
28
+ tqdm = get_tqdm()
29
+
30
+ def parse_args():
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument('--input_dir', type=str, default="/u/nlp/software/stanza/models/current-models-%s" % __resources_version__, help='Input dir for various models. Defaults to the recommended home on the nlp cluster')
33
+ parser.add_argument('--output_dir', type=str, default="/u/nlp/software/stanza/models/%s" % __resources_version__, help='Output dir for various models.')
34
+ parser.add_argument('--packages_only', action='store_true', default=False, help='Only build the package maps instead of rebuilding everything')
35
+ parser.add_argument('--lang', type=str, default=None, help='Only process this language or a comma-separated list of languages. If left blank, will prepare all languages. To use this argument, a previous prepared resources with all of the languages is necessary.')
36
+ args = parser.parse_args()
37
+ args.input_dir = os.path.abspath(args.input_dir)
38
+ args.output_dir = os.path.abspath(args.output_dir)
39
+ if args.lang is not None:
40
+ args.lang = ",".join(args.lang.strip().split())
41
+ return args
42
+
43
+
44
+ allowed_empty_languages = [
45
+ # we don't have a lot of Thai support yet
46
+ "th",
47
+ # only tokenize and NER for Myanmar right now (soon...)
48
+ "my",
49
+ ]
50
+
51
+ # map processor name to file ending
52
+ # the order of this dict determines the order in which default.zip files are built
53
+ # changing it will necessitate rebuilding all of the default.zip files
54
+ # not a disaster, but it would involve a bunch of uploading
55
+ processor_to_ending = {
56
+ "tokenize": "tokenizer",
57
+ "mwt": "mwt_expander",
58
+ "lemma": "lemmatizer",
59
+ "pos": "tagger",
60
+ "depparse": "parser",
61
+ "pretrain": "pretrain",
62
+ "ner": "nertagger",
63
+ "forward_charlm": "forward_charlm",
64
+ "backward_charlm": "backward_charlm",
65
+ "sentiment": "sentiment",
66
+ "constituency": "constituency",
67
+ "coref": "coref",
68
+ "langid": "langid",
69
+ }
70
+ ending_to_processor = {j: i for i, j in processor_to_ending.items()}
71
+ PROCESSORS = list(processor_to_ending.keys())
72
+
73
+ def ensure_dir(dir):
74
+ Path(dir).mkdir(parents=True, exist_ok=True)
75
+
76
+
77
+ def copy_file(src, dst):
78
+ ensure_dir(Path(dst).parent)
79
+ shutil.copy2(src, dst)
80
+
81
+
82
+ def get_md5(path):
83
+ data = open(path, 'rb').read()
84
+ return hashlib.md5(data).hexdigest()
85
+
86
+
87
+ def split_model_name(model):
88
+ """
89
+ Split model names by _
90
+
91
+ Takes into account packages with _ and processor types with _
92
+ """
93
+ model = model[:-3].replace('.', '_')
94
+ # sort by key length so that nertagger is checked before tagger, for example
95
+ for processor in sorted(ending_to_processor.keys(), key=lambda x: -len(x)):
96
+ if model.endswith(processor):
97
+ model = model[:-(len(processor)+1)]
98
+ processor = ending_to_processor[processor]
99
+ break
100
+ else:
101
+ raise AssertionError(f"Could not find a processor type in {model}")
102
+ lang, package = model.split('_', 1)
103
+ return lang, package, processor
104
+
105
+ def split_package(package):
106
+ if package.endswith("_finetuned"):
107
+ package = package[:-10]
108
+
109
+ if package.endswith("_nopretrain"):
110
+ package = package[:-11]
111
+ return package, False, False
112
+ if package.endswith("_nocharlm"):
113
+ package = package[:-9]
114
+ return package, True, False
115
+ if package.endswith("_charlm"):
116
+ package = package[:-7]
117
+ return package, True, True
118
+ underscore = package.rfind("_")
119
+ if underscore >= 0:
120
+ # +1 to skip the underscore
121
+ nickname = package[underscore+1:]
122
+ if nickname in known_nicknames():
123
+ return package[:underscore], True, True
124
+
125
+ # guess it was a model which wasn't built with the new naming convention of putting the pretrain type at the end
126
+ # assume WV and charlm... if the language / package doesn't allow for one, that should be caught later
127
+ return package, True, True
128
+
129
+ def get_pretrain_package(lang, package, model_pretrains, default_pretrains):
130
+ package, uses_pretrain, _ = split_package(package)
131
+
132
+ if not uses_pretrain or lang in no_pretrain_languages:
133
+ return None
134
+ elif model_pretrains is not None and lang in model_pretrains and package in model_pretrains[lang]:
135
+ return model_pretrains[lang][package]
136
+ elif lang in default_pretrains:
137
+ return default_pretrains[lang]
138
+
139
+ raise RuntimeError("pretrain not specified for lang %s package %s" % (lang, package))
140
+
141
+ def get_charlm_package(lang, package, model_charlms, default_charlms):
142
+ package, _, uses_charlm = split_package(package)
143
+
144
+ if not uses_charlm:
145
+ return None
146
+
147
+ if model_charlms is not None and lang in model_charlms and package in model_charlms[lang]:
148
+ return model_charlms[lang][package]
149
+ else:
150
+ return default_charlms.get(lang, None)
151
+
152
+ def get_con_dependencies(lang, package):
153
+ # so far, this invariant is true:
154
+ # constituency models use the default pretrain and charlm for the language
155
+ # sometimes there is no charlm for a language that has constituency, though
156
+ pretrain_package = get_pretrain_package(lang, package, None, default_pretrains)
157
+ dependencies = [{'model': 'pretrain', 'package': pretrain_package}]
158
+
159
+ charlm_package = default_charlms.get(lang, None)
160
+ if charlm_package is not None:
161
+ dependencies.append({'model': 'forward_charlm', 'package': charlm_package})
162
+ dependencies.append({'model': 'backward_charlm', 'package': charlm_package})
163
+
164
+ return dependencies
165
+
166
+ def get_pos_charlm_package(lang, package):
167
+ return get_charlm_package(lang, package, pos_charlms, default_charlms)
168
+
169
+ def get_pos_dependencies(lang, package):
170
+ dependencies = []
171
+
172
+ pretrain_package = get_pretrain_package(lang, package, pos_pretrains, default_pretrains)
173
+ if pretrain_package is not None:
174
+ dependencies.append({'model': 'pretrain', 'package': pretrain_package})
175
+
176
+ charlm_package = get_pos_charlm_package(lang, package)
177
+ if charlm_package is not None:
178
+ dependencies.append({'model': 'forward_charlm', 'package': charlm_package})
179
+ dependencies.append({'model': 'backward_charlm', 'package': charlm_package})
180
+
181
+ return dependencies
182
+
183
+ def get_lemma_pretrain_package(lang, package):
184
+ package, uses_pretrain, uses_charlm = split_package(package)
185
+ if not uses_pretrain:
186
+ return None
187
+ if not uses_charlm:
188
+ # currently the contextual lemma classifier is only active
189
+ # for the charlm lemmatizers
190
+ return None
191
+ if "%s_%s" % (lang, package) not in LEMMA_CLASSIFIER_DATASETS:
192
+ return None
193
+ return get_pretrain_package(lang, package, {}, default_pretrains)
194
+
195
+ def get_lemma_charlm_package(lang, package):
196
+ return get_charlm_package(lang, package, lemma_charlms, default_charlms)
197
+
198
+ def get_lemma_dependencies(lang, package):
199
+ dependencies = []
200
+
201
+ pretrain_package = get_lemma_pretrain_package(lang, package)
202
+ if pretrain_package is not None:
203
+ dependencies.append({'model': 'pretrain', 'package': pretrain_package})
204
+
205
+ charlm_package = get_lemma_charlm_package(lang, package)
206
+ if charlm_package is not None:
207
+ dependencies.append({'model': 'forward_charlm', 'package': charlm_package})
208
+ dependencies.append({'model': 'backward_charlm', 'package': charlm_package})
209
+
210
+ return dependencies
211
+
212
+
213
+ def get_depparse_charlm_package(lang, package):
214
+ return get_charlm_package(lang, package, depparse_charlms, default_charlms)
215
+
216
+ def get_depparse_dependencies(lang, package):
217
+ dependencies = []
218
+
219
+ pretrain_package = get_pretrain_package(lang, package, depparse_pretrains, default_pretrains)
220
+ if pretrain_package is not None:
221
+ dependencies.append({'model': 'pretrain', 'package': pretrain_package})
222
+
223
+ charlm_package = get_depparse_charlm_package(lang, package)
224
+ if charlm_package is not None:
225
+ dependencies.append({'model': 'forward_charlm', 'package': charlm_package})
226
+ dependencies.append({'model': 'backward_charlm', 'package': charlm_package})
227
+
228
+ return dependencies
229
+
230
+ def get_ner_charlm_package(lang, package):
231
+ return get_charlm_package(lang, package, ner_charlms, default_charlms)
232
+
233
+ def get_ner_pretrain_package(lang, package):
234
+ return get_pretrain_package(lang, package, ner_pretrains, default_pretrains)
235
+
236
+ def get_ner_dependencies(lang, package):
237
+ dependencies = []
238
+
239
+ pretrain_package = get_ner_pretrain_package(lang, package)
240
+ if pretrain_package is not None:
241
+ dependencies.append({'model': 'pretrain', 'package': pretrain_package})
242
+
243
+ charlm_package = get_ner_charlm_package(lang, package)
244
+ if charlm_package is not None:
245
+ dependencies.append({'model': 'forward_charlm', 'package': charlm_package})
246
+ dependencies.append({'model': 'backward_charlm', 'package': charlm_package})
247
+
248
+ return dependencies
249
+
250
+ def get_sentiment_dependencies(lang, package):
251
+ """
252
+ Return a list of dependencies for the sentiment model
253
+
254
+ Generally this will be pretrain, forward & backward charlm
255
+ So far, this invariant is true:
256
+ sentiment models use the default pretrain for the language
257
+ also, they all use the default charlm for a language
258
+ """
259
+ pretrain_package = get_pretrain_package(lang, package, None, default_pretrains)
260
+ dependencies = [{'model': 'pretrain', 'package': pretrain_package}]
261
+
262
+ charlm_package = default_charlms.get(lang, None)
263
+ if charlm_package is not None:
264
+ dependencies.append({'model': 'forward_charlm', 'package': charlm_package})
265
+ dependencies.append({'model': 'backward_charlm', 'package': charlm_package})
266
+
267
+ return dependencies
268
+
269
+ def get_dependencies(processor, lang, package):
270
+ """
271
+ Get the dependencies for a particular lang/package based on the package name
272
+
273
+ The package can include descriptors such as _nopretrain, _nocharlm, _charlm
274
+ which inform whether or not this particular model uses charlm or pretrain
275
+ """
276
+ if processor == 'depparse':
277
+ return get_depparse_dependencies(lang, package)
278
+ elif processor == 'lemma':
279
+ return get_lemma_dependencies(lang, package)
280
+ elif processor == 'pos':
281
+ return get_pos_dependencies(lang, package)
282
+ elif processor == 'ner':
283
+ return get_ner_dependencies(lang, package)
284
+ elif processor == 'sentiment':
285
+ return get_sentiment_dependencies(lang, package)
286
+ elif processor == 'constituency':
287
+ return get_con_dependencies(lang, package)
288
+ return {}
289
+
290
+ def process_dirs(args):
291
+ dirs = sorted(os.listdir(args.input_dir))
292
+ resources = {}
293
+ if args.lang:
294
+ resources = json.load(open(os.path.join(args.output_dir, 'resources.json')))
295
+ # this one language gets overridden
296
+ # if this is not done, and we reuse the old resources,
297
+ # any models which were deleted will still be in the resources
298
+ for lang in args.lang.split(","):
299
+ resources[lang] = {}
300
+
301
+ for model_dir in dirs:
302
+ print(f"Processing models in {model_dir}")
303
+ models = sorted(os.listdir(os.path.join(args.input_dir, model_dir)))
304
+ for model in tqdm(models):
305
+ if not model.endswith('.pt'): continue
306
+ # get processor
307
+ lang, package, processor = split_model_name(model)
308
+ if args.lang and lang not in args.lang.split(","):
309
+ continue
310
+
311
+ # copy file
312
+ input_path = os.path.join(args.input_dir, model_dir, model)
313
+ output_path = os.path.join(args.output_dir, lang, "models", processor, package + '.pt')
314
+ copy_file(input_path, output_path)
315
+ # maintain md5
316
+ md5 = get_md5(output_path)
317
+ # maintain dependencies
318
+ dependencies = get_dependencies(processor, lang, package)
319
+ # maintain resources
320
+ if lang not in resources: resources[lang] = {}
321
+ if processor not in resources[lang]: resources[lang][processor] = {}
322
+ if dependencies:
323
+ resources[lang][processor][package] = {'md5': md5, 'dependencies': dependencies}
324
+ else:
325
+ resources[lang][processor][package] = {'md5': md5}
326
+ print("Processed initial model directories. Writing preliminary resources.json")
327
+ json.dump(resources, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2)
328
+
329
+ def get_default_pos_package(lang, ud_package):
330
+ charlm_package = get_pos_charlm_package(lang, ud_package)
331
+ if charlm_package is not None:
332
+ return ud_package + "_charlm"
333
+ if lang in no_pretrain_languages:
334
+ return ud_package + "_nopretrain"
335
+ return ud_package + "_nocharlm"
336
+
337
+ def get_default_depparse_package(lang, ud_package):
338
+ charlm_package = get_depparse_charlm_package(lang, ud_package)
339
+ if charlm_package is not None:
340
+ return ud_package + "_charlm"
341
+ if lang in no_pretrain_languages:
342
+ return ud_package + "_nopretrain"
343
+ return ud_package + "_nocharlm"
344
+
345
+ def process_default_zips(args):
346
+ resources = json.load(open(os.path.join(args.output_dir, 'resources.json')))
347
+ for lang in resources:
348
+ # check url, alias, and lang_name in case we are rerunning this step on an already built resources.json
349
+ if lang == 'url':
350
+ continue
351
+ if 'alias' in resources[lang]:
352
+ continue
353
+ if all(k in ("backward_charlm", "forward_charlm", "pretrain", "lang_name") for k in resources[lang].keys()):
354
+ continue
355
+ if lang not in default_treebanks:
356
+ raise AssertionError(f'{lang} not in default treebanks!!!')
357
+
358
+ if args.lang and lang not in args.lang.split(","):
359
+ continue
360
+
361
+ print(f'Preparing default models for language {lang}')
362
+
363
+ models_needed = defaultdict(set)
364
+
365
+ packages = resources[lang][PACKAGES]["default"]
366
+ for processor, package in packages.items():
367
+ if processor == 'lemma' and package == 'identity':
368
+ continue
369
+ if processor == 'optional':
370
+ continue
371
+ models_needed[processor].add(package)
372
+ dependencies = get_dependencies(processor, lang, package)
373
+ for dependency in dependencies:
374
+ models_needed[dependency['model']].add(dependency['package'])
375
+
376
+ model_files = []
377
+ for processor in PROCESSORS:
378
+ if processor in models_needed:
379
+ for package in sorted(models_needed[processor]):
380
+ filename = os.path.join(args.output_dir, lang, "models", processor, package + '.pt')
381
+ if os.path.exists(filename):
382
+ print(" Model {} package {}: file {}".format(processor, package, filename))
383
+ model_files.append((filename, processor, package))
384
+ else:
385
+ raise FileNotFoundError(f"Processor {processor} package {package} needed for {lang} but cannot be found at {filename}")
386
+
387
+ with zipfile.ZipFile(os.path.join(args.output_dir, lang, 'models', 'default.zip'), 'w', zipfile.ZIP_DEFLATED) as zipf:
388
+ for filename, processor, package in model_files:
389
+ zipf.write(filename=filename, arcname=os.path.join(processor, package + '.pt'))
390
+
391
+ default_md5 = get_md5(os.path.join(args.output_dir, lang, 'models', 'default.zip'))
392
+ resources[lang]['default_md5'] = default_md5
393
+
394
+ print("Processed default model zips. Writing resources.json")
395
+ json.dump(resources, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2)
396
+
397
+ def get_default_processors(resources, lang):
398
+ """
399
+ Build a default package for this language
400
+
401
+ Will add each of pos, lemma, depparse, etc if those are available
402
+ Uses the existing models scraped from the language directories into resources.json, as relevant
403
+ """
404
+ if lang == "multilingual":
405
+ return {"langid": "ud"}
406
+
407
+ default_package = default_treebanks[lang]
408
+ default_processors = {}
409
+ if lang in default_tokenizer:
410
+ default_processors['tokenize'] = default_tokenizer[lang]
411
+ else:
412
+ default_processors['tokenize'] = default_package
413
+
414
+ if 'mwt' in resources[lang] and default_processors['tokenize'] in resources[lang]['mwt']:
415
+ # if this doesn't happen, we just skip MWT
416
+ default_processors['mwt'] = default_package
417
+
418
+ if 'lemma' in resources[lang]:
419
+ expected_lemma = default_package + "_nocharlm"
420
+ if expected_lemma in resources[lang]['lemma']:
421
+ default_processors['lemma'] = expected_lemma
422
+ elif lang not in allowed_empty_languages:
423
+ default_processors['lemma'] = 'identity'
424
+
425
+ if 'pos' in resources[lang]:
426
+ default_processors['pos'] = get_default_pos_package(lang, default_package)
427
+ if default_processors['pos'] not in resources[lang]['pos']:
428
+ raise AssertionError("Expected POS model not in resources: %s" % default_processors['pos'])
429
+ elif lang not in allowed_empty_languages:
430
+ raise AssertionError("Expected to find POS models for language %s" % lang)
431
+
432
+ if 'depparse' in resources[lang]:
433
+ default_processors['depparse'] = get_default_depparse_package(lang, default_package)
434
+ if default_processors['depparse'] not in resources[lang]['depparse']:
435
+ raise AssertionError("Expected depparse model not in resources: %s" % default_processors['depparse'])
436
+ elif lang not in allowed_empty_languages:
437
+ raise AssertionError("Expected to find depparse models for language %s" % lang)
438
+
439
+ if lang in default_ners:
440
+ default_processors['ner'] = default_ners[lang]
441
+
442
+ if lang in default_sentiment:
443
+ default_processors['sentiment'] = default_sentiment[lang]
444
+
445
+ if lang in default_constituency:
446
+ default_processors['constituency'] = default_constituency[lang]
447
+
448
+ optional = get_default_optional_processors(resources, lang)
449
+ if optional:
450
+ default_processors['optional'] = optional
451
+
452
+ return default_processors
453
+
454
+ def get_default_optional_processors(resources, lang):
455
+ optional_processors = {}
456
+ if lang in optional_constituency:
457
+ optional_processors['constituency'] = optional_constituency[lang]
458
+
459
+ if lang in optional_coref:
460
+ optional_processors['coref'] = optional_coref[lang]
461
+
462
+ return optional_processors
463
+
464
+ def update_processor_add_transformer(resources, lang, current_processors, processor, transformer):
465
+ if processor not in current_processors:
466
+ return
467
+
468
+ new_model = current_processors[processor].replace('_charlm', "_" + transformer).replace('_nocharlm', "_" + transformer)
469
+ if new_model in resources[lang][processor]:
470
+ current_processors[processor] = new_model
471
+ else:
472
+ print("WARNING: wanted to use %s for %s accurate %s, but that model does not exist" % (new_model, lang, processor))
473
+
474
+ def get_default_accurate(resources, lang):
475
+ """
476
+ A package that, if available, uses charlm and transformer models for each processor
477
+ """
478
+ default_processors = get_default_processors(resources, lang)
479
+
480
+ if 'lemma' in default_processors and default_processors['lemma'] != 'identity':
481
+ lemma_model = default_processors['lemma']
482
+ lemma_model = lemma_model.replace('_nocharlm', '_charlm')
483
+ charlm_package = get_lemma_charlm_package(lang, lemma_model)
484
+ if charlm_package is not None:
485
+ if lemma_model in resources[lang]['lemma']:
486
+ default_processors['lemma'] = lemma_model
487
+ else:
488
+ print("WARNING: wanted to use %s for %s default_accurate lemma, but that model does not exist" % (lemma_model, lang))
489
+
490
+ transformer = TRANSFORMER_NICKNAMES.get(TRANSFORMERS.get(lang, None), None)
491
+ if transformer is not None:
492
+ for processor in ('pos', 'depparse', 'constituency', 'sentiment'):
493
+ update_processor_add_transformer(resources, lang, default_processors, processor, transformer)
494
+ if 'ner' in default_processors and (default_processors['ner'].endswith("_charlm") or default_processors['ner'].endswith("_nocharlm")):
495
+ update_processor_add_transformer(resources, lang, default_processors, "ner", transformer)
496
+
497
+ optional = get_optional_accurate(resources, lang)
498
+ if optional:
499
+ default_processors['optional'] = optional
500
+
501
+ return default_processors
502
+
503
+ def get_optional_accurate(resources, lang):
504
+ optional_processors = get_default_optional_processors(resources, lang)
505
+
506
+ transformer = TRANSFORMER_NICKNAMES.get(TRANSFORMERS.get(lang, None), None)
507
+ if transformer is not None:
508
+ for processor in ('pos', 'depparse', 'constituency', 'sentiment'):
509
+ update_processor_add_transformer(resources, lang, optional_processors, processor, transformer)
510
+
511
+ if lang in optional_coref:
512
+ optional_processors['coref'] = optional_coref[lang]
513
+
514
+ return optional_processors
515
+
516
+
517
+ def get_default_fast(resources, lang):
518
+ """
519
+ Build a packages entry which only has the nocharlm models
520
+
521
+ Will make it easy for people to use the lower tier of models
522
+
523
+ We do this by building the same default package as normal,
524
+ then switching everything out for the lower tier model when possible.
525
+ We also remove constituency, as it is super slow.
526
+ Note that in the case of a language which doesn't have a charlm,
527
+ that means we wind up building the same for default and default_nocharlm
528
+ """
529
+ default_processors = get_default_processors(resources, lang)
530
+
531
+ # this is a slow model and we don't have non-charlm versions of it yet
532
+ if 'constituency' in default_processors:
533
+ default_processors.pop('constituency')
534
+
535
+ for processor, model in default_processors.items():
536
+ if "_charlm" in model:
537
+ nocharlm = model.replace("_charlm", "_nocharlm")
538
+ if nocharlm not in resources[lang][processor]:
539
+ print("WARNING: wanted to use %s for %s default_fast processor %s, but that model does not exist" % (nocharlm, lang, processor))
540
+ else:
541
+ default_processors[processor] = nocharlm
542
+
543
+ return default_processors
544
+
545
+ def process_packages(args):
546
+ """
547
+ Build a package for a language's default processors and all of the treebanks specifically used for that language
548
+ """
549
+ resources = json.load(open(os.path.join(args.output_dir, 'resources.json')))
550
+
551
+ for lang in resources:
552
+ # check url, alias, and lang_name in case we are rerunning this step on an already built resources.json
553
+ if lang == 'url':
554
+ continue
555
+ if 'alias' in resources[lang]:
556
+ continue
557
+ if all(k in ("backward_charlm", "forward_charlm", "pretrain", "lang_name") for k in resources[lang].keys()):
558
+ continue
559
+ if lang not in default_treebanks:
560
+ raise AssertionError(f'{lang} not in default treebanks!!!')
561
+
562
+ if args.lang and lang not in args.lang.split(","):
563
+ continue
564
+
565
+ default_processors = get_default_processors(resources, lang)
566
+
567
+ # TODO: eventually we can remove default_processors
568
+ # For now, we want to keep this so that v1.5.1 is compatible
569
+ # with the next iteration of resources files
570
+ resources[lang]['default_processors'] = default_processors
571
+ resources[lang][PACKAGES] = {}
572
+ resources[lang][PACKAGES]['default'] = default_processors
573
+
574
+ if lang not in no_pretrain_languages and lang != "multilingual":
575
+ default_fast = get_default_fast(resources, lang)
576
+ resources[lang][PACKAGES]['default_fast'] = default_fast
577
+
578
+ default_accurate = get_default_accurate(resources, lang)
579
+ resources[lang][PACKAGES]['default_accurate'] = default_accurate
580
+
581
+ # Now we loop over each of the tokenizers for this language
582
+ # ... we use this as a proxy for the available UD treebanks
583
+ # This loop also catches things such as "craft" which are
584
+ # included treebanks that aren't UD
585
+ # We then create a package in the packages dict for each of those treebanks
586
+ if 'tokenize' in resources[lang]:
587
+ for package in resources[lang]['tokenize']:
588
+ processors = {"tokenize": package}
589
+ if "mwt" in resources[lang] and package in resources[lang]["mwt"]:
590
+ processors["mwt"] = package
591
+
592
+ if "pos" in resources[lang]:
593
+ if package + "_charlm" in resources[lang]["pos"]:
594
+ processors["pos"] = package + "_charlm"
595
+ elif package + "_nocharlm" in resources[lang]["pos"]:
596
+ processors["pos"] = package + "_nocharlm"
597
+
598
+ if "lemma" in resources[lang] and "pos" in processors:
599
+ lemma_package = package + "_nocharlm"
600
+ if lemma_package in resources[lang]["lemma"]:
601
+ processors["lemma"] = lemma_package
602
+
603
+ if "depparse" in resources[lang] and "pos" in processors:
604
+ depparse_package = None
605
+ if package + "_charlm" in resources[lang]["depparse"]:
606
+ depparse_package = package + "_charlm"
607
+ elif package + "_nocharlm" in resources[lang]["depparse"]:
608
+ depparse_package = package + "_nocharlm"
609
+ # we want to set the lemma first if it's identity
610
+ # THEN set the depparse
611
+ if depparse_package is not None:
612
+ if "lemma" not in processors:
613
+ processors["lemma"] = "identity"
614
+ processors["depparse"] = depparse_package
615
+
616
+ resources[lang][PACKAGES][package] = processors
617
+
618
+ print("Processed packages. Writing resources.json")
619
+ json.dump(resources, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2)
620
+
621
+ def process_lcode(args):
622
+ resources = json.load(open(os.path.join(args.output_dir, 'resources.json')))
623
+ resources_new = {}
624
+ resources_new["multilingual"] = resources["multilingual"]
625
+ for lang in resources:
626
+ if lang == 'multilingual':
627
+ continue
628
+ if 'alias' in resources[lang]:
629
+ continue
630
+ if lang not in lcode2lang:
631
+ print(lang + ' not found in lcode2lang!')
632
+ continue
633
+ lang_name = lcode2lang[lang]
634
+ resources[lang]['lang_name'] = lang_name
635
+ resources_new[lang.lower()] = resources[lang.lower()]
636
+ resources_new[lang_name.lower()] = {'alias': lang.lower()}
637
+ if lang.lower() in two_to_three_letters:
638
+ resources_new[two_to_three_letters[lang.lower()]] = {'alias': lang.lower()}
639
+ elif lang.lower() in three_to_two_letters:
640
+ resources_new[three_to_two_letters[lang.lower()]] = {'alias': lang.lower()}
641
+ print("Processed lcode aliases. Writing resources.json")
642
+ json.dump(resources_new, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2)
643
+
644
+
645
+ def process_misc(args):
646
+ resources = json.load(open(os.path.join(args.output_dir, 'resources.json')))
647
+ resources['no'] = {'alias': 'nb'}
648
+ resources['zh'] = {'alias': 'zh-hans'}
649
+ # This is intended to be unformatted. expand_model_url in common.py will fill in the raw string
650
+ # with the appropriate values in order to find the needed model file on huggingface
651
+ resources['url'] = 'https://huggingface.co/stanfordnlp/stanza-{lang}/resolve/v{resources_version}/models/{filename}'
652
+ print("Finalized misc attributes. Writing resources.json")
653
+ json.dump(resources, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2)
654
+
655
+
656
+ def main():
657
+ args = parse_args()
658
+ print("Converting models from %s to %s" % (args.input_dir, args.output_dir))
659
+ if not args.packages_only:
660
+ process_dirs(args)
661
+ process_packages(args)
662
+ if not args.packages_only:
663
+ process_default_zips(args)
664
+ process_lcode(args)
665
+ process_misc(args)
666
+
667
+
668
+ if __name__ == '__main__':
669
+ main()
670
+
stanza/stanza/server/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from stanza.protobuf import to_text
2
+ from stanza.protobuf import Document, Sentence, Token, IndexedWord, Span
3
+ from stanza.protobuf import ParseTree, DependencyGraph, CorefChain
4
+ from stanza.protobuf import Mention, NERMention, Entity, Relation, RelationTriple, Timex
5
+ from stanza.protobuf import Quote, SpeakerInfo
6
+ from stanza.protobuf import Operator, Polarity
7
+ from stanza.protobuf import SentenceFragment, TokenLocation
8
+ from stanza.protobuf import MapStringString, MapIntString
9
+ from .client import CoreNLPClient, AnnotationException, TimeoutException, PermanentlyFailedException, StartServer
10
+ from .annotator import Annotator
stanza/stanza/server/annotator.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Defines a base class that can be used to annotate.
3
+ """
4
+ import io
5
+ from multiprocessing import Process
6
+ from http.server import BaseHTTPRequestHandler, HTTPServer
7
+ from http import client as HTTPStatus
8
+
9
+ from stanza.protobuf import Document, parseFromDelimitedString, writeToDelimitedString
10
+
11
+ class Annotator(Process):
12
+ """
13
+ This annotator base class hosts a lightweight server that accepts
14
+ annotation requests from CoreNLP.
15
+ Each annotator simply defines 3 functions: requires, provides and annotate.
16
+
17
+ This class takes care of defining appropriate endpoints to interface
18
+ with CoreNLP.
19
+ """
20
+ @property
21
+ def name(self):
22
+ """
23
+ Name of the annotator (used by CoreNLP)
24
+ """
25
+ raise NotImplementedError()
26
+
27
+ @property
28
+ def requires(self):
29
+ """
30
+ Requires has to specify all the annotations required before we
31
+ are called.
32
+ """
33
+ raise NotImplementedError()
34
+
35
+ @property
36
+ def provides(self):
37
+ """
38
+ The set of annotations guaranteed to be provided when we are done.
39
+ NOTE: that these annotations are either fully qualified Java
40
+ class names or refer to nested classes of
41
+ edu.stanford.nlp.ling.CoreAnnotations (as is the case below).
42
+ """
43
+ raise NotImplementedError()
44
+
45
+ def annotate(self, ann):
46
+ """
47
+ @ann: is a protobuf annotation object.
48
+ Actually populate @ann with tokens.
49
+ """
50
+ raise NotImplementedError()
51
+
52
+ @property
53
+ def properties(self):
54
+ """
55
+ Defines a Java property to define this annotator to CoreNLP.
56
+ """
57
+ return {
58
+ "customAnnotatorClass.{}".format(self.name): "edu.stanford.nlp.pipeline.GenericWebServiceAnnotator",
59
+ "generic.endpoint": "http://{}:{}".format(self.host, self.port),
60
+ "generic.requires": ",".join(self.requires),
61
+ "generic.provides": ",".join(self.provides),
62
+ }
63
+
64
+ class _Handler(BaseHTTPRequestHandler):
65
+ annotator = None
66
+
67
+ def __init__(self, request, client_address, server):
68
+ BaseHTTPRequestHandler.__init__(self, request, client_address, server)
69
+
70
+ def do_GET(self):
71
+ """
72
+ Handle a ping request
73
+ """
74
+ if not self.path.endswith("/"): self.path += "/"
75
+ if self.path == "/ping/":
76
+ msg = "pong".encode("UTF-8")
77
+
78
+ self.send_response(HTTPStatus.OK)
79
+ self.send_header("Content-Type", "text/application")
80
+ self.send_header("Content-Length", len(msg))
81
+ self.end_headers()
82
+ self.wfile.write(msg)
83
+ else:
84
+ self.send_response(HTTPStatus.BAD_REQUEST)
85
+ self.end_headers()
86
+
87
+ def do_POST(self):
88
+ """
89
+ Handle an annotate request
90
+ """
91
+ if not self.path.endswith("/"): self.path += "/"
92
+ if self.path == "/annotate/":
93
+ # Read message
94
+ length = int(self.headers.get('content-length'))
95
+ msg = self.rfile.read(length)
96
+
97
+ # Do the annotation
98
+ doc = Document()
99
+ parseFromDelimitedString(doc, msg)
100
+ self.annotator.annotate(doc)
101
+
102
+ with io.BytesIO() as stream:
103
+ writeToDelimitedString(doc, stream)
104
+ msg = stream.getvalue()
105
+
106
+ # write message
107
+ self.send_response(HTTPStatus.OK)
108
+ self.send_header("Content-Type", "application/x-protobuf")
109
+ self.send_header("Content-Length", len(msg))
110
+ self.end_headers()
111
+ self.wfile.write(msg)
112
+
113
+ else:
114
+ self.send_response(HTTPStatus.BAD_REQUEST)
115
+ self.end_headers()
116
+
117
+ def __init__(self, host="", port=8432):
118
+ """
119
+ Launches a server endpoint to communicate with CoreNLP
120
+ """
121
+ Process.__init__(self)
122
+ self.host, self.port = host, port
123
+ self._Handler.annotator = self
124
+
125
+ def run(self):
126
+ """
127
+ Runs the server using Python's simple HTTPServer.
128
+ TODO: make this multithreaded.
129
+ """
130
+ httpd = HTTPServer((self.host, self.port), self._Handler)
131
+ sa = httpd.socket.getsockname()
132
+ serve_message = "Serving HTTP on {host} port {port} (http://{host}:{port}/) ..."
133
+ print(serve_message.format(host=sa[0], port=sa[1]))
134
+ try:
135
+ httpd.serve_forever()
136
+ except KeyboardInterrupt:
137
+ print("\nKeyboard interrupt received, exiting.")
138
+ httpd.shutdown()
stanza/stanza/server/client.py ADDED
@@ -0,0 +1,779 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Client for accessing Stanford CoreNLP in Python
3
+ """
4
+
5
+ import atexit
6
+ import contextlib
7
+ import enum
8
+ import io
9
+ import os
10
+ import re
11
+ import requests
12
+ import logging
13
+ import json
14
+ import shlex
15
+ import socket
16
+ import subprocess
17
+ import time
18
+ import sys
19
+ import uuid
20
+
21
+ from datetime import datetime
22
+ from pathlib import Path
23
+ from urllib.parse import urlparse
24
+
25
+ from stanza.protobuf import Document, parseFromDelimitedString, writeToDelimitedString, to_text
26
+ __author__ = 'arunchaganty, kelvinguu, vzhong, wmonroe4'
27
+
28
+ logger = logging.getLogger('stanza')
29
+
30
+ # pattern tmp props file should follow
31
+ SERVER_PROPS_TMP_FILE_PATTERN = re.compile('corenlp_server-(.*).props')
32
+
33
+ # Check if str is CoreNLP supported language
34
+ CORENLP_LANGS = ['ar', 'arabic', 'chinese', 'zh', 'english', 'en', 'french', 'fr', 'de', 'german', 'hu', 'hungarian',
35
+ 'it', 'italian', 'es', 'spanish']
36
+
37
+ # map shorthands to full language names
38
+ LANGUAGE_SHORTHANDS_TO_FULL = {
39
+ "ar": "arabic",
40
+ "zh": "chinese",
41
+ "en": "english",
42
+ "fr": "french",
43
+ "de": "german",
44
+ "hu": "hungarian",
45
+ "it": "italian",
46
+ "es": "spanish"
47
+ }
48
+
49
+
50
+ def is_corenlp_lang(props_str):
51
+ """ Check if a string references a CoreNLP language """
52
+ return props_str.lower() in CORENLP_LANGS
53
+
54
+
55
+ # Validate CoreNLP properties
56
+ CORENLP_OUTPUT_VALS = ["conll", "conllu", "json", "serialized", "text", "xml", "inlinexml"]
57
+
58
+
59
+ def validate_corenlp_props(properties=None, annotators=None, output_format=None):
60
+ """ Do basic checks to validate CoreNLP properties """
61
+ if output_format and output_format.lower() not in CORENLP_OUTPUT_VALS:
62
+ raise ValueError(f"{output_format} not a valid CoreNLP outputFormat value! Choose from: {CORENLP_OUTPUT_VALS}")
63
+ if type(properties) == dict:
64
+ if "outputFormat" in properties and properties["outputFormat"].lower() not in CORENLP_OUTPUT_VALS:
65
+ raise ValueError(f"{properties['outputFormat']} not a valid CoreNLP outputFormat value! Choose from: "
66
+ f"{CORENLP_OUTPUT_VALS}")
67
+
68
+
69
+ class AnnotationException(Exception):
70
+ """ Exception raised when there was an error communicating with the CoreNLP server. """
71
+ pass
72
+
73
+
74
+ class TimeoutException(AnnotationException):
75
+ """ Exception raised when the CoreNLP server timed out. """
76
+ pass
77
+
78
+
79
+ class ShouldRetryException(Exception):
80
+ """ Exception raised if the service should retry the request. """
81
+ pass
82
+
83
+
84
+ class PermanentlyFailedException(Exception):
85
+ """ Exception raised if the service should NOT retry the request. """
86
+ pass
87
+
88
+ class StartServer(enum.Enum):
89
+ DONT_START = 0
90
+ FORCE_START = 1
91
+ TRY_START = 2
92
+
93
+
94
+ def clean_props_file(props_file):
95
+ # check if there is a temp server props file to remove and remove it
96
+ if props_file:
97
+ if os.path.isfile(props_file) and SERVER_PROPS_TMP_FILE_PATTERN.match(os.path.basename(props_file)):
98
+ os.remove(props_file)
99
+
100
+
101
+ class RobustService(object):
102
+ """ Service that resuscitates itself if it is not available. """
103
+ CHECK_ALIVE_TIMEOUT = 120
104
+
105
+ def __init__(self, start_cmd, stop_cmd, endpoint, stdout=None,
106
+ stderr=None, be_quiet=False, host=None, port=None, ignore_binding_error=False):
107
+ self.start_cmd = start_cmd and shlex.split(start_cmd)
108
+ self.stop_cmd = stop_cmd and shlex.split(stop_cmd)
109
+ self.endpoint = endpoint
110
+ self.stdout = stdout
111
+ self.stderr = stderr
112
+
113
+ self.server = None
114
+ self.is_active = False
115
+ self.be_quiet = be_quiet
116
+ self.host = host
117
+ self.port = port
118
+ self.ignore_binding_error = ignore_binding_error
119
+ atexit.register(self.atexit_kill)
120
+
121
+ def is_alive(self):
122
+ try:
123
+ if not self.ignore_binding_error and self.server is not None and self.server.poll() is not None:
124
+ return False
125
+ return requests.get(self.endpoint + "/ping").ok
126
+ except requests.exceptions.ConnectionError as e:
127
+ raise ShouldRetryException(e)
128
+
129
+ def start(self):
130
+ if self.start_cmd:
131
+ if self.host and self.port:
132
+ with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
133
+ try:
134
+ sock.bind((self.host, self.port))
135
+ except socket.error as e:
136
+ if self.ignore_binding_error:
137
+ logger.info(f"Connecting to existing CoreNLP server at {self.host}:{self.port}")
138
+ self.server = None
139
+ return
140
+ else:
141
+ raise PermanentlyFailedException("Error: unable to start the CoreNLP server on port %d "
142
+ "(possibly something is already running there)" % self.port) from e
143
+ if self.be_quiet:
144
+ # Issue #26: subprocess.DEVNULL isn't supported in python 2.7.
145
+ if hasattr(subprocess, 'DEVNULL'):
146
+ stderr = subprocess.DEVNULL
147
+ else:
148
+ stderr = open(os.devnull, 'w')
149
+ stdout = stderr
150
+ else:
151
+ stdout = self.stdout
152
+ stderr = self.stderr
153
+ logger.info(f"Starting server with command: {' '.join(self.start_cmd)}")
154
+ try:
155
+ self.server = subprocess.Popen(self.start_cmd,
156
+ stderr=stderr,
157
+ stdout=stdout)
158
+ except FileNotFoundError as e:
159
+ raise FileNotFoundError("When trying to run CoreNLP, a FileNotFoundError occurred, which frequently means Java was not installed or was not in the classpath.") from e
160
+
161
+ def atexit_kill(self):
162
+ # make some kind of effort to stop the service (such as a
163
+ # CoreNLP server) at the end of the program. not waiting so
164
+ # that the python script exiting isn't delayed
165
+ if self.server and self.server.poll() is None:
166
+ self.server.terminate()
167
+
168
+ def stop(self):
169
+ if self.server:
170
+ self.server.terminate()
171
+ try:
172
+ self.server.wait(5)
173
+ except subprocess.TimeoutExpired:
174
+ # Resorting to more aggressive measures...
175
+ self.server.kill()
176
+ try:
177
+ self.server.wait(5)
178
+ except subprocess.TimeoutExpired:
179
+ # oh well
180
+ pass
181
+ self.server = None
182
+ if self.stop_cmd:
183
+ subprocess.run(self.stop_cmd, check=True)
184
+ self.is_active = False
185
+
186
+ def __enter__(self):
187
+ self.start()
188
+ return self
189
+
190
+ def __exit__(self, _, __, ___):
191
+ self.stop()
192
+
193
+ def ensure_alive(self):
194
+ # Check if the service is active and alive
195
+ if self.is_active:
196
+ try:
197
+ if self.is_alive():
198
+ return
199
+ else:
200
+ self.stop()
201
+ except ShouldRetryException:
202
+ pass
203
+
204
+ # If not, try to start up the service.
205
+ if self.server is None:
206
+ self.start()
207
+
208
+ # Wait for the service to start up.
209
+ start_time = time.time()
210
+ while True:
211
+ try:
212
+ if self.is_alive():
213
+ break
214
+ except ShouldRetryException:
215
+ pass
216
+
217
+ if time.time() - start_time < self.CHECK_ALIVE_TIMEOUT:
218
+ time.sleep(1)
219
+ else:
220
+ raise PermanentlyFailedException("Timed out waiting for service to come alive.")
221
+
222
+ # At this point we are guaranteed that the service is alive.
223
+ self.is_active = True
224
+
225
+
226
+ def resolve_classpath(classpath=None):
227
+ """
228
+ Returns the classpath to use for corenlp.
229
+
230
+ Prefers to use the given classpath parameter, if available. If
231
+ not, uses the CORENLP_HOME environment variable. Resolves $CLASSPATH
232
+ (the exact string) in either the classpath parameter or $CORENLP_HOME.
233
+ """
234
+ if classpath == '$CLASSPATH' or (classpath is None and os.getenv("CORENLP_HOME", None) == '$CLASSPATH'):
235
+ classpath = os.getenv("CLASSPATH")
236
+ elif classpath is None:
237
+ classpath = os.getenv("CORENLP_HOME", os.path.join(str(Path.home()), 'stanza_corenlp'))
238
+
239
+ if not os.path.exists(classpath):
240
+ raise FileNotFoundError("Please install CoreNLP by running `stanza.install_corenlp()`. If you have installed it, please define "
241
+ "$CORENLP_HOME to be location of your CoreNLP distribution or pass in a classpath parameter. "
242
+ "$CORENLP_HOME={}".format(os.getenv("CORENLP_HOME")))
243
+ classpath = os.path.join(classpath, "*")
244
+ return classpath
245
+
246
+
247
+ class CoreNLPClient(RobustService):
248
+ """ A client to the Stanford CoreNLP server. """
249
+
250
+ DEFAULT_ENDPOINT = "http://localhost:9000"
251
+ DEFAULT_TIMEOUT = 60000
252
+ DEFAULT_THREADS = 5
253
+ DEFAULT_OUTPUT_FORMAT = "serialized"
254
+ DEFAULT_MEMORY = "5G"
255
+ DEFAULT_MAX_CHAR_LENGTH = 100000
256
+
257
+ def __init__(self, start_server=StartServer.FORCE_START,
258
+ endpoint=DEFAULT_ENDPOINT,
259
+ timeout=DEFAULT_TIMEOUT,
260
+ threads=DEFAULT_THREADS,
261
+ annotators=None,
262
+ pretokenized=False,
263
+ output_format=None,
264
+ properties=None,
265
+ stdout=None,
266
+ stderr=None,
267
+ memory=DEFAULT_MEMORY,
268
+ be_quiet=False,
269
+ max_char_length=DEFAULT_MAX_CHAR_LENGTH,
270
+ preload=True,
271
+ classpath=None,
272
+ **kwargs):
273
+
274
+ # whether or not server should be started by client
275
+ self.start_server = start_server
276
+ self.server_props_path = None
277
+ self.server_start_time = None
278
+ self.server_host = None
279
+ self.server_port = None
280
+ self.server_classpath = None
281
+ # validate properties
282
+ validate_corenlp_props(properties=properties, annotators=annotators, output_format=output_format)
283
+ # set up client defaults
284
+ self.properties = properties
285
+ self.annotators = annotators
286
+ self.pretokenized = pretokenized
287
+ self.output_format = output_format
288
+ self._setup_client_defaults()
289
+ # start the server
290
+ if isinstance(start_server, bool):
291
+ warning_msg = f"Setting 'start_server' to a boolean value when constructing {self.__class__.__name__} is deprecated and will stop" + \
292
+ " to function in a future version of stanza. Please consider switching to using a value from stanza.server.StartServer."
293
+ logger.warning(warning_msg)
294
+ start_server = StartServer.FORCE_START if start_server is True else StartServer.DONT_START
295
+
296
+ # start the server
297
+ if start_server is StartServer.FORCE_START or start_server is StartServer.TRY_START:
298
+ # record info for server start
299
+ self.server_start_time = datetime.now()
300
+ # set up default properties for server
301
+ self._setup_server_defaults()
302
+ host, port = urlparse(endpoint).netloc.split(":")
303
+ port = int(port)
304
+ assert host == "localhost", "If starting a server, endpoint must be localhost"
305
+ classpath = resolve_classpath(classpath)
306
+ start_cmd = f"java -Xmx{memory} -cp '{classpath}' edu.stanford.nlp.pipeline.StanfordCoreNLPServer " \
307
+ f"-port {port} -timeout {timeout} -threads {threads} -maxCharLength {max_char_length} " \
308
+ f"-quiet {be_quiet} "
309
+
310
+ self.server_classpath = classpath
311
+ self.server_host = host
312
+ self.server_port = port
313
+
314
+ # set up server defaults
315
+ if self.server_props_path is not None:
316
+ start_cmd += f" -serverProperties {self.server_props_path}"
317
+
318
+ # possibly set pretokenized
319
+ if self.pretokenized:
320
+ start_cmd += f" -preTokenized"
321
+
322
+ # set annotators for server default
323
+ if self.annotators is not None:
324
+ annotators_str = self.annotators if type(annotators) == str else ",".join(annotators)
325
+ start_cmd += f" -annotators {annotators_str}"
326
+
327
+ # specify what to preload, if anything
328
+ if preload:
329
+ if type(preload) == bool:
330
+ # -preload flag means to preload all default annotators
331
+ start_cmd += " -preload"
332
+ elif type(preload) == list:
333
+ # turn list into comma separated list string, only preload these annotators
334
+ start_cmd += f" -preload {','.join(preload)}"
335
+ elif type(preload) == str:
336
+ # comma separated list of annotators
337
+ start_cmd += f" -preload {preload}"
338
+
339
+ # set outputFormat for server default
340
+ # if no output format requested by user, set to serialized
341
+ start_cmd += f" -outputFormat {self.output_format}"
342
+
343
+ # additional options for server:
344
+ # - server_id
345
+ # - ssl
346
+ # - status_port
347
+ # - uriContext
348
+ # - strict
349
+ # - key
350
+ # - username
351
+ # - password
352
+ # - blockList
353
+ for kw in ['ssl', 'strict']:
354
+ if kwargs.get(kw) is not None:
355
+ start_cmd += f" -{kw}"
356
+ for kw in ['status_port', 'uriContext', 'key', 'username', 'password', 'blockList', 'server_id']:
357
+ if kwargs.get(kw) is not None:
358
+ start_cmd += f" -{kw} {kwargs.get(kw)}"
359
+ stop_cmd = None
360
+ else:
361
+ start_cmd = stop_cmd = None
362
+ host = port = None
363
+
364
+ super(CoreNLPClient, self).__init__(start_cmd, stop_cmd, endpoint,
365
+ stdout, stderr, be_quiet, host=host, port=port, ignore_binding_error=(start_server == StartServer.TRY_START))
366
+
367
+ self.timeout = timeout
368
+
369
+ def _setup_client_defaults(self):
370
+ """
371
+ Do some processing of annotators and output_format specified for the client.
372
+ If interacting with an externally started server, these will be defaults for annotate() calls.
373
+ :return: None
374
+ """
375
+ # normalize annotators to str
376
+ if self.annotators is not None:
377
+ self.annotators = self.annotators if type(self.annotators) == str else ",".join(self.annotators)
378
+
379
+ # handle case where no output format is specified
380
+ if self.output_format is None:
381
+ if type(self.properties) == dict and 'outputFormat' in self.properties:
382
+ self.output_format = self.properties['outputFormat']
383
+ else:
384
+ self.output_format = CoreNLPClient.DEFAULT_OUTPUT_FORMAT
385
+
386
+ def _setup_server_defaults(self):
387
+ """
388
+ Set up the default properties for the server.
389
+
390
+ The properties argument can take on one of 3 value types
391
+
392
+ 1. File path on system or in CLASSPATH (e.g. /path/to/server.props or StanfordCoreNLP-french.properties
393
+ 2. Name of a Stanford CoreNLP supported language (e.g. french or fr)
394
+ 3. Python dictionary (properties written to tmp file for Java server, erased at end)
395
+
396
+ In addition, an annotators list and output_format can be specified directly with arguments. These
397
+ will overwrite any settings in the specified properties.
398
+
399
+ If no properties are specified, the standard Stanford CoreNLP English server will be launched. The outputFormat
400
+ will be set to 'serialized' and use the ProtobufAnnotationSerializer.
401
+ """
402
+
403
+ # ensure properties is str or dict
404
+ if self.properties is None or (not isinstance(self.properties, str) and not isinstance(self.properties, dict)):
405
+ if self.properties is not None:
406
+ logger.warning('properties passed invalid value (not a str or dict), setting properties = {}')
407
+ self.properties = {}
408
+ # check if properties is a string, pass on to server which can handle
409
+ if isinstance(self.properties, str):
410
+ # try to translate to Stanford CoreNLP language name, or assume properties is a path
411
+ if is_corenlp_lang(self.properties):
412
+ if self.properties.lower() in LANGUAGE_SHORTHANDS_TO_FULL:
413
+ self.properties = LANGUAGE_SHORTHANDS_TO_FULL[self.properties]
414
+ logger.info(
415
+ f"Using CoreNLP default properties for: {self.properties}. Make sure to have "
416
+ f"{self.properties} models jar (available for download here: "
417
+ f"https://stanfordnlp.github.io/CoreNLP/) in CLASSPATH")
418
+ else:
419
+ if not os.path.isfile(self.properties):
420
+ logger.warning(f"{self.properties} does not correspond to a file path. Make sure this file is in "
421
+ f"your CLASSPATH.")
422
+ self.server_props_path = self.properties
423
+ elif isinstance(self.properties, dict):
424
+ # make a copy
425
+ server_start_properties = dict(self.properties)
426
+ if self.annotators is not None:
427
+ server_start_properties['annotators'] = self.annotators
428
+ if self.output_format is not None and isinstance(self.output_format, str):
429
+ server_start_properties['outputFormat'] = self.output_format
430
+ # write desired server start properties to tmp file
431
+ # set up to erase on exit
432
+ tmp_path = write_corenlp_props(server_start_properties)
433
+ logger.info(f"Writing properties to tmp file: {tmp_path}")
434
+ atexit.register(clean_props_file, tmp_path)
435
+ self.server_props_path = tmp_path
436
+
437
+ def _request(self, buf, properties, reset_default=False, **kwargs):
438
+ """
439
+ Send a request to the CoreNLP server.
440
+
441
+ :param (str | bytes) buf: data to be sent with the request
442
+ :param (dict) properties: properties that the server expects
443
+ :return: request result
444
+ """
445
+ if self.start_server is not StartServer.DONT_START:
446
+ self.ensure_alive()
447
+
448
+ try:
449
+ input_format = properties.get("inputFormat", "text")
450
+ if input_format == "text":
451
+ ctype = "text/plain; charset=utf-8"
452
+ elif input_format == "serialized":
453
+ ctype = "application/x-protobuf"
454
+ else:
455
+ raise ValueError("Unrecognized inputFormat " + input_format)
456
+ # handle auth
457
+ if 'username' in kwargs and 'password' in kwargs:
458
+ kwargs['auth'] = requests.auth.HTTPBasicAuth(kwargs['username'], kwargs['password'])
459
+ kwargs.pop('username')
460
+ kwargs.pop('password')
461
+ r = requests.post(self.endpoint,
462
+ params={'properties': str(properties), 'resetDefault': str(reset_default).lower()},
463
+ data=buf, headers={'content-type': ctype},
464
+ timeout=(self.timeout*2)/1000, **kwargs)
465
+ r.raise_for_status()
466
+ return r
467
+ except requests.exceptions.Timeout as e:
468
+ raise TimeoutException("Timeout requesting to CoreNLPServer. Maybe server is unavailable or your document is too long")
469
+ except requests.exceptions.RequestException as e:
470
+ if e.response is not None and e.response.text is not None:
471
+ raise AnnotationException(e.response.text) from e
472
+ elif e.args:
473
+ raise AnnotationException(e.args[0]) from e
474
+ raise AnnotationException() from e
475
+
476
+ def annotate(self, text, annotators=None, output_format=None, properties=None, reset_default=None, **kwargs):
477
+ """
478
+ Send a request to the CoreNLP server.
479
+
480
+ :param (str | unicode) text: raw text for the CoreNLPServer to parse
481
+ :param (list | string) annotators: list of annotators to use
482
+ :param (str) output_format: output type from server: serialized, json, text, conll, conllu, or xml
483
+ :param (dict) properties: additional request properties (written on top of defaults)
484
+ :param (bool) reset_default: don't use server defaults
485
+
486
+ Precedence for settings:
487
+
488
+ 1. annotators and output_format args
489
+ 2. Values from properties dict
490
+ 3. Client defaults self.annotators and self.output_format (set during client construction)
491
+ 4. Server defaults
492
+
493
+ Additional request parameters (apart from CoreNLP pipeline properties) such as 'username' and 'password'
494
+ can be specified with the kwargs.
495
+
496
+ :return: request result
497
+ """
498
+
499
+ # validate request properties
500
+ validate_corenlp_props(properties=properties, annotators=annotators, output_format=output_format)
501
+ # set request properties
502
+ request_properties = {}
503
+
504
+ # start with client defaults
505
+ if self.annotators is not None:
506
+ request_properties['annotators'] = self.annotators
507
+ if self.output_format is not None:
508
+ request_properties['outputFormat'] = self.output_format
509
+
510
+ # add values from properties arg
511
+ # handle str case
512
+ if type(properties) == str:
513
+ if is_corenlp_lang(properties):
514
+ properties = {'pipelineLanguage': properties.lower()}
515
+ if reset_default is None:
516
+ reset_default = True
517
+ else:
518
+ raise ValueError(f"Unrecognized properties keyword {properties}")
519
+
520
+ if type(properties) == dict:
521
+ request_properties.update(properties)
522
+
523
+ # if annotators list is specified, override with that
524
+ # also can use the annotators field the object was created with
525
+ if annotators is not None and (type(annotators) == str or type(annotators) == list):
526
+ request_properties['annotators'] = annotators if type(annotators) == str else ",".join(annotators)
527
+
528
+ # if output format is specified, override with that
529
+ if output_format is not None and type(output_format) == str:
530
+ request_properties['outputFormat'] = output_format
531
+
532
+ # make the request
533
+ # if not explicitly set or the case of pipelineLanguage, reset_default should be None
534
+ if reset_default is None:
535
+ reset_default = False
536
+ r = self._request(text.encode('utf-8'), request_properties, reset_default, **kwargs)
537
+ if request_properties["outputFormat"] == "json":
538
+ return r.json()
539
+ elif request_properties["outputFormat"] == "serialized":
540
+ doc = Document()
541
+ parseFromDelimitedString(doc, r.content)
542
+ return doc
543
+ elif request_properties["outputFormat"] in ["text", "conllu", "conll", "xml"]:
544
+ return r.text
545
+ else:
546
+ return r
547
+
548
+ def update(self, doc, annotators=None, properties=None):
549
+ if properties is None:
550
+ properties = {}
551
+ properties.update({
552
+ 'inputFormat': 'serialized',
553
+ 'outputFormat': 'serialized',
554
+ 'serializer': 'edu.stanford.nlp.pipeline.ProtobufAnnotationSerializer'
555
+ })
556
+ if annotators:
557
+ properties['annotators'] = annotators if type(annotators) == str else ",".join(annotators)
558
+ with io.BytesIO() as stream:
559
+ writeToDelimitedString(doc, stream)
560
+ msg = stream.getvalue()
561
+
562
+ r = self._request(msg, properties)
563
+ doc = Document()
564
+ parseFromDelimitedString(doc, r.content)
565
+ return doc
566
+
567
+ def tokensregex(self, text, pattern, filter=False, to_words=False, annotators=None, properties=None):
568
+ # this is required for some reason
569
+ matches = self.__regex('/tokensregex', text, pattern, filter, annotators, properties)
570
+ if to_words:
571
+ matches = regex_matches_to_indexed_words(matches)
572
+ return matches
573
+
574
+ def semgrex(self, text, pattern, filter=False, to_words=False, annotators=None, properties=None):
575
+ matches = self.__regex('/semgrex', text, pattern, filter, annotators, properties)
576
+ if to_words:
577
+ matches = regex_matches_to_indexed_words(matches)
578
+ return matches
579
+
580
+ def fill_tree_proto(self, tree, proto_tree):
581
+ if tree.label:
582
+ proto_tree.value = tree.label
583
+ for child in tree.children:
584
+ proto_child = proto_tree.child.add()
585
+ self.fill_tree_proto(child, proto_child)
586
+
587
+ def tregex(self, text=None, pattern=None, filter=False, annotators=None, properties=None, trees=None):
588
+ # parse is not included by default in some of the pipelines,
589
+ # so we may need to manually override the annotators
590
+ # to include parse in order for tregex to do anything
591
+ if annotators is None and self.annotators is not None:
592
+ assert isinstance(self.annotators, str)
593
+ pieces = self.annotators.split(",")
594
+ if "parse" not in pieces:
595
+ annotators = self.annotators + ",parse"
596
+ else:
597
+ annotators = "tokenize,ssplit,pos,parse"
598
+ if pattern is None:
599
+ raise ValueError("Cannot have None as a pattern for tregex")
600
+
601
+ # TODO: we could also allow for passing in a complete document,
602
+ # along with the original text, so that the spans returns are more accurate
603
+ if trees is not None:
604
+ if properties is None:
605
+ properties = {}
606
+ properties['inputFormat'] = 'serialized'
607
+ if 'serializer' not in properties:
608
+ properties['serializer'] = 'edu.stanford.nlp.pipeline.ProtobufAnnotationSerializer'
609
+ doc = Document()
610
+ full_text = []
611
+ for tree_idx, tree in enumerate(trees):
612
+ sentence = doc.sentence.add()
613
+ sentence.sentenceIndex = tree_idx
614
+ sentence.tokenOffsetBegin = len(full_text)
615
+ leaves = tree.leaf_labels()
616
+ full_text.extend(leaves)
617
+ sentence.tokenOffsetEnd = len(full_text)
618
+ self.fill_tree_proto(tree, sentence.parseTree)
619
+ for word in leaves:
620
+ token = sentence.token.add()
621
+ # the other side uses both value and word, weirdly enough
622
+ token.value = word
623
+ token.word = word
624
+ # without the actual tokenization, at least we can
625
+ # stop the words from running together
626
+ token.after = " "
627
+ doc.text = " ".join(full_text)
628
+ with io.BytesIO() as stream:
629
+ writeToDelimitedString(doc, stream)
630
+ text = stream.getvalue()
631
+
632
+ return self.__regex('/tregex', text, pattern, filter, annotators, properties)
633
+
634
+ def __regex(self, path, text, pattern, filter, annotators=None, properties=None):
635
+ """
636
+ Send a regex-related request to the CoreNLP server.
637
+
638
+ :param (str | unicode) path: the path for the regex endpoint
639
+ :param text: raw text for the CoreNLPServer to apply the regex
640
+ :param (str | unicode) pattern: regex pattern
641
+ :param (bool) filter: option to filter sentences that contain matches, if false returns matches
642
+ :param properties: option to filter sentences that contain matches, if false returns matches
643
+ :return: request result
644
+ """
645
+ if self.start_server is not StartServer.DONT_START:
646
+ self.ensure_alive()
647
+ if properties is None:
648
+ properties = {}
649
+ properties.update({
650
+ 'inputFormat': 'text',
651
+ 'serializer': 'edu.stanford.nlp.pipeline.ProtobufAnnotationSerializer'
652
+ })
653
+ if annotators:
654
+ properties['annotators'] = ",".join(annotators) if isinstance(annotators, list) else annotators
655
+
656
+ # force output for regex requests to be json
657
+ properties['outputFormat'] = 'json'
658
+ # if the server is trying to send back character offsets, it
659
+ # should send back codepoints counts as well in case the text
660
+ # has extra wide characters
661
+ properties['tokenize.codepoint'] = 'true'
662
+
663
+ try:
664
+ # Error occurs unless put properties in params
665
+ input_format = properties.get("inputFormat", "text")
666
+ if input_format == "text":
667
+ ctype = "text/plain; charset=utf-8"
668
+ elif input_format == "serialized":
669
+ ctype = "application/x-protobuf"
670
+ else:
671
+ raise ValueError("Unrecognized inputFormat " + input_format)
672
+ # change request method from `get` to `post` as required by CoreNLP
673
+ r = requests.post(
674
+ self.endpoint + path, params={
675
+ 'pattern': pattern,
676
+ 'filter': filter,
677
+ 'properties': str(properties)
678
+ },
679
+ data=text.encode('utf-8') if isinstance(text, str) else text,
680
+ headers={'content-type': ctype},
681
+ timeout=(self.timeout*2)/1000,
682
+ )
683
+ r.raise_for_status()
684
+ if r.encoding is None:
685
+ r.encoding = "utf-8"
686
+ return json.loads(r.text)
687
+ except requests.HTTPError as e:
688
+ if r.text.startswith("Timeout"):
689
+ raise TimeoutException(r.text)
690
+ else:
691
+ raise AnnotationException(r.text)
692
+ except json.JSONDecodeError:
693
+ raise AnnotationException(r.text)
694
+
695
+
696
+ def scenegraph(self, text, properties=None):
697
+ """
698
+ Send a request to the server which processes the text using SceneGraph
699
+
700
+ This will require a new CoreNLP release, 4.5.5 or later
701
+ """
702
+ # since we're using requests ourself,
703
+ # check if the server has started or not
704
+ if self.start_server is not StartServer.DONT_START:
705
+ self.ensure_alive()
706
+
707
+ if properties is None:
708
+ properties = {}
709
+ # the only thing the scenegraph knows how to use is text
710
+ properties['inputFormat'] = 'text'
711
+ ctype = "text/plain; charset=utf-8"
712
+ # the json output format is much more useful
713
+ properties['outputFormat'] = 'json'
714
+ try:
715
+ r = requests.post(
716
+ self.endpoint + "/scenegraph",
717
+ params={
718
+ 'properties': str(properties)
719
+ },
720
+ data=text.encode('utf-8') if isinstance(text, str) else text,
721
+ headers={'content-type': ctype},
722
+ timeout=(self.timeout*2)/1000,
723
+ )
724
+ r.raise_for_status()
725
+ if r.encoding is None:
726
+ r.encoding = "utf-8"
727
+ return json.loads(r.text)
728
+ except requests.HTTPError as e:
729
+ if r.text.startswith("Timeout"):
730
+ raise TimeoutException(r.text)
731
+ else:
732
+ raise AnnotationException(r.text)
733
+ except json.JSONDecodeError:
734
+ raise AnnotationException(r.text)
735
+
736
+
737
+ def read_corenlp_props(props_path):
738
+ """ Read a Stanford CoreNLP properties file into a dict """
739
+ props_dict = {}
740
+ with open(props_path) as props_file:
741
+ entry_lines = [entry_line for entry_line in props_file.read().split('\n')
742
+ if entry_line.strip() and not entry_line.startswith('#')]
743
+ for entry_line in entry_lines:
744
+ k = entry_line.split('=')[0]
745
+ k_len = len(k+"=")
746
+ v = entry_line[k_len:]
747
+ props_dict[k.strip()] = v
748
+ return props_dict
749
+
750
+
751
+ def write_corenlp_props(props_dict, file_path=None):
752
+ """ Write a Stanford CoreNLP properties dict to a file """
753
+ if file_path is None:
754
+ file_path = f"corenlp_server-{uuid.uuid4().hex[:16]}.props"
755
+ # confirm tmp file path matches pattern
756
+ assert SERVER_PROPS_TMP_FILE_PATTERN.match(file_path)
757
+ with open(file_path, 'w') as props_file:
758
+ for k, v in props_dict.items():
759
+ if isinstance(v, list):
760
+ writeable_v = ",".join(v)
761
+ else:
762
+ writeable_v = v
763
+ props_file.write(f'{k} = {writeable_v}\n\n')
764
+ return file_path
765
+
766
+
767
+ def regex_matches_to_indexed_words(matches):
768
+ """
769
+ Transforms tokensregex and semgrex matches to indexed words.
770
+ :param matches: unprocessed regex matches
771
+ :return: flat array of indexed words
772
+ """
773
+ words = [dict(v, **dict([('sentence', i)]))
774
+ for i, s in enumerate(matches['sentences'])
775
+ for k, v in s.items() if k != 'length']
776
+ return words
777
+
778
+
779
+ __all__ = ["CoreNLPClient", "AnnotationException", "TimeoutException", "to_text"]
stanza/stanza/server/semgrex.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Invokes the Java semgrex on a document
2
+
3
+ The server client has a method "semgrex" which sends text to Java
4
+ CoreNLP for processing with a semgrex (SEMantic GRaph regEX) query:
5
+
6
+ https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/semgraph/semgrex/SemgrexPattern.html
7
+
8
+ However, this operates on text using the CoreNLP tools, which means
9
+ the dependency graphs may not align with stanza's depparse module, and
10
+ this also limits the languages for which it can be used. This module
11
+ allows for running semgrex commands on the graphs produced by
12
+ depparse.
13
+
14
+ To use, first process text into a doc using stanza.Pipeline
15
+
16
+ Next, pass the processed doc and a list of semgrex patterns to
17
+ process_doc in this module. It will run the java semgrex module as a
18
+ subprocess and return the result in the form of a SemgrexResponse,
19
+ whose description is in the proto file included with stanza.
20
+
21
+ A minimal example is the main method of this module.
22
+
23
+ Note that launching the subprocess is potentially quite expensive
24
+ relative to the search if used many times on small documents. Ideally
25
+ larger texts would be processed, and all of the desired semgrex
26
+ patterns would be run at once. The worst thing to do would be to call
27
+ this multiple times on a large document, one invocation per semgrex
28
+ pattern, as that would serialize the document each time.
29
+ Included here is a context manager which allows for keeping the same
30
+ java process open for multiple requests. This saves on the subprocess
31
+ launching time. It is still important not to wastefully serialize the
32
+ same document over and over, though.
33
+ """
34
+
35
+ import argparse
36
+ import copy
37
+
38
+ import stanza
39
+ from stanza.protobuf import SemgrexRequest, SemgrexResponse
40
+ from stanza.server.java_protobuf_requests import send_request, add_token, add_word_to_graph, JavaProtobufContext, convert_networkx_graph
41
+ from stanza.utils.conll import CoNLL
42
+
43
+ SEMGREX_JAVA = "edu.stanford.nlp.semgraph.semgrex.ProcessSemgrexRequest"
44
+
45
+ def send_semgrex_request(request):
46
+ return send_request(request, SemgrexResponse, SEMGREX_JAVA)
47
+
48
+ def build_request(doc, semgrex_patterns, enhanced=False):
49
+ request = SemgrexRequest()
50
+ if isinstance(semgrex_patterns, str):
51
+ semgrex_patterns = [semgrex_patterns]
52
+ for semgrex in semgrex_patterns:
53
+ request.semgrex.append(semgrex)
54
+
55
+ for sent_idx, sentence in enumerate(doc.sentences):
56
+ query = request.query.add()
57
+ if enhanced:
58
+ # tokens will be added on to the graph object
59
+ convert_networkx_graph(query.graph, sentence, sent_idx)
60
+ else:
61
+ word_idx = 0
62
+ for token in sentence.tokens:
63
+ for word in token.words:
64
+ add_token(query.token, word, token)
65
+ add_word_to_graph(query.graph, word, sent_idx, word_idx)
66
+
67
+ word_idx = word_idx + 1
68
+
69
+ return request
70
+
71
+ def process_doc(doc, *semgrex_patterns, enhanced=False):
72
+ """
73
+ Returns the result of processing the given semgrex expression on the stanza doc.
74
+
75
+ Currently the return is a SemgrexResponse from CoreNLP.proto
76
+ """
77
+ request = build_request(doc, semgrex_patterns, enhanced=enhanced)
78
+
79
+ return send_semgrex_request(request)
80
+
81
+ class Semgrex(JavaProtobufContext):
82
+ """
83
+ Semgrex context window
84
+
85
+ This is a context window which keeps a process open. Should allow
86
+ for multiple requests without launching new java processes each time.
87
+ """
88
+ def __init__(self, classpath=None):
89
+ super(Semgrex, self).__init__(classpath, SemgrexResponse, SEMGREX_JAVA)
90
+
91
+ def process(self, doc, *semgrex_patterns):
92
+ """
93
+ Apply each of the semgrex patterns to each of the dependency trees in doc
94
+ """
95
+ request = build_request(doc, semgrex_patterns)
96
+ return self.process_request(request)
97
+
98
+ def annotate_doc(doc, semgrex_result, semgrex_patterns, matches_only):
99
+ """
100
+ Put comments on the sentences which describe the matching semgrex patterns
101
+ """
102
+ doc = copy.deepcopy(doc)
103
+ if isinstance(semgrex_patterns, str):
104
+ semgrex_patterns = [semgrex_patterns]
105
+ matching_sentences = []
106
+ for sentence, graph_result in zip(doc.sentences, semgrex_result.result):
107
+ sentence_matched = False
108
+ for semgrex_pattern, pattern_result in zip(semgrex_patterns, graph_result.result):
109
+ semgrex_pattern = semgrex_pattern.replace("\n", " ")
110
+ if len(pattern_result.match) == 0:
111
+ sentence.add_comment("# semgrex pattern |%s| did not match!" % semgrex_pattern)
112
+ else:
113
+ sentence_matched = True
114
+ for match in pattern_result.match:
115
+ match_word = "%d:%s" % (match.matchIndex, sentence.words[match.matchIndex-1].text)
116
+ if len(match.node) == 0:
117
+ node_matches = ""
118
+ else:
119
+ node_matches = ["%s=%d:%s" % (node.name, node.matchIndex, sentence.words[node.matchIndex-1].text)
120
+ for node in match.node]
121
+ node_matches = " " + " ".join(node_matches)
122
+ sentence.add_comment("# semgrex pattern |%s| matched at %s%s" % (semgrex_pattern, match_word, node_matches))
123
+ if sentence_matched:
124
+ matching_sentences.append(sentence)
125
+ if matches_only:
126
+ doc.sentences = matching_sentences
127
+ return doc
128
+
129
+
130
+ def main():
131
+ """
132
+ Runs a toy example, or can run a given semgrex expression on the given input file.
133
+
134
+ For example:
135
+ python3 -m stanza.server.semgrex --input_file demo/semgrex_sample.conllu
136
+
137
+ --matches_only to only print sentences that match the semgrex pattern
138
+ --no_print_input to not print the input
139
+ """
140
+ parser = argparse.ArgumentParser()
141
+ parser.add_argument('--input_file', type=str, default=None, help="Input file to process (otherwise will process a sample text)")
142
+ parser.add_argument('semgrex', type=str, nargs="*", default=["{}=source >obj=zzz {}=target"], help="Semgrex to apply to the text. The default looks for sentences with objects")
143
+ parser.add_argument('--semgrex_file', type=str, default=None, help="File to read semgrex patterns from - relevant in case the pattern you want to use doesn't work well on the command line, for example")
144
+ parser.add_argument('--print_input', dest='print_input', action='store_true', default=False, help="Print the input alongside the output - gets kind of noisy")
145
+ parser.add_argument('--no_print_input', dest='print_input', action='store_false', help="Don't print the input alongside the output - gets kind of noisy")
146
+ parser.add_argument('--matches_only', action='store_true', default=False, help="Only print the matching sentences")
147
+ parser.add_argument('--enhanced', action='store_true', default=False, help='Use the enhanced dependencies instead of the basic')
148
+ args = parser.parse_args()
149
+
150
+ if args.semgrex_file:
151
+ with open(args.semgrex_file) as fin:
152
+ args.semgrex = [x.strip() for x in fin.readlines() if x.strip()]
153
+
154
+ if args.input_file:
155
+ doc = CoNLL.conll2doc(input_file=args.input_file, ignore_gapping=False)
156
+ else:
157
+ nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma,depparse')
158
+ doc = nlp('Uro ruined modern. Fortunately, Wotc banned him.')
159
+
160
+ if args.print_input:
161
+ print("{:C}".format(doc))
162
+ print()
163
+ print("-" * 75)
164
+ print()
165
+ semgrex_result = process_doc(doc, *args.semgrex, enhanced=args.enhanced)
166
+ doc = annotate_doc(doc, semgrex_result, args.semgrex, args.matches_only)
167
+ print("{:C}".format(doc))
168
+
169
+ if __name__ == '__main__':
170
+ main()
stanza/stanza/server/ssurgeon.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Invokes the Java ssurgeon on a document
2
+
3
+ "ssurgeon" sends text to Java CoreNLP for processing with a ssurgeon
4
+ (Semantic graph SURGEON) query
5
+
6
+ The main program in this file gives a very short intro to how to use it.
7
+ """
8
+
9
+
10
+ import argparse
11
+ from collections import namedtuple
12
+ import copy
13
+ import os
14
+ import re
15
+ import sys
16
+
17
+ from stanza.models.common.utils import misc_to_space_after, space_after_to_misc
18
+ from stanza.protobuf import SsurgeonRequest, SsurgeonResponse
19
+ from stanza.server import java_protobuf_requests
20
+ from stanza.utils.conll import CoNLL
21
+
22
+ from stanza.models.common.doc import ID, TEXT, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC, START_CHAR, END_CHAR, NER, Word, Token, Sentence
23
+
24
+ SSURGEON_JAVA = "edu.stanford.nlp.semgraph.semgrex.ssurgeon.ProcessSsurgeonRequest"
25
+
26
+ SsurgeonEdit = namedtuple("SsurgeonEdit",
27
+ "semgrex_pattern ssurgeon_edits ssurgeon_id notes language",
28
+ defaults=[None, None, "UniversalEnglish"])
29
+
30
+ def parse_ssurgeon_edits(ssurgeon_text):
31
+ ssurgeon_text = ssurgeon_text.strip()
32
+ ssurgeon_blocks = re.split("\n\n+", ssurgeon_text)
33
+ ssurgeon_edits = []
34
+ for idx, block in enumerate(ssurgeon_blocks):
35
+ lines = block.split("\n")
36
+ comments = [line[1:].strip() for line in lines if line.startswith("#")]
37
+ notes = " ".join(comments)
38
+ lines = [x.strip() for x in lines if x.strip() and not x.startswith("#")]
39
+ if len(lines) == 0:
40
+ # was a block of entirely comments
41
+ continue
42
+ semgrex = lines[0]
43
+ ssurgeon = lines[1:]
44
+ ssurgeon_edits.append(SsurgeonEdit(semgrex, ssurgeon, "%d" % (idx + 1), notes))
45
+ return ssurgeon_edits
46
+
47
+ def read_ssurgeon_edits(edit_file):
48
+ with open(edit_file, encoding="utf-8") as fin:
49
+ return parse_ssurgeon_edits(fin.read())
50
+
51
+ def send_ssurgeon_request(request):
52
+ return java_protobuf_requests.send_request(request, SsurgeonResponse, SSURGEON_JAVA)
53
+
54
+ def build_request(doc, ssurgeon_edits):
55
+ request = SsurgeonRequest()
56
+
57
+ for ssurgeon in ssurgeon_edits:
58
+ ssurgeon_proto = request.ssurgeon.add()
59
+ ssurgeon_proto.semgrex = ssurgeon.semgrex_pattern
60
+ for operation in ssurgeon.ssurgeon_edits:
61
+ ssurgeon_proto.operation.append(operation)
62
+ if ssurgeon.ssurgeon_id is not None:
63
+ ssurgeon_proto.id = ssurgeon.ssurgeon_id
64
+ if ssurgeon.notes is not None:
65
+ ssurgeon_proto.notes = ssurgeon.notes
66
+ if ssurgeon.language is not None:
67
+ ssurgeon_proto.language = ssurgeon.language
68
+
69
+ try:
70
+ for sent_idx, sentence in enumerate(doc.sentences):
71
+ graph = request.graph.add()
72
+ word_idx = 0
73
+ for token in sentence.tokens:
74
+ for word in token.words:
75
+ java_protobuf_requests.add_token(graph.token, word, token)
76
+ java_protobuf_requests.add_word_to_graph(graph, word, sent_idx, word_idx)
77
+
78
+ word_idx = word_idx + 1
79
+ except Exception as e:
80
+ raise RuntimeError("Failed to process sentence {}:\n{:C}".format(sent_idx, sentence)) from e
81
+
82
+ return request
83
+
84
+ def build_request_one_operation(doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id=None, notes=None):
85
+ ssurgeon_edit = SsurgeonEdit(semgrex_pattern, ssurgeon_edits, ssurgeon_id, notes)
86
+ return build_request(doc, [ssurgeon_edit])
87
+
88
+ def process_doc(doc, ssurgeon_edits):
89
+ """
90
+ Returns the result of processing the given semgrex expression and ssurgeon edits on the stanza doc.
91
+
92
+ Currently the return is a SsurgeonResponse from CoreNLP.proto
93
+ """
94
+ request = build_request(doc, ssurgeon_edits)
95
+
96
+ return send_ssurgeon_request(request)
97
+
98
+ def process_doc_one_operation(doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id=None, notes=None):
99
+ request = build_request_one_operation(doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id, notes)
100
+
101
+ return send_ssurgeon_request(request)
102
+
103
+ def build_word_entry(word_index, graph_word):
104
+ word_entry = {
105
+ ID: word_index,
106
+ TEXT: graph_word.word if graph_word.word else None,
107
+ LEMMA: graph_word.lemma if graph_word.lemma else None,
108
+ UPOS: graph_word.coarseTag if graph_word.coarseTag else None,
109
+ XPOS: graph_word.pos if graph_word.pos else None,
110
+ FEATS: java_protobuf_requests.features_to_string(graph_word.conllUFeatures),
111
+ DEPS: None,
112
+ NER: graph_word.ner if graph_word.ner else None,
113
+ MISC: None,
114
+ START_CHAR: None, # TODO: fix this? one problem is the text positions
115
+ END_CHAR: None, # might change across all of the sentences
116
+ # presumably python will complain if this conflicts
117
+ # with one of the constants above
118
+ "is_mwt": graph_word.isMWT,
119
+ "is_first_mwt": graph_word.isFirstMWT,
120
+ "mwt_text": graph_word.mwtText,
121
+ "mwt_misc": graph_word.mwtMisc,
122
+ }
123
+ # TODO: do "before" as well
124
+ word_entry[MISC] = space_after_to_misc(graph_word.after)
125
+ if graph_word.conllUMisc:
126
+ word_entry[MISC] = java_protobuf_requests.substitute_space_misc(graph_word.conllUMisc, word_entry[MISC])
127
+ return word_entry
128
+
129
+ def convert_response_to_doc(doc, semgrex_response):
130
+ doc = copy.deepcopy(doc)
131
+ try:
132
+ for sent_idx, (sentence, ssurgeon_result) in enumerate(zip(doc.sentences, semgrex_response.result)):
133
+ # EditNode is currently bugged... :/
134
+ # TODO: change this after next CoreNLP release (after 4.5.3)
135
+ #if not ssurgeon_result.changed:
136
+ # continue
137
+
138
+ ssurgeon_graph = ssurgeon_result.graph
139
+ tokens = []
140
+ for graph_node, graph_word in zip(ssurgeon_graph.node, ssurgeon_graph.token):
141
+ word_entry = build_word_entry(graph_node.index, graph_word)
142
+ tokens.append(word_entry)
143
+ tokens.sort(key=lambda x: x[ID])
144
+ for root in ssurgeon_graph.root:
145
+ tokens[root-1][HEAD] = 0
146
+ tokens[root-1][DEPREL] = "root"
147
+ for edge in ssurgeon_graph.edge:
148
+ # can't do anything about the extra dependencies for now
149
+ # TODO: put them all in .deps
150
+ if edge.isExtra:
151
+ continue
152
+ tokens[edge.target-1][HEAD] = edge.source
153
+ tokens[edge.target-1][DEPREL] = edge.dep
154
+
155
+ # for any MWT, produce a token_entry which represents the word range
156
+ mwt_tokens = []
157
+ for word_start_idx, word in enumerate(tokens):
158
+ if not word["is_first_mwt"]:
159
+ if word["is_mwt"]:
160
+ word[MISC] = java_protobuf_requests.remove_space_misc(word[MISC])
161
+ mwt_tokens.append(word)
162
+ continue
163
+ word_end_idx = word_start_idx + 1
164
+ while word_end_idx < len(tokens) and tokens[word_end_idx]["is_mwt"] and not tokens[word_end_idx]["is_first_mwt"]:
165
+ word_end_idx += 1
166
+ mwt_token_entry = {
167
+ # the tokens don't fencepost the way lists do
168
+ ID: (tokens[word_start_idx][ID], tokens[word_end_idx-1][ID]),
169
+ TEXT: word["mwt_text"],
170
+ NER: word[NER],
171
+ # use the SpaceAfter=No (or not) from the last word in the token
172
+ MISC: None,
173
+ }
174
+ mwt_token_entry[MISC] = java_protobuf_requests.misc_space_pieces(tokens[word_end_idx-1][MISC])
175
+ if tokens[word_end_idx-1]["mwt_misc"]:
176
+ mwt_token_entry[MISC] = java_protobuf_requests.substitute_space_misc(tokens[word_end_idx-1]["mwt_misc"], mwt_token_entry[MISC])
177
+ word[MISC] = java_protobuf_requests.remove_space_misc(word[MISC])
178
+ mwt_tokens.append(mwt_token_entry)
179
+ mwt_tokens.append(word)
180
+
181
+ old_comments = list(sentence.comments)
182
+ sentence = Sentence(mwt_tokens, doc)
183
+
184
+ token_text = []
185
+ for token_idx, token in enumerate(sentence.tokens):
186
+ token_text.append(token.text)
187
+ if token_idx == len(sentence.tokens) - 1:
188
+ break
189
+ token_text.append(token.spaces_after)
190
+
191
+ sentence_text = "".join(token_text)
192
+
193
+ for comment in old_comments:
194
+ if comment.startswith("# text ") or comment.startswith("#text ") or comment.startswith("# text=") or comment.startswith("#text="):
195
+ sentence.add_comment("# text = " + sentence_text)
196
+ else:
197
+ sentence.add_comment(comment)
198
+
199
+ doc.sentences[sent_idx] = sentence
200
+
201
+ sentence.rebuild_dependencies()
202
+ except Exception as e:
203
+ raise RuntimeError("Ssurgeon could not process sentence {}\nSsurgeon result:\n{}\nOriginal sentence:\n{:C}".format(sent_idx, ssurgeon_result, sentence)) from e
204
+ return doc
205
+
206
+ class Ssurgeon(java_protobuf_requests.JavaProtobufContext):
207
+ """
208
+ Ssurgeon context window
209
+
210
+ This is a context window which keeps a process open. Should allow
211
+ for multiple requests without launching new java processes each time.
212
+ """
213
+ def __init__(self, classpath=None):
214
+ super(Ssurgeon, self).__init__(classpath, SsurgeonResponse, SSURGEON_JAVA)
215
+
216
+ def process(self, doc, ssurgeon_edits):
217
+ """
218
+ Apply each of the ssurgeon patterns to each of the dependency trees in doc
219
+ """
220
+ request = build_request(doc, ssurgeon_edits)
221
+ return self.process_request(request)
222
+
223
+ def process_one_operation(self, doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id=None, notes=None):
224
+ """
225
+ Convenience method - build one operation, then apply it
226
+ """
227
+ request = build_request_one_operation(doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id, notes)
228
+ return self.process_request(request)
229
+
230
+ SAMPLE_DOC = """
231
+ # sent_id = 271
232
+ # text = Hers is easy to clean.
233
+ # previous = What did the dealer like about Alex's car?
234
+ # comment = extraction/raising via "tough extraction" and clausal subject
235
+ 1 Hers hers PRON PRP Gender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs 3 nsubj _ _
236
+ 2 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 cop _ _
237
+ 3 easy easy ADJ JJ Degree=Pos 0 root _ _
238
+ 4 to to PART TO _ 5 mark _ _
239
+ 5 clean clean VERB VB VerbForm=Inf 3 csubj _ SpaceAfter=No
240
+ 6 . . PUNCT . _ 5 punct _ _
241
+ """
242
+
243
+ def main():
244
+ # for Windows, so that we aren't randomly printing garbage (or just failing to print)
245
+ try:
246
+ sys.stdout.reconfigure(encoding='utf-8')
247
+ except AttributeError:
248
+ # TODO: deprecate 3.6 support after the next release
249
+ pass
250
+
251
+ # The default semgrex detects sentences in the UD_English-Pronouns dataset which have both nsubj and csubj on the same word.
252
+ # The default ssurgeon transforms the unwanted csubj to advcl
253
+ # See https://github.com/UniversalDependencies/docs/issues/923
254
+ parser = argparse.ArgumentParser()
255
+ parser.add_argument('--input_file', type=str, default=None, help="Input file to process (otherwise will process a sample text)")
256
+ parser.add_argument('--output_file', type=str, default=None, help="Output file (otherwise will write to stdout)")
257
+ parser.add_argument('--input_dir', type=str, default=None, help="Input dir to process instead of a single file. Allows for reusing the Java program")
258
+ parser.add_argument('--input_filter', type=str, default=".*[.]conllu", help="Only process files from the input_dir that match this filter - regex, not shell filter. Default: %(default)s")
259
+ parser.add_argument('--no_input_filter', action='store_const', const=None, help="Remove the default input filename filter")
260
+ parser.add_argument('--output_dir', type=str, default=None, help="Output dir for writing files, necessary if using --input_dir")
261
+ parser.add_argument('--edit_file', type=str, default=None, help="File to get semgrex and ssurgeon rules from")
262
+ parser.add_argument('--semgrex', type=str, default="{}=source >nsubj {} >csubj=bad {}", help="Semgrex to apply to the text. A default detects words which have both an nsubj and a csubj. Default: %(default)s")
263
+ parser.add_argument('ssurgeon', type=str, default=["relabelNamedEdge -edge bad -reln advcl"], nargs="*", help="Ssurgeon edits to apply based on the Semgrex. Can have multiple edits in a row. A default exists to transform csubj into advcl. Default: %(default)s")
264
+ parser.add_argument('--print_input', dest='print_input', action='store_true', default=False, help="Print the input alongside the output - gets kind of noisy. Default: %(default)s")
265
+ parser.add_argument('--no_print_input', dest='print_input', action='store_false', help="Don't print the input alongside the output - gets kind of noisy")
266
+ args = parser.parse_args()
267
+
268
+ if args.edit_file:
269
+ ssurgeon_edits = read_ssurgeon_edits(args.edit_file)
270
+ else:
271
+ ssurgeon_edits = [SsurgeonEdit(args.semgrex, args.ssurgeon)]
272
+
273
+ if args.input_file:
274
+ docs = [CoNLL.conll2doc(input_file=args.input_file)]
275
+ outputs = [args.output_file]
276
+ input_output = zip(docs, outputs)
277
+ elif args.input_dir:
278
+ if not args.output_dir:
279
+ raise ValueError("Cannot process multiple files without knowing where to send them - please set --output_dir in order to use --input_dir")
280
+ if not os.path.exists(args.output_dir):
281
+ os.makedirs(args.output_dir)
282
+ def read_docs():
283
+ for doc_filename in os.listdir(args.input_dir):
284
+ if args.input_filter:
285
+ if not re.match(args.input_filter, doc_filename):
286
+ continue
287
+ doc_path = os.path.join(args.input_dir, doc_filename)
288
+ output_path = os.path.join(args.output_dir, doc_filename)
289
+ print("Processing %s to %s" % (doc_path, output_path))
290
+ yield CoNLL.conll2doc(input_file=doc_path), output_path
291
+ input_output = read_docs()
292
+ else:
293
+ docs = [CoNLL.conll2doc(input_str=SAMPLE_DOC)]
294
+ outputs = [None]
295
+ input_output = zip(docs, outputs)
296
+
297
+ for doc, output in input_output:
298
+ if args.print_input:
299
+ print("{:C}".format(doc))
300
+ ssurgeon_request = build_request(doc, ssurgeon_edits)
301
+ ssurgeon_response = send_ssurgeon_request(ssurgeon_request)
302
+ updated_doc = convert_response_to_doc(doc, ssurgeon_response)
303
+ if output is not None:
304
+ with open(output, "w", encoding="utf-8") as fout:
305
+ fout.write("{:C}\n\n".format(updated_doc))
306
+ else:
307
+ print("{:C}\n".format(updated_doc))
308
+
309
+ if __name__ == '__main__':
310
+ main()
stanza/stanza/tests/__init__.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for testing
3
+ """
4
+
5
+ import os
6
+ import re
7
+
8
+ # Environment Variables
9
+ # set this to specify working directory of tests
10
+ TEST_HOME_VAR = 'STANZA_TEST_HOME'
11
+
12
+ # Global Variables
13
+ TEST_DIR_BASE_NAME = 'stanza_test'
14
+
15
+ TEST_WORKING_DIR = os.getenv(TEST_HOME_VAR, None)
16
+ if not TEST_WORKING_DIR:
17
+ TEST_WORKING_DIR = os.path.join(os.getcwd(), TEST_DIR_BASE_NAME)
18
+
19
+ TEST_MODELS_DIR = f'{TEST_WORKING_DIR}/models'
20
+ TEST_CORENLP_DIR = f'{TEST_WORKING_DIR}/corenlp_dir'
21
+
22
+ # server resources
23
+ SERVER_TEST_PROPS = f'{TEST_WORKING_DIR}/scripts/external_server.properties'
24
+
25
+ # language resources
26
+ LANGUAGE_RESOURCES = {}
27
+
28
+ TOKENIZE_MODEL = 'tokenizer.pt'
29
+ MWT_MODEL = 'mwt_expander.pt'
30
+ POS_MODEL = 'tagger.pt'
31
+ POS_PRETRAIN = 'pretrain.pt'
32
+ LEMMA_MODEL = 'lemmatizer.pt'
33
+ DEPPARSE_MODEL = 'parser.pt'
34
+ DEPPARSE_PRETRAIN = 'pretrain.pt'
35
+
36
+ MODEL_FILES = [TOKENIZE_MODEL, MWT_MODEL, POS_MODEL, POS_PRETRAIN, LEMMA_MODEL, DEPPARSE_MODEL, DEPPARSE_PRETRAIN]
37
+
38
+ # English resources
39
+ EN_KEY = 'en'
40
+ EN_SHORTHAND = 'en_ewt'
41
+ # models
42
+ EN_MODELS_DIR = f'{TEST_WORKING_DIR}/models/{EN_SHORTHAND}_models'
43
+ EN_MODEL_FILES = [f'{EN_MODELS_DIR}/{EN_SHORTHAND}_{model_fname}' for model_fname in MODEL_FILES]
44
+
45
+ # French resources
46
+ FR_KEY = 'fr'
47
+ FR_SHORTHAND = 'fr_gsd'
48
+ # regression file paths
49
+ FR_TEST_IN = f'{TEST_WORKING_DIR}/in/fr_gsd.test.txt'
50
+ FR_TEST_OUT = f'{TEST_WORKING_DIR}/out/fr_gsd.test.txt.out'
51
+ FR_TEST_GOLD_OUT = f'{TEST_WORKING_DIR}/out/fr_gsd.test.txt.out.gold'
52
+ # models
53
+ FR_MODELS_DIR = f'{TEST_WORKING_DIR}/models/{FR_SHORTHAND}_models'
54
+ FR_MODEL_FILES = [f'{FR_MODELS_DIR}/{FR_SHORTHAND}_{model_fname}' for model_fname in MODEL_FILES]
55
+
56
+ # Other language resources
57
+ AR_SHORTHAND = 'ar_padt'
58
+ DE_SHORTHAND = 'de_gsd'
59
+ KK_SHORTHAND = 'kk_ktb'
60
+ KO_SHORTHAND = 'ko_gsd'
61
+
62
+
63
+ # utils for clean up
64
+ # only allow removal of dirs/files in this approved list
65
+ REMOVABLE_PATHS = ['en_ewt_models', 'en_ewt_tokenizer.pt', 'en_ewt_mwt_expander.pt', 'en_ewt_tagger.pt',
66
+ 'en_ewt.pretrain.pt', 'en_ewt_lemmatizer.pt', 'en_ewt_parser.pt', 'fr_gsd_models',
67
+ 'fr_gsd_tokenizer.pt', 'fr_gsd_mwt_expander.pt', 'fr_gsd_tagger.pt', 'fr_gsd.pretrain.pt',
68
+ 'fr_gsd_lemmatizer.pt', 'fr_gsd_parser.pt', 'ar_padt_models', 'ar_padt_tokenizer.pt',
69
+ 'ar_padt_mwt_expander.pt', 'ar_padt_tagger.pt', 'ar_padt.pretrain.pt', 'ar_padt_lemmatizer.pt',
70
+ 'ar_padt_parser.pt', 'de_gsd_models', 'de_gsd_tokenizer.pt', 'de_gsd_mwt_expander.pt',
71
+ 'de_gsd_tagger.pt', 'de_gsd.pretrain.pt', 'de_gsd_lemmatizer.pt', 'de_gsd_parser.pt',
72
+ 'kk_ktb_models', 'kk_ktb_tokenizer.pt', 'kk_ktb_mwt_expander.pt', 'kk_ktb_tagger.pt',
73
+ 'kk_ktb.pretrain.pt', 'kk_ktb_lemmatizer.pt', 'kk_ktb_parser.pt', 'ko_gsd_models',
74
+ 'ko_gsd_tokenizer.pt', 'ko_gsd_mwt_expander.pt', 'ko_gsd_tagger.pt', 'ko_gsd.pretrain.pt',
75
+ 'ko_gsd_lemmatizer.pt', 'ko_gsd_parser.pt']
76
+
77
+
78
+ def safe_rm(path_to_rm):
79
+ """
80
+ Safely remove a directory of files or a file
81
+ 1.) check path exists, files are files, dirs are dirs
82
+ 2.) only remove things on approved list REMOVABLE_PATHS
83
+ 3.) assert no longer exists
84
+ """
85
+ # just return if path doesn't exist
86
+ if not os.path.exists(path_to_rm):
87
+ return
88
+ # handle directory
89
+ if os.path.isdir(path_to_rm):
90
+ files_to_rm = [f'{path_to_rm}/{fname}' for fname in os.listdir(path_to_rm)]
91
+ dir_to_rm = path_to_rm
92
+ else:
93
+ files_to_rm = [path_to_rm]
94
+ dir_to_rm = None
95
+ # clear out files
96
+ for file_to_rm in files_to_rm:
97
+ if os.path.isfile(file_to_rm) and os.path.basename(file_to_rm) in REMOVABLE_PATHS:
98
+ os.remove(file_to_rm)
99
+ assert not os.path.exists(file_to_rm), f'Error removing: {file_to_rm}'
100
+ # clear out directory
101
+ if dir_to_rm is not None and os.path.isdir(dir_to_rm):
102
+ os.rmdir(dir_to_rm)
103
+ assert not os.path.exists(dir_to_rm), f'Error removing: {dir_to_rm}'
104
+
105
+ def compare_ignoring_whitespace(predicted, expected):
106
+ predicted = re.sub('[ \t]+', ' ', predicted.strip())
107
+ predicted = re.sub('\r\n', '\n', predicted)
108
+ expected = re.sub('[ \t]+', ' ', expected.strip())
109
+ expected = re.sub('\r\n', '\n', expected)
110
+ assert predicted == expected
111
+
stanza/stanza/tests/data/tiny_emb.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ 3 4
2
+ unban 1 2 3 4
3
+ mox 5 6 7 8
4
+ opal 9 10 11 12
stanza/stanza/tests/datasets/test_common.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test conllu manipulating routines in stanza/utils/dataset/common.py
3
+ """
4
+
5
+ import pytest
6
+
7
+
8
+ from stanza.utils.datasets.common import maybe_add_fake_dependencies
9
+ # from stanza.tests import *
10
+
11
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
12
+
13
+ DEPS_EXAMPLE="""
14
+ # text = Sh'reyan's antennae are hella thicc
15
+ 1 Sh'reyan Sh'reyan PROPN NNP Number=Sing 3 nmod:poss 3:nmod:poss SpaceAfter=No
16
+ 2 's 's PART POS _ 1 case 1:case _
17
+ 3 antennae antenna NOUN NNS Number=Plur 6 nsubj 6:nsubj _
18
+ 4 are be VERB VBP Mood=Ind|Tense=Pres|VerbForm=Fin 6 cop 6:cop _
19
+ 5 hella hella ADV RB _ 6 advmod 6:advmod _
20
+ 6 thicc thicc ADJ JJ Degree=Pos 0 root 0:root _
21
+ """.strip().split("\n")
22
+
23
+
24
+ ONLY_ROOT_EXAMPLE="""
25
+ # text = Sh'reyan's antennae are hella thicc
26
+ 1 Sh'reyan Sh'reyan PROPN NNP Number=Sing _ _ _ SpaceAfter=No
27
+ 2 's 's PART POS _ _ _ _ _
28
+ 3 antennae antenna NOUN NNS Number=Plur _ _ _ _
29
+ 4 are be VERB VBP Mood=Ind|Tense=Pres|VerbForm=Fin _ _ _ _
30
+ 5 hella hella ADV RB _ _ _ _ _
31
+ 6 thicc thicc ADJ JJ Degree=Pos 0 root 0:root _
32
+ """.strip().split("\n")
33
+
34
+ ONLY_ROOT_EXPECTED="""
35
+ # text = Sh'reyan's antennae are hella thicc
36
+ 1 Sh'reyan Sh'reyan PROPN NNP Number=Sing 6 dep _ SpaceAfter=No
37
+ 2 's 's PART POS _ 1 dep _ _
38
+ 3 antennae antenna NOUN NNS Number=Plur 1 dep _ _
39
+ 4 are be VERB VBP Mood=Ind|Tense=Pres|VerbForm=Fin 1 dep _ _
40
+ 5 hella hella ADV RB _ 1 dep _ _
41
+ 6 thicc thicc ADJ JJ Degree=Pos 0 root 0:root _
42
+ """.strip().split("\n")
43
+
44
+ NO_DEPS_EXAMPLE="""
45
+ # text = Sh'reyan's antennae are hella thicc
46
+ 1 Sh'reyan Sh'reyan PROPN NNP Number=Sing _ _ _ SpaceAfter=No
47
+ 2 's 's PART POS _ _ _ _ _
48
+ 3 antennae antenna NOUN NNS Number=Plur _ _ _ _
49
+ 4 are be VERB VBP Mood=Ind|Tense=Pres|VerbForm=Fin _ _ _ _
50
+ 5 hella hella ADV RB _ _ _ _ _
51
+ 6 thicc thicc ADJ JJ Degree=Pos _ _ _ _
52
+ """.strip().split("\n")
53
+
54
+ NO_DEPS_EXPECTED="""
55
+ # text = Sh'reyan's antennae are hella thicc
56
+ 1 Sh'reyan Sh'reyan PROPN NNP Number=Sing 0 root _ SpaceAfter=No
57
+ 2 's 's PART POS _ 1 dep _ _
58
+ 3 antennae antenna NOUN NNS Number=Plur 1 dep _ _
59
+ 4 are be VERB VBP Mood=Ind|Tense=Pres|VerbForm=Fin 1 dep _ _
60
+ 5 hella hella ADV RB _ 1 dep _ _
61
+ 6 thicc thicc ADJ JJ Degree=Pos 1 dep _ _
62
+ """.strip().split("\n")
63
+
64
+
65
+ def test_fake_deps_no_change():
66
+ result = maybe_add_fake_dependencies(DEPS_EXAMPLE)
67
+ assert result == DEPS_EXAMPLE
68
+
69
+ def test_fake_deps_all_tokens():
70
+ result = maybe_add_fake_dependencies(NO_DEPS_EXAMPLE)
71
+ assert result == NO_DEPS_EXPECTED
72
+
73
+
74
+ def test_fake_deps_only_root():
75
+ result = maybe_add_fake_dependencies(ONLY_ROOT_EXAMPLE)
76
+ assert result == ONLY_ROOT_EXPECTED
stanza/stanza/tests/datasets/test_vietnamese_renormalization.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import os
3
+
4
+ from stanza.utils.datasets.vietnamese import renormalize
5
+
6
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
7
+
8
+ def test_replace_all():
9
+ text = "SỌAmple tụy test file"
10
+ expected = "SOẠmple tuỵ test file"
11
+
12
+ assert renormalize.replace_all(text) == expected
13
+
14
+ def test_replace_file(tmp_path):
15
+ text = "SỌAmple tụy test file"
16
+ expected = "SOẠmple tuỵ test file"
17
+
18
+ orig = tmp_path / "orig.txt"
19
+ converted = tmp_path / "converted.txt"
20
+
21
+ with open(orig, "w", encoding="utf-8") as fout:
22
+ for i in range(10):
23
+ fout.write(text)
24
+ fout.write("\n")
25
+
26
+ renormalize.convert_file(orig, converted)
27
+
28
+ assert os.path.exists(converted)
29
+ with open(converted, encoding="utf-8") as fin:
30
+ lines = fin.readlines()
31
+
32
+ assert len(lines) == 10
33
+ for i in lines:
34
+ assert i.strip() == expected
35
+
stanza/stanza/tests/depparse/__init__.py ADDED
File without changes
stanza/stanza/tests/depparse/test_depparse_data.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test some pieces of the depparse dataloader
3
+ """
4
+ import pytest
5
+ from stanza.models.depparse.data import data_to_batches
6
+
7
+ pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
8
+
9
+ def make_fake_data(*lengths):
10
+ data = []
11
+ for i, length in enumerate(lengths):
12
+ word = chr(ord('A') + i)
13
+ chunk = [[word] * length]
14
+ data.append(chunk)
15
+ return data
16
+
17
+ def check_batches(batched_data, expected_sizes, expected_order):
18
+ for chunk, size in zip(batched_data, expected_sizes):
19
+ assert sum(len(x[0]) for x in chunk) == size
20
+ word_order = []
21
+ for chunk in batched_data:
22
+ for sentence in chunk:
23
+ word_order.append(sentence[0][0])
24
+ assert word_order == expected_order
25
+
26
+ def test_data_to_batches_eval_mode():
27
+ """
28
+ Tests the chunking of batches in eval_mode
29
+
30
+ A few options are tested, such as whether or not to sort and the maximum sentence size
31
+ """
32
+ data = make_fake_data(1, 2, 3)
33
+ batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None)
34
+ check_batches(batched_data[0], [5, 1], ['C', 'B', 'A'])
35
+
36
+ data = make_fake_data(1, 2, 6)
37
+ batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None)
38
+ check_batches(batched_data[0], [6, 3], ['C', 'B', 'A'])
39
+
40
+ data = make_fake_data(3, 2, 1)
41
+ batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None)
42
+ check_batches(batched_data[0], [5, 1], ['A', 'B', 'C'])
43
+
44
+ data = make_fake_data(3, 5, 2)
45
+ batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None)
46
+ check_batches(batched_data[0], [5, 5], ['B', 'A', 'C'])
47
+
48
+ data = make_fake_data(3, 5, 2)
49
+ batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=False, min_length_to_batch_separately=3)
50
+ check_batches(batched_data[0], [3, 5, 2], ['A', 'B', 'C'])
51
+
52
+ data = make_fake_data(4, 1, 1)
53
+ batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=False, min_length_to_batch_separately=3)
54
+ check_batches(batched_data[0], [4, 2], ['A', 'B', 'C'])
55
+
56
+ data = make_fake_data(1, 4, 1)
57
+ batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=False, min_length_to_batch_separately=3)
58
+ check_batches(batched_data[0], [1, 4, 1], ['A', 'B', 'C'])
59
+
60
+ if __name__ == '__main__':
61
+ test_data_to_batches()
62
+
stanza/stanza/tests/depparse/test_parser.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Run the tagger for a couple iterations on some fake data
3
+
4
+ Uses a couple sentences of UD_English-EWT as training/dev data
5
+ """
6
+
7
+ import os
8
+ import pytest
9
+
10
+ import torch
11
+
12
+ from stanza.models import parser
13
+ from stanza.models.common import pretrain
14
+ from stanza.models.depparse.trainer import Trainer
15
+ from stanza.tests import TEST_WORKING_DIR
16
+
17
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
18
+
19
+ TRAIN_DATA = """
20
+ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003
21
+ # text = DPA: Iraqi authorities announced that they had busted up 3 terrorist cells operating in Baghdad.
22
+ 1 DPA DPA PROPN NNP Number=Sing 0 root 0:root SpaceAfter=No
23
+ 2 : : PUNCT : _ 1 punct 1:punct _
24
+ 3 Iraqi Iraqi ADJ JJ Degree=Pos 4 amod 4:amod _
25
+ 4 authorities authority NOUN NNS Number=Plur 5 nsubj 5:nsubj _
26
+ 5 announced announce VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 1 parataxis 1:parataxis _
27
+ 6 that that SCONJ IN _ 9 mark 9:mark _
28
+ 7 they they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 9 nsubj 9:nsubj _
29
+ 8 had have AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 9 aux 9:aux _
30
+ 9 busted bust VERB VBN Tense=Past|VerbForm=Part 5 ccomp 5:ccomp _
31
+ 10 up up ADP RP _ 9 compound:prt 9:compound:prt _
32
+ 11 3 3 NUM CD NumForm=Digit|NumType=Card 13 nummod 13:nummod _
33
+ 12 terrorist terrorist ADJ JJ Degree=Pos 13 amod 13:amod _
34
+ 13 cells cell NOUN NNS Number=Plur 9 obj 9:obj _
35
+ 14 operating operate VERB VBG VerbForm=Ger 13 acl 13:acl _
36
+ 15 in in ADP IN _ 16 case 16:case _
37
+ 16 Baghdad Baghdad PROPN NNP Number=Sing 14 obl 14:obl:in SpaceAfter=No
38
+ 17 . . PUNCT . _ 1 punct 1:punct _
39
+
40
+ # sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004
41
+ # text = Two of them were being run by 2 officials of the Ministry of the Interior!
42
+ 1 Two two NUM CD NumForm=Word|NumType=Card 6 nsubj:pass 6:nsubj:pass _
43
+ 2 of of ADP IN _ 3 case 3:case _
44
+ 3 them they PRON PRP Case=Acc|Number=Plur|Person=3|PronType=Prs 1 nmod 1:nmod:of _
45
+ 4 were be AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _
46
+ 5 being be AUX VBG VerbForm=Ger 6 aux:pass 6:aux:pass _
47
+ 6 run run VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _
48
+ 7 by by ADP IN _ 9 case 9:case _
49
+ 8 2 2 NUM CD NumForm=Digit|NumType=Card 9 nummod 9:nummod _
50
+ 9 officials official NOUN NNS Number=Plur 6 obl 6:obl:by _
51
+ 10 of of ADP IN _ 12 case 12:case _
52
+ 11 the the DET DT Definite=Def|PronType=Art 12 det 12:det _
53
+ 12 Ministry Ministry PROPN NNP Number=Sing 9 nmod 9:nmod:of _
54
+ 13 of of ADP IN _ 15 case 15:case _
55
+ 14 the the DET DT Definite=Def|PronType=Art 15 det 15:det _
56
+ 15 Interior Interior PROPN NNP Number=Sing 12 nmod 12:nmod:of SpaceAfter=No
57
+ 16 ! ! PUNCT . _ 6 punct 6:punct _
58
+
59
+ """.lstrip()
60
+
61
+
62
+ DEV_DATA = """
63
+ 1 From from ADP IN _ 3 case 3:case _
64
+ 2 the the DET DT Definite=Def|PronType=Art 3 det 3:det _
65
+ 3 AP AP PROPN NNP Number=Sing 4 obl 4:obl:from _
66
+ 4 comes come VERB VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root _
67
+ 5 this this DET DT Number=Sing|PronType=Dem 6 det 6:det _
68
+ 6 story story NOUN NN Number=Sing 4 nsubj 4:nsubj _
69
+ 7 : : PUNCT : _ 4 punct 4:punct _
70
+
71
+ """.lstrip()
72
+
73
+
74
+
75
+ class TestParser:
76
+ @pytest.fixture(scope="class")
77
+ def wordvec_pretrain_file(self):
78
+ return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'
79
+
80
+ def run_training(self, tmp_path, wordvec_pretrain_file, train_text, dev_text, augment_nopunct=False, extra_args=None):
81
+ """
82
+ Run the training for a few iterations, load & return the model
83
+ """
84
+ train_file = str(tmp_path / "train.conllu")
85
+ dev_file = str(tmp_path / "dev.conllu")
86
+ pred_file = str(tmp_path / "pred.conllu")
87
+
88
+ save_name = "test_parser.pt"
89
+ save_file = str(tmp_path / save_name)
90
+
91
+ with open(train_file, "w", encoding="utf-8") as fout:
92
+ fout.write(train_text)
93
+
94
+ with open(dev_file, "w", encoding="utf-8") as fout:
95
+ fout.write(dev_text)
96
+
97
+ args = ["--wordvec_pretrain_file", wordvec_pretrain_file,
98
+ "--train_file", train_file,
99
+ "--eval_file", dev_file,
100
+ "--output_file", pred_file,
101
+ "--log_step", "10",
102
+ "--eval_interval", "20",
103
+ "--max_steps", "100",
104
+ "--shorthand", "en_test",
105
+ "--save_dir", str(tmp_path),
106
+ "--save_name", save_name,
107
+ # in case we are doing a bert test
108
+ "--bert_start_finetuning", "10",
109
+ "--bert_warmup_steps", "10",
110
+ "--lang", "en"]
111
+ if not augment_nopunct:
112
+ args.extend(["--augment_nopunct", "0.0"])
113
+ if extra_args is not None:
114
+ args = args + extra_args
115
+ trainer = parser.main(args)
116
+
117
+ assert os.path.exists(save_file)
118
+ pt = pretrain.Pretrain(wordvec_pretrain_file)
119
+ # test loading the saved model
120
+ saved_model = Trainer(pretrain=pt, model_file=save_file)
121
+ return trainer
122
+
123
+ def test_train(self, tmp_path, wordvec_pretrain_file):
124
+ """
125
+ Simple test of a few 'epochs' of tagger training
126
+ """
127
+ self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA)
128
+
129
+ def test_with_bert_nlayers(self, tmp_path, wordvec_pretrain_file):
130
+ self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_hidden_layers', '2'])
131
+
132
+ def test_with_bert_finetuning(self, tmp_path, wordvec_pretrain_file):
133
+ trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune', '--bert_hidden_layers', '2'])
134
+ assert 'bert_optimizer' in trainer.optimizer.keys()
135
+ assert 'bert_scheduler' in trainer.scheduler.keys()
136
+
137
+ def test_with_bert_finetuning_resaved(self, tmp_path, wordvec_pretrain_file):
138
+ """
139
+ Check that if we save, then load, then save a model with a finetuned bert, that bert isn't lost
140
+ """
141
+ trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune', '--bert_hidden_layers', '2'])
142
+ assert 'bert_optimizer' in trainer.optimizer.keys()
143
+ assert 'bert_scheduler' in trainer.scheduler.keys()
144
+
145
+ save_name = trainer.args['save_name']
146
+ filename = tmp_path / save_name
147
+ assert os.path.exists(filename)
148
+ checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
149
+ assert any(x.startswith("bert_model") for x in checkpoint['model'].keys())
150
+
151
+ # Test loading the saved model, saving it, and still having bert in it
152
+ # even if we have set bert_finetune to False for this incarnation
153
+ pt = pretrain.Pretrain(wordvec_pretrain_file)
154
+ args = {"bert_finetune": False}
155
+ saved_model = Trainer(pretrain=pt, model_file=filename, args=args)
156
+
157
+ saved_model.save(filename)
158
+
159
+ # This is the part that would fail if the force_bert_saved option did not exist
160
+ checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
161
+ assert any(x.startswith("bert_model") for x in checkpoint['model'].keys())
162
+
163
+ def test_with_peft(self, tmp_path, wordvec_pretrain_file):
164
+ trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune', '--bert_hidden_layers', '2', '--use_peft'])
165
+ assert 'bert_optimizer' in trainer.optimizer.keys()
166
+ assert 'bert_scheduler' in trainer.scheduler.keys()
167
+
168
+ def test_single_optimizer_checkpoint(self, tmp_path, wordvec_pretrain_file):
169
+ trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--optim', 'adam'])
170
+
171
+ save_dir = trainer.args['save_dir']
172
+ save_name = trainer.args['save_name']
173
+ checkpoint_name = trainer.args["checkpoint_save_name"]
174
+
175
+ assert os.path.exists(os.path.join(save_dir, save_name))
176
+ assert checkpoint_name is not None
177
+ assert os.path.exists(checkpoint_name)
178
+
179
+ assert len(trainer.optimizer) == 1
180
+ for opt in trainer.optimizer.values():
181
+ assert isinstance(opt, torch.optim.Adam)
182
+
183
+ pt = pretrain.Pretrain(wordvec_pretrain_file)
184
+ checkpoint = Trainer(args=trainer.args, pretrain=pt, model_file=checkpoint_name)
185
+ assert checkpoint.optimizer is not None
186
+ assert len(checkpoint.optimizer) == 1
187
+ for opt in checkpoint.optimizer.values():
188
+ assert isinstance(opt, torch.optim.Adam)
189
+
190
+ def test_two_optimizers_checkpoint(self, tmp_path, wordvec_pretrain_file):
191
+ trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--optim', 'adam', '--second_optim', 'sgd', '--second_optim_start_step', '40'])
192
+
193
+ save_dir = trainer.args['save_dir']
194
+ save_name = trainer.args['save_name']
195
+ checkpoint_name = trainer.args["checkpoint_save_name"]
196
+
197
+ assert os.path.exists(os.path.join(save_dir, save_name))
198
+ assert checkpoint_name is not None
199
+ assert os.path.exists(checkpoint_name)
200
+
201
+ assert len(trainer.optimizer) == 1
202
+ for opt in trainer.optimizer.values():
203
+ assert isinstance(opt, torch.optim.SGD)
204
+
205
+ pt = pretrain.Pretrain(wordvec_pretrain_file)
206
+ checkpoint = Trainer(args=trainer.args, pretrain=pt, model_file=checkpoint_name)
207
+ assert checkpoint.optimizer is not None
208
+ assert len(checkpoint.optimizer) == 1
209
+ for opt in trainer.optimizer.values():
210
+ assert isinstance(opt, torch.optim.SGD)
211
+
stanza/stanza/tests/langid/__init__.py ADDED
File without changes
stanza/stanza/tests/langid/test_multilingual.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests specifically for the MultilingualPipeline
3
+ """
4
+
5
+ from collections import defaultdict
6
+
7
+ import pytest
8
+
9
+ from stanza.pipeline.multilingual import MultilingualPipeline
10
+
11
+ from stanza.tests import TEST_MODELS_DIR
12
+
13
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
14
+
15
+ def run_multilingual_pipeline(en_has_dependencies=True, fr_has_dependencies=True, **kwargs):
16
+ english_text = "This is an English sentence."
17
+ english_words = ["This", "is", "an", "English", "sentence", "."]
18
+ english_deps_gold = "\n".join((
19
+ "('This', 5, 'nsubj')",
20
+ "('is', 5, 'cop')",
21
+ "('an', 5, 'det')",
22
+ "('English', 5, 'amod')",
23
+ "('sentence', 0, 'root')",
24
+ "('.', 5, 'punct')"
25
+ ))
26
+ if not en_has_dependencies:
27
+ english_deps_gold = ""
28
+
29
+ french_text = "C'est une phrase française."
30
+ french_words = ["C'", "est", "une", "phrase", "française", "."]
31
+ french_deps_gold = "\n".join((
32
+ "(\"C'\", 4, 'nsubj')",
33
+ "('est', 4, 'cop')",
34
+ "('une', 4, 'det')",
35
+ "('phrase', 0, 'root')",
36
+ "('française', 4, 'amod')",
37
+ "('.', 4, 'punct')"
38
+ ))
39
+ if not fr_has_dependencies:
40
+ french_deps_gold = ""
41
+
42
+ if 'lang_configs' in kwargs:
43
+ nlp = MultilingualPipeline(model_dir=TEST_MODELS_DIR, download_method=None, **kwargs)
44
+ else:
45
+ lang_configs = {"en": {"processors": "tokenize,pos,lemma,depparse"},
46
+ "fr": {"processors": "tokenize,pos,lemma,depparse"}}
47
+ nlp = MultilingualPipeline(model_dir=TEST_MODELS_DIR, download_method=None, lang_configs=lang_configs, **kwargs)
48
+ docs = [english_text, french_text]
49
+ docs = nlp(docs)
50
+
51
+ assert docs[0].lang == "en"
52
+ assert len(docs[0].sentences) == 1
53
+ assert [x.text for x in docs[0].sentences[0].words] == english_words
54
+ assert docs[0].sentences[0].dependencies_string() == english_deps_gold
55
+
56
+ assert len(docs[1].sentences) == 1
57
+ assert docs[1].lang == "fr"
58
+ assert [x.text for x in docs[1].sentences[0].words] == french_words
59
+ assert docs[1].sentences[0].dependencies_string() == french_deps_gold
60
+
61
+
62
+ def test_multilingual_pipeline():
63
+ """
64
+ Basic test of multilingual pipeline
65
+ """
66
+ run_multilingual_pipeline()
67
+
68
+ def test_multilingual_pipeline_small_cache():
69
+ """
70
+ Test with the cache size 1
71
+ """
72
+ run_multilingual_pipeline(max_cache_size=1)
73
+
74
+
75
+ def test_multilingual_config():
76
+ """
77
+ Test with only tokenize for the EN pipeline
78
+ """
79
+ lang_configs = {
80
+ "en": {"processors": "tokenize"}
81
+ }
82
+
83
+ run_multilingual_pipeline(en_has_dependencies=False, lang_configs=lang_configs)
84
+
85
+ def test_multilingual_processors_limited():
86
+ """
87
+ Test loading an available subset of processors
88
+ """
89
+ run_multilingual_pipeline(en_has_dependencies=False, fr_has_dependencies=False, lang_configs={}, processors="tokenize")
90
+ run_multilingual_pipeline(en_has_dependencies=True, fr_has_dependencies=False, lang_configs={"en": {"processors": "tokenize,pos,lemma,depparse"}}, processors="tokenize")
91
+ # this should not fail, as it will drop the zzzzzzzzzz processor for the languages which don't have it
92
+ run_multilingual_pipeline(en_has_dependencies=False, fr_has_dependencies=False, lang_configs={}, processors="tokenize,zzzzzzzzzz")
93
+
94
+
95
+ def test_defaultdict_config():
96
+ """
97
+ Test that you can pass in a defaultdict for the lang_configs argument
98
+ """
99
+ lang_configs = defaultdict(lambda: dict(processors="tokenize"))
100
+ run_multilingual_pipeline(en_has_dependencies=False, fr_has_dependencies=False, lang_configs=lang_configs)
101
+
102
+ lang_configs = defaultdict(lambda: dict(processors="tokenize"))
103
+ lang_configs["en"] = {"processors": "tokenize,pos,lemma,depparse"}
104
+ run_multilingual_pipeline(en_has_dependencies=True, fr_has_dependencies=False, lang_configs=lang_configs)
stanza/stanza/tests/lemma_classifier/__init__.py ADDED
File without changes
stanza/stanza/tests/lemma_classifier/test_training.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+
4
+ import pytest
5
+
6
+ pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
7
+
8
+ from stanza.models.lemma_classifier import train_lstm_model
9
+ from stanza.models.lemma_classifier import train_transformer_model
10
+ from stanza.models.lemma_classifier.base_model import LemmaClassifier
11
+ from stanza.models.lemma_classifier.evaluate_models import evaluate_model
12
+
13
+ from stanza.tests import TEST_WORKING_DIR
14
+ from stanza.tests.lemma_classifier.test_data_preparation import convert_english_dataset
15
+
16
+ @pytest.fixture(scope="module")
17
+ def pretrain_file():
18
+ return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'
19
+
20
+ def test_train_lstm(tmp_path, pretrain_file):
21
+ converted_files = convert_english_dataset(tmp_path)
22
+
23
+ save_name = str(tmp_path / 'lemma.pt')
24
+
25
+ train_file = converted_files[0]
26
+ eval_file = converted_files[1]
27
+ train_args = ['--wordvec_pretrain_file', pretrain_file,
28
+ '--save_name', save_name,
29
+ '--train_file', train_file,
30
+ '--eval_file', eval_file]
31
+ trainer = train_lstm_model.main(train_args)
32
+
33
+ evaluate_model(trainer.model, eval_file)
34
+ # test that loading the model works
35
+ model = LemmaClassifier.load(save_name, None)
36
+
37
+ def test_train_transformer(tmp_path, pretrain_file):
38
+ converted_files = convert_english_dataset(tmp_path)
39
+
40
+ save_name = str(tmp_path / 'lemma.pt')
41
+
42
+ train_file = converted_files[0]
43
+ eval_file = converted_files[1]
44
+ train_args = ['--bert_model', 'hf-internal-testing/tiny-bert',
45
+ '--save_name', save_name,
46
+ '--train_file', train_file,
47
+ '--eval_file', eval_file]
48
+ trainer = train_transformer_model.main(train_args)
49
+
50
+ evaluate_model(trainer.model, eval_file)
51
+
52
+ # test that loading the model works
53
+ model = LemmaClassifier.load(save_name, None)
stanza/stanza/tests/mwt/__init__.py ADDED
File without changes