Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- stanza/demo/Stanza_CoreNLP_Interface.ipynb +485 -0
- stanza/demo/en_test.conllu.txt +79 -0
- stanza/demo/semgrex visualization.ipynb +367 -0
- stanza/images/stanza-logo.png +0 -0
- stanza/stanza/__init__.py +27 -0
- stanza/stanza/models/charlm.py +357 -0
- stanza/stanza/models/identity_lemmatizer.py +66 -0
- stanza/stanza/models/lang_identifier.py +236 -0
- stanza/stanza/models/mwt_expander.py +322 -0
- stanza/stanza/models/ner_tagger.py +492 -0
- stanza/stanza/models/tagger.py +461 -0
- stanza/stanza/models/tokenizer.py +258 -0
- stanza/stanza/models/wl_coref.py +226 -0
- stanza/stanza/pipeline/__init__.py +0 -0
- stanza/stanza/pipeline/constituency_processor.py +81 -0
- stanza/stanza/pipeline/core.py +509 -0
- stanza/stanza/pipeline/coref_processor.py +154 -0
- stanza/stanza/pipeline/depparse_processor.py +78 -0
- stanza/stanza/pipeline/langid_processor.py +127 -0
- stanza/stanza/pipeline/lemma_processor.py +126 -0
- stanza/stanza/pipeline/multilingual.py +188 -0
- stanza/stanza/pipeline/mwt_processor.py +59 -0
- stanza/stanza/pipeline/pos_processor.py +89 -0
- stanza/stanza/pipeline/processor.py +293 -0
- stanza/stanza/pipeline/registry.py +8 -0
- stanza/stanza/pipeline/sentiment_processor.py +78 -0
- stanza/stanza/pipeline/tokenize_processor.py +185 -0
- stanza/stanza/protobuf/CoreNLP_pb2.py +686 -0
- stanza/stanza/protobuf/__init__.py +52 -0
- stanza/stanza/resources/common.py +619 -0
- stanza/stanza/resources/default_packages.py +909 -0
- stanza/stanza/resources/installation.py +148 -0
- stanza/stanza/resources/prepare_resources.py +670 -0
- stanza/stanza/server/__init__.py +10 -0
- stanza/stanza/server/annotator.py +138 -0
- stanza/stanza/server/client.py +779 -0
- stanza/stanza/server/semgrex.py +170 -0
- stanza/stanza/server/ssurgeon.py +310 -0
- stanza/stanza/tests/__init__.py +111 -0
- stanza/stanza/tests/data/tiny_emb.txt +4 -0
- stanza/stanza/tests/datasets/test_common.py +76 -0
- stanza/stanza/tests/datasets/test_vietnamese_renormalization.py +35 -0
- stanza/stanza/tests/depparse/__init__.py +0 -0
- stanza/stanza/tests/depparse/test_depparse_data.py +62 -0
- stanza/stanza/tests/depparse/test_parser.py +211 -0
- stanza/stanza/tests/langid/__init__.py +0 -0
- stanza/stanza/tests/langid/test_multilingual.py +104 -0
- stanza/stanza/tests/lemma_classifier/__init__.py +0 -0
- stanza/stanza/tests/lemma_classifier/test_training.py +53 -0
- stanza/stanza/tests/mwt/__init__.py +0 -0
stanza/demo/Stanza_CoreNLP_Interface.ipynb
ADDED
|
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 0,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"colab": {
|
| 6 |
+
"name": "Stanza-CoreNLP-Interface.ipynb",
|
| 7 |
+
"provenance": [],
|
| 8 |
+
"collapsed_sections": [],
|
| 9 |
+
"toc_visible": true
|
| 10 |
+
},
|
| 11 |
+
"kernelspec": {
|
| 12 |
+
"name": "python3",
|
| 13 |
+
"display_name": "Python 3"
|
| 14 |
+
}
|
| 15 |
+
},
|
| 16 |
+
"cells": [
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "markdown",
|
| 19 |
+
"metadata": {
|
| 20 |
+
"id": "2-4lzQTC9yxG",
|
| 21 |
+
"colab_type": "text"
|
| 22 |
+
},
|
| 23 |
+
"source": [
|
| 24 |
+
"# Stanza: A Tutorial on the Python CoreNLP Interface\n",
|
| 25 |
+
"\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"While the Stanza library implements accurate neural network modules for basic functionalities such as part-of-speech tagging and dependency parsing, the [Stanford CoreNLP Java library](https://stanfordnlp.github.io/CoreNLP/) has been developed for years and offers more complementary features such as coreference resolution and relation extraction. To unlock these features, the Stanza library also offers an officially maintained Python interface to the CoreNLP Java library. This interface allows you to get NLP anntotations from CoreNLP by writing native Python code.\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"\n",
|
| 32 |
+
"This tutorial walks you through the installation, setup and basic usage of this Python CoreNLP interface. If you want to learn how to use the neural network components in Stanza, please refer to other tutorials."
|
| 33 |
+
]
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"cell_type": "markdown",
|
| 37 |
+
"metadata": {
|
| 38 |
+
"id": "YpKwWeVkASGt",
|
| 39 |
+
"colab_type": "text"
|
| 40 |
+
},
|
| 41 |
+
"source": [
|
| 42 |
+
"## 1. Installation\n",
|
| 43 |
+
"\n",
|
| 44 |
+
"Before the installation starts, please make sure that you have Python 3 and Java installed on your computer. Since Colab already has them installed, we'll skip this procedure in this notebook."
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "markdown",
|
| 49 |
+
"metadata": {
|
| 50 |
+
"id": "k1Az2ECuAfG8",
|
| 51 |
+
"colab_type": "text"
|
| 52 |
+
},
|
| 53 |
+
"source": [
|
| 54 |
+
"### Installing Stanza\n",
|
| 55 |
+
"\n",
|
| 56 |
+
"Installing and importing Stanza are as simple as running the following commands:"
|
| 57 |
+
]
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"cell_type": "code",
|
| 61 |
+
"metadata": {
|
| 62 |
+
"id": "xiFwYAgW4Mss",
|
| 63 |
+
"colab_type": "code",
|
| 64 |
+
"colab": {}
|
| 65 |
+
},
|
| 66 |
+
"source": [
|
| 67 |
+
"# Install stanza; note that the prefix \"!\" is not needed if you are running in a terminal\n",
|
| 68 |
+
"!pip install stanza\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"# Import stanza\n",
|
| 71 |
+
"import stanza"
|
| 72 |
+
],
|
| 73 |
+
"execution_count": null,
|
| 74 |
+
"outputs": []
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
"cell_type": "markdown",
|
| 78 |
+
"metadata": {
|
| 79 |
+
"id": "2zFvaA8_A32_",
|
| 80 |
+
"colab_type": "text"
|
| 81 |
+
},
|
| 82 |
+
"source": [
|
| 83 |
+
"### Setting up Stanford CoreNLP\n",
|
| 84 |
+
"\n",
|
| 85 |
+
"In order for the interface to work, the Stanford CoreNLP library has to be installed and a `CORENLP_HOME` environment variable has to be pointed to the installation location.\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"Here we are going to show you how to download and install the CoreNLP library on your machine, with Stanza's installation command:"
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"cell_type": "code",
|
| 92 |
+
"metadata": {
|
| 93 |
+
"id": "MgK6-LPV-OdA",
|
| 94 |
+
"colab_type": "code",
|
| 95 |
+
"colab": {}
|
| 96 |
+
},
|
| 97 |
+
"source": [
|
| 98 |
+
"# Download the Stanford CoreNLP package with Stanza's installation command\n",
|
| 99 |
+
"# This'll take several minutes, depending on the network speed\n",
|
| 100 |
+
"corenlp_dir = './corenlp'\n",
|
| 101 |
+
"stanza.install_corenlp(dir=corenlp_dir)\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"# Set the CORENLP_HOME environment variable to point to the installation location\n",
|
| 104 |
+
"import os\n",
|
| 105 |
+
"os.environ[\"CORENLP_HOME\"] = corenlp_dir"
|
| 106 |
+
],
|
| 107 |
+
"execution_count": null,
|
| 108 |
+
"outputs": []
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"cell_type": "markdown",
|
| 112 |
+
"metadata": {
|
| 113 |
+
"id": "Jdq8MT-NAhKj",
|
| 114 |
+
"colab_type": "text"
|
| 115 |
+
},
|
| 116 |
+
"source": [
|
| 117 |
+
"That's all for the installation! 🎉 We can now double check if the installation is successful by listing files in the CoreNLP directory. You should be able to see a number of `.jar` files by running the following command:"
|
| 118 |
+
]
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"cell_type": "code",
|
| 122 |
+
"metadata": {
|
| 123 |
+
"id": "K5eIOaJp_tuo",
|
| 124 |
+
"colab_type": "code",
|
| 125 |
+
"colab": {}
|
| 126 |
+
},
|
| 127 |
+
"source": [
|
| 128 |
+
"# Examine the CoreNLP installation folder to make sure the installation is successful\n",
|
| 129 |
+
"!ls $CORENLP_HOME"
|
| 130 |
+
],
|
| 131 |
+
"execution_count": null,
|
| 132 |
+
"outputs": []
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"cell_type": "markdown",
|
| 136 |
+
"metadata": {
|
| 137 |
+
"id": "S0xb9BHt__gx",
|
| 138 |
+
"colab_type": "text"
|
| 139 |
+
},
|
| 140 |
+
"source": [
|
| 141 |
+
"**Note 1**:\n",
|
| 142 |
+
"If you are want to use the interface in a terminal (instead of a Colab notebook), you can properly set the `CORENLP_HOME` environment variable with:\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"```bash\n",
|
| 145 |
+
"export CORENLP_HOME=path_to_corenlp_dir\n",
|
| 146 |
+
"```\n",
|
| 147 |
+
"\n",
|
| 148 |
+
"Here we instead set this variable with the Python `os` library, simply because `export` command is not well-supported in Colab notebook.\n",
|
| 149 |
+
"\n",
|
| 150 |
+
"\n",
|
| 151 |
+
"**Note 2**:\n",
|
| 152 |
+
"The `stanza.install_corenlp()` function is only available since Stanza v1.1.1. If you are using an earlier version of Stanza, please check out our [manual installation page](https://stanfordnlp.github.io/stanza/client_setup.html#manual-installation) for how to install CoreNLP on your computer.\n",
|
| 153 |
+
"\n",
|
| 154 |
+
"**Note 3**:\n",
|
| 155 |
+
"Besides the installation function, we also provide a `stanza.download_corenlp_models()` function to help you download additional CoreNLP models for different languages that are not shipped with the default installation. Check out our [automatic installation website page](https://stanfordnlp.github.io/stanza/client_setup.html#automated-installation) for more information on how to use it."
|
| 156 |
+
]
|
| 157 |
+
},
|
| 158 |
+
{
|
| 159 |
+
"cell_type": "markdown",
|
| 160 |
+
"metadata": {
|
| 161 |
+
"id": "xJsuO6D8D05q",
|
| 162 |
+
"colab_type": "text"
|
| 163 |
+
},
|
| 164 |
+
"source": [
|
| 165 |
+
"## 2. Annotating Text with CoreNLP Interface"
|
| 166 |
+
]
|
| 167 |
+
},
|
| 168 |
+
{
|
| 169 |
+
"cell_type": "markdown",
|
| 170 |
+
"metadata": {
|
| 171 |
+
"id": "dZNHxXHkH1K2",
|
| 172 |
+
"colab_type": "text"
|
| 173 |
+
},
|
| 174 |
+
"source": [
|
| 175 |
+
"### Constructing CoreNLPClient\n",
|
| 176 |
+
"\n",
|
| 177 |
+
"At a high level, the CoreNLP Python interface works by first starting a background Java CoreNLP server process, and then initializing a client instance in Python which can pass the text to the background server process, and accept the returned annotation results.\n",
|
| 178 |
+
"\n",
|
| 179 |
+
"We wrap these functionalities in a `CoreNLPClient` class. Therefore, we need to start by importing this class from Stanza."
|
| 180 |
+
]
|
| 181 |
+
},
|
| 182 |
+
{
|
| 183 |
+
"cell_type": "code",
|
| 184 |
+
"metadata": {
|
| 185 |
+
"id": "LS4OKnqJ8wui",
|
| 186 |
+
"colab_type": "code",
|
| 187 |
+
"colab": {}
|
| 188 |
+
},
|
| 189 |
+
"source": [
|
| 190 |
+
"# Import client module\n",
|
| 191 |
+
"from stanza.server import CoreNLPClient"
|
| 192 |
+
],
|
| 193 |
+
"execution_count": null,
|
| 194 |
+
"outputs": []
|
| 195 |
+
},
|
| 196 |
+
{
|
| 197 |
+
"cell_type": "markdown",
|
| 198 |
+
"metadata": {
|
| 199 |
+
"id": "WP4Dz6PIJHeL",
|
| 200 |
+
"colab_type": "text"
|
| 201 |
+
},
|
| 202 |
+
"source": [
|
| 203 |
+
"After the import is done, we can construct a `CoreNLPClient` instance. The constructor method takes a Python list of annotator names as argument. Here let's explore some basic annotators including tokenization, sentence split, part-of-speech tagging, lemmatization and named entity recognition (NER). \n",
|
| 204 |
+
"\n",
|
| 205 |
+
"Additionally, the client constructor accepts a `memory` argument, which specifies how much memory will be allocated to the background Java process. An `endpoint` option can be used to specify a port number used by the communication between the server and the client. The default port is 9000. However, since this port is pre-occupied by a system process in Colab, we'll manually set it to 9001 in the following example.\n",
|
| 206 |
+
"\n",
|
| 207 |
+
"Also, here we manually set `be_quiet=True` to avoid an IO issue in colab notebook. You should be able to use `be_quiet=False` on your own computer, which will print detailed logging information from CoreNLP during usage.\n",
|
| 208 |
+
"\n",
|
| 209 |
+
"For more options in constructing the clients, please refer to the [CoreNLP Client Options List](https://stanfordnlp.github.io/stanza/corenlp_client.html#corenlp-client-options)."
|
| 210 |
+
]
|
| 211 |
+
},
|
| 212 |
+
{
|
| 213 |
+
"cell_type": "code",
|
| 214 |
+
"metadata": {
|
| 215 |
+
"id": "mbOBugvd9JaM",
|
| 216 |
+
"colab_type": "code",
|
| 217 |
+
"colab": {}
|
| 218 |
+
},
|
| 219 |
+
"source": [
|
| 220 |
+
"# Construct a CoreNLPClient with some basic annotators, a memory allocation of 4GB, and port number 9001\n",
|
| 221 |
+
"client = CoreNLPClient(\n",
|
| 222 |
+
" annotators=['tokenize','ssplit', 'pos', 'lemma', 'ner'], \n",
|
| 223 |
+
" memory='4G', \n",
|
| 224 |
+
" endpoint='http://localhost:9001',\n",
|
| 225 |
+
" be_quiet=True)\n",
|
| 226 |
+
"print(client)\n",
|
| 227 |
+
"\n",
|
| 228 |
+
"# Start the background server and wait for some time\n",
|
| 229 |
+
"# Note that in practice this is totally optional, as by default the server will be started when the first annotation is performed\n",
|
| 230 |
+
"client.start()\n",
|
| 231 |
+
"import time; time.sleep(10)"
|
| 232 |
+
],
|
| 233 |
+
"execution_count": null,
|
| 234 |
+
"outputs": []
|
| 235 |
+
},
|
| 236 |
+
{
|
| 237 |
+
"cell_type": "markdown",
|
| 238 |
+
"metadata": {
|
| 239 |
+
"id": "kgTiVjNydmIW",
|
| 240 |
+
"colab_type": "text"
|
| 241 |
+
},
|
| 242 |
+
"source": [
|
| 243 |
+
"After the above code block finishes executing, if you print the background processes, you should be able to find the Java CoreNLP server running."
|
| 244 |
+
]
|
| 245 |
+
},
|
| 246 |
+
{
|
| 247 |
+
"cell_type": "code",
|
| 248 |
+
"metadata": {
|
| 249 |
+
"id": "spZrJ-oFdkdF",
|
| 250 |
+
"colab_type": "code",
|
| 251 |
+
"colab": {}
|
| 252 |
+
},
|
| 253 |
+
"source": [
|
| 254 |
+
"# Print background processes and look for java\n",
|
| 255 |
+
"# You should be able to see a StanfordCoreNLPServer java process running in the background\n",
|
| 256 |
+
"!ps -o pid,cmd | grep java"
|
| 257 |
+
],
|
| 258 |
+
"execution_count": null,
|
| 259 |
+
"outputs": []
|
| 260 |
+
},
|
| 261 |
+
{
|
| 262 |
+
"cell_type": "markdown",
|
| 263 |
+
"metadata": {
|
| 264 |
+
"id": "KxJeJ0D2LoOs",
|
| 265 |
+
"colab_type": "text"
|
| 266 |
+
},
|
| 267 |
+
"source": [
|
| 268 |
+
"### Annotating Text\n",
|
| 269 |
+
"\n",
|
| 270 |
+
"Annotating a piece of text is as simple as passing the text into an `annotate` function of the client object. After the annotation is complete, a `Document` object will be returned with all annotations.\n",
|
| 271 |
+
"\n",
|
| 272 |
+
"Note that although in general annotations are very fast, the first annotation might take a while to complete in the notebook. Please stay patient."
|
| 273 |
+
]
|
| 274 |
+
},
|
| 275 |
+
{
|
| 276 |
+
"cell_type": "code",
|
| 277 |
+
"metadata": {
|
| 278 |
+
"id": "s194RnNg5z95",
|
| 279 |
+
"colab_type": "code",
|
| 280 |
+
"colab": {}
|
| 281 |
+
},
|
| 282 |
+
"source": [
|
| 283 |
+
"# Annotate some text\n",
|
| 284 |
+
"text = \"Albert Einstein was a German-born theoretical physicist. He developed the theory of relativity.\"\n",
|
| 285 |
+
"document = client.annotate(text)\n",
|
| 286 |
+
"print(type(document))"
|
| 287 |
+
],
|
| 288 |
+
"execution_count": null,
|
| 289 |
+
"outputs": []
|
| 290 |
+
},
|
| 291 |
+
{
|
| 292 |
+
"cell_type": "markdown",
|
| 293 |
+
"metadata": {
|
| 294 |
+
"id": "semmA3e0TcM1",
|
| 295 |
+
"colab_type": "text"
|
| 296 |
+
},
|
| 297 |
+
"source": [
|
| 298 |
+
"## 3. Accessing Annotations\n",
|
| 299 |
+
"\n",
|
| 300 |
+
"Annotations can be accessed from the returned `Document` object.\n",
|
| 301 |
+
"\n",
|
| 302 |
+
"A `Document` contains a list of `Sentence`s, which contain a list of `Token`s. Here let's first explore the annotations stored in all tokens."
|
| 303 |
+
]
|
| 304 |
+
},
|
| 305 |
+
{
|
| 306 |
+
"cell_type": "code",
|
| 307 |
+
"metadata": {
|
| 308 |
+
"id": "lIO4B5d6Rk4I",
|
| 309 |
+
"colab_type": "code",
|
| 310 |
+
"colab": {}
|
| 311 |
+
},
|
| 312 |
+
"source": [
|
| 313 |
+
"# Iterate over all tokens in all sentences, and print out the word, lemma, pos and ner tags\n",
|
| 314 |
+
"print(\"{:12s}\\t{:12s}\\t{:6s}\\t{}\".format(\"Word\", \"Lemma\", \"POS\", \"NER\"))\n",
|
| 315 |
+
"\n",
|
| 316 |
+
"for i, sent in enumerate(document.sentence):\n",
|
| 317 |
+
" print(\"[Sentence {}]\".format(i+1))\n",
|
| 318 |
+
" for t in sent.token:\n",
|
| 319 |
+
" print(\"{:12s}\\t{:12s}\\t{:6s}\\t{}\".format(t.word, t.lemma, t.pos, t.ner))\n",
|
| 320 |
+
" print(\"\")"
|
| 321 |
+
],
|
| 322 |
+
"execution_count": null,
|
| 323 |
+
"outputs": []
|
| 324 |
+
},
|
| 325 |
+
{
|
| 326 |
+
"cell_type": "markdown",
|
| 327 |
+
"metadata": {
|
| 328 |
+
"id": "msrJfvu8VV9m",
|
| 329 |
+
"colab_type": "text"
|
| 330 |
+
},
|
| 331 |
+
"source": [
|
| 332 |
+
"Alternatively, you can also browse the NER results by iterating over entity mentions over the sentences. For example:"
|
| 333 |
+
]
|
| 334 |
+
},
|
| 335 |
+
{
|
| 336 |
+
"cell_type": "code",
|
| 337 |
+
"metadata": {
|
| 338 |
+
"id": "ezEjc9LeV2Xs",
|
| 339 |
+
"colab_type": "code",
|
| 340 |
+
"colab": {}
|
| 341 |
+
},
|
| 342 |
+
"source": [
|
| 343 |
+
"# Iterate over all detected entity mentions\n",
|
| 344 |
+
"print(\"{:30s}\\t{}\".format(\"Mention\", \"Type\"))\n",
|
| 345 |
+
"\n",
|
| 346 |
+
"for sent in document.sentence:\n",
|
| 347 |
+
" for m in sent.mentions:\n",
|
| 348 |
+
" print(\"{:30s}\\t{}\".format(m.entityMentionText, m.entityType))"
|
| 349 |
+
],
|
| 350 |
+
"execution_count": null,
|
| 351 |
+
"outputs": []
|
| 352 |
+
},
|
| 353 |
+
{
|
| 354 |
+
"cell_type": "markdown",
|
| 355 |
+
"metadata": {
|
| 356 |
+
"id": "ueGzBZ3hWzkN",
|
| 357 |
+
"colab_type": "text"
|
| 358 |
+
},
|
| 359 |
+
"source": [
|
| 360 |
+
"To print all annotations a sentence, token or mention has, you can simply print the corresponding obejct."
|
| 361 |
+
]
|
| 362 |
+
},
|
| 363 |
+
{
|
| 364 |
+
"cell_type": "code",
|
| 365 |
+
"metadata": {
|
| 366 |
+
"id": "4_S8o2BHXIed",
|
| 367 |
+
"colab_type": "code",
|
| 368 |
+
"colab": {}
|
| 369 |
+
},
|
| 370 |
+
"source": [
|
| 371 |
+
"# Print annotations of a token\n",
|
| 372 |
+
"print(document.sentence[0].token[0])\n",
|
| 373 |
+
"\n",
|
| 374 |
+
"# Print annotations of a mention\n",
|
| 375 |
+
"print(document.sentence[0].mentions[0])"
|
| 376 |
+
],
|
| 377 |
+
"execution_count": null,
|
| 378 |
+
"outputs": []
|
| 379 |
+
},
|
| 380 |
+
{
|
| 381 |
+
"cell_type": "markdown",
|
| 382 |
+
"metadata": {
|
| 383 |
+
"id": "Qp66wjZ10xia",
|
| 384 |
+
"colab_type": "text"
|
| 385 |
+
},
|
| 386 |
+
"source": [
|
| 387 |
+
"**Note**: Since the Stanza CoreNLP client interface simply ports the CoreNLP annotation results to native Python objects, for a comprehensive lists of available annotators and how their annotation results can be accessed, you will need to visit the [Stanford CoreNLP website](https://stanfordnlp.github.io/CoreNLP/)."
|
| 388 |
+
]
|
| 389 |
+
},
|
| 390 |
+
{
|
| 391 |
+
"cell_type": "markdown",
|
| 392 |
+
"metadata": {
|
| 393 |
+
"id": "IPqzMK90X0w3",
|
| 394 |
+
"colab_type": "text"
|
| 395 |
+
},
|
| 396 |
+
"source": [
|
| 397 |
+
"## 4. Shutting Down the CoreNLP Server\n",
|
| 398 |
+
"\n",
|
| 399 |
+
"To shut down the background CoreNLP server process, simply call the `stop` function of the client. Note that once a server is shutdown, you'll have to restart the server with the `start()` function before any annotation is requested."
|
| 400 |
+
]
|
| 401 |
+
},
|
| 402 |
+
{
|
| 403 |
+
"cell_type": "code",
|
| 404 |
+
"metadata": {
|
| 405 |
+
"id": "xrJq8lZ3Nw7b",
|
| 406 |
+
"colab_type": "code",
|
| 407 |
+
"colab": {}
|
| 408 |
+
},
|
| 409 |
+
"source": [
|
| 410 |
+
"# Shut down the background CoreNLP server\n",
|
| 411 |
+
"client.stop()\n",
|
| 412 |
+
"\n",
|
| 413 |
+
"time.sleep(10)\n",
|
| 414 |
+
"!ps -o pid,cmd | grep java"
|
| 415 |
+
],
|
| 416 |
+
"execution_count": null,
|
| 417 |
+
"outputs": []
|
| 418 |
+
},
|
| 419 |
+
{
|
| 420 |
+
"cell_type": "markdown",
|
| 421 |
+
"metadata": {
|
| 422 |
+
"id": "23Vwa_ifYfF7",
|
| 423 |
+
"colab_type": "text"
|
| 424 |
+
},
|
| 425 |
+
"source": [
|
| 426 |
+
"### More Information\n",
|
| 427 |
+
"\n",
|
| 428 |
+
"For more information on how to use the `CoreNLPClient`, please go to the [CoreNLPClient documentation page](https://stanfordnlp.github.io/stanza/corenlp_client.html)."
|
| 429 |
+
]
|
| 430 |
+
},
|
| 431 |
+
{
|
| 432 |
+
"cell_type": "markdown",
|
| 433 |
+
"metadata": {
|
| 434 |
+
"id": "YUrVT6kA_Bzx",
|
| 435 |
+
"colab_type": "text"
|
| 436 |
+
},
|
| 437 |
+
"source": [
|
| 438 |
+
"## 5. Simplifying Client Usage with the Python `with` statement\n",
|
| 439 |
+
"\n",
|
| 440 |
+
"In the above demo, we explicitly called the `client.start()` and `client.stop()` functions to start and stop a client-server connection. However, doing this in practice is usually suboptimal, since you may forget to call the `stop()` function at the end, resulting in an unused server process occupying your machine memory.\n",
|
| 441 |
+
"\n",
|
| 442 |
+
"To solve is, a simple solution is to use the client interface with the [Python `with` statement](https://docs.python.org/3/reference/compound_stmts.html#the-with-statement). The `with` statement provides an elegant way to automatically start and stop the server process in your Python program, without you needing to worry about this. The following code snippet demonstrates how to establish a client, annotate an example text and then stop the server with a simple `with` statement. Note that we **always recommend** you to use the `with` statement when working with the Stanza CoreNLP client interface."
|
| 443 |
+
]
|
| 444 |
+
},
|
| 445 |
+
{
|
| 446 |
+
"cell_type": "code",
|
| 447 |
+
"metadata": {
|
| 448 |
+
"id": "H0ct2-R4AvJh",
|
| 449 |
+
"colab_type": "code",
|
| 450 |
+
"colab": {}
|
| 451 |
+
},
|
| 452 |
+
"source": [
|
| 453 |
+
"print(\"Starting a server with the Python \\\"with\\\" statement...\")\n",
|
| 454 |
+
"with CoreNLPClient(annotators=['tokenize','ssplit', 'pos', 'lemma', 'ner'], \n",
|
| 455 |
+
" memory='4G', endpoint='http://localhost:9001', be_quiet=True) as client:\n",
|
| 456 |
+
" text = \"Albert Einstein was a German-born theoretical physicist.\"\n",
|
| 457 |
+
" document = client.annotate(text)\n",
|
| 458 |
+
"\n",
|
| 459 |
+
" print(\"{:30s}\\t{}\".format(\"Mention\", \"Type\"))\n",
|
| 460 |
+
" for sent in document.sentence:\n",
|
| 461 |
+
" for m in sent.mentions:\n",
|
| 462 |
+
" print(\"{:30s}\\t{}\".format(m.entityMentionText, m.entityType))\n",
|
| 463 |
+
"\n",
|
| 464 |
+
"print(\"\\nThe server should be stopped upon exit from the \\\"with\\\" statement.\")"
|
| 465 |
+
],
|
| 466 |
+
"execution_count": null,
|
| 467 |
+
"outputs": []
|
| 468 |
+
},
|
| 469 |
+
{
|
| 470 |
+
"cell_type": "markdown",
|
| 471 |
+
"metadata": {
|
| 472 |
+
"id": "W435Lwc4YqKb",
|
| 473 |
+
"colab_type": "text"
|
| 474 |
+
},
|
| 475 |
+
"source": [
|
| 476 |
+
"## 6. Other Resources\n",
|
| 477 |
+
"\n",
|
| 478 |
+
"- [Stanza Homepage](https://stanfordnlp.github.io/stanza/)\n",
|
| 479 |
+
"- [FAQs](https://stanfordnlp.github.io/stanza/faq.html)\n",
|
| 480 |
+
"- [GitHub Repo](https://github.com/stanfordnlp/stanza)\n",
|
| 481 |
+
"- [Reporting Issues](https://github.com/stanfordnlp/stanza/issues)\n"
|
| 482 |
+
]
|
| 483 |
+
}
|
| 484 |
+
]
|
| 485 |
+
}
|
stanza/demo/en_test.conllu.txt
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# newdoc id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200
|
| 2 |
+
# sent_id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200-0001
|
| 3 |
+
# newpar id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200-p0001
|
| 4 |
+
# text = What if Google Morphed Into GoogleOS?
|
| 5 |
+
1 What what PRON WP PronType=Int 0 root 0:root _
|
| 6 |
+
2 if if SCONJ IN _ 4 mark 4:mark _
|
| 7 |
+
3 Google Google PROPN NNP Number=Sing 4 nsubj 4:nsubj _
|
| 8 |
+
4 Morphed morph VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 1 advcl 1:advcl:if _
|
| 9 |
+
5 Into into ADP IN _ 6 case 6:case _
|
| 10 |
+
6 GoogleOS GoogleOS PROPN NNP Number=Sing 4 obl 4:obl:into SpaceAfter=No
|
| 11 |
+
7 ? ? PUNCT . _ 4 punct 4:punct _
|
| 12 |
+
|
| 13 |
+
# sent_id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200-0002
|
| 14 |
+
# text = What if Google expanded on its search-engine (and now e-mail) wares into a full-fledged operating system?
|
| 15 |
+
1 What what PRON WP PronType=Int 0 root 0:root _
|
| 16 |
+
2 if if SCONJ IN _ 4 mark 4:mark _
|
| 17 |
+
3 Google Google PROPN NNP Number=Sing 4 nsubj 4:nsubj _
|
| 18 |
+
4 expanded expand VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 1 advcl 1:advcl:if _
|
| 19 |
+
5 on on ADP IN _ 15 case 15:case _
|
| 20 |
+
6 its its PRON PRP$ Gender=Neut|Number=Sing|Person=3|Poss=Yes|PronType=Prs 15 nmod:poss 15:nmod:poss _
|
| 21 |
+
7 search search NOUN NN Number=Sing 9 compound 9:compound SpaceAfter=No
|
| 22 |
+
8 - - PUNCT HYPH _ 9 punct 9:punct SpaceAfter=No
|
| 23 |
+
9 engine engine NOUN NN Number=Sing 15 compound 15:compound _
|
| 24 |
+
10 ( ( PUNCT -LRB- _ 9 punct 9:punct SpaceAfter=No
|
| 25 |
+
11 and and CCONJ CC _ 13 cc 13:cc _
|
| 26 |
+
12 now now ADV RB _ 13 advmod 13:advmod _
|
| 27 |
+
13 e-mail e-mail NOUN NN Number=Sing 9 conj 9:conj:and|15:compound SpaceAfter=No
|
| 28 |
+
14 ) ) PUNCT -RRB- _ 15 punct 15:punct _
|
| 29 |
+
15 wares wares NOUN NNS Number=Plur 4 obl 4:obl:on _
|
| 30 |
+
16 into into ADP IN _ 22 case 22:case _
|
| 31 |
+
17 a a DET DT Definite=Ind|PronType=Art 22 det 22:det _
|
| 32 |
+
18 full full ADV RB _ 20 advmod 20:advmod SpaceAfter=No
|
| 33 |
+
19 - - PUNCT HYPH _ 20 punct 20:punct SpaceAfter=No
|
| 34 |
+
20 fledged fledged ADJ JJ Degree=Pos 22 amod 22:amod _
|
| 35 |
+
21 operating operating NOUN NN Number=Sing 22 compound 22:compound _
|
| 36 |
+
22 system system NOUN NN Number=Sing 4 obl 4:obl:into SpaceAfter=No
|
| 37 |
+
23 ? ? PUNCT . _ 4 punct 4:punct _
|
| 38 |
+
|
| 39 |
+
# sent_id = weblog-blogspot.com_zentelligence_20040423000200_ENG_20040423_000200-0003
|
| 40 |
+
# text = [via Microsoft Watch from Mary Jo Foley ]
|
| 41 |
+
1 [ [ PUNCT -LRB- _ 4 punct 4:punct SpaceAfter=No
|
| 42 |
+
2 via via ADP IN _ 4 case 4:case _
|
| 43 |
+
3 Microsoft Microsoft PROPN NNP Number=Sing 4 compound 4:compound _
|
| 44 |
+
4 Watch Watch PROPN NNP Number=Sing 0 root 0:root _
|
| 45 |
+
5 from from ADP IN _ 6 case 6:case _
|
| 46 |
+
6 Mary Mary PROPN NNP Number=Sing 4 nmod 4:nmod:from _
|
| 47 |
+
7 Jo Jo PROPN NNP Number=Sing 6 flat 6:flat _
|
| 48 |
+
8 Foley Foley PROPN NNP Number=Sing 6 flat 6:flat _
|
| 49 |
+
9 ] ] PUNCT -RRB- _ 4 punct 4:punct _
|
| 50 |
+
|
| 51 |
+
# newdoc id = weblog-blogspot.com_marketview_20050511222700_ENG_20050511_222700
|
| 52 |
+
# sent_id = weblog-blogspot.com_marketview_20050511222700_ENG_20050511_222700-0001
|
| 53 |
+
# newpar id = weblog-blogspot.com_marketview_20050511222700_ENG_20050511_222700-p0001
|
| 54 |
+
# text = (And, by the way, is anybody else just a little nostalgic for the days when that was a good thing?)
|
| 55 |
+
1 ( ( PUNCT -LRB- _ 14 punct 14:punct SpaceAfter=No
|
| 56 |
+
2 And and CCONJ CC _ 14 cc 14:cc SpaceAfter=No
|
| 57 |
+
3 , , PUNCT , _ 14 punct 14:punct _
|
| 58 |
+
4 by by ADP IN _ 6 case 6:case _
|
| 59 |
+
5 the the DET DT Definite=Def|PronType=Art 6 det 6:det _
|
| 60 |
+
6 way way NOUN NN Number=Sing 14 obl 14:obl:by SpaceAfter=No
|
| 61 |
+
7 , , PUNCT , _ 14 punct 14:punct _
|
| 62 |
+
8 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 14 cop 14:cop _
|
| 63 |
+
9 anybody anybody PRON NN Number=Sing 14 nsubj 14:nsubj _
|
| 64 |
+
10 else else ADJ JJ Degree=Pos 9 amod 9:amod _
|
| 65 |
+
11 just just ADV RB _ 13 advmod 13:advmod _
|
| 66 |
+
12 a a DET DT Definite=Ind|PronType=Art 13 det 13:det _
|
| 67 |
+
13 little little ADJ JJ Degree=Pos 14 obl:npmod 14:obl:npmod _
|
| 68 |
+
14 nostalgic nostalgic NOUN NN Number=Sing 0 root 0:root _
|
| 69 |
+
15 for for ADP IN _ 17 case 17:case _
|
| 70 |
+
16 the the DET DT Definite=Def|PronType=Art 17 det 17:det _
|
| 71 |
+
17 days day NOUN NNS Number=Plur 14 nmod 14:nmod:for|23:obl:npmod _
|
| 72 |
+
18 when when ADV WRB PronType=Rel 23 advmod 17:ref _
|
| 73 |
+
19 that that PRON DT Number=Sing|PronType=Dem 23 nsubj 23:nsubj _
|
| 74 |
+
20 was be AUX VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 23 cop 23:cop _
|
| 75 |
+
21 a a DET DT Definite=Ind|PronType=Art 23 det 23:det _
|
| 76 |
+
22 good good ADJ JJ Degree=Pos 23 amod 23:amod _
|
| 77 |
+
23 thing thing NOUN NN Number=Sing 17 acl:relcl 17:acl:relcl SpaceAfter=No
|
| 78 |
+
24 ? ? PUNCT . _ 14 punct 14:punct SpaceAfter=No
|
| 79 |
+
25 ) ) PUNCT -RRB- _ 14 punct 14:punct _
|
stanza/demo/semgrex visualization.ipynb
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "2787d5f5",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import stanza\n",
|
| 11 |
+
"from stanza.server.semgrex import Semgrex\n",
|
| 12 |
+
"from stanza.models.common.constant import is_right_to_left\n",
|
| 13 |
+
"import spacy\n",
|
| 14 |
+
"from spacy import displacy\n",
|
| 15 |
+
"from spacy.tokens import Doc\n",
|
| 16 |
+
"from IPython.display import display, HTML\n",
|
| 17 |
+
"\n",
|
| 18 |
+
"\n",
|
| 19 |
+
"\"\"\"\n",
|
| 20 |
+
"IMPORTANT: For the code in this module to run, you must have corenlp and Java installed on your machine. Additionally,\n",
|
| 21 |
+
"set an environment variable CLASSPATH equal to the path of your corenlp directory.\n",
|
| 22 |
+
"\n",
|
| 23 |
+
"Example: CLASSPATH=C:\\\\Users\\\\Alex\\\\PycharmProjects\\\\pythonProject\\\\stanford-corenlp-4.5.0\\\\stanford-corenlp-4.5.0\\\\*\n",
|
| 24 |
+
"\"\"\"\n",
|
| 25 |
+
"\n",
|
| 26 |
+
"%env CLASSPATH=C:\\\\stanford-corenlp-4.5.2\\\\stanford-corenlp-4.5.2\\\\*\n",
|
| 27 |
+
"def get_sentences_html(doc, language):\n",
|
| 28 |
+
" \"\"\"\n",
|
| 29 |
+
" Returns a list of the HTML strings of the dependency visualizations of a given stanza doc object.\n",
|
| 30 |
+
"\n",
|
| 31 |
+
" The 'language' arg is the two-letter language code for the document to be processed.\n",
|
| 32 |
+
"\n",
|
| 33 |
+
" First converts the stanza doc object to a spacy doc object and uses displacy to generate an HTML\n",
|
| 34 |
+
" string for each sentence of the doc object.\n",
|
| 35 |
+
" \"\"\"\n",
|
| 36 |
+
" html_strings = []\n",
|
| 37 |
+
"\n",
|
| 38 |
+
" # blank model - we don't use any of the model features, just the visualization\n",
|
| 39 |
+
" nlp = spacy.blank(\"en\")\n",
|
| 40 |
+
" sentences_to_visualize = []\n",
|
| 41 |
+
" for sentence in doc.sentences:\n",
|
| 42 |
+
" words, lemmas, heads, deps, tags = [], [], [], [], []\n",
|
| 43 |
+
" if is_right_to_left(language): # order of words displayed is reversed, dependency arcs remain intact\n",
|
| 44 |
+
" sent_len = len(sentence.words)\n",
|
| 45 |
+
" for word in reversed(sentence.words):\n",
|
| 46 |
+
" words.append(word.text)\n",
|
| 47 |
+
" lemmas.append(word.lemma)\n",
|
| 48 |
+
" deps.append(word.deprel)\n",
|
| 49 |
+
" tags.append(word.upos)\n",
|
| 50 |
+
" if word.head == 0: # spaCy head indexes are formatted differently than that of Stanza\n",
|
| 51 |
+
" heads.append(sent_len - word.id)\n",
|
| 52 |
+
" else:\n",
|
| 53 |
+
" heads.append(sent_len - word.head)\n",
|
| 54 |
+
" else: # left to right rendering\n",
|
| 55 |
+
" for word in sentence.words:\n",
|
| 56 |
+
" words.append(word.text)\n",
|
| 57 |
+
" lemmas.append(word.lemma)\n",
|
| 58 |
+
" deps.append(word.deprel)\n",
|
| 59 |
+
" tags.append(word.upos)\n",
|
| 60 |
+
" if word.head == 0:\n",
|
| 61 |
+
" heads.append(word.id - 1)\n",
|
| 62 |
+
" else:\n",
|
| 63 |
+
" heads.append(word.head - 1)\n",
|
| 64 |
+
" document_result = Doc(nlp.vocab, words=words, lemmas=lemmas, heads=heads, deps=deps, pos=tags)\n",
|
| 65 |
+
" sentences_to_visualize.append(document_result)\n",
|
| 66 |
+
"\n",
|
| 67 |
+
" for line in sentences_to_visualize: # render all sentences through displaCy\n",
|
| 68 |
+
" html_strings.append(displacy.render(line, style=\"dep\",\n",
|
| 69 |
+
" options={\"compact\": True, \"word_spacing\": 30, \"distance\": 100,\n",
|
| 70 |
+
" \"arrow_spacing\": 20}, jupyter=False))\n",
|
| 71 |
+
" return html_strings\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"\n",
|
| 74 |
+
"def find_nth(haystack, needle, n):\n",
|
| 75 |
+
" \"\"\"\n",
|
| 76 |
+
" Returns the starting index of the nth occurrence of the substring 'needle' in the string 'haystack'.\n",
|
| 77 |
+
" \"\"\"\n",
|
| 78 |
+
" start = haystack.find(needle)\n",
|
| 79 |
+
" while start >= 0 and n > 1:\n",
|
| 80 |
+
" start = haystack.find(needle, start + len(needle))\n",
|
| 81 |
+
" n -= 1\n",
|
| 82 |
+
" return start\n",
|
| 83 |
+
"\n",
|
| 84 |
+
"\n",
|
| 85 |
+
"def round_base(num, base=10):\n",
|
| 86 |
+
" \"\"\"\n",
|
| 87 |
+
" Rounding a number to its nearest multiple of the base. round_base(49.2, base=50) = 50.\n",
|
| 88 |
+
" \"\"\"\n",
|
| 89 |
+
" return base * round(num/base)\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"\n",
|
| 92 |
+
"def process_sentence_html(orig_html, semgrex_sentence):\n",
|
| 93 |
+
" \"\"\"\n",
|
| 94 |
+
" Takes a semgrex sentence object and modifies the HTML of the original sentence's deprel visualization,\n",
|
| 95 |
+
" highlighting words involved in the search queries and adding the label of the word inside of the semgrex match.\n",
|
| 96 |
+
"\n",
|
| 97 |
+
" Returns the modified html string of the sentence's deprel visualization.\n",
|
| 98 |
+
" \"\"\"\n",
|
| 99 |
+
" tracker = {} # keep track of which words have multiple labels\n",
|
| 100 |
+
" DEFAULT_TSPAN_COUNT = 2 # the original displacy html assigns two <tspan> objects per <text> object\n",
|
| 101 |
+
" CLOSING_TSPAN_LEN = 8 # </tspan> is 8 chars long\n",
|
| 102 |
+
" colors = ['#4477AA', '#66CCEE', '#228833', '#CCBB44', '#EE6677', '#AA3377', '#BBBBBB']\n",
|
| 103 |
+
" css_bolded_class = \"<style> .bolded{font-weight: bold;} </style>\\n\"\n",
|
| 104 |
+
" found_index = orig_html.find(\"\\n\") # returns index where the opening <svg> ends\n",
|
| 105 |
+
" # insert the new style class into html string\n",
|
| 106 |
+
" orig_html = orig_html[: found_index + 1] + css_bolded_class + orig_html[found_index + 1:]\n",
|
| 107 |
+
"\n",
|
| 108 |
+
" # Add color to words in the match, bold words in the match\n",
|
| 109 |
+
" for query in semgrex_sentence.result:\n",
|
| 110 |
+
" for i, match in enumerate(query.match):\n",
|
| 111 |
+
" color = colors[i]\n",
|
| 112 |
+
" paired_dy = 2\n",
|
| 113 |
+
" for node in match.node:\n",
|
| 114 |
+
" name, match_index = node.name, node.matchIndex\n",
|
| 115 |
+
" # edit existing <tspan> to change color and bold the text\n",
|
| 116 |
+
" start = find_nth(orig_html, \"<text\", match_index) # finds start of svg <text> of interest\n",
|
| 117 |
+
" if match_index not in tracker: # if we've already bolded and colored, keep the first color\n",
|
| 118 |
+
" tspan_start = orig_html.find(\"<tspan\",\n",
|
| 119 |
+
" start) # finds start of the first svg <tspan> inside of the <text>\n",
|
| 120 |
+
" tspan_end = orig_html.find(\"</tspan>\", start) # finds start of the end of the above <tspan>\n",
|
| 121 |
+
" tspan_substr = orig_html[tspan_start: tspan_end + CLOSING_TSPAN_LEN + 1] + \"\\n\"\n",
|
| 122 |
+
" # color words in the hit and bold words in the hit\n",
|
| 123 |
+
" edited_tspan = tspan_substr.replace('class=\"displacy-word\"', 'class=\"bolded\"').replace(\n",
|
| 124 |
+
" 'fill=\"currentColor\"', f'fill=\"{color}\"')\n",
|
| 125 |
+
" # insert edited <tspan> object into html string\n",
|
| 126 |
+
" orig_html = orig_html[: tspan_start] + edited_tspan + orig_html[tspan_end + CLOSING_TSPAN_LEN + 2:]\n",
|
| 127 |
+
" tracker[match_index] = DEFAULT_TSPAN_COUNT\n",
|
| 128 |
+
"\n",
|
| 129 |
+
" # next, we have to insert the new <tspan> object for the label\n",
|
| 130 |
+
" # Copy old <tspan> to copy formatting when creating new <tspan> later\n",
|
| 131 |
+
" prev_tspan_start = find_nth(orig_html[start:], \"<tspan\",\n",
|
| 132 |
+
" tracker[match_index] - 1) + start # find the previous <tspan> start index\n",
|
| 133 |
+
" prev_tspan_end = find_nth(orig_html[start:], \"</tspan>\",\n",
|
| 134 |
+
" tracker[match_index] - 1) + start # find the prev </tspan> start index\n",
|
| 135 |
+
" prev_tspan = orig_html[prev_tspan_start: prev_tspan_end + CLOSING_TSPAN_LEN + 1]\n",
|
| 136 |
+
"\n",
|
| 137 |
+
" # Find spot to insert new tspan\n",
|
| 138 |
+
" closing_tspan_start = find_nth(orig_html[start:], \"</tspan>\", tracker[match_index]) + start\n",
|
| 139 |
+
" up_to_new_tspan = orig_html[: closing_tspan_start + CLOSING_TSPAN_LEN + 1]\n",
|
| 140 |
+
" rest_need_add_newline = orig_html[closing_tspan_start + CLOSING_TSPAN_LEN + 1:]\n",
|
| 141 |
+
"\n",
|
| 142 |
+
" # Calculate proper x value in svg\n",
|
| 143 |
+
" x_value_start = prev_tspan.find('x=\"')\n",
|
| 144 |
+
" x_value_end = prev_tspan[x_value_start + 3:].find('\"') + 3 # 3 is the length of the 'x=\"' substring\n",
|
| 145 |
+
" x_value = prev_tspan[x_value_start + 3: x_value_end + x_value_start]\n",
|
| 146 |
+
"\n",
|
| 147 |
+
" # Calculate proper y value in svg\n",
|
| 148 |
+
" DEFAULT_DY_VAL, dy = 2, 2\n",
|
| 149 |
+
" if paired_dy != DEFAULT_DY_VAL and node == match.node[\n",
|
| 150 |
+
" 1]: # we're on the second node and need to adjust height to match the paired node\n",
|
| 151 |
+
" dy = paired_dy\n",
|
| 152 |
+
" if node == match.node[0]:\n",
|
| 153 |
+
" paired_node_level = 2\n",
|
| 154 |
+
" if match.node[1].matchIndex in tracker: # check if we need to adjust heights of labels\n",
|
| 155 |
+
" paired_node_level = tracker[match.node[1].matchIndex]\n",
|
| 156 |
+
" dif = tracker[match_index] - paired_node_level\n",
|
| 157 |
+
" if dif > 0: # current node has more labels\n",
|
| 158 |
+
" paired_dy = DEFAULT_DY_VAL * dif + 1\n",
|
| 159 |
+
" dy = DEFAULT_DY_VAL\n",
|
| 160 |
+
" else: # paired node has more labels, adjust this label down\n",
|
| 161 |
+
" dy = DEFAULT_DY_VAL * (abs(dif) + 1)\n",
|
| 162 |
+
" paired_dy = DEFAULT_DY_VAL\n",
|
| 163 |
+
"\n",
|
| 164 |
+
" # Insert new <tspan> object\n",
|
| 165 |
+
" new_tspan = f' <tspan class=\"displacy-word\" dy=\"{dy}em\" fill=\"{color}\" x={x_value}>{name[: 3].title()}.</tspan>\\n' # abbreviate label names to 3 chars\n",
|
| 166 |
+
" orig_html = up_to_new_tspan + new_tspan + rest_need_add_newline\n",
|
| 167 |
+
" tracker[match_index] += 1\n",
|
| 168 |
+
" return orig_html\n",
|
| 169 |
+
"\n",
|
| 170 |
+
"\n",
|
| 171 |
+
"def render_html_strings(edited_html_strings):\n",
|
| 172 |
+
" \"\"\"\n",
|
| 173 |
+
" Renders the HTML to make the edits visible\n",
|
| 174 |
+
" \"\"\"\n",
|
| 175 |
+
" for html_string in edited_html_strings:\n",
|
| 176 |
+
" display(HTML(html_string))\n",
|
| 177 |
+
"\n",
|
| 178 |
+
"\n",
|
| 179 |
+
"def visualize_search_doc(doc, semgrex_queries, lang_code, start_match=0, end_match=10):\n",
|
| 180 |
+
" \"\"\"\n",
|
| 181 |
+
" Visualizes the semgrex results of running semgrex search on a stanza doc object with the given list of\n",
|
| 182 |
+
" semgrex queries. Returns a list of the edited HTML strings from the doc. Each element in the list represents\n",
|
| 183 |
+
" the HTML to render one of the sentences in the document.\n",
|
| 184 |
+
"\n",
|
| 185 |
+
" 'lang_code' is the two-letter language abbreviation for the language that the stanza doc object is written in.\n",
|
| 186 |
+
"\n",
|
| 187 |
+
"\n",
|
| 188 |
+
" 'start_match' and 'end_match' determine which matches to visualize. Works similar to splices, so that\n",
|
| 189 |
+
" start_match=0 and end_match=10 will display the first 10 semgrex matches.\n",
|
| 190 |
+
" \"\"\"\n",
|
| 191 |
+
" matches_count = 0 # Limits number of visualizations\n",
|
| 192 |
+
" with Semgrex(classpath=\"$CLASSPATH\") as sem:\n",
|
| 193 |
+
" edited_html_strings = []\n",
|
| 194 |
+
" semgrex_results = sem.process(doc, *semgrex_queries)\n",
|
| 195 |
+
" # one html string for each sentence\n",
|
| 196 |
+
" unedited_html_strings = get_sentences_html(doc, lang_code)\n",
|
| 197 |
+
" for i in range(len(unedited_html_strings)):\n",
|
| 198 |
+
"\n",
|
| 199 |
+
" if matches_count >= end_match: # we've collected enough matches, stop early\n",
|
| 200 |
+
" break\n",
|
| 201 |
+
"\n",
|
| 202 |
+
" # check if sentence has matches, if not then do not visualize\n",
|
| 203 |
+
" has_none = True\n",
|
| 204 |
+
" for query in semgrex_results.result[i].result:\n",
|
| 205 |
+
" for match in query.match:\n",
|
| 206 |
+
" if match:\n",
|
| 207 |
+
" has_none = False\n",
|
| 208 |
+
"\n",
|
| 209 |
+
" # Process HTML if queries have matches\n",
|
| 210 |
+
" if not has_none:\n",
|
| 211 |
+
" if start_match <= matches_count < end_match:\n",
|
| 212 |
+
" edited_string = process_sentence_html(unedited_html_strings[i], semgrex_results.result[i])\n",
|
| 213 |
+
" edited_string = adjust_dep_arrows(edited_string)\n",
|
| 214 |
+
" edited_html_strings.append(edited_string)\n",
|
| 215 |
+
" matches_count += 1\n",
|
| 216 |
+
"\n",
|
| 217 |
+
" render_html_strings(edited_html_strings)\n",
|
| 218 |
+
" return edited_html_strings\n",
|
| 219 |
+
"\n",
|
| 220 |
+
"\n",
|
| 221 |
+
"def visualize_search_str(text, semgrex_queries, lang_code):\n",
|
| 222 |
+
" \"\"\"\n",
|
| 223 |
+
" Visualizes the deprel of the semgrex results from running semgrex search on a string with the given list of\n",
|
| 224 |
+
" semgrex queries. Returns a list of the edited HTML strings. Each element in the list represents\n",
|
| 225 |
+
" the HTML to render one of the sentences in the document.\n",
|
| 226 |
+
"\n",
|
| 227 |
+
" Internally, this function converts the string into a stanza doc object before processing the doc object.\n",
|
| 228 |
+
"\n",
|
| 229 |
+
" 'lang_code' is the two-letter language abbreviation for the language that the stanza doc object is written in.\n",
|
| 230 |
+
" \"\"\"\n",
|
| 231 |
+
" nlp = stanza.Pipeline(lang_code, processors=\"tokenize, pos, lemma, depparse\")\n",
|
| 232 |
+
" doc = nlp(text)\n",
|
| 233 |
+
" return visualize_search_doc(doc, semgrex_queries, lang_code)\n",
|
| 234 |
+
"\n",
|
| 235 |
+
"\n",
|
| 236 |
+
"def adjust_dep_arrows(raw_html):\n",
|
| 237 |
+
" \"\"\"\n",
|
| 238 |
+
" The default spaCy dependency visualization has misaligned arrows.\n",
|
| 239 |
+
" We fix arrows by aligning arrow ends and bodies to the word that they are directed to. If a word has an\n",
|
| 240 |
+
" arrowhead that is pointing not directly on the word's center, align the arrowhead to match the center of the word.\n",
|
| 241 |
+
"\n",
|
| 242 |
+
" returns the edited html with fixed arrow placement\n",
|
| 243 |
+
" \"\"\"\n",
|
| 244 |
+
" HTML_ARROW_BEGINNING = '<g class=\"displacy-arrow\">'\n",
|
| 245 |
+
" HTML_ARROW_ENDING = \"</g>\"\n",
|
| 246 |
+
" HTML_ARROW_ENDING_LEN = 6 # there are 2 newline chars after the arrow ending\n",
|
| 247 |
+
" arrows_start_idx = find_nth(haystack=raw_html, needle='<g class=\"displacy-arrow\">', n=1)\n",
|
| 248 |
+
" words_html, arrows_html = raw_html[: arrows_start_idx], raw_html[arrows_start_idx:] # separate html for words and arrows\n",
|
| 249 |
+
" final_html = words_html # continually concatenate to this after processing each arrow\n",
|
| 250 |
+
" arrow_number = 1 # which arrow we're editing (1-indexed)\n",
|
| 251 |
+
" start_idx, end_of_class_idx = find_nth(haystack=arrows_html, needle=HTML_ARROW_BEGINNING, n=arrow_number), find_nth(arrows_html, HTML_ARROW_ENDING, arrow_number)\n",
|
| 252 |
+
" while start_idx != -1: # edit every arrow\n",
|
| 253 |
+
" arrow_section = arrows_html[start_idx: end_of_class_idx + HTML_ARROW_ENDING_LEN] # slice a single svg arrow object\n",
|
| 254 |
+
" if arrow_section[-1] == \"<\": # this is the last arrow in the HTML, don't cut the splice early\n",
|
| 255 |
+
" arrow_section = arrows_html[start_idx:]\n",
|
| 256 |
+
" edited_arrow_section = edit_dep_arrow(arrow_section)\n",
|
| 257 |
+
"\n",
|
| 258 |
+
" final_html = final_html + edited_arrow_section # continually update html with new arrow html until done\n",
|
| 259 |
+
"\n",
|
| 260 |
+
" # Prepare for next iteration\n",
|
| 261 |
+
" arrow_number += 1\n",
|
| 262 |
+
" start_idx = find_nth(arrows_html, '<g class=\"displacy-arrow\">', n=arrow_number)\n",
|
| 263 |
+
" end_of_class_idx = find_nth(arrows_html, \"</g>\", arrow_number)\n",
|
| 264 |
+
" return final_html\n",
|
| 265 |
+
"\n",
|
| 266 |
+
"\n",
|
| 267 |
+
"def edit_dep_arrow(arrow_html):\n",
|
| 268 |
+
" \"\"\"\n",
|
| 269 |
+
" The formatting of a displacy arrow in svg is the following:\n",
|
| 270 |
+
" <g class=\"displacy-arrow\">\n",
|
| 271 |
+
" <path class=\"displacy-arc\" id=\"arrow-c628889ffbf343e3848193a08606f10a-0-0\" stroke-width=\"2px\" d=\"M70,352.0 C70,177.0 390.0,177.0 390.0,352.0\" fill=\"none\" stroke=\"currentColor\"/>\n",
|
| 272 |
+
" <text dy=\"1.25em\" style=\"font-size: 0.8em; letter-spacing: 1px\">\n",
|
| 273 |
+
" <textPath xlink:href=\"#arrow-c628889ffbf343e3848193a08606f10a-0-0\" class=\"displacy-label\" startOffset=\"50%\" side=\"left\" fill=\"currentColor\" text-anchor=\"middle\">csubj</textPath>\n",
|
| 274 |
+
" </text>\n",
|
| 275 |
+
" <path class=\"displacy-arrowhead\" d=\"M70,354.0 L62,342.0 78,342.0\" fill=\"currentColor\"/>\n",
|
| 276 |
+
" </g>\n",
|
| 277 |
+
"\n",
|
| 278 |
+
" We edit the 'd = ...' parts of the <path class ...> section to fix the arrow direction and length\n",
|
| 279 |
+
"\n",
|
| 280 |
+
" returns the arrow_html with distances fixed\n",
|
| 281 |
+
" \"\"\"\n",
|
| 282 |
+
" WORD_SPACING = 50 # words start at x=50 and are separated by 100s so their x values are multiples of 50\n",
|
| 283 |
+
" M_OFFSET = 4 # length of 'd=\"M' that we search for to extract the number from d=\"M70, for instance\n",
|
| 284 |
+
" ARROW_PIXEL_SIZE = 4\n",
|
| 285 |
+
" first_d_idx, second_d_idx = find_nth(arrow_html, 'd=\"M', 1), find_nth(arrow_html, 'd=\"M', 2) # find where d=\"M starts\n",
|
| 286 |
+
" first_d_cutoff, second_d_cutoff = arrow_html.find(\",\", first_d_idx), arrow_html.find(\",\", second_d_idx) # isolate the number after 'M' e.g. 'M70'\n",
|
| 287 |
+
" # gives svg x values of arrow body starting position and arrowhead position\n",
|
| 288 |
+
" arrow_position, arrowhead_position = float(arrow_html[first_d_idx + M_OFFSET: first_d_cutoff]), float(arrow_html[second_d_idx + M_OFFSET: second_d_cutoff])\n",
|
| 289 |
+
" # gives starting index of where 'fill=\"none\"' or 'fill=\"currentColor\"' begin, reference points to end the d= section\n",
|
| 290 |
+
" first_fill_start_idx, second_fill_start_idx = find_nth(arrow_html, \"fill\", n=1), find_nth(arrow_html, \"fill\", n=3)\n",
|
| 291 |
+
"\n",
|
| 292 |
+
" # isolate the d= ... section to edit\n",
|
| 293 |
+
" first_d, second_d = arrow_html[first_d_idx: first_fill_start_idx], arrow_html[second_d_idx: second_fill_start_idx]\n",
|
| 294 |
+
" first_d_split, second_d_split = first_d.split(\",\"), second_d.split(\",\")\n",
|
| 295 |
+
"\n",
|
| 296 |
+
" if arrow_position == arrowhead_position: # This arrow is incoming onto the word, center the arrow/head to word center\n",
|
| 297 |
+
" corrected_arrow_pos = corrected_arrowhead_pos = round_base(arrow_position, base=WORD_SPACING)\n",
|
| 298 |
+
"\n",
|
| 299 |
+
" # edit first_d -- arrow body\n",
|
| 300 |
+
" second_term = first_d_split[1].split(\" \")[0] + \" \" + str(corrected_arrow_pos)\n",
|
| 301 |
+
" first_d = 'd=\"M' + str(corrected_arrow_pos) + \",\" + second_term + \",\" + \",\".join(first_d_split[2:])\n",
|
| 302 |
+
"\n",
|
| 303 |
+
" # edit second_d -- arrowhead\n",
|
| 304 |
+
" second_term = second_d_split[1].split(\" \")[0] + \" L\" + str(corrected_arrowhead_pos - ARROW_PIXEL_SIZE)\n",
|
| 305 |
+
" third_term = second_d_split[2].split(\" \")[0] + \" \" + str(corrected_arrowhead_pos + ARROW_PIXEL_SIZE)\n",
|
| 306 |
+
" second_d = 'd=\"M' + str(corrected_arrowhead_pos) + \",\" + second_term + \",\" + third_term + \",\" + \",\".join(second_d_split[3:])\n",
|
| 307 |
+
" else: # This arrow is outgoing to another word, center the arrow/head to that word's center\n",
|
| 308 |
+
" corrected_arrowhead_pos = round_base(arrowhead_position, base=WORD_SPACING)\n",
|
| 309 |
+
"\n",
|
| 310 |
+
" # edit first_d -- arrow body\n",
|
| 311 |
+
" third_term = first_d_split[2].split(\" \")[0] + \" \" + str(corrected_arrowhead_pos)\n",
|
| 312 |
+
" fourth_term = first_d_split[3].split(\" \")[0] + \" \" + str(corrected_arrowhead_pos)\n",
|
| 313 |
+
" terms = [first_d_split[0], first_d_split[1], third_term, fourth_term] + first_d_split[4:]\n",
|
| 314 |
+
" first_d = \",\".join(terms)\n",
|
| 315 |
+
"\n",
|
| 316 |
+
" # edit second_d -- arrow head\n",
|
| 317 |
+
" first_term = f'd=\"M{corrected_arrowhead_pos}'\n",
|
| 318 |
+
" second_term = second_d_split[1].split(\" \")[0] + \" L\" + str(corrected_arrowhead_pos - ARROW_PIXEL_SIZE)\n",
|
| 319 |
+
" third_term = second_d_split[2].split(\" \")[0] + \" \" + str(corrected_arrowhead_pos + ARROW_PIXEL_SIZE)\n",
|
| 320 |
+
" terms = [first_term, second_term, third_term] + second_d_split[3:]\n",
|
| 321 |
+
" second_d = \",\".join(terms)\n",
|
| 322 |
+
" # rebuild and return html\n",
|
| 323 |
+
" return arrow_html[:first_d_idx] + first_d + \" \" + arrow_html[first_fill_start_idx:second_d_idx] + second_d + \" \" + arrow_html[second_fill_start_idx:]\n",
|
| 324 |
+
"\n",
|
| 325 |
+
"\n",
|
| 326 |
+
"def main():\n",
|
| 327 |
+
" nlp = stanza.Pipeline(\"en\", processors=\"tokenize,pos,lemma,depparse\")\n",
|
| 328 |
+
"\n",
|
| 329 |
+
" # doc = nlp(\"This a dummy sentence. Banning opal removed all artifact decks from the meta. I miss playing lantern. This is a dummy sentence.\")\n",
|
| 330 |
+
" doc = nlp(\"Banning opal removed artifact decks from the meta. Banning tennis resulted in players banning people.\")\n",
|
| 331 |
+
" # A single result .result[i].result[j] is a list of matches for sentence i on semgrex query j.\n",
|
| 332 |
+
" queries = [\"{pos:NN}=object <obl {}=action\",\n",
|
| 333 |
+
" \"{cpos:NOUN}=thing <obj {cpos:VERB}=action\"]\n",
|
| 334 |
+
" res = visualize_search_doc(doc, queries, \"en\")\n",
|
| 335 |
+
" print(res[0]) # see the first sentence's deprel visualization HTML\n",
|
| 336 |
+
" print(\"---------------------------------------\")\n",
|
| 337 |
+
" print(res[1]) # second sentence's deprel visualization HTML\n",
|
| 338 |
+
" return\n",
|
| 339 |
+
"\n",
|
| 340 |
+
"\n",
|
| 341 |
+
"if __name__ == '__main__':\n",
|
| 342 |
+
" main()\n"
|
| 343 |
+
]
|
| 344 |
+
}
|
| 345 |
+
],
|
| 346 |
+
"metadata": {
|
| 347 |
+
"kernelspec": {
|
| 348 |
+
"display_name": "Python 3 (ipykernel)",
|
| 349 |
+
"language": "python",
|
| 350 |
+
"name": "python3"
|
| 351 |
+
},
|
| 352 |
+
"language_info": {
|
| 353 |
+
"codemirror_mode": {
|
| 354 |
+
"name": "ipython",
|
| 355 |
+
"version": 3
|
| 356 |
+
},
|
| 357 |
+
"file_extension": ".py",
|
| 358 |
+
"mimetype": "text/x-python",
|
| 359 |
+
"name": "python",
|
| 360 |
+
"nbconvert_exporter": "python",
|
| 361 |
+
"pygments_lexer": "ipython3",
|
| 362 |
+
"version": "3.9.6"
|
| 363 |
+
}
|
| 364 |
+
},
|
| 365 |
+
"nbformat": 4,
|
| 366 |
+
"nbformat_minor": 5
|
| 367 |
+
}
|
stanza/images/stanza-logo.png
ADDED
|
stanza/stanza/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from stanza.pipeline.core import DownloadMethod, Pipeline
|
| 2 |
+
from stanza.pipeline.multilingual import MultilingualPipeline
|
| 3 |
+
from stanza.models.common.doc import Document
|
| 4 |
+
from stanza.resources.common import download
|
| 5 |
+
from stanza.resources.installation import install_corenlp, download_corenlp_models
|
| 6 |
+
from stanza._version import __version__, __resources_version__
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
logger = logging.getLogger('stanza')
|
| 10 |
+
|
| 11 |
+
# if the client application hasn't set the log level, we set it
|
| 12 |
+
# ourselves to INFO
|
| 13 |
+
if logger.level == 0:
|
| 14 |
+
logger.setLevel(logging.INFO)
|
| 15 |
+
|
| 16 |
+
log_handler = logging.StreamHandler()
|
| 17 |
+
log_formatter = logging.Formatter(fmt="%(asctime)s %(levelname)s: %(message)s",
|
| 18 |
+
datefmt='%Y-%m-%d %H:%M:%S')
|
| 19 |
+
log_handler.setFormatter(log_formatter)
|
| 20 |
+
|
| 21 |
+
# also, if the client hasn't added any handlers for this logger
|
| 22 |
+
# (or a default handler), we add a handler of our own
|
| 23 |
+
#
|
| 24 |
+
# client can later do
|
| 25 |
+
# logger.removeHandler(stanza.log_handler)
|
| 26 |
+
if not logger.hasHandlers():
|
| 27 |
+
logger.addHandler(log_handler)
|
stanza/stanza/models/charlm.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Entry point for training and evaluating a character-level neural language model.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
from copy import copy
|
| 7 |
+
import logging
|
| 8 |
+
import lzma
|
| 9 |
+
import math
|
| 10 |
+
import os
|
| 11 |
+
import random
|
| 12 |
+
import time
|
| 13 |
+
from types import GeneratorType
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
from stanza.models.common.char_model import build_charlm_vocab, CharacterLanguageModel, CharacterLanguageModelTrainer
|
| 18 |
+
from stanza.models.common.vocab import CharVocab
|
| 19 |
+
from stanza.models.common import utils
|
| 20 |
+
from stanza.models import _training_logging
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger('stanza')
|
| 23 |
+
|
| 24 |
+
def repackage_hidden(h):
|
| 25 |
+
"""Wraps hidden states in new Tensors,
|
| 26 |
+
to detach them from their history."""
|
| 27 |
+
if isinstance(h, torch.Tensor):
|
| 28 |
+
return h.detach()
|
| 29 |
+
else:
|
| 30 |
+
return tuple(repackage_hidden(v) for v in h)
|
| 31 |
+
|
| 32 |
+
def batchify(data, bsz, device):
|
| 33 |
+
# Work out how cleanly we can divide the dataset into bsz parts.
|
| 34 |
+
nbatch = data.size(0) // bsz
|
| 35 |
+
# Trim off any extra elements that wouldn't cleanly fit (remainders).
|
| 36 |
+
data = data.narrow(0, 0, nbatch * bsz)
|
| 37 |
+
# Evenly divide the data across the bsz batches.
|
| 38 |
+
data = data.view(bsz, -1) # batch_first is True
|
| 39 |
+
data = data.to(device)
|
| 40 |
+
return data
|
| 41 |
+
|
| 42 |
+
def get_batch(source, i, seq_len):
|
| 43 |
+
seq_len = min(seq_len, source.size(1) - 1 - i)
|
| 44 |
+
data = source[:, i:i+seq_len]
|
| 45 |
+
target = source[:, i+1:i+1+seq_len].reshape(-1)
|
| 46 |
+
return data, target
|
| 47 |
+
|
| 48 |
+
def load_file(filename, vocab, direction):
|
| 49 |
+
with utils.open_read_text(filename) as fin:
|
| 50 |
+
data = fin.read()
|
| 51 |
+
|
| 52 |
+
idx = vocab['char'].map(data)
|
| 53 |
+
if direction == 'backward': idx = idx[::-1]
|
| 54 |
+
return torch.tensor(idx)
|
| 55 |
+
|
| 56 |
+
def load_data(path, vocab, direction):
|
| 57 |
+
if os.path.isdir(path):
|
| 58 |
+
filenames = sorted(os.listdir(path))
|
| 59 |
+
for filename in filenames:
|
| 60 |
+
logger.info('Loading data from {}'.format(filename))
|
| 61 |
+
data = load_file(os.path.join(path, filename), vocab, direction)
|
| 62 |
+
yield data
|
| 63 |
+
else:
|
| 64 |
+
data = load_file(path, vocab, direction)
|
| 65 |
+
yield data
|
| 66 |
+
|
| 67 |
+
def build_argparse():
|
| 68 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 69 |
+
parser.add_argument('--train_file', type=str, help="Input plaintext file")
|
| 70 |
+
parser.add_argument('--train_dir', type=str, help="If non-empty, load from directory with multiple training files")
|
| 71 |
+
parser.add_argument('--eval_file', type=str, help="Input plaintext file for the dev/test set")
|
| 72 |
+
parser.add_argument('--shorthand', type=str, help="UD treebank shorthand")
|
| 73 |
+
|
| 74 |
+
parser.add_argument('--mode', default='train', choices=['train', 'predict'])
|
| 75 |
+
parser.add_argument('--direction', default='forward', choices=['forward', 'backward'], help="Forward or backward language model")
|
| 76 |
+
parser.add_argument('--forward', action='store_const', dest='direction', const='forward', help="Train a forward language model")
|
| 77 |
+
parser.add_argument('--backward', action='store_const', dest='direction', const='backward', help="Train a backward language model")
|
| 78 |
+
|
| 79 |
+
parser.add_argument('--char_emb_dim', type=int, default=100, help="Dimension of unit embeddings")
|
| 80 |
+
parser.add_argument('--char_hidden_dim', type=int, default=1024, help="Dimension of hidden units")
|
| 81 |
+
parser.add_argument('--char_num_layers', type=int, default=1, help="Layers of RNN in the language model")
|
| 82 |
+
parser.add_argument('--char_dropout', type=float, default=0.05, help="Dropout probability")
|
| 83 |
+
parser.add_argument('--char_unit_dropout', type=float, default=1e-5, help="Randomly set an input char to UNK during training")
|
| 84 |
+
parser.add_argument('--char_rec_dropout', type=float, default=0.0, help="Recurrent dropout probability")
|
| 85 |
+
|
| 86 |
+
parser.add_argument('--batch_size', type=int, default=100, help="Batch size to use")
|
| 87 |
+
parser.add_argument('--bptt_size', type=int, default=250, help="Sequence length to consider at a time")
|
| 88 |
+
parser.add_argument('--epochs', type=int, default=50, help="Total epochs to train the model for")
|
| 89 |
+
parser.add_argument('--max_grad_norm', type=float, default=0.25, help="Maximum gradient norm to clip to")
|
| 90 |
+
parser.add_argument('--lr0', type=float, default=5, help="Initial learning rate")
|
| 91 |
+
parser.add_argument('--anneal', type=float, default=0.25, help="Anneal the learning rate by this amount when dev performance deteriorate")
|
| 92 |
+
parser.add_argument('--patience', type=int, default=1, help="Patience for annealing the learning rate")
|
| 93 |
+
parser.add_argument('--weight_decay', type=float, default=0.0, help="Weight decay")
|
| 94 |
+
parser.add_argument('--momentum', type=float, default=0.0, help='Momentum for SGD.')
|
| 95 |
+
parser.add_argument('--cutoff', type=int, default=1000, help="Frequency cutoff for char vocab. By default we assume a very large corpus.")
|
| 96 |
+
|
| 97 |
+
parser.add_argument('--report_steps', type=int, default=50, help="Update step interval to report loss")
|
| 98 |
+
parser.add_argument('--eval_steps', type=int, default=100000, help="Update step interval to run eval on dev; set to -1 to eval after each epoch")
|
| 99 |
+
parser.add_argument('--save_name', type=str, default=None, help="File name to save the model")
|
| 100 |
+
parser.add_argument('--vocab_save_name', type=str, default=None, help="File name to save the vocab")
|
| 101 |
+
parser.add_argument('--checkpoint_save_name', type=str, default=None, help="File name to save the most recent checkpoint")
|
| 102 |
+
parser.add_argument('--no_checkpoint', dest='checkpoint', action='store_false', help="Don't save checkpoints")
|
| 103 |
+
parser.add_argument('--save_dir', type=str, default='saved_models/charlm', help="Directory to save models in")
|
| 104 |
+
parser.add_argument('--summary', action='store_true', help='Use summary writer to record progress.')
|
| 105 |
+
utils.add_device_args(parser)
|
| 106 |
+
parser.add_argument('--seed', type=int, default=1234)
|
| 107 |
+
|
| 108 |
+
parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
|
| 109 |
+
parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
|
| 110 |
+
return parser
|
| 111 |
+
|
| 112 |
+
def build_model_filename(args):
|
| 113 |
+
if args['save_name']:
|
| 114 |
+
save_name = args['save_name']
|
| 115 |
+
else:
|
| 116 |
+
save_name = '{}_{}_charlm.pt'.format(args['shorthand'], args['direction'])
|
| 117 |
+
model_file = os.path.join(args['save_dir'], save_name)
|
| 118 |
+
return model_file
|
| 119 |
+
|
| 120 |
+
def parse_args(args=None):
|
| 121 |
+
parser = build_argparse()
|
| 122 |
+
|
| 123 |
+
args = parser.parse_args(args=args)
|
| 124 |
+
|
| 125 |
+
if args.wandb_name:
|
| 126 |
+
args.wandb = True
|
| 127 |
+
|
| 128 |
+
args = vars(args)
|
| 129 |
+
return args
|
| 130 |
+
|
| 131 |
+
def main(args=None):
|
| 132 |
+
args = parse_args(args=args)
|
| 133 |
+
|
| 134 |
+
utils.set_random_seed(args['seed'])
|
| 135 |
+
|
| 136 |
+
logger.info("Running {} character-level language model in {} mode".format(args['direction'], args['mode']))
|
| 137 |
+
|
| 138 |
+
utils.ensure_dir(args['save_dir'])
|
| 139 |
+
|
| 140 |
+
if args['mode'] == 'train':
|
| 141 |
+
train(args)
|
| 142 |
+
else:
|
| 143 |
+
evaluate(args)
|
| 144 |
+
|
| 145 |
+
def evaluate_epoch(args, vocab, data, model, criterion):
|
| 146 |
+
"""
|
| 147 |
+
Run an evaluation over entire dataset.
|
| 148 |
+
"""
|
| 149 |
+
model.eval()
|
| 150 |
+
device = next(model.parameters()).device
|
| 151 |
+
hidden = None
|
| 152 |
+
total_loss = 0
|
| 153 |
+
if isinstance(data, GeneratorType):
|
| 154 |
+
data = list(data)
|
| 155 |
+
assert len(data) == 1, 'Only support single dev/test file'
|
| 156 |
+
data = data[0]
|
| 157 |
+
batches = batchify(data, args['batch_size'], device)
|
| 158 |
+
with torch.no_grad():
|
| 159 |
+
for i in range(0, batches.size(1) - 1, args['bptt_size']):
|
| 160 |
+
data, target = get_batch(batches, i, args['bptt_size'])
|
| 161 |
+
lens = [data.size(1) for i in range(data.size(0))]
|
| 162 |
+
|
| 163 |
+
output, hidden, decoded = model.forward(data, lens, hidden)
|
| 164 |
+
loss = criterion(decoded.view(-1, len(vocab['char'])), target)
|
| 165 |
+
|
| 166 |
+
hidden = repackage_hidden(hidden)
|
| 167 |
+
total_loss += data.size(1) * loss.data.item()
|
| 168 |
+
return total_loss / batches.size(1)
|
| 169 |
+
|
| 170 |
+
def evaluate_and_save(args, vocab, data, trainer, best_loss, model_file, checkpoint_file, writer=None):
|
| 171 |
+
"""
|
| 172 |
+
Run an evaluation over entire dataset, print progress and save the model if necessary.
|
| 173 |
+
"""
|
| 174 |
+
start_time = time.time()
|
| 175 |
+
loss = evaluate_epoch(args, vocab, data, trainer.model, trainer.criterion)
|
| 176 |
+
ppl = math.exp(loss)
|
| 177 |
+
elapsed = int(time.time() - start_time)
|
| 178 |
+
# TODO: step the scheduler less often when the eval frequency is higher
|
| 179 |
+
previous_lr = get_current_lr(trainer, args)
|
| 180 |
+
trainer.scheduler.step(loss)
|
| 181 |
+
current_lr = get_current_lr(trainer, args)
|
| 182 |
+
if previous_lr != current_lr:
|
| 183 |
+
logger.info("Updating learning rate to %f", current_lr)
|
| 184 |
+
logger.info(
|
| 185 |
+
"| eval checkpoint @ global step {:10d} | time elapsed {:6d}s | loss {:5.2f} | ppl {:8.2f}".format(
|
| 186 |
+
trainer.global_step,
|
| 187 |
+
elapsed,
|
| 188 |
+
loss,
|
| 189 |
+
ppl,
|
| 190 |
+
)
|
| 191 |
+
)
|
| 192 |
+
if best_loss is None or loss < best_loss:
|
| 193 |
+
best_loss = loss
|
| 194 |
+
trainer.save(model_file, full=False)
|
| 195 |
+
logger.info('new best model saved at step {:10d}'.format(trainer.global_step))
|
| 196 |
+
if writer:
|
| 197 |
+
writer.add_scalar('dev_loss', loss, global_step=trainer.global_step)
|
| 198 |
+
writer.add_scalar('dev_ppl', ppl, global_step=trainer.global_step)
|
| 199 |
+
if checkpoint_file:
|
| 200 |
+
trainer.save(checkpoint_file, full=True)
|
| 201 |
+
logger.info('new checkpoint saved at step {:10d}'.format(trainer.global_step))
|
| 202 |
+
|
| 203 |
+
return loss, ppl, best_loss
|
| 204 |
+
|
| 205 |
+
def get_current_lr(trainer, args):
|
| 206 |
+
return trainer.scheduler.state_dict().get('_last_lr', [args['lr0']])[0]
|
| 207 |
+
|
| 208 |
+
def load_char_vocab(vocab_file):
|
| 209 |
+
return {'char': CharVocab.load_state_dict(torch.load(vocab_file, lambda storage, loc: storage, weights_only=True))}
|
| 210 |
+
|
| 211 |
+
def train(args):
|
| 212 |
+
utils.log_training_args(args, logger)
|
| 213 |
+
model_file = build_model_filename(args)
|
| 214 |
+
|
| 215 |
+
vocab_file = args['save_dir'] + '/' + args['vocab_save_name'] if args['vocab_save_name'] is not None \
|
| 216 |
+
else '{}/{}_vocab.pt'.format(args['save_dir'], args['shorthand'])
|
| 217 |
+
|
| 218 |
+
if args['checkpoint']:
|
| 219 |
+
checkpoint_file = utils.checkpoint_name(args['save_dir'], model_file, args['checkpoint_save_name'])
|
| 220 |
+
else:
|
| 221 |
+
checkpoint_file = None
|
| 222 |
+
|
| 223 |
+
if os.path.exists(vocab_file):
|
| 224 |
+
logger.info('Loading existing vocab file')
|
| 225 |
+
vocab = load_char_vocab(vocab_file)
|
| 226 |
+
else:
|
| 227 |
+
logger.info('Building and saving vocab')
|
| 228 |
+
vocab = {'char': build_charlm_vocab(args['train_file'] if args['train_dir'] is None else args['train_dir'], cutoff=args['cutoff'])}
|
| 229 |
+
torch.save(vocab['char'].state_dict(), vocab_file)
|
| 230 |
+
logger.info("Training model with vocab size: {}".format(len(vocab['char'])))
|
| 231 |
+
|
| 232 |
+
if checkpoint_file and os.path.exists(checkpoint_file):
|
| 233 |
+
logger.info('Loading existing checkpoint: %s' % checkpoint_file)
|
| 234 |
+
trainer = CharacterLanguageModelTrainer.load(args, checkpoint_file, finetune=True)
|
| 235 |
+
else:
|
| 236 |
+
trainer = CharacterLanguageModelTrainer.from_new_model(args, vocab)
|
| 237 |
+
|
| 238 |
+
writer = None
|
| 239 |
+
if args['summary']:
|
| 240 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 241 |
+
summary_dir = '{}/{}_summary'.format(args['save_dir'], args['save_name']) if args['save_name'] is not None \
|
| 242 |
+
else '{}/{}_{}_charlm_summary'.format(args['save_dir'], args['shorthand'], args['direction'])
|
| 243 |
+
writer = SummaryWriter(log_dir=summary_dir)
|
| 244 |
+
|
| 245 |
+
# evaluate model within epoch if eval_interval is set
|
| 246 |
+
eval_within_epoch = False
|
| 247 |
+
if args['eval_steps'] > 0:
|
| 248 |
+
eval_within_epoch = True
|
| 249 |
+
|
| 250 |
+
if args['wandb']:
|
| 251 |
+
import wandb
|
| 252 |
+
wandb_name = args['wandb_name'] if args['wandb_name'] else '%s_%s_charlm' % (args['shorthand'], args['direction'])
|
| 253 |
+
wandb.init(name=wandb_name, config=args)
|
| 254 |
+
wandb.run.define_metric('best_loss', summary='min')
|
| 255 |
+
wandb.run.define_metric('ppl', summary='min')
|
| 256 |
+
|
| 257 |
+
device = next(trainer.model.parameters()).device
|
| 258 |
+
|
| 259 |
+
best_loss = None
|
| 260 |
+
start_epoch = trainer.epoch # will default to 1 for a new trainer
|
| 261 |
+
for trainer.epoch in range(start_epoch, args['epochs']+1):
|
| 262 |
+
# load train data from train_dir if not empty, otherwise load from file
|
| 263 |
+
if args['train_dir'] is not None:
|
| 264 |
+
train_path = args['train_dir']
|
| 265 |
+
else:
|
| 266 |
+
train_path = args['train_file']
|
| 267 |
+
train_data = load_data(train_path, vocab, args['direction'])
|
| 268 |
+
dev_data = load_file(args['eval_file'], vocab, args['direction']) # dev must be a single file
|
| 269 |
+
|
| 270 |
+
# run over entire training set
|
| 271 |
+
for data_chunk in train_data:
|
| 272 |
+
batches = batchify(data_chunk, args['batch_size'], device)
|
| 273 |
+
hidden = None
|
| 274 |
+
total_loss = 0.0
|
| 275 |
+
total_batches = math.ceil((batches.size(1) - 1) / args['bptt_size'])
|
| 276 |
+
iteration, i = 0, 0
|
| 277 |
+
# over the data chunk
|
| 278 |
+
while i < batches.size(1) - 1 - 1:
|
| 279 |
+
trainer.model.train()
|
| 280 |
+
trainer.global_step += 1
|
| 281 |
+
start_time = time.time()
|
| 282 |
+
bptt = args['bptt_size'] if np.random.random() < 0.95 else args['bptt_size']/ 2.
|
| 283 |
+
# prevent excessively small or negative sequence lengths
|
| 284 |
+
seq_len = max(5, int(np.random.normal(bptt, 5)))
|
| 285 |
+
# prevent very large sequence length, must be <= 1.2 x bptt
|
| 286 |
+
seq_len = min(seq_len, int(args['bptt_size'] * 1.2))
|
| 287 |
+
data, target = get_batch(batches, i, seq_len)
|
| 288 |
+
lens = [data.size(1) for i in range(data.size(0))]
|
| 289 |
+
|
| 290 |
+
trainer.optimizer.zero_grad()
|
| 291 |
+
output, hidden, decoded = trainer.model.forward(data, lens, hidden)
|
| 292 |
+
loss = trainer.criterion(decoded.view(-1, len(vocab['char'])), target)
|
| 293 |
+
total_loss += loss.data.item()
|
| 294 |
+
loss.backward()
|
| 295 |
+
|
| 296 |
+
torch.nn.utils.clip_grad_norm_(trainer.params, args['max_grad_norm'])
|
| 297 |
+
trainer.optimizer.step()
|
| 298 |
+
|
| 299 |
+
hidden = repackage_hidden(hidden)
|
| 300 |
+
|
| 301 |
+
if (iteration + 1) % args['report_steps'] == 0:
|
| 302 |
+
cur_loss = total_loss / args['report_steps']
|
| 303 |
+
elapsed = time.time() - start_time
|
| 304 |
+
logger.info(
|
| 305 |
+
"| epoch {:5d} | {:5d}/{:5d} batches | sec/batch {:.6f} | loss {:5.2f} | ppl {:8.2f}".format(
|
| 306 |
+
trainer.epoch,
|
| 307 |
+
iteration + 1,
|
| 308 |
+
total_batches,
|
| 309 |
+
elapsed / args['report_steps'],
|
| 310 |
+
cur_loss,
|
| 311 |
+
math.exp(cur_loss),
|
| 312 |
+
)
|
| 313 |
+
)
|
| 314 |
+
if args['wandb']:
|
| 315 |
+
wandb.log({'train_loss': cur_loss}, step=trainer.global_step)
|
| 316 |
+
total_loss = 0.0
|
| 317 |
+
|
| 318 |
+
iteration += 1
|
| 319 |
+
i += seq_len
|
| 320 |
+
|
| 321 |
+
# evaluate if necessary
|
| 322 |
+
if eval_within_epoch and trainer.global_step % args['eval_steps'] == 0:
|
| 323 |
+
_, ppl, best_loss = evaluate_and_save(args, vocab, dev_data, trainer, best_loss, model_file, checkpoint_file, writer)
|
| 324 |
+
if args['wandb']:
|
| 325 |
+
wandb.log({'ppl': ppl, 'best_loss': best_loss, 'lr': get_current_lr(trainer, args)}, step=trainer.global_step)
|
| 326 |
+
|
| 327 |
+
# if eval_interval isn't provided, run evaluation after each epoch
|
| 328 |
+
if not eval_within_epoch or trainer.epoch == args['epochs']:
|
| 329 |
+
_, ppl, best_loss = evaluate_and_save(args, vocab, dev_data, trainer, best_loss, model_file, checkpoint_file, writer)
|
| 330 |
+
if args['wandb']:
|
| 331 |
+
wandb.log({'ppl': ppl, 'best_loss': best_loss, 'lr': get_current_lr(trainer, args)}, step=trainer.global_step)
|
| 332 |
+
|
| 333 |
+
if writer:
|
| 334 |
+
writer.close()
|
| 335 |
+
if args['wandb']:
|
| 336 |
+
wandb.finish()
|
| 337 |
+
return
|
| 338 |
+
|
| 339 |
+
def evaluate(args):
|
| 340 |
+
model_file = build_model_filename(args)
|
| 341 |
+
|
| 342 |
+
model = CharacterLanguageModel.load(model_file).to(args['device'])
|
| 343 |
+
vocab = model.vocab
|
| 344 |
+
data = load_data(args['eval_file'], vocab, args['direction'])
|
| 345 |
+
criterion = torch.nn.CrossEntropyLoss()
|
| 346 |
+
|
| 347 |
+
loss = evaluate_epoch(args, vocab, data, model, criterion)
|
| 348 |
+
logger.info(
|
| 349 |
+
"| best model | loss {:5.2f} | ppl {:8.2f}".format(
|
| 350 |
+
loss,
|
| 351 |
+
math.exp(loss),
|
| 352 |
+
)
|
| 353 |
+
)
|
| 354 |
+
return
|
| 355 |
+
|
| 356 |
+
if __name__ == '__main__':
|
| 357 |
+
main()
|
stanza/stanza/models/identity_lemmatizer.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
An identity lemmatizer that mimics the behavior of a normal lemmatizer but directly uses word as lemma.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import argparse
|
| 7 |
+
import logging
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
from stanza.models.lemma.data import DataLoader
|
| 11 |
+
from stanza.models.lemma import scorer
|
| 12 |
+
from stanza.models.common import utils
|
| 13 |
+
from stanza.models.common.doc import *
|
| 14 |
+
from stanza.utils.conll import CoNLL
|
| 15 |
+
from stanza.models import _training_logging
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger('stanza')
|
| 18 |
+
|
| 19 |
+
def parse_args(args=None):
|
| 20 |
+
parser = argparse.ArgumentParser()
|
| 21 |
+
parser.add_argument('--data_dir', type=str, default='data/lemma', help='Directory for all lemma data.')
|
| 22 |
+
parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')
|
| 23 |
+
parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')
|
| 24 |
+
parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.')
|
| 25 |
+
parser.add_argument('--gold_file', type=str, default=None, help='Output CoNLL-U file.')
|
| 26 |
+
|
| 27 |
+
parser.add_argument('--mode', default='train', choices=['train', 'predict'])
|
| 28 |
+
parser.add_argument('--shorthand', type=str, help='Shorthand')
|
| 29 |
+
|
| 30 |
+
parser.add_argument('--batch_size', type=int, default=50)
|
| 31 |
+
parser.add_argument('--seed', type=int, default=1234)
|
| 32 |
+
|
| 33 |
+
args = parser.parse_args(args=args)
|
| 34 |
+
return args
|
| 35 |
+
|
| 36 |
+
def main(args=None):
|
| 37 |
+
args = parse_args(args=args)
|
| 38 |
+
|
| 39 |
+
random.seed(args.seed)
|
| 40 |
+
|
| 41 |
+
args = vars(args)
|
| 42 |
+
|
| 43 |
+
logger.info("[Launching identity lemmatizer...]")
|
| 44 |
+
|
| 45 |
+
if args['mode'] == 'train':
|
| 46 |
+
logger.info("[No training is required; will only generate evaluation output...]")
|
| 47 |
+
|
| 48 |
+
document = CoNLL.conll2doc(input_file=args['eval_file'])
|
| 49 |
+
batch = DataLoader(document, args['batch_size'], args, evaluation=True, conll_only=True)
|
| 50 |
+
system_pred_file = args['output_file']
|
| 51 |
+
gold_file = args['gold_file']
|
| 52 |
+
|
| 53 |
+
# use identity mapping for prediction
|
| 54 |
+
preds = batch.doc.get([TEXT])
|
| 55 |
+
|
| 56 |
+
# write to file and score
|
| 57 |
+
batch.doc.set([LEMMA], preds)
|
| 58 |
+
CoNLL.write_doc2conll(batch.doc, system_pred_file)
|
| 59 |
+
if gold_file is not None:
|
| 60 |
+
_, _, score = scorer.score(system_pred_file, gold_file)
|
| 61 |
+
|
| 62 |
+
logger.info("Lemma score:")
|
| 63 |
+
logger.info("{} {:.2f}".format(args['shorthand'], score*100))
|
| 64 |
+
|
| 65 |
+
if __name__ == '__main__':
|
| 66 |
+
main()
|
stanza/stanza/models/lang_identifier.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Entry point for training and evaluating a Bi-LSTM language identifier
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import random
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from stanza.models.common import utils
|
| 14 |
+
from stanza.models.langid.data import DataLoader
|
| 15 |
+
from stanza.models.langid.trainer import Trainer
|
| 16 |
+
from stanza.utils.get_tqdm import get_tqdm
|
| 17 |
+
|
| 18 |
+
tqdm = get_tqdm()
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger('stanza')
|
| 21 |
+
|
| 22 |
+
def parse_args(args=None):
|
| 23 |
+
parser = argparse.ArgumentParser()
|
| 24 |
+
parser.add_argument("--batch_mode", help="custom settings when running in batch mode", action="store_true")
|
| 25 |
+
parser.add_argument("--batch_size", help="batch size for training", type=int, default=64)
|
| 26 |
+
parser.add_argument("--eval_length", help="length of strings to eval on", type=int, default=None)
|
| 27 |
+
parser.add_argument("--eval_set", help="eval on dev or test", default="test")
|
| 28 |
+
parser.add_argument("--data_dir", help="directory with train/dev/test data", default=None)
|
| 29 |
+
parser.add_argument("--load_name", help="path to load model from", default=None)
|
| 30 |
+
parser.add_argument("--mode", help="train or eval", default="train")
|
| 31 |
+
parser.add_argument("--num_epochs", help="number of epochs for training", type=int, default=50)
|
| 32 |
+
parser.add_argument("--randomize", help="take random substrings of samples", action="store_true")
|
| 33 |
+
parser.add_argument("--randomize_lengths_range", help="range of lengths to use when random sampling text",
|
| 34 |
+
type=randomize_lengths_range, default="5,20")
|
| 35 |
+
parser.add_argument("--merge_labels_for_eval",
|
| 36 |
+
help="merge some language labels for eval (e.g. \"zh-hans\" and \"zh-hant\" to \"zh\")",
|
| 37 |
+
action="store_true")
|
| 38 |
+
parser.add_argument("--save_best_epochs", help="save model for every epoch with new best score", action="store_true")
|
| 39 |
+
parser.add_argument("--save_name", help="where to save model", default=None)
|
| 40 |
+
utils.add_device_args(parser)
|
| 41 |
+
args = parser.parse_args(args=args)
|
| 42 |
+
return args
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def randomize_lengths_range(range_list):
|
| 46 |
+
"""
|
| 47 |
+
Range of lengths for random samples
|
| 48 |
+
"""
|
| 49 |
+
range_boundaries = [int(x) for x in range_list.split(",")]
|
| 50 |
+
assert range_boundaries[0] < range_boundaries[1], f"Invalid range: ({range_boundaries[0]}, {range_boundaries[1]})"
|
| 51 |
+
return range_boundaries
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def main(args=None):
|
| 55 |
+
args = parse_args(args=args)
|
| 56 |
+
torch.manual_seed(0)
|
| 57 |
+
if args.mode == "train":
|
| 58 |
+
train_model(args)
|
| 59 |
+
else:
|
| 60 |
+
eval_model(args)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def build_indexes(args):
|
| 64 |
+
tag_to_idx = {}
|
| 65 |
+
char_to_idx = {}
|
| 66 |
+
train_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if "train" in x]
|
| 67 |
+
for train_file in train_files:
|
| 68 |
+
with open(train_file) as curr_file:
|
| 69 |
+
lines = curr_file.read().strip().split("\n")
|
| 70 |
+
examples = [json.loads(line) for line in lines if line.strip()]
|
| 71 |
+
for example in examples:
|
| 72 |
+
label = example["label"]
|
| 73 |
+
if label not in tag_to_idx:
|
| 74 |
+
tag_to_idx[label] = len(tag_to_idx)
|
| 75 |
+
sequence = example["text"]
|
| 76 |
+
for char in list(sequence):
|
| 77 |
+
if char not in char_to_idx:
|
| 78 |
+
char_to_idx[char] = len(char_to_idx)
|
| 79 |
+
char_to_idx["UNK"] = len(char_to_idx)
|
| 80 |
+
char_to_idx["<PAD>"] = len(char_to_idx)
|
| 81 |
+
|
| 82 |
+
return tag_to_idx, char_to_idx
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def train_model(args):
|
| 86 |
+
# set up indexes
|
| 87 |
+
tag_to_idx, char_to_idx = build_indexes(args)
|
| 88 |
+
# load training data
|
| 89 |
+
train_data = DataLoader(args.device)
|
| 90 |
+
train_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if "train" in x]
|
| 91 |
+
train_data.load_data(args.batch_size, train_files, char_to_idx, tag_to_idx, args.randomize)
|
| 92 |
+
# load dev data
|
| 93 |
+
dev_data = DataLoader(args.device)
|
| 94 |
+
dev_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if "dev" in x]
|
| 95 |
+
dev_data.load_data(args.batch_size, dev_files, char_to_idx, tag_to_idx, randomize=False,
|
| 96 |
+
max_length=args.eval_length)
|
| 97 |
+
# set up trainer
|
| 98 |
+
trainer_config = {
|
| 99 |
+
"model_path": args.save_name,
|
| 100 |
+
"char_to_idx": char_to_idx,
|
| 101 |
+
"tag_to_idx": tag_to_idx,
|
| 102 |
+
"batch_size": args.batch_size,
|
| 103 |
+
"lang_weights": train_data.lang_weights
|
| 104 |
+
}
|
| 105 |
+
if args.load_name:
|
| 106 |
+
trainer_config["load_name"] = args.load_name
|
| 107 |
+
logger.info(f"{datetime.now()}\tLoading model from: {args.load_name}")
|
| 108 |
+
trainer = Trainer(trainer_config, load_model=args.load_name is not None, device=args.device)
|
| 109 |
+
# run training
|
| 110 |
+
best_accuracy = 0.0
|
| 111 |
+
for epoch in range(1, args.num_epochs+1):
|
| 112 |
+
logger.info(f"{datetime.now()}\tEpoch {epoch}")
|
| 113 |
+
logger.info(f"{datetime.now()}\tNum training batches: {len(train_data.batches)}")
|
| 114 |
+
|
| 115 |
+
batches = train_data.batches
|
| 116 |
+
if not args.batch_mode:
|
| 117 |
+
batches = tqdm(batches)
|
| 118 |
+
for train_batch in batches:
|
| 119 |
+
inputs = (train_batch["sentences"], train_batch["targets"])
|
| 120 |
+
trainer.update(inputs)
|
| 121 |
+
|
| 122 |
+
logger.info(f"{datetime.now()}\tEpoch complete. Evaluating on dev data.")
|
| 123 |
+
curr_dev_accuracy, curr_confusion_matrix, curr_precisions, curr_recalls, curr_f1s = \
|
| 124 |
+
eval_trainer(trainer, dev_data, batch_mode=args.batch_mode)
|
| 125 |
+
logger.info(f"{datetime.now()}\tCurrent dev accuracy: {curr_dev_accuracy}")
|
| 126 |
+
if curr_dev_accuracy > best_accuracy:
|
| 127 |
+
logger.info(f"{datetime.now()}\tNew best score. Saving model.")
|
| 128 |
+
model_label = f"epoch{epoch}" if args.save_best_epochs else None
|
| 129 |
+
trainer.save(label=model_label)
|
| 130 |
+
with open(score_log_path(args.save_name), "w") as score_log_file:
|
| 131 |
+
for score_log in [{"dev_accuracy": curr_dev_accuracy}, curr_confusion_matrix, curr_precisions,
|
| 132 |
+
curr_recalls, curr_f1s]:
|
| 133 |
+
score_log_file.write(json.dumps(score_log) + "\n")
|
| 134 |
+
best_accuracy = curr_dev_accuracy
|
| 135 |
+
|
| 136 |
+
# reload training data
|
| 137 |
+
logger.info(f"{datetime.now()}\tResampling training data.")
|
| 138 |
+
train_data.load_data(args.batch_size, train_files, char_to_idx, tag_to_idx, args.randomize)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def score_log_path(file_path):
|
| 142 |
+
"""
|
| 143 |
+
Helper that will determine corresponding log file (e.g. /path/to/demo.pt to /path/to/demo.json
|
| 144 |
+
"""
|
| 145 |
+
model_suffix = os.path.splitext(file_path)
|
| 146 |
+
if model_suffix[1]:
|
| 147 |
+
score_log_path = f"{file_path[:-len(model_suffix[1])]}.json"
|
| 148 |
+
else:
|
| 149 |
+
score_log_path = f"{file_path}.json"
|
| 150 |
+
return score_log_path
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def eval_model(args):
|
| 154 |
+
# set up trainer
|
| 155 |
+
trainer_config = {
|
| 156 |
+
"model_path": None,
|
| 157 |
+
"load_name": args.load_name,
|
| 158 |
+
"batch_size": args.batch_size
|
| 159 |
+
}
|
| 160 |
+
trainer = Trainer(trainer_config, load_model=True, device=args.device)
|
| 161 |
+
# load test data
|
| 162 |
+
test_data = DataLoader(args.device)
|
| 163 |
+
test_files = [f"{args.data_dir}/{x}" for x in os.listdir(args.data_dir) if args.eval_set in x]
|
| 164 |
+
test_data.load_data(args.batch_size, test_files, trainer.model.char_to_idx, trainer.model.tag_to_idx,
|
| 165 |
+
randomize=False, max_length=args.eval_length)
|
| 166 |
+
curr_accuracy, curr_confusion_matrix, curr_precisions, curr_recalls, curr_f1s = \
|
| 167 |
+
eval_trainer(trainer, test_data, batch_mode=args.batch_mode, fine_grained=not args.merge_labels_for_eval)
|
| 168 |
+
logger.info(f"{datetime.now()}\t{args.eval_set} accuracy: {curr_accuracy}")
|
| 169 |
+
eval_save_path = args.save_name if args.save_name else score_log_path(args.load_name)
|
| 170 |
+
if not os.path.exists(eval_save_path) or args.save_name:
|
| 171 |
+
with open(eval_save_path, "w") as score_log_file:
|
| 172 |
+
for score_log in [{"dev_accuracy": curr_accuracy}, curr_confusion_matrix, curr_precisions,
|
| 173 |
+
curr_recalls, curr_f1s]:
|
| 174 |
+
score_log_file.write(json.dumps(score_log) + "\n")
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def eval_trainer(trainer, dev_data, batch_mode=False, fine_grained=True):
|
| 179 |
+
"""
|
| 180 |
+
Produce dev accuracy and confusion matrix for a trainer
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
# set up confusion matrix
|
| 184 |
+
tag_to_idx = dev_data.tag_to_idx
|
| 185 |
+
idx_to_tag = dev_data.idx_to_tag
|
| 186 |
+
confusion_matrix = {}
|
| 187 |
+
for row_label in tag_to_idx:
|
| 188 |
+
confusion_matrix[row_label] = {}
|
| 189 |
+
for col_label in tag_to_idx:
|
| 190 |
+
confusion_matrix[row_label][col_label] = 0
|
| 191 |
+
|
| 192 |
+
# process dev batches
|
| 193 |
+
batches = dev_data.batches
|
| 194 |
+
if not batch_mode:
|
| 195 |
+
batches = tqdm(batches)
|
| 196 |
+
for dev_batch in batches:
|
| 197 |
+
inputs = (dev_batch["sentences"], dev_batch["targets"])
|
| 198 |
+
predictions = trainer.predict(inputs)
|
| 199 |
+
for target_idx, prediction in zip(dev_batch["targets"], predictions):
|
| 200 |
+
prediction_label = idx_to_tag[prediction] if fine_grained else idx_to_tag[prediction].split("-")[0]
|
| 201 |
+
confusion_matrix[idx_to_tag[target_idx]][prediction_label] += 1
|
| 202 |
+
|
| 203 |
+
# calculate dev accuracy
|
| 204 |
+
total_examples = sum([sum([confusion_matrix[i][j] for j in confusion_matrix[i]]) for i in confusion_matrix])
|
| 205 |
+
total_correct = sum([confusion_matrix[i][i] for i in confusion_matrix])
|
| 206 |
+
dev_accuracy = float(total_correct) / float(total_examples)
|
| 207 |
+
|
| 208 |
+
# calculate precision, recall, F1
|
| 209 |
+
precision_scores = {"type": "precision"}
|
| 210 |
+
recall_scores = {"type": "recall"}
|
| 211 |
+
f1_scores = {"type": "f1"}
|
| 212 |
+
for prediction_label in tag_to_idx:
|
| 213 |
+
total = sum([confusion_matrix[k][prediction_label] for k in tag_to_idx])
|
| 214 |
+
if total != 0.0:
|
| 215 |
+
precision_scores[prediction_label] = float(confusion_matrix[prediction_label][prediction_label])/float(total)
|
| 216 |
+
else:
|
| 217 |
+
precision_scores[prediction_label] = 0.0
|
| 218 |
+
for target_label in tag_to_idx:
|
| 219 |
+
total = sum([confusion_matrix[target_label][k] for k in tag_to_idx])
|
| 220 |
+
if total != 0:
|
| 221 |
+
recall_scores[target_label] = float(confusion_matrix[target_label][target_label])/float(total)
|
| 222 |
+
else:
|
| 223 |
+
recall_scores[target_label] = 0.0
|
| 224 |
+
for label in tag_to_idx:
|
| 225 |
+
if precision_scores[label] == 0.0 and recall_scores[label] == 0.0:
|
| 226 |
+
f1_scores[label] = 0.0
|
| 227 |
+
else:
|
| 228 |
+
f1_scores[label] = \
|
| 229 |
+
2.0 * (precision_scores[label] * recall_scores[label]) / (precision_scores[label] + recall_scores[label])
|
| 230 |
+
|
| 231 |
+
return dev_accuracy, confusion_matrix, precision_scores, recall_scores, f1_scores
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
if __name__ == "__main__":
|
| 235 |
+
main()
|
| 236 |
+
|
stanza/stanza/models/mwt_expander.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Entry point for training and evaluating a multi-word token (MWT) expander.
|
| 3 |
+
|
| 4 |
+
This MWT expander combines a neural sequence-to-sequence architecture with a dictionary
|
| 5 |
+
to decode the token into multiple words.
|
| 6 |
+
For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf
|
| 7 |
+
|
| 8 |
+
In the case of a dataset where all of the MWT exactly split into the words
|
| 9 |
+
composing the MWT, a classifier over the characters is used instead of the seq2seq
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import sys
|
| 13 |
+
import os
|
| 14 |
+
import shutil
|
| 15 |
+
import time
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
import argparse
|
| 18 |
+
import logging
|
| 19 |
+
import math
|
| 20 |
+
import numpy as np
|
| 21 |
+
import random
|
| 22 |
+
import torch
|
| 23 |
+
from torch import nn, optim
|
| 24 |
+
import copy
|
| 25 |
+
|
| 26 |
+
from stanza.models.mwt.data import DataLoader, BinaryDataLoader
|
| 27 |
+
from stanza.models.mwt.utils import mwts_composed_of_words
|
| 28 |
+
from stanza.models.mwt.vocab import Vocab
|
| 29 |
+
from stanza.models.mwt.trainer import Trainer
|
| 30 |
+
from stanza.models.mwt import scorer
|
| 31 |
+
from stanza.models.common import utils
|
| 32 |
+
import stanza.models.common.seq2seq_constant as constant
|
| 33 |
+
from stanza.models.common.doc import Document
|
| 34 |
+
from stanza.utils.conll import CoNLL
|
| 35 |
+
from stanza.models import _training_logging
|
| 36 |
+
|
| 37 |
+
logger = logging.getLogger('stanza')
|
| 38 |
+
|
| 39 |
+
def build_argparse():
|
| 40 |
+
parser = argparse.ArgumentParser()
|
| 41 |
+
parser.add_argument('--data_dir', type=str, default='data/mwt', help='Root dir for saving models.')
|
| 42 |
+
parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')
|
| 43 |
+
parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')
|
| 44 |
+
parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.')
|
| 45 |
+
parser.add_argument('--gold_file', type=str, default=None, help='Output CoNLL-U file.')
|
| 46 |
+
|
| 47 |
+
parser.add_argument('--mode', default='train', choices=['train', 'predict'])
|
| 48 |
+
parser.add_argument('--lang', type=str, help='Language')
|
| 49 |
+
parser.add_argument('--shorthand', type=str, help="Treebank shorthand")
|
| 50 |
+
|
| 51 |
+
parser.add_argument('--no_dict', dest='ensemble_dict', action='store_false', help='Do not ensemble dictionary with seq2seq. By default ensemble a dict.')
|
| 52 |
+
parser.add_argument('--ensemble_early_stop', action='store_true', help='Early stopping based on ensemble performance.')
|
| 53 |
+
parser.add_argument('--dict_only', action='store_true', help='Only train a dictionary-based MWT expander.')
|
| 54 |
+
|
| 55 |
+
parser.add_argument('--hidden_dim', type=int, default=100)
|
| 56 |
+
parser.add_argument('--emb_dim', type=int, default=50)
|
| 57 |
+
parser.add_argument('--num_layers', type=int, default=None, help='Number of layers in model encoder. Defaults to 1 for seq2seq, 2 for classifier')
|
| 58 |
+
parser.add_argument('--emb_dropout', type=float, default=0.5)
|
| 59 |
+
parser.add_argument('--dropout', type=float, default=0.5)
|
| 60 |
+
parser.add_argument('--max_dec_len', type=int, default=50)
|
| 61 |
+
parser.add_argument('--beam_size', type=int, default=1)
|
| 62 |
+
parser.add_argument('--attn_type', default='soft', choices=['soft', 'mlp', 'linear', 'deep'], help='Attention type')
|
| 63 |
+
parser.add_argument('--no_copy', dest='copy', action='store_false', help='Do not use copy mechanism in MWT expansion. By default copy mechanism is used to improve generalization.')
|
| 64 |
+
|
| 65 |
+
parser.add_argument('--augment_apos', default=0.01, type=float, help='At training time, how much to augment |\'| to |"| |’| |ʼ|')
|
| 66 |
+
parser.add_argument('--force_exact_pieces', default=None, action='store_true', help='If possible, make the text of the pieces of the MWT add up to the token itself. (By default, this is determined from the dataset.)')
|
| 67 |
+
parser.add_argument('--no_force_exact_pieces', dest='force_exact_pieces', action='store_false', help="Don't make the text of the pieces of the MWT add up to the token itself. (By default, this is determined from the dataset.)")
|
| 68 |
+
|
| 69 |
+
parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.')
|
| 70 |
+
parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam or adamax.')
|
| 71 |
+
parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
|
| 72 |
+
parser.add_argument('--lr_decay', type=float, default=0.9)
|
| 73 |
+
parser.add_argument('--decay_epoch', type=int, default=30, help="Decay the lr starting from this epoch.")
|
| 74 |
+
parser.add_argument('--num_epoch', type=int, default=30)
|
| 75 |
+
parser.add_argument('--batch_size', type=int, default=50)
|
| 76 |
+
parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.')
|
| 77 |
+
parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.')
|
| 78 |
+
parser.add_argument('--save_dir', type=str, default='saved_models/mwt', help='Root dir for saving models.')
|
| 79 |
+
parser.add_argument('--save_name', type=str, default=None, help="File name to save the model")
|
| 80 |
+
parser.add_argument('--save_each_name', type=str, default=None, help="Save each model in sequence to this pattern. Mostly for testing")
|
| 81 |
+
|
| 82 |
+
parser.add_argument('--seed', type=int, default=1234)
|
| 83 |
+
utils.add_device_args(parser)
|
| 84 |
+
|
| 85 |
+
parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
|
| 86 |
+
parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
|
| 87 |
+
return parser
|
| 88 |
+
|
| 89 |
+
def parse_args(args=None):
|
| 90 |
+
parser = build_argparse()
|
| 91 |
+
args = parser.parse_args(args=args)
|
| 92 |
+
|
| 93 |
+
if args.wandb_name:
|
| 94 |
+
args.wandb = True
|
| 95 |
+
|
| 96 |
+
return args
|
| 97 |
+
|
| 98 |
+
def main(args=None):
|
| 99 |
+
args = parse_args(args=args)
|
| 100 |
+
|
| 101 |
+
utils.set_random_seed(args.seed)
|
| 102 |
+
|
| 103 |
+
args = vars(args)
|
| 104 |
+
logger.info("Running MWT expander in {} mode".format(args['mode']))
|
| 105 |
+
|
| 106 |
+
if args['mode'] == 'train':
|
| 107 |
+
train(args)
|
| 108 |
+
else:
|
| 109 |
+
evaluate(args)
|
| 110 |
+
|
| 111 |
+
def train(args):
|
| 112 |
+
# load data
|
| 113 |
+
logger.debug('max_dec_len: %d' % args['max_dec_len'])
|
| 114 |
+
logger.debug("Loading data with batch size {}...".format(args['batch_size']))
|
| 115 |
+
train_doc = CoNLL.conll2doc(input_file=args['train_file'])
|
| 116 |
+
train_batch = DataLoader(train_doc, args['batch_size'], args, evaluation=False)
|
| 117 |
+
vocab = train_batch.vocab
|
| 118 |
+
args['vocab_size'] = vocab.size
|
| 119 |
+
dev_doc = CoNLL.conll2doc(input_file=args['eval_file'])
|
| 120 |
+
dev_batch = DataLoader(dev_doc, args['batch_size'], args, vocab=vocab, evaluation=True)
|
| 121 |
+
|
| 122 |
+
utils.ensure_dir(args['save_dir'])
|
| 123 |
+
save_name = args['save_name'] if args['save_name'] else '{}_mwt_expander.pt'.format(args['shorthand'])
|
| 124 |
+
model_file = os.path.join(args['save_dir'], save_name)
|
| 125 |
+
|
| 126 |
+
save_each_name = None
|
| 127 |
+
if args['save_each_name']:
|
| 128 |
+
save_each_name = os.path.join(args['save_dir'], args['save_each_name'])
|
| 129 |
+
save_each_name = utils.build_save_each_filename(save_each_name)
|
| 130 |
+
|
| 131 |
+
# pred and gold path
|
| 132 |
+
system_pred_file = args['output_file']
|
| 133 |
+
gold_file = args['gold_file']
|
| 134 |
+
|
| 135 |
+
# skip training if the language does not have training or dev data
|
| 136 |
+
if len(train_batch) == 0:
|
| 137 |
+
logger.warning("Skip training because no data available...")
|
| 138 |
+
return
|
| 139 |
+
dev_mwt = dev_doc.get_mwt_expansions(False)
|
| 140 |
+
if len(dev_batch) == 0 and args.get('dict_only', False):
|
| 141 |
+
logger.warning("Training data available, but dev data has no MWTs. Only training a dict based MWT")
|
| 142 |
+
args['dict_only'] = True
|
| 143 |
+
|
| 144 |
+
if args['force_exact_pieces'] and not mwts_composed_of_words(train_doc):
|
| 145 |
+
raise ValueError("Cannot train model with --force_exact_pieces, as the MWT in this dataset are not entirely composed of their subwords")
|
| 146 |
+
|
| 147 |
+
if args['force_exact_pieces'] is None and mwts_composed_of_words(train_doc):
|
| 148 |
+
# the force_exact_pieces mechanism trains a separate version of the MWT expander in the Trainer
|
| 149 |
+
# (the training loop here does not need to change)
|
| 150 |
+
# in this model, a classifier distinguishes whether or not a location is a split
|
| 151 |
+
# and the text is copied exactly from the input rather than created via seq2seq
|
| 152 |
+
# this behavior can be turned off at training time with --no_force_exact_pieces
|
| 153 |
+
logger.info("Train MWTs entirely composed of their subwords. Training the MWT to match that paradigm as closely as possible")
|
| 154 |
+
args['force_exact_pieces'] = True
|
| 155 |
+
|
| 156 |
+
if args['force_exact_pieces']:
|
| 157 |
+
logger.info("Reconverting to BinaryDataLoader")
|
| 158 |
+
train_batch = BinaryDataLoader(train_doc, args['batch_size'], args, evaluation=False)
|
| 159 |
+
vocab = train_batch.vocab
|
| 160 |
+
args['vocab_size'] = vocab.size
|
| 161 |
+
dev_batch = BinaryDataLoader(dev_doc, args['batch_size'], args, vocab=vocab, evaluation=True)
|
| 162 |
+
|
| 163 |
+
if args['num_layers'] is None:
|
| 164 |
+
if args['force_exact_pieces']:
|
| 165 |
+
args['num_layers'] = 2
|
| 166 |
+
else:
|
| 167 |
+
args['num_layers'] = 1
|
| 168 |
+
|
| 169 |
+
# train a dictionary-based MWT expander
|
| 170 |
+
trainer = Trainer(args=args, vocab=vocab, device=args['device'])
|
| 171 |
+
logger.info("Training dictionary-based MWT expander...")
|
| 172 |
+
trainer.train_dict(train_batch.doc.get_mwt_expansions(evaluation=False))
|
| 173 |
+
logger.info("Evaluating on dev set...")
|
| 174 |
+
dev_preds = trainer.predict_dict(dev_batch.doc.get_mwt_expansions(evaluation=True))
|
| 175 |
+
doc = copy.deepcopy(dev_batch.doc)
|
| 176 |
+
doc.set_mwt_expansions(dev_preds, fake_dependencies=True)
|
| 177 |
+
CoNLL.write_doc2conll(doc, system_pred_file)
|
| 178 |
+
_, _, dev_f = scorer.score(system_pred_file, gold_file)
|
| 179 |
+
logger.info("Dev F1 = {:.2f}".format(dev_f * 100))
|
| 180 |
+
|
| 181 |
+
if args.get('dict_only', False):
|
| 182 |
+
# save dictionaries
|
| 183 |
+
trainer.save(model_file)
|
| 184 |
+
else:
|
| 185 |
+
# train a seq2seq model
|
| 186 |
+
logger.info("Training seq2seq-based MWT expander...")
|
| 187 |
+
global_step = 0
|
| 188 |
+
steps_per_epoch = math.ceil(len(train_batch) / args['batch_size'])
|
| 189 |
+
max_steps = steps_per_epoch * args['num_epoch']
|
| 190 |
+
dev_score_history = []
|
| 191 |
+
best_dev_preds = []
|
| 192 |
+
current_lr = args['lr']
|
| 193 |
+
global_start_time = time.time()
|
| 194 |
+
format_str = '{}: step {}/{} (epoch {}/{}), loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'
|
| 195 |
+
|
| 196 |
+
if args['wandb']:
|
| 197 |
+
import wandb
|
| 198 |
+
wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_mwt" % args['shorthand']
|
| 199 |
+
wandb.init(name=wandb_name, config=args)
|
| 200 |
+
wandb.run.define_metric('train_loss', summary='min')
|
| 201 |
+
wandb.run.define_metric('dev_score', summary='max')
|
| 202 |
+
|
| 203 |
+
# start training
|
| 204 |
+
for epoch in range(1, args['num_epoch']+1):
|
| 205 |
+
train_loss = 0
|
| 206 |
+
for i, batch in enumerate(train_batch.to_loader()):
|
| 207 |
+
start_time = time.time()
|
| 208 |
+
global_step += 1
|
| 209 |
+
loss = trainer.update(batch, eval=False) # update step
|
| 210 |
+
train_loss += loss
|
| 211 |
+
if global_step % args['log_step'] == 0:
|
| 212 |
+
duration = time.time() - start_time
|
| 213 |
+
logger.info(format_str.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), global_step,\
|
| 214 |
+
max_steps, epoch, args['num_epoch'], loss, duration, current_lr))
|
| 215 |
+
|
| 216 |
+
if save_each_name:
|
| 217 |
+
trainer.save(save_each_name % epoch)
|
| 218 |
+
logger.info("Saved epoch %d model to %s" % (epoch, save_each_name % epoch))
|
| 219 |
+
|
| 220 |
+
# eval on dev
|
| 221 |
+
logger.info("Evaluating on dev set...")
|
| 222 |
+
dev_preds = []
|
| 223 |
+
for i, batch in enumerate(dev_batch.to_loader()):
|
| 224 |
+
preds = trainer.predict(batch)
|
| 225 |
+
dev_preds += preds
|
| 226 |
+
if args.get('ensemble_dict', False) and args.get('ensemble_early_stop', False):
|
| 227 |
+
logger.info("[Ensembling dict with seq2seq model...]")
|
| 228 |
+
dev_preds = trainer.ensemble(dev_batch.doc.get_mwt_expansions(evaluation=True), dev_preds)
|
| 229 |
+
doc = copy.deepcopy(dev_batch.doc)
|
| 230 |
+
doc.set_mwt_expansions(dev_preds, fake_dependencies=True)
|
| 231 |
+
CoNLL.write_doc2conll(doc, system_pred_file)
|
| 232 |
+
_, _, dev_score = scorer.score(system_pred_file, gold_file)
|
| 233 |
+
train_loss = train_loss / train_batch.num_examples * args['batch_size'] # avg loss per batch
|
| 234 |
+
logger.info("epoch {}: train_loss = {:.6f}, dev_score = {:.4f}".format(epoch, train_loss, dev_score))
|
| 235 |
+
|
| 236 |
+
if args['wandb']:
|
| 237 |
+
wandb.log({'train_loss': train_loss, 'dev_score': dev_score})
|
| 238 |
+
|
| 239 |
+
# save best model
|
| 240 |
+
if epoch == 1 or dev_score > max(dev_score_history):
|
| 241 |
+
trainer.save(model_file)
|
| 242 |
+
logger.info("new best model saved.")
|
| 243 |
+
best_dev_preds = dev_preds
|
| 244 |
+
|
| 245 |
+
# lr schedule
|
| 246 |
+
if epoch > args['decay_epoch'] and dev_score <= dev_score_history[-1]:
|
| 247 |
+
current_lr *= args['lr_decay']
|
| 248 |
+
trainer.change_lr(current_lr)
|
| 249 |
+
|
| 250 |
+
dev_score_history += [dev_score]
|
| 251 |
+
|
| 252 |
+
logger.info("Training ended with {} epochs.".format(epoch))
|
| 253 |
+
|
| 254 |
+
if args['wandb']:
|
| 255 |
+
wandb.finish()
|
| 256 |
+
|
| 257 |
+
best_f, best_epoch = max(dev_score_history)*100, np.argmax(dev_score_history)+1
|
| 258 |
+
logger.info("Best dev F1 = {:.2f}, at epoch = {}".format(best_f, best_epoch))
|
| 259 |
+
|
| 260 |
+
# try ensembling with dict if necessary
|
| 261 |
+
if args.get('ensemble_dict', False):
|
| 262 |
+
logger.info("[Ensembling dict with seq2seq model...]")
|
| 263 |
+
dev_preds = trainer.ensemble(dev_batch.doc.get_mwt_expansions(evaluation=True), best_dev_preds)
|
| 264 |
+
doc = copy.deepcopy(dev_batch.doc)
|
| 265 |
+
doc.set_mwt_expansions(dev_preds, fake_dependencies=True)
|
| 266 |
+
CoNLL.write_doc2conll(doc, system_pred_file)
|
| 267 |
+
_, _, dev_score = scorer.score(system_pred_file, gold_file)
|
| 268 |
+
logger.info("Ensemble dev F1 = {:.2f}".format(dev_score*100))
|
| 269 |
+
best_f = max(best_f, dev_score)
|
| 270 |
+
|
| 271 |
+
def evaluate(args):
|
| 272 |
+
# file paths
|
| 273 |
+
system_pred_file = args['output_file']
|
| 274 |
+
gold_file = args['gold_file']
|
| 275 |
+
model_file = args['save_name'] if args['save_name'] else '{}_mwt_expander.pt'.format(args['shorthand'])
|
| 276 |
+
if args['save_dir'] and not model_file.startswith(args['save_dir']) and not os.path.exists(model_file):
|
| 277 |
+
model_file = os.path.join(args['save_dir'], model_file)
|
| 278 |
+
|
| 279 |
+
# load model
|
| 280 |
+
trainer = Trainer(model_file=model_file, device=args['device'])
|
| 281 |
+
loaded_args, vocab = trainer.args, trainer.vocab
|
| 282 |
+
|
| 283 |
+
for k in args:
|
| 284 |
+
if k.endswith('_dir') or k.endswith('_file') or k in ['shorthand']:
|
| 285 |
+
loaded_args[k] = args[k]
|
| 286 |
+
logger.debug('max_dec_len: %d' % loaded_args['max_dec_len'])
|
| 287 |
+
|
| 288 |
+
# load data
|
| 289 |
+
logger.debug("Loading data with batch size {}...".format(args['batch_size']))
|
| 290 |
+
doc = CoNLL.conll2doc(input_file=args['eval_file'])
|
| 291 |
+
batch = DataLoader(doc, args['batch_size'], loaded_args, vocab=vocab, evaluation=True)
|
| 292 |
+
|
| 293 |
+
if len(batch) > 0:
|
| 294 |
+
dict_preds = trainer.predict_dict(batch.doc.get_mwt_expansions(evaluation=True))
|
| 295 |
+
# decide trainer type and run eval
|
| 296 |
+
if loaded_args['dict_only']:
|
| 297 |
+
preds = dict_preds
|
| 298 |
+
else:
|
| 299 |
+
logger.info("Running the seq2seq model...")
|
| 300 |
+
preds = []
|
| 301 |
+
for i, b in enumerate(batch.to_loader()):
|
| 302 |
+
preds += trainer.predict(b)
|
| 303 |
+
|
| 304 |
+
if loaded_args.get('ensemble_dict', False):
|
| 305 |
+
preds = trainer.ensemble(batch.doc.get_mwt_expansions(evaluation=True), preds)
|
| 306 |
+
else:
|
| 307 |
+
# skip eval if dev data does not exist
|
| 308 |
+
preds = []
|
| 309 |
+
|
| 310 |
+
# write to file and score
|
| 311 |
+
doc = copy.deepcopy(batch.doc)
|
| 312 |
+
doc.set_mwt_expansions(preds, fake_dependencies=True)
|
| 313 |
+
CoNLL.write_doc2conll(doc, system_pred_file)
|
| 314 |
+
|
| 315 |
+
if gold_file is not None:
|
| 316 |
+
_, _, score = scorer.score(system_pred_file, gold_file)
|
| 317 |
+
|
| 318 |
+
logger.info("MWT expansion score: {} {:.2f}".format(args['shorthand'], score*100))
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
if __name__ == '__main__':
|
| 322 |
+
main()
|
stanza/stanza/models/ner_tagger.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Entry point for training and evaluating an NER tagger.
|
| 3 |
+
|
| 4 |
+
This tagger uses BiLSTM layers with character and word-level representations, and a CRF decoding layer
|
| 5 |
+
to produce NER predictions.
|
| 6 |
+
For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import sys
|
| 10 |
+
import os
|
| 11 |
+
import time
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
import argparse
|
| 14 |
+
import logging
|
| 15 |
+
import numpy as np
|
| 16 |
+
import random
|
| 17 |
+
import re
|
| 18 |
+
import json
|
| 19 |
+
import torch
|
| 20 |
+
from torch import nn, optim
|
| 21 |
+
|
| 22 |
+
from stanza.models.ner.data import DataLoader
|
| 23 |
+
from stanza.models.ner.trainer import Trainer
|
| 24 |
+
from stanza.models.ner import scorer
|
| 25 |
+
from stanza.models.common import utils
|
| 26 |
+
from stanza.models.common.pretrain import Pretrain
|
| 27 |
+
from stanza.utils.conll import CoNLL
|
| 28 |
+
from stanza.models.common.doc import *
|
| 29 |
+
from stanza.models import _training_logging
|
| 30 |
+
|
| 31 |
+
from stanza.models.common.peft_config import add_peft_args, resolve_peft_args
|
| 32 |
+
from stanza.utils.confusion import confusion_to_weighted_f1, format_confusion
|
| 33 |
+
|
| 34 |
+
logger = logging.getLogger('stanza')
|
| 35 |
+
|
| 36 |
+
def build_argparse():
|
| 37 |
+
parser = argparse.ArgumentParser()
|
| 38 |
+
parser.add_argument('--data_dir', type=str, default='data/ner', help='Directory of NER data.')
|
| 39 |
+
parser.add_argument('--wordvec_dir', type=str, default='extern_data/word2vec', help='Directory of word vectors')
|
| 40 |
+
parser.add_argument('--wordvec_file', type=str, default='', help='File that contains word vectors')
|
| 41 |
+
parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
|
| 42 |
+
parser.add_argument('--train_file', type=str, default=None, help='Input file for data loader.')
|
| 43 |
+
parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')
|
| 44 |
+
parser.add_argument('--eval_output_file', type=str, default=None, help='Where to write results: text, gold, pred. If None, no results file printed')
|
| 45 |
+
|
| 46 |
+
parser.add_argument('--mode', default='train', choices=['train', 'predict'])
|
| 47 |
+
parser.add_argument('--finetune', action='store_true', help='Load existing model during `train` mode from `save_dir` path')
|
| 48 |
+
parser.add_argument('--finetune_load_name', type=str, default=None, help='Model to load when finetuning')
|
| 49 |
+
parser.add_argument('--train_classifier_only', action='store_true',
|
| 50 |
+
help='In case of applying Transfer-learning approach and training only the classifier layer this will freeze gradient propagation for all other layers.')
|
| 51 |
+
parser.add_argument('--shorthand', type=str, help="Treebank shorthand")
|
| 52 |
+
|
| 53 |
+
parser.add_argument('--hidden_dim', type=int, default=256)
|
| 54 |
+
parser.add_argument('--char_hidden_dim', type=int, default=100)
|
| 55 |
+
parser.add_argument('--word_emb_dim', type=int, default=100)
|
| 56 |
+
parser.add_argument('--char_emb_dim', type=int, default=100)
|
| 57 |
+
parser.add_argument('--num_layers', type=int, default=1)
|
| 58 |
+
parser.add_argument('--char_num_layers', type=int, default=1)
|
| 59 |
+
parser.add_argument('--pretrain_max_vocab', type=int, default=100000)
|
| 60 |
+
parser.add_argument('--word_dropout', type=float, default=0.01, help="How often to remove a word at training time. Set to a small value to train unk when finetuning word embeddings")
|
| 61 |
+
parser.add_argument('--locked_dropout', type=float, default=0.0)
|
| 62 |
+
parser.add_argument('--dropout', type=float, default=0.5)
|
| 63 |
+
parser.add_argument('--rec_dropout', type=float, default=0, help="Word recurrent dropout")
|
| 64 |
+
parser.add_argument('--char_rec_dropout', type=float, default=0, help="Character recurrent dropout")
|
| 65 |
+
parser.add_argument('--char_dropout', type=float, default=0, help="Character-level language model dropout")
|
| 66 |
+
parser.add_argument('--no_char', dest='char', action='store_false', help="Turn off training a character model.")
|
| 67 |
+
parser.add_argument('--charlm', action='store_true', help="Turn on contextualized char embedding using pretrained character-level language model.")
|
| 68 |
+
parser.add_argument('--charlm_save_dir', type=str, default='saved_models/charlm', help="Root dir for pretrained character-level language model.")
|
| 69 |
+
parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
|
| 70 |
+
parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
|
| 71 |
+
parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
|
| 72 |
+
parser.add_argument('--char_lowercase', dest='char_lowercase', action='store_true', help="Use lowercased characters in character model.")
|
| 73 |
+
parser.add_argument('--no_lowercase', dest='lowercase', action='store_false', help="Use cased word vectors.")
|
| 74 |
+
parser.add_argument('--no_emb_finetune', dest='emb_finetune', action='store_false', help="Turn off finetuning of the embedding matrix.")
|
| 75 |
+
parser.add_argument('--emb_finetune_known_only', dest='emb_finetune_known_only', action='store_true', help="Finetune the embedding matrix only for words in the embedding. (Default: finetune words not in the embedding as well) This may be useful for very large datasets where obscure words are only trained once in a while, such as French-WikiNER")
|
| 76 |
+
parser.add_argument('--no_input_transform', dest='input_transform', action='store_false', help="Do not use input transformation layer before tagger lstm.")
|
| 77 |
+
parser.add_argument('--scheme', type=str, default='bioes', help="The tagging scheme to use: bio or bioes.")
|
| 78 |
+
parser.add_argument('--train_scheme', type=str, default=None, help="The tagging scheme to use when training: bio or bioes. Overrides --scheme for the training set")
|
| 79 |
+
|
| 80 |
+
parser.add_argument('--bert_model', type=str, default=None, help="Use an external bert model (requires the transformers package)")
|
| 81 |
+
parser.add_argument('--no_bert_model', dest='bert_model', action="store_const", const=None, help="Don't use bert")
|
| 82 |
+
parser.add_argument('--bert_hidden_layers', type=int, default=None, help="How many layers of hidden state to use from the transformer")
|
| 83 |
+
parser.add_argument('--bert_finetune', default=False, action='store_true', help='Finetune the bert (or other transformer)')
|
| 84 |
+
parser.add_argument('--gradient_checkpointing', default=False, action='store_true', help='Checkpoint intermediate gradients between layers to save memory at the cost of training steps')
|
| 85 |
+
parser.add_argument('--no_bert_finetune', dest='bert_finetune', action='store_false', help="Don't finetune the bert (or other transformer)")
|
| 86 |
+
parser.add_argument('--bert_learning_rate', default=1.0, type=float, help='Scale the learning rate for transformer finetuning by this much')
|
| 87 |
+
parser.add_argument('--second_optim', type=str, default=None, help='once first optimizer converged, tune the model again. with: sgd, adagrad, adam or adamax.')
|
| 88 |
+
parser.add_argument('--second_bert_learning_rate', default=0, type=float, help='Secondary stage transformer finetuning learning rate scale')
|
| 89 |
+
|
| 90 |
+
parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.')
|
| 91 |
+
parser.add_argument('--optim', type=str, default='sgd', help='sgd, adagrad, adam or adamax.')
|
| 92 |
+
parser.add_argument('--lr', type=float, default=0.1, help='Learning rate.')
|
| 93 |
+
parser.add_argument('--min_lr', type=float, default=1e-4, help='Minimum learning rate to stop training.')
|
| 94 |
+
parser.add_argument('--second_lr', type=float, default=5e-3, help='Secondary learning rate')
|
| 95 |
+
parser.add_argument('--momentum', type=float, default=0, help='Momentum for SGD.')
|
| 96 |
+
parser.add_argument('--lr_decay', type=float, default=0.5, help="LR decay rate.")
|
| 97 |
+
parser.add_argument('--patience', type=int, default=3, help="Patience for LR decay.")
|
| 98 |
+
|
| 99 |
+
parser.add_argument('--connect_output_layers', action='store_true', default=False, help='Connect one output layer to the input of the next output layer. By default, those layers are all separate')
|
| 100 |
+
parser.add_argument('--predict_tagset', type=int, default=None, help='Which tagset to predict if there are multiple tagsets. Will default to 0. Default of None allows the model to remember the value from training time, but be overridden at test time')
|
| 101 |
+
|
| 102 |
+
parser.add_argument('--ignore_tag_scores', type=str, default=None, help="Which tags to ignore, if any, when scoring dev & test sets")
|
| 103 |
+
|
| 104 |
+
parser.add_argument('--max_steps', type=int, default=200000)
|
| 105 |
+
parser.add_argument('--max_steps_no_improve', type=int, default=2500, help='if the model doesn\'t improve after this many steps, give up or switch to new optimizer.')
|
| 106 |
+
parser.add_argument('--eval_interval', type=int, default=500)
|
| 107 |
+
parser.add_argument('--batch_size', type=int, default=32)
|
| 108 |
+
parser.add_argument('--max_batch_words', type=int, default=800, help='Long sentences can overwhelm even a large GPU when finetuning a transformer on otherwise reasonable batch sizes. This cuts off those batches early')
|
| 109 |
+
parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Gradient clipping.')
|
| 110 |
+
parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.')
|
| 111 |
+
parser.add_argument('--log_norms', action='store_true', default=False, help='Log the norms of all the parameters (noisy!)')
|
| 112 |
+
parser.add_argument('--save_dir', type=str, default='saved_models/ner', help='Root dir for saving models.')
|
| 113 |
+
parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_{finetune}_nertagger.pt", help="File name to save the model")
|
| 114 |
+
|
| 115 |
+
parser.add_argument('--seed', type=int, default=1234)
|
| 116 |
+
utils.add_device_args(parser)
|
| 117 |
+
|
| 118 |
+
parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
|
| 119 |
+
parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
|
| 120 |
+
return parser
|
| 121 |
+
|
| 122 |
+
def parse_args(args=None):
|
| 123 |
+
parser = build_argparse()
|
| 124 |
+
add_peft_args(parser)
|
| 125 |
+
args = parser.parse_args(args=args)
|
| 126 |
+
resolve_peft_args(args, logger)
|
| 127 |
+
|
| 128 |
+
if args.wandb_name:
|
| 129 |
+
args.wandb = True
|
| 130 |
+
|
| 131 |
+
args = vars(args)
|
| 132 |
+
return args
|
| 133 |
+
|
| 134 |
+
def main(args=None):
|
| 135 |
+
args = parse_args(args=args)
|
| 136 |
+
|
| 137 |
+
utils.set_random_seed(args['seed'])
|
| 138 |
+
|
| 139 |
+
logger.info("Running NER tagger in {} mode".format(args['mode']))
|
| 140 |
+
|
| 141 |
+
if args['mode'] == 'train':
|
| 142 |
+
return train(args)
|
| 143 |
+
else:
|
| 144 |
+
evaluate(args)
|
| 145 |
+
|
| 146 |
+
def load_pretrain(args):
|
| 147 |
+
# load pretrained vectors
|
| 148 |
+
if args['wordvec_pretrain_file']:
|
| 149 |
+
pretrain_file = args['wordvec_pretrain_file']
|
| 150 |
+
pretrain = Pretrain(pretrain_file, None, args['pretrain_max_vocab'], save_to_file=False)
|
| 151 |
+
else:
|
| 152 |
+
if len(args['wordvec_file']) == 0:
|
| 153 |
+
vec_file = utils.get_wordvec_file(args['wordvec_dir'], args['shorthand'])
|
| 154 |
+
else:
|
| 155 |
+
vec_file = args['wordvec_file']
|
| 156 |
+
# do not save pretrained embeddings individually
|
| 157 |
+
pretrain = Pretrain(None, vec_file, args['pretrain_max_vocab'], save_to_file=False)
|
| 158 |
+
return pretrain
|
| 159 |
+
|
| 160 |
+
def model_file_name(args):
|
| 161 |
+
return utils.standard_model_file_name(args, "nertagger")
|
| 162 |
+
|
| 163 |
+
def get_known_tags(tags):
|
| 164 |
+
"""
|
| 165 |
+
Tags are stored in the dataset as a list of list of tags
|
| 166 |
+
|
| 167 |
+
This returns a sorted list for each column of tags in the dataset
|
| 168 |
+
"""
|
| 169 |
+
max_columns = max(len(word) for sent in tags for word in sent)
|
| 170 |
+
known_tags = [set() for _ in range(max_columns)]
|
| 171 |
+
for sent in tags:
|
| 172 |
+
for word in sent:
|
| 173 |
+
for tag_idx, tag in enumerate(word):
|
| 174 |
+
known_tags[tag_idx].add(tag)
|
| 175 |
+
return [sorted(x) for x in known_tags]
|
| 176 |
+
|
| 177 |
+
def warn_missing_tags(tag_vocab, data_tags, error_msg, bioes_to_bio=False):
|
| 178 |
+
"""
|
| 179 |
+
Check for tags missing from the tag_vocab.
|
| 180 |
+
|
| 181 |
+
Given a tag_vocab and the known tags in the format used by
|
| 182 |
+
ner.data, go through the tags in the dataset and look for any
|
| 183 |
+
which aren't in the tag_vocab.
|
| 184 |
+
|
| 185 |
+
error_msg is something like "training set" or "eval file" to
|
| 186 |
+
indicate where the missing tags came from.
|
| 187 |
+
"""
|
| 188 |
+
tag_depth = max(max(len(tags) for tags in sentence) for sentence in data_tags)
|
| 189 |
+
|
| 190 |
+
if tag_depth != len(tag_vocab.lens()):
|
| 191 |
+
logger.warning("Test dataset has a different number of tag types compared to the model: %d vs %d", tag_depth, len(tag_vocab.lens()))
|
| 192 |
+
for tag_set_idx in range(min(tag_depth, len(tag_vocab.lens()))):
|
| 193 |
+
tag_set = tag_vocab.items(tag_set_idx)
|
| 194 |
+
if len(tag_vocab.lens()) > 1:
|
| 195 |
+
current_error_msg = error_msg + " tag set %d" % tag_set_idx
|
| 196 |
+
else:
|
| 197 |
+
current_error_msg = error_msg
|
| 198 |
+
|
| 199 |
+
current_tags = set([word[tag_set_idx] for sentence in data_tags for word in sentence])
|
| 200 |
+
if bioes_to_bio:
|
| 201 |
+
current_tags = set([re.sub("^E-", "I-", re.sub("^S-", "B-", x)) for x in current_tags])
|
| 202 |
+
utils.warn_missing_tags(tag_set, current_tags, current_error_msg)
|
| 203 |
+
|
| 204 |
+
def train(args):
|
| 205 |
+
model_file = model_file_name(args)
|
| 206 |
+
|
| 207 |
+
save_dir, save_name = os.path.split(model_file)
|
| 208 |
+
utils.ensure_dir(save_dir)
|
| 209 |
+
if args['save_dir'] is None:
|
| 210 |
+
args['save_dir'] = save_dir
|
| 211 |
+
args['save_name'] = save_name
|
| 212 |
+
|
| 213 |
+
utils.log_training_args(args, logger)
|
| 214 |
+
|
| 215 |
+
pretrain = None
|
| 216 |
+
vocab = None
|
| 217 |
+
trainer = None
|
| 218 |
+
|
| 219 |
+
if args['finetune'] and args['finetune_load_name']:
|
| 220 |
+
logger.warning('Finetune is ON. Using model from "{}"'.format(args['finetune_load_name']))
|
| 221 |
+
_, trainer, vocab = load_model(args, args['finetune_load_name'])
|
| 222 |
+
elif args['finetune'] and os.path.exists(model_file):
|
| 223 |
+
logger.warning('Finetune is ON. Using model from "{}"'.format(model_file))
|
| 224 |
+
_, trainer, vocab = load_model(args, model_file)
|
| 225 |
+
else:
|
| 226 |
+
if args['finetune']:
|
| 227 |
+
raise FileNotFoundError('Finetune is set to true but model file is not found: {}'.format(model_file))
|
| 228 |
+
|
| 229 |
+
pretrain = load_pretrain(args)
|
| 230 |
+
|
| 231 |
+
if pretrain is not None:
|
| 232 |
+
word_emb_dim = pretrain.emb.shape[1]
|
| 233 |
+
if args['word_emb_dim'] and args['word_emb_dim'] != word_emb_dim:
|
| 234 |
+
logger.warning("Embedding file has a dimension of {}. Model will be built with that size instead of {}".format(word_emb_dim, args['word_emb_dim']))
|
| 235 |
+
args['word_emb_dim'] = word_emb_dim
|
| 236 |
+
|
| 237 |
+
if args['charlm']:
|
| 238 |
+
if args['charlm_shorthand'] is None:
|
| 239 |
+
raise ValueError("CharLM Shorthand is required for loading pretrained CharLM model...")
|
| 240 |
+
logger.info('Using pretrained contextualized char embedding')
|
| 241 |
+
if not args['charlm_forward_file']:
|
| 242 |
+
args['charlm_forward_file'] = '{}/{}_forward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])
|
| 243 |
+
if not args['charlm_backward_file']:
|
| 244 |
+
args['charlm_backward_file'] = '{}/{}_backward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])
|
| 245 |
+
|
| 246 |
+
# load data
|
| 247 |
+
logger.info("Loading training data with batch size %d from %s", args['batch_size'], args['train_file'])
|
| 248 |
+
with open(args['train_file']) as fin:
|
| 249 |
+
train_doc = Document(json.load(fin))
|
| 250 |
+
logger.info("Loaded %d sentences of training data", len(train_doc.sentences))
|
| 251 |
+
if len(train_doc.sentences) == 0:
|
| 252 |
+
raise ValueError("File %s exists but has no usable training data" % args['train_file'])
|
| 253 |
+
train_batch = DataLoader(train_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=False, scheme=args.get('train_scheme'), max_batch_words=args['max_batch_words'])
|
| 254 |
+
vocab = train_batch.vocab
|
| 255 |
+
logger.info("Loading dev data from %s", args['eval_file'])
|
| 256 |
+
with open(args['eval_file']) as fin:
|
| 257 |
+
dev_doc = Document(json.load(fin))
|
| 258 |
+
logger.info("Loaded %d sentences of dev data", len(dev_doc.sentences))
|
| 259 |
+
if len(dev_doc.sentences) == 0:
|
| 260 |
+
raise ValueError("File %s exists but has no usable dev data" % args['train_file'])
|
| 261 |
+
dev_batch = DataLoader(dev_doc, args['batch_size'], args, pretrain, vocab=vocab, evaluation=True)
|
| 262 |
+
|
| 263 |
+
train_tags = get_known_tags(train_batch.tags)
|
| 264 |
+
logger.info("Training data has %d columns of tags", len(train_tags))
|
| 265 |
+
for tag_idx, tags in enumerate(train_tags):
|
| 266 |
+
logger.info("Tags present in training set at column %d:\n Tags without BIES markers: %s\n Tags with B-, I-, E-, or S-: %s",
|
| 267 |
+
tag_idx,
|
| 268 |
+
" ".join(sorted(set(i for i in tags if i[:2] not in ('B-', 'I-', 'E-', 'S-')))),
|
| 269 |
+
" ".join(sorted(set(i[2:] for i in tags if i[:2] in ('B-', 'I-', 'E-', 'S-')))))
|
| 270 |
+
|
| 271 |
+
# skip training if the language does not have training or dev data
|
| 272 |
+
if len(train_batch) == 0 or len(dev_batch) == 0:
|
| 273 |
+
logger.info("Skip training because no data available...")
|
| 274 |
+
return
|
| 275 |
+
|
| 276 |
+
logger.info("Training tagger...")
|
| 277 |
+
if trainer is None: # init if model was not loaded previously from file
|
| 278 |
+
trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, device=args['device'],
|
| 279 |
+
train_classifier_only=args['train_classifier_only'])
|
| 280 |
+
|
| 281 |
+
if args['finetune']:
|
| 282 |
+
warn_missing_tags(trainer.vocab['tag'], train_batch.tags, "training set")
|
| 283 |
+
# the evaluation will coerce the tags to the proper scheme,
|
| 284 |
+
# so we won't need to alert for not having S- or E- tags
|
| 285 |
+
bioes_to_bio = args['train_scheme'] == 'bio' and args['scheme'] == 'bioes'
|
| 286 |
+
warn_missing_tags(trainer.vocab['tag'], dev_batch.tags, "dev set", bioes_to_bio=bioes_to_bio)
|
| 287 |
+
|
| 288 |
+
# TODO: might still want to add multiple layers of tag evaluation to the scorer
|
| 289 |
+
dev_gold_tags = [[x[trainer.args['predict_tagset']] for x in tags] for tags in dev_batch.tags]
|
| 290 |
+
|
| 291 |
+
logger.info(trainer.model)
|
| 292 |
+
|
| 293 |
+
global_step = 0
|
| 294 |
+
max_steps = args['max_steps']
|
| 295 |
+
dev_score_history = []
|
| 296 |
+
best_dev_preds = []
|
| 297 |
+
current_lr = trainer.optimizer.param_groups[0]['lr']
|
| 298 |
+
global_start_time = time.time()
|
| 299 |
+
format_str = '{}: step {}/{}, loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'
|
| 300 |
+
|
| 301 |
+
# LR scheduling
|
| 302 |
+
if args['lr_decay'] > 0:
|
| 303 |
+
# learning rate changes on plateau -- no improvement on model for patience number of epochs
|
| 304 |
+
# change is made as a factor of the learning rate decay
|
| 305 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(trainer.optimizer, mode='max', factor=args['lr_decay'],
|
| 306 |
+
patience=args['patience'], verbose=True, min_lr=args['min_lr'])
|
| 307 |
+
else:
|
| 308 |
+
scheduler = None
|
| 309 |
+
|
| 310 |
+
if args['wandb']:
|
| 311 |
+
import wandb
|
| 312 |
+
wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_ner" % args['shorthand']
|
| 313 |
+
wandb.init(name=wandb_name, config=args)
|
| 314 |
+
wandb.run.define_metric('train_loss', summary='min')
|
| 315 |
+
wandb.run.define_metric('dev_score', summary='max')
|
| 316 |
+
# track gradients!
|
| 317 |
+
wandb.watch(trainer.model, log_freq=4, log="gradients")
|
| 318 |
+
|
| 319 |
+
# start training
|
| 320 |
+
last_best_step = 0
|
| 321 |
+
train_loss = 0
|
| 322 |
+
is_second_optim = False
|
| 323 |
+
while True:
|
| 324 |
+
should_stop = False
|
| 325 |
+
for i, batch in enumerate(train_batch):
|
| 326 |
+
start_time = time.time()
|
| 327 |
+
global_step += 1
|
| 328 |
+
loss = trainer.update(batch, eval=False) # update step
|
| 329 |
+
train_loss += loss
|
| 330 |
+
if global_step % args['log_step'] == 0:
|
| 331 |
+
duration = time.time() - start_time
|
| 332 |
+
logger.info(format_str.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), global_step,
|
| 333 |
+
max_steps, loss, duration, current_lr))
|
| 334 |
+
if global_step % args['eval_interval'] == 0:
|
| 335 |
+
# eval on dev
|
| 336 |
+
logger.info("Evaluating on dev set...")
|
| 337 |
+
dev_preds = []
|
| 338 |
+
for batch in dev_batch:
|
| 339 |
+
preds = trainer.predict(batch)
|
| 340 |
+
dev_preds += preds
|
| 341 |
+
_, _, dev_score, _ = scorer.score_by_entity(dev_preds, dev_gold_tags, ignore_tags=args['ignore_tag_scores'])
|
| 342 |
+
|
| 343 |
+
train_loss = train_loss / args['eval_interval'] # avg loss per batch
|
| 344 |
+
logger.info("step {}: train_loss = {:.6f}, dev_score = {:.4f}".format(global_step, train_loss, dev_score))
|
| 345 |
+
if args['wandb']:
|
| 346 |
+
wandb.log({'train_loss': train_loss, 'dev_score': dev_score})
|
| 347 |
+
train_loss = 0
|
| 348 |
+
|
| 349 |
+
# save best model
|
| 350 |
+
if len(dev_score_history) == 0 or dev_score > max(dev_score_history):
|
| 351 |
+
trainer.save(model_file)
|
| 352 |
+
last_best_step = global_step
|
| 353 |
+
logger.info("New best model saved.")
|
| 354 |
+
best_dev_preds = dev_preds
|
| 355 |
+
|
| 356 |
+
dev_score_history += [dev_score]
|
| 357 |
+
logger.info("")
|
| 358 |
+
|
| 359 |
+
# lr schedule
|
| 360 |
+
if scheduler is not None:
|
| 361 |
+
scheduler.step(dev_score)
|
| 362 |
+
|
| 363 |
+
if args['log_norms']:
|
| 364 |
+
trainer.model.log_norms()
|
| 365 |
+
|
| 366 |
+
# check stopping
|
| 367 |
+
current_lr = trainer.optimizer.param_groups[0]['lr']
|
| 368 |
+
if (global_step - last_best_step) >= args['max_steps_no_improve'] or global_step >= args['max_steps'] or current_lr <= args['min_lr']:
|
| 369 |
+
if (global_step - last_best_step) >= args['max_steps_no_improve']:
|
| 370 |
+
logger.info("{} steps without improvement...".format((global_step - last_best_step)))
|
| 371 |
+
if not is_second_optim and args['second_optim'] is not None:
|
| 372 |
+
logger.info("Switching to second optimizer: {}".format(args['second_optim']))
|
| 373 |
+
logger.info('Reloading best model to continue from current local optimum')
|
| 374 |
+
trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, device=args['device'],
|
| 375 |
+
train_classifier_only=args['train_classifier_only'], model_file=model_file, second_optim=True)
|
| 376 |
+
is_second_optim = True
|
| 377 |
+
last_best_step = global_step
|
| 378 |
+
current_lr = trainer.optimizer.param_groups[0]['lr']
|
| 379 |
+
else:
|
| 380 |
+
logger.info("stopping...")
|
| 381 |
+
should_stop = True
|
| 382 |
+
break
|
| 383 |
+
|
| 384 |
+
if should_stop:
|
| 385 |
+
break
|
| 386 |
+
|
| 387 |
+
train_batch.reshuffle()
|
| 388 |
+
|
| 389 |
+
logger.info("Training ended with {} steps.".format(global_step))
|
| 390 |
+
|
| 391 |
+
if args['wandb']:
|
| 392 |
+
wandb.finish()
|
| 393 |
+
|
| 394 |
+
if len(dev_score_history) > 0:
|
| 395 |
+
best_f, best_eval = max(dev_score_history)*100, np.argmax(dev_score_history)+1
|
| 396 |
+
logger.info("Best dev F1 = {:.2f}, at iteration = {}".format(best_f, best_eval * args['eval_interval']))
|
| 397 |
+
else:
|
| 398 |
+
logger.info("Dev set never evaluated. Saving final model.")
|
| 399 |
+
trainer.save(model_file)
|
| 400 |
+
|
| 401 |
+
return trainer
|
| 402 |
+
|
| 403 |
+
def write_ner_results(filename, batch, preds, predict_tagset):
|
| 404 |
+
if len(batch.tags) != len(preds):
|
| 405 |
+
raise ValueError("Unexpected batch vs pred lengths: %d vs %d" % (len(batch.tags), len(preds)))
|
| 406 |
+
|
| 407 |
+
with open(filename, "w", encoding="utf-8") as fout:
|
| 408 |
+
tag_idx = 0
|
| 409 |
+
for b in batch:
|
| 410 |
+
# b[0] is words, b[5] is orig_idx
|
| 411 |
+
# a namedtuple would make this cleaner without being much slower
|
| 412 |
+
text = utils.unsort(b[0], b[5])
|
| 413 |
+
for sentence in text:
|
| 414 |
+
# TODO: if we change the predict_tagset mechanism, will have to change this
|
| 415 |
+
sentence_gold = [x[predict_tagset] for x in batch.tags[tag_idx]]
|
| 416 |
+
sentence_pred = preds[tag_idx]
|
| 417 |
+
tag_idx += 1
|
| 418 |
+
for word, gold, pred in zip(sentence, sentence_gold, sentence_pred):
|
| 419 |
+
fout.write("%s\t%s\t%s\n" % (word, gold, pred))
|
| 420 |
+
fout.write("\n")
|
| 421 |
+
|
| 422 |
+
def evaluate(args):
|
| 423 |
+
# file paths
|
| 424 |
+
model_file = model_file_name(args)
|
| 425 |
+
|
| 426 |
+
loaded_args, trainer, vocab = load_model(args, model_file)
|
| 427 |
+
return evaluate_model(loaded_args, trainer, vocab, args['eval_file'])
|
| 428 |
+
|
| 429 |
+
def evaluate_model(loaded_args, trainer, vocab, eval_file):
|
| 430 |
+
if loaded_args['log_norms']:
|
| 431 |
+
trainer.model.log_norms()
|
| 432 |
+
|
| 433 |
+
model_file = os.path.join(loaded_args['save_dir'], loaded_args['save_name'])
|
| 434 |
+
logger.debug("Loaded model for eval from %s", model_file)
|
| 435 |
+
logger.debug("Using the %d tagset for evaluation", loaded_args['predict_tagset'])
|
| 436 |
+
|
| 437 |
+
# load data
|
| 438 |
+
logger.info("Loading data with batch size {}...".format(loaded_args['batch_size']))
|
| 439 |
+
with open(eval_file) as fin:
|
| 440 |
+
doc = Document(json.load(fin))
|
| 441 |
+
batch = DataLoader(doc, loaded_args['batch_size'], loaded_args, vocab=vocab, evaluation=True, bert_tokenizer=trainer.model.bert_tokenizer)
|
| 442 |
+
bioes_to_bio = loaded_args['train_scheme'] == 'bio' and loaded_args['scheme'] == 'bioes'
|
| 443 |
+
warn_missing_tags(trainer.vocab['tag'], batch.tags, "eval_file", bioes_to_bio=bioes_to_bio)
|
| 444 |
+
|
| 445 |
+
logger.info("Start evaluation...")
|
| 446 |
+
preds = []
|
| 447 |
+
for i, b in enumerate(batch):
|
| 448 |
+
preds += trainer.predict(b)
|
| 449 |
+
|
| 450 |
+
gold_tags = batch.tags
|
| 451 |
+
# TODO: might still want to add multiple layers of tag evaluation to the scorer
|
| 452 |
+
gold_tags = [[x[trainer.args['predict_tagset']] for x in tags] for tags in gold_tags]
|
| 453 |
+
|
| 454 |
+
_, _, score, entity_f1 = scorer.score_by_entity(preds, gold_tags, ignore_tags=loaded_args['ignore_tag_scores'])
|
| 455 |
+
_, _, _, confusion = scorer.score_by_token(preds, gold_tags, ignore_tags=loaded_args['ignore_tag_scores'])
|
| 456 |
+
logger.info("Weighted f1 for non-O tokens: %5f", confusion_to_weighted_f1(confusion, exclude=["O"]))
|
| 457 |
+
|
| 458 |
+
logger.info("NER tagger score: %s %s %s %.2f", loaded_args['shorthand'], model_file, eval_file, score*100)
|
| 459 |
+
entity_f1_lines = ["%s: %.2f" % (x, y*100) for x, y in entity_f1.items()]
|
| 460 |
+
logger.info("NER Entity F1 scores:\n %s", "\n ".join(entity_f1_lines))
|
| 461 |
+
logger.info("NER token confusion matrix:\n{}".format(format_confusion(confusion)))
|
| 462 |
+
|
| 463 |
+
if loaded_args['eval_output_file']:
|
| 464 |
+
write_ner_results(loaded_args['eval_output_file'], batch, preds, trainer.args['predict_tagset'])
|
| 465 |
+
|
| 466 |
+
return confusion
|
| 467 |
+
|
| 468 |
+
def load_model(args, model_file):
|
| 469 |
+
# load model
|
| 470 |
+
charlm_args = {}
|
| 471 |
+
if 'charlm_forward_file' in args:
|
| 472 |
+
charlm_args['charlm_forward_file'] = args['charlm_forward_file']
|
| 473 |
+
if 'charlm_backward_file' in args:
|
| 474 |
+
charlm_args['charlm_backward_file'] = args['charlm_backward_file']
|
| 475 |
+
if args['predict_tagset'] is not None:
|
| 476 |
+
charlm_args['predict_tagset'] = args['predict_tagset']
|
| 477 |
+
pretrain = load_pretrain(args)
|
| 478 |
+
trainer = Trainer(args=charlm_args, model_file=model_file, pretrain=pretrain, device=args['device'], train_classifier_only=args['train_classifier_only'])
|
| 479 |
+
loaded_args, vocab = trainer.args, trainer.vocab
|
| 480 |
+
|
| 481 |
+
# load config
|
| 482 |
+
for k in args:
|
| 483 |
+
if k.endswith('_dir') or k.endswith('_file') or k in ['batch_size', 'ignore_tag_scores', 'log_norms', 'mode', 'scheme', 'shorthand']:
|
| 484 |
+
loaded_args[k] = args[k]
|
| 485 |
+
save_dir, save_name = os.path.split(model_file)
|
| 486 |
+
args['save_dir'] = save_dir
|
| 487 |
+
args['save_name'] = save_name
|
| 488 |
+
return loaded_args, trainer, vocab
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
if __name__ == '__main__':
|
| 492 |
+
main()
|
stanza/stanza/models/tagger.py
ADDED
|
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Entry point for training and evaluating a POS/morphological features tagger.
|
| 3 |
+
|
| 4 |
+
This tagger uses highway BiLSTM layers with character and word-level representations, and biaffine classifiers
|
| 5 |
+
to produce consistent POS and UFeats predictions.
|
| 6 |
+
For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import time
|
| 13 |
+
import zipfile
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
from torch import nn, optim
|
| 18 |
+
|
| 19 |
+
from stanza.models.pos.data import Dataset, ShuffledDataset
|
| 20 |
+
from stanza.models.pos.trainer import Trainer
|
| 21 |
+
from stanza.models.pos import scorer
|
| 22 |
+
from stanza.models.common import utils
|
| 23 |
+
from stanza.models.common import pretrain
|
| 24 |
+
from stanza.models.common.doc import *
|
| 25 |
+
from stanza.models.common.foundation_cache import FoundationCache
|
| 26 |
+
from stanza.models.common.peft_config import add_peft_args, resolve_peft_args
|
| 27 |
+
from stanza.models import _training_logging
|
| 28 |
+
from stanza.utils.conll import CoNLL
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger('stanza')
|
| 31 |
+
|
| 32 |
+
def build_argparse():
|
| 33 |
+
parser = argparse.ArgumentParser()
|
| 34 |
+
parser.add_argument('--data_dir', type=str, default='data/pos', help='Root dir for saving models.')
|
| 35 |
+
parser.add_argument('--wordvec_dir', type=str, default='extern_data/wordvec', help='Directory of word vectors.')
|
| 36 |
+
parser.add_argument('--wordvec_file', type=str, default=None, help='Word vectors filename.')
|
| 37 |
+
parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
|
| 38 |
+
parser.add_argument('--train_file', type=str, default=None, help='Input file for training.')
|
| 39 |
+
parser.add_argument('--eval_file', type=str, default=None, help='Input file for scoring.')
|
| 40 |
+
parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.')
|
| 41 |
+
parser.add_argument('--no_gold_labels', dest='gold_labels', action='store_false', help="Don't score the eval file - perhaps it has no gold labels, for example. Cannot be used at training time")
|
| 42 |
+
|
| 43 |
+
parser.add_argument('--mode', default='train', choices=['train', 'predict'])
|
| 44 |
+
parser.add_argument('--lang', type=str, help='Language')
|
| 45 |
+
parser.add_argument('--shorthand', type=str, help="Treebank shorthand")
|
| 46 |
+
|
| 47 |
+
parser.add_argument('--hidden_dim', type=int, default=200)
|
| 48 |
+
parser.add_argument('--char_hidden_dim', type=int, default=400)
|
| 49 |
+
parser.add_argument('--deep_biaff_hidden_dim', type=int, default=400)
|
| 50 |
+
parser.add_argument('--composite_deep_biaff_hidden_dim', type=int, default=100)
|
| 51 |
+
parser.add_argument('--word_emb_dim', type=int, default=75, help='Dimension of the finetuned word embedding. Set to 0 to turn off')
|
| 52 |
+
parser.add_argument('--word_cutoff', type=int, default=7, help='How common a word must be to include it in the finetuned word embedding')
|
| 53 |
+
parser.add_argument('--char_emb_dim', type=int, default=100)
|
| 54 |
+
parser.add_argument('--tag_emb_dim', type=int, default=50)
|
| 55 |
+
parser.add_argument('--charlm_transform_dim', type=int, default=None, help='Transform the pretrained charlm to this dimension. If not set, no transform is used')
|
| 56 |
+
parser.add_argument('--transformed_dim', type=int, default=125)
|
| 57 |
+
parser.add_argument('--num_layers', type=int, default=2)
|
| 58 |
+
parser.add_argument('--char_num_layers', type=int, default=1)
|
| 59 |
+
parser.add_argument('--pretrain_max_vocab', type=int, default=250000)
|
| 60 |
+
parser.add_argument('--word_dropout', type=float, default=0.33)
|
| 61 |
+
parser.add_argument('--dropout', type=float, default=0.5)
|
| 62 |
+
parser.add_argument('--rec_dropout', type=float, default=0, help="Recurrent dropout")
|
| 63 |
+
parser.add_argument('--char_rec_dropout', type=float, default=0, help="Recurrent dropout")
|
| 64 |
+
|
| 65 |
+
# TODO: refactor charlm arguments for models which use it?
|
| 66 |
+
parser.add_argument('--no_char', dest='char', action='store_false', help="Turn off character model.")
|
| 67 |
+
parser.add_argument('--char_bidirectional', dest='char_bidirectional', action='store_true', help="Use a bidirectional version of the non-pretrained charlm. Doesn't help much, makes the models larger")
|
| 68 |
+
parser.add_argument('--char_lowercase', dest='char_lowercase', action='store_true', help="Use lowercased characters in character model.")
|
| 69 |
+
parser.add_argument('--charlm', action='store_true', help="Turn on contextualized char embedding using pretrained character-level language model.")
|
| 70 |
+
parser.add_argument('--charlm_save_dir', type=str, default='saved_models/charlm', help="Root dir for pretrained character-level language model.")
|
| 71 |
+
parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.")
|
| 72 |
+
parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm")
|
| 73 |
+
parser.add_argument('--charlm_backward_file', type=str, default=None, help="Exact path to use for backward charlm")
|
| 74 |
+
|
| 75 |
+
parser.add_argument('--bert_model', type=str, default=None, help="Use an external bert model (requires the transformers package)")
|
| 76 |
+
parser.add_argument('--no_bert_model', dest='bert_model', action="store_const", const=None, help="Don't use bert")
|
| 77 |
+
parser.add_argument('--bert_hidden_layers', type=int, default=None, help="How many layers of hidden state to use from the transformer")
|
| 78 |
+
parser.add_argument('--bert_finetune', default=False, action='store_true', help='Finetune the bert (or other transformer)')
|
| 79 |
+
parser.add_argument('--no_bert_finetune', dest='bert_finetune', action='store_false', help="Don't finetune the bert (or other transformer)")
|
| 80 |
+
parser.add_argument('--bert_learning_rate', default=1.0, type=float, help='Scale the learning rate for transformer finetuning by this much')
|
| 81 |
+
|
| 82 |
+
parser.add_argument('--no_pretrain', dest='pretrain', action='store_false', help="Turn off pretrained embeddings.")
|
| 83 |
+
parser.add_argument('--share_hid', action='store_true', help="Share hidden representations for UPOS, XPOS and UFeats.")
|
| 84 |
+
parser.set_defaults(share_hid=False)
|
| 85 |
+
|
| 86 |
+
parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.')
|
| 87 |
+
parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam, adamw, adamax, or adadelta. madgrad as an optional dependency')
|
| 88 |
+
parser.add_argument('--second_optim', type=str, default='amsgrad', help='Optimizer for the second half of training. Default is Adam with AMSGrad')
|
| 89 |
+
parser.add_argument('--second_optim_reload', default=False, action='store_true', help='Reload the best model instead of continuing from current model if the first optimizer stalls out. This does not seem to help, but might be useful for further experiments')
|
| 90 |
+
parser.add_argument('--no_second_optim', action='store_const', const=None, dest='second_optim', help="Don't use a second optimizer - only use the first optimizer")
|
| 91 |
+
parser.add_argument('--lr', type=float, default=3e-3, help='Learning rate')
|
| 92 |
+
parser.add_argument('--second_lr', type=float, default=None, help='Alternate learning rate for the second optimizer')
|
| 93 |
+
parser.add_argument('--initial_weight_decay', type=float, default=None, help='Optimizer weight decay for the first optimizer')
|
| 94 |
+
parser.add_argument('--second_weight_decay', type=float, default=None, help='Optimizer weight decay for the second optimizer')
|
| 95 |
+
parser.add_argument('--beta2', type=float, default=0.95)
|
| 96 |
+
|
| 97 |
+
parser.add_argument('--max_steps', type=int, default=50000)
|
| 98 |
+
parser.add_argument('--eval_interval', type=int, default=100)
|
| 99 |
+
parser.add_argument('--fix_eval_interval', dest='adapt_eval_interval', action='store_false', \
|
| 100 |
+
help="Use fixed evaluation interval for all treebanks, otherwise by default the interval will be increased for larger treebanks.")
|
| 101 |
+
parser.add_argument('--max_steps_before_stop', type=int, default=3000, help='Changes learning method or early terminates after this many steps if the dev scores are not improving')
|
| 102 |
+
parser.add_argument('--batch_size', type=int, default=250)
|
| 103 |
+
parser.add_argument('--batch_maximum_tokens', type=int, default=5000, help='When run in a Pipeline, limit a batch to this many tokens to help avoid OOM for long sentences')
|
| 104 |
+
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Gradient clipping.')
|
| 105 |
+
parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.')
|
| 106 |
+
parser.add_argument('--log_norms', action='store_true', default=False, help='Log the norms of all the parameters (noisy!)')
|
| 107 |
+
parser.add_argument('--save_dir', type=str, default='saved_models/pos', help='Root dir for saving models.')
|
| 108 |
+
parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_tagger.pt", help="File name to save the model")
|
| 109 |
+
parser.add_argument('--save_each', default=False, action='store_true', help="Save each checkpoint to its own model. Will take up a bunch of space")
|
| 110 |
+
|
| 111 |
+
parser.add_argument('--seed', type=int, default=1234)
|
| 112 |
+
add_peft_args(parser)
|
| 113 |
+
utils.add_device_args(parser)
|
| 114 |
+
|
| 115 |
+
parser.add_argument('--augment_nopunct', type=float, default=None, help='Augment the training data by copying this fraction of punct-ending sentences as non-punct. Default of None will aim for roughly 50%%')
|
| 116 |
+
|
| 117 |
+
parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
|
| 118 |
+
parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
|
| 119 |
+
return parser
|
| 120 |
+
|
| 121 |
+
def parse_args(args=None):
|
| 122 |
+
parser = build_argparse()
|
| 123 |
+
args = parser.parse_args(args=args)
|
| 124 |
+
resolve_peft_args(args, logger)
|
| 125 |
+
|
| 126 |
+
if args.augment_nopunct is None:
|
| 127 |
+
args.augment_nopunct = 0.25
|
| 128 |
+
|
| 129 |
+
if args.wandb_name:
|
| 130 |
+
args.wandb = True
|
| 131 |
+
|
| 132 |
+
args = vars(args)
|
| 133 |
+
return args
|
| 134 |
+
|
| 135 |
+
def main(args=None):
|
| 136 |
+
args = parse_args(args=args)
|
| 137 |
+
|
| 138 |
+
utils.set_random_seed(args['seed'])
|
| 139 |
+
|
| 140 |
+
logger.info("Running tagger in {} mode".format(args['mode']))
|
| 141 |
+
|
| 142 |
+
if args['mode'] == 'train':
|
| 143 |
+
train(args)
|
| 144 |
+
else:
|
| 145 |
+
evaluate(args)
|
| 146 |
+
|
| 147 |
+
def model_file_name(args):
|
| 148 |
+
return utils.standard_model_file_name(args, "tagger")
|
| 149 |
+
|
| 150 |
+
def save_each_file_name(args):
|
| 151 |
+
model_file = model_file_name(args)
|
| 152 |
+
pieces = os.path.splitext(model_file)
|
| 153 |
+
return pieces[0] + "_%05d" + pieces[1]
|
| 154 |
+
|
| 155 |
+
def load_pretrain(args):
|
| 156 |
+
pt = None
|
| 157 |
+
if args['pretrain']:
|
| 158 |
+
pretrain_file = pretrain.find_pretrain_file(args['wordvec_pretrain_file'], args['save_dir'], args['shorthand'], args['lang'])
|
| 159 |
+
if os.path.exists(pretrain_file):
|
| 160 |
+
vec_file = None
|
| 161 |
+
else:
|
| 162 |
+
vec_file = args['wordvec_file'] if args['wordvec_file'] else utils.get_wordvec_file(args['wordvec_dir'], args['shorthand'])
|
| 163 |
+
pt = pretrain.Pretrain(pretrain_file, vec_file, args['pretrain_max_vocab'])
|
| 164 |
+
return pt
|
| 165 |
+
|
| 166 |
+
def get_eval_type(dev_batch):
|
| 167 |
+
"""
|
| 168 |
+
If there is only one column to score in the dev set, use that instead of AllTags
|
| 169 |
+
"""
|
| 170 |
+
if dev_batch.has_xpos and not dev_batch.has_upos and not dev_batch.has_feats:
|
| 171 |
+
return "XPOS"
|
| 172 |
+
elif dev_batch.has_upos and not dev_batch.has_xpos and not dev_batch.has_feats:
|
| 173 |
+
return "UPOS"
|
| 174 |
+
else:
|
| 175 |
+
return "AllTags"
|
| 176 |
+
|
| 177 |
+
def load_training_data(args, pretrain):
|
| 178 |
+
train_docs = []
|
| 179 |
+
raw_train_files = args['train_file'].split(";")
|
| 180 |
+
train_files = []
|
| 181 |
+
for train_file in raw_train_files:
|
| 182 |
+
if zipfile.is_zipfile(train_file):
|
| 183 |
+
logger.info("Decompressing %s" % train_file)
|
| 184 |
+
with zipfile.ZipFile(train_file) as zin:
|
| 185 |
+
for zipped_train_file in zin.namelist():
|
| 186 |
+
with zin.open(zipped_train_file) as fin:
|
| 187 |
+
logger.info("Reading %s from %s" % (zipped_train_file, train_file))
|
| 188 |
+
train_str = fin.read()
|
| 189 |
+
train_str = train_str.decode("utf-8")
|
| 190 |
+
train_file_data, _, _ = CoNLL.conll2dict(input_str=train_str)
|
| 191 |
+
logger.info("Train File {} from {}, Data Size: {}".format(zipped_train_file, train_file, len(train_file_data)))
|
| 192 |
+
train_docs.append(Document(train_file_data))
|
| 193 |
+
train_files.append("%s %s" % (train_file, zipped_train_file))
|
| 194 |
+
else:
|
| 195 |
+
logger.info("Reading %s" % train_file)
|
| 196 |
+
# train_data is now a list of sentences, where each sentence is a
|
| 197 |
+
# list of words, in which each word is a dict of conll attributes
|
| 198 |
+
train_file_data, _, _ = CoNLL.conll2dict(input_file=train_file)
|
| 199 |
+
logger.info("Train File {}, Data Size: {}".format(train_file, len(train_file_data)))
|
| 200 |
+
train_docs.append(Document(train_file_data))
|
| 201 |
+
train_files.append(train_file)
|
| 202 |
+
if sum(len(x.sentences) for x in train_docs) == 0:
|
| 203 |
+
raise RuntimeError("Training data for the tagger is empty: %s" % args['train_file'])
|
| 204 |
+
# we want to ensure that the model is able te output _ for empty columns,
|
| 205 |
+
# but create batches whereby if a doc has upos/xpos tags we include them all.
|
| 206 |
+
# therefore, we create seperate datasets and loaders for each input training file,
|
| 207 |
+
# which will ensure the system be able to see batches with both upos available
|
| 208 |
+
# and upos unavailable depending on what the availability in the file is.
|
| 209 |
+
vocab = Dataset.init_vocab(train_docs, args)
|
| 210 |
+
train_data = [Dataset(i, args, pretrain, vocab=vocab, evaluation=False)
|
| 211 |
+
for i in train_docs]
|
| 212 |
+
for train_file, td in zip(train_files, train_data):
|
| 213 |
+
if not td.has_upos:
|
| 214 |
+
logger.info("No UPOS in %s" % train_file)
|
| 215 |
+
if not td.has_xpos:
|
| 216 |
+
logger.info("No XPOS in %s" % train_file)
|
| 217 |
+
if not td.has_feats:
|
| 218 |
+
logger.info("No feats in %s" % train_file)
|
| 219 |
+
|
| 220 |
+
# reject partially tagged upos or xpos documents
|
| 221 |
+
# otherwise, the model will learn to output blanks for some words,
|
| 222 |
+
# which is probably a confusing result
|
| 223 |
+
# (and definitely throws off the depparse)
|
| 224 |
+
# another option would be to treat those as masked out
|
| 225 |
+
for td_idx, td in enumerate(train_data):
|
| 226 |
+
if td.has_upos:
|
| 227 |
+
upos_data = td.doc.get(UPOS, as_sentences=True)
|
| 228 |
+
for sentence_idx, sentence in enumerate(upos_data):
|
| 229 |
+
for word_idx, upos in enumerate(sentence):
|
| 230 |
+
if upos == '_' or upos is None:
|
| 231 |
+
conll = "{:C}".format(td.doc.sentences[sentence_idx])
|
| 232 |
+
raise RuntimeError("Found a blank tag in the UPOS at sentence %d word %d of %s.\n%s" % ((sentence_idx+1), (word_idx+1), train_files[td_idx], conll))
|
| 233 |
+
|
| 234 |
+
# here we make sure the model will learn to output _ for empty columns
|
| 235 |
+
# if *any* dataset has data for the upos, xpos, or feature column,
|
| 236 |
+
# we consider that data enough to train the model on that column
|
| 237 |
+
# otherwise, we want to train the model to always output blanks
|
| 238 |
+
if not any(td.has_upos for td in train_data):
|
| 239 |
+
for td in train_data:
|
| 240 |
+
td.has_upos = True
|
| 241 |
+
if not any(td.has_xpos for td in train_data):
|
| 242 |
+
for td in train_data:
|
| 243 |
+
td.has_xpos = True
|
| 244 |
+
if not any(td.has_feats for td in train_data):
|
| 245 |
+
for td in train_data:
|
| 246 |
+
td.has_feats = True
|
| 247 |
+
# calculate the batches
|
| 248 |
+
train_batches = ShuffledDataset(train_data, args["batch_size"])
|
| 249 |
+
return vocab, train_data, train_batches
|
| 250 |
+
|
| 251 |
+
def train(args):
|
| 252 |
+
model_file = model_file_name(args)
|
| 253 |
+
utils.ensure_dir(os.path.split(model_file)[0])
|
| 254 |
+
|
| 255 |
+
if args['save_each']:
|
| 256 |
+
# so models.pt -> models_0001.pt, etc
|
| 257 |
+
model_save_each_file = save_each_file_name(args)
|
| 258 |
+
logger.info("Saving each checkpoint to %s" % model_save_each_file)
|
| 259 |
+
|
| 260 |
+
# load pretrained vectors if needed
|
| 261 |
+
pretrain = load_pretrain(args)
|
| 262 |
+
|
| 263 |
+
if args['charlm']:
|
| 264 |
+
if args['charlm_shorthand'] is None:
|
| 265 |
+
raise ValueError("CharLM Shorthand is required for loading pretrained CharLM model...")
|
| 266 |
+
logger.info('Using pretrained contextualized char embedding')
|
| 267 |
+
if not args['charlm_forward_file']:
|
| 268 |
+
args['charlm_forward_file'] = '{}/{}_forward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])
|
| 269 |
+
if not args['charlm_backward_file']:
|
| 270 |
+
args['charlm_backward_file'] = '{}/{}_backward_charlm.pt'.format(args['charlm_save_dir'], args['charlm_shorthand'])
|
| 271 |
+
|
| 272 |
+
# load data
|
| 273 |
+
logger.info("Loading data with batch size {}...".format(args['batch_size']))
|
| 274 |
+
vocab, train_data, train_batches = load_training_data(args, pretrain)
|
| 275 |
+
|
| 276 |
+
dev_doc = CoNLL.conll2doc(input_file=args['eval_file'])
|
| 277 |
+
dev_data = Dataset(dev_doc, args, pretrain, vocab=vocab, evaluation=True, sort_during_eval=True)
|
| 278 |
+
dev_batch = dev_data.to_loader(batch_size=args["batch_size"])
|
| 279 |
+
|
| 280 |
+
eval_type = get_eval_type(dev_data)
|
| 281 |
+
|
| 282 |
+
# pred and gold path
|
| 283 |
+
system_pred_file = args['output_file']
|
| 284 |
+
|
| 285 |
+
# skip training if the language does not have training or dev data
|
| 286 |
+
# sum(...) to check if all of the training files are empty
|
| 287 |
+
if sum(len(td) for td in train_data) == 0 or len(dev_data) == 0:
|
| 288 |
+
logger.info("Skip training because no data available...")
|
| 289 |
+
return
|
| 290 |
+
|
| 291 |
+
if args['wandb']:
|
| 292 |
+
import wandb
|
| 293 |
+
wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_tagger" % args['shorthand']
|
| 294 |
+
wandb.init(name=wandb_name, config=args)
|
| 295 |
+
wandb.run.define_metric('train_loss', summary='min')
|
| 296 |
+
wandb.run.define_metric('dev_score', summary='max')
|
| 297 |
+
|
| 298 |
+
logger.info("Training tagger...")
|
| 299 |
+
foundation_cache = FoundationCache()
|
| 300 |
+
trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, device=args['device'], foundation_cache=foundation_cache)
|
| 301 |
+
|
| 302 |
+
global_step = 0
|
| 303 |
+
max_steps = args['max_steps']
|
| 304 |
+
dev_score_history = []
|
| 305 |
+
best_dev_preds = []
|
| 306 |
+
current_lr = args['lr']
|
| 307 |
+
global_start_time = time.time()
|
| 308 |
+
format_str = 'Finished STEP {}/{}, loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'
|
| 309 |
+
|
| 310 |
+
if args['adapt_eval_interval']:
|
| 311 |
+
args['eval_interval'] = utils.get_adaptive_eval_interval(dev_data.num_examples, 2000, args['eval_interval'])
|
| 312 |
+
logger.info("Evaluating the model every {} steps...".format(args['eval_interval']))
|
| 313 |
+
|
| 314 |
+
if args['save_each']:
|
| 315 |
+
logger.info("Saving initial checkpoint to %s" % (model_save_each_file % global_step))
|
| 316 |
+
trainer.save(model_save_each_file % global_step)
|
| 317 |
+
|
| 318 |
+
using_amsgrad = False
|
| 319 |
+
last_best_step = 0
|
| 320 |
+
# start training
|
| 321 |
+
train_loss = 0
|
| 322 |
+
if args['log_norms']:
|
| 323 |
+
trainer.model.log_norms()
|
| 324 |
+
while True:
|
| 325 |
+
do_break = False
|
| 326 |
+
for i, batch in enumerate(train_batches):
|
| 327 |
+
start_time = time.time()
|
| 328 |
+
global_step += 1
|
| 329 |
+
loss = trainer.update(batch, eval=False) # update step
|
| 330 |
+
train_loss += loss
|
| 331 |
+
if global_step % args['log_step'] == 0:
|
| 332 |
+
duration = time.time() - start_time
|
| 333 |
+
logger.info(format_str.format(global_step, max_steps, loss, duration, current_lr))
|
| 334 |
+
if args['log_norms']:
|
| 335 |
+
trainer.model.log_norms()
|
| 336 |
+
|
| 337 |
+
if global_step % args['eval_interval'] == 0:
|
| 338 |
+
# eval on dev
|
| 339 |
+
logger.info("Evaluating on dev set...")
|
| 340 |
+
dev_preds = []
|
| 341 |
+
indices = []
|
| 342 |
+
for batch in dev_batch:
|
| 343 |
+
preds = trainer.predict(batch)
|
| 344 |
+
dev_preds += preds
|
| 345 |
+
indices.extend(batch[-1])
|
| 346 |
+
dev_preds = utils.unsort(dev_preds, indices)
|
| 347 |
+
dev_data.doc.set([UPOS, XPOS, FEATS], [y for x in dev_preds for y in x])
|
| 348 |
+
CoNLL.write_doc2conll(dev_data.doc, system_pred_file)
|
| 349 |
+
|
| 350 |
+
_, _, dev_score = scorer.score(system_pred_file, args['eval_file'], eval_type=eval_type)
|
| 351 |
+
|
| 352 |
+
train_loss = train_loss / args['eval_interval'] # avg loss per batch
|
| 353 |
+
logger.info("step {}: train_loss = {:.6f}, dev_score = {:.4f}".format(global_step, train_loss, dev_score))
|
| 354 |
+
|
| 355 |
+
if args['wandb']:
|
| 356 |
+
wandb.log({'train_loss': train_loss, 'dev_score': dev_score})
|
| 357 |
+
|
| 358 |
+
train_loss = 0
|
| 359 |
+
|
| 360 |
+
if args['save_each']:
|
| 361 |
+
logger.info("Saving checkpoint to %s" % (model_save_each_file % global_step))
|
| 362 |
+
trainer.save(model_save_each_file % global_step)
|
| 363 |
+
|
| 364 |
+
# save best model
|
| 365 |
+
if len(dev_score_history) == 0 or dev_score > max(dev_score_history):
|
| 366 |
+
last_best_step = global_step
|
| 367 |
+
trainer.save(model_file)
|
| 368 |
+
logger.info("new best model saved.")
|
| 369 |
+
best_dev_preds = dev_preds
|
| 370 |
+
|
| 371 |
+
dev_score_history += [dev_score]
|
| 372 |
+
|
| 373 |
+
if global_step - last_best_step >= args['max_steps_before_stop']:
|
| 374 |
+
if not using_amsgrad and args['second_optim'] is not None:
|
| 375 |
+
logger.info("Switching to second optimizer: {}".format(args['second_optim']))
|
| 376 |
+
if args['second_optim_reload']:
|
| 377 |
+
logger.info('Reloading best model to continue from current local optimum')
|
| 378 |
+
trainer = Trainer(args=args, vocab=trainer.vocab, pretrain=pretrain, model_file=model_file, device=args['device'], foundation_cache=foundation_cache)
|
| 379 |
+
last_best_step = global_step
|
| 380 |
+
using_amsgrad = True
|
| 381 |
+
lr = args['second_lr']
|
| 382 |
+
if lr is None:
|
| 383 |
+
lr = args['lr']
|
| 384 |
+
trainer.optimizer = utils.get_optimizer(args['second_optim'], trainer.model, lr=lr, betas=(.9, args['beta2']), eps=1e-6, weight_decay=args['second_weight_decay'])
|
| 385 |
+
else:
|
| 386 |
+
logger.info("Early termination: have not improved in {} steps".format(args['max_steps_before_stop']))
|
| 387 |
+
do_break = True
|
| 388 |
+
break
|
| 389 |
+
|
| 390 |
+
if global_step >= args['max_steps']:
|
| 391 |
+
do_break = True
|
| 392 |
+
break
|
| 393 |
+
|
| 394 |
+
if do_break: break
|
| 395 |
+
|
| 396 |
+
logger.info("Training ended with {} steps.".format(global_step))
|
| 397 |
+
|
| 398 |
+
if args['wandb']:
|
| 399 |
+
wandb.finish()
|
| 400 |
+
|
| 401 |
+
if len(dev_score_history) > 0:
|
| 402 |
+
best_f, best_eval = max(dev_score_history)*100, np.argmax(dev_score_history)+1
|
| 403 |
+
logger.info("Best dev F1 = {:.2f}, at iteration = {}".format(best_f, best_eval * args['eval_interval']))
|
| 404 |
+
else:
|
| 405 |
+
logger.info("Dev set never evaluated. Saving final model.")
|
| 406 |
+
trainer.save(model_file)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def evaluate(args):
|
| 410 |
+
# file paths
|
| 411 |
+
model_file = model_file_name(args)
|
| 412 |
+
|
| 413 |
+
pretrain = load_pretrain(args)
|
| 414 |
+
|
| 415 |
+
load_args = {'charlm_forward_file': args.get('charlm_forward_file', None),
|
| 416 |
+
'charlm_backward_file': args.get('charlm_backward_file', None)}
|
| 417 |
+
|
| 418 |
+
# load model
|
| 419 |
+
logger.info("Loading model from: {}".format(model_file))
|
| 420 |
+
trainer = Trainer(pretrain=pretrain, model_file=model_file, device=args['device'], args=load_args)
|
| 421 |
+
evaluate_trainer(args, trainer, pretrain)
|
| 422 |
+
|
| 423 |
+
def evaluate_trainer(args, trainer, pretrain):
|
| 424 |
+
system_pred_file = args['output_file']
|
| 425 |
+
loaded_args, vocab = trainer.args, trainer.vocab
|
| 426 |
+
|
| 427 |
+
# load config
|
| 428 |
+
for k in args:
|
| 429 |
+
if k.endswith('_dir') or k.endswith('_file') or k in ['shorthand'] or k == 'mode':
|
| 430 |
+
loaded_args[k] = args[k]
|
| 431 |
+
|
| 432 |
+
# load data
|
| 433 |
+
logger.info("Loading data with batch size {}...".format(args['batch_size']))
|
| 434 |
+
doc = CoNLL.conll2doc(input_file=args['eval_file'])
|
| 435 |
+
dev_data = Dataset(doc, loaded_args, pretrain, vocab=vocab, evaluation=True, sort_during_eval=True)
|
| 436 |
+
dev_batch = dev_data.to_loader(batch_size=args['batch_size'])
|
| 437 |
+
eval_type = get_eval_type(dev_data)
|
| 438 |
+
if len(dev_batch) > 0:
|
| 439 |
+
logger.info("Start evaluation...")
|
| 440 |
+
preds = []
|
| 441 |
+
indices = []
|
| 442 |
+
with torch.no_grad():
|
| 443 |
+
for b in dev_batch:
|
| 444 |
+
preds += trainer.predict(b)
|
| 445 |
+
indices.extend(b[-1])
|
| 446 |
+
else:
|
| 447 |
+
# skip eval if dev data does not exist
|
| 448 |
+
preds = []
|
| 449 |
+
preds = utils.unsort(preds, indices)
|
| 450 |
+
|
| 451 |
+
# write to file and score
|
| 452 |
+
dev_data.doc.set([UPOS, XPOS, FEATS], [y for x in preds for y in x])
|
| 453 |
+
CoNLL.write_doc2conll(dev_data.doc, system_pred_file)
|
| 454 |
+
|
| 455 |
+
if args['gold_labels']:
|
| 456 |
+
_, _, score = scorer.score(system_pred_file, args['eval_file'], eval_type=eval_type)
|
| 457 |
+
|
| 458 |
+
logger.info("POS Tagger score: %s %.2f", args['shorthand'], score*100)
|
| 459 |
+
|
| 460 |
+
if __name__ == '__main__':
|
| 461 |
+
main()
|
stanza/stanza/models/tokenizer.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Entry point for training and evaluating a neural tokenizer.
|
| 3 |
+
|
| 4 |
+
This tokenizer treats tokenization and sentence segmentation as a tagging problem, and uses a combination of
|
| 5 |
+
recurrent and convolutional architectures.
|
| 6 |
+
For details please refer to paper: https://nlp.stanford.edu/pubs/qi2018universal.pdf.
|
| 7 |
+
|
| 8 |
+
Updated: This new version of tokenizer model incorporates the dictionary feature, especially useful for languages that
|
| 9 |
+
have multi-syllable words such as Vietnamese, Chinese or Thai. In summary, a lexicon contains all unique words found in
|
| 10 |
+
training dataset and external lexicon (if any) is created during training and saved alongside the model after training.
|
| 11 |
+
Using this lexicon, a dictionary is created which includes "words", "prefixes" and "suffixes" sets. During data preparation,
|
| 12 |
+
dictionary features are extracted at each character position, to "look ahead" and "look backward" to see if any words formed
|
| 13 |
+
found in the dictionary. The window size (or the dictionary feature length) is defined at the 95-percentile among all the existing
|
| 14 |
+
words in the lexicon, this is to eliminate the less frequent but long words (avoid having a high-dimension feat vector). Prefixes
|
| 15 |
+
and suffixes are used to stop early during the window-dictionary checking process.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
from copy import copy
|
| 20 |
+
import logging
|
| 21 |
+
import random
|
| 22 |
+
import numpy as np
|
| 23 |
+
import os
|
| 24 |
+
import torch
|
| 25 |
+
import json
|
| 26 |
+
from stanza.models.common import utils
|
| 27 |
+
from stanza.models.tokenization.trainer import Trainer
|
| 28 |
+
from stanza.models.tokenization.data import DataLoader, TokenizationDataset
|
| 29 |
+
from stanza.models.tokenization.utils import load_mwt_dict, eval_model, output_predictions, load_lexicon, create_dictionary
|
| 30 |
+
from stanza.models import _training_logging
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger('stanza')
|
| 33 |
+
|
| 34 |
+
def build_argparse():
|
| 35 |
+
"""
|
| 36 |
+
If args == None, the system args are used.
|
| 37 |
+
"""
|
| 38 |
+
parser = argparse.ArgumentParser()
|
| 39 |
+
parser.add_argument('--txt_file', type=str, help="Input plaintext file")
|
| 40 |
+
parser.add_argument('--label_file', type=str, default=None, help="Character-level label file")
|
| 41 |
+
parser.add_argument('--mwt_json_file', type=str, default=None, help="JSON file for MWT expansions")
|
| 42 |
+
parser.add_argument('--conll_file', type=str, default=None, help="CoNLL file for output")
|
| 43 |
+
parser.add_argument('--dev_txt_file', type=str, help="(Train only) Input plaintext file for the dev set")
|
| 44 |
+
parser.add_argument('--dev_label_file', type=str, default=None, help="(Train only) Character-level label file for the dev set")
|
| 45 |
+
parser.add_argument('--dev_conll_gold', type=str, default=None, help="(Train only) CoNLL-U file for the dev set for early stopping")
|
| 46 |
+
parser.add_argument('--lang', type=str, help="Language")
|
| 47 |
+
parser.add_argument('--shorthand', type=str, help="UD treebank shorthand")
|
| 48 |
+
|
| 49 |
+
parser.add_argument('--mode', default='train', choices=['train', 'predict'])
|
| 50 |
+
parser.add_argument('--skip_newline', action='store_true', help="Whether to skip newline characters in input. Particularly useful for languages like Chinese.")
|
| 51 |
+
|
| 52 |
+
parser.add_argument('--emb_dim', type=int, default=32, help="Dimension of unit embeddings")
|
| 53 |
+
parser.add_argument('--hidden_dim', type=int, default=64, help="Dimension of hidden units")
|
| 54 |
+
parser.add_argument('--conv_filters', type=str, default="1,9", help="Configuration of conv filters. ,, separates layers and , separates filter sizes in the same layer.")
|
| 55 |
+
parser.add_argument('--no-residual', dest='residual', action='store_false', help="Add linear residual connections")
|
| 56 |
+
parser.add_argument('--no-hierarchical', dest='hierarchical', action='store_false', help="\"Hierarchical\" RNN tokenizer")
|
| 57 |
+
parser.add_argument('--hier_invtemp', type=float, default=0.5, help="Inverse temperature used in propagating tokenization predictions between RNN layers")
|
| 58 |
+
parser.add_argument('--input_dropout', action='store_true', help="Dropout input embeddings as well")
|
| 59 |
+
parser.add_argument('--conv_res', type=str, default=None, help="Convolutional residual layers for the RNN")
|
| 60 |
+
parser.add_argument('--rnn_layers', type=int, default=1, help="Layers of RNN in the tokenizer")
|
| 61 |
+
parser.add_argument('--use_dictionary', action='store_true', help="Use dictionary feature. The lexicon is created using the training data and external dict (if any) expected to be found under the same folder of training dataset, formatted as SHORTHAND-externaldict.txt where each line in this file is a word. For example, data/tokenize/zh_gsdsimp-externaldict.txt")
|
| 62 |
+
|
| 63 |
+
parser.add_argument('--max_grad_norm', type=float, default=1.0, help="Maximum gradient norm to clip to")
|
| 64 |
+
parser.add_argument('--anneal', type=float, default=.999, help="Anneal the learning rate by this amount when dev performance deteriorate")
|
| 65 |
+
parser.add_argument('--anneal_after', type=int, default=2000, help="Anneal the learning rate no earlier than this step")
|
| 66 |
+
parser.add_argument('--lr0', type=float, default=2e-3, help="Initial learning rate")
|
| 67 |
+
parser.add_argument('--dropout', type=float, default=0.33, help="Dropout probability")
|
| 68 |
+
parser.add_argument('--unit_dropout', type=float, default=0.33, help="Unit dropout probability")
|
| 69 |
+
parser.add_argument('--feat_dropout', type=float, default=0.05, help="Features dropout probability for each element in feature vector")
|
| 70 |
+
parser.add_argument('--feat_unit_dropout', type=float, default=0.33, help="The whole feature of units dropout probability")
|
| 71 |
+
parser.add_argument('--tok_noise', type=float, default=0.02, help="Probability to induce noise to the input of the higher RNN")
|
| 72 |
+
parser.add_argument('--sent_drop_prob', type=float, default=0.2, help="Probability to drop sentences at the end of batches during training uniformly at random. Idea is to fake paragraph endings.")
|
| 73 |
+
parser.add_argument('--last_char_drop_prob', type=float, default=0.2, help="Probability to drop the last char of a block of text during training, uniformly at random. Idea is to fake a document ending w/o sentence final punctuation, hopefully to avoid the tokenizer learning to always tokenize the last character as a period")
|
| 74 |
+
parser.add_argument('--weight_decay', type=float, default=0.0, help="Weight decay")
|
| 75 |
+
parser.add_argument('--max_seqlen', type=int, default=100, help="Maximum sequence length to consider at a time")
|
| 76 |
+
parser.add_argument('--batch_size', type=int, default=32, help="Batch size to use")
|
| 77 |
+
parser.add_argument('--epochs', type=int, default=10, help="Total epochs to train the model for")
|
| 78 |
+
parser.add_argument('--steps', type=int, default=50000, help="Steps to train the model for, if unspecified use epochs")
|
| 79 |
+
parser.add_argument('--report_steps', type=int, default=20, help="Update step interval to report loss")
|
| 80 |
+
parser.add_argument('--shuffle_steps', type=int, default=100, help="Step interval to shuffle each paragraph in the generator")
|
| 81 |
+
parser.add_argument('--eval_steps', type=int, default=200, help="Step interval to evaluate the model on the dev set for early stopping")
|
| 82 |
+
parser.add_argument('--max_steps_before_stop', type=int, default=5000, help='Early terminates after this many steps if the dev scores are not improving')
|
| 83 |
+
parser.add_argument('--save_name', type=str, default=None, help="File name to save the model")
|
| 84 |
+
parser.add_argument('--load_name', type=str, default=None, help="File name to load a saved model")
|
| 85 |
+
parser.add_argument('--save_dir', type=str, default='saved_models/tokenize', help="Directory to save models in")
|
| 86 |
+
utils.add_device_args(parser)
|
| 87 |
+
parser.add_argument('--seed', type=int, default=1234)
|
| 88 |
+
|
| 89 |
+
parser.add_argument('--use_mwt', dest='use_mwt', default=None, action='store_true', help='Whether or not to include mwt output layers. If set to None, this will be determined by examining the training data for MWTs')
|
| 90 |
+
parser.add_argument('--no_use_mwt', dest='use_mwt', action='store_false', help='Whether or not to include mwt output layers')
|
| 91 |
+
|
| 92 |
+
parser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name')
|
| 93 |
+
parser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
|
| 94 |
+
return parser
|
| 95 |
+
|
| 96 |
+
def parse_args(args=None):
|
| 97 |
+
parser = build_argparse()
|
| 98 |
+
args = parser.parse_args(args=args)
|
| 99 |
+
|
| 100 |
+
if args.wandb_name:
|
| 101 |
+
args.wandb = True
|
| 102 |
+
|
| 103 |
+
args = vars(args)
|
| 104 |
+
return args
|
| 105 |
+
|
| 106 |
+
def model_file_name(args):
|
| 107 |
+
if args['save_name'] is not None:
|
| 108 |
+
save_name = args['save_name']
|
| 109 |
+
else:
|
| 110 |
+
save_name = args['shorthand'] + "_tokenizer.pt"
|
| 111 |
+
|
| 112 |
+
if not os.path.exists(os.path.join(args['save_dir'], save_name)) and os.path.exists(save_name):
|
| 113 |
+
return save_name
|
| 114 |
+
return os.path.join(args['save_dir'], save_name)
|
| 115 |
+
|
| 116 |
+
def main(args=None):
|
| 117 |
+
args = parse_args(args=args)
|
| 118 |
+
|
| 119 |
+
utils.set_random_seed(args['seed'])
|
| 120 |
+
|
| 121 |
+
logger.info("Running tokenizer in {} mode".format(args['mode']))
|
| 122 |
+
|
| 123 |
+
args['feat_funcs'] = ['space_before', 'capitalized', 'numeric', 'end_of_para', 'start_of_para']
|
| 124 |
+
args['feat_dim'] = len(args['feat_funcs'])
|
| 125 |
+
args['save_name'] = model_file_name(args)
|
| 126 |
+
utils.ensure_dir(os.path.split(args['save_name'])[0])
|
| 127 |
+
|
| 128 |
+
if args['mode'] == 'train':
|
| 129 |
+
train(args)
|
| 130 |
+
else:
|
| 131 |
+
evaluate(args)
|
| 132 |
+
|
| 133 |
+
def train(args):
|
| 134 |
+
if args['use_dictionary']:
|
| 135 |
+
#load lexicon
|
| 136 |
+
lexicon, args['num_dict_feat'] = load_lexicon(args)
|
| 137 |
+
#create the dictionary
|
| 138 |
+
dictionary = create_dictionary(lexicon)
|
| 139 |
+
#adjust the feat_dim
|
| 140 |
+
args['feat_dim'] += args['num_dict_feat']*2
|
| 141 |
+
else:
|
| 142 |
+
args['num_dict_feat'] = 0
|
| 143 |
+
lexicon=None
|
| 144 |
+
dictionary=None
|
| 145 |
+
|
| 146 |
+
mwt_dict = load_mwt_dict(args['mwt_json_file'])
|
| 147 |
+
|
| 148 |
+
train_input_files = {
|
| 149 |
+
'txt': args['txt_file'],
|
| 150 |
+
'label': args['label_file']
|
| 151 |
+
}
|
| 152 |
+
train_batches = DataLoader(args, input_files=train_input_files, dictionary=dictionary)
|
| 153 |
+
vocab = train_batches.vocab
|
| 154 |
+
|
| 155 |
+
args['vocab_size'] = len(vocab)
|
| 156 |
+
|
| 157 |
+
dev_input_files = {
|
| 158 |
+
'txt': args['dev_txt_file'],
|
| 159 |
+
'label': args['dev_label_file']
|
| 160 |
+
}
|
| 161 |
+
dev_batches = TokenizationDataset(args, input_files=dev_input_files, vocab=vocab, evaluation=True, dictionary=dictionary)
|
| 162 |
+
|
| 163 |
+
if args['use_mwt'] is None:
|
| 164 |
+
args['use_mwt'] = train_batches.has_mwt()
|
| 165 |
+
logger.info("Found {}mwts in the training data. Setting use_mwt to {}".format(("" if args['use_mwt'] else "no "), args['use_mwt']))
|
| 166 |
+
|
| 167 |
+
trainer = Trainer(args=args, vocab=vocab, lexicon=lexicon, dictionary=dictionary, device=args['device'])
|
| 168 |
+
|
| 169 |
+
if args['load_name'] is not None:
|
| 170 |
+
load_name = os.path.join(args['save_dir'], args['load_name'])
|
| 171 |
+
trainer.load(load_name)
|
| 172 |
+
trainer.change_lr(args['lr0'])
|
| 173 |
+
|
| 174 |
+
N = len(train_batches)
|
| 175 |
+
steps = args['steps'] if args['steps'] is not None else int(N * args['epochs'] / args['batch_size'] + .5)
|
| 176 |
+
lr = args['lr0']
|
| 177 |
+
|
| 178 |
+
prev_dev_score = -1
|
| 179 |
+
best_dev_score = -1
|
| 180 |
+
best_dev_step = -1
|
| 181 |
+
|
| 182 |
+
if args['wandb']:
|
| 183 |
+
import wandb
|
| 184 |
+
wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_tokenizer" % args['shorthand']
|
| 185 |
+
wandb.init(name=wandb_name, config=args)
|
| 186 |
+
wandb.run.define_metric('train_loss', summary='min')
|
| 187 |
+
wandb.run.define_metric('dev_score', summary='max')
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
for step in range(1, steps+1):
|
| 191 |
+
batch = train_batches.next(unit_dropout=args['unit_dropout'], feat_unit_dropout = args['feat_unit_dropout'])
|
| 192 |
+
|
| 193 |
+
loss = trainer.update(batch)
|
| 194 |
+
if step % args['report_steps'] == 0:
|
| 195 |
+
logger.info("Step {:6d}/{:6d} Loss: {:.3f}".format(step, steps, loss))
|
| 196 |
+
if args['wandb']:
|
| 197 |
+
wandb.log({'train_loss': loss}, step=step)
|
| 198 |
+
|
| 199 |
+
if args['shuffle_steps'] > 0 and step % args['shuffle_steps'] == 0:
|
| 200 |
+
train_batches.shuffle()
|
| 201 |
+
|
| 202 |
+
if step % args['eval_steps'] == 0:
|
| 203 |
+
dev_score = eval_model(args, trainer, dev_batches, vocab, mwt_dict)
|
| 204 |
+
if args['wandb']:
|
| 205 |
+
wandb.log({'dev_score': dev_score}, step=step)
|
| 206 |
+
reports = ['Dev score: {:6.3f}'.format(dev_score * 100)]
|
| 207 |
+
if step >= args['anneal_after'] and dev_score < prev_dev_score:
|
| 208 |
+
reports += ['lr: {:.6f} -> {:.6f}'.format(lr, lr * args['anneal'])]
|
| 209 |
+
lr *= args['anneal']
|
| 210 |
+
trainer.change_lr(lr)
|
| 211 |
+
|
| 212 |
+
prev_dev_score = dev_score
|
| 213 |
+
|
| 214 |
+
if dev_score > best_dev_score:
|
| 215 |
+
reports += ['New best dev score!']
|
| 216 |
+
best_dev_score = dev_score
|
| 217 |
+
best_dev_step = step
|
| 218 |
+
trainer.save(args['save_name'])
|
| 219 |
+
elif best_dev_step > 0 and step - best_dev_step > args['max_steps_before_stop']:
|
| 220 |
+
reports += ['Stopping training after {} steps with no improvement'.format(step - best_dev_step)]
|
| 221 |
+
logger.info('\t'.join(reports))
|
| 222 |
+
break
|
| 223 |
+
|
| 224 |
+
logger.info('\t'.join(reports))
|
| 225 |
+
|
| 226 |
+
if args['wandb']:
|
| 227 |
+
wandb.finish()
|
| 228 |
+
|
| 229 |
+
if best_dev_step > -1:
|
| 230 |
+
logger.info('Best dev score={} at step {}'.format(best_dev_score, best_dev_step))
|
| 231 |
+
else:
|
| 232 |
+
logger.info('Dev set never evaluated. Saving final model')
|
| 233 |
+
trainer.save(args['save_name'])
|
| 234 |
+
|
| 235 |
+
def evaluate(args):
|
| 236 |
+
mwt_dict = load_mwt_dict(args['mwt_json_file'])
|
| 237 |
+
trainer = Trainer(model_file=args['load_name'] or args['save_name'], device=args['device'])
|
| 238 |
+
loaded_args, vocab = trainer.args, trainer.vocab
|
| 239 |
+
|
| 240 |
+
for k in loaded_args:
|
| 241 |
+
if not k.endswith('_file') and k not in ['device', 'mode', 'save_dir', 'load_name', 'save_name']:
|
| 242 |
+
args[k] = loaded_args[k]
|
| 243 |
+
|
| 244 |
+
eval_input_files = {
|
| 245 |
+
'txt': args['txt_file'],
|
| 246 |
+
'label': args['label_file']
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
batches = TokenizationDataset(args, input_files=eval_input_files, vocab=vocab, evaluation=True, dictionary=trainer.dictionary)
|
| 251 |
+
|
| 252 |
+
oov_count, N, _, _ = output_predictions(args['conll_file'], trainer, batches, vocab, mwt_dict, args['max_seqlen'])
|
| 253 |
+
|
| 254 |
+
logger.info("OOV rate: {:6.3f}% ({:6d}/{:6d})".format(oov_count / N * 100, oov_count, N))
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
if __name__ == '__main__':
|
| 258 |
+
main()
|
stanza/stanza/models/wl_coref.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Runs experiments with CorefModel.
|
| 3 |
+
|
| 4 |
+
Try 'python wl_coref.py -h' for more details.
|
| 5 |
+
|
| 6 |
+
Code based on
|
| 7 |
+
|
| 8 |
+
https://github.com/KarelDO/wl-coref/tree/master
|
| 9 |
+
https://arxiv.org/abs/2310.06165
|
| 10 |
+
|
| 11 |
+
This was a fork of
|
| 12 |
+
|
| 13 |
+
https://github.com/vdobrovolskii/wl-coref
|
| 14 |
+
https://aclanthology.org/2021.emnlp-main.605/
|
| 15 |
+
|
| 16 |
+
If you use Stanza's coref module in your work, please cite the following:
|
| 17 |
+
|
| 18 |
+
@misc{doosterlinck2023cawcoref,
|
| 19 |
+
title={CAW-coref: Conjunction-Aware Word-level Coreference Resolution},
|
| 20 |
+
author={Karel D'Oosterlinck and Semere Kiros Bitew and Brandon Papineau and Christopher Potts and Thomas Demeester and Chris Develder},
|
| 21 |
+
year={2023},
|
| 22 |
+
eprint={2310.06165},
|
| 23 |
+
archivePrefix={arXiv},
|
| 24 |
+
primaryClass={cs.CL},
|
| 25 |
+
url = "https://arxiv.org/abs/2310.06165",
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
@inproceedings{dobrovolskii-2021-word,
|
| 29 |
+
title = "Word-Level Coreference Resolution",
|
| 30 |
+
author = "Dobrovolskii, Vladimir",
|
| 31 |
+
booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing",
|
| 32 |
+
month = nov,
|
| 33 |
+
year = "2021",
|
| 34 |
+
address = "Online and Punta Cana, Dominican Republic",
|
| 35 |
+
publisher = "Association for Computational Linguistics",
|
| 36 |
+
url = "https://aclanthology.org/2021.emnlp-main.605",
|
| 37 |
+
pages = "7670--7675"
|
| 38 |
+
}
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
import argparse
|
| 42 |
+
from contextlib import contextmanager
|
| 43 |
+
import datetime
|
| 44 |
+
import logging
|
| 45 |
+
import os
|
| 46 |
+
import random
|
| 47 |
+
import sys
|
| 48 |
+
import dataclasses
|
| 49 |
+
import time
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
import numpy as np # type: ignore
|
| 53 |
+
import torch # type: ignore
|
| 54 |
+
|
| 55 |
+
from stanza.models.common.utils import set_random_seed
|
| 56 |
+
from stanza.models.coref.model import CorefModel
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
logger = logging.getLogger('stanza')
|
| 60 |
+
|
| 61 |
+
@contextmanager
|
| 62 |
+
def output_running_time():
|
| 63 |
+
""" Prints the time elapsed in the context """
|
| 64 |
+
start = int(time.time())
|
| 65 |
+
try:
|
| 66 |
+
yield
|
| 67 |
+
finally:
|
| 68 |
+
end = int(time.time())
|
| 69 |
+
delta = datetime.timedelta(seconds=end - start)
|
| 70 |
+
logger.info(f"Total running time: {delta}")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def deterministic() -> None:
|
| 74 |
+
torch.backends.cudnn.deterministic = True # type: ignore
|
| 75 |
+
torch.backends.cudnn.benchmark = False # type: ignore
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
argparser = argparse.ArgumentParser()
|
| 80 |
+
argparser.add_argument("mode", choices=("train", "eval"))
|
| 81 |
+
argparser.add_argument("experiment")
|
| 82 |
+
argparser.add_argument("--config-file", default="config.toml")
|
| 83 |
+
argparser.add_argument("--data-split", choices=("train", "dev", "test"),
|
| 84 |
+
default="test",
|
| 85 |
+
help="Data split to be used for evaluation."
|
| 86 |
+
" Defaults to 'test'."
|
| 87 |
+
" Ignored in 'train' mode.")
|
| 88 |
+
argparser.add_argument("--batch-size", type=int,
|
| 89 |
+
help="Adjust to override the config value of anaphoricity "
|
| 90 |
+
"batch size if you are experiencing out-of-memory "
|
| 91 |
+
"issues")
|
| 92 |
+
argparser.add_argument("--disable_singletons", action="store_true",
|
| 93 |
+
help="don't predict singletons")
|
| 94 |
+
argparser.add_argument("--full_pairwise", action="store_true",
|
| 95 |
+
help="use speaker and document embeddings")
|
| 96 |
+
argparser.add_argument("--hidden_size", type=int,
|
| 97 |
+
help="Adjust the anaphoricity scorer hidden size")
|
| 98 |
+
argparser.add_argument("--rough_k", type=int,
|
| 99 |
+
help="Adjust the number of dummies to keep")
|
| 100 |
+
argparser.add_argument("--n_hidden_layers", type=int,
|
| 101 |
+
help="Adjust the anaphoricity scorer hidden layers")
|
| 102 |
+
argparser.add_argument("--dummy_mix", type=float,
|
| 103 |
+
help="Adjust the dummy mix")
|
| 104 |
+
argparser.add_argument("--bert_finetune_begin_epoch", type=float,
|
| 105 |
+
help="Adjust the bert finetune begin epoch")
|
| 106 |
+
argparser.add_argument("--warm_start", action="store_true",
|
| 107 |
+
help="If set, the training will resume from the"
|
| 108 |
+
" last checkpoint saved if any. Ignored in"
|
| 109 |
+
" evaluation modes."
|
| 110 |
+
" Incompatible with '--weights'.")
|
| 111 |
+
argparser.add_argument("--weights",
|
| 112 |
+
help="Path to file with weights to load."
|
| 113 |
+
" If not supplied, in 'eval' mode the latest"
|
| 114 |
+
" weights of the experiment will be loaded;"
|
| 115 |
+
" in 'train' mode no weights will be loaded.")
|
| 116 |
+
argparser.add_argument("--word-level", action="store_true",
|
| 117 |
+
help="If set, output word-level conll-formatted"
|
| 118 |
+
" files in evaluation modes. Ignored in"
|
| 119 |
+
" 'train' mode.")
|
| 120 |
+
argparser.add_argument("--learning_rate", default=None, type=float,
|
| 121 |
+
help="If set, update the learning rate for the model")
|
| 122 |
+
argparser.add_argument("--bert_learning_rate", default=None, type=float,
|
| 123 |
+
help="If set, update the learning rate for the transformer")
|
| 124 |
+
argparser.add_argument("--save_dir", default=None,
|
| 125 |
+
help="If set, update the save directory for writing models")
|
| 126 |
+
argparser.add_argument("--save_name", default=None,
|
| 127 |
+
help="If set, update the save name for writing models (otherwise, section name)")
|
| 128 |
+
argparser.add_argument("--score_lang", default=None,
|
| 129 |
+
help="only score a particular language for eval")
|
| 130 |
+
argparser.add_argument("--log_norms", action="store_true", default=None,
|
| 131 |
+
help="If set, log all of the trainable norms each epoch. Very noisy!")
|
| 132 |
+
argparser.add_argument("--seed", type=int, default=2020,
|
| 133 |
+
help="Random seed to set")
|
| 134 |
+
|
| 135 |
+
argparser.add_argument("--train_data", default=None, help="File to use for train data")
|
| 136 |
+
argparser.add_argument("--dev_data", default=None, help="File to use for dev data")
|
| 137 |
+
argparser.add_argument("--test_data", default=None, help="File to use for test data")
|
| 138 |
+
|
| 139 |
+
argparser.add_argument('--wandb', action='store_true', help='Start a wandb session and write the results of training. Only applies to training. Use --wandb_name instead to specify a name', default=False)
|
| 140 |
+
argparser.add_argument('--wandb_name', default=None, help='Name of a wandb session to start when training. Will default to the dataset short name')
|
| 141 |
+
|
| 142 |
+
args = argparser.parse_args()
|
| 143 |
+
|
| 144 |
+
if args.warm_start and args.weights is not None:
|
| 145 |
+
raise ValueError("The following options are incompatible: '--warm_start' and '--weights'")
|
| 146 |
+
|
| 147 |
+
set_random_seed(args.seed)
|
| 148 |
+
deterministic()
|
| 149 |
+
config = CorefModel._load_config(args.config_file, args.experiment)
|
| 150 |
+
if args.batch_size:
|
| 151 |
+
config.a_scoring_batch_size = args.batch_size
|
| 152 |
+
if args.hidden_size:
|
| 153 |
+
config.hidden_size = args.hidden_size
|
| 154 |
+
if args.n_hidden_layers:
|
| 155 |
+
config.n_hidden_layers = args.n_hidden_layers
|
| 156 |
+
if args.learning_rate is not None:
|
| 157 |
+
config.learning_rate = args.learning_rate
|
| 158 |
+
if args.bert_learning_rate is not None:
|
| 159 |
+
config.bert_learning_rate = args.bert_learning_rate
|
| 160 |
+
if args.bert_finetune_begin_epoch is not None:
|
| 161 |
+
config.bert_finetune_begin_epoch = args.bert_finetune_begin_epoch
|
| 162 |
+
if args.dummy_mix is not None:
|
| 163 |
+
config.dummy_mix = args.dummy_mix
|
| 164 |
+
|
| 165 |
+
if args.save_dir is not None:
|
| 166 |
+
config.save_dir = args.save_dir
|
| 167 |
+
if args.save_name:
|
| 168 |
+
config.save_name = args.save_name
|
| 169 |
+
else:
|
| 170 |
+
config.save_name = args.experiment
|
| 171 |
+
|
| 172 |
+
if args.rough_k is not None:
|
| 173 |
+
config.rough_k = args.rough_k
|
| 174 |
+
if args.log_norms is not None:
|
| 175 |
+
config.log_norms = args.log_norms
|
| 176 |
+
if args.full_pairwise:
|
| 177 |
+
config.full_pairwise = args.full_pairwise
|
| 178 |
+
if args.disable_singletons:
|
| 179 |
+
config.singletons = False
|
| 180 |
+
if args.train_data:
|
| 181 |
+
config.train_data = args.train_data
|
| 182 |
+
if args.dev_data:
|
| 183 |
+
config.dev_data = args.dev_data
|
| 184 |
+
if args.test_data:
|
| 185 |
+
config.test_data = args.test_data
|
| 186 |
+
|
| 187 |
+
# if wandb, generate wandb configuration
|
| 188 |
+
if args.mode == "train":
|
| 189 |
+
if args.wandb:
|
| 190 |
+
import wandb
|
| 191 |
+
wandb_name = args.wandb_name if args.wandb_name else f"wl_coref_{args.experiment}"
|
| 192 |
+
wandb.init(name=wandb_name, config=dataclasses.asdict(config), project="stanza")
|
| 193 |
+
wandb.run.define_metric('train_c_loss', summary='min')
|
| 194 |
+
wandb.run.define_metric('train_s_loss', summary='min')
|
| 195 |
+
wandb.run.define_metric('dev_score', summary='max')
|
| 196 |
+
|
| 197 |
+
model = CorefModel(config=config)
|
| 198 |
+
if args.weights is not None or args.warm_start:
|
| 199 |
+
model.load_weights(path=args.weights, map_location="cpu",
|
| 200 |
+
noexception=args.warm_start)
|
| 201 |
+
with output_running_time():
|
| 202 |
+
model.train(args.wandb)
|
| 203 |
+
else:
|
| 204 |
+
config_update = {
|
| 205 |
+
'log_norms': args.log_norms if args.log_norms is not None else False
|
| 206 |
+
}
|
| 207 |
+
if args.test_data:
|
| 208 |
+
config_update['test_data'] = args.test_data
|
| 209 |
+
|
| 210 |
+
if args.weights is None and config.save_name is not None:
|
| 211 |
+
args.weights = config.save_name
|
| 212 |
+
if not os.path.exists(args.weights) and os.path.exists(args.weights + ".pt"):
|
| 213 |
+
args.weights = args.weights + ".pt"
|
| 214 |
+
elif not os.path.exists(args.weights) and config.save_dir and os.path.exists(os.path.join(config.save_dir, args.weights)):
|
| 215 |
+
args.weights = os.path.join(config.save_dir, args.weights)
|
| 216 |
+
elif not os.path.exists(args.weights) and config.save_dir and os.path.exists(os.path.join(config.save_dir, args.weights + ".pt")):
|
| 217 |
+
args.weights = os.path.join(config.save_dir, args.weights + ".pt")
|
| 218 |
+
model = CorefModel.load_model(path=args.weights, map_location="cpu",
|
| 219 |
+
ignore={"bert_optimizer", "general_optimizer",
|
| 220 |
+
"bert_scheduler", "general_scheduler"},
|
| 221 |
+
config_update=config_update)
|
| 222 |
+
results = model.evaluate(data_split=args.data_split,
|
| 223 |
+
word_level_conll=args.word_level,
|
| 224 |
+
eval_lang=args.score_lang)
|
| 225 |
+
# logger.info(("mean loss", "))
|
| 226 |
+
print("\t".join([str(round(i, 3)) for i in results]))
|
stanza/stanza/pipeline/__init__.py
ADDED
|
File without changes
|
stanza/stanza/pipeline/constituency_processor.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Processor that attaches a constituency tree to a sentence
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from stanza.models.constituency.trainer import Trainer
|
| 6 |
+
|
| 7 |
+
from stanza.models.common import doc
|
| 8 |
+
from stanza.models.common.utils import sort_with_indices, unsort
|
| 9 |
+
from stanza.utils.get_tqdm import get_tqdm
|
| 10 |
+
from stanza.pipeline._constants import *
|
| 11 |
+
from stanza.pipeline.processor import UDProcessor, register_processor
|
| 12 |
+
|
| 13 |
+
tqdm = get_tqdm()
|
| 14 |
+
|
| 15 |
+
@register_processor(CONSTITUENCY)
|
| 16 |
+
class ConstituencyProcessor(UDProcessor):
|
| 17 |
+
# set of processor requirements this processor fulfills
|
| 18 |
+
PROVIDES_DEFAULT = set([CONSTITUENCY])
|
| 19 |
+
# set of processor requirements for this processor
|
| 20 |
+
REQUIRES_DEFAULT = set([TOKENIZE, POS])
|
| 21 |
+
|
| 22 |
+
# default batch size, measured in sentences
|
| 23 |
+
DEFAULT_BATCH_SIZE = 50
|
| 24 |
+
|
| 25 |
+
def _set_up_requires(self):
|
| 26 |
+
self._pretagged = self._config.get('pretagged')
|
| 27 |
+
if self._pretagged:
|
| 28 |
+
self._requires = set()
|
| 29 |
+
else:
|
| 30 |
+
self._requires = self.__class__.REQUIRES_DEFAULT
|
| 31 |
+
|
| 32 |
+
def _set_up_model(self, config, pipeline, device):
|
| 33 |
+
# set up model
|
| 34 |
+
# pretrain and charlm paths are args from the config
|
| 35 |
+
# bert (if used) will be chosen from the model save file
|
| 36 |
+
args = {
|
| 37 |
+
"wordvec_pretrain_file": config.get('pretrain_path', None),
|
| 38 |
+
"charlm_forward_file": config.get('forward_charlm_path', None),
|
| 39 |
+
"charlm_backward_file": config.get('backward_charlm_path', None),
|
| 40 |
+
"device": device,
|
| 41 |
+
}
|
| 42 |
+
trainer = Trainer.load(filename=config['model_path'],
|
| 43 |
+
args=args,
|
| 44 |
+
foundation_cache=pipeline.foundation_cache)
|
| 45 |
+
self._trainer = trainer
|
| 46 |
+
self._model = trainer.model
|
| 47 |
+
self._model.eval()
|
| 48 |
+
# batch size counted as sentences
|
| 49 |
+
self._batch_size = int(config.get('batch_size', ConstituencyProcessor.DEFAULT_BATCH_SIZE))
|
| 50 |
+
self._tqdm = 'tqdm' in config and config['tqdm']
|
| 51 |
+
|
| 52 |
+
def _set_up_final_config(self, config):
|
| 53 |
+
loaded_args = self._model.args
|
| 54 |
+
loaded_args = {k: v for k, v in loaded_args.items() if not UDProcessor.filter_out_option(k)}
|
| 55 |
+
loaded_args.update(config)
|
| 56 |
+
self._config = loaded_args
|
| 57 |
+
|
| 58 |
+
def process(self, document):
|
| 59 |
+
sentences = document.sentences
|
| 60 |
+
|
| 61 |
+
if self._model.uses_xpos():
|
| 62 |
+
words = [[(w.text, w.xpos) for w in s.words] for s in sentences]
|
| 63 |
+
else:
|
| 64 |
+
words = [[(w.text, w.upos) for w in s.words] for s in sentences]
|
| 65 |
+
words, original_indices = sort_with_indices(words, key=len, reverse=True)
|
| 66 |
+
if self._tqdm:
|
| 67 |
+
words = tqdm(words)
|
| 68 |
+
|
| 69 |
+
trees = self._model.parse_tagged_words(words, self._batch_size)
|
| 70 |
+
trees = unsort(trees, original_indices)
|
| 71 |
+
document.set(CONSTITUENCY, trees, to_sentence=True)
|
| 72 |
+
return document
|
| 73 |
+
|
| 74 |
+
def get_constituents(self):
|
| 75 |
+
"""
|
| 76 |
+
Return a set of the constituents known by this model
|
| 77 |
+
|
| 78 |
+
For a pipeline, this can be queried with
|
| 79 |
+
pipeline.processors["constituency"].get_constituents()
|
| 80 |
+
"""
|
| 81 |
+
return set(self._model.constituents)
|
stanza/stanza/pipeline/core.py
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pipeline that runs tokenize,mwt,pos,lemma,depparse
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import collections
|
| 7 |
+
from enum import Enum
|
| 8 |
+
import io
|
| 9 |
+
import itertools
|
| 10 |
+
import sys
|
| 11 |
+
import logging
|
| 12 |
+
import json
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
from stanza.pipeline._constants import *
|
| 16 |
+
from stanza.models.common.constant import langcode_to_lang
|
| 17 |
+
from stanza.models.common.doc import Document
|
| 18 |
+
from stanza.models.common.foundation_cache import FoundationCache
|
| 19 |
+
from stanza.models.common.utils import default_device
|
| 20 |
+
from stanza.pipeline.processor import Processor, ProcessorRequirementsException
|
| 21 |
+
from stanza.pipeline.registry import NAME_TO_PROCESSOR_CLASS, PIPELINE_NAMES, PROCESSOR_VARIANTS
|
| 22 |
+
from stanza.pipeline.langid_processor import LangIDProcessor
|
| 23 |
+
from stanza.pipeline.tokenize_processor import TokenizeProcessor
|
| 24 |
+
from stanza.pipeline.mwt_processor import MWTProcessor
|
| 25 |
+
from stanza.pipeline.pos_processor import POSProcessor
|
| 26 |
+
from stanza.pipeline.lemma_processor import LemmaProcessor
|
| 27 |
+
from stanza.pipeline.constituency_processor import ConstituencyProcessor
|
| 28 |
+
from stanza.pipeline.coref_processor import CorefProcessor
|
| 29 |
+
from stanza.pipeline.depparse_processor import DepparseProcessor
|
| 30 |
+
from stanza.pipeline.sentiment_processor import SentimentProcessor
|
| 31 |
+
from stanza.pipeline.ner_processor import NERProcessor
|
| 32 |
+
from stanza.resources.common import DEFAULT_MODEL_DIR, DEFAULT_RESOURCES_URL, DEFAULT_RESOURCES_VERSION, ModelSpecification, add_dependencies, add_mwt, download_models, download_resources_json, flatten_processor_list, load_resources_json, maintain_processor_list, process_pipeline_parameters, set_logging_level, sort_processors
|
| 33 |
+
from stanza.resources.default_packages import PACKAGES
|
| 34 |
+
from stanza.utils.conll import CoNLL, CoNLLError
|
| 35 |
+
from stanza.utils.helper_func import make_table
|
| 36 |
+
|
| 37 |
+
logger = logging.getLogger('stanza')
|
| 38 |
+
|
| 39 |
+
class DownloadMethod(Enum):
|
| 40 |
+
"""
|
| 41 |
+
Determines a couple options on how to download resources for the pipeline.
|
| 42 |
+
|
| 43 |
+
NONE will not download anything, including HF transformers, probably resulting in failure if the resources aren't already in place.
|
| 44 |
+
REUSE_RESOURCES will reuse the existing resources.json and models, but will download any missing models.
|
| 45 |
+
DOWNLOAD_RESOURCES will download a new resources.json and will overwrite any out of date models.
|
| 46 |
+
"""
|
| 47 |
+
NONE = 1
|
| 48 |
+
REUSE_RESOURCES = 2
|
| 49 |
+
DOWNLOAD_RESOURCES = 3
|
| 50 |
+
|
| 51 |
+
class LanguageNotDownloadedError(FileNotFoundError):
|
| 52 |
+
def __init__(self, lang, lang_dir, model_path):
|
| 53 |
+
super().__init__(f'Could not find the model file {model_path}. The expected model directory {lang_dir} is missing. Perhaps you need to run stanza.download("{lang}")')
|
| 54 |
+
self.lang = lang
|
| 55 |
+
self.lang_dir = lang_dir
|
| 56 |
+
self.model_path = model_path
|
| 57 |
+
|
| 58 |
+
class UnsupportedProcessorError(FileNotFoundError):
|
| 59 |
+
def __init__(self, processor, lang):
|
| 60 |
+
super().__init__(f'Processor {processor} is not known for language {lang}. If you have created your own model, please specify the {processor}_model_path parameter when creating the pipeline.')
|
| 61 |
+
self.processor = processor
|
| 62 |
+
self.lang = lang
|
| 63 |
+
|
| 64 |
+
class IllegalPackageError(ValueError):
|
| 65 |
+
def __init__(self, msg):
|
| 66 |
+
super().__init__(msg)
|
| 67 |
+
|
| 68 |
+
class PipelineRequirementsException(Exception):
|
| 69 |
+
"""
|
| 70 |
+
Exception indicating one or more requirements failures while attempting to build a pipeline.
|
| 71 |
+
Contains a ProcessorRequirementsException list.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(self, processor_req_fails):
|
| 75 |
+
self._processor_req_fails = processor_req_fails
|
| 76 |
+
self.build_message()
|
| 77 |
+
|
| 78 |
+
@property
|
| 79 |
+
def processor_req_fails(self):
|
| 80 |
+
return self._processor_req_fails
|
| 81 |
+
|
| 82 |
+
def build_message(self):
|
| 83 |
+
err_msg = io.StringIO()
|
| 84 |
+
print(*[req_fail.message for req_fail in self.processor_req_fails], sep='\n', file=err_msg)
|
| 85 |
+
self.message = '\n\n' + err_msg.getvalue()
|
| 86 |
+
|
| 87 |
+
def __str__(self):
|
| 88 |
+
return self.message
|
| 89 |
+
|
| 90 |
+
def build_default_config_option(model_specs):
|
| 91 |
+
"""
|
| 92 |
+
Build a config option for a couple situations: lemma=identity, processor is a variant
|
| 93 |
+
|
| 94 |
+
Returns the option name and value
|
| 95 |
+
|
| 96 |
+
Refactored from build_default_config so that we can reuse it when
|
| 97 |
+
downloading all models
|
| 98 |
+
"""
|
| 99 |
+
# handle case when processor variants are used
|
| 100 |
+
if any(model_spec.package in PROCESSOR_VARIANTS[model_spec.processor] for model_spec in model_specs):
|
| 101 |
+
if len(model_specs) > 1:
|
| 102 |
+
raise IllegalPackageError("Variant processor selected for {}, but multiple packages requested".format(model_spec.processor))
|
| 103 |
+
return f"{model_specs[0].processor}_with_{model_specs[0].package}", True
|
| 104 |
+
# handle case when identity is specified as lemmatizer
|
| 105 |
+
elif any(model_spec.processor == LEMMA and model_spec.package == 'identity' for model_spec in model_specs):
|
| 106 |
+
if len(model_specs) > 1:
|
| 107 |
+
raise IllegalPackageError("Identity processor selected for lemma, but multiple packages requested")
|
| 108 |
+
return f"{LEMMA}_use_identity", True
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
def filter_variants(model_specs):
|
| 112 |
+
return [(key, value) for (key, value) in model_specs if build_default_config_option(value) is None]
|
| 113 |
+
|
| 114 |
+
# given a language and models path, build a default configuration
|
| 115 |
+
def build_default_config(resources, lang, model_dir, load_list):
|
| 116 |
+
default_config = {}
|
| 117 |
+
for processor, model_specs in load_list:
|
| 118 |
+
option = build_default_config_option(model_specs)
|
| 119 |
+
if option is not None:
|
| 120 |
+
# if an option is set for the model_specs, keep that option and ignore
|
| 121 |
+
# the rest of the model spec
|
| 122 |
+
default_config[option[0]] = option[1]
|
| 123 |
+
continue
|
| 124 |
+
|
| 125 |
+
model_paths = [os.path.join(model_dir, lang, processor, model_spec.package + '.pt') for model_spec in model_specs]
|
| 126 |
+
dependencies = [model_spec.dependencies for model_spec in model_specs]
|
| 127 |
+
|
| 128 |
+
# Special case for NER: load multiple models at once
|
| 129 |
+
# The pattern will be:
|
| 130 |
+
# a list of ner_model_path
|
| 131 |
+
# a list of ner_dependencies
|
| 132 |
+
# where each item in ner_dependencies is a map
|
| 133 |
+
# the map may contain forward_charlm_path, backward_charlm_path, or any other deps
|
| 134 |
+
# The user will be able to override the defaults using a semicolon separated string
|
| 135 |
+
# TODO: at least use the same config pattern for all other models
|
| 136 |
+
if processor == NER:
|
| 137 |
+
default_config[f"{processor}_model_path"] = model_paths
|
| 138 |
+
dependency_paths = []
|
| 139 |
+
for dependency_block in dependencies:
|
| 140 |
+
if not dependency_block:
|
| 141 |
+
dependency_paths.append({})
|
| 142 |
+
continue
|
| 143 |
+
dependency_paths.append({})
|
| 144 |
+
for dependency in dependency_block:
|
| 145 |
+
dep_processor, dep_model = dependency
|
| 146 |
+
dependency_paths[-1][f"{dep_processor}_path"] = os.path.join(model_dir, lang, dep_processor, dep_model + '.pt')
|
| 147 |
+
default_config[f"{processor}_dependencies"] = dependency_paths
|
| 148 |
+
continue
|
| 149 |
+
|
| 150 |
+
if len(model_specs) > 1:
|
| 151 |
+
raise IllegalPackageError("Specified multiple packages for {}, which currently only handles one package".format(processor))
|
| 152 |
+
|
| 153 |
+
default_config[f"{processor}_model_path"] = model_paths[0]
|
| 154 |
+
if not dependencies[0]: continue
|
| 155 |
+
for dependency in dependencies[0]:
|
| 156 |
+
dep_processor, dep_model = dependency
|
| 157 |
+
default_config[f"{processor}_{dep_processor}_path"] = os.path.join(
|
| 158 |
+
model_dir, lang, dep_processor, dep_model + '.pt'
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
return default_config
|
| 162 |
+
|
| 163 |
+
def normalize_download_method(download_method):
|
| 164 |
+
"""
|
| 165 |
+
Turn None -> DownloadMethod.NONE, strings to the corresponding enum
|
| 166 |
+
"""
|
| 167 |
+
if download_method is None:
|
| 168 |
+
return DownloadMethod.NONE
|
| 169 |
+
elif isinstance(download_method, str):
|
| 170 |
+
try:
|
| 171 |
+
return DownloadMethod[download_method.upper()]
|
| 172 |
+
except KeyError as e:
|
| 173 |
+
raise ValueError("Unknown download method %s" % download_method) from e
|
| 174 |
+
return download_method
|
| 175 |
+
|
| 176 |
+
class Pipeline:
|
| 177 |
+
|
| 178 |
+
def __init__(self,
|
| 179 |
+
lang='en',
|
| 180 |
+
dir=DEFAULT_MODEL_DIR,
|
| 181 |
+
package='default',
|
| 182 |
+
processors={},
|
| 183 |
+
logging_level=None,
|
| 184 |
+
verbose=None,
|
| 185 |
+
use_gpu=None,
|
| 186 |
+
model_dir=None,
|
| 187 |
+
download_method=DownloadMethod.DOWNLOAD_RESOURCES,
|
| 188 |
+
resources_url=DEFAULT_RESOURCES_URL,
|
| 189 |
+
resources_branch=None,
|
| 190 |
+
resources_version=DEFAULT_RESOURCES_VERSION,
|
| 191 |
+
resources_filepath=None,
|
| 192 |
+
proxies=None,
|
| 193 |
+
foundation_cache=None,
|
| 194 |
+
device=None,
|
| 195 |
+
allow_unknown_language=False,
|
| 196 |
+
**kwargs):
|
| 197 |
+
self.lang, self.dir, self.kwargs = lang, dir, kwargs
|
| 198 |
+
if model_dir is not None and dir == DEFAULT_MODEL_DIR:
|
| 199 |
+
self.dir = model_dir
|
| 200 |
+
|
| 201 |
+
# set global logging level
|
| 202 |
+
set_logging_level(logging_level, verbose)
|
| 203 |
+
|
| 204 |
+
self.download_method = normalize_download_method(download_method)
|
| 205 |
+
if (self.download_method is DownloadMethod.DOWNLOAD_RESOURCES or
|
| 206 |
+
(self.download_method is DownloadMethod.REUSE_RESOURCES and not os.path.exists(os.path.join(self.dir, "resources.json")))):
|
| 207 |
+
logger.info("Checking for updates to resources.json in case models have been updated. Note: this behavior can be turned off with download_method=None or download_method=DownloadMethod.REUSE_RESOURCES")
|
| 208 |
+
download_resources_json(self.dir,
|
| 209 |
+
resources_url=resources_url,
|
| 210 |
+
resources_branch=resources_branch,
|
| 211 |
+
resources_version=resources_version,
|
| 212 |
+
resources_filepath=resources_filepath,
|
| 213 |
+
proxies=proxies)
|
| 214 |
+
|
| 215 |
+
# processors can use this to save on the effort of loading
|
| 216 |
+
# large sub-models, such as pretrained embeddings, bert, etc
|
| 217 |
+
if foundation_cache is None:
|
| 218 |
+
self.foundation_cache = FoundationCache(local_files_only=(self.download_method is DownloadMethod.NONE))
|
| 219 |
+
else:
|
| 220 |
+
self.foundation_cache = FoundationCache(foundation_cache, local_files_only=(self.download_method is DownloadMethod.NONE))
|
| 221 |
+
|
| 222 |
+
# process different pipeline parameters
|
| 223 |
+
lang, self.dir, package, processors = process_pipeline_parameters(lang, self.dir, package, processors)
|
| 224 |
+
|
| 225 |
+
# Load resources.json to obtain latest packages.
|
| 226 |
+
logger.debug('Loading resource file...')
|
| 227 |
+
resources = load_resources_json(self.dir, resources_filepath)
|
| 228 |
+
if lang in resources:
|
| 229 |
+
if 'alias' in resources[lang]:
|
| 230 |
+
logger.info(f'"{lang}" is an alias for "{resources[lang]["alias"]}"')
|
| 231 |
+
lang = resources[lang]['alias']
|
| 232 |
+
lang_name = resources[lang]['lang_name'] if 'lang_name' in resources[lang] else ''
|
| 233 |
+
elif allow_unknown_language:
|
| 234 |
+
logger.warning("Trying to create pipeline for unsupported language: %s", lang)
|
| 235 |
+
lang_name = langcode_to_lang(lang)
|
| 236 |
+
else:
|
| 237 |
+
logger.warning("Unsupported language: %s If trying to add a new language, consider using allow_unknown_language=True", lang)
|
| 238 |
+
lang_name = langcode_to_lang(lang)
|
| 239 |
+
|
| 240 |
+
# Maintain load list
|
| 241 |
+
if lang in resources:
|
| 242 |
+
self.load_list = maintain_processor_list(resources, lang, package, processors, maybe_add_mwt=(not kwargs.get("tokenize_pretokenized")))
|
| 243 |
+
self.load_list = add_dependencies(resources, lang, self.load_list)
|
| 244 |
+
if self.download_method is not DownloadMethod.NONE:
|
| 245 |
+
# skip processors which aren't downloaded from our collection
|
| 246 |
+
download_list = [x for x in self.load_list if x[0] in resources.get(lang, {})]
|
| 247 |
+
# skip variants
|
| 248 |
+
download_list = filter_variants(download_list)
|
| 249 |
+
# gather up the model list...
|
| 250 |
+
download_list = flatten_processor_list(download_list)
|
| 251 |
+
# download_models will skip models we already have
|
| 252 |
+
download_models(download_list,
|
| 253 |
+
resources=resources,
|
| 254 |
+
lang=lang,
|
| 255 |
+
model_dir=self.dir,
|
| 256 |
+
resources_version=resources_version,
|
| 257 |
+
proxies=proxies,
|
| 258 |
+
log_info=False)
|
| 259 |
+
elif allow_unknown_language:
|
| 260 |
+
self.load_list = [(proc, [ModelSpecification(processor=proc, package='default', dependencies=None)])
|
| 261 |
+
for proc in list(processors.keys())]
|
| 262 |
+
else:
|
| 263 |
+
self.load_list = []
|
| 264 |
+
self.load_list = self.update_kwargs(kwargs, self.load_list)
|
| 265 |
+
if len(self.load_list) == 0:
|
| 266 |
+
if lang not in resources or PACKAGES not in resources[lang]:
|
| 267 |
+
raise ValueError(f'No processors to load for language {lang}. Language {lang} is currently unsupported')
|
| 268 |
+
else:
|
| 269 |
+
raise ValueError('No processors to load for language {}. Please check if your language or package is correctly set.'.format(lang))
|
| 270 |
+
load_table = make_table(['Processor', 'Package'], [(row[0], ";".join(model_spec.package for model_spec in row[1])) for row in self.load_list])
|
| 271 |
+
logger.info(f'Loading these models for language: {lang} ({lang_name}):\n{load_table}')
|
| 272 |
+
|
| 273 |
+
self.config = build_default_config(resources, lang, self.dir, self.load_list)
|
| 274 |
+
self.config.update(kwargs)
|
| 275 |
+
|
| 276 |
+
# Load processors
|
| 277 |
+
self.processors = {}
|
| 278 |
+
|
| 279 |
+
# configs that are the same for all processors
|
| 280 |
+
pipeline_level_configs = {'lang': lang, 'mode': 'predict'}
|
| 281 |
+
|
| 282 |
+
if device is None:
|
| 283 |
+
if use_gpu is None or use_gpu == True:
|
| 284 |
+
device = default_device()
|
| 285 |
+
else:
|
| 286 |
+
device = 'cpu'
|
| 287 |
+
if use_gpu == True and device == 'cpu':
|
| 288 |
+
logger.warning("GPU requested, but is not available!")
|
| 289 |
+
self.device = device
|
| 290 |
+
logger.info("Using device: {}".format(self.device))
|
| 291 |
+
|
| 292 |
+
# set up processors
|
| 293 |
+
pipeline_reqs_exceptions = []
|
| 294 |
+
for item in self.load_list:
|
| 295 |
+
processor_name, _ = item
|
| 296 |
+
logger.info('Loading: ' + processor_name)
|
| 297 |
+
curr_processor_config = self.filter_config(processor_name, self.config)
|
| 298 |
+
curr_processor_config.update(pipeline_level_configs)
|
| 299 |
+
# TODO: this is obviously a hack
|
| 300 |
+
# a better solution overall would be to make a pretagged version of the pos annotator
|
| 301 |
+
# and then subsequent modules can use those tags without knowing where those tags came from
|
| 302 |
+
if "pretagged" in self.config and "pretagged" not in curr_processor_config:
|
| 303 |
+
curr_processor_config["pretagged"] = self.config["pretagged"]
|
| 304 |
+
logger.debug('With settings: ')
|
| 305 |
+
logger.debug(curr_processor_config)
|
| 306 |
+
try:
|
| 307 |
+
# try to build processor, throw an exception if there is a requirements issue
|
| 308 |
+
self.processors[processor_name] = NAME_TO_PROCESSOR_CLASS[processor_name](config=curr_processor_config,
|
| 309 |
+
pipeline=self,
|
| 310 |
+
device=self.device)
|
| 311 |
+
except ProcessorRequirementsException as e:
|
| 312 |
+
# if there was a requirements issue, add it to list which will be printed at end
|
| 313 |
+
pipeline_reqs_exceptions.append(e)
|
| 314 |
+
# add the broken processor to the loaded processors for the sake of analyzing the validity of the
|
| 315 |
+
# entire proposed pipeline, but at this point the pipeline will not be built successfully
|
| 316 |
+
self.processors[processor_name] = e.err_processor
|
| 317 |
+
except FileNotFoundError as e:
|
| 318 |
+
# For a FileNotFoundError, we try to guess if there's
|
| 319 |
+
# a missing model directory or file. If so, we
|
| 320 |
+
# suggest the user try to download the models
|
| 321 |
+
if 'model_path' in curr_processor_config:
|
| 322 |
+
model_path = curr_processor_config['model_path']
|
| 323 |
+
if e.filename == model_path or (isinstance(model_path, (tuple, list)) and e.filename in model_path):
|
| 324 |
+
model_path = e.filename
|
| 325 |
+
model_dir, model_name = os.path.split(model_path)
|
| 326 |
+
lang_dir = os.path.dirname(model_dir)
|
| 327 |
+
if lang_dir and not os.path.exists(lang_dir):
|
| 328 |
+
# model files for this language can't be found in the expected directory
|
| 329 |
+
raise LanguageNotDownloadedError(lang, lang_dir, model_path) from e
|
| 330 |
+
if processor_name not in resources[lang]:
|
| 331 |
+
# user asked for a model which doesn't exist for this language?
|
| 332 |
+
raise UnsupportedProcessorError(processor_name, lang) from e
|
| 333 |
+
if not os.path.exists(model_path):
|
| 334 |
+
model_name, _ = os.path.splitext(model_name)
|
| 335 |
+
# TODO: before recommending this, check that such a thing exists in resources.json.
|
| 336 |
+
# currently that case is handled by ignoring the model, anyway
|
| 337 |
+
raise FileNotFoundError('Could not find model file %s, although there are other models downloaded for language %s. Perhaps you need to download a specific model. Try: stanza.download(lang="%s",package=None,processors={"%s":"%s"})' % (model_path, lang, lang, processor_name, model_name)) from e
|
| 338 |
+
|
| 339 |
+
# if we couldn't find a more suitable description of the
|
| 340 |
+
# FileNotFoundError, just raise the old error
|
| 341 |
+
raise
|
| 342 |
+
|
| 343 |
+
# if there are any processor exceptions, throw an exception to indicate pipeline build failure
|
| 344 |
+
if pipeline_reqs_exceptions:
|
| 345 |
+
logger.info('\n')
|
| 346 |
+
raise PipelineRequirementsException(pipeline_reqs_exceptions)
|
| 347 |
+
|
| 348 |
+
logger.info("Done loading processors!")
|
| 349 |
+
|
| 350 |
+
@staticmethod
|
| 351 |
+
def update_kwargs(kwargs, processor_list):
|
| 352 |
+
processor_dict = {processor: [{'package': model_spec.package, 'dependencies': model_spec.dependencies} for model_spec in model_specs]
|
| 353 |
+
for (processor, model_specs) in processor_list}
|
| 354 |
+
for key, value in kwargs.items():
|
| 355 |
+
pieces = key.split('_', 1)
|
| 356 |
+
if len(pieces) == 1:
|
| 357 |
+
continue
|
| 358 |
+
k, v = pieces
|
| 359 |
+
if v == 'model_path':
|
| 360 |
+
package = value if len(value) < 25 else value[:10]+ '...' + value[-10:]
|
| 361 |
+
original_spec = processor_dict.get(k, [])
|
| 362 |
+
if len(original_spec) > 0:
|
| 363 |
+
dependencies = original_spec[0].get('dependencies')
|
| 364 |
+
else:
|
| 365 |
+
dependencies = None
|
| 366 |
+
processor_dict[k] = [{'package': package, 'dependencies': dependencies}]
|
| 367 |
+
processor_list = [(processor, [ModelSpecification(processor=processor, package=model_spec['package'], dependencies=model_spec['dependencies']) for model_spec in processor_dict[processor]]) for processor in processor_dict]
|
| 368 |
+
processor_list = sort_processors(processor_list)
|
| 369 |
+
return processor_list
|
| 370 |
+
|
| 371 |
+
@staticmethod
|
| 372 |
+
def filter_config(prefix, config_dict):
|
| 373 |
+
filtered_dict = {}
|
| 374 |
+
for key in config_dict.keys():
|
| 375 |
+
pieces = key.split('_', 1) # split tokenize_pretokenize to tokenize+pretokenize
|
| 376 |
+
if len(pieces) == 1:
|
| 377 |
+
continue
|
| 378 |
+
k, v = pieces
|
| 379 |
+
if k == prefix:
|
| 380 |
+
filtered_dict[v] = config_dict[key]
|
| 381 |
+
return filtered_dict
|
| 382 |
+
|
| 383 |
+
@property
|
| 384 |
+
def loaded_processors(self):
|
| 385 |
+
"""
|
| 386 |
+
Return all currently loaded processors in execution order.
|
| 387 |
+
:return: list of Processor instances
|
| 388 |
+
"""
|
| 389 |
+
return [self.processors[processor_name] for processor_name in PIPELINE_NAMES if self.processors.get(processor_name)]
|
| 390 |
+
|
| 391 |
+
def process(self, doc, processors=None):
|
| 392 |
+
"""
|
| 393 |
+
Run the pipeline
|
| 394 |
+
|
| 395 |
+
processors: allow for a list of processors used by this pipeline action
|
| 396 |
+
can be list, tuple, set, or comma separated string
|
| 397 |
+
if None, use all the processors this pipeline knows about
|
| 398 |
+
MWT is added if necessary
|
| 399 |
+
otherwise, no care is taken to make sure prerequisites are followed...
|
| 400 |
+
some of the annotators, such as depparse, will check, but others
|
| 401 |
+
will fail in some unusual manner or just have really bad results
|
| 402 |
+
"""
|
| 403 |
+
assert any([isinstance(doc, str), isinstance(doc, list),
|
| 404 |
+
isinstance(doc, Document)]), 'input should be either str, list or Document'
|
| 405 |
+
|
| 406 |
+
# empty bulk process
|
| 407 |
+
if isinstance(doc, list) and len(doc) == 0:
|
| 408 |
+
return []
|
| 409 |
+
|
| 410 |
+
# determine whether we are in bulk processing mode for multiple documents
|
| 411 |
+
bulk=(isinstance(doc, list) and len(doc) > 0 and isinstance(doc[0], Document))
|
| 412 |
+
|
| 413 |
+
# various options to limit the processors used by this pipeline action
|
| 414 |
+
if processors is None:
|
| 415 |
+
processors = PIPELINE_NAMES
|
| 416 |
+
elif not isinstance(processors, (str, list, tuple, set)):
|
| 417 |
+
raise ValueError("Cannot process {} as a list of processors to run".format(type(processors)))
|
| 418 |
+
else:
|
| 419 |
+
if isinstance(processors, str):
|
| 420 |
+
processors = {x for x in processors.split(",")}
|
| 421 |
+
else:
|
| 422 |
+
processors = set(processors)
|
| 423 |
+
if TOKENIZE in processors and MWT in self.processors and MWT not in processors:
|
| 424 |
+
logger.debug("Requested processors for pipeline did not have mwt, but pipeline needs mwt, so mwt is added")
|
| 425 |
+
processors.add(MWT)
|
| 426 |
+
processors = [x for x in PIPELINE_NAMES if x in processors]
|
| 427 |
+
|
| 428 |
+
for processor_name in processors:
|
| 429 |
+
if self.processors.get(processor_name):
|
| 430 |
+
process = self.processors[processor_name].bulk_process if bulk else self.processors[processor_name].process
|
| 431 |
+
doc = process(doc)
|
| 432 |
+
return doc
|
| 433 |
+
|
| 434 |
+
def bulk_process(self, docs, *args, **kwargs):
|
| 435 |
+
"""
|
| 436 |
+
Run the pipeline in bulk processing mode
|
| 437 |
+
|
| 438 |
+
Expects a list of str or a list of Docs
|
| 439 |
+
"""
|
| 440 |
+
# Wrap each text as a Document unless it is already such a document
|
| 441 |
+
docs = [doc if isinstance(doc, Document) else Document([], text=doc) for doc in docs]
|
| 442 |
+
return self.process(docs, *args, **kwargs)
|
| 443 |
+
|
| 444 |
+
def stream(self, docs, batch_size=50, *args, **kwargs):
|
| 445 |
+
"""
|
| 446 |
+
Go through an iterator of documents in batches, yield processed documents
|
| 447 |
+
|
| 448 |
+
sentence indices will be counted across the entire iterator
|
| 449 |
+
"""
|
| 450 |
+
if not isinstance(docs, collections.abc.Iterator):
|
| 451 |
+
docs = iter(docs)
|
| 452 |
+
def next_batch():
|
| 453 |
+
batch = []
|
| 454 |
+
for _ in range(batch_size):
|
| 455 |
+
try:
|
| 456 |
+
next_doc = next(docs)
|
| 457 |
+
batch.append(next_doc)
|
| 458 |
+
except StopIteration:
|
| 459 |
+
return batch
|
| 460 |
+
return batch
|
| 461 |
+
|
| 462 |
+
sentence_start_index = 0
|
| 463 |
+
batch = next_batch()
|
| 464 |
+
while batch:
|
| 465 |
+
batch = self.bulk_process(batch, *args, **kwargs)
|
| 466 |
+
for doc in batch:
|
| 467 |
+
doc.reindex_sentences(sentence_start_index)
|
| 468 |
+
sentence_start_index += len(doc.sentences)
|
| 469 |
+
yield doc
|
| 470 |
+
batch = next_batch()
|
| 471 |
+
|
| 472 |
+
def __str__(self):
|
| 473 |
+
"""
|
| 474 |
+
Assemble the processors in order to make a simple description of the pipeline
|
| 475 |
+
"""
|
| 476 |
+
processors = ["%s=%s" % (x, str(self.processors[x])) for x in PIPELINE_NAMES if x in self.processors]
|
| 477 |
+
return "<Pipeline: %s>" % ", ".join(processors)
|
| 478 |
+
|
| 479 |
+
def __call__(self, doc, processors=None):
|
| 480 |
+
return self.process(doc, processors)
|
| 481 |
+
|
| 482 |
+
def main():
|
| 483 |
+
# TODO: can add a bunch more arguments
|
| 484 |
+
parser = argparse.ArgumentParser()
|
| 485 |
+
parser.add_argument('--lang', type=str, default='en', help='Language of the pipeline to use')
|
| 486 |
+
parser.add_argument('--input_file', type=str, required=True, help='Input file to read')
|
| 487 |
+
parser.add_argument('--processors', type=str, default='tokenize,pos,lemma,depparse', help='Processors to use')
|
| 488 |
+
args, extra_args = parser.parse_known_args()
|
| 489 |
+
|
| 490 |
+
try:
|
| 491 |
+
doc = CoNLL.conll2doc(args.input_file)
|
| 492 |
+
extra_args = {
|
| 493 |
+
"tokenize_pretokenized": True
|
| 494 |
+
}
|
| 495 |
+
except CoNLLError:
|
| 496 |
+
logger.debug("Input file %s does not appear to be a conllu file. Will read it as a text file")
|
| 497 |
+
with open(args.input_file, encoding="utf-8") as fin:
|
| 498 |
+
doc = fin.read()
|
| 499 |
+
extra_args = {}
|
| 500 |
+
|
| 501 |
+
pipe = Pipeline(args.lang, processors=args.processors, **extra_args)
|
| 502 |
+
|
| 503 |
+
doc = pipe(doc)
|
| 504 |
+
|
| 505 |
+
print("{:C}".format(doc))
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
if __name__ == '__main__':
|
| 509 |
+
main()
|
stanza/stanza/pipeline/coref_processor.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Processor that attaches coref annotations to a document
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from stanza.models.common.utils import misc_to_space_after
|
| 6 |
+
from stanza.models.coref.coref_chain import CorefMention, CorefChain
|
| 7 |
+
|
| 8 |
+
from stanza.pipeline._constants import *
|
| 9 |
+
from stanza.pipeline.processor import UDProcessor, register_processor
|
| 10 |
+
|
| 11 |
+
def extract_text(document, sent_id, start_word, end_word):
|
| 12 |
+
sentence = document.sentences[sent_id]
|
| 13 |
+
tokens = []
|
| 14 |
+
|
| 15 |
+
# the coref model indexes the words from 0,
|
| 16 |
+
# whereas the ids we are looking at on the tokens start from 1
|
| 17 |
+
# here we will switch to ID space
|
| 18 |
+
start_word = start_word + 1
|
| 19 |
+
end_word = end_word + 1
|
| 20 |
+
|
| 21 |
+
# For each position between start and end word:
|
| 22 |
+
# If a word is part of an MWT, and the entire token
|
| 23 |
+
# is inside the range, we use that Token's text for that span
|
| 24 |
+
# This will let us easily handle words which are split into pieces
|
| 25 |
+
# Otherwise, we only take the text of the word itself
|
| 26 |
+
next_idx = start_word
|
| 27 |
+
while next_idx < end_word:
|
| 28 |
+
word = sentence.words[next_idx-1]
|
| 29 |
+
parent_token = word.parent
|
| 30 |
+
if isinstance(parent_token.id, int) or len(parent_token.id) == 1:
|
| 31 |
+
tokens.append(parent_token)
|
| 32 |
+
next_idx += 1
|
| 33 |
+
elif parent_token.id[0] >= start_word and parent_token.id[1] < end_word:
|
| 34 |
+
tokens.append(parent_token)
|
| 35 |
+
next_idx = parent_token.id[1] + 1
|
| 36 |
+
else:
|
| 37 |
+
tokens.append(word)
|
| 38 |
+
next_idx += 1
|
| 39 |
+
|
| 40 |
+
# We use the SpaceAfter or SpacesAfter attribute on each Word or Token
|
| 41 |
+
# we chose in the above loop to separate the text pieces
|
| 42 |
+
text = []
|
| 43 |
+
for token in tokens:
|
| 44 |
+
text.append(token.text)
|
| 45 |
+
text.append(misc_to_space_after(token.misc))
|
| 46 |
+
# the last token space_after will be discarded
|
| 47 |
+
# so that we don't have stray WS at the end of the mention text
|
| 48 |
+
text = text[:-1]
|
| 49 |
+
return "".join(text)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@register_processor(COREF)
|
| 53 |
+
class CorefProcessor(UDProcessor):
|
| 54 |
+
# set of processor requirements this processor fulfills
|
| 55 |
+
PROVIDES_DEFAULT = set([COREF])
|
| 56 |
+
# set of processor requirements for this processor
|
| 57 |
+
REQUIRES_DEFAULT = set([TOKENIZE])
|
| 58 |
+
|
| 59 |
+
def _set_up_model(self, config, pipeline, device):
|
| 60 |
+
try:
|
| 61 |
+
from stanza.models.coref.model import CorefModel
|
| 62 |
+
except ImportError:
|
| 63 |
+
raise ImportError("Please install the transformers and peft libraries before using coref! Try `pip install -e .[transformers]`.")
|
| 64 |
+
|
| 65 |
+
# set up model
|
| 66 |
+
# currently, the model has everything packaged in it
|
| 67 |
+
# (except its config)
|
| 68 |
+
# TODO: separate any pretrains if possible
|
| 69 |
+
# TODO: add device parameter to the load mechanism
|
| 70 |
+
config_update = {'log_norms': False,
|
| 71 |
+
'device': device}
|
| 72 |
+
model = CorefModel.load_model(path=config['model_path'],
|
| 73 |
+
ignore={"bert_optimizer", "general_optimizer",
|
| 74 |
+
"bert_scheduler", "general_scheduler"},
|
| 75 |
+
config_update=config_update,
|
| 76 |
+
foundation_cache=pipeline.foundation_cache)
|
| 77 |
+
if config.get('batch_size', None):
|
| 78 |
+
model.config.a_scoring_batch_size = int(config['batch_size'])
|
| 79 |
+
model.training = False
|
| 80 |
+
|
| 81 |
+
self._model = model
|
| 82 |
+
|
| 83 |
+
def process(self, document):
|
| 84 |
+
sentences = document.sentences
|
| 85 |
+
|
| 86 |
+
cased_words = []
|
| 87 |
+
sent_ids = []
|
| 88 |
+
word_pos = []
|
| 89 |
+
for sent_idx, sentence in enumerate(sentences):
|
| 90 |
+
for word_idx, word in enumerate(sentence.words):
|
| 91 |
+
cased_words.append(word.text)
|
| 92 |
+
sent_ids.append(sent_idx)
|
| 93 |
+
word_pos.append(word_idx)
|
| 94 |
+
|
| 95 |
+
coref_input = {
|
| 96 |
+
"document_id": "wb_doc_1",
|
| 97 |
+
"cased_words": cased_words,
|
| 98 |
+
"sent_id": sent_ids
|
| 99 |
+
}
|
| 100 |
+
coref_input = self._model.build_doc(coref_input)
|
| 101 |
+
results = self._model.run(coref_input)
|
| 102 |
+
clusters = []
|
| 103 |
+
for span_cluster in results.span_clusters:
|
| 104 |
+
if len(span_cluster) == 0:
|
| 105 |
+
continue
|
| 106 |
+
span_cluster = sorted(span_cluster)
|
| 107 |
+
|
| 108 |
+
for span in span_cluster:
|
| 109 |
+
# check there are no sentence crossings before
|
| 110 |
+
# manipulating the spans, since we will expect it to
|
| 111 |
+
# be this way for multiple usages of the spans
|
| 112 |
+
sent_id = sent_ids[span[0]]
|
| 113 |
+
if sent_ids[span[1]-1] != sent_id:
|
| 114 |
+
raise ValueError("The coref model predicted a span that crossed two sentences! Please send this example to us on our github")
|
| 115 |
+
|
| 116 |
+
# treat the longest span as the representative
|
| 117 |
+
# break ties using the first one
|
| 118 |
+
# IF there is the POS processor, and it adds upos tags
|
| 119 |
+
# to the sentence, ties are broken first by maximum
|
| 120 |
+
# number of UPOS and then earliest in the document
|
| 121 |
+
max_len = 0
|
| 122 |
+
best_span = None
|
| 123 |
+
max_propn = 0
|
| 124 |
+
for span_idx, span in enumerate(span_cluster):
|
| 125 |
+
sent_id = sent_ids[span[0]]
|
| 126 |
+
sentence = sentences[sent_id]
|
| 127 |
+
start_word = word_pos[span[0]]
|
| 128 |
+
# fiddle -1 / +1 so as to avoid problems with coref
|
| 129 |
+
# clusters that end at exactly the end of a document
|
| 130 |
+
end_word = word_pos[span[1]-1] + 1
|
| 131 |
+
# very UD specific test for most number of proper nouns in a mention
|
| 132 |
+
# will do nothing if POS is not active (they will all be None)
|
| 133 |
+
num_propn = sum(word.pos == 'PROPN' for word in sentence.words[start_word:end_word])
|
| 134 |
+
|
| 135 |
+
if ((span[1] - span[0] > max_len) or
|
| 136 |
+
span[1] - span[0] == max_len and num_propn > max_propn):
|
| 137 |
+
max_len = span[1] - span[0]
|
| 138 |
+
best_span = span_idx
|
| 139 |
+
max_propn = num_propn
|
| 140 |
+
|
| 141 |
+
mentions = []
|
| 142 |
+
for span in span_cluster:
|
| 143 |
+
sent_id = sent_ids[span[0]]
|
| 144 |
+
start_word = word_pos[span[0]]
|
| 145 |
+
end_word = word_pos[span[1]-1] + 1
|
| 146 |
+
mentions.append(CorefMention(sent_id, start_word, end_word))
|
| 147 |
+
representative = mentions[best_span]
|
| 148 |
+
representative_text = extract_text(document, representative.sentence, representative.start_word, representative.end_word)
|
| 149 |
+
|
| 150 |
+
chain = CorefChain(len(clusters), mentions, representative_text, best_span)
|
| 151 |
+
clusters.append(chain)
|
| 152 |
+
|
| 153 |
+
document.coref = clusters
|
| 154 |
+
return document
|
stanza/stanza/pipeline/depparse_processor.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Processor for performing dependency parsing
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from stanza.models.common import doc
|
| 8 |
+
from stanza.models.common.utils import unsort
|
| 9 |
+
from stanza.models.common.vocab import VOCAB_PREFIX
|
| 10 |
+
from stanza.models.depparse.data import DataLoader
|
| 11 |
+
from stanza.models.depparse.trainer import Trainer
|
| 12 |
+
from stanza.pipeline._constants import *
|
| 13 |
+
from stanza.pipeline.processor import UDProcessor, register_processor
|
| 14 |
+
|
| 15 |
+
# these imports trigger the "register_variant" decorations
|
| 16 |
+
from stanza.pipeline.external.corenlp_converter_depparse import ConverterDepparse
|
| 17 |
+
|
| 18 |
+
DEFAULT_SEPARATE_BATCH=150
|
| 19 |
+
|
| 20 |
+
@register_processor(name=DEPPARSE)
|
| 21 |
+
class DepparseProcessor(UDProcessor):
|
| 22 |
+
|
| 23 |
+
# set of processor requirements this processor fulfills
|
| 24 |
+
PROVIDES_DEFAULT = set([DEPPARSE])
|
| 25 |
+
# set of processor requirements for this processor
|
| 26 |
+
REQUIRES_DEFAULT = set([TOKENIZE, POS, LEMMA])
|
| 27 |
+
|
| 28 |
+
def __init__(self, config, pipeline, device):
|
| 29 |
+
self._pretagged = None
|
| 30 |
+
super().__init__(config, pipeline, device)
|
| 31 |
+
|
| 32 |
+
def _set_up_requires(self):
|
| 33 |
+
self._pretagged = self._config.get('pretagged')
|
| 34 |
+
if self._pretagged:
|
| 35 |
+
self._requires = set()
|
| 36 |
+
else:
|
| 37 |
+
self._requires = self.__class__.REQUIRES_DEFAULT
|
| 38 |
+
|
| 39 |
+
def _set_up_model(self, config, pipeline, device):
|
| 40 |
+
self._pretrain = pipeline.foundation_cache.load_pretrain(config['pretrain_path']) if 'pretrain_path' in config else None
|
| 41 |
+
args = {'charlm_forward_file': config.get('forward_charlm_path', None),
|
| 42 |
+
'charlm_backward_file': config.get('backward_charlm_path', None)}
|
| 43 |
+
self._trainer = Trainer(args=args, pretrain=self.pretrain, model_file=config['model_path'], device=device, foundation_cache=pipeline.foundation_cache)
|
| 44 |
+
|
| 45 |
+
def get_known_relations(self):
|
| 46 |
+
"""
|
| 47 |
+
Return a list of relations which this processor can produce
|
| 48 |
+
"""
|
| 49 |
+
keys = [k for k in self.vocab['deprel']._unit2id.keys() if k not in VOCAB_PREFIX]
|
| 50 |
+
return keys
|
| 51 |
+
|
| 52 |
+
def process(self, document):
|
| 53 |
+
if hasattr(self, '_variant'):
|
| 54 |
+
return self._variant.process(document)
|
| 55 |
+
|
| 56 |
+
if any(word.upos is None and word.xpos is None for sentence in document.sentences for word in sentence.words):
|
| 57 |
+
raise ValueError("POS not run before depparse!")
|
| 58 |
+
try:
|
| 59 |
+
batch = DataLoader(document, self.config['batch_size'], self.config, self.pretrain, vocab=self.vocab, evaluation=True,
|
| 60 |
+
sort_during_eval=self.config.get('sort_during_eval', True),
|
| 61 |
+
min_length_to_batch_separately=self.config.get('min_length_to_batch_separately', DEFAULT_SEPARATE_BATCH))
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
preds = []
|
| 64 |
+
for i, b in enumerate(batch):
|
| 65 |
+
preds += self.trainer.predict(b)
|
| 66 |
+
if batch.data_orig_idx is not None:
|
| 67 |
+
preds = unsort(preds, batch.data_orig_idx)
|
| 68 |
+
batch.doc.set((doc.HEAD, doc.DEPREL), [y for x in preds for y in x])
|
| 69 |
+
# build dependencies based on predictions
|
| 70 |
+
for sentence in batch.doc.sentences:
|
| 71 |
+
sentence.build_dependencies()
|
| 72 |
+
return batch.doc
|
| 73 |
+
except RuntimeError as e:
|
| 74 |
+
if str(e).startswith("CUDA out of memory. Tried to allocate"):
|
| 75 |
+
new_message = str(e) + " ... You may be able to compensate for this by separating long sentences into their own batch with a parameter such as depparse_min_length_to_batch_separately=150 or by limiting the overall batch size with depparse_batch_size=400."
|
| 76 |
+
raise RuntimeError(new_message) from e
|
| 77 |
+
else:
|
| 78 |
+
raise
|
stanza/stanza/pipeline/langid_processor.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Processor for determining language of text.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import emoji
|
| 6 |
+
import re
|
| 7 |
+
import stanza
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from stanza.models.common.doc import Document
|
| 11 |
+
from stanza.models.langid.model import LangIDBiLSTM
|
| 12 |
+
from stanza.pipeline._constants import *
|
| 13 |
+
from stanza.pipeline.processor import UDProcessor, register_processor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@register_processor(name=LANGID)
|
| 17 |
+
class LangIDProcessor(UDProcessor):
|
| 18 |
+
"""
|
| 19 |
+
Class for detecting language of text.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
# set of processor requirements this processor fulfills
|
| 23 |
+
PROVIDES_DEFAULT = set([LANGID])
|
| 24 |
+
|
| 25 |
+
# set of processor requirements for this processor
|
| 26 |
+
REQUIRES_DEFAULT = set([])
|
| 27 |
+
|
| 28 |
+
# default max sequence length
|
| 29 |
+
MAX_SEQ_LENGTH_DEFAULT = 1000
|
| 30 |
+
|
| 31 |
+
def _set_up_model(self, config, pipeline, device):
|
| 32 |
+
batch_size = config.get("batch_size", 64)
|
| 33 |
+
self._model = LangIDBiLSTM.load(path=config["model_path"], device=device,
|
| 34 |
+
batch_size=batch_size, lang_subset=config.get("lang_subset"))
|
| 35 |
+
self._char_index = self._model.char_to_idx
|
| 36 |
+
self._clean_text = config.get("clean_text")
|
| 37 |
+
|
| 38 |
+
def _text_to_tensor(self, docs):
|
| 39 |
+
"""
|
| 40 |
+
Map list of strings to batch tensor. Assumed all docs are same length.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
device = next(self._model.parameters()).device
|
| 44 |
+
all_docs = []
|
| 45 |
+
for doc in docs:
|
| 46 |
+
doc_chars = [self._char_index.get(c, self._char_index["UNK"]) for c in list(doc)]
|
| 47 |
+
all_docs.append(doc_chars)
|
| 48 |
+
return torch.tensor(all_docs, device=device, dtype=torch.long)
|
| 49 |
+
|
| 50 |
+
def _id_langs(self, batch_tensor):
|
| 51 |
+
"""
|
| 52 |
+
Identify languages for each sequence in a batch tensor
|
| 53 |
+
"""
|
| 54 |
+
predictions = self._model.prediction_scores(batch_tensor)
|
| 55 |
+
prediction_labels = [self._model.idx_to_tag[prediction] for prediction in predictions]
|
| 56 |
+
|
| 57 |
+
return prediction_labels
|
| 58 |
+
|
| 59 |
+
# regexes for cleaning text
|
| 60 |
+
http_regex = re.compile(r"https?:\/\/t\.co/[a-zA-Z0-9]+")
|
| 61 |
+
handle_regex = re.compile("@[a-zA-Z0-9_]+")
|
| 62 |
+
hashtag_regex = re.compile("#[a-zA-Z]+")
|
| 63 |
+
punctuation_regex = re.compile("[!.]+")
|
| 64 |
+
all_regexes = [http_regex, handle_regex, hashtag_regex, punctuation_regex]
|
| 65 |
+
|
| 66 |
+
@staticmethod
|
| 67 |
+
def clean_text(text):
|
| 68 |
+
"""
|
| 69 |
+
Process text to improve language id performance. Main emphasis is on tweets, this method removes shortened
|
| 70 |
+
urls, hashtags, handles, and punctuation and emoji.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
for regex in LangIDProcessor.all_regexes:
|
| 74 |
+
text = regex.sub(" ", text)
|
| 75 |
+
|
| 76 |
+
text = emoji.emojize(text)
|
| 77 |
+
text = emoji.replace_emoji(text, replace=' ')
|
| 78 |
+
|
| 79 |
+
if text.strip():
|
| 80 |
+
text = text.strip()
|
| 81 |
+
|
| 82 |
+
return text
|
| 83 |
+
|
| 84 |
+
def _process_list(self, docs):
|
| 85 |
+
"""
|
| 86 |
+
Identify language of list of strings or Documents
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
if len(docs) == 0:
|
| 90 |
+
# TO DO: what standard do we want for bad input, such as empty list?
|
| 91 |
+
# TO DO: more handling of bad input
|
| 92 |
+
return
|
| 93 |
+
|
| 94 |
+
if isinstance(docs[0], str):
|
| 95 |
+
docs = [Document([], text) for text in docs]
|
| 96 |
+
|
| 97 |
+
docs_by_length = {}
|
| 98 |
+
for doc in docs:
|
| 99 |
+
text = LangIDProcessor.clean_text(doc.text) if self._clean_text else doc.text
|
| 100 |
+
doc_length = len(text)
|
| 101 |
+
if doc_length not in docs_by_length:
|
| 102 |
+
docs_by_length[doc_length] = []
|
| 103 |
+
docs_by_length[doc_length].append((doc, text))
|
| 104 |
+
|
| 105 |
+
for doc_length in docs_by_length:
|
| 106 |
+
inputs = [doc[1] for doc in docs_by_length[doc_length]]
|
| 107 |
+
predictions = self._id_langs(self._text_to_tensor(inputs))
|
| 108 |
+
for doc, lang in zip(docs_by_length[doc_length], predictions):
|
| 109 |
+
doc[0].lang = lang
|
| 110 |
+
|
| 111 |
+
return docs
|
| 112 |
+
|
| 113 |
+
def process(self, doc):
|
| 114 |
+
"""
|
| 115 |
+
Handle single str or Document
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
wrapped_doc = [doc]
|
| 119 |
+
return self._process_list(wrapped_doc)[0]
|
| 120 |
+
|
| 121 |
+
def bulk_process(self, docs):
|
| 122 |
+
"""
|
| 123 |
+
Handle list of strings or Documents
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
return self._process_list(docs)
|
| 127 |
+
|
stanza/stanza/pipeline/lemma_processor.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Processor for performing lemmatization
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from itertools import compress
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from stanza.models.common import doc
|
| 10 |
+
from stanza.models.lemma.data import DataLoader
|
| 11 |
+
from stanza.models.lemma.trainer import Trainer
|
| 12 |
+
from stanza.pipeline._constants import *
|
| 13 |
+
from stanza.pipeline.processor import UDProcessor, register_processor
|
| 14 |
+
|
| 15 |
+
WORD_TAGS = [doc.TEXT, doc.UPOS]
|
| 16 |
+
|
| 17 |
+
@register_processor(name=LEMMA)
|
| 18 |
+
class LemmaProcessor(UDProcessor):
|
| 19 |
+
|
| 20 |
+
# set of processor requirements this processor fulfills
|
| 21 |
+
PROVIDES_DEFAULT = set([LEMMA])
|
| 22 |
+
# set of processor requirements for this processor
|
| 23 |
+
# pos will be added later for non-identity lemmatizerx
|
| 24 |
+
REQUIRES_DEFAULT = set([TOKENIZE])
|
| 25 |
+
# default batch size
|
| 26 |
+
DEFAULT_BATCH_SIZE = 5000
|
| 27 |
+
|
| 28 |
+
def __init__(self, config, pipeline, device):
|
| 29 |
+
# run lemmatizer in identity mode
|
| 30 |
+
self._use_identity = None
|
| 31 |
+
self._pretagged = None
|
| 32 |
+
super().__init__(config, pipeline, device)
|
| 33 |
+
|
| 34 |
+
@property
|
| 35 |
+
def use_identity(self):
|
| 36 |
+
return self._use_identity
|
| 37 |
+
|
| 38 |
+
def _set_up_model(self, config, pipeline, device):
|
| 39 |
+
if config.get('use_identity') in ['True', True]:
|
| 40 |
+
self._use_identity = True
|
| 41 |
+
self._config = config
|
| 42 |
+
self.config['batch_size'] = LemmaProcessor.DEFAULT_BATCH_SIZE
|
| 43 |
+
else:
|
| 44 |
+
# the lemmatizer only looks at one word when making
|
| 45 |
+
# decisions, not the surrounding context
|
| 46 |
+
# therefore, we can save some time by remembering what
|
| 47 |
+
# we did the last time we saw any given word,pos
|
| 48 |
+
# since a long running program will remember everything
|
| 49 |
+
# (unless we go back and make it smarter)
|
| 50 |
+
# we make this an option, not the default
|
| 51 |
+
# TODO: need to update the cache to skip the contextual lemmatizer
|
| 52 |
+
self.store_results = config.get('store_results', False)
|
| 53 |
+
self._use_identity = False
|
| 54 |
+
args = {'charlm_forward_file': config.get('forward_charlm_path', None),
|
| 55 |
+
'charlm_backward_file': config.get('backward_charlm_path', None)}
|
| 56 |
+
lemma_classifier_args = dict(args)
|
| 57 |
+
lemma_classifier_args['wordvec_pretrain_file'] = config.get('pretrain_path', None)
|
| 58 |
+
self._trainer = Trainer(args=args, model_file=config['model_path'], device=device, foundation_cache=pipeline.foundation_cache, lemma_classifier_args=lemma_classifier_args)
|
| 59 |
+
|
| 60 |
+
def _set_up_requires(self):
|
| 61 |
+
self._pretagged = self._config.get('pretagged', None)
|
| 62 |
+
if self._pretagged:
|
| 63 |
+
self._requires = set()
|
| 64 |
+
elif self.config.get('pos') and not self.use_identity:
|
| 65 |
+
self._requires = LemmaProcessor.REQUIRES_DEFAULT.union(set([POS]))
|
| 66 |
+
else:
|
| 67 |
+
self._requires = LemmaProcessor.REQUIRES_DEFAULT
|
| 68 |
+
|
| 69 |
+
def process(self, document):
|
| 70 |
+
if not self.use_identity:
|
| 71 |
+
batch = DataLoader(document, self.config['batch_size'], self.config, vocab=self.vocab, evaluation=True, expand_unk_vocab=True)
|
| 72 |
+
else:
|
| 73 |
+
batch = DataLoader(document, self.config['batch_size'], self.config, evaluation=True, conll_only=True)
|
| 74 |
+
if self.use_identity:
|
| 75 |
+
preds = [word.text for sent in batch.doc.sentences for word in sent.words]
|
| 76 |
+
elif self.config.get('dict_only', False):
|
| 77 |
+
preds = self.trainer.predict_dict(batch.doc.get([doc.TEXT, doc.UPOS]))
|
| 78 |
+
else:
|
| 79 |
+
if self.config.get('ensemble_dict', False):
|
| 80 |
+
# skip the seq2seq model when we can
|
| 81 |
+
skip = self.trainer.skip_seq2seq(batch.doc.get([doc.TEXT, doc.UPOS]))
|
| 82 |
+
# although there is no explicit use of caseless or lemma_caseless in this processor,
|
| 83 |
+
# it shows up in the config which gets passed to the DataLoader,
|
| 84 |
+
# possibly affecting its results
|
| 85 |
+
seq2seq_batch = DataLoader(document, self.config['batch_size'], self.config, vocab=self.vocab,
|
| 86 |
+
evaluation=True, skip=skip, expand_unk_vocab=True)
|
| 87 |
+
else:
|
| 88 |
+
seq2seq_batch = batch
|
| 89 |
+
|
| 90 |
+
with torch.no_grad():
|
| 91 |
+
preds = []
|
| 92 |
+
edits = []
|
| 93 |
+
for i, b in enumerate(seq2seq_batch):
|
| 94 |
+
ps, es = self.trainer.predict(b, self.config['beam_size'], seq2seq_batch.vocab)
|
| 95 |
+
preds += ps
|
| 96 |
+
if es is not None:
|
| 97 |
+
edits += es
|
| 98 |
+
|
| 99 |
+
if self.config.get('ensemble_dict', False):
|
| 100 |
+
word_tags = batch.doc.get(WORD_TAGS)
|
| 101 |
+
words = [x[0] for x in word_tags]
|
| 102 |
+
preds = self.trainer.postprocess([x for x, y in zip(words, skip) if not y], preds, edits=edits)
|
| 103 |
+
if self.store_results:
|
| 104 |
+
new_word_tags = compress(word_tags, map(lambda x: not x, skip))
|
| 105 |
+
new_predictions = [(x[0], x[1], y) for x, y in zip(new_word_tags, preds)]
|
| 106 |
+
self.trainer.train_dict(new_predictions, update_word_dict=False)
|
| 107 |
+
# expand seq2seq predictions to the same size as all words
|
| 108 |
+
i = 0
|
| 109 |
+
preds1 = []
|
| 110 |
+
for s in skip:
|
| 111 |
+
if s:
|
| 112 |
+
preds1.append('')
|
| 113 |
+
else:
|
| 114 |
+
preds1.append(preds[i])
|
| 115 |
+
i += 1
|
| 116 |
+
preds = self.trainer.ensemble(word_tags, preds1)
|
| 117 |
+
else:
|
| 118 |
+
preds = self.trainer.postprocess(batch.doc.get([doc.TEXT]), preds, edits=edits)
|
| 119 |
+
|
| 120 |
+
if self.trainer.has_contextual_lemmatizers():
|
| 121 |
+
preds = self.trainer.update_contextual_preds(batch.doc, preds)
|
| 122 |
+
|
| 123 |
+
# map empty string lemmas to '_'
|
| 124 |
+
preds = [max([(len(x), x), (0, '_')])[1] for x in preds]
|
| 125 |
+
batch.doc.set([doc.LEMMA], preds)
|
| 126 |
+
return batch.doc
|
stanza/stanza/pipeline/multilingual.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Class for running multilingual pipelines
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
import copy
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Union
|
| 9 |
+
|
| 10 |
+
from stanza.models.common.doc import Document
|
| 11 |
+
from stanza.models.common.utils import default_device
|
| 12 |
+
from stanza.pipeline.core import Pipeline, DownloadMethod
|
| 13 |
+
from stanza.pipeline._constants import *
|
| 14 |
+
from stanza.resources.common import DEFAULT_MODEL_DIR, get_language_resources, load_resources_json
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger('stanza')
|
| 17 |
+
|
| 18 |
+
class MultilingualPipeline:
|
| 19 |
+
"""
|
| 20 |
+
Pipeline for handling multilingual data. Takes in text, detects language, and routes request to pipeline for that
|
| 21 |
+
language.
|
| 22 |
+
|
| 23 |
+
You can specify options to individual language pipelines with the lang_configs field.
|
| 24 |
+
For example, if you want English pipelines to have NER, but want to turn that off for French, you can do:
|
| 25 |
+
lang_configs = {"en": {"processors": "tokenize,pos,lemma,depparse,ner"},
|
| 26 |
+
"fr": {"processors": "tokenize,pos,lemma,depparse"}}
|
| 27 |
+
pipeline = MultilingualPipeline(lang_configs=lang_configs)
|
| 28 |
+
|
| 29 |
+
You can also pass in a defaultdict created in such a way that it provides default parameters for each language.
|
| 30 |
+
For example, in order to only get tokenization for each language:
|
| 31 |
+
(remembering that the Pipeline will automagically add MWT to a language which uses MWT):
|
| 32 |
+
from collections import defaultdict
|
| 33 |
+
lang_configs = defaultdict(lambda: dict(processors="tokenize"))
|
| 34 |
+
pipeline = MultilingualPipeline(lang_configs=lang_configs)
|
| 35 |
+
|
| 36 |
+
download_method can be set as in Pipeline to turn off downloading
|
| 37 |
+
of the .json config or turn off downloading of everything
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self,
|
| 41 |
+
model_dir: str = DEFAULT_MODEL_DIR,
|
| 42 |
+
lang_id_config: dict = None,
|
| 43 |
+
lang_configs: dict = None,
|
| 44 |
+
ld_batch_size: int = 64,
|
| 45 |
+
max_cache_size: int = 10,
|
| 46 |
+
use_gpu: bool = None,
|
| 47 |
+
restrict: bool = False,
|
| 48 |
+
device: str = None,
|
| 49 |
+
download_method: DownloadMethod = DownloadMethod.DOWNLOAD_RESOURCES,
|
| 50 |
+
# python 3.6 compatibility - maybe want to update to 3.7 at some point
|
| 51 |
+
processors: Union[str, list] = None,
|
| 52 |
+
):
|
| 53 |
+
# set up configs and cache for various language pipelines
|
| 54 |
+
self.model_dir = model_dir
|
| 55 |
+
self.lang_id_config = {} if lang_id_config is None else copy.deepcopy(lang_id_config)
|
| 56 |
+
self.lang_configs = {} if lang_configs is None else copy.deepcopy(lang_configs)
|
| 57 |
+
self.max_cache_size = max_cache_size
|
| 58 |
+
# OrderedDict so we can use it as a LRU cache
|
| 59 |
+
# most recent Pipeline goes to the end, pop the oldest one
|
| 60 |
+
# when we run out of space
|
| 61 |
+
self.pipeline_cache = OrderedDict()
|
| 62 |
+
if processors is None:
|
| 63 |
+
self.default_processors = None
|
| 64 |
+
elif isinstance(processors, str):
|
| 65 |
+
self.default_processors = [x.strip() for x in processors.split(",")]
|
| 66 |
+
else:
|
| 67 |
+
self.default_processors = list(processors)
|
| 68 |
+
|
| 69 |
+
self.download_method = download_method
|
| 70 |
+
if 'download_method' not in self.lang_id_config:
|
| 71 |
+
self.lang_id_config['download_method'] = self.download_method
|
| 72 |
+
|
| 73 |
+
# if lang is not in any of the lang_configs, update them to
|
| 74 |
+
# include the lang parameter. otherwise, the default language
|
| 75 |
+
# will always be used...
|
| 76 |
+
for lang in self.lang_configs:
|
| 77 |
+
if 'lang' not in self.lang_configs[lang]:
|
| 78 |
+
self.lang_configs[lang]['lang'] = lang
|
| 79 |
+
|
| 80 |
+
if restrict and 'langid_lang_subset' not in self.lang_id_config:
|
| 81 |
+
known_langs = sorted(self.lang_configs.keys())
|
| 82 |
+
if known_langs == 0:
|
| 83 |
+
logger.warning("MultilingualPipeline asked to restrict to lang_configs, but lang_configs was empty. Ignoring...")
|
| 84 |
+
else:
|
| 85 |
+
logger.debug("Restricting MultilingualPipeline to %s", known_langs)
|
| 86 |
+
self.lang_id_config['langid_lang_subset'] = known_langs
|
| 87 |
+
|
| 88 |
+
# set use_gpu
|
| 89 |
+
if device is None:
|
| 90 |
+
if use_gpu is None or use_gpu == True:
|
| 91 |
+
device = default_device()
|
| 92 |
+
else:
|
| 93 |
+
device = 'cpu'
|
| 94 |
+
self.device = device
|
| 95 |
+
|
| 96 |
+
# build language id pipeline
|
| 97 |
+
self.lang_id_pipeline = Pipeline(dir=self.model_dir, lang='multilingual', processors="langid",
|
| 98 |
+
device=self.device, **self.lang_id_config)
|
| 99 |
+
# load the resources so that we can refer to it later when building a new pipeline
|
| 100 |
+
# note that it was either downloaded or not based on download_method when building the lang_id_pipeline
|
| 101 |
+
self.resources = load_resources_json(self.model_dir)
|
| 102 |
+
|
| 103 |
+
def _update_pipeline_cache(self, lang):
|
| 104 |
+
"""
|
| 105 |
+
Do any necessary updates to the pipeline cache for this language. This includes building a new
|
| 106 |
+
pipeline for the lang, and possibly clearing out a language with the old last access date.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
# update request history
|
| 110 |
+
if lang in self.pipeline_cache:
|
| 111 |
+
self.pipeline_cache.move_to_end(lang, last=True)
|
| 112 |
+
|
| 113 |
+
# update language configs
|
| 114 |
+
# try/except to allow for a defaultdict
|
| 115 |
+
try:
|
| 116 |
+
lang_config = self.lang_configs[lang]
|
| 117 |
+
except KeyError:
|
| 118 |
+
lang_config = {'lang': lang}
|
| 119 |
+
self.lang_configs[lang] = lang_config
|
| 120 |
+
|
| 121 |
+
# if a defaultdict is passed in, the defaultdict might not contain 'lang'
|
| 122 |
+
# so even though we tried adding 'lang' in the constructor, we'll check again here
|
| 123 |
+
if 'lang' not in lang_config:
|
| 124 |
+
lang_config['lang'] = lang
|
| 125 |
+
|
| 126 |
+
if 'download_method' not in lang_config:
|
| 127 |
+
lang_config['download_method'] = self.download_method
|
| 128 |
+
|
| 129 |
+
if 'processors' not in lang_config:
|
| 130 |
+
if self.default_processors:
|
| 131 |
+
lang_resources = get_language_resources(self.resources, lang)
|
| 132 |
+
lang_processors = [x for x in self.default_processors if x in lang_resources]
|
| 133 |
+
if lang_processors != self.default_processors:
|
| 134 |
+
logger.info("Not all requested processors %s available for %s. Loading %s instead", self.default_processors, lang, lang_processors)
|
| 135 |
+
lang_config['processors'] = ",".join(lang_processors)
|
| 136 |
+
|
| 137 |
+
if 'device' not in lang_config:
|
| 138 |
+
lang_config['device'] = self.device
|
| 139 |
+
|
| 140 |
+
# update pipeline cache
|
| 141 |
+
if lang not in self.pipeline_cache:
|
| 142 |
+
logger.debug("Loading unknown language in MultilingualPipeline: %s", lang)
|
| 143 |
+
# clear least recently used lang from pipeline cache
|
| 144 |
+
if len(self.pipeline_cache) == self.max_cache_size:
|
| 145 |
+
self.pipeline_cache.popitem(last=False)
|
| 146 |
+
self.pipeline_cache[lang] = Pipeline(dir=self.model_dir, **self.lang_configs[lang])
|
| 147 |
+
|
| 148 |
+
def process(self, doc):
|
| 149 |
+
"""
|
| 150 |
+
Run language detection on a string, a Document, or a list of either, route to language specific pipeline
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
# only return a list if given a list
|
| 154 |
+
singleton_input = not isinstance(doc, list)
|
| 155 |
+
if singleton_input:
|
| 156 |
+
docs = [doc]
|
| 157 |
+
else:
|
| 158 |
+
docs = doc
|
| 159 |
+
|
| 160 |
+
if docs and isinstance(docs[0], str):
|
| 161 |
+
docs = [Document([], text=text) for text in docs]
|
| 162 |
+
|
| 163 |
+
# run language identification
|
| 164 |
+
docs_w_langid = self.lang_id_pipeline.process(docs)
|
| 165 |
+
|
| 166 |
+
# create language specific batches, store global idx with each doc
|
| 167 |
+
lang_batches = {}
|
| 168 |
+
for doc_idx, doc in enumerate(docs_w_langid):
|
| 169 |
+
logger.debug("Language for document %d: %s", doc_idx, doc.lang)
|
| 170 |
+
if doc.lang not in lang_batches:
|
| 171 |
+
lang_batches[doc.lang] = []
|
| 172 |
+
lang_batches[doc.lang].append(doc)
|
| 173 |
+
|
| 174 |
+
# run through each language, submit a batch to the language specific pipeline
|
| 175 |
+
for lang in lang_batches.keys():
|
| 176 |
+
self._update_pipeline_cache(lang)
|
| 177 |
+
self.pipeline_cache[lang](lang_batches[lang])
|
| 178 |
+
|
| 179 |
+
# only return a list if given a list
|
| 180 |
+
if singleton_input:
|
| 181 |
+
return docs_w_langid[0]
|
| 182 |
+
else:
|
| 183 |
+
return docs_w_langid
|
| 184 |
+
|
| 185 |
+
def __call__(self, doc):
|
| 186 |
+
doc = self.process(doc)
|
| 187 |
+
return doc
|
| 188 |
+
|
stanza/stanza/pipeline/mwt_processor.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Processor for performing multi-word-token expansion
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import io
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from stanza.models.mwt.data import DataLoader
|
| 10 |
+
from stanza.models.mwt.trainer import Trainer
|
| 11 |
+
from stanza.pipeline._constants import *
|
| 12 |
+
from stanza.pipeline.processor import UDProcessor, register_processor
|
| 13 |
+
|
| 14 |
+
@register_processor(MWT)
|
| 15 |
+
class MWTProcessor(UDProcessor):
|
| 16 |
+
|
| 17 |
+
# set of processor requirements this processor fulfills
|
| 18 |
+
PROVIDES_DEFAULT = set([MWT])
|
| 19 |
+
# set of processor requirements for this processor
|
| 20 |
+
REQUIRES_DEFAULT = set([TOKENIZE])
|
| 21 |
+
|
| 22 |
+
def _set_up_model(self, config, pipeline, device):
|
| 23 |
+
self._trainer = Trainer(model_file=config['model_path'], device=device)
|
| 24 |
+
|
| 25 |
+
def build_batch(self, document):
|
| 26 |
+
return DataLoader(document, self.config['batch_size'], self.config, vocab=self.vocab, evaluation=True, expand_unk_vocab=True)
|
| 27 |
+
|
| 28 |
+
def process(self, document):
|
| 29 |
+
batch = self.build_batch(document)
|
| 30 |
+
|
| 31 |
+
# process the rest
|
| 32 |
+
expansions = batch.doc.get_mwt_expansions(evaluation=True)
|
| 33 |
+
if len(batch) > 0:
|
| 34 |
+
# decide trainer type and run eval
|
| 35 |
+
if self.config['dict_only']:
|
| 36 |
+
preds = self.trainer.predict_dict(expansions)
|
| 37 |
+
else:
|
| 38 |
+
with torch.no_grad():
|
| 39 |
+
preds = []
|
| 40 |
+
for i, b in enumerate(batch.to_loader()):
|
| 41 |
+
preds += self.trainer.predict(b, never_decode_unk=True, vocab=batch.vocab)
|
| 42 |
+
|
| 43 |
+
if self.config.get('ensemble_dict', False):
|
| 44 |
+
preds = self.trainer.ensemble(expansions, preds)
|
| 45 |
+
else:
|
| 46 |
+
# skip eval if dev data does not exist
|
| 47 |
+
preds = []
|
| 48 |
+
|
| 49 |
+
batch.doc.set_mwt_expansions(preds, process_manual_expanded=False)
|
| 50 |
+
return batch.doc
|
| 51 |
+
|
| 52 |
+
def bulk_process(self, docs):
|
| 53 |
+
"""
|
| 54 |
+
MWT processor counts some statistics on the individual docs, so we need to separately redo those stats
|
| 55 |
+
"""
|
| 56 |
+
docs = super().bulk_process(docs)
|
| 57 |
+
for doc in docs:
|
| 58 |
+
doc._count_words()
|
| 59 |
+
return docs
|
stanza/stanza/pipeline/pos_processor.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Processor for performing part-of-speech tagging
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from stanza.models.common import doc
|
| 8 |
+
from stanza.models.common.utils import unsort
|
| 9 |
+
from stanza.models.common.vocab import VOCAB_PREFIX, CompositeVocab
|
| 10 |
+
from stanza.models.pos.data import Dataset
|
| 11 |
+
from stanza.models.pos.trainer import Trainer
|
| 12 |
+
from stanza.pipeline._constants import *
|
| 13 |
+
from stanza.pipeline.processor import UDProcessor, register_processor
|
| 14 |
+
from stanza.utils.get_tqdm import get_tqdm
|
| 15 |
+
|
| 16 |
+
tqdm = get_tqdm()
|
| 17 |
+
|
| 18 |
+
@register_processor(name=POS)
|
| 19 |
+
class POSProcessor(UDProcessor):
|
| 20 |
+
|
| 21 |
+
# set of processor requirements this processor fulfills
|
| 22 |
+
PROVIDES_DEFAULT = set([POS])
|
| 23 |
+
# set of processor requirements for this processor
|
| 24 |
+
REQUIRES_DEFAULT = set([TOKENIZE])
|
| 25 |
+
|
| 26 |
+
def _set_up_model(self, config, pipeline, device):
|
| 27 |
+
# get pretrained word vectors
|
| 28 |
+
self._pretrain = pipeline.foundation_cache.load_pretrain(config['pretrain_path']) if 'pretrain_path' in config else None
|
| 29 |
+
args = {'charlm_forward_file': config.get('forward_charlm_path', None),
|
| 30 |
+
'charlm_backward_file': config.get('backward_charlm_path', None)}
|
| 31 |
+
# set up trainer
|
| 32 |
+
self._trainer = Trainer(pretrain=self.pretrain, model_file=config['model_path'], device=device, args=args, foundation_cache=pipeline.foundation_cache)
|
| 33 |
+
self._tqdm = 'tqdm' in config and config['tqdm']
|
| 34 |
+
|
| 35 |
+
def __str__(self):
|
| 36 |
+
return "POSProcessor(%s)" % self.config['model_path']
|
| 37 |
+
|
| 38 |
+
def get_known_xpos(self):
|
| 39 |
+
"""
|
| 40 |
+
Returns the xpos tags known by this model
|
| 41 |
+
"""
|
| 42 |
+
if isinstance(self.vocab['xpos'], CompositeVocab):
|
| 43 |
+
if len(self.vocab['xpos']) == 1:
|
| 44 |
+
return [k for k in self.vocab['xpos'][0]._unit2id.keys() if k not in VOCAB_PREFIX]
|
| 45 |
+
else:
|
| 46 |
+
return {k: v.keys() - VOCAB_PREFIX for k, v in self.vocab['xpos']._unit2id.items()}
|
| 47 |
+
return [k for k in self.vocab['xpos']._unit2id.keys() if k not in VOCAB_PREFIX]
|
| 48 |
+
|
| 49 |
+
def is_composite_xpos(self):
|
| 50 |
+
"""
|
| 51 |
+
Returns if the xpos tags are part of a composite vocab
|
| 52 |
+
"""
|
| 53 |
+
return isinstance(self.vocab['xpos'], CompositeVocab)
|
| 54 |
+
|
| 55 |
+
def get_known_upos(self):
|
| 56 |
+
"""
|
| 57 |
+
Returns the upos tags known by this model
|
| 58 |
+
"""
|
| 59 |
+
keys = [k for k in self.vocab['upos']._unit2id.keys() if k not in VOCAB_PREFIX]
|
| 60 |
+
return keys
|
| 61 |
+
|
| 62 |
+
def get_known_feats(self):
|
| 63 |
+
"""
|
| 64 |
+
Returns the features known by this model
|
| 65 |
+
"""
|
| 66 |
+
values = {k: v.keys() - VOCAB_PREFIX for k, v in self.vocab['feats']._unit2id.items()}
|
| 67 |
+
return values
|
| 68 |
+
|
| 69 |
+
def process(self, document):
|
| 70 |
+
# currently, POS models are saved w/o the batch_maximum_tokens flag
|
| 71 |
+
maximum_tokens = self.config.get('batch_maximum_tokens', 5000)
|
| 72 |
+
|
| 73 |
+
dataset = Dataset(
|
| 74 |
+
document, self.config, self.pretrain, vocab=self.vocab, evaluation=True,
|
| 75 |
+
sort_during_eval=True)
|
| 76 |
+
batch = iter(dataset.to_length_limited_loader(batch_size=self.config['batch_size'], maximum_tokens=maximum_tokens))
|
| 77 |
+
preds = []
|
| 78 |
+
|
| 79 |
+
idx = []
|
| 80 |
+
with torch.no_grad():
|
| 81 |
+
if self._tqdm:
|
| 82 |
+
batch = tqdm(batch)
|
| 83 |
+
for i, b in enumerate(batch):
|
| 84 |
+
idx.extend(b[-1])
|
| 85 |
+
preds += self.trainer.predict(b)
|
| 86 |
+
|
| 87 |
+
preds = unsort(preds, idx)
|
| 88 |
+
dataset.doc.set([doc.UPOS, doc.XPOS, doc.FEATS], [y for x in preds for y in x])
|
| 89 |
+
return dataset.doc
|
stanza/stanza/pipeline/processor.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Base classes for processors
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from abc import ABC, abstractmethod
|
| 6 |
+
|
| 7 |
+
from stanza.models.common.doc import Document
|
| 8 |
+
from stanza.pipeline.registry import NAME_TO_PROCESSOR_CLASS, PIPELINE_NAMES, PROCESSOR_VARIANTS
|
| 9 |
+
|
| 10 |
+
class ProcessorRequirementsException(Exception):
|
| 11 |
+
""" Exception indicating a processor's requirements will not be met """
|
| 12 |
+
|
| 13 |
+
def __init__(self, processors_list, err_processor, provided_reqs):
|
| 14 |
+
self._err_processor = err_processor
|
| 15 |
+
# mark the broken processor as inactive, drop resources
|
| 16 |
+
self.err_processor.mark_inactive()
|
| 17 |
+
self._processors_list = processors_list
|
| 18 |
+
self._provided_reqs = provided_reqs
|
| 19 |
+
self.build_message()
|
| 20 |
+
|
| 21 |
+
@property
|
| 22 |
+
def err_processor(self):
|
| 23 |
+
""" The processor that raised the exception """
|
| 24 |
+
return self._err_processor
|
| 25 |
+
|
| 26 |
+
@property
|
| 27 |
+
def processor_type(self):
|
| 28 |
+
return type(self.err_processor).__name__
|
| 29 |
+
|
| 30 |
+
@property
|
| 31 |
+
def processors_list(self):
|
| 32 |
+
return self._processors_list
|
| 33 |
+
|
| 34 |
+
@property
|
| 35 |
+
def provided_reqs(self):
|
| 36 |
+
return self._provided_reqs
|
| 37 |
+
|
| 38 |
+
def build_message(self):
|
| 39 |
+
self.message = (f"---\nPipeline Requirements Error!\n"
|
| 40 |
+
f"\tProcessor: {self.processor_type}\n"
|
| 41 |
+
f"\tPipeline processors list: {','.join(self.processors_list)}\n"
|
| 42 |
+
f"\tProcessor Requirements: {self.err_processor.requires}\n"
|
| 43 |
+
f"\t\t- fulfilled: {self.err_processor.requires.intersection(self.provided_reqs)}\n"
|
| 44 |
+
f"\t\t- missing: {self.err_processor.requires - self.provided_reqs}\n"
|
| 45 |
+
f"\nThe processors list provided for this pipeline is invalid. Please make sure all "
|
| 46 |
+
f"prerequisites are met for every processor.\n\n")
|
| 47 |
+
|
| 48 |
+
def __str__(self):
|
| 49 |
+
return self.message
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class Processor(ABC):
|
| 53 |
+
""" Base class for all processors """
|
| 54 |
+
|
| 55 |
+
def __init__(self, config, pipeline, device):
|
| 56 |
+
# overall config for the processor
|
| 57 |
+
self._config = config
|
| 58 |
+
# pipeline building this processor (presently processors are only meant to exist in one pipeline)
|
| 59 |
+
self._pipeline = pipeline
|
| 60 |
+
self._set_up_variants(config, device)
|
| 61 |
+
# run set up process
|
| 62 |
+
# set up what annotations are required based on config
|
| 63 |
+
if not self._set_up_variant_requires():
|
| 64 |
+
self._set_up_requires()
|
| 65 |
+
# set up what annotations are provided based on config
|
| 66 |
+
self._set_up_provides()
|
| 67 |
+
# given pipeline constructing this processor, check if requirements are met, throw exception if not
|
| 68 |
+
self._check_requirements()
|
| 69 |
+
|
| 70 |
+
if hasattr(self, '_variant') and self._variant.OVERRIDE:
|
| 71 |
+
self.process = self._variant.process
|
| 72 |
+
|
| 73 |
+
def __str__(self):
|
| 74 |
+
"""
|
| 75 |
+
Simple description of the processor: name(model)
|
| 76 |
+
"""
|
| 77 |
+
name = self.__class__.__name__
|
| 78 |
+
model = None
|
| 79 |
+
if self._config is not None:
|
| 80 |
+
model = self._config.get('model_path')
|
| 81 |
+
if model is None:
|
| 82 |
+
return name
|
| 83 |
+
else:
|
| 84 |
+
return "{}({})".format(name, model)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@abstractmethod
|
| 88 |
+
def process(self, doc):
|
| 89 |
+
""" Process a Document. This is the main method of a processor. """
|
| 90 |
+
pass
|
| 91 |
+
|
| 92 |
+
def bulk_process(self, docs):
|
| 93 |
+
""" Process a list of Documents. This should be replaced with a more efficient implementation if possible. """
|
| 94 |
+
|
| 95 |
+
if hasattr(self, '_variant'):
|
| 96 |
+
return self._variant.bulk_process(docs)
|
| 97 |
+
|
| 98 |
+
return [self.process(doc) for doc in docs]
|
| 99 |
+
|
| 100 |
+
def _set_up_provides(self):
|
| 101 |
+
""" Set up what processor requirements this processor fulfills. Default is to use a class defined list. """
|
| 102 |
+
self._provides = self.__class__.PROVIDES_DEFAULT
|
| 103 |
+
|
| 104 |
+
def _set_up_requires(self):
|
| 105 |
+
""" Set up requirements for this processor. Default is to use a class defined list. """
|
| 106 |
+
self._requires = self.__class__.REQUIRES_DEFAULT
|
| 107 |
+
|
| 108 |
+
def _set_up_variant_requires(self):
|
| 109 |
+
"""
|
| 110 |
+
If this has a variant with its own requirements, use those instead
|
| 111 |
+
|
| 112 |
+
Returns True iff the _requires is set from the _variant
|
| 113 |
+
"""
|
| 114 |
+
if not hasattr(self, '_variant'):
|
| 115 |
+
return False
|
| 116 |
+
if hasattr(self._variant, '_set_up_requires'):
|
| 117 |
+
self._variant._set_up_requires()
|
| 118 |
+
self._requires = self._variant._requires
|
| 119 |
+
return True
|
| 120 |
+
if hasattr(self._variant.__class__, 'REQUIRES_DEFAULT'):
|
| 121 |
+
self._requires = self._variant.__class__.REQUIRES_DEFAULT
|
| 122 |
+
return True
|
| 123 |
+
return False
|
| 124 |
+
|
| 125 |
+
def _set_up_variants(self, config, device):
|
| 126 |
+
processor_name = list(self.__class__.PROVIDES_DEFAULT)[0]
|
| 127 |
+
if any(config.get(f'with_{variant}', False) for variant in PROCESSOR_VARIANTS[processor_name]):
|
| 128 |
+
self._trainer = None
|
| 129 |
+
variant_name = [variant for variant in PROCESSOR_VARIANTS[processor_name] if config.get(f'with_{variant}', False)][0]
|
| 130 |
+
self._variant = PROCESSOR_VARIANTS[processor_name][variant_name](config)
|
| 131 |
+
|
| 132 |
+
@property
|
| 133 |
+
def config(self):
|
| 134 |
+
""" Configurations for the processor """
|
| 135 |
+
return self._config
|
| 136 |
+
|
| 137 |
+
@property
|
| 138 |
+
def pipeline(self):
|
| 139 |
+
""" The pipeline that this processor belongs to """
|
| 140 |
+
return self._pipeline
|
| 141 |
+
|
| 142 |
+
@property
|
| 143 |
+
def provides(self):
|
| 144 |
+
return self._provides
|
| 145 |
+
|
| 146 |
+
@property
|
| 147 |
+
def requires(self):
|
| 148 |
+
return self._requires
|
| 149 |
+
|
| 150 |
+
def _check_requirements(self):
|
| 151 |
+
""" Given a list of fulfilled requirements, check if all of this processor's requirements are met or not. """
|
| 152 |
+
if not self.config.get("check_requirements", True):
|
| 153 |
+
return
|
| 154 |
+
provided_reqs = set.union(*[processor.provides for processor in self.pipeline.loaded_processors]+[set([])])
|
| 155 |
+
if self.requires - provided_reqs:
|
| 156 |
+
load_names = [item[0] for item in self.pipeline.load_list]
|
| 157 |
+
raise ProcessorRequirementsException(load_names, self, provided_reqs)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class ProcessorVariant(ABC):
|
| 161 |
+
""" Base class for all processor variants """
|
| 162 |
+
|
| 163 |
+
OVERRIDE = False # Set to true to override all the processing from the processor
|
| 164 |
+
|
| 165 |
+
@abstractmethod
|
| 166 |
+
def process(self, doc):
|
| 167 |
+
"""
|
| 168 |
+
Process a document that is potentially preprocessed by the processor.
|
| 169 |
+
This is the main method of a processor variant.
|
| 170 |
+
|
| 171 |
+
If `OVERRIDE` is set to True, all preprocessing by the processor would be bypassed, and the processor variant
|
| 172 |
+
would serve as a drop-in replacement of the entire processor, and has to be able to interpret all the configs
|
| 173 |
+
that are typically handled by the processor it replaces.
|
| 174 |
+
"""
|
| 175 |
+
pass
|
| 176 |
+
|
| 177 |
+
def bulk_process(self, docs):
|
| 178 |
+
""" Process a list of Documents. This should be replaced with a more efficient implementation if possible. """
|
| 179 |
+
|
| 180 |
+
return [self.process(doc) for doc in docs]
|
| 181 |
+
|
| 182 |
+
class UDProcessor(Processor):
|
| 183 |
+
""" Base class for the neural UD Processors (tokenize,mwt,pos,lemma,depparse,sentiment,constituency) """
|
| 184 |
+
|
| 185 |
+
def __init__(self, config, pipeline, device):
|
| 186 |
+
super().__init__(config, pipeline, device)
|
| 187 |
+
|
| 188 |
+
# UD model resources, set up is processor specific
|
| 189 |
+
self._pretrain = None
|
| 190 |
+
self._trainer = None
|
| 191 |
+
self._vocab = None
|
| 192 |
+
if not hasattr(self, '_variant'):
|
| 193 |
+
self._set_up_model(config, pipeline, device)
|
| 194 |
+
|
| 195 |
+
# build the final config for the processor
|
| 196 |
+
self._set_up_final_config(config)
|
| 197 |
+
|
| 198 |
+
@abstractmethod
|
| 199 |
+
def _set_up_model(self, config, pipeline, device):
|
| 200 |
+
pass
|
| 201 |
+
|
| 202 |
+
def _set_up_final_config(self, config):
|
| 203 |
+
""" Finalize the configurations for this processor, based off of values from a UD model. """
|
| 204 |
+
# set configurations from loaded model
|
| 205 |
+
if self._trainer is not None:
|
| 206 |
+
loaded_args, self._vocab = self._trainer.args, self._trainer.vocab
|
| 207 |
+
# filter out unneeded args from model
|
| 208 |
+
loaded_args = {k: v for k, v in loaded_args.items() if not UDProcessor.filter_out_option(k)}
|
| 209 |
+
else:
|
| 210 |
+
loaded_args = {}
|
| 211 |
+
loaded_args.update(config)
|
| 212 |
+
self._config = loaded_args
|
| 213 |
+
|
| 214 |
+
def mark_inactive(self):
|
| 215 |
+
""" Drop memory intensive resources if keeping this processor around for reasons other than running it. """
|
| 216 |
+
self._trainer = None
|
| 217 |
+
self._vocab = None
|
| 218 |
+
|
| 219 |
+
@property
|
| 220 |
+
def pretrain(self):
|
| 221 |
+
return self._pretrain
|
| 222 |
+
|
| 223 |
+
@property
|
| 224 |
+
def trainer(self):
|
| 225 |
+
return self._trainer
|
| 226 |
+
|
| 227 |
+
@property
|
| 228 |
+
def vocab(self):
|
| 229 |
+
return self._vocab
|
| 230 |
+
|
| 231 |
+
@staticmethod
|
| 232 |
+
def filter_out_option(option):
|
| 233 |
+
""" Filter out non-processor configurations """
|
| 234 |
+
options_to_filter = ['device', 'cpu', 'cuda', 'dev_conll_gold', 'epochs', 'lang', 'mode', 'save_name', 'shorthand']
|
| 235 |
+
if option.endswith('_file') or option.endswith('_dir'):
|
| 236 |
+
return True
|
| 237 |
+
elif option in options_to_filter:
|
| 238 |
+
return True
|
| 239 |
+
else:
|
| 240 |
+
return False
|
| 241 |
+
|
| 242 |
+
def bulk_process(self, docs):
|
| 243 |
+
"""
|
| 244 |
+
Most processors operate on the sentence level, where each sentence is processed independently and processors can benefit
|
| 245 |
+
a lot from the ability to combine sentences from multiple documents for faster batched processing. This is a transparent
|
| 246 |
+
implementation that allows these processors to batch process a list of Documents as if they were from a single Document.
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
if hasattr(self, '_variant'):
|
| 250 |
+
return self._variant.bulk_process(docs)
|
| 251 |
+
|
| 252 |
+
combined_sents = [sent for doc in docs for sent in doc.sentences]
|
| 253 |
+
combined_doc = Document([])
|
| 254 |
+
combined_doc.sentences = combined_sents
|
| 255 |
+
combined_doc.num_tokens = sum(doc.num_tokens for doc in docs)
|
| 256 |
+
combined_doc.num_words = sum(doc.num_words for doc in docs)
|
| 257 |
+
|
| 258 |
+
self.process(combined_doc) # annotations are attached to sentence objects
|
| 259 |
+
|
| 260 |
+
return docs
|
| 261 |
+
|
| 262 |
+
class ProcessorRegisterException(Exception):
|
| 263 |
+
""" Exception indicating processor or processor registration failure """
|
| 264 |
+
|
| 265 |
+
def __init__(self, processor_class, expected_parent):
|
| 266 |
+
self._processor_class = processor_class
|
| 267 |
+
self._expected_parent = expected_parent
|
| 268 |
+
self.build_message()
|
| 269 |
+
|
| 270 |
+
def build_message(self):
|
| 271 |
+
self.message = f"Failed to register '{self._processor_class}'. It must be a subclass of '{self._expected_parent}'."
|
| 272 |
+
|
| 273 |
+
def __str__(self):
|
| 274 |
+
return self.message
|
| 275 |
+
|
| 276 |
+
def register_processor(name):
|
| 277 |
+
def wrapper(Cls):
|
| 278 |
+
if not issubclass(Cls, Processor):
|
| 279 |
+
raise ProcessorRegisterException(Cls, Processor)
|
| 280 |
+
|
| 281 |
+
NAME_TO_PROCESSOR_CLASS[name] = Cls
|
| 282 |
+
PIPELINE_NAMES.append(name)
|
| 283 |
+
return Cls
|
| 284 |
+
return wrapper
|
| 285 |
+
|
| 286 |
+
def register_processor_variant(name, variant):
|
| 287 |
+
def wrapper(Cls):
|
| 288 |
+
if not issubclass(Cls, ProcessorVariant):
|
| 289 |
+
raise ProcessorRegisterException(Cls, ProcessorVariant)
|
| 290 |
+
|
| 291 |
+
PROCESSOR_VARIANTS[name][variant] = Cls
|
| 292 |
+
return Cls
|
| 293 |
+
return wrapper
|
stanza/stanza/pipeline/registry.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
|
| 3 |
+
# these two get filled by register_processor
|
| 4 |
+
NAME_TO_PROCESSOR_CLASS = dict()
|
| 5 |
+
PIPELINE_NAMES = []
|
| 6 |
+
|
| 7 |
+
# this gets filled by register_processor_variant
|
| 8 |
+
PROCESSOR_VARIANTS = defaultdict(dict)
|
stanza/stanza/pipeline/sentiment_processor.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Processor that attaches a sentiment score to a sentence
|
| 2 |
+
|
| 3 |
+
The model used is a generally a model trained on the Stanford
|
| 4 |
+
Sentiment Treebank or some similar dataset. When run, this processor
|
| 5 |
+
attaches a score in the form of a string to each sentence in the
|
| 6 |
+
document.
|
| 7 |
+
|
| 8 |
+
TODO: a possible way to generalize this would be to make it a
|
| 9 |
+
ClassifierProcessor and have "sentiment" be an option.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import dataclasses
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from types import SimpleNamespace
|
| 16 |
+
|
| 17 |
+
from stanza.models.classifiers.trainer import Trainer
|
| 18 |
+
|
| 19 |
+
from stanza.pipeline._constants import *
|
| 20 |
+
from stanza.pipeline.processor import UDProcessor, register_processor
|
| 21 |
+
|
| 22 |
+
@register_processor(SENTIMENT)
|
| 23 |
+
class SentimentProcessor(UDProcessor):
|
| 24 |
+
# set of processor requirements this processor fulfills
|
| 25 |
+
PROVIDES_DEFAULT = set([SENTIMENT])
|
| 26 |
+
# set of processor requirements for this processor
|
| 27 |
+
# TODO: a constituency based model needs CONSTITUENCY as well
|
| 28 |
+
# issue: by the time we load the model in Processor.__init__,
|
| 29 |
+
# the requirements are already prepared
|
| 30 |
+
REQUIRES_DEFAULT = set([TOKENIZE])
|
| 31 |
+
|
| 32 |
+
# default batch size, measured in words per batch
|
| 33 |
+
DEFAULT_BATCH_SIZE = 5000
|
| 34 |
+
|
| 35 |
+
def _set_up_model(self, config, pipeline, device):
|
| 36 |
+
# get pretrained word vectors
|
| 37 |
+
pretrain_path = config.get('pretrain_path', None)
|
| 38 |
+
forward_charlm_path = config.get('forward_charlm_path', None)
|
| 39 |
+
backward_charlm_path = config.get('backward_charlm_path', None)
|
| 40 |
+
# elmo does not have a convenient way to download intermediate
|
| 41 |
+
# models the way stanza downloads charlms & pretrains or
|
| 42 |
+
# transformers downloads bert etc
|
| 43 |
+
# however, elmo in general is not as good as using a
|
| 44 |
+
# transformer, so it is unlikely we will ever fix this
|
| 45 |
+
args = SimpleNamespace(device = device,
|
| 46 |
+
charlm_forward_file = forward_charlm_path,
|
| 47 |
+
charlm_backward_file = backward_charlm_path,
|
| 48 |
+
wordvec_pretrain_file = pretrain_path,
|
| 49 |
+
elmo_model = None,
|
| 50 |
+
use_elmo = False,
|
| 51 |
+
save_dir = None)
|
| 52 |
+
filename = config['model_path']
|
| 53 |
+
if filename is None:
|
| 54 |
+
raise FileNotFoundError("No model specified for the sentiment processor. Perhaps it is not supported for the language. {}".format(config))
|
| 55 |
+
# set up model
|
| 56 |
+
trainer = Trainer.load(filename=filename,
|
| 57 |
+
args=args,
|
| 58 |
+
foundation_cache=pipeline.foundation_cache)
|
| 59 |
+
self._trainer = trainer
|
| 60 |
+
self._model = trainer.model
|
| 61 |
+
self._model_type = self._model.config.model_type
|
| 62 |
+
# batch size counted as words
|
| 63 |
+
self._batch_size = config.get('batch_size', SentimentProcessor.DEFAULT_BATCH_SIZE)
|
| 64 |
+
|
| 65 |
+
def _set_up_final_config(self, config):
|
| 66 |
+
loaded_args = dataclasses.asdict(self._model.config)
|
| 67 |
+
loaded_args = {k: v for k, v in loaded_args.items() if not UDProcessor.filter_out_option(k)}
|
| 68 |
+
loaded_args.update(config)
|
| 69 |
+
self._config = loaded_args
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def process(self, document):
|
| 73 |
+
sentences = self._model.extract_sentences(document)
|
| 74 |
+
with torch.no_grad():
|
| 75 |
+
labels = self._model.label_sentences(sentences, batch_size=self._batch_size)
|
| 76 |
+
# TODO: allow a classifier processor for any attribute, not just sentiment
|
| 77 |
+
document.set(SENTIMENT, labels, to_sentence=True)
|
| 78 |
+
return document
|
stanza/stanza/pipeline/tokenize_processor.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Processor for performing tokenization
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import io
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from stanza.models.tokenization.data import TokenizationDataset
|
| 11 |
+
from stanza.models.tokenization.trainer import Trainer
|
| 12 |
+
from stanza.models.tokenization.utils import output_predictions
|
| 13 |
+
from stanza.pipeline._constants import *
|
| 14 |
+
from stanza.pipeline.processor import UDProcessor, register_processor
|
| 15 |
+
from stanza.pipeline.registry import PROCESSOR_VARIANTS
|
| 16 |
+
from stanza.models.common import doc
|
| 17 |
+
|
| 18 |
+
# these imports trigger the "register_variant" decorations
|
| 19 |
+
from stanza.pipeline.external.jieba import JiebaTokenizer
|
| 20 |
+
from stanza.pipeline.external.spacy import SpacyTokenizer
|
| 21 |
+
from stanza.pipeline.external.sudachipy import SudachiPyTokenizer
|
| 22 |
+
from stanza.pipeline.external.pythainlp import PyThaiNLPTokenizer
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger('stanza')
|
| 25 |
+
|
| 26 |
+
TOKEN_TOO_LONG_REPLACEMENT = "<UNK>"
|
| 27 |
+
|
| 28 |
+
# class for running the tokenizer
|
| 29 |
+
@register_processor(name=TOKENIZE)
|
| 30 |
+
class TokenizeProcessor(UDProcessor):
|
| 31 |
+
|
| 32 |
+
# set of processor requirements this processor fulfills
|
| 33 |
+
PROVIDES_DEFAULT = set([TOKENIZE])
|
| 34 |
+
# set of processor requirements for this processor
|
| 35 |
+
REQUIRES_DEFAULT = set([])
|
| 36 |
+
# default max sequence length
|
| 37 |
+
MAX_SEQ_LENGTH_DEFAULT = 1000
|
| 38 |
+
|
| 39 |
+
def _set_up_model(self, config, pipeline, device):
|
| 40 |
+
# set up trainer
|
| 41 |
+
if config.get('pretokenized'):
|
| 42 |
+
self._trainer = None
|
| 43 |
+
else:
|
| 44 |
+
self._trainer = Trainer(model_file=config['model_path'], device=device)
|
| 45 |
+
|
| 46 |
+
# get and typecheck the postprocessor
|
| 47 |
+
postprocessor = config.get('postprocessor')
|
| 48 |
+
if postprocessor and callable(postprocessor):
|
| 49 |
+
self._postprocessor = postprocessor
|
| 50 |
+
elif not postprocessor:
|
| 51 |
+
self._postprocessor = None
|
| 52 |
+
else:
|
| 53 |
+
raise ValueError("Tokenizer recieved 'postprocessor' option of unrecognized type; postprocessor must be callable. Got %s" % postprocessor)
|
| 54 |
+
|
| 55 |
+
def process_pre_tokenized_text(self, input_src):
|
| 56 |
+
"""
|
| 57 |
+
Pretokenized text can be provided in 2 manners:
|
| 58 |
+
|
| 59 |
+
1.) str, tokenized by whitespace, sentence split by newline
|
| 60 |
+
2.) list of token lists, each token list represents a sentence
|
| 61 |
+
|
| 62 |
+
generate dictionary data structure
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
document = []
|
| 66 |
+
if isinstance(input_src, str):
|
| 67 |
+
sentences = [sent.strip().split() for sent in input_src.strip().split('\n') if len(sent.strip()) > 0]
|
| 68 |
+
elif isinstance(input_src, list):
|
| 69 |
+
sentences = input_src
|
| 70 |
+
idx = 0
|
| 71 |
+
for sentence in sentences:
|
| 72 |
+
sent = []
|
| 73 |
+
for token_id, token in enumerate(sentence):
|
| 74 |
+
sent.append({doc.ID: (token_id + 1, ), doc.TEXT: token, doc.MISC: f'start_char={idx}|end_char={idx + len(token)}'})
|
| 75 |
+
idx += len(token) + 1
|
| 76 |
+
document.append(sent)
|
| 77 |
+
raw_text = ' '.join([' '.join(sentence) for sentence in sentences])
|
| 78 |
+
return raw_text, document
|
| 79 |
+
|
| 80 |
+
def process(self, document):
|
| 81 |
+
if not (isinstance(document, str) or isinstance(document, doc.Document) or (self.config.get('pretokenized') or self.config.get('no_ssplit', False))):
|
| 82 |
+
raise ValueError("If neither 'pretokenized' or 'no_ssplit' option is enabled, the input to the TokenizerProcessor must be a string or a Document object. Got %s" % str(type(document)))
|
| 83 |
+
|
| 84 |
+
if isinstance(document, doc.Document):
|
| 85 |
+
if self.config.get('pretokenized'):
|
| 86 |
+
return document
|
| 87 |
+
document = document.text
|
| 88 |
+
|
| 89 |
+
if self.config.get('pretokenized'):
|
| 90 |
+
raw_text, document = self.process_pre_tokenized_text(document)
|
| 91 |
+
return doc.Document(document, raw_text)
|
| 92 |
+
|
| 93 |
+
if hasattr(self, '_variant'):
|
| 94 |
+
return self._variant.process(document)
|
| 95 |
+
|
| 96 |
+
raw_text = '\n\n'.join(document) if isinstance(document, list) else document
|
| 97 |
+
|
| 98 |
+
max_seq_len = self.config.get('max_seqlen', TokenizeProcessor.MAX_SEQ_LENGTH_DEFAULT)
|
| 99 |
+
|
| 100 |
+
# set up batches
|
| 101 |
+
batches = TokenizationDataset(self.config, input_text=raw_text, vocab=self.vocab, evaluation=True, dictionary=self.trainer.dictionary)
|
| 102 |
+
# get dict data
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
_, _, _, document = output_predictions(None, self.trainer, batches, self.vocab, None,
|
| 105 |
+
max_seq_len,
|
| 106 |
+
orig_text=raw_text,
|
| 107 |
+
no_ssplit=self.config.get('no_ssplit', False),
|
| 108 |
+
num_workers = self.config.get('num_workers', 0),
|
| 109 |
+
postprocessor = self._postprocessor)
|
| 110 |
+
|
| 111 |
+
# replace excessively long tokens with <UNK> to avoid downstream GPU memory issues in POS
|
| 112 |
+
for sentence in document:
|
| 113 |
+
for token in sentence:
|
| 114 |
+
if len(token['text']) > max_seq_len:
|
| 115 |
+
token['text'] = TOKEN_TOO_LONG_REPLACEMENT
|
| 116 |
+
|
| 117 |
+
return doc.Document(document, raw_text)
|
| 118 |
+
|
| 119 |
+
def bulk_process(self, docs):
|
| 120 |
+
"""
|
| 121 |
+
The tokenizer cannot use UDProcessor's sentence-level cross-document batching interface, and requires special handling.
|
| 122 |
+
Essentially, this method concatenates the text of multiple documents with "\n\n", tokenizes it with the neural tokenizer,
|
| 123 |
+
then splits the result into the original Documents and recovers the original character offsets.
|
| 124 |
+
"""
|
| 125 |
+
if hasattr(self, '_variant'):
|
| 126 |
+
return self._variant.bulk_process(docs)
|
| 127 |
+
|
| 128 |
+
if self.config.get('pretokenized'):
|
| 129 |
+
res = []
|
| 130 |
+
for document in docs:
|
| 131 |
+
raw_text, document = self.process_pre_tokenized_text(document.text)
|
| 132 |
+
res.append(doc.Document(document, raw_text))
|
| 133 |
+
return res
|
| 134 |
+
|
| 135 |
+
combined_text = '\n\n'.join([thisdoc.text for thisdoc in docs])
|
| 136 |
+
processed_combined = self.process(doc.Document([], text=combined_text))
|
| 137 |
+
|
| 138 |
+
# postprocess sentences and tokens to reset back pointers and char offsets
|
| 139 |
+
charoffset = 0
|
| 140 |
+
sentst = senten = 0
|
| 141 |
+
for thisdoc in docs:
|
| 142 |
+
while senten < len(processed_combined.sentences) and processed_combined.sentences[senten].tokens[-1].end_char - charoffset <= len(thisdoc.text):
|
| 143 |
+
senten += 1
|
| 144 |
+
|
| 145 |
+
sentences = processed_combined.sentences[sentst:senten]
|
| 146 |
+
thisdoc.sentences = sentences
|
| 147 |
+
for sent in sentences:
|
| 148 |
+
# fix doc back pointers for sentences
|
| 149 |
+
sent._doc = thisdoc
|
| 150 |
+
|
| 151 |
+
# fix char offsets for tokens and words
|
| 152 |
+
for token in sent.tokens:
|
| 153 |
+
token._start_char -= charoffset
|
| 154 |
+
token._end_char -= charoffset
|
| 155 |
+
if token.words: # not-yet-processed MWT can leave empty tokens
|
| 156 |
+
for word in token.words:
|
| 157 |
+
word._start_char -= charoffset
|
| 158 |
+
word._end_char -= charoffset
|
| 159 |
+
|
| 160 |
+
# Here we need to fix up the SpacesAfter for the very last token
|
| 161 |
+
# and the SpacesBefore for the first token of the next doc
|
| 162 |
+
# After all, we had connected the text with \n\n
|
| 163 |
+
# Need to be careful about this - in a case such as
|
| 164 |
+
# " -text one- "
|
| 165 |
+
# " -text two- "
|
| 166 |
+
# We want the SpacesBefore for the second document to reflect
|
| 167 |
+
# the extra space at the start of its text
|
| 168 |
+
# and the SpacesAfter for the first document to reflect
|
| 169 |
+
# the whitespace after its text
|
| 170 |
+
if len(sentences) > 0:
|
| 171 |
+
last_token = sentences[-1].tokens[-1]
|
| 172 |
+
last_whitespace = thisdoc.text[last_token.end_char:]
|
| 173 |
+
last_token.spaces_after = last_whitespace
|
| 174 |
+
|
| 175 |
+
first_token = sentences[0].tokens[0]
|
| 176 |
+
first_whitespace = thisdoc.text[:first_token.start_char]
|
| 177 |
+
first_token.spaces_before = first_whitespace
|
| 178 |
+
|
| 179 |
+
thisdoc.num_tokens = sum(len(sent.tokens) for sent in sentences)
|
| 180 |
+
thisdoc.num_words = sum(len(sent.words) for sent in sentences)
|
| 181 |
+
sentst = senten
|
| 182 |
+
|
| 183 |
+
charoffset += len(thisdoc.text) + 2
|
| 184 |
+
|
| 185 |
+
return docs
|
stanza/stanza/protobuf/CoreNLP_pb2.py
ADDED
|
@@ -0,0 +1,686 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
| 3 |
+
# source: CoreNLP.proto
|
| 4 |
+
"""Generated protocol buffer code."""
|
| 5 |
+
from google.protobuf.internal import enum_type_wrapper
|
| 6 |
+
from google.protobuf import descriptor as _descriptor
|
| 7 |
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
| 8 |
+
from google.protobuf import message as _message
|
| 9 |
+
from google.protobuf import reflection as _reflection
|
| 10 |
+
from google.protobuf import symbol_database as _symbol_database
|
| 11 |
+
# @@protoc_insertion_point(imports)
|
| 12 |
+
|
| 13 |
+
_sym_db = _symbol_database.Default()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rCoreNLP.proto\x12\x19\x65\x64u.stanford.nlp.pipeline\"\xe1\x05\n\x08\x44ocument\x12\x0c\n\x04text\x18\x01 \x02(\t\x12\x35\n\x08sentence\x18\x02 \x03(\x0b\x32#.edu.stanford.nlp.pipeline.Sentence\x12\x39\n\ncorefChain\x18\x03 \x03(\x0b\x32%.edu.stanford.nlp.pipeline.CorefChain\x12\r\n\x05\x64ocID\x18\x04 \x01(\t\x12\x0f\n\x07\x64ocDate\x18\x07 \x01(\t\x12\x10\n\x08\x63\x61lendar\x18\x08 \x01(\x04\x12;\n\x11sentencelessToken\x18\x05 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x33\n\tcharacter\x18\n \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12/\n\x05quote\x18\x06 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Quote\x12\x37\n\x08mentions\x18\t \x03(\x0b\x32%.edu.stanford.nlp.pipeline.NERMention\x12#\n\x1bhasEntityMentionsAnnotation\x18\r \x01(\x08\x12\x0e\n\x06xmlDoc\x18\x0b \x01(\x08\x12\x34\n\x08sections\x18\x0c \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Section\x12<\n\x10mentionsForCoref\x18\x0e \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Mention\x12!\n\x19hasCorefMentionAnnotation\x18\x0f \x01(\x08\x12\x1a\n\x12hasCorefAnnotation\x18\x10 \x01(\x08\x12+\n#corefMentionToEntityMentionMappings\x18\x11 \x03(\x05\x12+\n#entityMentionToCorefMentionMappings\x18\x12 \x03(\x05*\x05\x08\x64\x10\x80\x02\"\xf3\x0f\n\x08Sentence\x12/\n\x05token\x18\x01 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x18\n\x10tokenOffsetBegin\x18\x02 \x02(\r\x12\x16\n\x0etokenOffsetEnd\x18\x03 \x02(\r\x12\x15\n\rsentenceIndex\x18\x04 \x01(\r\x12\x1c\n\x14\x63haracterOffsetBegin\x18\x05 \x01(\r\x12\x1a\n\x12\x63haracterOffsetEnd\x18\x06 \x01(\r\x12\x37\n\tparseTree\x18\x07 \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12@\n\x12\x62inarizedParseTree\x18\x1f \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12@\n\x12\x61nnotatedParseTree\x18 \x01(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\x11\n\tsentiment\x18! \x01(\t\x12=\n\x0fkBestParseTrees\x18\" \x03(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\x45\n\x11\x62\x61sicDependencies\x18\x08 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12I\n\x15\x63ollapsedDependencies\x18\t \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12T\n collapsedCCProcessedDependencies\x18\n \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12K\n\x17\x61lternativeDependencies\x18\r \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12?\n\x0copenieTriple\x18\x0e \x03(\x0b\x32).edu.stanford.nlp.pipeline.RelationTriple\x12<\n\tkbpTriple\x18\x10 \x03(\x0b\x32).edu.stanford.nlp.pipeline.RelationTriple\x12\x45\n\x10\x65ntailedSentence\x18\x0f \x03(\x0b\x32+.edu.stanford.nlp.pipeline.SentenceFragment\x12\x43\n\x0e\x65ntailedClause\x18# \x03(\x0b\x32+.edu.stanford.nlp.pipeline.SentenceFragment\x12H\n\x14\x65nhancedDependencies\x18\x11 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12P\n\x1c\x65nhancedPlusPlusDependencies\x18\x12 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12\x33\n\tcharacter\x18\x13 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x11\n\tparagraph\x18\x0b \x01(\r\x12\x0c\n\x04text\x18\x0c \x01(\t\x12\x12\n\nlineNumber\x18\x14 \x01(\r\x12\x1e\n\x16hasRelationAnnotations\x18\x33 \x01(\x08\x12\x31\n\x06\x65ntity\x18\x34 \x03(\x0b\x32!.edu.stanford.nlp.pipeline.Entity\x12\x35\n\x08relation\x18\x35 \x03(\x0b\x32#.edu.stanford.nlp.pipeline.Relation\x12$\n\x1chasNumerizedTokensAnnotation\x18\x36 \x01(\x08\x12\x37\n\x08mentions\x18\x37 \x03(\x0b\x32%.edu.stanford.nlp.pipeline.NERMention\x12<\n\x10mentionsForCoref\x18\x38 \x03(\x0b\x32\".edu.stanford.nlp.pipeline.Mention\x12\"\n\x1ahasCorefMentionsAnnotation\x18\x39 \x01(\x08\x12\x12\n\nsentenceID\x18: \x01(\t\x12\x13\n\x0bsectionDate\x18; \x01(\t\x12\x14\n\x0csectionIndex\x18< \x01(\r\x12\x13\n\x0bsectionName\x18= \x01(\t\x12\x15\n\rsectionAuthor\x18> \x01(\t\x12\r\n\x05\x64ocID\x18? \x01(\t\x12\x15\n\rsectionQuoted\x18@ \x01(\x08\x12#\n\x1bhasEntityMentionsAnnotation\x18\x41 \x01(\x08\x12\x1f\n\x17hasKBPTriplesAnnotation\x18\x44 \x01(\x08\x12\"\n\x1ahasOpenieTriplesAnnotation\x18\x45 \x01(\x08\x12\x14\n\x0c\x63hapterIndex\x18\x42 \x01(\r\x12\x16\n\x0eparagraphIndex\x18\x43 \x01(\r\x12=\n\x10\x65nhancedSentence\x18\x46 \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Sentence\x12\x0f\n\x07speaker\x18G \x01(\t\x12\x13\n\x0bspeakerType\x18H \x01(\t*\x05\x08\x64\x10\x80\x02\"\xf6\x0c\n\x05Token\x12\x0c\n\x04word\x18\x01 \x01(\t\x12\x0b\n\x03pos\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\t\x12\x10\n\x08\x63\x61tegory\x18\x04 \x01(\t\x12\x0e\n\x06\x62\x65\x66ore\x18\x05 \x01(\t\x12\r\n\x05\x61\x66ter\x18\x06 \x01(\t\x12\x14\n\x0coriginalText\x18\x07 \x01(\t\x12\x0b\n\x03ner\x18\x08 \x01(\t\x12\x11\n\tcoarseNER\x18> \x01(\t\x12\x16\n\x0e\x66ineGrainedNER\x18? \x01(\t\x12\x15\n\rnerLabelProbs\x18\x42 \x03(\t\x12\x15\n\rnormalizedNER\x18\t \x01(\t\x12\r\n\x05lemma\x18\n \x01(\t\x12\x11\n\tbeginChar\x18\x0b \x01(\r\x12\x0f\n\x07\x65ndChar\x18\x0c \x01(\r\x12\x11\n\tutterance\x18\r \x01(\r\x12\x0f\n\x07speaker\x18\x0e \x01(\t\x12\x13\n\x0bspeakerType\x18M \x01(\t\x12\x12\n\nbeginIndex\x18\x0f \x01(\r\x12\x10\n\x08\x65ndIndex\x18\x10 \x01(\r\x12\x17\n\x0ftokenBeginIndex\x18\x11 \x01(\r\x12\x15\n\rtokenEndIndex\x18\x12 \x01(\r\x12\x34\n\ntimexValue\x18\x13 \x01(\x0b\x32 .edu.stanford.nlp.pipeline.Timex\x12\x15\n\rhasXmlContext\x18\x15 \x01(\x08\x12\x12\n\nxmlContext\x18\x16 \x03(\t\x12\x16\n\x0e\x63orefClusterID\x18\x17 \x01(\r\x12\x0e\n\x06\x61nswer\x18\x18 \x01(\t\x12\x15\n\rheadWordIndex\x18\x1a \x01(\r\x12\x35\n\x08operator\x18\x1b \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Operator\x12\x35\n\x08polarity\x18\x1c \x01(\x0b\x32#.edu.stanford.nlp.pipeline.Polarity\x12\x14\n\x0cpolarity_dir\x18\' \x01(\t\x12-\n\x04span\x18\x1d \x01(\x0b\x32\x1f.edu.stanford.nlp.pipeline.Span\x12\x11\n\tsentiment\x18\x1e \x01(\t\x12\x16\n\x0equotationIndex\x18\x1f \x01(\x05\x12\x42\n\x0e\x63onllUFeatures\x18 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.MapStringString\x12\x11\n\tcoarseTag\x18! \x01(\t\x12\x38\n\x0f\x63onllUTokenSpan\x18\" \x01(\x0b\x32\x1f.edu.stanford.nlp.pipeline.Span\x12\x12\n\nconllUMisc\x18# \x01(\t\x12G\n\x13\x63onllUSecondaryDeps\x18$ \x01(\x0b\x32*.edu.stanford.nlp.pipeline.MapStringString\x12\x17\n\x0fwikipediaEntity\x18% \x01(\t\x12\x11\n\tisNewline\x18& \x01(\x08\x12\x0e\n\x06gender\x18\x33 \x01(\t\x12\x10\n\x08trueCase\x18\x34 \x01(\t\x12\x14\n\x0ctrueCaseText\x18\x35 \x01(\t\x12\x13\n\x0b\x63hineseChar\x18\x36 \x01(\t\x12\x12\n\nchineseSeg\x18\x37 \x01(\t\x12\x16\n\x0e\x63hineseXMLChar\x18< \x01(\t\x12\x11\n\tarabicSeg\x18L \x01(\t\x12\x13\n\x0bsectionName\x18\x38 \x01(\t\x12\x15\n\rsectionAuthor\x18\x39 \x01(\t\x12\x13\n\x0bsectionDate\x18: \x01(\t\x12\x17\n\x0fsectionEndLabel\x18; \x01(\t\x12\x0e\n\x06parent\x18= \x01(\t\x12\x19\n\x11\x63orefMentionIndex\x18@ \x03(\r\x12\x1a\n\x12\x65ntityMentionIndex\x18\x41 \x01(\r\x12\r\n\x05isMWT\x18\x43 \x01(\x08\x12\x12\n\nisFirstMWT\x18\x44 \x01(\x08\x12\x0f\n\x07mwtText\x18\x45 \x01(\t\x12\x0f\n\x07mwtMisc\x18N \x01(\t\x12\x14\n\x0cnumericValue\x18\x46 \x01(\x04\x12\x13\n\x0bnumericType\x18G \x01(\t\x12\x1d\n\x15numericCompositeValue\x18H \x01(\x04\x12\x1c\n\x14numericCompositeType\x18I \x01(\t\x12\x1c\n\x14\x63odepointOffsetBegin\x18J \x01(\r\x12\x1a\n\x12\x63odepointOffsetEnd\x18K \x01(\r\x12\r\n\x05index\x18O \x01(\r\x12\x12\n\nemptyIndex\x18P \x01(\r*\x05\x08\x64\x10\x80\x02\"\xe4\x03\n\x05Quote\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05\x62\x65gin\x18\x02 \x01(\r\x12\x0b\n\x03\x65nd\x18\x03 \x01(\r\x12\x15\n\rsentenceBegin\x18\x05 \x01(\r\x12\x13\n\x0bsentenceEnd\x18\x06 \x01(\r\x12\x12\n\ntokenBegin\x18\x07 \x01(\r\x12\x10\n\x08tokenEnd\x18\x08 \x01(\r\x12\r\n\x05\x64ocid\x18\t \x01(\t\x12\r\n\x05index\x18\n \x01(\r\x12\x0e\n\x06\x61uthor\x18\x0b \x01(\t\x12\x0f\n\x07mention\x18\x0c \x01(\t\x12\x14\n\x0cmentionBegin\x18\r \x01(\r\x12\x12\n\nmentionEnd\x18\x0e \x01(\r\x12\x13\n\x0bmentionType\x18\x0f \x01(\t\x12\x14\n\x0cmentionSieve\x18\x10 \x01(\t\x12\x0f\n\x07speaker\x18\x11 \x01(\t\x12\x14\n\x0cspeakerSieve\x18\x12 \x01(\t\x12\x18\n\x10\x63\x61nonicalMention\x18\x13 \x01(\t\x12\x1d\n\x15\x63\x61nonicalMentionBegin\x18\x14 \x01(\r\x12\x1b\n\x13\x63\x61nonicalMentionEnd\x18\x15 \x01(\r\x12N\n\x1a\x61ttributionDependencyGraph\x18\x16 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\"\xc7\x01\n\tParseTree\x12\x33\n\x05\x63hild\x18\x01 \x03(\x0b\x32$.edu.stanford.nlp.pipeline.ParseTree\x12\r\n\x05value\x18\x02 \x01(\t\x12\x17\n\x0fyieldBeginIndex\x18\x03 \x01(\r\x12\x15\n\ryieldEndIndex\x18\x04 \x01(\r\x12\r\n\x05score\x18\x05 \x01(\x01\x12\x37\n\tsentiment\x18\x06 \x01(\x0e\x32$.edu.stanford.nlp.pipeline.Sentiment\"\x9b\x04\n\x0f\x44\x65pendencyGraph\x12=\n\x04node\x18\x01 \x03(\x0b\x32/.edu.stanford.nlp.pipeline.DependencyGraph.Node\x12=\n\x04\x65\x64ge\x18\x02 \x03(\x0b\x32/.edu.stanford.nlp.pipeline.DependencyGraph.Edge\x12\x10\n\x04root\x18\x03 \x03(\rB\x02\x10\x01\x12/\n\x05token\x18\x04 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x14\n\x08rootNode\x18\x05 \x03(\rB\x02\x10\x01\x1aX\n\x04Node\x12\x15\n\rsentenceIndex\x18\x01 \x02(\r\x12\r\n\x05index\x18\x02 \x02(\r\x12\x16\n\x0e\x63opyAnnotation\x18\x03 \x01(\r\x12\x12\n\nemptyIndex\x18\x04 \x01(\r\x1a\xd6\x01\n\x04\x45\x64ge\x12\x0e\n\x06source\x18\x01 \x02(\r\x12\x0e\n\x06target\x18\x02 \x02(\r\x12\x0b\n\x03\x64\x65p\x18\x03 \x01(\t\x12\x0f\n\x07isExtra\x18\x04 \x01(\x08\x12\x12\n\nsourceCopy\x18\x05 \x01(\r\x12\x12\n\ntargetCopy\x18\x06 \x01(\r\x12\x13\n\x0bsourceEmpty\x18\x08 \x01(\r\x12\x13\n\x0btargetEmpty\x18\t \x01(\r\x12>\n\x08language\x18\x07 \x01(\x0e\x32#.edu.stanford.nlp.pipeline.Language:\x07Unknown\"\xc6\x02\n\nCorefChain\x12\x0f\n\x07\x63hainID\x18\x01 \x02(\x05\x12\x43\n\x07mention\x18\x02 \x03(\x0b\x32\x32.edu.stanford.nlp.pipeline.CorefChain.CorefMention\x12\x16\n\x0erepresentative\x18\x03 \x02(\r\x1a\xc9\x01\n\x0c\x43orefMention\x12\x11\n\tmentionID\x18\x01 \x01(\x05\x12\x13\n\x0bmentionType\x18\x02 \x01(\t\x12\x0e\n\x06number\x18\x03 \x01(\t\x12\x0e\n\x06gender\x18\x04 \x01(\t\x12\x0f\n\x07\x61nimacy\x18\x05 \x01(\t\x12\x12\n\nbeginIndex\x18\x06 \x01(\r\x12\x10\n\x08\x65ndIndex\x18\x07 \x01(\r\x12\x11\n\theadIndex\x18\t \x01(\r\x12\x15\n\rsentenceIndex\x18\n \x01(\r\x12\x10\n\x08position\x18\x0b \x01(\r\"\xef\x08\n\x07Mention\x12\x11\n\tmentionID\x18\x01 \x01(\x05\x12\x13\n\x0bmentionType\x18\x02 \x01(\t\x12\x0e\n\x06number\x18\x03 \x01(\t\x12\x0e\n\x06gender\x18\x04 \x01(\t\x12\x0f\n\x07\x61nimacy\x18\x05 \x01(\t\x12\x0e\n\x06person\x18\x06 \x01(\t\x12\x12\n\nstartIndex\x18\x07 \x01(\r\x12\x10\n\x08\x65ndIndex\x18\t \x01(\r\x12\x11\n\theadIndex\x18\n \x01(\x05\x12\x12\n\nheadString\x18\x0b \x01(\t\x12\x11\n\tnerString\x18\x0c \x01(\t\x12\x13\n\x0boriginalRef\x18\r \x01(\x05\x12\x1a\n\x12goldCorefClusterID\x18\x0e \x01(\x05\x12\x16\n\x0e\x63orefClusterID\x18\x0f \x01(\x05\x12\x12\n\nmentionNum\x18\x10 \x01(\x05\x12\x0f\n\x07sentNum\x18\x11 \x01(\x05\x12\r\n\x05utter\x18\x12 \x01(\x05\x12\x11\n\tparagraph\x18\x13 \x01(\x05\x12\x11\n\tisSubject\x18\x14 \x01(\x08\x12\x16\n\x0eisDirectObject\x18\x15 \x01(\x08\x12\x18\n\x10isIndirectObject\x18\x16 \x01(\x08\x12\x1b\n\x13isPrepositionObject\x18\x17 \x01(\x08\x12\x0f\n\x07hasTwin\x18\x18 \x01(\x08\x12\x0f\n\x07generic\x18\x19 \x01(\x08\x12\x13\n\x0bisSingleton\x18\x1a \x01(\x08\x12\x1a\n\x12hasBasicDependency\x18\x1b \x01(\x08\x12\x1d\n\x15hasEnhancedDependency\x18\x1c \x01(\x08\x12\x1b\n\x13hasContextParseTree\x18\x1d \x01(\x08\x12?\n\x0fheadIndexedWord\x18\x1e \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12=\n\rdependingVerb\x18\x1f \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12\x38\n\x08headWord\x18 \x01(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12;\n\x0bspeakerInfo\x18! \x01(\x0b\x32&.edu.stanford.nlp.pipeline.SpeakerInfo\x12=\n\rsentenceWords\x18\x32 \x03(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12<\n\x0coriginalSpan\x18\x33 \x03(\x0b\x32&.edu.stanford.nlp.pipeline.IndexedWord\x12\x12\n\ndependents\x18\x34 \x03(\t\x12\x19\n\x11preprocessedTerms\x18\x35 \x03(\t\x12\x13\n\x0b\x61ppositions\x18\x36 \x03(\x05\x12\x1c\n\x14predicateNominatives\x18\x37 \x03(\x05\x12\x18\n\x10relativePronouns\x18\x38 \x03(\x05\x12\x13\n\x0blistMembers\x18\x39 \x03(\x05\x12\x15\n\rbelongToLists\x18: \x03(\x05\"X\n\x0bIndexedWord\x12\x13\n\x0bsentenceNum\x18\x01 \x01(\x05\x12\x12\n\ntokenIndex\x18\x02 \x01(\x05\x12\r\n\x05\x64ocID\x18\x03 \x01(\x05\x12\x11\n\tcopyCount\x18\x04 \x01(\r\"4\n\x0bSpeakerInfo\x12\x13\n\x0bspeakerName\x18\x01 \x01(\t\x12\x10\n\x08mentions\x18\x02 \x03(\x05\"\"\n\x04Span\x12\r\n\x05\x62\x65gin\x18\x01 \x02(\r\x12\x0b\n\x03\x65nd\x18\x02 \x02(\r\"w\n\x05Timex\x12\r\n\x05value\x18\x01 \x01(\t\x12\x10\n\x08\x61ltValue\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0b\n\x03tid\x18\x05 \x01(\t\x12\x12\n\nbeginPoint\x18\x06 \x01(\r\x12\x10\n\x08\x65ndPoint\x18\x07 \x01(\r\"\xdb\x01\n\x06\x45ntity\x12\x11\n\theadStart\x18\x06 \x01(\r\x12\x0f\n\x07headEnd\x18\x07 \x01(\r\x12\x13\n\x0bmentionType\x18\x08 \x01(\t\x12\x16\n\x0enormalizedName\x18\t \x01(\t\x12\x16\n\x0eheadTokenIndex\x18\n \x01(\r\x12\x0f\n\x07\x63orefID\x18\x0b \x01(\t\x12\x10\n\x08objectID\x18\x01 \x01(\t\x12\x13\n\x0b\x65xtentStart\x18\x02 \x01(\r\x12\x11\n\textentEnd\x18\x03 \x01(\r\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0f\n\x07subtype\x18\x05 \x01(\t\"\xb7\x01\n\x08Relation\x12\x0f\n\x07\x61rgName\x18\x06 \x03(\t\x12.\n\x03\x61rg\x18\x07 \x03(\x0b\x32!.edu.stanford.nlp.pipeline.Entity\x12\x11\n\tsignature\x18\x08 \x01(\t\x12\x10\n\x08objectID\x18\x01 \x01(\t\x12\x13\n\x0b\x65xtentStart\x18\x02 \x01(\r\x12\x11\n\textentEnd\x18\x03 \x01(\r\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x0f\n\x07subtype\x18\x05 \x01(\t\"\xb2\x01\n\x08Operator\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x1b\n\x13quantifierSpanBegin\x18\x02 \x02(\x05\x12\x19\n\x11quantifierSpanEnd\x18\x03 \x02(\x05\x12\x18\n\x10subjectSpanBegin\x18\x04 \x02(\x05\x12\x16\n\x0esubjectSpanEnd\x18\x05 \x02(\x05\x12\x17\n\x0fobjectSpanBegin\x18\x06 \x02(\x05\x12\x15\n\robjectSpanEnd\x18\x07 \x02(\x05\"\xa9\x04\n\x08Polarity\x12K\n\x12projectEquivalence\x18\x01 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12Q\n\x18projectForwardEntailment\x18\x02 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12Q\n\x18projectReverseEntailment\x18\x03 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12H\n\x0fprojectNegation\x18\x04 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12K\n\x12projectAlternation\x18\x05 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12\x45\n\x0cprojectCover\x18\x06 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\x12L\n\x13projectIndependence\x18\x07 \x02(\x0e\x32/.edu.stanford.nlp.pipeline.NaturalLogicRelation\"\xdd\x02\n\nNERMention\x12\x15\n\rsentenceIndex\x18\x01 \x01(\r\x12%\n\x1dtokenStartInSentenceInclusive\x18\x02 \x02(\r\x12#\n\x1btokenEndInSentenceExclusive\x18\x03 \x02(\r\x12\x0b\n\x03ner\x18\x04 \x02(\t\x12\x15\n\rnormalizedNER\x18\x05 \x01(\t\x12\x12\n\nentityType\x18\x06 \x01(\t\x12/\n\x05timex\x18\x07 \x01(\x0b\x32 .edu.stanford.nlp.pipeline.Timex\x12\x17\n\x0fwikipediaEntity\x18\x08 \x01(\t\x12\x0e\n\x06gender\x18\t \x01(\t\x12\x1a\n\x12\x65ntityMentionIndex\x18\n \x01(\r\x12#\n\x1b\x63\x61nonicalEntityMentionIndex\x18\x0b \x01(\r\x12\x19\n\x11\x65ntityMentionText\x18\x0c \x01(\t\"Y\n\x10SentenceFragment\x12\x12\n\ntokenIndex\x18\x01 \x03(\r\x12\x0c\n\x04root\x18\x02 \x01(\r\x12\x14\n\x0c\x61ssumedTruth\x18\x03 \x01(\x08\x12\r\n\x05score\x18\x04 \x01(\x01\":\n\rTokenLocation\x12\x15\n\rsentenceIndex\x18\x01 \x01(\r\x12\x12\n\ntokenIndex\x18\x02 \x01(\r\"\x9a\x03\n\x0eRelationTriple\x12\x0f\n\x07subject\x18\x01 \x01(\t\x12\x10\n\x08relation\x18\x02 \x01(\t\x12\x0e\n\x06object\x18\x03 \x01(\t\x12\x12\n\nconfidence\x18\x04 \x01(\x01\x12?\n\rsubjectTokens\x18\r \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12@\n\x0erelationTokens\x18\x0e \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12>\n\x0cobjectTokens\x18\x0f \x03(\x0b\x32(.edu.stanford.nlp.pipeline.TokenLocation\x12\x38\n\x04tree\x18\x08 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12\x0e\n\x06istmod\x18\t \x01(\x08\x12\x10\n\x08prefixBe\x18\n \x01(\x08\x12\x10\n\x08suffixBe\x18\x0b \x01(\x08\x12\x10\n\x08suffixOf\x18\x0c \x01(\x08\"-\n\x0fMapStringString\x12\x0b\n\x03key\x18\x01 \x03(\t\x12\r\n\x05value\x18\x02 \x03(\t\"*\n\x0cMapIntString\x12\x0b\n\x03key\x18\x01 \x03(\r\x12\r\n\x05value\x18\x02 \x03(\t\"\xfc\x01\n\x07Section\x12\x11\n\tcharBegin\x18\x01 \x02(\r\x12\x0f\n\x07\x63harEnd\x18\x02 \x02(\r\x12\x0e\n\x06\x61uthor\x18\x03 \x01(\t\x12\x17\n\x0fsentenceIndexes\x18\x04 \x03(\r\x12\x10\n\x08\x64\x61tetime\x18\x05 \x01(\t\x12\x30\n\x06quotes\x18\x06 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Quote\x12\x17\n\x0f\x61uthorCharBegin\x18\x07 \x01(\r\x12\x15\n\rauthorCharEnd\x18\x08 \x01(\r\x12\x30\n\x06xmlTag\x18\t \x02(\x0b\x32 .edu.stanford.nlp.pipeline.Token\"\xe4\x01\n\x0eSemgrexRequest\x12\x0f\n\x07semgrex\x18\x01 \x03(\t\x12\x45\n\x05query\x18\x02 \x03(\x0b\x32\x36.edu.stanford.nlp.pipeline.SemgrexRequest.Dependencies\x1az\n\x0c\x44\x65pendencies\x12/\n\x05token\x18\x01 \x03(\x0b\x32 .edu.stanford.nlp.pipeline.Token\x12\x39\n\x05graph\x18\x02 \x02(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\"\xfb\x05\n\x0fSemgrexResponse\x12\x46\n\x06result\x18\x01 \x03(\x0b\x32\x36.edu.stanford.nlp.pipeline.SemgrexResponse.GraphResult\x1a-\n\tNamedNode\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x12\n\nmatchIndex\x18\x02 \x02(\x05\x1a+\n\rNamedRelation\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x0c\n\x04reln\x18\x02 \x02(\t\x1a\x80\x01\n\tNamedEdge\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x0e\n\x06source\x18\x02 \x02(\x05\x12\x0e\n\x06target\x18\x03 \x02(\x05\x12\x0c\n\x04reln\x18\x04 \x01(\t\x12\x0f\n\x07isExtra\x18\x05 \x01(\x08\x12\x12\n\nsourceCopy\x18\x06 \x01(\r\x12\x12\n\ntargetCopy\x18\x07 \x01(\r\x1a\x95\x02\n\x05Match\x12\x12\n\nmatchIndex\x18\x01 \x02(\x05\x12\x42\n\x04node\x18\x02 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.SemgrexResponse.NamedNode\x12\x46\n\x04reln\x18\x03 \x03(\x0b\x32\x38.edu.stanford.nlp.pipeline.SemgrexResponse.NamedRelation\x12\x42\n\x04\x65\x64ge\x18\x06 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.SemgrexResponse.NamedEdge\x12\x12\n\ngraphIndex\x18\x04 \x01(\x05\x12\x14\n\x0csemgrexIndex\x18\x05 \x01(\x05\x1aP\n\rSemgrexResult\x12?\n\x05match\x18\x01 \x03(\x0b\x32\x30.edu.stanford.nlp.pipeline.SemgrexResponse.Match\x1aW\n\x0bGraphResult\x12H\n\x06result\x18\x01 \x03(\x0b\x32\x38.edu.stanford.nlp.pipeline.SemgrexResponse.SemgrexResult\"\xf0\x01\n\x0fSsurgeonRequest\x12\x45\n\x08ssurgeon\x18\x01 \x03(\x0b\x32\x33.edu.stanford.nlp.pipeline.SsurgeonRequest.Ssurgeon\x12\x39\n\x05graph\x18\x02 \x03(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x1a[\n\x08Ssurgeon\x12\x0f\n\x07semgrex\x18\x01 \x01(\t\x12\x11\n\toperation\x18\x02 \x03(\t\x12\n\n\x02id\x18\x03 \x01(\t\x12\r\n\x05notes\x18\x04 \x01(\t\x12\x10\n\x08language\x18\x05 \x01(\t\"\xbc\x01\n\x10SsurgeonResponse\x12J\n\x06result\x18\x01 \x03(\x0b\x32:.edu.stanford.nlp.pipeline.SsurgeonResponse.SsurgeonResult\x1a\\\n\x0eSsurgeonResult\x12\x39\n\x05graph\x18\x01 \x01(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12\x0f\n\x07\x63hanged\x18\x02 \x01(\x08\"W\n\x12TokensRegexRequest\x12\x30\n\x03\x64oc\x18\x01 \x02(\x0b\x32#.edu.stanford.nlp.pipeline.Document\x12\x0f\n\x07pattern\x18\x02 \x03(\t\"\xa7\x03\n\x13TokensRegexResponse\x12J\n\x05match\x18\x01 \x03(\x0b\x32;.edu.stanford.nlp.pipeline.TokensRegexResponse.PatternMatch\x1a\x39\n\rMatchLocation\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05\x62\x65gin\x18\x02 \x01(\x05\x12\x0b\n\x03\x65nd\x18\x03 \x01(\x05\x1a\xb3\x01\n\x05Match\x12\x10\n\x08sentence\x18\x01 \x02(\x05\x12K\n\x05match\x18\x02 \x02(\x0b\x32<.edu.stanford.nlp.pipeline.TokensRegexResponse.MatchLocation\x12K\n\x05group\x18\x03 \x03(\x0b\x32<.edu.stanford.nlp.pipeline.TokensRegexResponse.MatchLocation\x1aS\n\x0cPatternMatch\x12\x43\n\x05match\x18\x01 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.TokensRegexResponse.Match\"\xae\x01\n\x19\x44\x65pendencyEnhancerRequest\x12\x35\n\x08\x64ocument\x18\x01 \x02(\x0b\x32#.edu.stanford.nlp.pipeline.Document\x12\x37\n\x08language\x18\x02 \x01(\x0e\x32#.edu.stanford.nlp.pipeline.LanguageH\x00\x12\x1a\n\x10relativePronouns\x18\x03 \x01(\tH\x00\x42\x05\n\x03ref\"\xb4\x01\n\x12\x46lattenedParseTree\x12\x41\n\x05nodes\x18\x01 \x03(\x0b\x32\x32.edu.stanford.nlp.pipeline.FlattenedParseTree.Node\x1a[\n\x04Node\x12\x12\n\x08openNode\x18\x01 \x01(\x08H\x00\x12\x13\n\tcloseNode\x18\x02 \x01(\x08H\x00\x12\x0f\n\x05value\x18\x03 \x01(\tH\x00\x12\r\n\x05score\x18\x04 \x01(\x01\x42\n\n\x08\x63ontents\"\xf6\x01\n\x15\x45valuateParserRequest\x12N\n\x08treebank\x18\x01 \x03(\x0b\x32<.edu.stanford.nlp.pipeline.EvaluateParserRequest.ParseResult\x1a\x8c\x01\n\x0bParseResult\x12;\n\x04gold\x18\x01 \x02(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\x12@\n\tpredicted\x18\x02 \x03(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\"E\n\x16\x45valuateParserResponse\x12\n\n\x02\x66\x31\x18\x01 \x02(\x01\x12\x0f\n\x07kbestF1\x18\x02 \x01(\x01\x12\x0e\n\x06treeF1\x18\x03 \x03(\x01\"\xc8\x01\n\x0fTsurgeonRequest\x12H\n\noperations\x18\x01 \x03(\x0b\x32\x34.edu.stanford.nlp.pipeline.TsurgeonRequest.Operation\x12<\n\x05trees\x18\x02 \x03(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\x1a-\n\tOperation\x12\x0e\n\x06tregex\x18\x01 \x02(\t\x12\x10\n\x08tsurgeon\x18\x02 \x03(\t\"P\n\x10TsurgeonResponse\x12<\n\x05trees\x18\x01 \x03(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\"\x85\x01\n\x11MorphologyRequest\x12\x46\n\x05words\x18\x01 \x03(\x0b\x32\x37.edu.stanford.nlp.pipeline.MorphologyRequest.TaggedWord\x1a(\n\nTaggedWord\x12\x0c\n\x04word\x18\x01 \x02(\t\x12\x0c\n\x04xpos\x18\x02 \x01(\t\"\x9a\x01\n\x12MorphologyResponse\x12I\n\x05words\x18\x01 \x03(\x0b\x32:.edu.stanford.nlp.pipeline.MorphologyResponse.WordTagLemma\x1a\x39\n\x0cWordTagLemma\x12\x0c\n\x04word\x18\x01 \x02(\t\x12\x0c\n\x04xpos\x18\x02 \x01(\t\x12\r\n\x05lemma\x18\x03 \x02(\t\"Z\n\x1a\x44\x65pendencyConverterRequest\x12<\n\x05trees\x18\x01 \x03(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree\"\x90\x02\n\x1b\x44\x65pendencyConverterResponse\x12`\n\x0b\x63onversions\x18\x01 \x03(\x0b\x32K.edu.stanford.nlp.pipeline.DependencyConverterResponse.DependencyConversion\x1a\x8e\x01\n\x14\x44\x65pendencyConversion\x12\x39\n\x05graph\x18\x01 \x02(\x0b\x32*.edu.stanford.nlp.pipeline.DependencyGraph\x12;\n\x04tree\x18\x02 \x01(\x0b\x32-.edu.stanford.nlp.pipeline.FlattenedParseTree*\xa3\x01\n\x08Language\x12\x0b\n\x07Unknown\x10\x00\x12\x07\n\x03\x41ny\x10\x01\x12\n\n\x06\x41rabic\x10\x02\x12\x0b\n\x07\x43hinese\x10\x03\x12\x0b\n\x07\x45nglish\x10\x04\x12\n\n\x06German\x10\x05\x12\n\n\x06\x46rench\x10\x06\x12\n\n\x06Hebrew\x10\x07\x12\x0b\n\x07Spanish\x10\x08\x12\x14\n\x10UniversalEnglish\x10\t\x12\x14\n\x10UniversalChinese\x10\n*h\n\tSentiment\x12\x13\n\x0fSTRONG_NEGATIVE\x10\x00\x12\x11\n\rWEAK_NEGATIVE\x10\x01\x12\x0b\n\x07NEUTRAL\x10\x02\x12\x11\n\rWEAK_POSITIVE\x10\x03\x12\x13\n\x0fSTRONG_POSITIVE\x10\x04*\x93\x01\n\x14NaturalLogicRelation\x12\x0f\n\x0b\x45QUIVALENCE\x10\x00\x12\x16\n\x12\x46ORWARD_ENTAILMENT\x10\x01\x12\x16\n\x12REVERSE_ENTAILMENT\x10\x02\x12\x0c\n\x08NEGATION\x10\x03\x12\x0f\n\x0b\x41LTERNATION\x10\x04\x12\t\n\x05\x43OVER\x10\x05\x12\x10\n\x0cINDEPENDENCE\x10\x06\x42*\n\x19\x65\x64u.stanford.nlp.pipelineB\rCoreNLPProtos')
|
| 19 |
+
|
| 20 |
+
_LANGUAGE = DESCRIPTOR.enum_types_by_name['Language']
|
| 21 |
+
Language = enum_type_wrapper.EnumTypeWrapper(_LANGUAGE)
|
| 22 |
+
_SENTIMENT = DESCRIPTOR.enum_types_by_name['Sentiment']
|
| 23 |
+
Sentiment = enum_type_wrapper.EnumTypeWrapper(_SENTIMENT)
|
| 24 |
+
_NATURALLOGICRELATION = DESCRIPTOR.enum_types_by_name['NaturalLogicRelation']
|
| 25 |
+
NaturalLogicRelation = enum_type_wrapper.EnumTypeWrapper(_NATURALLOGICRELATION)
|
| 26 |
+
Unknown = 0
|
| 27 |
+
Any = 1
|
| 28 |
+
Arabic = 2
|
| 29 |
+
Chinese = 3
|
| 30 |
+
English = 4
|
| 31 |
+
German = 5
|
| 32 |
+
French = 6
|
| 33 |
+
Hebrew = 7
|
| 34 |
+
Spanish = 8
|
| 35 |
+
UniversalEnglish = 9
|
| 36 |
+
UniversalChinese = 10
|
| 37 |
+
STRONG_NEGATIVE = 0
|
| 38 |
+
WEAK_NEGATIVE = 1
|
| 39 |
+
NEUTRAL = 2
|
| 40 |
+
WEAK_POSITIVE = 3
|
| 41 |
+
STRONG_POSITIVE = 4
|
| 42 |
+
EQUIVALENCE = 0
|
| 43 |
+
FORWARD_ENTAILMENT = 1
|
| 44 |
+
REVERSE_ENTAILMENT = 2
|
| 45 |
+
NEGATION = 3
|
| 46 |
+
ALTERNATION = 4
|
| 47 |
+
COVER = 5
|
| 48 |
+
INDEPENDENCE = 6
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
_DOCUMENT = DESCRIPTOR.message_types_by_name['Document']
|
| 52 |
+
_SENTENCE = DESCRIPTOR.message_types_by_name['Sentence']
|
| 53 |
+
_TOKEN = DESCRIPTOR.message_types_by_name['Token']
|
| 54 |
+
_QUOTE = DESCRIPTOR.message_types_by_name['Quote']
|
| 55 |
+
_PARSETREE = DESCRIPTOR.message_types_by_name['ParseTree']
|
| 56 |
+
_DEPENDENCYGRAPH = DESCRIPTOR.message_types_by_name['DependencyGraph']
|
| 57 |
+
_DEPENDENCYGRAPH_NODE = _DEPENDENCYGRAPH.nested_types_by_name['Node']
|
| 58 |
+
_DEPENDENCYGRAPH_EDGE = _DEPENDENCYGRAPH.nested_types_by_name['Edge']
|
| 59 |
+
_COREFCHAIN = DESCRIPTOR.message_types_by_name['CorefChain']
|
| 60 |
+
_COREFCHAIN_COREFMENTION = _COREFCHAIN.nested_types_by_name['CorefMention']
|
| 61 |
+
_MENTION = DESCRIPTOR.message_types_by_name['Mention']
|
| 62 |
+
_INDEXEDWORD = DESCRIPTOR.message_types_by_name['IndexedWord']
|
| 63 |
+
_SPEAKERINFO = DESCRIPTOR.message_types_by_name['SpeakerInfo']
|
| 64 |
+
_SPAN = DESCRIPTOR.message_types_by_name['Span']
|
| 65 |
+
_TIMEX = DESCRIPTOR.message_types_by_name['Timex']
|
| 66 |
+
_ENTITY = DESCRIPTOR.message_types_by_name['Entity']
|
| 67 |
+
_RELATION = DESCRIPTOR.message_types_by_name['Relation']
|
| 68 |
+
_OPERATOR = DESCRIPTOR.message_types_by_name['Operator']
|
| 69 |
+
_POLARITY = DESCRIPTOR.message_types_by_name['Polarity']
|
| 70 |
+
_NERMENTION = DESCRIPTOR.message_types_by_name['NERMention']
|
| 71 |
+
_SENTENCEFRAGMENT = DESCRIPTOR.message_types_by_name['SentenceFragment']
|
| 72 |
+
_TOKENLOCATION = DESCRIPTOR.message_types_by_name['TokenLocation']
|
| 73 |
+
_RELATIONTRIPLE = DESCRIPTOR.message_types_by_name['RelationTriple']
|
| 74 |
+
_MAPSTRINGSTRING = DESCRIPTOR.message_types_by_name['MapStringString']
|
| 75 |
+
_MAPINTSTRING = DESCRIPTOR.message_types_by_name['MapIntString']
|
| 76 |
+
_SECTION = DESCRIPTOR.message_types_by_name['Section']
|
| 77 |
+
_SEMGREXREQUEST = DESCRIPTOR.message_types_by_name['SemgrexRequest']
|
| 78 |
+
_SEMGREXREQUEST_DEPENDENCIES = _SEMGREXREQUEST.nested_types_by_name['Dependencies']
|
| 79 |
+
_SEMGREXRESPONSE = DESCRIPTOR.message_types_by_name['SemgrexResponse']
|
| 80 |
+
_SEMGREXRESPONSE_NAMEDNODE = _SEMGREXRESPONSE.nested_types_by_name['NamedNode']
|
| 81 |
+
_SEMGREXRESPONSE_NAMEDRELATION = _SEMGREXRESPONSE.nested_types_by_name['NamedRelation']
|
| 82 |
+
_SEMGREXRESPONSE_NAMEDEDGE = _SEMGREXRESPONSE.nested_types_by_name['NamedEdge']
|
| 83 |
+
_SEMGREXRESPONSE_MATCH = _SEMGREXRESPONSE.nested_types_by_name['Match']
|
| 84 |
+
_SEMGREXRESPONSE_SEMGREXRESULT = _SEMGREXRESPONSE.nested_types_by_name['SemgrexResult']
|
| 85 |
+
_SEMGREXRESPONSE_GRAPHRESULT = _SEMGREXRESPONSE.nested_types_by_name['GraphResult']
|
| 86 |
+
_SSURGEONREQUEST = DESCRIPTOR.message_types_by_name['SsurgeonRequest']
|
| 87 |
+
_SSURGEONREQUEST_SSURGEON = _SSURGEONREQUEST.nested_types_by_name['Ssurgeon']
|
| 88 |
+
_SSURGEONRESPONSE = DESCRIPTOR.message_types_by_name['SsurgeonResponse']
|
| 89 |
+
_SSURGEONRESPONSE_SSURGEONRESULT = _SSURGEONRESPONSE.nested_types_by_name['SsurgeonResult']
|
| 90 |
+
_TOKENSREGEXREQUEST = DESCRIPTOR.message_types_by_name['TokensRegexRequest']
|
| 91 |
+
_TOKENSREGEXRESPONSE = DESCRIPTOR.message_types_by_name['TokensRegexResponse']
|
| 92 |
+
_TOKENSREGEXRESPONSE_MATCHLOCATION = _TOKENSREGEXRESPONSE.nested_types_by_name['MatchLocation']
|
| 93 |
+
_TOKENSREGEXRESPONSE_MATCH = _TOKENSREGEXRESPONSE.nested_types_by_name['Match']
|
| 94 |
+
_TOKENSREGEXRESPONSE_PATTERNMATCH = _TOKENSREGEXRESPONSE.nested_types_by_name['PatternMatch']
|
| 95 |
+
_DEPENDENCYENHANCERREQUEST = DESCRIPTOR.message_types_by_name['DependencyEnhancerRequest']
|
| 96 |
+
_FLATTENEDPARSETREE = DESCRIPTOR.message_types_by_name['FlattenedParseTree']
|
| 97 |
+
_FLATTENEDPARSETREE_NODE = _FLATTENEDPARSETREE.nested_types_by_name['Node']
|
| 98 |
+
_EVALUATEPARSERREQUEST = DESCRIPTOR.message_types_by_name['EvaluateParserRequest']
|
| 99 |
+
_EVALUATEPARSERREQUEST_PARSERESULT = _EVALUATEPARSERREQUEST.nested_types_by_name['ParseResult']
|
| 100 |
+
_EVALUATEPARSERRESPONSE = DESCRIPTOR.message_types_by_name['EvaluateParserResponse']
|
| 101 |
+
_TSURGEONREQUEST = DESCRIPTOR.message_types_by_name['TsurgeonRequest']
|
| 102 |
+
_TSURGEONREQUEST_OPERATION = _TSURGEONREQUEST.nested_types_by_name['Operation']
|
| 103 |
+
_TSURGEONRESPONSE = DESCRIPTOR.message_types_by_name['TsurgeonResponse']
|
| 104 |
+
_MORPHOLOGYREQUEST = DESCRIPTOR.message_types_by_name['MorphologyRequest']
|
| 105 |
+
_MORPHOLOGYREQUEST_TAGGEDWORD = _MORPHOLOGYREQUEST.nested_types_by_name['TaggedWord']
|
| 106 |
+
_MORPHOLOGYRESPONSE = DESCRIPTOR.message_types_by_name['MorphologyResponse']
|
| 107 |
+
_MORPHOLOGYRESPONSE_WORDTAGLEMMA = _MORPHOLOGYRESPONSE.nested_types_by_name['WordTagLemma']
|
| 108 |
+
_DEPENDENCYCONVERTERREQUEST = DESCRIPTOR.message_types_by_name['DependencyConverterRequest']
|
| 109 |
+
_DEPENDENCYCONVERTERRESPONSE = DESCRIPTOR.message_types_by_name['DependencyConverterResponse']
|
| 110 |
+
_DEPENDENCYCONVERTERRESPONSE_DEPENDENCYCONVERSION = _DEPENDENCYCONVERTERRESPONSE.nested_types_by_name['DependencyConversion']
|
| 111 |
+
Document = _reflection.GeneratedProtocolMessageType('Document', (_message.Message,), {
|
| 112 |
+
'DESCRIPTOR' : _DOCUMENT,
|
| 113 |
+
'__module__' : 'CoreNLP_pb2'
|
| 114 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Document)
|
| 115 |
+
})
|
| 116 |
+
_sym_db.RegisterMessage(Document)
|
| 117 |
+
|
| 118 |
+
Sentence = _reflection.GeneratedProtocolMessageType('Sentence', (_message.Message,), {
|
| 119 |
+
'DESCRIPTOR' : _SENTENCE,
|
| 120 |
+
'__module__' : 'CoreNLP_pb2'
|
| 121 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Sentence)
|
| 122 |
+
})
|
| 123 |
+
_sym_db.RegisterMessage(Sentence)
|
| 124 |
+
|
| 125 |
+
Token = _reflection.GeneratedProtocolMessageType('Token', (_message.Message,), {
|
| 126 |
+
'DESCRIPTOR' : _TOKEN,
|
| 127 |
+
'__module__' : 'CoreNLP_pb2'
|
| 128 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Token)
|
| 129 |
+
})
|
| 130 |
+
_sym_db.RegisterMessage(Token)
|
| 131 |
+
|
| 132 |
+
Quote = _reflection.GeneratedProtocolMessageType('Quote', (_message.Message,), {
|
| 133 |
+
'DESCRIPTOR' : _QUOTE,
|
| 134 |
+
'__module__' : 'CoreNLP_pb2'
|
| 135 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Quote)
|
| 136 |
+
})
|
| 137 |
+
_sym_db.RegisterMessage(Quote)
|
| 138 |
+
|
| 139 |
+
ParseTree = _reflection.GeneratedProtocolMessageType('ParseTree', (_message.Message,), {
|
| 140 |
+
'DESCRIPTOR' : _PARSETREE,
|
| 141 |
+
'__module__' : 'CoreNLP_pb2'
|
| 142 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.ParseTree)
|
| 143 |
+
})
|
| 144 |
+
_sym_db.RegisterMessage(ParseTree)
|
| 145 |
+
|
| 146 |
+
DependencyGraph = _reflection.GeneratedProtocolMessageType('DependencyGraph', (_message.Message,), {
|
| 147 |
+
|
| 148 |
+
'Node' : _reflection.GeneratedProtocolMessageType('Node', (_message.Message,), {
|
| 149 |
+
'DESCRIPTOR' : _DEPENDENCYGRAPH_NODE,
|
| 150 |
+
'__module__' : 'CoreNLP_pb2'
|
| 151 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.DependencyGraph.Node)
|
| 152 |
+
})
|
| 153 |
+
,
|
| 154 |
+
|
| 155 |
+
'Edge' : _reflection.GeneratedProtocolMessageType('Edge', (_message.Message,), {
|
| 156 |
+
'DESCRIPTOR' : _DEPENDENCYGRAPH_EDGE,
|
| 157 |
+
'__module__' : 'CoreNLP_pb2'
|
| 158 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.DependencyGraph.Edge)
|
| 159 |
+
})
|
| 160 |
+
,
|
| 161 |
+
'DESCRIPTOR' : _DEPENDENCYGRAPH,
|
| 162 |
+
'__module__' : 'CoreNLP_pb2'
|
| 163 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.DependencyGraph)
|
| 164 |
+
})
|
| 165 |
+
_sym_db.RegisterMessage(DependencyGraph)
|
| 166 |
+
_sym_db.RegisterMessage(DependencyGraph.Node)
|
| 167 |
+
_sym_db.RegisterMessage(DependencyGraph.Edge)
|
| 168 |
+
|
| 169 |
+
CorefChain = _reflection.GeneratedProtocolMessageType('CorefChain', (_message.Message,), {
|
| 170 |
+
|
| 171 |
+
'CorefMention' : _reflection.GeneratedProtocolMessageType('CorefMention', (_message.Message,), {
|
| 172 |
+
'DESCRIPTOR' : _COREFCHAIN_COREFMENTION,
|
| 173 |
+
'__module__' : 'CoreNLP_pb2'
|
| 174 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.CorefChain.CorefMention)
|
| 175 |
+
})
|
| 176 |
+
,
|
| 177 |
+
'DESCRIPTOR' : _COREFCHAIN,
|
| 178 |
+
'__module__' : 'CoreNLP_pb2'
|
| 179 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.CorefChain)
|
| 180 |
+
})
|
| 181 |
+
_sym_db.RegisterMessage(CorefChain)
|
| 182 |
+
_sym_db.RegisterMessage(CorefChain.CorefMention)
|
| 183 |
+
|
| 184 |
+
Mention = _reflection.GeneratedProtocolMessageType('Mention', (_message.Message,), {
|
| 185 |
+
'DESCRIPTOR' : _MENTION,
|
| 186 |
+
'__module__' : 'CoreNLP_pb2'
|
| 187 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Mention)
|
| 188 |
+
})
|
| 189 |
+
_sym_db.RegisterMessage(Mention)
|
| 190 |
+
|
| 191 |
+
IndexedWord = _reflection.GeneratedProtocolMessageType('IndexedWord', (_message.Message,), {
|
| 192 |
+
'DESCRIPTOR' : _INDEXEDWORD,
|
| 193 |
+
'__module__' : 'CoreNLP_pb2'
|
| 194 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.IndexedWord)
|
| 195 |
+
})
|
| 196 |
+
_sym_db.RegisterMessage(IndexedWord)
|
| 197 |
+
|
| 198 |
+
SpeakerInfo = _reflection.GeneratedProtocolMessageType('SpeakerInfo', (_message.Message,), {
|
| 199 |
+
'DESCRIPTOR' : _SPEAKERINFO,
|
| 200 |
+
'__module__' : 'CoreNLP_pb2'
|
| 201 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SpeakerInfo)
|
| 202 |
+
})
|
| 203 |
+
_sym_db.RegisterMessage(SpeakerInfo)
|
| 204 |
+
|
| 205 |
+
Span = _reflection.GeneratedProtocolMessageType('Span', (_message.Message,), {
|
| 206 |
+
'DESCRIPTOR' : _SPAN,
|
| 207 |
+
'__module__' : 'CoreNLP_pb2'
|
| 208 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Span)
|
| 209 |
+
})
|
| 210 |
+
_sym_db.RegisterMessage(Span)
|
| 211 |
+
|
| 212 |
+
Timex = _reflection.GeneratedProtocolMessageType('Timex', (_message.Message,), {
|
| 213 |
+
'DESCRIPTOR' : _TIMEX,
|
| 214 |
+
'__module__' : 'CoreNLP_pb2'
|
| 215 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Timex)
|
| 216 |
+
})
|
| 217 |
+
_sym_db.RegisterMessage(Timex)
|
| 218 |
+
|
| 219 |
+
Entity = _reflection.GeneratedProtocolMessageType('Entity', (_message.Message,), {
|
| 220 |
+
'DESCRIPTOR' : _ENTITY,
|
| 221 |
+
'__module__' : 'CoreNLP_pb2'
|
| 222 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Entity)
|
| 223 |
+
})
|
| 224 |
+
_sym_db.RegisterMessage(Entity)
|
| 225 |
+
|
| 226 |
+
Relation = _reflection.GeneratedProtocolMessageType('Relation', (_message.Message,), {
|
| 227 |
+
'DESCRIPTOR' : _RELATION,
|
| 228 |
+
'__module__' : 'CoreNLP_pb2'
|
| 229 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Relation)
|
| 230 |
+
})
|
| 231 |
+
_sym_db.RegisterMessage(Relation)
|
| 232 |
+
|
| 233 |
+
Operator = _reflection.GeneratedProtocolMessageType('Operator', (_message.Message,), {
|
| 234 |
+
'DESCRIPTOR' : _OPERATOR,
|
| 235 |
+
'__module__' : 'CoreNLP_pb2'
|
| 236 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Operator)
|
| 237 |
+
})
|
| 238 |
+
_sym_db.RegisterMessage(Operator)
|
| 239 |
+
|
| 240 |
+
Polarity = _reflection.GeneratedProtocolMessageType('Polarity', (_message.Message,), {
|
| 241 |
+
'DESCRIPTOR' : _POLARITY,
|
| 242 |
+
'__module__' : 'CoreNLP_pb2'
|
| 243 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Polarity)
|
| 244 |
+
})
|
| 245 |
+
_sym_db.RegisterMessage(Polarity)
|
| 246 |
+
|
| 247 |
+
NERMention = _reflection.GeneratedProtocolMessageType('NERMention', (_message.Message,), {
|
| 248 |
+
'DESCRIPTOR' : _NERMENTION,
|
| 249 |
+
'__module__' : 'CoreNLP_pb2'
|
| 250 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.NERMention)
|
| 251 |
+
})
|
| 252 |
+
_sym_db.RegisterMessage(NERMention)
|
| 253 |
+
|
| 254 |
+
SentenceFragment = _reflection.GeneratedProtocolMessageType('SentenceFragment', (_message.Message,), {
|
| 255 |
+
'DESCRIPTOR' : _SENTENCEFRAGMENT,
|
| 256 |
+
'__module__' : 'CoreNLP_pb2'
|
| 257 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SentenceFragment)
|
| 258 |
+
})
|
| 259 |
+
_sym_db.RegisterMessage(SentenceFragment)
|
| 260 |
+
|
| 261 |
+
TokenLocation = _reflection.GeneratedProtocolMessageType('TokenLocation', (_message.Message,), {
|
| 262 |
+
'DESCRIPTOR' : _TOKENLOCATION,
|
| 263 |
+
'__module__' : 'CoreNLP_pb2'
|
| 264 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.TokenLocation)
|
| 265 |
+
})
|
| 266 |
+
_sym_db.RegisterMessage(TokenLocation)
|
| 267 |
+
|
| 268 |
+
RelationTriple = _reflection.GeneratedProtocolMessageType('RelationTriple', (_message.Message,), {
|
| 269 |
+
'DESCRIPTOR' : _RELATIONTRIPLE,
|
| 270 |
+
'__module__' : 'CoreNLP_pb2'
|
| 271 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.RelationTriple)
|
| 272 |
+
})
|
| 273 |
+
_sym_db.RegisterMessage(RelationTriple)
|
| 274 |
+
|
| 275 |
+
MapStringString = _reflection.GeneratedProtocolMessageType('MapStringString', (_message.Message,), {
|
| 276 |
+
'DESCRIPTOR' : _MAPSTRINGSTRING,
|
| 277 |
+
'__module__' : 'CoreNLP_pb2'
|
| 278 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.MapStringString)
|
| 279 |
+
})
|
| 280 |
+
_sym_db.RegisterMessage(MapStringString)
|
| 281 |
+
|
| 282 |
+
MapIntString = _reflection.GeneratedProtocolMessageType('MapIntString', (_message.Message,), {
|
| 283 |
+
'DESCRIPTOR' : _MAPINTSTRING,
|
| 284 |
+
'__module__' : 'CoreNLP_pb2'
|
| 285 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.MapIntString)
|
| 286 |
+
})
|
| 287 |
+
_sym_db.RegisterMessage(MapIntString)
|
| 288 |
+
|
| 289 |
+
Section = _reflection.GeneratedProtocolMessageType('Section', (_message.Message,), {
|
| 290 |
+
'DESCRIPTOR' : _SECTION,
|
| 291 |
+
'__module__' : 'CoreNLP_pb2'
|
| 292 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.Section)
|
| 293 |
+
})
|
| 294 |
+
_sym_db.RegisterMessage(Section)
|
| 295 |
+
|
| 296 |
+
SemgrexRequest = _reflection.GeneratedProtocolMessageType('SemgrexRequest', (_message.Message,), {
|
| 297 |
+
|
| 298 |
+
'Dependencies' : _reflection.GeneratedProtocolMessageType('Dependencies', (_message.Message,), {
|
| 299 |
+
'DESCRIPTOR' : _SEMGREXREQUEST_DEPENDENCIES,
|
| 300 |
+
'__module__' : 'CoreNLP_pb2'
|
| 301 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SemgrexRequest.Dependencies)
|
| 302 |
+
})
|
| 303 |
+
,
|
| 304 |
+
'DESCRIPTOR' : _SEMGREXREQUEST,
|
| 305 |
+
'__module__' : 'CoreNLP_pb2'
|
| 306 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SemgrexRequest)
|
| 307 |
+
})
|
| 308 |
+
_sym_db.RegisterMessage(SemgrexRequest)
|
| 309 |
+
_sym_db.RegisterMessage(SemgrexRequest.Dependencies)
|
| 310 |
+
|
| 311 |
+
SemgrexResponse = _reflection.GeneratedProtocolMessageType('SemgrexResponse', (_message.Message,), {
|
| 312 |
+
|
| 313 |
+
'NamedNode' : _reflection.GeneratedProtocolMessageType('NamedNode', (_message.Message,), {
|
| 314 |
+
'DESCRIPTOR' : _SEMGREXRESPONSE_NAMEDNODE,
|
| 315 |
+
'__module__' : 'CoreNLP_pb2'
|
| 316 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SemgrexResponse.NamedNode)
|
| 317 |
+
})
|
| 318 |
+
,
|
| 319 |
+
|
| 320 |
+
'NamedRelation' : _reflection.GeneratedProtocolMessageType('NamedRelation', (_message.Message,), {
|
| 321 |
+
'DESCRIPTOR' : _SEMGREXRESPONSE_NAMEDRELATION,
|
| 322 |
+
'__module__' : 'CoreNLP_pb2'
|
| 323 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SemgrexResponse.NamedRelation)
|
| 324 |
+
})
|
| 325 |
+
,
|
| 326 |
+
|
| 327 |
+
'NamedEdge' : _reflection.GeneratedProtocolMessageType('NamedEdge', (_message.Message,), {
|
| 328 |
+
'DESCRIPTOR' : _SEMGREXRESPONSE_NAMEDEDGE,
|
| 329 |
+
'__module__' : 'CoreNLP_pb2'
|
| 330 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SemgrexResponse.NamedEdge)
|
| 331 |
+
})
|
| 332 |
+
,
|
| 333 |
+
|
| 334 |
+
'Match' : _reflection.GeneratedProtocolMessageType('Match', (_message.Message,), {
|
| 335 |
+
'DESCRIPTOR' : _SEMGREXRESPONSE_MATCH,
|
| 336 |
+
'__module__' : 'CoreNLP_pb2'
|
| 337 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SemgrexResponse.Match)
|
| 338 |
+
})
|
| 339 |
+
,
|
| 340 |
+
|
| 341 |
+
'SemgrexResult' : _reflection.GeneratedProtocolMessageType('SemgrexResult', (_message.Message,), {
|
| 342 |
+
'DESCRIPTOR' : _SEMGREXRESPONSE_SEMGREXRESULT,
|
| 343 |
+
'__module__' : 'CoreNLP_pb2'
|
| 344 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SemgrexResponse.SemgrexResult)
|
| 345 |
+
})
|
| 346 |
+
,
|
| 347 |
+
|
| 348 |
+
'GraphResult' : _reflection.GeneratedProtocolMessageType('GraphResult', (_message.Message,), {
|
| 349 |
+
'DESCRIPTOR' : _SEMGREXRESPONSE_GRAPHRESULT,
|
| 350 |
+
'__module__' : 'CoreNLP_pb2'
|
| 351 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SemgrexResponse.GraphResult)
|
| 352 |
+
})
|
| 353 |
+
,
|
| 354 |
+
'DESCRIPTOR' : _SEMGREXRESPONSE,
|
| 355 |
+
'__module__' : 'CoreNLP_pb2'
|
| 356 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SemgrexResponse)
|
| 357 |
+
})
|
| 358 |
+
_sym_db.RegisterMessage(SemgrexResponse)
|
| 359 |
+
_sym_db.RegisterMessage(SemgrexResponse.NamedNode)
|
| 360 |
+
_sym_db.RegisterMessage(SemgrexResponse.NamedRelation)
|
| 361 |
+
_sym_db.RegisterMessage(SemgrexResponse.NamedEdge)
|
| 362 |
+
_sym_db.RegisterMessage(SemgrexResponse.Match)
|
| 363 |
+
_sym_db.RegisterMessage(SemgrexResponse.SemgrexResult)
|
| 364 |
+
_sym_db.RegisterMessage(SemgrexResponse.GraphResult)
|
| 365 |
+
|
| 366 |
+
SsurgeonRequest = _reflection.GeneratedProtocolMessageType('SsurgeonRequest', (_message.Message,), {
|
| 367 |
+
|
| 368 |
+
'Ssurgeon' : _reflection.GeneratedProtocolMessageType('Ssurgeon', (_message.Message,), {
|
| 369 |
+
'DESCRIPTOR' : _SSURGEONREQUEST_SSURGEON,
|
| 370 |
+
'__module__' : 'CoreNLP_pb2'
|
| 371 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SsurgeonRequest.Ssurgeon)
|
| 372 |
+
})
|
| 373 |
+
,
|
| 374 |
+
'DESCRIPTOR' : _SSURGEONREQUEST,
|
| 375 |
+
'__module__' : 'CoreNLP_pb2'
|
| 376 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SsurgeonRequest)
|
| 377 |
+
})
|
| 378 |
+
_sym_db.RegisterMessage(SsurgeonRequest)
|
| 379 |
+
_sym_db.RegisterMessage(SsurgeonRequest.Ssurgeon)
|
| 380 |
+
|
| 381 |
+
SsurgeonResponse = _reflection.GeneratedProtocolMessageType('SsurgeonResponse', (_message.Message,), {
|
| 382 |
+
|
| 383 |
+
'SsurgeonResult' : _reflection.GeneratedProtocolMessageType('SsurgeonResult', (_message.Message,), {
|
| 384 |
+
'DESCRIPTOR' : _SSURGEONRESPONSE_SSURGEONRESULT,
|
| 385 |
+
'__module__' : 'CoreNLP_pb2'
|
| 386 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SsurgeonResponse.SsurgeonResult)
|
| 387 |
+
})
|
| 388 |
+
,
|
| 389 |
+
'DESCRIPTOR' : _SSURGEONRESPONSE,
|
| 390 |
+
'__module__' : 'CoreNLP_pb2'
|
| 391 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.SsurgeonResponse)
|
| 392 |
+
})
|
| 393 |
+
_sym_db.RegisterMessage(SsurgeonResponse)
|
| 394 |
+
_sym_db.RegisterMessage(SsurgeonResponse.SsurgeonResult)
|
| 395 |
+
|
| 396 |
+
TokensRegexRequest = _reflection.GeneratedProtocolMessageType('TokensRegexRequest', (_message.Message,), {
|
| 397 |
+
'DESCRIPTOR' : _TOKENSREGEXREQUEST,
|
| 398 |
+
'__module__' : 'CoreNLP_pb2'
|
| 399 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.TokensRegexRequest)
|
| 400 |
+
})
|
| 401 |
+
_sym_db.RegisterMessage(TokensRegexRequest)
|
| 402 |
+
|
| 403 |
+
TokensRegexResponse = _reflection.GeneratedProtocolMessageType('TokensRegexResponse', (_message.Message,), {
|
| 404 |
+
|
| 405 |
+
'MatchLocation' : _reflection.GeneratedProtocolMessageType('MatchLocation', (_message.Message,), {
|
| 406 |
+
'DESCRIPTOR' : _TOKENSREGEXRESPONSE_MATCHLOCATION,
|
| 407 |
+
'__module__' : 'CoreNLP_pb2'
|
| 408 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.TokensRegexResponse.MatchLocation)
|
| 409 |
+
})
|
| 410 |
+
,
|
| 411 |
+
|
| 412 |
+
'Match' : _reflection.GeneratedProtocolMessageType('Match', (_message.Message,), {
|
| 413 |
+
'DESCRIPTOR' : _TOKENSREGEXRESPONSE_MATCH,
|
| 414 |
+
'__module__' : 'CoreNLP_pb2'
|
| 415 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.TokensRegexResponse.Match)
|
| 416 |
+
})
|
| 417 |
+
,
|
| 418 |
+
|
| 419 |
+
'PatternMatch' : _reflection.GeneratedProtocolMessageType('PatternMatch', (_message.Message,), {
|
| 420 |
+
'DESCRIPTOR' : _TOKENSREGEXRESPONSE_PATTERNMATCH,
|
| 421 |
+
'__module__' : 'CoreNLP_pb2'
|
| 422 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.TokensRegexResponse.PatternMatch)
|
| 423 |
+
})
|
| 424 |
+
,
|
| 425 |
+
'DESCRIPTOR' : _TOKENSREGEXRESPONSE,
|
| 426 |
+
'__module__' : 'CoreNLP_pb2'
|
| 427 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.TokensRegexResponse)
|
| 428 |
+
})
|
| 429 |
+
_sym_db.RegisterMessage(TokensRegexResponse)
|
| 430 |
+
_sym_db.RegisterMessage(TokensRegexResponse.MatchLocation)
|
| 431 |
+
_sym_db.RegisterMessage(TokensRegexResponse.Match)
|
| 432 |
+
_sym_db.RegisterMessage(TokensRegexResponse.PatternMatch)
|
| 433 |
+
|
| 434 |
+
DependencyEnhancerRequest = _reflection.GeneratedProtocolMessageType('DependencyEnhancerRequest', (_message.Message,), {
|
| 435 |
+
'DESCRIPTOR' : _DEPENDENCYENHANCERREQUEST,
|
| 436 |
+
'__module__' : 'CoreNLP_pb2'
|
| 437 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.DependencyEnhancerRequest)
|
| 438 |
+
})
|
| 439 |
+
_sym_db.RegisterMessage(DependencyEnhancerRequest)
|
| 440 |
+
|
| 441 |
+
FlattenedParseTree = _reflection.GeneratedProtocolMessageType('FlattenedParseTree', (_message.Message,), {
|
| 442 |
+
|
| 443 |
+
'Node' : _reflection.GeneratedProtocolMessageType('Node', (_message.Message,), {
|
| 444 |
+
'DESCRIPTOR' : _FLATTENEDPARSETREE_NODE,
|
| 445 |
+
'__module__' : 'CoreNLP_pb2'
|
| 446 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.FlattenedParseTree.Node)
|
| 447 |
+
})
|
| 448 |
+
,
|
| 449 |
+
'DESCRIPTOR' : _FLATTENEDPARSETREE,
|
| 450 |
+
'__module__' : 'CoreNLP_pb2'
|
| 451 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.FlattenedParseTree)
|
| 452 |
+
})
|
| 453 |
+
_sym_db.RegisterMessage(FlattenedParseTree)
|
| 454 |
+
_sym_db.RegisterMessage(FlattenedParseTree.Node)
|
| 455 |
+
|
| 456 |
+
EvaluateParserRequest = _reflection.GeneratedProtocolMessageType('EvaluateParserRequest', (_message.Message,), {
|
| 457 |
+
|
| 458 |
+
'ParseResult' : _reflection.GeneratedProtocolMessageType('ParseResult', (_message.Message,), {
|
| 459 |
+
'DESCRIPTOR' : _EVALUATEPARSERREQUEST_PARSERESULT,
|
| 460 |
+
'__module__' : 'CoreNLP_pb2'
|
| 461 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.EvaluateParserRequest.ParseResult)
|
| 462 |
+
})
|
| 463 |
+
,
|
| 464 |
+
'DESCRIPTOR' : _EVALUATEPARSERREQUEST,
|
| 465 |
+
'__module__' : 'CoreNLP_pb2'
|
| 466 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.EvaluateParserRequest)
|
| 467 |
+
})
|
| 468 |
+
_sym_db.RegisterMessage(EvaluateParserRequest)
|
| 469 |
+
_sym_db.RegisterMessage(EvaluateParserRequest.ParseResult)
|
| 470 |
+
|
| 471 |
+
EvaluateParserResponse = _reflection.GeneratedProtocolMessageType('EvaluateParserResponse', (_message.Message,), {
|
| 472 |
+
'DESCRIPTOR' : _EVALUATEPARSERRESPONSE,
|
| 473 |
+
'__module__' : 'CoreNLP_pb2'
|
| 474 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.EvaluateParserResponse)
|
| 475 |
+
})
|
| 476 |
+
_sym_db.RegisterMessage(EvaluateParserResponse)
|
| 477 |
+
|
| 478 |
+
TsurgeonRequest = _reflection.GeneratedProtocolMessageType('TsurgeonRequest', (_message.Message,), {
|
| 479 |
+
|
| 480 |
+
'Operation' : _reflection.GeneratedProtocolMessageType('Operation', (_message.Message,), {
|
| 481 |
+
'DESCRIPTOR' : _TSURGEONREQUEST_OPERATION,
|
| 482 |
+
'__module__' : 'CoreNLP_pb2'
|
| 483 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.TsurgeonRequest.Operation)
|
| 484 |
+
})
|
| 485 |
+
,
|
| 486 |
+
'DESCRIPTOR' : _TSURGEONREQUEST,
|
| 487 |
+
'__module__' : 'CoreNLP_pb2'
|
| 488 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.TsurgeonRequest)
|
| 489 |
+
})
|
| 490 |
+
_sym_db.RegisterMessage(TsurgeonRequest)
|
| 491 |
+
_sym_db.RegisterMessage(TsurgeonRequest.Operation)
|
| 492 |
+
|
| 493 |
+
TsurgeonResponse = _reflection.GeneratedProtocolMessageType('TsurgeonResponse', (_message.Message,), {
|
| 494 |
+
'DESCRIPTOR' : _TSURGEONRESPONSE,
|
| 495 |
+
'__module__' : 'CoreNLP_pb2'
|
| 496 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.TsurgeonResponse)
|
| 497 |
+
})
|
| 498 |
+
_sym_db.RegisterMessage(TsurgeonResponse)
|
| 499 |
+
|
| 500 |
+
MorphologyRequest = _reflection.GeneratedProtocolMessageType('MorphologyRequest', (_message.Message,), {
|
| 501 |
+
|
| 502 |
+
'TaggedWord' : _reflection.GeneratedProtocolMessageType('TaggedWord', (_message.Message,), {
|
| 503 |
+
'DESCRIPTOR' : _MORPHOLOGYREQUEST_TAGGEDWORD,
|
| 504 |
+
'__module__' : 'CoreNLP_pb2'
|
| 505 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.MorphologyRequest.TaggedWord)
|
| 506 |
+
})
|
| 507 |
+
,
|
| 508 |
+
'DESCRIPTOR' : _MORPHOLOGYREQUEST,
|
| 509 |
+
'__module__' : 'CoreNLP_pb2'
|
| 510 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.MorphologyRequest)
|
| 511 |
+
})
|
| 512 |
+
_sym_db.RegisterMessage(MorphologyRequest)
|
| 513 |
+
_sym_db.RegisterMessage(MorphologyRequest.TaggedWord)
|
| 514 |
+
|
| 515 |
+
MorphologyResponse = _reflection.GeneratedProtocolMessageType('MorphologyResponse', (_message.Message,), {
|
| 516 |
+
|
| 517 |
+
'WordTagLemma' : _reflection.GeneratedProtocolMessageType('WordTagLemma', (_message.Message,), {
|
| 518 |
+
'DESCRIPTOR' : _MORPHOLOGYRESPONSE_WORDTAGLEMMA,
|
| 519 |
+
'__module__' : 'CoreNLP_pb2'
|
| 520 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.MorphologyResponse.WordTagLemma)
|
| 521 |
+
})
|
| 522 |
+
,
|
| 523 |
+
'DESCRIPTOR' : _MORPHOLOGYRESPONSE,
|
| 524 |
+
'__module__' : 'CoreNLP_pb2'
|
| 525 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.MorphologyResponse)
|
| 526 |
+
})
|
| 527 |
+
_sym_db.RegisterMessage(MorphologyResponse)
|
| 528 |
+
_sym_db.RegisterMessage(MorphologyResponse.WordTagLemma)
|
| 529 |
+
|
| 530 |
+
DependencyConverterRequest = _reflection.GeneratedProtocolMessageType('DependencyConverterRequest', (_message.Message,), {
|
| 531 |
+
'DESCRIPTOR' : _DEPENDENCYCONVERTERREQUEST,
|
| 532 |
+
'__module__' : 'CoreNLP_pb2'
|
| 533 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.DependencyConverterRequest)
|
| 534 |
+
})
|
| 535 |
+
_sym_db.RegisterMessage(DependencyConverterRequest)
|
| 536 |
+
|
| 537 |
+
DependencyConverterResponse = _reflection.GeneratedProtocolMessageType('DependencyConverterResponse', (_message.Message,), {
|
| 538 |
+
|
| 539 |
+
'DependencyConversion' : _reflection.GeneratedProtocolMessageType('DependencyConversion', (_message.Message,), {
|
| 540 |
+
'DESCRIPTOR' : _DEPENDENCYCONVERTERRESPONSE_DEPENDENCYCONVERSION,
|
| 541 |
+
'__module__' : 'CoreNLP_pb2'
|
| 542 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.DependencyConverterResponse.DependencyConversion)
|
| 543 |
+
})
|
| 544 |
+
,
|
| 545 |
+
'DESCRIPTOR' : _DEPENDENCYCONVERTERRESPONSE,
|
| 546 |
+
'__module__' : 'CoreNLP_pb2'
|
| 547 |
+
# @@protoc_insertion_point(class_scope:edu.stanford.nlp.pipeline.DependencyConverterResponse)
|
| 548 |
+
})
|
| 549 |
+
_sym_db.RegisterMessage(DependencyConverterResponse)
|
| 550 |
+
_sym_db.RegisterMessage(DependencyConverterResponse.DependencyConversion)
|
| 551 |
+
|
| 552 |
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
| 553 |
+
|
| 554 |
+
DESCRIPTOR._options = None
|
| 555 |
+
DESCRIPTOR._serialized_options = b'\n\031edu.stanford.nlp.pipelineB\rCoreNLPProtos'
|
| 556 |
+
_DEPENDENCYGRAPH.fields_by_name['root']._options = None
|
| 557 |
+
_DEPENDENCYGRAPH.fields_by_name['root']._serialized_options = b'\020\001'
|
| 558 |
+
_DEPENDENCYGRAPH.fields_by_name['rootNode']._options = None
|
| 559 |
+
_DEPENDENCYGRAPH.fields_by_name['rootNode']._serialized_options = b'\020\001'
|
| 560 |
+
_LANGUAGE._serialized_start=13457
|
| 561 |
+
_LANGUAGE._serialized_end=13620
|
| 562 |
+
_SENTIMENT._serialized_start=13622
|
| 563 |
+
_SENTIMENT._serialized_end=13726
|
| 564 |
+
_NATURALLOGICRELATION._serialized_start=13729
|
| 565 |
+
_NATURALLOGICRELATION._serialized_end=13876
|
| 566 |
+
_DOCUMENT._serialized_start=45
|
| 567 |
+
_DOCUMENT._serialized_end=782
|
| 568 |
+
_SENTENCE._serialized_start=785
|
| 569 |
+
_SENTENCE._serialized_end=2820
|
| 570 |
+
_TOKEN._serialized_start=2823
|
| 571 |
+
_TOKEN._serialized_end=4477
|
| 572 |
+
_QUOTE._serialized_start=4480
|
| 573 |
+
_QUOTE._serialized_end=4964
|
| 574 |
+
_PARSETREE._serialized_start=4967
|
| 575 |
+
_PARSETREE._serialized_end=5166
|
| 576 |
+
_DEPENDENCYGRAPH._serialized_start=5169
|
| 577 |
+
_DEPENDENCYGRAPH._serialized_end=5708
|
| 578 |
+
_DEPENDENCYGRAPH_NODE._serialized_start=5403
|
| 579 |
+
_DEPENDENCYGRAPH_NODE._serialized_end=5491
|
| 580 |
+
_DEPENDENCYGRAPH_EDGE._serialized_start=5494
|
| 581 |
+
_DEPENDENCYGRAPH_EDGE._serialized_end=5708
|
| 582 |
+
_COREFCHAIN._serialized_start=5711
|
| 583 |
+
_COREFCHAIN._serialized_end=6037
|
| 584 |
+
_COREFCHAIN_COREFMENTION._serialized_start=5836
|
| 585 |
+
_COREFCHAIN_COREFMENTION._serialized_end=6037
|
| 586 |
+
_MENTION._serialized_start=6040
|
| 587 |
+
_MENTION._serialized_end=7175
|
| 588 |
+
_INDEXEDWORD._serialized_start=7177
|
| 589 |
+
_INDEXEDWORD._serialized_end=7265
|
| 590 |
+
_SPEAKERINFO._serialized_start=7267
|
| 591 |
+
_SPEAKERINFO._serialized_end=7319
|
| 592 |
+
_SPAN._serialized_start=7321
|
| 593 |
+
_SPAN._serialized_end=7355
|
| 594 |
+
_TIMEX._serialized_start=7357
|
| 595 |
+
_TIMEX._serialized_end=7476
|
| 596 |
+
_ENTITY._serialized_start=7479
|
| 597 |
+
_ENTITY._serialized_end=7698
|
| 598 |
+
_RELATION._serialized_start=7701
|
| 599 |
+
_RELATION._serialized_end=7884
|
| 600 |
+
_OPERATOR._serialized_start=7887
|
| 601 |
+
_OPERATOR._serialized_end=8065
|
| 602 |
+
_POLARITY._serialized_start=8068
|
| 603 |
+
_POLARITY._serialized_end=8621
|
| 604 |
+
_NERMENTION._serialized_start=8624
|
| 605 |
+
_NERMENTION._serialized_end=8973
|
| 606 |
+
_SENTENCEFRAGMENT._serialized_start=8975
|
| 607 |
+
_SENTENCEFRAGMENT._serialized_end=9064
|
| 608 |
+
_TOKENLOCATION._serialized_start=9066
|
| 609 |
+
_TOKENLOCATION._serialized_end=9124
|
| 610 |
+
_RELATIONTRIPLE._serialized_start=9127
|
| 611 |
+
_RELATIONTRIPLE._serialized_end=9537
|
| 612 |
+
_MAPSTRINGSTRING._serialized_start=9539
|
| 613 |
+
_MAPSTRINGSTRING._serialized_end=9584
|
| 614 |
+
_MAPINTSTRING._serialized_start=9586
|
| 615 |
+
_MAPINTSTRING._serialized_end=9628
|
| 616 |
+
_SECTION._serialized_start=9631
|
| 617 |
+
_SECTION._serialized_end=9883
|
| 618 |
+
_SEMGREXREQUEST._serialized_start=9886
|
| 619 |
+
_SEMGREXREQUEST._serialized_end=10114
|
| 620 |
+
_SEMGREXREQUEST_DEPENDENCIES._serialized_start=9992
|
| 621 |
+
_SEMGREXREQUEST_DEPENDENCIES._serialized_end=10114
|
| 622 |
+
_SEMGREXRESPONSE._serialized_start=10117
|
| 623 |
+
_SEMGREXRESPONSE._serialized_end=10880
|
| 624 |
+
_SEMGREXRESPONSE_NAMEDNODE._serialized_start=10208
|
| 625 |
+
_SEMGREXRESPONSE_NAMEDNODE._serialized_end=10253
|
| 626 |
+
_SEMGREXRESPONSE_NAMEDRELATION._serialized_start=10255
|
| 627 |
+
_SEMGREXRESPONSE_NAMEDRELATION._serialized_end=10298
|
| 628 |
+
_SEMGREXRESPONSE_NAMEDEDGE._serialized_start=10301
|
| 629 |
+
_SEMGREXRESPONSE_NAMEDEDGE._serialized_end=10429
|
| 630 |
+
_SEMGREXRESPONSE_MATCH._serialized_start=10432
|
| 631 |
+
_SEMGREXRESPONSE_MATCH._serialized_end=10709
|
| 632 |
+
_SEMGREXRESPONSE_SEMGREXRESULT._serialized_start=10711
|
| 633 |
+
_SEMGREXRESPONSE_SEMGREXRESULT._serialized_end=10791
|
| 634 |
+
_SEMGREXRESPONSE_GRAPHRESULT._serialized_start=10793
|
| 635 |
+
_SEMGREXRESPONSE_GRAPHRESULT._serialized_end=10880
|
| 636 |
+
_SSURGEONREQUEST._serialized_start=10883
|
| 637 |
+
_SSURGEONREQUEST._serialized_end=11123
|
| 638 |
+
_SSURGEONREQUEST_SSURGEON._serialized_start=11032
|
| 639 |
+
_SSURGEONREQUEST_SSURGEON._serialized_end=11123
|
| 640 |
+
_SSURGEONRESPONSE._serialized_start=11126
|
| 641 |
+
_SSURGEONRESPONSE._serialized_end=11314
|
| 642 |
+
_SSURGEONRESPONSE_SSURGEONRESULT._serialized_start=11222
|
| 643 |
+
_SSURGEONRESPONSE_SSURGEONRESULT._serialized_end=11314
|
| 644 |
+
_TOKENSREGEXREQUEST._serialized_start=11316
|
| 645 |
+
_TOKENSREGEXREQUEST._serialized_end=11403
|
| 646 |
+
_TOKENSREGEXRESPONSE._serialized_start=11406
|
| 647 |
+
_TOKENSREGEXRESPONSE._serialized_end=11829
|
| 648 |
+
_TOKENSREGEXRESPONSE_MATCHLOCATION._serialized_start=11505
|
| 649 |
+
_TOKENSREGEXRESPONSE_MATCHLOCATION._serialized_end=11562
|
| 650 |
+
_TOKENSREGEXRESPONSE_MATCH._serialized_start=11565
|
| 651 |
+
_TOKENSREGEXRESPONSE_MATCH._serialized_end=11744
|
| 652 |
+
_TOKENSREGEXRESPONSE_PATTERNMATCH._serialized_start=11746
|
| 653 |
+
_TOKENSREGEXRESPONSE_PATTERNMATCH._serialized_end=11829
|
| 654 |
+
_DEPENDENCYENHANCERREQUEST._serialized_start=11832
|
| 655 |
+
_DEPENDENCYENHANCERREQUEST._serialized_end=12006
|
| 656 |
+
_FLATTENEDPARSETREE._serialized_start=12009
|
| 657 |
+
_FLATTENEDPARSETREE._serialized_end=12189
|
| 658 |
+
_FLATTENEDPARSETREE_NODE._serialized_start=12098
|
| 659 |
+
_FLATTENEDPARSETREE_NODE._serialized_end=12189
|
| 660 |
+
_EVALUATEPARSERREQUEST._serialized_start=12192
|
| 661 |
+
_EVALUATEPARSERREQUEST._serialized_end=12438
|
| 662 |
+
_EVALUATEPARSERREQUEST_PARSERESULT._serialized_start=12298
|
| 663 |
+
_EVALUATEPARSERREQUEST_PARSERESULT._serialized_end=12438
|
| 664 |
+
_EVALUATEPARSERRESPONSE._serialized_start=12440
|
| 665 |
+
_EVALUATEPARSERRESPONSE._serialized_end=12509
|
| 666 |
+
_TSURGEONREQUEST._serialized_start=12512
|
| 667 |
+
_TSURGEONREQUEST._serialized_end=12712
|
| 668 |
+
_TSURGEONREQUEST_OPERATION._serialized_start=12667
|
| 669 |
+
_TSURGEONREQUEST_OPERATION._serialized_end=12712
|
| 670 |
+
_TSURGEONRESPONSE._serialized_start=12714
|
| 671 |
+
_TSURGEONRESPONSE._serialized_end=12794
|
| 672 |
+
_MORPHOLOGYREQUEST._serialized_start=12797
|
| 673 |
+
_MORPHOLOGYREQUEST._serialized_end=12930
|
| 674 |
+
_MORPHOLOGYREQUEST_TAGGEDWORD._serialized_start=12890
|
| 675 |
+
_MORPHOLOGYREQUEST_TAGGEDWORD._serialized_end=12930
|
| 676 |
+
_MORPHOLOGYRESPONSE._serialized_start=12933
|
| 677 |
+
_MORPHOLOGYRESPONSE._serialized_end=13087
|
| 678 |
+
_MORPHOLOGYRESPONSE_WORDTAGLEMMA._serialized_start=13030
|
| 679 |
+
_MORPHOLOGYRESPONSE_WORDTAGLEMMA._serialized_end=13087
|
| 680 |
+
_DEPENDENCYCONVERTERREQUEST._serialized_start=13089
|
| 681 |
+
_DEPENDENCYCONVERTERREQUEST._serialized_end=13179
|
| 682 |
+
_DEPENDENCYCONVERTERRESPONSE._serialized_start=13182
|
| 683 |
+
_DEPENDENCYCONVERTERRESPONSE._serialized_end=13454
|
| 684 |
+
_DEPENDENCYCONVERTERRESPONSE_DEPENDENCYCONVERSION._serialized_start=13312
|
| 685 |
+
_DEPENDENCYCONVERTERRESPONSE_DEPENDENCYCONVERSION._serialized_end=13454
|
| 686 |
+
# @@protoc_insertion_point(module_scope)
|
stanza/stanza/protobuf/__init__.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import
|
| 2 |
+
|
| 3 |
+
from io import BytesIO
|
| 4 |
+
import warnings
|
| 5 |
+
|
| 6 |
+
from google.protobuf.internal.encoder import _EncodeVarint
|
| 7 |
+
from google.protobuf.internal.decoder import _DecodeVarint
|
| 8 |
+
from google.protobuf.message import DecodeError
|
| 9 |
+
from .CoreNLP_pb2 import *
|
| 10 |
+
|
| 11 |
+
def parseFromDelimitedString(obj, buf, offset=0):
|
| 12 |
+
"""
|
| 13 |
+
Stanford CoreNLP uses the Java "writeDelimitedTo" function, which
|
| 14 |
+
writes the size (and offset) of the buffer before writing the object.
|
| 15 |
+
This function handles parsing this message starting from offset 0.
|
| 16 |
+
|
| 17 |
+
@returns how many bytes of @buf were consumed.
|
| 18 |
+
"""
|
| 19 |
+
size, pos = _DecodeVarint(buf, offset)
|
| 20 |
+
try:
|
| 21 |
+
obj.ParseFromString(buf[offset+pos:offset+pos+size])
|
| 22 |
+
except DecodeError as e:
|
| 23 |
+
warnings.warn("Failed to decode a serialized output from CoreNLP server. An incomplete or empty object will be returned.", \
|
| 24 |
+
RuntimeWarning)
|
| 25 |
+
return pos+size
|
| 26 |
+
|
| 27 |
+
def writeToDelimitedString(obj, stream=None):
|
| 28 |
+
"""
|
| 29 |
+
Stanford CoreNLP uses the Java "writeDelimitedTo" function, which
|
| 30 |
+
writes the size (and offset) of the buffer before writing the object.
|
| 31 |
+
This function handles parsing this message starting from offset 0.
|
| 32 |
+
|
| 33 |
+
@returns how many bytes of @buf were consumed.
|
| 34 |
+
"""
|
| 35 |
+
if stream is None:
|
| 36 |
+
stream = BytesIO()
|
| 37 |
+
|
| 38 |
+
_EncodeVarint(stream.write, obj.ByteSize(), True)
|
| 39 |
+
stream.write(obj.SerializeToString())
|
| 40 |
+
return stream
|
| 41 |
+
|
| 42 |
+
def to_text(sentence):
|
| 43 |
+
"""
|
| 44 |
+
Helper routine that converts a Sentence protobuf to a string from
|
| 45 |
+
its tokens.
|
| 46 |
+
"""
|
| 47 |
+
text = ""
|
| 48 |
+
for i, tok in enumerate(sentence.token):
|
| 49 |
+
if i != 0:
|
| 50 |
+
text += tok.before
|
| 51 |
+
text += tok.word
|
| 52 |
+
return text
|
stanza/stanza/resources/common.py
ADDED
|
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Common utilities for Stanza resources.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from collections import defaultdict, namedtuple
|
| 6 |
+
import errno
|
| 7 |
+
import hashlib
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import requests
|
| 13 |
+
import shutil
|
| 14 |
+
import tempfile
|
| 15 |
+
import zipfile
|
| 16 |
+
|
| 17 |
+
from tqdm.auto import tqdm
|
| 18 |
+
|
| 19 |
+
from stanza.utils.helper_func import make_table
|
| 20 |
+
from stanza.pipeline._constants import TOKENIZE, MWT, POS, LEMMA, DEPPARSE, NER, SENTIMENT
|
| 21 |
+
from stanza.pipeline.registry import PIPELINE_NAMES, PROCESSOR_VARIANTS
|
| 22 |
+
from stanza.resources.default_packages import PACKAGES
|
| 23 |
+
from stanza._version import __resources_version__
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger('stanza')
|
| 26 |
+
|
| 27 |
+
# set home dir for default
|
| 28 |
+
HOME_DIR = str(Path.home())
|
| 29 |
+
STANFORDNLP_RESOURCES_URL = 'https://nlp.stanford.edu/software/stanza/stanza-resources/'
|
| 30 |
+
STANZA_RESOURCES_GITHUB = 'https://raw.githubusercontent.com/stanfordnlp/stanza-resources/'
|
| 31 |
+
DEFAULT_RESOURCES_URL = os.getenv('STANZA_RESOURCES_URL', STANZA_RESOURCES_GITHUB + 'main')
|
| 32 |
+
DEFAULT_RESOURCES_VERSION = os.getenv(
|
| 33 |
+
'STANZA_RESOURCES_VERSION',
|
| 34 |
+
__resources_version__
|
| 35 |
+
)
|
| 36 |
+
DEFAULT_MODEL_URL = os.getenv('STANZA_MODEL_URL', 'default')
|
| 37 |
+
DEFAULT_MODEL_DIR = os.getenv(
|
| 38 |
+
'STANZA_RESOURCES_DIR',
|
| 39 |
+
os.path.join(HOME_DIR, 'stanza_resources')
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
PRETRAIN_NAMES = ("pretrain", "forward_charlm", "backward_charlm")
|
| 43 |
+
|
| 44 |
+
class ResourcesFileNotFoundError(FileNotFoundError):
|
| 45 |
+
def __init__(self, resources_filepath):
|
| 46 |
+
super().__init__(f"Resources file not found at: {resources_filepath} Try to download the model again.")
|
| 47 |
+
self.resources_filepath = resources_filepath
|
| 48 |
+
|
| 49 |
+
class UnknownLanguageError(ValueError):
|
| 50 |
+
def __init__(self, unknown):
|
| 51 |
+
super().__init__(f"Unknown language requested: {unknown}")
|
| 52 |
+
self.unknown_language = unknown
|
| 53 |
+
|
| 54 |
+
class UnknownProcessorError(ValueError):
|
| 55 |
+
def __init__(self, unknown):
|
| 56 |
+
super().__init__(f"Unknown processor type requested: {unknown}")
|
| 57 |
+
self.unknown_processor = unknown
|
| 58 |
+
|
| 59 |
+
ModelSpecification = namedtuple('ModelSpecification', ['processor', 'package', 'dependencies'])
|
| 60 |
+
|
| 61 |
+
def ensure_dir(path):
|
| 62 |
+
"""
|
| 63 |
+
Create dir in case it does not exist.
|
| 64 |
+
"""
|
| 65 |
+
Path(path).mkdir(parents=True, exist_ok=True)
|
| 66 |
+
|
| 67 |
+
def get_md5(path):
|
| 68 |
+
"""
|
| 69 |
+
Get the MD5 value of a path.
|
| 70 |
+
"""
|
| 71 |
+
try:
|
| 72 |
+
with open(path, 'rb') as fin:
|
| 73 |
+
data = fin.read()
|
| 74 |
+
except OSError as e:
|
| 75 |
+
if not e.filename:
|
| 76 |
+
e.filename = path
|
| 77 |
+
raise
|
| 78 |
+
return hashlib.md5(data).hexdigest()
|
| 79 |
+
|
| 80 |
+
def unzip(path, filename):
|
| 81 |
+
"""
|
| 82 |
+
Fully unzip a file `filename` that's in a directory `dir`.
|
| 83 |
+
"""
|
| 84 |
+
logger.debug(f'Unzip: {path}/{filename}...')
|
| 85 |
+
with zipfile.ZipFile(os.path.join(path, filename)) as f:
|
| 86 |
+
f.extractall(path)
|
| 87 |
+
|
| 88 |
+
def get_root_from_zipfile(filename):
|
| 89 |
+
"""
|
| 90 |
+
Get the root directory from a archived zip file.
|
| 91 |
+
"""
|
| 92 |
+
zf = zipfile.ZipFile(filename, "r")
|
| 93 |
+
assert len(zf.filelist) > 0, \
|
| 94 |
+
f"Zip file at f{filename} seems to be corrupted. Please check it."
|
| 95 |
+
return os.path.dirname(zf.filelist[0].filename)
|
| 96 |
+
|
| 97 |
+
def file_exists(path, md5):
|
| 98 |
+
"""
|
| 99 |
+
Check if the file at `path` exists and match the provided md5 value.
|
| 100 |
+
"""
|
| 101 |
+
return os.path.exists(path) and get_md5(path) == md5
|
| 102 |
+
|
| 103 |
+
def assert_file_exists(path, md5=None, alternate_md5=None):
|
| 104 |
+
if not os.path.exists(path):
|
| 105 |
+
raise FileNotFoundError(errno.ENOENT, "Cannot find expected file", path)
|
| 106 |
+
if md5:
|
| 107 |
+
file_md5 = get_md5(path)
|
| 108 |
+
if file_md5 != md5:
|
| 109 |
+
if file_md5 == alternate_md5:
|
| 110 |
+
logger.debug("Found a possibly older version of file %s, md5 %s instead of %s", path, alternate_md5, md5)
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError("md5 for %s is %s, expected %s" % (path, file_md5, md5))
|
| 113 |
+
|
| 114 |
+
def download_file(url, path, proxies, raise_for_status=False):
|
| 115 |
+
"""
|
| 116 |
+
Download a URL into a file as specified by `path`.
|
| 117 |
+
"""
|
| 118 |
+
verbose = logger.level in [0, 10, 20]
|
| 119 |
+
r = requests.get(url, stream=True, proxies=proxies)
|
| 120 |
+
if raise_for_status:
|
| 121 |
+
r.raise_for_status()
|
| 122 |
+
with open(path, 'wb') as f:
|
| 123 |
+
file_size = int(r.headers.get('content-length'))
|
| 124 |
+
default_chunk_size = 131072
|
| 125 |
+
desc = 'Downloading ' + url
|
| 126 |
+
with tqdm(total=file_size, unit='B', unit_scale=True, \
|
| 127 |
+
disable=not verbose, desc=desc) as pbar:
|
| 128 |
+
for chunk in r.iter_content(chunk_size=default_chunk_size):
|
| 129 |
+
if chunk:
|
| 130 |
+
f.write(chunk)
|
| 131 |
+
f.flush()
|
| 132 |
+
pbar.update(len(chunk))
|
| 133 |
+
return r.status_code
|
| 134 |
+
|
| 135 |
+
def request_file(url, path, proxies=None, md5=None, raise_for_status=False, log_info=True, alternate_md5=None):
|
| 136 |
+
"""
|
| 137 |
+
A complete wrapper over download_file() that also make sure the directory of
|
| 138 |
+
`path` exists, and that a file matching the md5 value does not exist.
|
| 139 |
+
|
| 140 |
+
alternate_md5 allows for an alternate md5 that is acceptable (such as if an older version of a file is okay)
|
| 141 |
+
"""
|
| 142 |
+
basedir = Path(path).parent
|
| 143 |
+
ensure_dir(basedir)
|
| 144 |
+
if file_exists(path, md5):
|
| 145 |
+
if log_info:
|
| 146 |
+
logger.info(f'File exists: {path}')
|
| 147 |
+
else:
|
| 148 |
+
logger.debug(f'File exists: {path}')
|
| 149 |
+
return
|
| 150 |
+
# We write data first to a temporary directory,
|
| 151 |
+
# then use os.replace() so that multiple processes
|
| 152 |
+
# running at the same time don't clobber each other
|
| 153 |
+
# with partially downloaded files
|
| 154 |
+
# This was especially common with resources.json
|
| 155 |
+
with tempfile.TemporaryDirectory(dir=basedir) as temp:
|
| 156 |
+
temppath = os.path.join(temp, os.path.split(path)[-1])
|
| 157 |
+
download_file(url, temppath, proxies, raise_for_status)
|
| 158 |
+
os.replace(temppath, path)
|
| 159 |
+
assert_file_exists(path, md5, alternate_md5)
|
| 160 |
+
if log_info:
|
| 161 |
+
logger.info(f'Downloaded file to {path}')
|
| 162 |
+
else:
|
| 163 |
+
logger.debug(f'Downloaded file to {path}')
|
| 164 |
+
|
| 165 |
+
def sort_processors(processor_list):
|
| 166 |
+
sorted_list = []
|
| 167 |
+
for processor in PIPELINE_NAMES:
|
| 168 |
+
for item in processor_list:
|
| 169 |
+
if item[0] == processor:
|
| 170 |
+
sorted_list.append(item)
|
| 171 |
+
# going just by processors in PIPELINE_NAMES, this drops any names
|
| 172 |
+
# which are not an official processor but might still be useful
|
| 173 |
+
# check the list and append them to the end
|
| 174 |
+
# this is especially useful when downloading pretrain or charlm models
|
| 175 |
+
for processor in processor_list:
|
| 176 |
+
for item in sorted_list:
|
| 177 |
+
if processor[0] == item[0]:
|
| 178 |
+
break
|
| 179 |
+
else:
|
| 180 |
+
sorted_list.append(item)
|
| 181 |
+
return sorted_list
|
| 182 |
+
|
| 183 |
+
def add_mwt(processors, resources, lang):
|
| 184 |
+
"""Add mwt if tokenize is passed without mwt.
|
| 185 |
+
|
| 186 |
+
If tokenize is in the list, but mwt is not, and there is a corresponding
|
| 187 |
+
tokenize and mwt pair in the resources file, mwt is added so no missing
|
| 188 |
+
mwt errors are raised.
|
| 189 |
+
|
| 190 |
+
TODO: how does this handle EWT in English?
|
| 191 |
+
"""
|
| 192 |
+
value = processors[TOKENIZE]
|
| 193 |
+
if value in resources[lang][PACKAGES] and MWT in resources[lang][PACKAGES][value]:
|
| 194 |
+
logger.warning("Language %s package %s expects mwt, which has been added", lang, value)
|
| 195 |
+
processors[MWT] = value
|
| 196 |
+
elif (value in resources[lang][TOKENIZE] and MWT in resources[lang] and value in resources[lang][MWT]):
|
| 197 |
+
logger.warning("Language %s package %s expects mwt, which has been added", lang, value)
|
| 198 |
+
processors[MWT] = value
|
| 199 |
+
|
| 200 |
+
def maintain_processor_list(resources, lang, package, processors, allow_pretrain=False, maybe_add_mwt=True):
|
| 201 |
+
"""
|
| 202 |
+
Given a parsed resources file, language, and possible package
|
| 203 |
+
and/or processors, expands the package to the list of processors
|
| 204 |
+
|
| 205 |
+
Returns a list of processors
|
| 206 |
+
Each item in the list of processors is a pair:
|
| 207 |
+
name, then a list of ModelSpecification
|
| 208 |
+
so, for example:
|
| 209 |
+
[['pos', [ModelSpecification(processor='pos', package='gsd', dependencies=None)]],
|
| 210 |
+
['depparse', [ModelSpecification(processor='depparse', package='gsd', dependencies=None)]]]
|
| 211 |
+
"""
|
| 212 |
+
processor_list = defaultdict(list)
|
| 213 |
+
# resolve processor models
|
| 214 |
+
if processors:
|
| 215 |
+
logger.debug(f'Processing parameter "processors"...')
|
| 216 |
+
if maybe_add_mwt and TOKENIZE in processors and MWT not in processors:
|
| 217 |
+
add_mwt(processors, resources, lang)
|
| 218 |
+
for key, plist in processors.items():
|
| 219 |
+
if not isinstance(key, str):
|
| 220 |
+
raise ValueError("Processor names must be strings")
|
| 221 |
+
if not isinstance(plist, (tuple, list, str)):
|
| 222 |
+
raise ValueError("Processor values must be strings")
|
| 223 |
+
if isinstance(plist, str):
|
| 224 |
+
plist = [plist]
|
| 225 |
+
if key not in PIPELINE_NAMES:
|
| 226 |
+
if not allow_pretrain or key not in PRETRAIN_NAMES:
|
| 227 |
+
raise UnknownProcessorError(key)
|
| 228 |
+
for value in plist:
|
| 229 |
+
# check if keys and values can be found
|
| 230 |
+
if key in resources[lang] and value in resources[lang][key]:
|
| 231 |
+
logger.debug(f'Found {key}: {value}.')
|
| 232 |
+
processor_list[key].append(value)
|
| 233 |
+
# allow values to be default in some cases
|
| 234 |
+
elif value in resources[lang][PACKAGES] and key in resources[lang][PACKAGES][value]:
|
| 235 |
+
logger.debug(
|
| 236 |
+
f'Found {key}: {resources[lang][PACKAGES][value][key]}.'
|
| 237 |
+
)
|
| 238 |
+
processor_list[key].append(resources[lang][PACKAGES][value][key])
|
| 239 |
+
# optional defaults will be activated if specifically turned on
|
| 240 |
+
elif value in resources[lang][PACKAGES] and 'optional' in resources[lang][PACKAGES][value] and key in resources[lang][PACKAGES][value]['optional']:
|
| 241 |
+
logger.debug(
|
| 242 |
+
f"Found {key}: {resources[lang][PACKAGES][value]['optional'][key]}."
|
| 243 |
+
)
|
| 244 |
+
processor_list[key].append(resources[lang][PACKAGES][value]['optional'][key])
|
| 245 |
+
# allow processors to be set to variants that we didn't implement
|
| 246 |
+
elif value in PROCESSOR_VARIANTS[key]:
|
| 247 |
+
logger.debug(
|
| 248 |
+
f'Found {key}: {value}. '
|
| 249 |
+
f'Using external {value} variant for the {key} processor.'
|
| 250 |
+
)
|
| 251 |
+
processor_list[key].append(value)
|
| 252 |
+
# allow lemma to be set to "identity"
|
| 253 |
+
elif key == LEMMA and value == 'identity':
|
| 254 |
+
logger.debug(
|
| 255 |
+
f'Found {key}: {value}. Using identity lemmatizer.'
|
| 256 |
+
)
|
| 257 |
+
processor_list[key].append(value)
|
| 258 |
+
# not a processor in the officially supported processor list
|
| 259 |
+
elif key not in resources[lang]:
|
| 260 |
+
logger.debug(
|
| 261 |
+
f'{key}: {value} is not officially supported by Stanza, '
|
| 262 |
+
f'loading it anyway.'
|
| 263 |
+
)
|
| 264 |
+
processor_list[key].append(value)
|
| 265 |
+
# cannot find the package for a processor and warn user
|
| 266 |
+
else:
|
| 267 |
+
logger.warning(
|
| 268 |
+
f'Can not find {key}: {value} from official model list. '
|
| 269 |
+
f'Ignoring it.'
|
| 270 |
+
)
|
| 271 |
+
# resolve package
|
| 272 |
+
if package:
|
| 273 |
+
logger.debug(f'Processing parameter "package"...')
|
| 274 |
+
if PACKAGES in resources[lang] and package in resources[lang][PACKAGES]:
|
| 275 |
+
for key, value in resources[lang][PACKAGES][package].items():
|
| 276 |
+
if key != 'optional' and key not in processor_list:
|
| 277 |
+
logger.debug(f'Found {key}: {value}.')
|
| 278 |
+
processor_list[key].append(value)
|
| 279 |
+
else:
|
| 280 |
+
flag = False
|
| 281 |
+
for key in PIPELINE_NAMES:
|
| 282 |
+
if key not in resources[lang]: continue
|
| 283 |
+
if package in resources[lang][key]:
|
| 284 |
+
flag = True
|
| 285 |
+
if key not in processor_list:
|
| 286 |
+
logger.debug(f'Found {key}: {package}.')
|
| 287 |
+
processor_list[key].append(package)
|
| 288 |
+
else:
|
| 289 |
+
logger.debug(
|
| 290 |
+
f'{key}: {package} is overwritten by '
|
| 291 |
+
f'{key}: {processors[key]}.'
|
| 292 |
+
)
|
| 293 |
+
if not flag: logger.warning((f'Can not find package: {package}.'))
|
| 294 |
+
processor_list = [[key, [ModelSpecification(processor=key, package=value, dependencies=None) for value in plist]] for key, plist in processor_list.items()]
|
| 295 |
+
processor_list = sort_processors(processor_list)
|
| 296 |
+
return processor_list
|
| 297 |
+
|
| 298 |
+
def add_dependencies(resources, lang, processor_list):
|
| 299 |
+
"""
|
| 300 |
+
Expand the processor_list as given in maintain_processor_list to have the dependencies
|
| 301 |
+
|
| 302 |
+
Still a list of model types to ModelSpecifications
|
| 303 |
+
the dependencies are tuples: name and package
|
| 304 |
+
for example:
|
| 305 |
+
[['pos', (ModelSpecification(processor='pos', package='gsd', dependencies=(('pretrain', 'gsd'),)),)],
|
| 306 |
+
['depparse', (ModelSpecification(processor='depparse', package='gsd', dependencies=(('pretrain', 'gsd'),)),)]]
|
| 307 |
+
"""
|
| 308 |
+
lang_resources = resources[lang]
|
| 309 |
+
for item in processor_list:
|
| 310 |
+
processor, model_specs = item
|
| 311 |
+
new_model_specs = []
|
| 312 |
+
for model_spec in model_specs:
|
| 313 |
+
# skip dependency checking for external variants of processors and identity lemmatizer
|
| 314 |
+
if not any([
|
| 315 |
+
model_spec.package in PROCESSOR_VARIANTS[processor],
|
| 316 |
+
processor == LEMMA and model_spec.package == 'identity'
|
| 317 |
+
]):
|
| 318 |
+
dependencies = lang_resources.get(processor, {}).get(model_spec.package, {}).get('dependencies', [])
|
| 319 |
+
dependencies = [(dependency['model'], dependency['package']) for dependency in dependencies]
|
| 320 |
+
model_spec = model_spec._replace(dependencies=tuple(dependencies))
|
| 321 |
+
logger.debug("Found dependencies %s for processor %s model %s", dependencies, processor, model_spec.package)
|
| 322 |
+
new_model_specs.append(model_spec)
|
| 323 |
+
item[1] = tuple(new_model_specs)
|
| 324 |
+
return processor_list
|
| 325 |
+
|
| 326 |
+
def flatten_processor_list(processor_list):
|
| 327 |
+
"""
|
| 328 |
+
The flattened processor list is just a list of types & packages
|
| 329 |
+
|
| 330 |
+
For example:
|
| 331 |
+
[['pos', 'gsd'], ['depparse', 'gsd'], ['pretrain', 'gsd']]
|
| 332 |
+
"""
|
| 333 |
+
flattened_processor_list = []
|
| 334 |
+
dependencies_list = []
|
| 335 |
+
for item in processor_list:
|
| 336 |
+
processor, model_specs = item
|
| 337 |
+
for model_spec in model_specs:
|
| 338 |
+
package = model_spec.package
|
| 339 |
+
dependencies = model_spec.dependencies
|
| 340 |
+
flattened_processor_list.append([processor, package])
|
| 341 |
+
if dependencies:
|
| 342 |
+
dependencies_list += [tuple(dependency) for dependency in dependencies]
|
| 343 |
+
dependencies_list = [list(item) for item in set(dependencies_list)]
|
| 344 |
+
for processor, package in dependencies_list:
|
| 345 |
+
logger.debug(f'Find dependency {processor}: {package}.')
|
| 346 |
+
flattened_processor_list += dependencies_list
|
| 347 |
+
return flattened_processor_list
|
| 348 |
+
|
| 349 |
+
def set_logging_level(logging_level, verbose):
|
| 350 |
+
# Check verbose for easy logging control
|
| 351 |
+
if verbose == False:
|
| 352 |
+
logging_level = 'ERROR'
|
| 353 |
+
elif verbose == True:
|
| 354 |
+
logging_level = 'INFO'
|
| 355 |
+
|
| 356 |
+
if logging_level is None:
|
| 357 |
+
# default logging level of INFO is set in stanza.__init__
|
| 358 |
+
# but the user may have set it via the logging API
|
| 359 |
+
# it should NOT be 0, but let's check to be sure...
|
| 360 |
+
if logger.level == 0:
|
| 361 |
+
logger.setLevel('INFO')
|
| 362 |
+
return logger.level
|
| 363 |
+
|
| 364 |
+
# Set logging level
|
| 365 |
+
logging_level = logging_level.upper()
|
| 366 |
+
all_levels = ['DEBUG', 'INFO', 'WARNING', 'WARN', 'ERROR', 'CRITICAL', 'FATAL']
|
| 367 |
+
if logging_level not in all_levels:
|
| 368 |
+
raise ValueError(
|
| 369 |
+
f"Unrecognized logging level for pipeline: "
|
| 370 |
+
f"{logging_level}. Must be one of {', '.join(all_levels)}."
|
| 371 |
+
)
|
| 372 |
+
logger.setLevel(logging_level)
|
| 373 |
+
return logger.level
|
| 374 |
+
|
| 375 |
+
def process_pipeline_parameters(lang, model_dir, package, processors):
|
| 376 |
+
# Check parameter types and convert values to lower case
|
| 377 |
+
if isinstance(lang, str):
|
| 378 |
+
lang = lang.strip().lower()
|
| 379 |
+
elif lang is not None:
|
| 380 |
+
raise TypeError(
|
| 381 |
+
f"The parameter 'lang' should be str, "
|
| 382 |
+
f"but got {type(lang).__name__} instead."
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
if isinstance(model_dir, str):
|
| 386 |
+
model_dir = model_dir.strip()
|
| 387 |
+
elif model_dir is not None:
|
| 388 |
+
raise TypeError(
|
| 389 |
+
f"The parameter 'model_dir' should be str, "
|
| 390 |
+
f"but got {type(model_dir).__name__} instead."
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
if isinstance(processors, (str, list, tuple)):
|
| 394 |
+
# Special case: processors is str, compatible with older version
|
| 395 |
+
# also allow for setting alternate packages for these processors
|
| 396 |
+
# via the package argument
|
| 397 |
+
if package is None:
|
| 398 |
+
# each processor will be 'default' for this language
|
| 399 |
+
package = defaultdict(lambda: 'default')
|
| 400 |
+
elif isinstance(package, str):
|
| 401 |
+
# same, but now the named package will be the default instead
|
| 402 |
+
default = package
|
| 403 |
+
package = defaultdict(lambda: default)
|
| 404 |
+
elif isinstance(package, dict):
|
| 405 |
+
# the dictionary of packages will be used to build the processors dict
|
| 406 |
+
# any processor not specified in package will be 'default'
|
| 407 |
+
package = defaultdict(lambda: 'default', package)
|
| 408 |
+
else:
|
| 409 |
+
raise TypeError(
|
| 410 |
+
f"The parameter 'package' should be None, str, or dict, "
|
| 411 |
+
f"but got {type(package).__name__} instead."
|
| 412 |
+
)
|
| 413 |
+
if isinstance(processors, str):
|
| 414 |
+
processors = [x.strip().lower() for x in processors.split(",")]
|
| 415 |
+
processors = {
|
| 416 |
+
processor: package[processor] for processor in processors
|
| 417 |
+
}
|
| 418 |
+
package = None
|
| 419 |
+
elif isinstance(processors, dict):
|
| 420 |
+
processors = {
|
| 421 |
+
k.strip().lower(): ([v_i.strip().lower() for v_i in v] if isinstance(v, (tuple, list)) else v.strip().lower())
|
| 422 |
+
for k, v in processors.items()
|
| 423 |
+
}
|
| 424 |
+
elif processors is not None:
|
| 425 |
+
raise TypeError(
|
| 426 |
+
f"The parameter 'processors' should be dict or str, "
|
| 427 |
+
f"but got {type(processors).__name__} instead."
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
if isinstance(package, str):
|
| 431 |
+
package = package.strip().lower()
|
| 432 |
+
elif package is not None:
|
| 433 |
+
raise TypeError(
|
| 434 |
+
f"The parameter 'package' should be str, or a dict if 'processors' is a str, "
|
| 435 |
+
f"but got {type(package).__name__} instead."
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
return lang, model_dir, package, processors
|
| 439 |
+
|
| 440 |
+
def download_resources_json(model_dir=DEFAULT_MODEL_DIR,
|
| 441 |
+
resources_url=DEFAULT_RESOURCES_URL,
|
| 442 |
+
resources_branch=None,
|
| 443 |
+
resources_version=DEFAULT_RESOURCES_VERSION,
|
| 444 |
+
resources_filepath=None,
|
| 445 |
+
proxies=None):
|
| 446 |
+
"""
|
| 447 |
+
Downloads resources.json to obtain latest packages.
|
| 448 |
+
"""
|
| 449 |
+
if resources_url == DEFAULT_RESOURCES_URL and resources_branch is not None:
|
| 450 |
+
resources_url = STANZA_RESOURCES_GITHUB + resources_branch
|
| 451 |
+
# handle short name for resources urls; otherwise treat it as url
|
| 452 |
+
if resources_url.lower() in ('stanford', 'stanfordnlp'):
|
| 453 |
+
resources_url = STANFORDNLP_RESOURCES_URL
|
| 454 |
+
resources_url = f'{resources_url}/resources_{resources_version}.json'
|
| 455 |
+
logger.debug('Downloading resource file from %s', resources_url)
|
| 456 |
+
if resources_filepath is None:
|
| 457 |
+
resources_filepath = os.path.join(model_dir, 'resources.json')
|
| 458 |
+
# make request
|
| 459 |
+
request_file(
|
| 460 |
+
resources_url,
|
| 461 |
+
resources_filepath,
|
| 462 |
+
proxies,
|
| 463 |
+
raise_for_status=True
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def load_resources_json(model_dir=DEFAULT_MODEL_DIR, resources_filepath=None):
|
| 468 |
+
"""
|
| 469 |
+
Unpack the resources json file from the given model_dir
|
| 470 |
+
"""
|
| 471 |
+
if resources_filepath is None:
|
| 472 |
+
resources_filepath = os.path.join(model_dir, 'resources.json')
|
| 473 |
+
if not os.path.exists(resources_filepath):
|
| 474 |
+
raise ResourcesFileNotFoundError(resources_filepath)
|
| 475 |
+
with open(resources_filepath, encoding="utf-8") as fin:
|
| 476 |
+
resources = json.load(fin)
|
| 477 |
+
return resources
|
| 478 |
+
|
| 479 |
+
def get_language_resources(resources, lang):
|
| 480 |
+
"""
|
| 481 |
+
Get the resources for a lang from an already loaded resources json, following 'alias' if needed
|
| 482 |
+
"""
|
| 483 |
+
if lang not in resources:
|
| 484 |
+
return None
|
| 485 |
+
|
| 486 |
+
lang_resources = resources[lang]
|
| 487 |
+
while 'alias' in lang_resources:
|
| 488 |
+
lang = lang_resources['alias']
|
| 489 |
+
lang_resources = resources[lang]
|
| 490 |
+
|
| 491 |
+
return lang_resources
|
| 492 |
+
|
| 493 |
+
def list_available_languages(model_dir=DEFAULT_MODEL_DIR,
|
| 494 |
+
resources_url=DEFAULT_RESOURCES_URL,
|
| 495 |
+
resources_branch=None,
|
| 496 |
+
resources_version=DEFAULT_RESOURCES_VERSION,
|
| 497 |
+
proxies=None):
|
| 498 |
+
"""
|
| 499 |
+
List the non-alias languages in the resources file
|
| 500 |
+
"""
|
| 501 |
+
download_resources_json(model_dir, resources_url, resources_branch, resources_version, resources_filepath=None, proxies=proxies)
|
| 502 |
+
resources = load_resources_json(model_dir)
|
| 503 |
+
# isinstance(str) is because of fields such as "url"
|
| 504 |
+
# 'alias' is because we want to skip German, alias of de, for example
|
| 505 |
+
languages = [lang for lang in resources
|
| 506 |
+
if not isinstance(resources[lang], str) and 'alias' not in resources[lang]]
|
| 507 |
+
languages = sorted(languages)
|
| 508 |
+
return languages
|
| 509 |
+
|
| 510 |
+
def expand_model_url(resources, model_url):
|
| 511 |
+
"""
|
| 512 |
+
Returns the url in the resources dict if model_url is default, or returns the model_url
|
| 513 |
+
"""
|
| 514 |
+
return resources['url'] if model_url.lower() == 'default' else model_url
|
| 515 |
+
|
| 516 |
+
def download_models(download_list,
|
| 517 |
+
resources,
|
| 518 |
+
lang,
|
| 519 |
+
model_dir=DEFAULT_MODEL_DIR,
|
| 520 |
+
resources_version=DEFAULT_RESOURCES_VERSION,
|
| 521 |
+
model_url=DEFAULT_MODEL_URL,
|
| 522 |
+
proxies=None,
|
| 523 |
+
log_info=True):
|
| 524 |
+
lang_name = resources.get(lang, {}).get('lang_name', lang)
|
| 525 |
+
download_table = make_table(['Processor', 'Package'], download_list)
|
| 526 |
+
if log_info:
|
| 527 |
+
log_msg = logger.info
|
| 528 |
+
else:
|
| 529 |
+
log_msg = logger.debug
|
| 530 |
+
log_msg(
|
| 531 |
+
f'Downloading these customized packages for language: '
|
| 532 |
+
f'{lang} ({lang_name})...\n{download_table}'
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
url = expand_model_url(resources, model_url)
|
| 536 |
+
|
| 537 |
+
# Download packages
|
| 538 |
+
for key, value in download_list:
|
| 539 |
+
try:
|
| 540 |
+
request_file(
|
| 541 |
+
url.format(resources_version=resources_version, lang=lang, filename=f"{key}/{value}.pt"),
|
| 542 |
+
os.path.join(model_dir, lang, key, f'{value}.pt'),
|
| 543 |
+
proxies,
|
| 544 |
+
md5=resources[lang][key][value]['md5'],
|
| 545 |
+
log_info=log_info,
|
| 546 |
+
alternate_md5=resources[lang][key][value].get('alternate_md5', None)
|
| 547 |
+
)
|
| 548 |
+
except KeyError as e:
|
| 549 |
+
raise ValueError(
|
| 550 |
+
f'Cannot find the following processor and model name combination: '
|
| 551 |
+
f'{key}, {value}. Please check if you have provided the correct model name.'
|
| 552 |
+
) from e
|
| 553 |
+
|
| 554 |
+
# main download function
|
| 555 |
+
def download(
|
| 556 |
+
lang='en',
|
| 557 |
+
model_dir=DEFAULT_MODEL_DIR,
|
| 558 |
+
package='default',
|
| 559 |
+
processors={},
|
| 560 |
+
logging_level=None,
|
| 561 |
+
verbose=None,
|
| 562 |
+
resources_url=DEFAULT_RESOURCES_URL,
|
| 563 |
+
resources_branch=None,
|
| 564 |
+
resources_version=DEFAULT_RESOURCES_VERSION,
|
| 565 |
+
model_url=DEFAULT_MODEL_URL,
|
| 566 |
+
proxies=None,
|
| 567 |
+
download_json=True
|
| 568 |
+
):
|
| 569 |
+
# set global logging level
|
| 570 |
+
set_logging_level(logging_level, verbose)
|
| 571 |
+
# process different pipeline parameters
|
| 572 |
+
lang, model_dir, package, processors = process_pipeline_parameters(
|
| 573 |
+
lang, model_dir, package, processors
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
if download_json or not os.path.exists(os.path.join(model_dir, 'resources.json')):
|
| 577 |
+
if not download_json:
|
| 578 |
+
logger.warning("Asked to skip downloading resources.json, but the file does not exist. Downloading anyway")
|
| 579 |
+
download_resources_json(model_dir, resources_url, resources_branch, resources_version, resources_filepath=None, proxies=proxies)
|
| 580 |
+
|
| 581 |
+
resources = load_resources_json(model_dir)
|
| 582 |
+
if lang not in resources:
|
| 583 |
+
raise UnknownLanguageError(lang)
|
| 584 |
+
if 'alias' in resources[lang]:
|
| 585 |
+
logger.info(f'"{lang}" is an alias for "{resources[lang]["alias"]}"')
|
| 586 |
+
lang = resources[lang]['alias']
|
| 587 |
+
lang_name = resources.get(lang, {}).get('lang_name', lang)
|
| 588 |
+
url = expand_model_url(resources, model_url)
|
| 589 |
+
|
| 590 |
+
# Default: download zipfile and unzip
|
| 591 |
+
if package == 'default' and (processors is None or len(processors) == 0):
|
| 592 |
+
logger.info(
|
| 593 |
+
f'Downloading default packages for language: {lang} ({lang_name}) ...'
|
| 594 |
+
)
|
| 595 |
+
# want the URL to become, for example:
|
| 596 |
+
# https://huggingface.co/stanfordnlp/stanza-af/resolve/v1.3.0/models/default.zip
|
| 597 |
+
# so we hopefully start from
|
| 598 |
+
# https://huggingface.co/stanfordnlp/stanza-{lang}/resolve/v{resources_version}/models/{filename}
|
| 599 |
+
request_file(
|
| 600 |
+
url.format(resources_version=resources_version, lang=lang, filename="default.zip"),
|
| 601 |
+
os.path.join(model_dir, lang, f'default.zip'),
|
| 602 |
+
proxies,
|
| 603 |
+
md5=resources[lang]['default_md5'],
|
| 604 |
+
)
|
| 605 |
+
unzip(os.path.join(model_dir, lang), 'default.zip')
|
| 606 |
+
# Customize: maintain download list
|
| 607 |
+
else:
|
| 608 |
+
download_list = maintain_processor_list(resources, lang, package, processors, allow_pretrain=True)
|
| 609 |
+
download_list = add_dependencies(resources, lang, download_list)
|
| 610 |
+
download_list = flatten_processor_list(download_list)
|
| 611 |
+
download_models(download_list=download_list,
|
| 612 |
+
resources=resources,
|
| 613 |
+
lang=lang,
|
| 614 |
+
model_dir=model_dir,
|
| 615 |
+
resources_version=resources_version,
|
| 616 |
+
model_url=model_url,
|
| 617 |
+
proxies=proxies,
|
| 618 |
+
log_info=True)
|
| 619 |
+
logger.info(f'Finished downloading models and saved to {model_dir}')
|
stanza/stanza/resources/default_packages.py
ADDED
|
@@ -0,0 +1,909 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Constants for default packages, default pretrains, charlms, etc
|
| 3 |
+
|
| 4 |
+
Separated from prepare_resources.py so that other modules can use the
|
| 5 |
+
same lists / maps without importing the resources script and possibly
|
| 6 |
+
causing a circular import
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import copy
|
| 10 |
+
|
| 11 |
+
# all languages will have a map which represents the available packages
|
| 12 |
+
PACKAGES = "packages"
|
| 13 |
+
|
| 14 |
+
# default treebank for languages
|
| 15 |
+
default_treebanks = {
|
| 16 |
+
"af": "afribooms",
|
| 17 |
+
# currently not publicly released! sent to us from the group developing this resource
|
| 18 |
+
"ang": "nerthus",
|
| 19 |
+
"ar": "padt",
|
| 20 |
+
"be": "hse",
|
| 21 |
+
"bg": "btb",
|
| 22 |
+
"bxr": "bdt",
|
| 23 |
+
"ca": "ancora",
|
| 24 |
+
"cop": "scriptorium",
|
| 25 |
+
"cs": "pdt",
|
| 26 |
+
"cu": "proiel",
|
| 27 |
+
"cy": "ccg",
|
| 28 |
+
"da": "ddt",
|
| 29 |
+
"de": "gsd",
|
| 30 |
+
"el": "gdt",
|
| 31 |
+
"en": "combined",
|
| 32 |
+
"es": "combined",
|
| 33 |
+
"et": "edt",
|
| 34 |
+
"eu": "bdt",
|
| 35 |
+
"fa": "perdt",
|
| 36 |
+
"fi": "tdt",
|
| 37 |
+
"fo": "farpahc",
|
| 38 |
+
"fr": "combined",
|
| 39 |
+
"fro": "profiterole",
|
| 40 |
+
"ga": "idt",
|
| 41 |
+
"gd": "arcosg",
|
| 42 |
+
"gl": "ctg",
|
| 43 |
+
"got": "proiel",
|
| 44 |
+
"grc": "perseus",
|
| 45 |
+
"gv": "cadhan",
|
| 46 |
+
"hbo": "ptnk",
|
| 47 |
+
"he": "combined",
|
| 48 |
+
"hi": "hdtb",
|
| 49 |
+
"hr": "set",
|
| 50 |
+
"hsb": "ufal",
|
| 51 |
+
"hu": "szeged",
|
| 52 |
+
"hy": "armtdp",
|
| 53 |
+
"hyw": "armtdp",
|
| 54 |
+
"id": "gsd",
|
| 55 |
+
"is": "icepahc",
|
| 56 |
+
"it": "combined",
|
| 57 |
+
"ja": "gsd",
|
| 58 |
+
"ka": "glc",
|
| 59 |
+
"kk": "ktb",
|
| 60 |
+
"kmr": "mg",
|
| 61 |
+
"ko": "kaist",
|
| 62 |
+
"kpv": "lattice",
|
| 63 |
+
"ky": "ktmu",
|
| 64 |
+
"la": "ittb",
|
| 65 |
+
"lij": "glt",
|
| 66 |
+
"lt": "alksnis",
|
| 67 |
+
"lv": "lvtb",
|
| 68 |
+
"lzh": "kyoto",
|
| 69 |
+
"mr": "ufal",
|
| 70 |
+
"mt": "mudt",
|
| 71 |
+
"my": "ucsy",
|
| 72 |
+
"myv": "jr",
|
| 73 |
+
"nb": "bokmaal",
|
| 74 |
+
"nds": "lsdc",
|
| 75 |
+
"nl": "alpino",
|
| 76 |
+
"nn": "nynorsk",
|
| 77 |
+
"olo": "kkpp",
|
| 78 |
+
"orv": "torot",
|
| 79 |
+
"ota": "boun",
|
| 80 |
+
"pcm": "nsc",
|
| 81 |
+
"pl": "pdb",
|
| 82 |
+
"pt": "bosque",
|
| 83 |
+
"qaf": "arabizi",
|
| 84 |
+
"qpm": "philotis",
|
| 85 |
+
"qtd": "sagt",
|
| 86 |
+
"ro": "rrt",
|
| 87 |
+
"ru": "syntagrus",
|
| 88 |
+
"sa": "vedic",
|
| 89 |
+
"sd": "isra",
|
| 90 |
+
"sk": "snk",
|
| 91 |
+
"sl": "ssj",
|
| 92 |
+
"sme": "giella",
|
| 93 |
+
"sq": "combined",
|
| 94 |
+
"sr": "set",
|
| 95 |
+
"sv": "talbanken",
|
| 96 |
+
"swl": "sslc",
|
| 97 |
+
"ta": "ttb",
|
| 98 |
+
"te": "mtg",
|
| 99 |
+
"th": "orchid",
|
| 100 |
+
"tr": "imst",
|
| 101 |
+
"ug": "udt",
|
| 102 |
+
"uk": "iu",
|
| 103 |
+
"ur": "udtb",
|
| 104 |
+
"vi": "vtb",
|
| 105 |
+
"wo": "wtb",
|
| 106 |
+
"xcl": "caval",
|
| 107 |
+
"zh-hans": "gsdsimp",
|
| 108 |
+
"zh-hant": "gsd",
|
| 109 |
+
"multilingual": "ud"
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
no_pretrain_languages = set([
|
| 113 |
+
"cop",
|
| 114 |
+
"olo",
|
| 115 |
+
"orv",
|
| 116 |
+
"pcm",
|
| 117 |
+
"qaf", # the QAF treebank is code switched and Romanized, so not easy to reuse existing resources
|
| 118 |
+
"qpm", # have talked about deriving this from a language neighborinig to Pomak, but that hasn't happened yet
|
| 119 |
+
"qtd",
|
| 120 |
+
"swl",
|
| 121 |
+
|
| 122 |
+
"multilingual", # special case so that all languages with a default treebank are represented somewhere
|
| 123 |
+
])
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# in some cases, we give the pretrain a name other than the original
|
| 127 |
+
# name for the UD dataset
|
| 128 |
+
# we will eventually do this for all of the pretrains
|
| 129 |
+
specific_default_pretrains = {
|
| 130 |
+
"af": "fasttextwiki",
|
| 131 |
+
"ang": "nerthus",
|
| 132 |
+
"ar": "conll17",
|
| 133 |
+
"be": "fasttextwiki",
|
| 134 |
+
"bg": "conll17",
|
| 135 |
+
"bxr": "fasttextwiki",
|
| 136 |
+
"ca": "conll17",
|
| 137 |
+
"cs": "conll17",
|
| 138 |
+
"cu": "conll17",
|
| 139 |
+
"cy": "fasttext157",
|
| 140 |
+
"da": "conll17",
|
| 141 |
+
"de": "conll17",
|
| 142 |
+
"el": "conll17",
|
| 143 |
+
"en": "conll17",
|
| 144 |
+
"es": "conll17",
|
| 145 |
+
"et": "conll17",
|
| 146 |
+
"eu": "conll17",
|
| 147 |
+
"fa": "conll17",
|
| 148 |
+
"fi": "conll17",
|
| 149 |
+
"fo": "fasttextwiki",
|
| 150 |
+
"fr": "conll17",
|
| 151 |
+
"fro": "conll17",
|
| 152 |
+
"ga": "conll17",
|
| 153 |
+
"gd": "fasttextwiki",
|
| 154 |
+
"gl": "conll17",
|
| 155 |
+
"got": "fasttextwiki",
|
| 156 |
+
"grc": "conll17",
|
| 157 |
+
"gv": "fasttext157",
|
| 158 |
+
"hbo": "utah",
|
| 159 |
+
"he": "conll17",
|
| 160 |
+
"hi": "conll17",
|
| 161 |
+
"hr": "conll17",
|
| 162 |
+
"hsb": "fasttextwiki",
|
| 163 |
+
"hu": "conll17",
|
| 164 |
+
"hy": "isprasglove",
|
| 165 |
+
"hyw": "isprasglove",
|
| 166 |
+
"id": "conll17",
|
| 167 |
+
"is": "fasttext157",
|
| 168 |
+
"it": "conll17",
|
| 169 |
+
"ja": "conll17",
|
| 170 |
+
"ka": "fasttext157",
|
| 171 |
+
"kk": "fasttext157",
|
| 172 |
+
"kmr": "fasttextwiki",
|
| 173 |
+
"ko": "conll17",
|
| 174 |
+
"kpv": "fasttextwiki",
|
| 175 |
+
"ky": "fasttext157",
|
| 176 |
+
"la": "conll17",
|
| 177 |
+
"lij": "fasttextwiki",
|
| 178 |
+
"lt": "fasttextwiki",
|
| 179 |
+
"lv": "conll17",
|
| 180 |
+
"lzh": "fasttextwiki",
|
| 181 |
+
"mr": "fasttextwiki",
|
| 182 |
+
"mt": "fasttextwiki",
|
| 183 |
+
"my": "ucsy",
|
| 184 |
+
"myv": "mokha",
|
| 185 |
+
"nb": "conll17",
|
| 186 |
+
"nds": "fasttext157",
|
| 187 |
+
"nl": "conll17",
|
| 188 |
+
"nn": "conll17",
|
| 189 |
+
"ota": "conll17",
|
| 190 |
+
"pl": "conll17",
|
| 191 |
+
"pt": "conll17",
|
| 192 |
+
"ro": "conll17",
|
| 193 |
+
"ru": "conll17",
|
| 194 |
+
"sa": "fasttext157",
|
| 195 |
+
"sd": "isra",
|
| 196 |
+
"sk": "conll17",
|
| 197 |
+
"sl": "conll17",
|
| 198 |
+
"sme": "fasttextwiki",
|
| 199 |
+
"sq": "fasttext157",
|
| 200 |
+
"sr": "fasttextwiki",
|
| 201 |
+
"sv": "conll17",
|
| 202 |
+
"ta": "fasttextwiki",
|
| 203 |
+
"te": "fasttextwiki",
|
| 204 |
+
"th": "fasttext157",
|
| 205 |
+
"tr": "conll17",
|
| 206 |
+
"ug": "conll17",
|
| 207 |
+
"uk": "conll17",
|
| 208 |
+
"ur": "conll17",
|
| 209 |
+
"vi": "conll17",
|
| 210 |
+
"wo": "fasttextwiki",
|
| 211 |
+
"xcl": "caval",
|
| 212 |
+
"zh-hans": "fasttext157",
|
| 213 |
+
"zh-hant": "conll17",
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def build_default_pretrains(default_treebanks):
|
| 218 |
+
default_pretrains = dict(default_treebanks)
|
| 219 |
+
for lang in no_pretrain_languages:
|
| 220 |
+
default_pretrains.pop(lang, None)
|
| 221 |
+
for lang in specific_default_pretrains.keys():
|
| 222 |
+
default_pretrains[lang] = specific_default_pretrains[lang]
|
| 223 |
+
return default_pretrains
|
| 224 |
+
|
| 225 |
+
default_pretrains = build_default_pretrains(default_treebanks)
|
| 226 |
+
|
| 227 |
+
pos_pretrains = {
|
| 228 |
+
"en": {
|
| 229 |
+
"craft": "biomed",
|
| 230 |
+
"genia": "biomed",
|
| 231 |
+
"mimic": "mimic",
|
| 232 |
+
},
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
depparse_pretrains = pos_pretrains
|
| 236 |
+
|
| 237 |
+
ner_pretrains = {
|
| 238 |
+
"ar": {
|
| 239 |
+
"aqmar": "fasttextwiki",
|
| 240 |
+
},
|
| 241 |
+
"de": {
|
| 242 |
+
"conll03": "fasttextwiki",
|
| 243 |
+
# the bert version of germeval uses the smaller vector file
|
| 244 |
+
"germeval2014": "fasttextwiki",
|
| 245 |
+
},
|
| 246 |
+
"en": {
|
| 247 |
+
"anatem": "biomed",
|
| 248 |
+
"bc4chemd": "biomed",
|
| 249 |
+
"bc5cdr": "biomed",
|
| 250 |
+
"bionlp13cg": "biomed",
|
| 251 |
+
"jnlpba": "biomed",
|
| 252 |
+
"linnaeus": "biomed",
|
| 253 |
+
"ncbi_disease": "biomed",
|
| 254 |
+
"s800": "biomed",
|
| 255 |
+
|
| 256 |
+
"ontonotes": "fasttextcrawl",
|
| 257 |
+
# the stanza-train sample NER model should use the default NER pretrain
|
| 258 |
+
# for English, that is the same as ontonotes
|
| 259 |
+
"sample": "fasttextcrawl",
|
| 260 |
+
|
| 261 |
+
"conll03": "glove",
|
| 262 |
+
|
| 263 |
+
"i2b2": "mimic",
|
| 264 |
+
"radiology": "mimic",
|
| 265 |
+
},
|
| 266 |
+
"es": {
|
| 267 |
+
"ancora": "fasttextwiki",
|
| 268 |
+
"conll02": "fasttextwiki",
|
| 269 |
+
},
|
| 270 |
+
"nl": {
|
| 271 |
+
"conll02": "fasttextwiki",
|
| 272 |
+
"wikiner": "fasttextwiki",
|
| 273 |
+
},
|
| 274 |
+
"ru": {
|
| 275 |
+
"wikiner": "fasttextwiki",
|
| 276 |
+
},
|
| 277 |
+
"th": {
|
| 278 |
+
"lst20": "fasttext157",
|
| 279 |
+
},
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# default charlms for languages
|
| 284 |
+
default_charlms = {
|
| 285 |
+
"af": "oscar",
|
| 286 |
+
"ang": "nerthus1024",
|
| 287 |
+
"ar": "ccwiki",
|
| 288 |
+
"bg": "conll17",
|
| 289 |
+
"da": "oscar",
|
| 290 |
+
"de": "newswiki",
|
| 291 |
+
"en": "1billion",
|
| 292 |
+
"es": "newswiki",
|
| 293 |
+
"fa": "conll17",
|
| 294 |
+
"fi": "conll17",
|
| 295 |
+
"fr": "newswiki",
|
| 296 |
+
"he": "oscar",
|
| 297 |
+
"hi": "oscar",
|
| 298 |
+
"id": "oscar2023",
|
| 299 |
+
"it": "conll17",
|
| 300 |
+
"ja": "conll17",
|
| 301 |
+
"kk": "oscar",
|
| 302 |
+
"mr": "l3cube",
|
| 303 |
+
"my": "oscar",
|
| 304 |
+
"nb": "conll17",
|
| 305 |
+
"nl": "ccwiki",
|
| 306 |
+
"pl": "oscar",
|
| 307 |
+
"pt": "oscar2023",
|
| 308 |
+
"ru": "newswiki",
|
| 309 |
+
"sd": "isra",
|
| 310 |
+
"sv": "conll17",
|
| 311 |
+
"te": "oscar2022",
|
| 312 |
+
"th": "oscar",
|
| 313 |
+
"tr": "conll17",
|
| 314 |
+
"uk": "conll17",
|
| 315 |
+
"vi": "conll17",
|
| 316 |
+
"zh-hans": "gigaword"
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
pos_charlms = {
|
| 320 |
+
"en": {
|
| 321 |
+
# none of the English charlms help with craft or genia
|
| 322 |
+
"craft": None,
|
| 323 |
+
"genia": None,
|
| 324 |
+
"mimic": "mimic",
|
| 325 |
+
},
|
| 326 |
+
"tr": { # no idea why, but this particular one goes down in dev score
|
| 327 |
+
"boun": None,
|
| 328 |
+
},
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
depparse_charlms = copy.deepcopy(pos_charlms)
|
| 332 |
+
|
| 333 |
+
lemma_charlms = copy.deepcopy(pos_charlms)
|
| 334 |
+
|
| 335 |
+
ner_charlms = {
|
| 336 |
+
"en": {
|
| 337 |
+
"conll03": "1billion",
|
| 338 |
+
"ontonotes": "1billion",
|
| 339 |
+
"anatem": "pubmed",
|
| 340 |
+
"bc4chemd": "pubmed",
|
| 341 |
+
"bc5cdr": "pubmed",
|
| 342 |
+
"bionlp13cg": "pubmed",
|
| 343 |
+
"i2b2": "mimic",
|
| 344 |
+
"jnlpba": "pubmed",
|
| 345 |
+
"linnaeus": "pubmed",
|
| 346 |
+
"ncbi_disease": "pubmed",
|
| 347 |
+
"radiology": "mimic",
|
| 348 |
+
"s800": "pubmed",
|
| 349 |
+
},
|
| 350 |
+
"hu": {
|
| 351 |
+
"combined": None,
|
| 352 |
+
},
|
| 353 |
+
"nn": {
|
| 354 |
+
"norne": None,
|
| 355 |
+
},
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
# default ner for languages
|
| 359 |
+
default_ners = {
|
| 360 |
+
"af": "nchlt",
|
| 361 |
+
"ar": "aqmar_charlm",
|
| 362 |
+
"bg": "bsnlp19",
|
| 363 |
+
"da": "ddt",
|
| 364 |
+
"de": "germeval2014",
|
| 365 |
+
"en": "ontonotes-ww-multi_charlm",
|
| 366 |
+
"es": "conll02",
|
| 367 |
+
"fa": "arman",
|
| 368 |
+
"fi": "turku",
|
| 369 |
+
"fr": "wikinergold_charlm",
|
| 370 |
+
"he": "iahlt_charlm",
|
| 371 |
+
"hu": "combined",
|
| 372 |
+
"hy": "armtdp",
|
| 373 |
+
"it": "fbk",
|
| 374 |
+
"ja": "gsd",
|
| 375 |
+
"kk": "kazNERD",
|
| 376 |
+
"mr": "l3cube",
|
| 377 |
+
"my": "ucsy",
|
| 378 |
+
"nb": "norne",
|
| 379 |
+
"nl": "conll02",
|
| 380 |
+
"nn": "norne",
|
| 381 |
+
"pl": "nkjp",
|
| 382 |
+
"ru": "wikiner",
|
| 383 |
+
"sd": "siner",
|
| 384 |
+
"sv": "suc3shuffle",
|
| 385 |
+
"th": "lst20",
|
| 386 |
+
"tr": "starlang",
|
| 387 |
+
"uk": "languk",
|
| 388 |
+
"vi": "vlsp",
|
| 389 |
+
"zh-hans": "ontonotes",
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
# a few languages have sentiment classifier models
|
| 393 |
+
default_sentiment = {
|
| 394 |
+
"en": "sstplus_charlm",
|
| 395 |
+
"de": "sb10k_charlm",
|
| 396 |
+
"es": "tass2020_charlm",
|
| 397 |
+
"mr": "l3cube_charlm",
|
| 398 |
+
"vi": "vsfc_charlm",
|
| 399 |
+
"zh-hans": "ren_charlm",
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
# also, a few languages (very few, currently) have constituency parser models
|
| 403 |
+
default_constituency = {
|
| 404 |
+
"da": "arboretum_charlm",
|
| 405 |
+
"de": "spmrl_charlm",
|
| 406 |
+
"en": "ptb3-revised_charlm",
|
| 407 |
+
"es": "combined_charlm",
|
| 408 |
+
"id": "icon_charlm",
|
| 409 |
+
"it": "vit_charlm",
|
| 410 |
+
"ja": "alt_charlm",
|
| 411 |
+
"pt": "cintil_charlm",
|
| 412 |
+
#"tr": "starlang_charlm",
|
| 413 |
+
"vi": "vlsp22_charlm",
|
| 414 |
+
"zh-hans": "ctb-51_charlm",
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
optional_constituency = {
|
| 418 |
+
"tr": "starlang_charlm",
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
# an alternate tokenizer for languages which aren't trained from a base UD source
|
| 422 |
+
default_tokenizer = {
|
| 423 |
+
"my": "alt",
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
# ideally we would have a less expensive model as the base model
|
| 427 |
+
#default_coref = {
|
| 428 |
+
# "en": "ontonotes_roberta-large_finetuned",
|
| 429 |
+
#}
|
| 430 |
+
|
| 431 |
+
optional_coref = {
|
| 432 |
+
"ca": "udcoref_xlm-roberta-lora",
|
| 433 |
+
"cs": "udcoref_xlm-roberta-lora",
|
| 434 |
+
"de": "udcoref_xlm-roberta-lora",
|
| 435 |
+
"en": "udcoref_xlm-roberta-lora",
|
| 436 |
+
"es": "udcoref_xlm-roberta-lora",
|
| 437 |
+
"fr": "udcoref_xlm-roberta-lora",
|
| 438 |
+
"hi": "deeph_muril-large-cased-lora",
|
| 439 |
+
# UD Coref has both nb and nn datasets for Norwegian
|
| 440 |
+
"nb": "udcoref_xlm-roberta-lora",
|
| 441 |
+
"nn": "udcoref_xlm-roberta-lora",
|
| 442 |
+
"pl": "udcoref_xlm-roberta-lora",
|
| 443 |
+
"ru": "udcoref_xlm-roberta-lora",
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
"""
|
| 447 |
+
default transformers to use for various languages
|
| 448 |
+
|
| 449 |
+
we try to document why we choose a particular model in each case
|
| 450 |
+
"""
|
| 451 |
+
TRANSFORMERS = {
|
| 452 |
+
# We tested three candidate AR models on POS, Depparse, and NER
|
| 453 |
+
#
|
| 454 |
+
# POS: padt dev set scores, AllTags
|
| 455 |
+
# depparse: padt dev set scores, LAS
|
| 456 |
+
# NER: dev scores on a random split of AQMAR, entity scores
|
| 457 |
+
#
|
| 458 |
+
# pos depparse ner
|
| 459 |
+
# none (pt & charlm only) 94.08 83.49 84.19
|
| 460 |
+
# asafaya/bert-base-arabic 95.10 84.96 85.98
|
| 461 |
+
# aubmindlab/bert-base-arabertv2 95.33 85.28 84.93
|
| 462 |
+
# aubmindlab/araelectra-base-discriminator 95.66 85.83 86.10
|
| 463 |
+
"ar": "aubmindlab/araelectra-base-discriminator",
|
| 464 |
+
|
| 465 |
+
# https://huggingface.co/Maltehb/danish-bert-botxo
|
| 466 |
+
# contrary to normal expectations, this hurts F1
|
| 467 |
+
# on a dev split by about 1 F1
|
| 468 |
+
# "da": "Maltehb/danish-bert-botxo",
|
| 469 |
+
#
|
| 470 |
+
# the multilingual bert is a marginal improvement for conparse
|
| 471 |
+
#
|
| 472 |
+
# December 2022 update:
|
| 473 |
+
# there are quite a few Danish transformers available on HuggingFace
|
| 474 |
+
# here are the results of training a constituency parser with adadelta/adamw
|
| 475 |
+
# on each of them:
|
| 476 |
+
#
|
| 477 |
+
# no bert 0.8245 0.8230
|
| 478 |
+
# alexanderfalk/danbert-small-cased 0.8236 0.8286
|
| 479 |
+
# Geotrend/distilbert-base-da-cased 0.8268 0.8306
|
| 480 |
+
# sarnikowski/convbert-small-da-cased 0.8322 0.8341
|
| 481 |
+
# bert-base-multilingual-cased 0.8341 0.8342
|
| 482 |
+
# vesteinn/ScandiBERT-no-faroese 0.8373 0.8408
|
| 483 |
+
# Maltehb/danish-bert-botxo 0.8383 0.8408
|
| 484 |
+
# vesteinn/ScandiBERT 0.8421 0.8475
|
| 485 |
+
#
|
| 486 |
+
# Also, two models have token windows too short for use with the
|
| 487 |
+
# Danish dataset:
|
| 488 |
+
# jonfd/electra-small-nordic
|
| 489 |
+
# Maltehb/aelaectra-danish-electra-small-cased
|
| 490 |
+
#
|
| 491 |
+
"da": "vesteinn/ScandiBERT",
|
| 492 |
+
|
| 493 |
+
# As of April 2022, the bert models available have a weird
|
| 494 |
+
# tokenizer issue where soft hyphen causes it to crash.
|
| 495 |
+
# We attempt to compensate for that in the dev branch
|
| 496 |
+
#
|
| 497 |
+
# NER scores
|
| 498 |
+
# model dev text
|
| 499 |
+
# xlm-roberta-large 86.56 85.23
|
| 500 |
+
# bert-base-german-cased 87.59 86.95
|
| 501 |
+
# dbmdz/bert-base-german-cased 88.27 87.47
|
| 502 |
+
# german-nlp-group/electra-base-german-uncased 88.60 87.09
|
| 503 |
+
#
|
| 504 |
+
# constituency scores w/ peft, March 2024 model, in-order
|
| 505 |
+
# model dev test
|
| 506 |
+
# xlm-roberta-base 95.17 93.34
|
| 507 |
+
# xlm-roberta-large 95.86 94.46 (!!!)
|
| 508 |
+
# bert-base 95.24 93.24
|
| 509 |
+
# dbmdz/bert 95.32 93.33
|
| 510 |
+
# german/electra 95.72 94.05
|
| 511 |
+
#
|
| 512 |
+
# POS scores
|
| 513 |
+
# model dev test
|
| 514 |
+
# None 88.65 87.28
|
| 515 |
+
# xlm-roberta-large 89.21 88.11
|
| 516 |
+
# bert-base 89.52 88.42
|
| 517 |
+
# dbmdz/bert 89.67 88.54
|
| 518 |
+
# german/electra 89.98 88.66
|
| 519 |
+
#
|
| 520 |
+
# depparse scores, LAS
|
| 521 |
+
# model dev test
|
| 522 |
+
# None 87.76 84.37
|
| 523 |
+
# xlm-roberta-large 89.00 85.79
|
| 524 |
+
# bert-base 88.72 85.40
|
| 525 |
+
# dbmdz/bert 88.70 85.14
|
| 526 |
+
# german/electra 89.21 86.06
|
| 527 |
+
"de": "german-nlp-group/electra-base-german-uncased",
|
| 528 |
+
|
| 529 |
+
# experiments on various forms of roberta & electra
|
| 530 |
+
# https://huggingface.co/roberta-base
|
| 531 |
+
# https://huggingface.co/roberta-large
|
| 532 |
+
# https://huggingface.co/google/electra-small-discriminator
|
| 533 |
+
# https://huggingface.co/google/electra-base-discriminator
|
| 534 |
+
# https://huggingface.co/google/electra-large-discriminator
|
| 535 |
+
#
|
| 536 |
+
# experiments using the different models for POS tagging,
|
| 537 |
+
# dev set, including WV and charlm, AllTags score:
|
| 538 |
+
# roberta-base: 95.67
|
| 539 |
+
# roberta-large: 95.98
|
| 540 |
+
# electra-small: 95.31
|
| 541 |
+
# electra-base: 95.90
|
| 542 |
+
# electra-large: 96.01
|
| 543 |
+
#
|
| 544 |
+
# depparse scores, dev set, no finetuning, with WV and charlm
|
| 545 |
+
# UAS LAS CLAS MLAS BLEX
|
| 546 |
+
# roberta-base: 93.16 91.20 89.87 89.38 89.87
|
| 547 |
+
# roberta-large: 93.47 91.56 90.13 89.71 90.13
|
| 548 |
+
# electra-small: 92.17 90.02 88.25 87.66 88.25
|
| 549 |
+
# electra-base: 93.42 91.44 90.10 89.67 90.10
|
| 550 |
+
# electra-large: 94.07 92.17 90.99 90.53 90.99
|
| 551 |
+
#
|
| 552 |
+
# conparse scores, dev & test set, with WV and charlm
|
| 553 |
+
# roberta_base: 96.05 95.60
|
| 554 |
+
# roberta_large: 95.95 95.60
|
| 555 |
+
# electra-small: 95.33 95.04
|
| 556 |
+
# electra-base: 96.09 95.98
|
| 557 |
+
# electra-large: 96.25 96.14
|
| 558 |
+
#
|
| 559 |
+
# conparse scores w/ finetune, dev & test set, with WV and charlm
|
| 560 |
+
# roberta_base: 96.07 95.81
|
| 561 |
+
# roberta_large: 96.37 96.41 (!!!)
|
| 562 |
+
# electra-small: 95.62 95.36
|
| 563 |
+
# electra-base: 96.21 95.94
|
| 564 |
+
# electra-large: 96.40 96.32
|
| 565 |
+
#
|
| 566 |
+
"en": "google/electra-large-discriminator",
|
| 567 |
+
|
| 568 |
+
# TODO need to test, possibly compare with others
|
| 569 |
+
"es": "bertin-project/bertin-roberta-base-spanish",
|
| 570 |
+
|
| 571 |
+
# NER scores for a couple Persian options:
|
| 572 |
+
# none:
|
| 573 |
+
# dev: 2022-04-23 01:44:53 INFO: fa_arman 79.46
|
| 574 |
+
# test: 2022-04-23 01:45:03 INFO: fa_arman 80.06
|
| 575 |
+
#
|
| 576 |
+
# HooshvareLab/bert-fa-zwnj-base
|
| 577 |
+
# dev: 2022-04-23 02:43:44 INFO: fa_arman 80.87
|
| 578 |
+
# test: 2022-04-23 02:44:07 INFO: fa_arman 80.81
|
| 579 |
+
#
|
| 580 |
+
# HooshvareLab/roberta-fa-zwnj-base
|
| 581 |
+
# dev: 2022-04-23 16:23:25 INFO: fa_arman 81.23
|
| 582 |
+
# test: 2022-04-23 16:23:48 INFO: fa_arman 81.11
|
| 583 |
+
#
|
| 584 |
+
# HooshvareLab/bert-base-parsbert-uncased
|
| 585 |
+
# dev: 2022-04-26 10:42:09 INFO: fa_arman 82.49
|
| 586 |
+
# test: 2022-04-26 10:42:31 INFO: fa_arman 83.16
|
| 587 |
+
"fa": 'HooshvareLab/bert-base-parsbert-uncased',
|
| 588 |
+
|
| 589 |
+
# NER scores for a couple options:
|
| 590 |
+
# none:
|
| 591 |
+
# dev: 2022-03-04 INFO: fi_turku 83.45
|
| 592 |
+
# test: 2022-03-04 INFO: fi_turku 86.25
|
| 593 |
+
#
|
| 594 |
+
# bert-base-multilingual-cased
|
| 595 |
+
# dev: 2022-03-04 INFO: fi_turku 85.23
|
| 596 |
+
# test: 2022-03-04 INFO: fi_turku 89.00
|
| 597 |
+
#
|
| 598 |
+
# TurkuNLP/bert-base-finnish-cased-v1:
|
| 599 |
+
# dev: 2022-03-04 INFO: fi_turku 88.41
|
| 600 |
+
# test: 2022-03-04 INFO: fi_turku 91.36
|
| 601 |
+
"fi": "TurkuNLP/bert-base-finnish-cased-v1",
|
| 602 |
+
|
| 603 |
+
# POS dev set tagging results for French:
|
| 604 |
+
# No bert:
|
| 605 |
+
# 98.60 100.00 98.55 98.04
|
| 606 |
+
# dbmdz/electra-base-french-europeana-cased-discriminator
|
| 607 |
+
# 98.70 100.00 98.69 98.24
|
| 608 |
+
# benjamin/roberta-base-wechsel-french
|
| 609 |
+
# 98.71 100.00 98.75 98.26
|
| 610 |
+
# camembert/camembert-large
|
| 611 |
+
# 98.75 100.00 98.75 98.30
|
| 612 |
+
# camembert-base
|
| 613 |
+
# 98.78 100.00 98.77 98.33
|
| 614 |
+
#
|
| 615 |
+
# GSD depparse dev set results for French:
|
| 616 |
+
# No bert:
|
| 617 |
+
# 95.83 94.52 91.34 91.10 91.34
|
| 618 |
+
# camembert/camembert-large
|
| 619 |
+
# 96.80 95.71 93.37 93.13 93.37
|
| 620 |
+
# TODO: the rest of the chart
|
| 621 |
+
"fr": "camembert/camembert-large",
|
| 622 |
+
|
| 623 |
+
# Ancient Greek has a surprising number of transformers, considering
|
| 624 |
+
# Model POS Depparse LAS
|
| 625 |
+
# None 0.8812 0.7684
|
| 626 |
+
# Microbert M 0.8883 0.7706
|
| 627 |
+
# Microbert MX 0.8910 0.7755
|
| 628 |
+
# Microbert MXP 0.8916 0.7742
|
| 629 |
+
# Pranaydeeps Bert 0.9139 0.7987
|
| 630 |
+
"grc": "pranaydeeps/Ancient-Greek-BERT",
|
| 631 |
+
|
| 632 |
+
# a couple possibilities to experiment with for Hebrew
|
| 633 |
+
# dev scores for POS and depparse
|
| 634 |
+
# https://huggingface.co/imvladikon/alephbertgimmel-base-512
|
| 635 |
+
# UPOS XPOS UFeats AllTags
|
| 636 |
+
# 97.25 97.25 92.84 91.81
|
| 637 |
+
# UAS LAS CLAS MLAS BLEX
|
| 638 |
+
# 94.42 92.47 89.49 88.82 89.49
|
| 639 |
+
#
|
| 640 |
+
# https://huggingface.co/onlplab/alephbert-base
|
| 641 |
+
# UPOS XPOS UFeats AllTags
|
| 642 |
+
# 97.37 97.37 92.50 91.55
|
| 643 |
+
# UAS LAS CLAS MLAS BLEX
|
| 644 |
+
# 94.06 92.12 88.80 88.13 88.80
|
| 645 |
+
#
|
| 646 |
+
# https://huggingface.co/avichr/heBERT
|
| 647 |
+
# UPOS XPOS UFeats AllTags
|
| 648 |
+
# 97.09 97.09 92.36 91.28
|
| 649 |
+
# UAS LAS CLAS MLAS BLEX
|
| 650 |
+
# 94.29 92.30 88.99 88.38 88.99
|
| 651 |
+
"he": "imvladikon/alephbertgimmel-base-512",
|
| 652 |
+
|
| 653 |
+
# can also experiment with xlm-roberta
|
| 654 |
+
# on a coref dataset from IITH, span F1:
|
| 655 |
+
# dev test
|
| 656 |
+
# xlm-roberta-large 0.63635 0.66579
|
| 657 |
+
# muril-large 0.65369 0.68290
|
| 658 |
+
"hi": "google/muril-large-cased",
|
| 659 |
+
|
| 660 |
+
# https://huggingface.co/xlm-roberta-base
|
| 661 |
+
# Scores by entity for armtdp NER on 18 labels:
|
| 662 |
+
# no bert : 86.68
|
| 663 |
+
# xlm-roberta-base : 89.31
|
| 664 |
+
"hy": "xlm-roberta-base",
|
| 665 |
+
|
| 666 |
+
# Indonesian POS experiments: dev set of GSD
|
| 667 |
+
# python3 stanza/utils/training/run_pos.py id_gsd --no_bert
|
| 668 |
+
# python3 stanza/utils/training/run_pos.py id_gsd --bert_model ...
|
| 669 |
+
# also ran on the ICON constituency dataset
|
| 670 |
+
# model POS CON
|
| 671 |
+
# no_bert 89.95 84.74
|
| 672 |
+
# flax-community/indonesian-roberta-large 89.78 (!) xxx
|
| 673 |
+
# flax-community/indonesian-roberta-base 90.14 xxx
|
| 674 |
+
# indobenchmark/indobert-base-p2 90.09
|
| 675 |
+
# indobenchmark/indobert-base-p1 90.14
|
| 676 |
+
# indobenchmark/indobert-large-p1 90.19
|
| 677 |
+
# indolem/indobert-base-uncased 90.21 88.60
|
| 678 |
+
# cahya/bert-base-indonesian-1.5G 90.32 88.15
|
| 679 |
+
# cahya/roberta-base-indonesian-1.5G 90.40 87.27
|
| 680 |
+
"id": "indolem/indobert-base-uncased",
|
| 681 |
+
|
| 682 |
+
# from https://github.com/idb-ita/GilBERTo
|
| 683 |
+
# annoyingly, it doesn't handle cased text
|
| 684 |
+
# supposedly there is an argument "do_lower_case"
|
| 685 |
+
# but that still leaves a lot of unk tokens
|
| 686 |
+
# "it": "idb-ita/gilberto-uncased-from-camembert",
|
| 687 |
+
#
|
| 688 |
+
# from https://github.com/musixmatchresearch/umberto
|
| 689 |
+
# on NER, this gets 88.37 dev and 91.02 test
|
| 690 |
+
# another option is dbmdz/bert-base-italian-cased,
|
| 691 |
+
# which gets 87.27 dev and 90.32 test
|
| 692 |
+
#
|
| 693 |
+
# in-order constituency parser on the VIT dev set:
|
| 694 |
+
# dbmdz/bert-base-italian-cased 0.8079
|
| 695 |
+
# dbmdz/bert-base-italian-xxl-cased: 0.8195
|
| 696 |
+
# Musixmatch/umberto-commoncrawl-cased-v1: 0.8256
|
| 697 |
+
# dbmdz/electra-base-italian-xxl-cased-discriminator: 0.8314
|
| 698 |
+
#
|
| 699 |
+
# FBK NER dev set:
|
| 700 |
+
# dbmdz/bert-base-italian-cased: 87.76
|
| 701 |
+
# Musixmatch/umberto-commoncrawl-cased-v1: 88.62
|
| 702 |
+
# dbmdz/bert-base-italian-xxl-cased: 88.84
|
| 703 |
+
# dbmdz/electra-base-italian-xxl-cased-discriminator: 89.91
|
| 704 |
+
#
|
| 705 |
+
# combined UD POS dev set: UPOS XPOS UFeats AllTags
|
| 706 |
+
# dbmdz/bert-base-italian-cased: 98.62 98.53 98.06 97.49
|
| 707 |
+
# dbmdz/bert-base-italian-xxl-cased: 98.61 98.54 98.07 97.58
|
| 708 |
+
# dbmdz/electra-base-italian-xxl-cased-discriminator: 98.64 98.54 98.14 97.61
|
| 709 |
+
# Musixmatch/umberto-commoncrawl-cased-v1: 98.56 98.45 98.13 97.62
|
| 710 |
+
"it": "dbmdz/electra-base-italian-xxl-cased-discriminator",
|
| 711 |
+
|
| 712 |
+
# for Japanese
|
| 713 |
+
# there are others that would also work,
|
| 714 |
+
# but they require different tokenizers instead of being
|
| 715 |
+
# plug & play
|
| 716 |
+
#
|
| 717 |
+
# Constitutency scores on ALT (in-order)
|
| 718 |
+
# no bert: 90.68 dev, 91.40 test
|
| 719 |
+
# rinna: 91.54 dev, 91.89 test
|
| 720 |
+
"ja": "rinna/japanese-roberta-base",
|
| 721 |
+
|
| 722 |
+
# could also try:
|
| 723 |
+
# l3cube-pune/marathi-bert-v2
|
| 724 |
+
# or
|
| 725 |
+
# https://huggingface.co/l3cube-pune/hindi-marathi-dev-roberta
|
| 726 |
+
# l3cube-pune/hindi-marathi-dev-roberta
|
| 727 |
+
#
|
| 728 |
+
# depparse ufal dev scores:
|
| 729 |
+
# no transformer 74.89 63.70 57.43 53.01 57.43
|
| 730 |
+
# l3cube-pune/marathi-roberta 76.48 66.21 61.20 57.60 61.20
|
| 731 |
+
"mr": "l3cube-pune/marathi-roberta",
|
| 732 |
+
|
| 733 |
+
# https://huggingface.co/allegro/herbert-base-cased
|
| 734 |
+
# Scores by entity on the NKJP NER task:
|
| 735 |
+
# no bert (dev/test): 88.64/88.75
|
| 736 |
+
# herbert-base-cased (dev/test): 91.48/91.02,
|
| 737 |
+
# herbert-large-cased (dev/test): 92.25/91.62
|
| 738 |
+
# sdadas/polish-roberta-large-v2 (dev/test): 92.66/91.22
|
| 739 |
+
"pl": "allegro/herbert-base-cased",
|
| 740 |
+
|
| 741 |
+
# experiments on the cintil conparse dataset
|
| 742 |
+
# ran a variety of transformer settings
|
| 743 |
+
# found the following dev set scores after 400 iterations:
|
| 744 |
+
# Geotrend/distilbert-base-pt-cased : not plug & play
|
| 745 |
+
# no bert: 0.9082
|
| 746 |
+
# xlm-roberta-base: 0.9109
|
| 747 |
+
# xlm-roberta-large: 0.9254
|
| 748 |
+
# adalbertojunior/distilbert-portuguese-cased: 0.9300
|
| 749 |
+
# neuralmind/bert-base-portuguese-cased: 0.9307
|
| 750 |
+
# neuralmind/bert-large-portuguese-cased: 0.9343
|
| 751 |
+
"pt": "neuralmind/bert-large-portuguese-cased",
|
| 752 |
+
|
| 753 |
+
# hope is actually to build our own using a large text collection
|
| 754 |
+
"sd": "google/muril-large-cased",
|
| 755 |
+
|
| 756 |
+
# Tamil options: quite a few, need to run a bunch of experiments
|
| 757 |
+
# dev pos dev depparse las
|
| 758 |
+
# no transformer 82.82 69.12
|
| 759 |
+
# ai4bharat/indic-bert 82.98 70.47
|
| 760 |
+
# lgessler/microbert-tamil-mxp 83.21 69.28
|
| 761 |
+
# monsoon-nlp/tamillion 83.37 69.28
|
| 762 |
+
# l3cube-pune/tamil-bert 85.27 72.53
|
| 763 |
+
# d42kw01f/Tamil-RoBERTa 85.59 70.55
|
| 764 |
+
# google/muril-base-cased 85.67 72.68
|
| 765 |
+
# google/muril-large-cased 86.30 72.45
|
| 766 |
+
"ta": "google/muril-large-cased",
|
| 767 |
+
|
| 768 |
+
|
| 769 |
+
# https://huggingface.co/dbmdz/bert-base-turkish-128k-cased
|
| 770 |
+
# helps the Turkish model quite a bit
|
| 771 |
+
"tr": "dbmdz/bert-base-turkish-128k-cased",
|
| 772 |
+
|
| 773 |
+
# from https://github.com/VinAIResearch/PhoBERT
|
| 774 |
+
# "vi": "vinai/phobert-base",
|
| 775 |
+
# using 6 or 7 layers of phobert-large is slightly
|
| 776 |
+
# more effective for constituency parsing than
|
| 777 |
+
# using 4 layers of phobert-base
|
| 778 |
+
# ... going beyond 4 layers of phobert-base
|
| 779 |
+
# does not help the scores
|
| 780 |
+
"vi": "vinai/phobert-large",
|
| 781 |
+
|
| 782 |
+
# https://github.com/ymcui/Chinese-BERT-wwm
|
| 783 |
+
# there's also hfl/chinese-roberta-wwm-ext-large
|
| 784 |
+
# or hfl/chinese-electra-base-discriminator
|
| 785 |
+
# or hfl/chinese-electra-180g-large-discriminator,
|
| 786 |
+
# which works better than the below roberta on constituency
|
| 787 |
+
# "zh-hans": "hfl/chinese-roberta-wwm-ext",
|
| 788 |
+
"zh-hans": "hfl/chinese-electra-180g-large-discriminator",
|
| 789 |
+
}
|
| 790 |
+
|
| 791 |
+
TRANSFORMER_LAYERS = {
|
| 792 |
+
# not clear what the best number is without more experiments,
|
| 793 |
+
# but more than 4 is working better than just 4
|
| 794 |
+
"vi": 7,
|
| 795 |
+
}
|
| 796 |
+
|
| 797 |
+
TRANSFORMER_NICKNAMES = {
|
| 798 |
+
# ar
|
| 799 |
+
"asafaya/bert-base-arabic": "asafaya-bert",
|
| 800 |
+
"aubmindlab/araelectra-base-discriminator": "aubmind-electra",
|
| 801 |
+
"aubmindlab/bert-base-arabertv2": "aubmind-bert",
|
| 802 |
+
|
| 803 |
+
# da
|
| 804 |
+
"vesteinn/ScandiBERT": "scandibert",
|
| 805 |
+
|
| 806 |
+
# de
|
| 807 |
+
"bert-base-german-cased": "bert-base-german-cased",
|
| 808 |
+
"dbmdz/bert-base-german-cased": "dbmdz-bert-german-cased",
|
| 809 |
+
"german-nlp-group/electra-base-german-uncased": "german-nlp-electra",
|
| 810 |
+
|
| 811 |
+
# en
|
| 812 |
+
"bert-base-multilingual-cased": "mbert",
|
| 813 |
+
"xlm-roberta-large": "xlm-roberta-large",
|
| 814 |
+
"google/electra-large-discriminator": "electra-large",
|
| 815 |
+
"microsoft/deberta-v3-large": "deberta-v3-large",
|
| 816 |
+
|
| 817 |
+
# es
|
| 818 |
+
"bertin-project/bertin-roberta-base-spanish": "bertin-roberta",
|
| 819 |
+
|
| 820 |
+
# fa
|
| 821 |
+
"HooshvareLab/bert-base-parsbert-uncased": "parsbert",
|
| 822 |
+
|
| 823 |
+
# fi
|
| 824 |
+
"TurkuNLP/bert-base-finnish-cased-v1": "bert",
|
| 825 |
+
|
| 826 |
+
# fr
|
| 827 |
+
"benjamin/roberta-base-wechsel-french": "wechsel-roberta",
|
| 828 |
+
"camembert-base": "camembert-base",
|
| 829 |
+
"camembert/camembert-large": "camembert-large",
|
| 830 |
+
"dbmdz/electra-base-french-europeana-cased-discriminator": "dbmdz-electra",
|
| 831 |
+
|
| 832 |
+
# grc
|
| 833 |
+
"pranaydeeps/Ancient-Greek-BERT": "grc-pranaydeeps",
|
| 834 |
+
"lgessler/microbert-ancient-greek-m": "grc-microbert-m",
|
| 835 |
+
"lgessler/microbert-ancient-greek-mx": "grc-microbert-mx",
|
| 836 |
+
"lgessler/microbert-ancient-greek-mxp": "grc-microbert-mxp",
|
| 837 |
+
"altsoph/bert-base-ancientgreek-uncased": "grc-altsoph",
|
| 838 |
+
|
| 839 |
+
# he
|
| 840 |
+
"imvladikon/alephbertgimmel-base-512" : "alephbertgimmel",
|
| 841 |
+
|
| 842 |
+
# hy
|
| 843 |
+
"xlm-roberta-base": "xlm-roberta-base",
|
| 844 |
+
|
| 845 |
+
# id
|
| 846 |
+
"indolem/indobert-base-uncased": "indobert",
|
| 847 |
+
"indobenchmark/indobert-large-p1": "indobenchmark-large-p1",
|
| 848 |
+
"indobenchmark/indobert-base-p1": "indobenchmark-base-p1",
|
| 849 |
+
"indobenchmark/indobert-lite-large-p1": "indobenchmark-lite-large-p1",
|
| 850 |
+
"indobenchmark/indobert-lite-base-p1": "indobenchmark-lite-base-p1",
|
| 851 |
+
"indobenchmark/indobert-large-p2": "indobenchmark-large-p2",
|
| 852 |
+
"indobenchmark/indobert-base-p2": "indobenchmark-base-p2",
|
| 853 |
+
"indobenchmark/indobert-lite-large-p2": "indobenchmark-lite-large-p2",
|
| 854 |
+
"indobenchmark/indobert-lite-base-p2": "indobenchmark-lite-base-p2",
|
| 855 |
+
|
| 856 |
+
# it
|
| 857 |
+
"dbmdz/electra-base-italian-xxl-cased-discriminator": "electra",
|
| 858 |
+
|
| 859 |
+
# ja
|
| 860 |
+
"rinna/japanese-roberta-base": "rinna-roberta",
|
| 861 |
+
|
| 862 |
+
# mr
|
| 863 |
+
"l3cube-pune/marathi-roberta": "l3cube-marathi-roberta",
|
| 864 |
+
|
| 865 |
+
# pl
|
| 866 |
+
"allegro/herbert-base-cased": "herbert",
|
| 867 |
+
|
| 868 |
+
# pt
|
| 869 |
+
"neuralmind/bert-large-portuguese-cased": "bertimbau",
|
| 870 |
+
|
| 871 |
+
# ta: tamil
|
| 872 |
+
"monsoon-nlp/tamillion": "tamillion",
|
| 873 |
+
"lgessler/microbert-tamil-m": "ta-microbert-m",
|
| 874 |
+
"lgessler/microbert-tamil-mxp": "ta-microbert-mxp",
|
| 875 |
+
"l3cube-pune/tamil-bert": "l3cube-tamil-bert",
|
| 876 |
+
"d42kw01f/Tamil-RoBERTa": "ta-d42kw01f-roberta",
|
| 877 |
+
|
| 878 |
+
# tr
|
| 879 |
+
"dbmdz/bert-base-turkish-128k-cased": "bert",
|
| 880 |
+
|
| 881 |
+
# vi
|
| 882 |
+
"vinai/phobert-base": "phobert-base",
|
| 883 |
+
"vinai/phobert-large": "phobert-large",
|
| 884 |
+
|
| 885 |
+
# zh
|
| 886 |
+
"google-bert/bert-base-chinese": "google-bert-chinese",
|
| 887 |
+
"hfl/chinese-roberta-wwm-ext": "hfl-roberta-chinese",
|
| 888 |
+
"hfl/chinese-electra-180g-large-discriminator": "electra-large",
|
| 889 |
+
|
| 890 |
+
# multi-lingual Indic
|
| 891 |
+
"ai4bharat/indic-bert": "indic-bert",
|
| 892 |
+
"google/muril-base-cased": "muril-base-cased",
|
| 893 |
+
"google/muril-large-cased": "muril-large-cased",
|
| 894 |
+
}
|
| 895 |
+
|
| 896 |
+
def known_nicknames():
|
| 897 |
+
"""
|
| 898 |
+
Return a list of all the transformer nicknames
|
| 899 |
+
|
| 900 |
+
We return a list so that we can sort them in decreasing key length
|
| 901 |
+
"""
|
| 902 |
+
nicknames = list(value for key, value in TRANSFORMER_NICKNAMES.items())
|
| 903 |
+
|
| 904 |
+
# previously unspecific transformers get "transformer" as the nickname
|
| 905 |
+
nicknames.append("transformer")
|
| 906 |
+
|
| 907 |
+
nicknames = sorted(nicknames, key=lambda x: -len(x))
|
| 908 |
+
|
| 909 |
+
return nicknames
|
stanza/stanza/resources/installation.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Functions for setting up the environments.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import logging
|
| 7 |
+
import zipfile
|
| 8 |
+
import shutil
|
| 9 |
+
|
| 10 |
+
from stanza.resources.common import HOME_DIR, request_file, unzip, \
|
| 11 |
+
get_root_from_zipfile, set_logging_level
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger('stanza')
|
| 14 |
+
|
| 15 |
+
DEFAULT_CORENLP_MODEL_URL = os.getenv(
|
| 16 |
+
'CORENLP_MODEL_URL',
|
| 17 |
+
'https://huggingface.co/stanfordnlp/corenlp-{model}/resolve/{tag}/stanford-corenlp-models-{model}.jar'
|
| 18 |
+
)
|
| 19 |
+
BACKUP_CORENLP_MODEL_URL = "http://nlp.stanford.edu/software/stanford-corenlp-{version}-models-{model}.jar"
|
| 20 |
+
|
| 21 |
+
DEFAULT_CORENLP_URL = os.getenv(
|
| 22 |
+
'CORENLP_MODEL_URL',
|
| 23 |
+
'https://huggingface.co/stanfordnlp/CoreNLP/resolve/{tag}/stanford-corenlp-latest.zip'
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
DEFAULT_CORENLP_DIR = os.getenv(
|
| 27 |
+
'CORENLP_HOME',
|
| 28 |
+
os.path.join(HOME_DIR, 'stanza_corenlp')
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
AVAILABLE_MODELS = set(['arabic', 'chinese', 'english-extra', 'english-kbp', 'french', 'german', 'hungarian', 'italian', 'spanish'])
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def download_corenlp_models(model, version, dir=DEFAULT_CORENLP_DIR, url=DEFAULT_CORENLP_MODEL_URL, logging_level='INFO', proxies=None, force=True):
|
| 35 |
+
"""
|
| 36 |
+
A automatic way to download the CoreNLP models.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
model: the name of the model, can be one of 'arabic', 'chinese', 'english',
|
| 40 |
+
'english-kbp', 'french', 'german', 'hungarian', 'italian', 'spanish'
|
| 41 |
+
version: the version of the model
|
| 42 |
+
dir: the directory to download CoreNLP model into; alternatively can be
|
| 43 |
+
set up with environment variable $CORENLP_HOME
|
| 44 |
+
url: The link to download CoreNLP models.
|
| 45 |
+
It will need {model} and either {version} or {tag} to properly format the URL
|
| 46 |
+
logging_level: logging level to use during installation
|
| 47 |
+
force: Download model anyway, no matter model file exists or not
|
| 48 |
+
"""
|
| 49 |
+
dir = os.path.expanduser(dir)
|
| 50 |
+
if not model or not version:
|
| 51 |
+
raise ValueError(
|
| 52 |
+
"Both model and model version should be specified."
|
| 53 |
+
)
|
| 54 |
+
logger.info(f"Downloading {model} models (version {version}) into directory {dir}")
|
| 55 |
+
model = model.strip().lower()
|
| 56 |
+
if model not in AVAILABLE_MODELS:
|
| 57 |
+
raise KeyError(
|
| 58 |
+
f'{model} is currently not supported. '
|
| 59 |
+
f'Must be one of: {list(AVAILABLE_MODELS)}.'
|
| 60 |
+
)
|
| 61 |
+
# for example:
|
| 62 |
+
# https://huggingface.co/stanfordnlp/CoreNLP/resolve/v4.2.2/stanford-corenlp-models-french.jar
|
| 63 |
+
tag = version if version == 'main' else 'v' + version
|
| 64 |
+
download_url = url.format(tag=tag, model=model, version=version)
|
| 65 |
+
model_path = os.path.join(dir, f'stanford-corenlp-{version}-models-{model}.jar')
|
| 66 |
+
|
| 67 |
+
if os.path.exists(model_path) and not force:
|
| 68 |
+
logger.warn(
|
| 69 |
+
f"Model file {model_path} already exists. "
|
| 70 |
+
f"Please download this model to a new directory.")
|
| 71 |
+
return
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
request_file(
|
| 75 |
+
download_url,
|
| 76 |
+
model_path,
|
| 77 |
+
proxies
|
| 78 |
+
)
|
| 79 |
+
except (KeyboardInterrupt, SystemExit):
|
| 80 |
+
raise
|
| 81 |
+
except Exception as e:
|
| 82 |
+
raise RuntimeError(
|
| 83 |
+
"Downloading CoreNLP model file failed. "
|
| 84 |
+
"Please try manual downloading at: https://stanfordnlp.github.io/CoreNLP/."
|
| 85 |
+
) from e
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def install_corenlp(dir=DEFAULT_CORENLP_DIR, url=DEFAULT_CORENLP_URL, logging_level=None, proxies=None, version="main"):
|
| 89 |
+
"""
|
| 90 |
+
A fully automatic way to install and setting up the CoreNLP library
|
| 91 |
+
to use the client functionality.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
dir: the directory to download CoreNLP model into; alternatively can be
|
| 95 |
+
set up with environment variable $CORENLP_HOME
|
| 96 |
+
url: The link to download CoreNLP models
|
| 97 |
+
Needs a {version} or {tag} parameter to specify the version
|
| 98 |
+
logging_level: logging level to use during installation
|
| 99 |
+
"""
|
| 100 |
+
dir = os.path.expanduser(dir)
|
| 101 |
+
set_logging_level(logging_level=logging_level, verbose=None)
|
| 102 |
+
if os.path.exists(dir) and len(os.listdir(dir)) > 0:
|
| 103 |
+
logger.warn(
|
| 104 |
+
f"Directory {dir} already exists. "
|
| 105 |
+
f"Please install CoreNLP to a new directory.")
|
| 106 |
+
return
|
| 107 |
+
|
| 108 |
+
logger.info(f"Installing CoreNLP package into {dir}")
|
| 109 |
+
# First download the URL package
|
| 110 |
+
logger.debug(f"Download to destination file: {os.path.join(dir, 'corenlp.zip')}")
|
| 111 |
+
tag = version if version == 'main' else 'v' + version
|
| 112 |
+
url = url.format(version=version, tag=tag)
|
| 113 |
+
try:
|
| 114 |
+
request_file(url, os.path.join(dir, 'corenlp.zip'), proxies)
|
| 115 |
+
|
| 116 |
+
except (KeyboardInterrupt, SystemExit):
|
| 117 |
+
raise
|
| 118 |
+
except Exception as e:
|
| 119 |
+
raise RuntimeError(
|
| 120 |
+
"Downloading CoreNLP zip file failed. "
|
| 121 |
+
"Please try manual installation: https://stanfordnlp.github.io/CoreNLP/."
|
| 122 |
+
) from e
|
| 123 |
+
|
| 124 |
+
# Unzip corenlp into dir
|
| 125 |
+
logger.debug("Unzipping downloaded zip file...")
|
| 126 |
+
unzip(dir, 'corenlp.zip')
|
| 127 |
+
|
| 128 |
+
# By default CoreNLP will be unzipped into a version-dependent folder,
|
| 129 |
+
# e.g., stanford-corenlp-4.0.0. We need some hack around that and move
|
| 130 |
+
# files back into our designated folder
|
| 131 |
+
logger.debug(f"Moving files into the designated folder at: {dir}")
|
| 132 |
+
corenlp_dirname = get_root_from_zipfile(os.path.join(dir, 'corenlp.zip'))
|
| 133 |
+
corenlp_dirname = os.path.join(dir, corenlp_dirname)
|
| 134 |
+
for f in os.listdir(corenlp_dirname):
|
| 135 |
+
shutil.move(os.path.join(corenlp_dirname, f), dir)
|
| 136 |
+
|
| 137 |
+
# Remove original zip and folder
|
| 138 |
+
logger.debug("Removing downloaded zip file...")
|
| 139 |
+
os.remove(os.path.join(dir, 'corenlp.zip'))
|
| 140 |
+
shutil.rmtree(corenlp_dirname)
|
| 141 |
+
|
| 142 |
+
# Warn user to set up env
|
| 143 |
+
if dir != DEFAULT_CORENLP_DIR:
|
| 144 |
+
logger.warning(
|
| 145 |
+
f"For customized installation location, please set the `CORENLP_HOME` "
|
| 146 |
+
f"environment variable to the location of the installation. "
|
| 147 |
+
f"In Unix, this is done with `export CORENLP_HOME={dir}`.")
|
| 148 |
+
|
stanza/stanza/resources/prepare_resources.py
ADDED
|
@@ -0,0 +1,670 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Converts a directory of models organized by type into a directory organized by language.
|
| 3 |
+
|
| 4 |
+
Also produces the resources.json file.
|
| 5 |
+
|
| 6 |
+
For example, on the cluster, you can do this:
|
| 7 |
+
|
| 8 |
+
python3 -m stanza.resources.prepare_resources --input_dir /u/nlp/software/stanza/models/current-models-1.5.0 --output_dir /u/nlp/software/stanza/models/1.5.0 > resources.out 2>&1
|
| 9 |
+
nlprun -a stanza-1.2 -q john "python3 -m stanza.resources.prepare_resources --input_dir /u/nlp/software/stanza/models/current-models-1.5.0 --output_dir /u/nlp/software/stanza/models/1.5.0" -o resources.out
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
from collections import defaultdict
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
import hashlib
|
| 18 |
+
import shutil
|
| 19 |
+
import zipfile
|
| 20 |
+
|
| 21 |
+
from stanza import __resources_version__
|
| 22 |
+
from stanza.models.common.constant import lcode2lang, two_to_three_letters, three_to_two_letters
|
| 23 |
+
from stanza.resources.default_packages import PACKAGES, TRANSFORMERS, TRANSFORMER_NICKNAMES
|
| 24 |
+
from stanza.resources.default_packages import *
|
| 25 |
+
from stanza.utils.datasets.prepare_lemma_classifier import DATASET_MAPPING as LEMMA_CLASSIFIER_DATASETS
|
| 26 |
+
from stanza.utils.get_tqdm import get_tqdm
|
| 27 |
+
|
| 28 |
+
tqdm = get_tqdm()
|
| 29 |
+
|
| 30 |
+
def parse_args():
|
| 31 |
+
parser = argparse.ArgumentParser()
|
| 32 |
+
parser.add_argument('--input_dir', type=str, default="/u/nlp/software/stanza/models/current-models-%s" % __resources_version__, help='Input dir for various models. Defaults to the recommended home on the nlp cluster')
|
| 33 |
+
parser.add_argument('--output_dir', type=str, default="/u/nlp/software/stanza/models/%s" % __resources_version__, help='Output dir for various models.')
|
| 34 |
+
parser.add_argument('--packages_only', action='store_true', default=False, help='Only build the package maps instead of rebuilding everything')
|
| 35 |
+
parser.add_argument('--lang', type=str, default=None, help='Only process this language or a comma-separated list of languages. If left blank, will prepare all languages. To use this argument, a previous prepared resources with all of the languages is necessary.')
|
| 36 |
+
args = parser.parse_args()
|
| 37 |
+
args.input_dir = os.path.abspath(args.input_dir)
|
| 38 |
+
args.output_dir = os.path.abspath(args.output_dir)
|
| 39 |
+
if args.lang is not None:
|
| 40 |
+
args.lang = ",".join(args.lang.strip().split())
|
| 41 |
+
return args
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
allowed_empty_languages = [
|
| 45 |
+
# we don't have a lot of Thai support yet
|
| 46 |
+
"th",
|
| 47 |
+
# only tokenize and NER for Myanmar right now (soon...)
|
| 48 |
+
"my",
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
# map processor name to file ending
|
| 52 |
+
# the order of this dict determines the order in which default.zip files are built
|
| 53 |
+
# changing it will necessitate rebuilding all of the default.zip files
|
| 54 |
+
# not a disaster, but it would involve a bunch of uploading
|
| 55 |
+
processor_to_ending = {
|
| 56 |
+
"tokenize": "tokenizer",
|
| 57 |
+
"mwt": "mwt_expander",
|
| 58 |
+
"lemma": "lemmatizer",
|
| 59 |
+
"pos": "tagger",
|
| 60 |
+
"depparse": "parser",
|
| 61 |
+
"pretrain": "pretrain",
|
| 62 |
+
"ner": "nertagger",
|
| 63 |
+
"forward_charlm": "forward_charlm",
|
| 64 |
+
"backward_charlm": "backward_charlm",
|
| 65 |
+
"sentiment": "sentiment",
|
| 66 |
+
"constituency": "constituency",
|
| 67 |
+
"coref": "coref",
|
| 68 |
+
"langid": "langid",
|
| 69 |
+
}
|
| 70 |
+
ending_to_processor = {j: i for i, j in processor_to_ending.items()}
|
| 71 |
+
PROCESSORS = list(processor_to_ending.keys())
|
| 72 |
+
|
| 73 |
+
def ensure_dir(dir):
|
| 74 |
+
Path(dir).mkdir(parents=True, exist_ok=True)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def copy_file(src, dst):
|
| 78 |
+
ensure_dir(Path(dst).parent)
|
| 79 |
+
shutil.copy2(src, dst)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_md5(path):
|
| 83 |
+
data = open(path, 'rb').read()
|
| 84 |
+
return hashlib.md5(data).hexdigest()
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def split_model_name(model):
|
| 88 |
+
"""
|
| 89 |
+
Split model names by _
|
| 90 |
+
|
| 91 |
+
Takes into account packages with _ and processor types with _
|
| 92 |
+
"""
|
| 93 |
+
model = model[:-3].replace('.', '_')
|
| 94 |
+
# sort by key length so that nertagger is checked before tagger, for example
|
| 95 |
+
for processor in sorted(ending_to_processor.keys(), key=lambda x: -len(x)):
|
| 96 |
+
if model.endswith(processor):
|
| 97 |
+
model = model[:-(len(processor)+1)]
|
| 98 |
+
processor = ending_to_processor[processor]
|
| 99 |
+
break
|
| 100 |
+
else:
|
| 101 |
+
raise AssertionError(f"Could not find a processor type in {model}")
|
| 102 |
+
lang, package = model.split('_', 1)
|
| 103 |
+
return lang, package, processor
|
| 104 |
+
|
| 105 |
+
def split_package(package):
|
| 106 |
+
if package.endswith("_finetuned"):
|
| 107 |
+
package = package[:-10]
|
| 108 |
+
|
| 109 |
+
if package.endswith("_nopretrain"):
|
| 110 |
+
package = package[:-11]
|
| 111 |
+
return package, False, False
|
| 112 |
+
if package.endswith("_nocharlm"):
|
| 113 |
+
package = package[:-9]
|
| 114 |
+
return package, True, False
|
| 115 |
+
if package.endswith("_charlm"):
|
| 116 |
+
package = package[:-7]
|
| 117 |
+
return package, True, True
|
| 118 |
+
underscore = package.rfind("_")
|
| 119 |
+
if underscore >= 0:
|
| 120 |
+
# +1 to skip the underscore
|
| 121 |
+
nickname = package[underscore+1:]
|
| 122 |
+
if nickname in known_nicknames():
|
| 123 |
+
return package[:underscore], True, True
|
| 124 |
+
|
| 125 |
+
# guess it was a model which wasn't built with the new naming convention of putting the pretrain type at the end
|
| 126 |
+
# assume WV and charlm... if the language / package doesn't allow for one, that should be caught later
|
| 127 |
+
return package, True, True
|
| 128 |
+
|
| 129 |
+
def get_pretrain_package(lang, package, model_pretrains, default_pretrains):
|
| 130 |
+
package, uses_pretrain, _ = split_package(package)
|
| 131 |
+
|
| 132 |
+
if not uses_pretrain or lang in no_pretrain_languages:
|
| 133 |
+
return None
|
| 134 |
+
elif model_pretrains is not None and lang in model_pretrains and package in model_pretrains[lang]:
|
| 135 |
+
return model_pretrains[lang][package]
|
| 136 |
+
elif lang in default_pretrains:
|
| 137 |
+
return default_pretrains[lang]
|
| 138 |
+
|
| 139 |
+
raise RuntimeError("pretrain not specified for lang %s package %s" % (lang, package))
|
| 140 |
+
|
| 141 |
+
def get_charlm_package(lang, package, model_charlms, default_charlms):
|
| 142 |
+
package, _, uses_charlm = split_package(package)
|
| 143 |
+
|
| 144 |
+
if not uses_charlm:
|
| 145 |
+
return None
|
| 146 |
+
|
| 147 |
+
if model_charlms is not None and lang in model_charlms and package in model_charlms[lang]:
|
| 148 |
+
return model_charlms[lang][package]
|
| 149 |
+
else:
|
| 150 |
+
return default_charlms.get(lang, None)
|
| 151 |
+
|
| 152 |
+
def get_con_dependencies(lang, package):
|
| 153 |
+
# so far, this invariant is true:
|
| 154 |
+
# constituency models use the default pretrain and charlm for the language
|
| 155 |
+
# sometimes there is no charlm for a language that has constituency, though
|
| 156 |
+
pretrain_package = get_pretrain_package(lang, package, None, default_pretrains)
|
| 157 |
+
dependencies = [{'model': 'pretrain', 'package': pretrain_package}]
|
| 158 |
+
|
| 159 |
+
charlm_package = default_charlms.get(lang, None)
|
| 160 |
+
if charlm_package is not None:
|
| 161 |
+
dependencies.append({'model': 'forward_charlm', 'package': charlm_package})
|
| 162 |
+
dependencies.append({'model': 'backward_charlm', 'package': charlm_package})
|
| 163 |
+
|
| 164 |
+
return dependencies
|
| 165 |
+
|
| 166 |
+
def get_pos_charlm_package(lang, package):
|
| 167 |
+
return get_charlm_package(lang, package, pos_charlms, default_charlms)
|
| 168 |
+
|
| 169 |
+
def get_pos_dependencies(lang, package):
|
| 170 |
+
dependencies = []
|
| 171 |
+
|
| 172 |
+
pretrain_package = get_pretrain_package(lang, package, pos_pretrains, default_pretrains)
|
| 173 |
+
if pretrain_package is not None:
|
| 174 |
+
dependencies.append({'model': 'pretrain', 'package': pretrain_package})
|
| 175 |
+
|
| 176 |
+
charlm_package = get_pos_charlm_package(lang, package)
|
| 177 |
+
if charlm_package is not None:
|
| 178 |
+
dependencies.append({'model': 'forward_charlm', 'package': charlm_package})
|
| 179 |
+
dependencies.append({'model': 'backward_charlm', 'package': charlm_package})
|
| 180 |
+
|
| 181 |
+
return dependencies
|
| 182 |
+
|
| 183 |
+
def get_lemma_pretrain_package(lang, package):
|
| 184 |
+
package, uses_pretrain, uses_charlm = split_package(package)
|
| 185 |
+
if not uses_pretrain:
|
| 186 |
+
return None
|
| 187 |
+
if not uses_charlm:
|
| 188 |
+
# currently the contextual lemma classifier is only active
|
| 189 |
+
# for the charlm lemmatizers
|
| 190 |
+
return None
|
| 191 |
+
if "%s_%s" % (lang, package) not in LEMMA_CLASSIFIER_DATASETS:
|
| 192 |
+
return None
|
| 193 |
+
return get_pretrain_package(lang, package, {}, default_pretrains)
|
| 194 |
+
|
| 195 |
+
def get_lemma_charlm_package(lang, package):
|
| 196 |
+
return get_charlm_package(lang, package, lemma_charlms, default_charlms)
|
| 197 |
+
|
| 198 |
+
def get_lemma_dependencies(lang, package):
|
| 199 |
+
dependencies = []
|
| 200 |
+
|
| 201 |
+
pretrain_package = get_lemma_pretrain_package(lang, package)
|
| 202 |
+
if pretrain_package is not None:
|
| 203 |
+
dependencies.append({'model': 'pretrain', 'package': pretrain_package})
|
| 204 |
+
|
| 205 |
+
charlm_package = get_lemma_charlm_package(lang, package)
|
| 206 |
+
if charlm_package is not None:
|
| 207 |
+
dependencies.append({'model': 'forward_charlm', 'package': charlm_package})
|
| 208 |
+
dependencies.append({'model': 'backward_charlm', 'package': charlm_package})
|
| 209 |
+
|
| 210 |
+
return dependencies
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def get_depparse_charlm_package(lang, package):
|
| 214 |
+
return get_charlm_package(lang, package, depparse_charlms, default_charlms)
|
| 215 |
+
|
| 216 |
+
def get_depparse_dependencies(lang, package):
|
| 217 |
+
dependencies = []
|
| 218 |
+
|
| 219 |
+
pretrain_package = get_pretrain_package(lang, package, depparse_pretrains, default_pretrains)
|
| 220 |
+
if pretrain_package is not None:
|
| 221 |
+
dependencies.append({'model': 'pretrain', 'package': pretrain_package})
|
| 222 |
+
|
| 223 |
+
charlm_package = get_depparse_charlm_package(lang, package)
|
| 224 |
+
if charlm_package is not None:
|
| 225 |
+
dependencies.append({'model': 'forward_charlm', 'package': charlm_package})
|
| 226 |
+
dependencies.append({'model': 'backward_charlm', 'package': charlm_package})
|
| 227 |
+
|
| 228 |
+
return dependencies
|
| 229 |
+
|
| 230 |
+
def get_ner_charlm_package(lang, package):
|
| 231 |
+
return get_charlm_package(lang, package, ner_charlms, default_charlms)
|
| 232 |
+
|
| 233 |
+
def get_ner_pretrain_package(lang, package):
|
| 234 |
+
return get_pretrain_package(lang, package, ner_pretrains, default_pretrains)
|
| 235 |
+
|
| 236 |
+
def get_ner_dependencies(lang, package):
|
| 237 |
+
dependencies = []
|
| 238 |
+
|
| 239 |
+
pretrain_package = get_ner_pretrain_package(lang, package)
|
| 240 |
+
if pretrain_package is not None:
|
| 241 |
+
dependencies.append({'model': 'pretrain', 'package': pretrain_package})
|
| 242 |
+
|
| 243 |
+
charlm_package = get_ner_charlm_package(lang, package)
|
| 244 |
+
if charlm_package is not None:
|
| 245 |
+
dependencies.append({'model': 'forward_charlm', 'package': charlm_package})
|
| 246 |
+
dependencies.append({'model': 'backward_charlm', 'package': charlm_package})
|
| 247 |
+
|
| 248 |
+
return dependencies
|
| 249 |
+
|
| 250 |
+
def get_sentiment_dependencies(lang, package):
|
| 251 |
+
"""
|
| 252 |
+
Return a list of dependencies for the sentiment model
|
| 253 |
+
|
| 254 |
+
Generally this will be pretrain, forward & backward charlm
|
| 255 |
+
So far, this invariant is true:
|
| 256 |
+
sentiment models use the default pretrain for the language
|
| 257 |
+
also, they all use the default charlm for a language
|
| 258 |
+
"""
|
| 259 |
+
pretrain_package = get_pretrain_package(lang, package, None, default_pretrains)
|
| 260 |
+
dependencies = [{'model': 'pretrain', 'package': pretrain_package}]
|
| 261 |
+
|
| 262 |
+
charlm_package = default_charlms.get(lang, None)
|
| 263 |
+
if charlm_package is not None:
|
| 264 |
+
dependencies.append({'model': 'forward_charlm', 'package': charlm_package})
|
| 265 |
+
dependencies.append({'model': 'backward_charlm', 'package': charlm_package})
|
| 266 |
+
|
| 267 |
+
return dependencies
|
| 268 |
+
|
| 269 |
+
def get_dependencies(processor, lang, package):
|
| 270 |
+
"""
|
| 271 |
+
Get the dependencies for a particular lang/package based on the package name
|
| 272 |
+
|
| 273 |
+
The package can include descriptors such as _nopretrain, _nocharlm, _charlm
|
| 274 |
+
which inform whether or not this particular model uses charlm or pretrain
|
| 275 |
+
"""
|
| 276 |
+
if processor == 'depparse':
|
| 277 |
+
return get_depparse_dependencies(lang, package)
|
| 278 |
+
elif processor == 'lemma':
|
| 279 |
+
return get_lemma_dependencies(lang, package)
|
| 280 |
+
elif processor == 'pos':
|
| 281 |
+
return get_pos_dependencies(lang, package)
|
| 282 |
+
elif processor == 'ner':
|
| 283 |
+
return get_ner_dependencies(lang, package)
|
| 284 |
+
elif processor == 'sentiment':
|
| 285 |
+
return get_sentiment_dependencies(lang, package)
|
| 286 |
+
elif processor == 'constituency':
|
| 287 |
+
return get_con_dependencies(lang, package)
|
| 288 |
+
return {}
|
| 289 |
+
|
| 290 |
+
def process_dirs(args):
|
| 291 |
+
dirs = sorted(os.listdir(args.input_dir))
|
| 292 |
+
resources = {}
|
| 293 |
+
if args.lang:
|
| 294 |
+
resources = json.load(open(os.path.join(args.output_dir, 'resources.json')))
|
| 295 |
+
# this one language gets overridden
|
| 296 |
+
# if this is not done, and we reuse the old resources,
|
| 297 |
+
# any models which were deleted will still be in the resources
|
| 298 |
+
for lang in args.lang.split(","):
|
| 299 |
+
resources[lang] = {}
|
| 300 |
+
|
| 301 |
+
for model_dir in dirs:
|
| 302 |
+
print(f"Processing models in {model_dir}")
|
| 303 |
+
models = sorted(os.listdir(os.path.join(args.input_dir, model_dir)))
|
| 304 |
+
for model in tqdm(models):
|
| 305 |
+
if not model.endswith('.pt'): continue
|
| 306 |
+
# get processor
|
| 307 |
+
lang, package, processor = split_model_name(model)
|
| 308 |
+
if args.lang and lang not in args.lang.split(","):
|
| 309 |
+
continue
|
| 310 |
+
|
| 311 |
+
# copy file
|
| 312 |
+
input_path = os.path.join(args.input_dir, model_dir, model)
|
| 313 |
+
output_path = os.path.join(args.output_dir, lang, "models", processor, package + '.pt')
|
| 314 |
+
copy_file(input_path, output_path)
|
| 315 |
+
# maintain md5
|
| 316 |
+
md5 = get_md5(output_path)
|
| 317 |
+
# maintain dependencies
|
| 318 |
+
dependencies = get_dependencies(processor, lang, package)
|
| 319 |
+
# maintain resources
|
| 320 |
+
if lang not in resources: resources[lang] = {}
|
| 321 |
+
if processor not in resources[lang]: resources[lang][processor] = {}
|
| 322 |
+
if dependencies:
|
| 323 |
+
resources[lang][processor][package] = {'md5': md5, 'dependencies': dependencies}
|
| 324 |
+
else:
|
| 325 |
+
resources[lang][processor][package] = {'md5': md5}
|
| 326 |
+
print("Processed initial model directories. Writing preliminary resources.json")
|
| 327 |
+
json.dump(resources, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2)
|
| 328 |
+
|
| 329 |
+
def get_default_pos_package(lang, ud_package):
|
| 330 |
+
charlm_package = get_pos_charlm_package(lang, ud_package)
|
| 331 |
+
if charlm_package is not None:
|
| 332 |
+
return ud_package + "_charlm"
|
| 333 |
+
if lang in no_pretrain_languages:
|
| 334 |
+
return ud_package + "_nopretrain"
|
| 335 |
+
return ud_package + "_nocharlm"
|
| 336 |
+
|
| 337 |
+
def get_default_depparse_package(lang, ud_package):
|
| 338 |
+
charlm_package = get_depparse_charlm_package(lang, ud_package)
|
| 339 |
+
if charlm_package is not None:
|
| 340 |
+
return ud_package + "_charlm"
|
| 341 |
+
if lang in no_pretrain_languages:
|
| 342 |
+
return ud_package + "_nopretrain"
|
| 343 |
+
return ud_package + "_nocharlm"
|
| 344 |
+
|
| 345 |
+
def process_default_zips(args):
|
| 346 |
+
resources = json.load(open(os.path.join(args.output_dir, 'resources.json')))
|
| 347 |
+
for lang in resources:
|
| 348 |
+
# check url, alias, and lang_name in case we are rerunning this step on an already built resources.json
|
| 349 |
+
if lang == 'url':
|
| 350 |
+
continue
|
| 351 |
+
if 'alias' in resources[lang]:
|
| 352 |
+
continue
|
| 353 |
+
if all(k in ("backward_charlm", "forward_charlm", "pretrain", "lang_name") for k in resources[lang].keys()):
|
| 354 |
+
continue
|
| 355 |
+
if lang not in default_treebanks:
|
| 356 |
+
raise AssertionError(f'{lang} not in default treebanks!!!')
|
| 357 |
+
|
| 358 |
+
if args.lang and lang not in args.lang.split(","):
|
| 359 |
+
continue
|
| 360 |
+
|
| 361 |
+
print(f'Preparing default models for language {lang}')
|
| 362 |
+
|
| 363 |
+
models_needed = defaultdict(set)
|
| 364 |
+
|
| 365 |
+
packages = resources[lang][PACKAGES]["default"]
|
| 366 |
+
for processor, package in packages.items():
|
| 367 |
+
if processor == 'lemma' and package == 'identity':
|
| 368 |
+
continue
|
| 369 |
+
if processor == 'optional':
|
| 370 |
+
continue
|
| 371 |
+
models_needed[processor].add(package)
|
| 372 |
+
dependencies = get_dependencies(processor, lang, package)
|
| 373 |
+
for dependency in dependencies:
|
| 374 |
+
models_needed[dependency['model']].add(dependency['package'])
|
| 375 |
+
|
| 376 |
+
model_files = []
|
| 377 |
+
for processor in PROCESSORS:
|
| 378 |
+
if processor in models_needed:
|
| 379 |
+
for package in sorted(models_needed[processor]):
|
| 380 |
+
filename = os.path.join(args.output_dir, lang, "models", processor, package + '.pt')
|
| 381 |
+
if os.path.exists(filename):
|
| 382 |
+
print(" Model {} package {}: file {}".format(processor, package, filename))
|
| 383 |
+
model_files.append((filename, processor, package))
|
| 384 |
+
else:
|
| 385 |
+
raise FileNotFoundError(f"Processor {processor} package {package} needed for {lang} but cannot be found at {filename}")
|
| 386 |
+
|
| 387 |
+
with zipfile.ZipFile(os.path.join(args.output_dir, lang, 'models', 'default.zip'), 'w', zipfile.ZIP_DEFLATED) as zipf:
|
| 388 |
+
for filename, processor, package in model_files:
|
| 389 |
+
zipf.write(filename=filename, arcname=os.path.join(processor, package + '.pt'))
|
| 390 |
+
|
| 391 |
+
default_md5 = get_md5(os.path.join(args.output_dir, lang, 'models', 'default.zip'))
|
| 392 |
+
resources[lang]['default_md5'] = default_md5
|
| 393 |
+
|
| 394 |
+
print("Processed default model zips. Writing resources.json")
|
| 395 |
+
json.dump(resources, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2)
|
| 396 |
+
|
| 397 |
+
def get_default_processors(resources, lang):
|
| 398 |
+
"""
|
| 399 |
+
Build a default package for this language
|
| 400 |
+
|
| 401 |
+
Will add each of pos, lemma, depparse, etc if those are available
|
| 402 |
+
Uses the existing models scraped from the language directories into resources.json, as relevant
|
| 403 |
+
"""
|
| 404 |
+
if lang == "multilingual":
|
| 405 |
+
return {"langid": "ud"}
|
| 406 |
+
|
| 407 |
+
default_package = default_treebanks[lang]
|
| 408 |
+
default_processors = {}
|
| 409 |
+
if lang in default_tokenizer:
|
| 410 |
+
default_processors['tokenize'] = default_tokenizer[lang]
|
| 411 |
+
else:
|
| 412 |
+
default_processors['tokenize'] = default_package
|
| 413 |
+
|
| 414 |
+
if 'mwt' in resources[lang] and default_processors['tokenize'] in resources[lang]['mwt']:
|
| 415 |
+
# if this doesn't happen, we just skip MWT
|
| 416 |
+
default_processors['mwt'] = default_package
|
| 417 |
+
|
| 418 |
+
if 'lemma' in resources[lang]:
|
| 419 |
+
expected_lemma = default_package + "_nocharlm"
|
| 420 |
+
if expected_lemma in resources[lang]['lemma']:
|
| 421 |
+
default_processors['lemma'] = expected_lemma
|
| 422 |
+
elif lang not in allowed_empty_languages:
|
| 423 |
+
default_processors['lemma'] = 'identity'
|
| 424 |
+
|
| 425 |
+
if 'pos' in resources[lang]:
|
| 426 |
+
default_processors['pos'] = get_default_pos_package(lang, default_package)
|
| 427 |
+
if default_processors['pos'] not in resources[lang]['pos']:
|
| 428 |
+
raise AssertionError("Expected POS model not in resources: %s" % default_processors['pos'])
|
| 429 |
+
elif lang not in allowed_empty_languages:
|
| 430 |
+
raise AssertionError("Expected to find POS models for language %s" % lang)
|
| 431 |
+
|
| 432 |
+
if 'depparse' in resources[lang]:
|
| 433 |
+
default_processors['depparse'] = get_default_depparse_package(lang, default_package)
|
| 434 |
+
if default_processors['depparse'] not in resources[lang]['depparse']:
|
| 435 |
+
raise AssertionError("Expected depparse model not in resources: %s" % default_processors['depparse'])
|
| 436 |
+
elif lang not in allowed_empty_languages:
|
| 437 |
+
raise AssertionError("Expected to find depparse models for language %s" % lang)
|
| 438 |
+
|
| 439 |
+
if lang in default_ners:
|
| 440 |
+
default_processors['ner'] = default_ners[lang]
|
| 441 |
+
|
| 442 |
+
if lang in default_sentiment:
|
| 443 |
+
default_processors['sentiment'] = default_sentiment[lang]
|
| 444 |
+
|
| 445 |
+
if lang in default_constituency:
|
| 446 |
+
default_processors['constituency'] = default_constituency[lang]
|
| 447 |
+
|
| 448 |
+
optional = get_default_optional_processors(resources, lang)
|
| 449 |
+
if optional:
|
| 450 |
+
default_processors['optional'] = optional
|
| 451 |
+
|
| 452 |
+
return default_processors
|
| 453 |
+
|
| 454 |
+
def get_default_optional_processors(resources, lang):
|
| 455 |
+
optional_processors = {}
|
| 456 |
+
if lang in optional_constituency:
|
| 457 |
+
optional_processors['constituency'] = optional_constituency[lang]
|
| 458 |
+
|
| 459 |
+
if lang in optional_coref:
|
| 460 |
+
optional_processors['coref'] = optional_coref[lang]
|
| 461 |
+
|
| 462 |
+
return optional_processors
|
| 463 |
+
|
| 464 |
+
def update_processor_add_transformer(resources, lang, current_processors, processor, transformer):
|
| 465 |
+
if processor not in current_processors:
|
| 466 |
+
return
|
| 467 |
+
|
| 468 |
+
new_model = current_processors[processor].replace('_charlm', "_" + transformer).replace('_nocharlm', "_" + transformer)
|
| 469 |
+
if new_model in resources[lang][processor]:
|
| 470 |
+
current_processors[processor] = new_model
|
| 471 |
+
else:
|
| 472 |
+
print("WARNING: wanted to use %s for %s accurate %s, but that model does not exist" % (new_model, lang, processor))
|
| 473 |
+
|
| 474 |
+
def get_default_accurate(resources, lang):
|
| 475 |
+
"""
|
| 476 |
+
A package that, if available, uses charlm and transformer models for each processor
|
| 477 |
+
"""
|
| 478 |
+
default_processors = get_default_processors(resources, lang)
|
| 479 |
+
|
| 480 |
+
if 'lemma' in default_processors and default_processors['lemma'] != 'identity':
|
| 481 |
+
lemma_model = default_processors['lemma']
|
| 482 |
+
lemma_model = lemma_model.replace('_nocharlm', '_charlm')
|
| 483 |
+
charlm_package = get_lemma_charlm_package(lang, lemma_model)
|
| 484 |
+
if charlm_package is not None:
|
| 485 |
+
if lemma_model in resources[lang]['lemma']:
|
| 486 |
+
default_processors['lemma'] = lemma_model
|
| 487 |
+
else:
|
| 488 |
+
print("WARNING: wanted to use %s for %s default_accurate lemma, but that model does not exist" % (lemma_model, lang))
|
| 489 |
+
|
| 490 |
+
transformer = TRANSFORMER_NICKNAMES.get(TRANSFORMERS.get(lang, None), None)
|
| 491 |
+
if transformer is not None:
|
| 492 |
+
for processor in ('pos', 'depparse', 'constituency', 'sentiment'):
|
| 493 |
+
update_processor_add_transformer(resources, lang, default_processors, processor, transformer)
|
| 494 |
+
if 'ner' in default_processors and (default_processors['ner'].endswith("_charlm") or default_processors['ner'].endswith("_nocharlm")):
|
| 495 |
+
update_processor_add_transformer(resources, lang, default_processors, "ner", transformer)
|
| 496 |
+
|
| 497 |
+
optional = get_optional_accurate(resources, lang)
|
| 498 |
+
if optional:
|
| 499 |
+
default_processors['optional'] = optional
|
| 500 |
+
|
| 501 |
+
return default_processors
|
| 502 |
+
|
| 503 |
+
def get_optional_accurate(resources, lang):
|
| 504 |
+
optional_processors = get_default_optional_processors(resources, lang)
|
| 505 |
+
|
| 506 |
+
transformer = TRANSFORMER_NICKNAMES.get(TRANSFORMERS.get(lang, None), None)
|
| 507 |
+
if transformer is not None:
|
| 508 |
+
for processor in ('pos', 'depparse', 'constituency', 'sentiment'):
|
| 509 |
+
update_processor_add_transformer(resources, lang, optional_processors, processor, transformer)
|
| 510 |
+
|
| 511 |
+
if lang in optional_coref:
|
| 512 |
+
optional_processors['coref'] = optional_coref[lang]
|
| 513 |
+
|
| 514 |
+
return optional_processors
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
def get_default_fast(resources, lang):
|
| 518 |
+
"""
|
| 519 |
+
Build a packages entry which only has the nocharlm models
|
| 520 |
+
|
| 521 |
+
Will make it easy for people to use the lower tier of models
|
| 522 |
+
|
| 523 |
+
We do this by building the same default package as normal,
|
| 524 |
+
then switching everything out for the lower tier model when possible.
|
| 525 |
+
We also remove constituency, as it is super slow.
|
| 526 |
+
Note that in the case of a language which doesn't have a charlm,
|
| 527 |
+
that means we wind up building the same for default and default_nocharlm
|
| 528 |
+
"""
|
| 529 |
+
default_processors = get_default_processors(resources, lang)
|
| 530 |
+
|
| 531 |
+
# this is a slow model and we don't have non-charlm versions of it yet
|
| 532 |
+
if 'constituency' in default_processors:
|
| 533 |
+
default_processors.pop('constituency')
|
| 534 |
+
|
| 535 |
+
for processor, model in default_processors.items():
|
| 536 |
+
if "_charlm" in model:
|
| 537 |
+
nocharlm = model.replace("_charlm", "_nocharlm")
|
| 538 |
+
if nocharlm not in resources[lang][processor]:
|
| 539 |
+
print("WARNING: wanted to use %s for %s default_fast processor %s, but that model does not exist" % (nocharlm, lang, processor))
|
| 540 |
+
else:
|
| 541 |
+
default_processors[processor] = nocharlm
|
| 542 |
+
|
| 543 |
+
return default_processors
|
| 544 |
+
|
| 545 |
+
def process_packages(args):
|
| 546 |
+
"""
|
| 547 |
+
Build a package for a language's default processors and all of the treebanks specifically used for that language
|
| 548 |
+
"""
|
| 549 |
+
resources = json.load(open(os.path.join(args.output_dir, 'resources.json')))
|
| 550 |
+
|
| 551 |
+
for lang in resources:
|
| 552 |
+
# check url, alias, and lang_name in case we are rerunning this step on an already built resources.json
|
| 553 |
+
if lang == 'url':
|
| 554 |
+
continue
|
| 555 |
+
if 'alias' in resources[lang]:
|
| 556 |
+
continue
|
| 557 |
+
if all(k in ("backward_charlm", "forward_charlm", "pretrain", "lang_name") for k in resources[lang].keys()):
|
| 558 |
+
continue
|
| 559 |
+
if lang not in default_treebanks:
|
| 560 |
+
raise AssertionError(f'{lang} not in default treebanks!!!')
|
| 561 |
+
|
| 562 |
+
if args.lang and lang not in args.lang.split(","):
|
| 563 |
+
continue
|
| 564 |
+
|
| 565 |
+
default_processors = get_default_processors(resources, lang)
|
| 566 |
+
|
| 567 |
+
# TODO: eventually we can remove default_processors
|
| 568 |
+
# For now, we want to keep this so that v1.5.1 is compatible
|
| 569 |
+
# with the next iteration of resources files
|
| 570 |
+
resources[lang]['default_processors'] = default_processors
|
| 571 |
+
resources[lang][PACKAGES] = {}
|
| 572 |
+
resources[lang][PACKAGES]['default'] = default_processors
|
| 573 |
+
|
| 574 |
+
if lang not in no_pretrain_languages and lang != "multilingual":
|
| 575 |
+
default_fast = get_default_fast(resources, lang)
|
| 576 |
+
resources[lang][PACKAGES]['default_fast'] = default_fast
|
| 577 |
+
|
| 578 |
+
default_accurate = get_default_accurate(resources, lang)
|
| 579 |
+
resources[lang][PACKAGES]['default_accurate'] = default_accurate
|
| 580 |
+
|
| 581 |
+
# Now we loop over each of the tokenizers for this language
|
| 582 |
+
# ... we use this as a proxy for the available UD treebanks
|
| 583 |
+
# This loop also catches things such as "craft" which are
|
| 584 |
+
# included treebanks that aren't UD
|
| 585 |
+
# We then create a package in the packages dict for each of those treebanks
|
| 586 |
+
if 'tokenize' in resources[lang]:
|
| 587 |
+
for package in resources[lang]['tokenize']:
|
| 588 |
+
processors = {"tokenize": package}
|
| 589 |
+
if "mwt" in resources[lang] and package in resources[lang]["mwt"]:
|
| 590 |
+
processors["mwt"] = package
|
| 591 |
+
|
| 592 |
+
if "pos" in resources[lang]:
|
| 593 |
+
if package + "_charlm" in resources[lang]["pos"]:
|
| 594 |
+
processors["pos"] = package + "_charlm"
|
| 595 |
+
elif package + "_nocharlm" in resources[lang]["pos"]:
|
| 596 |
+
processors["pos"] = package + "_nocharlm"
|
| 597 |
+
|
| 598 |
+
if "lemma" in resources[lang] and "pos" in processors:
|
| 599 |
+
lemma_package = package + "_nocharlm"
|
| 600 |
+
if lemma_package in resources[lang]["lemma"]:
|
| 601 |
+
processors["lemma"] = lemma_package
|
| 602 |
+
|
| 603 |
+
if "depparse" in resources[lang] and "pos" in processors:
|
| 604 |
+
depparse_package = None
|
| 605 |
+
if package + "_charlm" in resources[lang]["depparse"]:
|
| 606 |
+
depparse_package = package + "_charlm"
|
| 607 |
+
elif package + "_nocharlm" in resources[lang]["depparse"]:
|
| 608 |
+
depparse_package = package + "_nocharlm"
|
| 609 |
+
# we want to set the lemma first if it's identity
|
| 610 |
+
# THEN set the depparse
|
| 611 |
+
if depparse_package is not None:
|
| 612 |
+
if "lemma" not in processors:
|
| 613 |
+
processors["lemma"] = "identity"
|
| 614 |
+
processors["depparse"] = depparse_package
|
| 615 |
+
|
| 616 |
+
resources[lang][PACKAGES][package] = processors
|
| 617 |
+
|
| 618 |
+
print("Processed packages. Writing resources.json")
|
| 619 |
+
json.dump(resources, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2)
|
| 620 |
+
|
| 621 |
+
def process_lcode(args):
|
| 622 |
+
resources = json.load(open(os.path.join(args.output_dir, 'resources.json')))
|
| 623 |
+
resources_new = {}
|
| 624 |
+
resources_new["multilingual"] = resources["multilingual"]
|
| 625 |
+
for lang in resources:
|
| 626 |
+
if lang == 'multilingual':
|
| 627 |
+
continue
|
| 628 |
+
if 'alias' in resources[lang]:
|
| 629 |
+
continue
|
| 630 |
+
if lang not in lcode2lang:
|
| 631 |
+
print(lang + ' not found in lcode2lang!')
|
| 632 |
+
continue
|
| 633 |
+
lang_name = lcode2lang[lang]
|
| 634 |
+
resources[lang]['lang_name'] = lang_name
|
| 635 |
+
resources_new[lang.lower()] = resources[lang.lower()]
|
| 636 |
+
resources_new[lang_name.lower()] = {'alias': lang.lower()}
|
| 637 |
+
if lang.lower() in two_to_three_letters:
|
| 638 |
+
resources_new[two_to_three_letters[lang.lower()]] = {'alias': lang.lower()}
|
| 639 |
+
elif lang.lower() in three_to_two_letters:
|
| 640 |
+
resources_new[three_to_two_letters[lang.lower()]] = {'alias': lang.lower()}
|
| 641 |
+
print("Processed lcode aliases. Writing resources.json")
|
| 642 |
+
json.dump(resources_new, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2)
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
def process_misc(args):
|
| 646 |
+
resources = json.load(open(os.path.join(args.output_dir, 'resources.json')))
|
| 647 |
+
resources['no'] = {'alias': 'nb'}
|
| 648 |
+
resources['zh'] = {'alias': 'zh-hans'}
|
| 649 |
+
# This is intended to be unformatted. expand_model_url in common.py will fill in the raw string
|
| 650 |
+
# with the appropriate values in order to find the needed model file on huggingface
|
| 651 |
+
resources['url'] = 'https://huggingface.co/stanfordnlp/stanza-{lang}/resolve/v{resources_version}/models/{filename}'
|
| 652 |
+
print("Finalized misc attributes. Writing resources.json")
|
| 653 |
+
json.dump(resources, open(os.path.join(args.output_dir, 'resources.json'), 'w'), indent=2)
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
def main():
|
| 657 |
+
args = parse_args()
|
| 658 |
+
print("Converting models from %s to %s" % (args.input_dir, args.output_dir))
|
| 659 |
+
if not args.packages_only:
|
| 660 |
+
process_dirs(args)
|
| 661 |
+
process_packages(args)
|
| 662 |
+
if not args.packages_only:
|
| 663 |
+
process_default_zips(args)
|
| 664 |
+
process_lcode(args)
|
| 665 |
+
process_misc(args)
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
if __name__ == '__main__':
|
| 669 |
+
main()
|
| 670 |
+
|
stanza/stanza/server/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from stanza.protobuf import to_text
|
| 2 |
+
from stanza.protobuf import Document, Sentence, Token, IndexedWord, Span
|
| 3 |
+
from stanza.protobuf import ParseTree, DependencyGraph, CorefChain
|
| 4 |
+
from stanza.protobuf import Mention, NERMention, Entity, Relation, RelationTriple, Timex
|
| 5 |
+
from stanza.protobuf import Quote, SpeakerInfo
|
| 6 |
+
from stanza.protobuf import Operator, Polarity
|
| 7 |
+
from stanza.protobuf import SentenceFragment, TokenLocation
|
| 8 |
+
from stanza.protobuf import MapStringString, MapIntString
|
| 9 |
+
from .client import CoreNLPClient, AnnotationException, TimeoutException, PermanentlyFailedException, StartServer
|
| 10 |
+
from .annotator import Annotator
|
stanza/stanza/server/annotator.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Defines a base class that can be used to annotate.
|
| 3 |
+
"""
|
| 4 |
+
import io
|
| 5 |
+
from multiprocessing import Process
|
| 6 |
+
from http.server import BaseHTTPRequestHandler, HTTPServer
|
| 7 |
+
from http import client as HTTPStatus
|
| 8 |
+
|
| 9 |
+
from stanza.protobuf import Document, parseFromDelimitedString, writeToDelimitedString
|
| 10 |
+
|
| 11 |
+
class Annotator(Process):
|
| 12 |
+
"""
|
| 13 |
+
This annotator base class hosts a lightweight server that accepts
|
| 14 |
+
annotation requests from CoreNLP.
|
| 15 |
+
Each annotator simply defines 3 functions: requires, provides and annotate.
|
| 16 |
+
|
| 17 |
+
This class takes care of defining appropriate endpoints to interface
|
| 18 |
+
with CoreNLP.
|
| 19 |
+
"""
|
| 20 |
+
@property
|
| 21 |
+
def name(self):
|
| 22 |
+
"""
|
| 23 |
+
Name of the annotator (used by CoreNLP)
|
| 24 |
+
"""
|
| 25 |
+
raise NotImplementedError()
|
| 26 |
+
|
| 27 |
+
@property
|
| 28 |
+
def requires(self):
|
| 29 |
+
"""
|
| 30 |
+
Requires has to specify all the annotations required before we
|
| 31 |
+
are called.
|
| 32 |
+
"""
|
| 33 |
+
raise NotImplementedError()
|
| 34 |
+
|
| 35 |
+
@property
|
| 36 |
+
def provides(self):
|
| 37 |
+
"""
|
| 38 |
+
The set of annotations guaranteed to be provided when we are done.
|
| 39 |
+
NOTE: that these annotations are either fully qualified Java
|
| 40 |
+
class names or refer to nested classes of
|
| 41 |
+
edu.stanford.nlp.ling.CoreAnnotations (as is the case below).
|
| 42 |
+
"""
|
| 43 |
+
raise NotImplementedError()
|
| 44 |
+
|
| 45 |
+
def annotate(self, ann):
|
| 46 |
+
"""
|
| 47 |
+
@ann: is a protobuf annotation object.
|
| 48 |
+
Actually populate @ann with tokens.
|
| 49 |
+
"""
|
| 50 |
+
raise NotImplementedError()
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def properties(self):
|
| 54 |
+
"""
|
| 55 |
+
Defines a Java property to define this annotator to CoreNLP.
|
| 56 |
+
"""
|
| 57 |
+
return {
|
| 58 |
+
"customAnnotatorClass.{}".format(self.name): "edu.stanford.nlp.pipeline.GenericWebServiceAnnotator",
|
| 59 |
+
"generic.endpoint": "http://{}:{}".format(self.host, self.port),
|
| 60 |
+
"generic.requires": ",".join(self.requires),
|
| 61 |
+
"generic.provides": ",".join(self.provides),
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
class _Handler(BaseHTTPRequestHandler):
|
| 65 |
+
annotator = None
|
| 66 |
+
|
| 67 |
+
def __init__(self, request, client_address, server):
|
| 68 |
+
BaseHTTPRequestHandler.__init__(self, request, client_address, server)
|
| 69 |
+
|
| 70 |
+
def do_GET(self):
|
| 71 |
+
"""
|
| 72 |
+
Handle a ping request
|
| 73 |
+
"""
|
| 74 |
+
if not self.path.endswith("/"): self.path += "/"
|
| 75 |
+
if self.path == "/ping/":
|
| 76 |
+
msg = "pong".encode("UTF-8")
|
| 77 |
+
|
| 78 |
+
self.send_response(HTTPStatus.OK)
|
| 79 |
+
self.send_header("Content-Type", "text/application")
|
| 80 |
+
self.send_header("Content-Length", len(msg))
|
| 81 |
+
self.end_headers()
|
| 82 |
+
self.wfile.write(msg)
|
| 83 |
+
else:
|
| 84 |
+
self.send_response(HTTPStatus.BAD_REQUEST)
|
| 85 |
+
self.end_headers()
|
| 86 |
+
|
| 87 |
+
def do_POST(self):
|
| 88 |
+
"""
|
| 89 |
+
Handle an annotate request
|
| 90 |
+
"""
|
| 91 |
+
if not self.path.endswith("/"): self.path += "/"
|
| 92 |
+
if self.path == "/annotate/":
|
| 93 |
+
# Read message
|
| 94 |
+
length = int(self.headers.get('content-length'))
|
| 95 |
+
msg = self.rfile.read(length)
|
| 96 |
+
|
| 97 |
+
# Do the annotation
|
| 98 |
+
doc = Document()
|
| 99 |
+
parseFromDelimitedString(doc, msg)
|
| 100 |
+
self.annotator.annotate(doc)
|
| 101 |
+
|
| 102 |
+
with io.BytesIO() as stream:
|
| 103 |
+
writeToDelimitedString(doc, stream)
|
| 104 |
+
msg = stream.getvalue()
|
| 105 |
+
|
| 106 |
+
# write message
|
| 107 |
+
self.send_response(HTTPStatus.OK)
|
| 108 |
+
self.send_header("Content-Type", "application/x-protobuf")
|
| 109 |
+
self.send_header("Content-Length", len(msg))
|
| 110 |
+
self.end_headers()
|
| 111 |
+
self.wfile.write(msg)
|
| 112 |
+
|
| 113 |
+
else:
|
| 114 |
+
self.send_response(HTTPStatus.BAD_REQUEST)
|
| 115 |
+
self.end_headers()
|
| 116 |
+
|
| 117 |
+
def __init__(self, host="", port=8432):
|
| 118 |
+
"""
|
| 119 |
+
Launches a server endpoint to communicate with CoreNLP
|
| 120 |
+
"""
|
| 121 |
+
Process.__init__(self)
|
| 122 |
+
self.host, self.port = host, port
|
| 123 |
+
self._Handler.annotator = self
|
| 124 |
+
|
| 125 |
+
def run(self):
|
| 126 |
+
"""
|
| 127 |
+
Runs the server using Python's simple HTTPServer.
|
| 128 |
+
TODO: make this multithreaded.
|
| 129 |
+
"""
|
| 130 |
+
httpd = HTTPServer((self.host, self.port), self._Handler)
|
| 131 |
+
sa = httpd.socket.getsockname()
|
| 132 |
+
serve_message = "Serving HTTP on {host} port {port} (http://{host}:{port}/) ..."
|
| 133 |
+
print(serve_message.format(host=sa[0], port=sa[1]))
|
| 134 |
+
try:
|
| 135 |
+
httpd.serve_forever()
|
| 136 |
+
except KeyboardInterrupt:
|
| 137 |
+
print("\nKeyboard interrupt received, exiting.")
|
| 138 |
+
httpd.shutdown()
|
stanza/stanza/server/client.py
ADDED
|
@@ -0,0 +1,779 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Client for accessing Stanford CoreNLP in Python
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import atexit
|
| 6 |
+
import contextlib
|
| 7 |
+
import enum
|
| 8 |
+
import io
|
| 9 |
+
import os
|
| 10 |
+
import re
|
| 11 |
+
import requests
|
| 12 |
+
import logging
|
| 13 |
+
import json
|
| 14 |
+
import shlex
|
| 15 |
+
import socket
|
| 16 |
+
import subprocess
|
| 17 |
+
import time
|
| 18 |
+
import sys
|
| 19 |
+
import uuid
|
| 20 |
+
|
| 21 |
+
from datetime import datetime
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from urllib.parse import urlparse
|
| 24 |
+
|
| 25 |
+
from stanza.protobuf import Document, parseFromDelimitedString, writeToDelimitedString, to_text
|
| 26 |
+
__author__ = 'arunchaganty, kelvinguu, vzhong, wmonroe4'
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger('stanza')
|
| 29 |
+
|
| 30 |
+
# pattern tmp props file should follow
|
| 31 |
+
SERVER_PROPS_TMP_FILE_PATTERN = re.compile('corenlp_server-(.*).props')
|
| 32 |
+
|
| 33 |
+
# Check if str is CoreNLP supported language
|
| 34 |
+
CORENLP_LANGS = ['ar', 'arabic', 'chinese', 'zh', 'english', 'en', 'french', 'fr', 'de', 'german', 'hu', 'hungarian',
|
| 35 |
+
'it', 'italian', 'es', 'spanish']
|
| 36 |
+
|
| 37 |
+
# map shorthands to full language names
|
| 38 |
+
LANGUAGE_SHORTHANDS_TO_FULL = {
|
| 39 |
+
"ar": "arabic",
|
| 40 |
+
"zh": "chinese",
|
| 41 |
+
"en": "english",
|
| 42 |
+
"fr": "french",
|
| 43 |
+
"de": "german",
|
| 44 |
+
"hu": "hungarian",
|
| 45 |
+
"it": "italian",
|
| 46 |
+
"es": "spanish"
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def is_corenlp_lang(props_str):
|
| 51 |
+
""" Check if a string references a CoreNLP language """
|
| 52 |
+
return props_str.lower() in CORENLP_LANGS
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Validate CoreNLP properties
|
| 56 |
+
CORENLP_OUTPUT_VALS = ["conll", "conllu", "json", "serialized", "text", "xml", "inlinexml"]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def validate_corenlp_props(properties=None, annotators=None, output_format=None):
|
| 60 |
+
""" Do basic checks to validate CoreNLP properties """
|
| 61 |
+
if output_format and output_format.lower() not in CORENLP_OUTPUT_VALS:
|
| 62 |
+
raise ValueError(f"{output_format} not a valid CoreNLP outputFormat value! Choose from: {CORENLP_OUTPUT_VALS}")
|
| 63 |
+
if type(properties) == dict:
|
| 64 |
+
if "outputFormat" in properties and properties["outputFormat"].lower() not in CORENLP_OUTPUT_VALS:
|
| 65 |
+
raise ValueError(f"{properties['outputFormat']} not a valid CoreNLP outputFormat value! Choose from: "
|
| 66 |
+
f"{CORENLP_OUTPUT_VALS}")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class AnnotationException(Exception):
|
| 70 |
+
""" Exception raised when there was an error communicating with the CoreNLP server. """
|
| 71 |
+
pass
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class TimeoutException(AnnotationException):
|
| 75 |
+
""" Exception raised when the CoreNLP server timed out. """
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class ShouldRetryException(Exception):
|
| 80 |
+
""" Exception raised if the service should retry the request. """
|
| 81 |
+
pass
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class PermanentlyFailedException(Exception):
|
| 85 |
+
""" Exception raised if the service should NOT retry the request. """
|
| 86 |
+
pass
|
| 87 |
+
|
| 88 |
+
class StartServer(enum.Enum):
|
| 89 |
+
DONT_START = 0
|
| 90 |
+
FORCE_START = 1
|
| 91 |
+
TRY_START = 2
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def clean_props_file(props_file):
|
| 95 |
+
# check if there is a temp server props file to remove and remove it
|
| 96 |
+
if props_file:
|
| 97 |
+
if os.path.isfile(props_file) and SERVER_PROPS_TMP_FILE_PATTERN.match(os.path.basename(props_file)):
|
| 98 |
+
os.remove(props_file)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class RobustService(object):
|
| 102 |
+
""" Service that resuscitates itself if it is not available. """
|
| 103 |
+
CHECK_ALIVE_TIMEOUT = 120
|
| 104 |
+
|
| 105 |
+
def __init__(self, start_cmd, stop_cmd, endpoint, stdout=None,
|
| 106 |
+
stderr=None, be_quiet=False, host=None, port=None, ignore_binding_error=False):
|
| 107 |
+
self.start_cmd = start_cmd and shlex.split(start_cmd)
|
| 108 |
+
self.stop_cmd = stop_cmd and shlex.split(stop_cmd)
|
| 109 |
+
self.endpoint = endpoint
|
| 110 |
+
self.stdout = stdout
|
| 111 |
+
self.stderr = stderr
|
| 112 |
+
|
| 113 |
+
self.server = None
|
| 114 |
+
self.is_active = False
|
| 115 |
+
self.be_quiet = be_quiet
|
| 116 |
+
self.host = host
|
| 117 |
+
self.port = port
|
| 118 |
+
self.ignore_binding_error = ignore_binding_error
|
| 119 |
+
atexit.register(self.atexit_kill)
|
| 120 |
+
|
| 121 |
+
def is_alive(self):
|
| 122 |
+
try:
|
| 123 |
+
if not self.ignore_binding_error and self.server is not None and self.server.poll() is not None:
|
| 124 |
+
return False
|
| 125 |
+
return requests.get(self.endpoint + "/ping").ok
|
| 126 |
+
except requests.exceptions.ConnectionError as e:
|
| 127 |
+
raise ShouldRetryException(e)
|
| 128 |
+
|
| 129 |
+
def start(self):
|
| 130 |
+
if self.start_cmd:
|
| 131 |
+
if self.host and self.port:
|
| 132 |
+
with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
|
| 133 |
+
try:
|
| 134 |
+
sock.bind((self.host, self.port))
|
| 135 |
+
except socket.error as e:
|
| 136 |
+
if self.ignore_binding_error:
|
| 137 |
+
logger.info(f"Connecting to existing CoreNLP server at {self.host}:{self.port}")
|
| 138 |
+
self.server = None
|
| 139 |
+
return
|
| 140 |
+
else:
|
| 141 |
+
raise PermanentlyFailedException("Error: unable to start the CoreNLP server on port %d "
|
| 142 |
+
"(possibly something is already running there)" % self.port) from e
|
| 143 |
+
if self.be_quiet:
|
| 144 |
+
# Issue #26: subprocess.DEVNULL isn't supported in python 2.7.
|
| 145 |
+
if hasattr(subprocess, 'DEVNULL'):
|
| 146 |
+
stderr = subprocess.DEVNULL
|
| 147 |
+
else:
|
| 148 |
+
stderr = open(os.devnull, 'w')
|
| 149 |
+
stdout = stderr
|
| 150 |
+
else:
|
| 151 |
+
stdout = self.stdout
|
| 152 |
+
stderr = self.stderr
|
| 153 |
+
logger.info(f"Starting server with command: {' '.join(self.start_cmd)}")
|
| 154 |
+
try:
|
| 155 |
+
self.server = subprocess.Popen(self.start_cmd,
|
| 156 |
+
stderr=stderr,
|
| 157 |
+
stdout=stdout)
|
| 158 |
+
except FileNotFoundError as e:
|
| 159 |
+
raise FileNotFoundError("When trying to run CoreNLP, a FileNotFoundError occurred, which frequently means Java was not installed or was not in the classpath.") from e
|
| 160 |
+
|
| 161 |
+
def atexit_kill(self):
|
| 162 |
+
# make some kind of effort to stop the service (such as a
|
| 163 |
+
# CoreNLP server) at the end of the program. not waiting so
|
| 164 |
+
# that the python script exiting isn't delayed
|
| 165 |
+
if self.server and self.server.poll() is None:
|
| 166 |
+
self.server.terminate()
|
| 167 |
+
|
| 168 |
+
def stop(self):
|
| 169 |
+
if self.server:
|
| 170 |
+
self.server.terminate()
|
| 171 |
+
try:
|
| 172 |
+
self.server.wait(5)
|
| 173 |
+
except subprocess.TimeoutExpired:
|
| 174 |
+
# Resorting to more aggressive measures...
|
| 175 |
+
self.server.kill()
|
| 176 |
+
try:
|
| 177 |
+
self.server.wait(5)
|
| 178 |
+
except subprocess.TimeoutExpired:
|
| 179 |
+
# oh well
|
| 180 |
+
pass
|
| 181 |
+
self.server = None
|
| 182 |
+
if self.stop_cmd:
|
| 183 |
+
subprocess.run(self.stop_cmd, check=True)
|
| 184 |
+
self.is_active = False
|
| 185 |
+
|
| 186 |
+
def __enter__(self):
|
| 187 |
+
self.start()
|
| 188 |
+
return self
|
| 189 |
+
|
| 190 |
+
def __exit__(self, _, __, ___):
|
| 191 |
+
self.stop()
|
| 192 |
+
|
| 193 |
+
def ensure_alive(self):
|
| 194 |
+
# Check if the service is active and alive
|
| 195 |
+
if self.is_active:
|
| 196 |
+
try:
|
| 197 |
+
if self.is_alive():
|
| 198 |
+
return
|
| 199 |
+
else:
|
| 200 |
+
self.stop()
|
| 201 |
+
except ShouldRetryException:
|
| 202 |
+
pass
|
| 203 |
+
|
| 204 |
+
# If not, try to start up the service.
|
| 205 |
+
if self.server is None:
|
| 206 |
+
self.start()
|
| 207 |
+
|
| 208 |
+
# Wait for the service to start up.
|
| 209 |
+
start_time = time.time()
|
| 210 |
+
while True:
|
| 211 |
+
try:
|
| 212 |
+
if self.is_alive():
|
| 213 |
+
break
|
| 214 |
+
except ShouldRetryException:
|
| 215 |
+
pass
|
| 216 |
+
|
| 217 |
+
if time.time() - start_time < self.CHECK_ALIVE_TIMEOUT:
|
| 218 |
+
time.sleep(1)
|
| 219 |
+
else:
|
| 220 |
+
raise PermanentlyFailedException("Timed out waiting for service to come alive.")
|
| 221 |
+
|
| 222 |
+
# At this point we are guaranteed that the service is alive.
|
| 223 |
+
self.is_active = True
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def resolve_classpath(classpath=None):
|
| 227 |
+
"""
|
| 228 |
+
Returns the classpath to use for corenlp.
|
| 229 |
+
|
| 230 |
+
Prefers to use the given classpath parameter, if available. If
|
| 231 |
+
not, uses the CORENLP_HOME environment variable. Resolves $CLASSPATH
|
| 232 |
+
(the exact string) in either the classpath parameter or $CORENLP_HOME.
|
| 233 |
+
"""
|
| 234 |
+
if classpath == '$CLASSPATH' or (classpath is None and os.getenv("CORENLP_HOME", None) == '$CLASSPATH'):
|
| 235 |
+
classpath = os.getenv("CLASSPATH")
|
| 236 |
+
elif classpath is None:
|
| 237 |
+
classpath = os.getenv("CORENLP_HOME", os.path.join(str(Path.home()), 'stanza_corenlp'))
|
| 238 |
+
|
| 239 |
+
if not os.path.exists(classpath):
|
| 240 |
+
raise FileNotFoundError("Please install CoreNLP by running `stanza.install_corenlp()`. If you have installed it, please define "
|
| 241 |
+
"$CORENLP_HOME to be location of your CoreNLP distribution or pass in a classpath parameter. "
|
| 242 |
+
"$CORENLP_HOME={}".format(os.getenv("CORENLP_HOME")))
|
| 243 |
+
classpath = os.path.join(classpath, "*")
|
| 244 |
+
return classpath
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class CoreNLPClient(RobustService):
|
| 248 |
+
""" A client to the Stanford CoreNLP server. """
|
| 249 |
+
|
| 250 |
+
DEFAULT_ENDPOINT = "http://localhost:9000"
|
| 251 |
+
DEFAULT_TIMEOUT = 60000
|
| 252 |
+
DEFAULT_THREADS = 5
|
| 253 |
+
DEFAULT_OUTPUT_FORMAT = "serialized"
|
| 254 |
+
DEFAULT_MEMORY = "5G"
|
| 255 |
+
DEFAULT_MAX_CHAR_LENGTH = 100000
|
| 256 |
+
|
| 257 |
+
def __init__(self, start_server=StartServer.FORCE_START,
|
| 258 |
+
endpoint=DEFAULT_ENDPOINT,
|
| 259 |
+
timeout=DEFAULT_TIMEOUT,
|
| 260 |
+
threads=DEFAULT_THREADS,
|
| 261 |
+
annotators=None,
|
| 262 |
+
pretokenized=False,
|
| 263 |
+
output_format=None,
|
| 264 |
+
properties=None,
|
| 265 |
+
stdout=None,
|
| 266 |
+
stderr=None,
|
| 267 |
+
memory=DEFAULT_MEMORY,
|
| 268 |
+
be_quiet=False,
|
| 269 |
+
max_char_length=DEFAULT_MAX_CHAR_LENGTH,
|
| 270 |
+
preload=True,
|
| 271 |
+
classpath=None,
|
| 272 |
+
**kwargs):
|
| 273 |
+
|
| 274 |
+
# whether or not server should be started by client
|
| 275 |
+
self.start_server = start_server
|
| 276 |
+
self.server_props_path = None
|
| 277 |
+
self.server_start_time = None
|
| 278 |
+
self.server_host = None
|
| 279 |
+
self.server_port = None
|
| 280 |
+
self.server_classpath = None
|
| 281 |
+
# validate properties
|
| 282 |
+
validate_corenlp_props(properties=properties, annotators=annotators, output_format=output_format)
|
| 283 |
+
# set up client defaults
|
| 284 |
+
self.properties = properties
|
| 285 |
+
self.annotators = annotators
|
| 286 |
+
self.pretokenized = pretokenized
|
| 287 |
+
self.output_format = output_format
|
| 288 |
+
self._setup_client_defaults()
|
| 289 |
+
# start the server
|
| 290 |
+
if isinstance(start_server, bool):
|
| 291 |
+
warning_msg = f"Setting 'start_server' to a boolean value when constructing {self.__class__.__name__} is deprecated and will stop" + \
|
| 292 |
+
" to function in a future version of stanza. Please consider switching to using a value from stanza.server.StartServer."
|
| 293 |
+
logger.warning(warning_msg)
|
| 294 |
+
start_server = StartServer.FORCE_START if start_server is True else StartServer.DONT_START
|
| 295 |
+
|
| 296 |
+
# start the server
|
| 297 |
+
if start_server is StartServer.FORCE_START or start_server is StartServer.TRY_START:
|
| 298 |
+
# record info for server start
|
| 299 |
+
self.server_start_time = datetime.now()
|
| 300 |
+
# set up default properties for server
|
| 301 |
+
self._setup_server_defaults()
|
| 302 |
+
host, port = urlparse(endpoint).netloc.split(":")
|
| 303 |
+
port = int(port)
|
| 304 |
+
assert host == "localhost", "If starting a server, endpoint must be localhost"
|
| 305 |
+
classpath = resolve_classpath(classpath)
|
| 306 |
+
start_cmd = f"java -Xmx{memory} -cp '{classpath}' edu.stanford.nlp.pipeline.StanfordCoreNLPServer " \
|
| 307 |
+
f"-port {port} -timeout {timeout} -threads {threads} -maxCharLength {max_char_length} " \
|
| 308 |
+
f"-quiet {be_quiet} "
|
| 309 |
+
|
| 310 |
+
self.server_classpath = classpath
|
| 311 |
+
self.server_host = host
|
| 312 |
+
self.server_port = port
|
| 313 |
+
|
| 314 |
+
# set up server defaults
|
| 315 |
+
if self.server_props_path is not None:
|
| 316 |
+
start_cmd += f" -serverProperties {self.server_props_path}"
|
| 317 |
+
|
| 318 |
+
# possibly set pretokenized
|
| 319 |
+
if self.pretokenized:
|
| 320 |
+
start_cmd += f" -preTokenized"
|
| 321 |
+
|
| 322 |
+
# set annotators for server default
|
| 323 |
+
if self.annotators is not None:
|
| 324 |
+
annotators_str = self.annotators if type(annotators) == str else ",".join(annotators)
|
| 325 |
+
start_cmd += f" -annotators {annotators_str}"
|
| 326 |
+
|
| 327 |
+
# specify what to preload, if anything
|
| 328 |
+
if preload:
|
| 329 |
+
if type(preload) == bool:
|
| 330 |
+
# -preload flag means to preload all default annotators
|
| 331 |
+
start_cmd += " -preload"
|
| 332 |
+
elif type(preload) == list:
|
| 333 |
+
# turn list into comma separated list string, only preload these annotators
|
| 334 |
+
start_cmd += f" -preload {','.join(preload)}"
|
| 335 |
+
elif type(preload) == str:
|
| 336 |
+
# comma separated list of annotators
|
| 337 |
+
start_cmd += f" -preload {preload}"
|
| 338 |
+
|
| 339 |
+
# set outputFormat for server default
|
| 340 |
+
# if no output format requested by user, set to serialized
|
| 341 |
+
start_cmd += f" -outputFormat {self.output_format}"
|
| 342 |
+
|
| 343 |
+
# additional options for server:
|
| 344 |
+
# - server_id
|
| 345 |
+
# - ssl
|
| 346 |
+
# - status_port
|
| 347 |
+
# - uriContext
|
| 348 |
+
# - strict
|
| 349 |
+
# - key
|
| 350 |
+
# - username
|
| 351 |
+
# - password
|
| 352 |
+
# - blockList
|
| 353 |
+
for kw in ['ssl', 'strict']:
|
| 354 |
+
if kwargs.get(kw) is not None:
|
| 355 |
+
start_cmd += f" -{kw}"
|
| 356 |
+
for kw in ['status_port', 'uriContext', 'key', 'username', 'password', 'blockList', 'server_id']:
|
| 357 |
+
if kwargs.get(kw) is not None:
|
| 358 |
+
start_cmd += f" -{kw} {kwargs.get(kw)}"
|
| 359 |
+
stop_cmd = None
|
| 360 |
+
else:
|
| 361 |
+
start_cmd = stop_cmd = None
|
| 362 |
+
host = port = None
|
| 363 |
+
|
| 364 |
+
super(CoreNLPClient, self).__init__(start_cmd, stop_cmd, endpoint,
|
| 365 |
+
stdout, stderr, be_quiet, host=host, port=port, ignore_binding_error=(start_server == StartServer.TRY_START))
|
| 366 |
+
|
| 367 |
+
self.timeout = timeout
|
| 368 |
+
|
| 369 |
+
def _setup_client_defaults(self):
|
| 370 |
+
"""
|
| 371 |
+
Do some processing of annotators and output_format specified for the client.
|
| 372 |
+
If interacting with an externally started server, these will be defaults for annotate() calls.
|
| 373 |
+
:return: None
|
| 374 |
+
"""
|
| 375 |
+
# normalize annotators to str
|
| 376 |
+
if self.annotators is not None:
|
| 377 |
+
self.annotators = self.annotators if type(self.annotators) == str else ",".join(self.annotators)
|
| 378 |
+
|
| 379 |
+
# handle case where no output format is specified
|
| 380 |
+
if self.output_format is None:
|
| 381 |
+
if type(self.properties) == dict and 'outputFormat' in self.properties:
|
| 382 |
+
self.output_format = self.properties['outputFormat']
|
| 383 |
+
else:
|
| 384 |
+
self.output_format = CoreNLPClient.DEFAULT_OUTPUT_FORMAT
|
| 385 |
+
|
| 386 |
+
def _setup_server_defaults(self):
|
| 387 |
+
"""
|
| 388 |
+
Set up the default properties for the server.
|
| 389 |
+
|
| 390 |
+
The properties argument can take on one of 3 value types
|
| 391 |
+
|
| 392 |
+
1. File path on system or in CLASSPATH (e.g. /path/to/server.props or StanfordCoreNLP-french.properties
|
| 393 |
+
2. Name of a Stanford CoreNLP supported language (e.g. french or fr)
|
| 394 |
+
3. Python dictionary (properties written to tmp file for Java server, erased at end)
|
| 395 |
+
|
| 396 |
+
In addition, an annotators list and output_format can be specified directly with arguments. These
|
| 397 |
+
will overwrite any settings in the specified properties.
|
| 398 |
+
|
| 399 |
+
If no properties are specified, the standard Stanford CoreNLP English server will be launched. The outputFormat
|
| 400 |
+
will be set to 'serialized' and use the ProtobufAnnotationSerializer.
|
| 401 |
+
"""
|
| 402 |
+
|
| 403 |
+
# ensure properties is str or dict
|
| 404 |
+
if self.properties is None or (not isinstance(self.properties, str) and not isinstance(self.properties, dict)):
|
| 405 |
+
if self.properties is not None:
|
| 406 |
+
logger.warning('properties passed invalid value (not a str or dict), setting properties = {}')
|
| 407 |
+
self.properties = {}
|
| 408 |
+
# check if properties is a string, pass on to server which can handle
|
| 409 |
+
if isinstance(self.properties, str):
|
| 410 |
+
# try to translate to Stanford CoreNLP language name, or assume properties is a path
|
| 411 |
+
if is_corenlp_lang(self.properties):
|
| 412 |
+
if self.properties.lower() in LANGUAGE_SHORTHANDS_TO_FULL:
|
| 413 |
+
self.properties = LANGUAGE_SHORTHANDS_TO_FULL[self.properties]
|
| 414 |
+
logger.info(
|
| 415 |
+
f"Using CoreNLP default properties for: {self.properties}. Make sure to have "
|
| 416 |
+
f"{self.properties} models jar (available for download here: "
|
| 417 |
+
f"https://stanfordnlp.github.io/CoreNLP/) in CLASSPATH")
|
| 418 |
+
else:
|
| 419 |
+
if not os.path.isfile(self.properties):
|
| 420 |
+
logger.warning(f"{self.properties} does not correspond to a file path. Make sure this file is in "
|
| 421 |
+
f"your CLASSPATH.")
|
| 422 |
+
self.server_props_path = self.properties
|
| 423 |
+
elif isinstance(self.properties, dict):
|
| 424 |
+
# make a copy
|
| 425 |
+
server_start_properties = dict(self.properties)
|
| 426 |
+
if self.annotators is not None:
|
| 427 |
+
server_start_properties['annotators'] = self.annotators
|
| 428 |
+
if self.output_format is not None and isinstance(self.output_format, str):
|
| 429 |
+
server_start_properties['outputFormat'] = self.output_format
|
| 430 |
+
# write desired server start properties to tmp file
|
| 431 |
+
# set up to erase on exit
|
| 432 |
+
tmp_path = write_corenlp_props(server_start_properties)
|
| 433 |
+
logger.info(f"Writing properties to tmp file: {tmp_path}")
|
| 434 |
+
atexit.register(clean_props_file, tmp_path)
|
| 435 |
+
self.server_props_path = tmp_path
|
| 436 |
+
|
| 437 |
+
def _request(self, buf, properties, reset_default=False, **kwargs):
|
| 438 |
+
"""
|
| 439 |
+
Send a request to the CoreNLP server.
|
| 440 |
+
|
| 441 |
+
:param (str | bytes) buf: data to be sent with the request
|
| 442 |
+
:param (dict) properties: properties that the server expects
|
| 443 |
+
:return: request result
|
| 444 |
+
"""
|
| 445 |
+
if self.start_server is not StartServer.DONT_START:
|
| 446 |
+
self.ensure_alive()
|
| 447 |
+
|
| 448 |
+
try:
|
| 449 |
+
input_format = properties.get("inputFormat", "text")
|
| 450 |
+
if input_format == "text":
|
| 451 |
+
ctype = "text/plain; charset=utf-8"
|
| 452 |
+
elif input_format == "serialized":
|
| 453 |
+
ctype = "application/x-protobuf"
|
| 454 |
+
else:
|
| 455 |
+
raise ValueError("Unrecognized inputFormat " + input_format)
|
| 456 |
+
# handle auth
|
| 457 |
+
if 'username' in kwargs and 'password' in kwargs:
|
| 458 |
+
kwargs['auth'] = requests.auth.HTTPBasicAuth(kwargs['username'], kwargs['password'])
|
| 459 |
+
kwargs.pop('username')
|
| 460 |
+
kwargs.pop('password')
|
| 461 |
+
r = requests.post(self.endpoint,
|
| 462 |
+
params={'properties': str(properties), 'resetDefault': str(reset_default).lower()},
|
| 463 |
+
data=buf, headers={'content-type': ctype},
|
| 464 |
+
timeout=(self.timeout*2)/1000, **kwargs)
|
| 465 |
+
r.raise_for_status()
|
| 466 |
+
return r
|
| 467 |
+
except requests.exceptions.Timeout as e:
|
| 468 |
+
raise TimeoutException("Timeout requesting to CoreNLPServer. Maybe server is unavailable or your document is too long")
|
| 469 |
+
except requests.exceptions.RequestException as e:
|
| 470 |
+
if e.response is not None and e.response.text is not None:
|
| 471 |
+
raise AnnotationException(e.response.text) from e
|
| 472 |
+
elif e.args:
|
| 473 |
+
raise AnnotationException(e.args[0]) from e
|
| 474 |
+
raise AnnotationException() from e
|
| 475 |
+
|
| 476 |
+
def annotate(self, text, annotators=None, output_format=None, properties=None, reset_default=None, **kwargs):
|
| 477 |
+
"""
|
| 478 |
+
Send a request to the CoreNLP server.
|
| 479 |
+
|
| 480 |
+
:param (str | unicode) text: raw text for the CoreNLPServer to parse
|
| 481 |
+
:param (list | string) annotators: list of annotators to use
|
| 482 |
+
:param (str) output_format: output type from server: serialized, json, text, conll, conllu, or xml
|
| 483 |
+
:param (dict) properties: additional request properties (written on top of defaults)
|
| 484 |
+
:param (bool) reset_default: don't use server defaults
|
| 485 |
+
|
| 486 |
+
Precedence for settings:
|
| 487 |
+
|
| 488 |
+
1. annotators and output_format args
|
| 489 |
+
2. Values from properties dict
|
| 490 |
+
3. Client defaults self.annotators and self.output_format (set during client construction)
|
| 491 |
+
4. Server defaults
|
| 492 |
+
|
| 493 |
+
Additional request parameters (apart from CoreNLP pipeline properties) such as 'username' and 'password'
|
| 494 |
+
can be specified with the kwargs.
|
| 495 |
+
|
| 496 |
+
:return: request result
|
| 497 |
+
"""
|
| 498 |
+
|
| 499 |
+
# validate request properties
|
| 500 |
+
validate_corenlp_props(properties=properties, annotators=annotators, output_format=output_format)
|
| 501 |
+
# set request properties
|
| 502 |
+
request_properties = {}
|
| 503 |
+
|
| 504 |
+
# start with client defaults
|
| 505 |
+
if self.annotators is not None:
|
| 506 |
+
request_properties['annotators'] = self.annotators
|
| 507 |
+
if self.output_format is not None:
|
| 508 |
+
request_properties['outputFormat'] = self.output_format
|
| 509 |
+
|
| 510 |
+
# add values from properties arg
|
| 511 |
+
# handle str case
|
| 512 |
+
if type(properties) == str:
|
| 513 |
+
if is_corenlp_lang(properties):
|
| 514 |
+
properties = {'pipelineLanguage': properties.lower()}
|
| 515 |
+
if reset_default is None:
|
| 516 |
+
reset_default = True
|
| 517 |
+
else:
|
| 518 |
+
raise ValueError(f"Unrecognized properties keyword {properties}")
|
| 519 |
+
|
| 520 |
+
if type(properties) == dict:
|
| 521 |
+
request_properties.update(properties)
|
| 522 |
+
|
| 523 |
+
# if annotators list is specified, override with that
|
| 524 |
+
# also can use the annotators field the object was created with
|
| 525 |
+
if annotators is not None and (type(annotators) == str or type(annotators) == list):
|
| 526 |
+
request_properties['annotators'] = annotators if type(annotators) == str else ",".join(annotators)
|
| 527 |
+
|
| 528 |
+
# if output format is specified, override with that
|
| 529 |
+
if output_format is not None and type(output_format) == str:
|
| 530 |
+
request_properties['outputFormat'] = output_format
|
| 531 |
+
|
| 532 |
+
# make the request
|
| 533 |
+
# if not explicitly set or the case of pipelineLanguage, reset_default should be None
|
| 534 |
+
if reset_default is None:
|
| 535 |
+
reset_default = False
|
| 536 |
+
r = self._request(text.encode('utf-8'), request_properties, reset_default, **kwargs)
|
| 537 |
+
if request_properties["outputFormat"] == "json":
|
| 538 |
+
return r.json()
|
| 539 |
+
elif request_properties["outputFormat"] == "serialized":
|
| 540 |
+
doc = Document()
|
| 541 |
+
parseFromDelimitedString(doc, r.content)
|
| 542 |
+
return doc
|
| 543 |
+
elif request_properties["outputFormat"] in ["text", "conllu", "conll", "xml"]:
|
| 544 |
+
return r.text
|
| 545 |
+
else:
|
| 546 |
+
return r
|
| 547 |
+
|
| 548 |
+
def update(self, doc, annotators=None, properties=None):
|
| 549 |
+
if properties is None:
|
| 550 |
+
properties = {}
|
| 551 |
+
properties.update({
|
| 552 |
+
'inputFormat': 'serialized',
|
| 553 |
+
'outputFormat': 'serialized',
|
| 554 |
+
'serializer': 'edu.stanford.nlp.pipeline.ProtobufAnnotationSerializer'
|
| 555 |
+
})
|
| 556 |
+
if annotators:
|
| 557 |
+
properties['annotators'] = annotators if type(annotators) == str else ",".join(annotators)
|
| 558 |
+
with io.BytesIO() as stream:
|
| 559 |
+
writeToDelimitedString(doc, stream)
|
| 560 |
+
msg = stream.getvalue()
|
| 561 |
+
|
| 562 |
+
r = self._request(msg, properties)
|
| 563 |
+
doc = Document()
|
| 564 |
+
parseFromDelimitedString(doc, r.content)
|
| 565 |
+
return doc
|
| 566 |
+
|
| 567 |
+
def tokensregex(self, text, pattern, filter=False, to_words=False, annotators=None, properties=None):
|
| 568 |
+
# this is required for some reason
|
| 569 |
+
matches = self.__regex('/tokensregex', text, pattern, filter, annotators, properties)
|
| 570 |
+
if to_words:
|
| 571 |
+
matches = regex_matches_to_indexed_words(matches)
|
| 572 |
+
return matches
|
| 573 |
+
|
| 574 |
+
def semgrex(self, text, pattern, filter=False, to_words=False, annotators=None, properties=None):
|
| 575 |
+
matches = self.__regex('/semgrex', text, pattern, filter, annotators, properties)
|
| 576 |
+
if to_words:
|
| 577 |
+
matches = regex_matches_to_indexed_words(matches)
|
| 578 |
+
return matches
|
| 579 |
+
|
| 580 |
+
def fill_tree_proto(self, tree, proto_tree):
|
| 581 |
+
if tree.label:
|
| 582 |
+
proto_tree.value = tree.label
|
| 583 |
+
for child in tree.children:
|
| 584 |
+
proto_child = proto_tree.child.add()
|
| 585 |
+
self.fill_tree_proto(child, proto_child)
|
| 586 |
+
|
| 587 |
+
def tregex(self, text=None, pattern=None, filter=False, annotators=None, properties=None, trees=None):
|
| 588 |
+
# parse is not included by default in some of the pipelines,
|
| 589 |
+
# so we may need to manually override the annotators
|
| 590 |
+
# to include parse in order for tregex to do anything
|
| 591 |
+
if annotators is None and self.annotators is not None:
|
| 592 |
+
assert isinstance(self.annotators, str)
|
| 593 |
+
pieces = self.annotators.split(",")
|
| 594 |
+
if "parse" not in pieces:
|
| 595 |
+
annotators = self.annotators + ",parse"
|
| 596 |
+
else:
|
| 597 |
+
annotators = "tokenize,ssplit,pos,parse"
|
| 598 |
+
if pattern is None:
|
| 599 |
+
raise ValueError("Cannot have None as a pattern for tregex")
|
| 600 |
+
|
| 601 |
+
# TODO: we could also allow for passing in a complete document,
|
| 602 |
+
# along with the original text, so that the spans returns are more accurate
|
| 603 |
+
if trees is not None:
|
| 604 |
+
if properties is None:
|
| 605 |
+
properties = {}
|
| 606 |
+
properties['inputFormat'] = 'serialized'
|
| 607 |
+
if 'serializer' not in properties:
|
| 608 |
+
properties['serializer'] = 'edu.stanford.nlp.pipeline.ProtobufAnnotationSerializer'
|
| 609 |
+
doc = Document()
|
| 610 |
+
full_text = []
|
| 611 |
+
for tree_idx, tree in enumerate(trees):
|
| 612 |
+
sentence = doc.sentence.add()
|
| 613 |
+
sentence.sentenceIndex = tree_idx
|
| 614 |
+
sentence.tokenOffsetBegin = len(full_text)
|
| 615 |
+
leaves = tree.leaf_labels()
|
| 616 |
+
full_text.extend(leaves)
|
| 617 |
+
sentence.tokenOffsetEnd = len(full_text)
|
| 618 |
+
self.fill_tree_proto(tree, sentence.parseTree)
|
| 619 |
+
for word in leaves:
|
| 620 |
+
token = sentence.token.add()
|
| 621 |
+
# the other side uses both value and word, weirdly enough
|
| 622 |
+
token.value = word
|
| 623 |
+
token.word = word
|
| 624 |
+
# without the actual tokenization, at least we can
|
| 625 |
+
# stop the words from running together
|
| 626 |
+
token.after = " "
|
| 627 |
+
doc.text = " ".join(full_text)
|
| 628 |
+
with io.BytesIO() as stream:
|
| 629 |
+
writeToDelimitedString(doc, stream)
|
| 630 |
+
text = stream.getvalue()
|
| 631 |
+
|
| 632 |
+
return self.__regex('/tregex', text, pattern, filter, annotators, properties)
|
| 633 |
+
|
| 634 |
+
def __regex(self, path, text, pattern, filter, annotators=None, properties=None):
|
| 635 |
+
"""
|
| 636 |
+
Send a regex-related request to the CoreNLP server.
|
| 637 |
+
|
| 638 |
+
:param (str | unicode) path: the path for the regex endpoint
|
| 639 |
+
:param text: raw text for the CoreNLPServer to apply the regex
|
| 640 |
+
:param (str | unicode) pattern: regex pattern
|
| 641 |
+
:param (bool) filter: option to filter sentences that contain matches, if false returns matches
|
| 642 |
+
:param properties: option to filter sentences that contain matches, if false returns matches
|
| 643 |
+
:return: request result
|
| 644 |
+
"""
|
| 645 |
+
if self.start_server is not StartServer.DONT_START:
|
| 646 |
+
self.ensure_alive()
|
| 647 |
+
if properties is None:
|
| 648 |
+
properties = {}
|
| 649 |
+
properties.update({
|
| 650 |
+
'inputFormat': 'text',
|
| 651 |
+
'serializer': 'edu.stanford.nlp.pipeline.ProtobufAnnotationSerializer'
|
| 652 |
+
})
|
| 653 |
+
if annotators:
|
| 654 |
+
properties['annotators'] = ",".join(annotators) if isinstance(annotators, list) else annotators
|
| 655 |
+
|
| 656 |
+
# force output for regex requests to be json
|
| 657 |
+
properties['outputFormat'] = 'json'
|
| 658 |
+
# if the server is trying to send back character offsets, it
|
| 659 |
+
# should send back codepoints counts as well in case the text
|
| 660 |
+
# has extra wide characters
|
| 661 |
+
properties['tokenize.codepoint'] = 'true'
|
| 662 |
+
|
| 663 |
+
try:
|
| 664 |
+
# Error occurs unless put properties in params
|
| 665 |
+
input_format = properties.get("inputFormat", "text")
|
| 666 |
+
if input_format == "text":
|
| 667 |
+
ctype = "text/plain; charset=utf-8"
|
| 668 |
+
elif input_format == "serialized":
|
| 669 |
+
ctype = "application/x-protobuf"
|
| 670 |
+
else:
|
| 671 |
+
raise ValueError("Unrecognized inputFormat " + input_format)
|
| 672 |
+
# change request method from `get` to `post` as required by CoreNLP
|
| 673 |
+
r = requests.post(
|
| 674 |
+
self.endpoint + path, params={
|
| 675 |
+
'pattern': pattern,
|
| 676 |
+
'filter': filter,
|
| 677 |
+
'properties': str(properties)
|
| 678 |
+
},
|
| 679 |
+
data=text.encode('utf-8') if isinstance(text, str) else text,
|
| 680 |
+
headers={'content-type': ctype},
|
| 681 |
+
timeout=(self.timeout*2)/1000,
|
| 682 |
+
)
|
| 683 |
+
r.raise_for_status()
|
| 684 |
+
if r.encoding is None:
|
| 685 |
+
r.encoding = "utf-8"
|
| 686 |
+
return json.loads(r.text)
|
| 687 |
+
except requests.HTTPError as e:
|
| 688 |
+
if r.text.startswith("Timeout"):
|
| 689 |
+
raise TimeoutException(r.text)
|
| 690 |
+
else:
|
| 691 |
+
raise AnnotationException(r.text)
|
| 692 |
+
except json.JSONDecodeError:
|
| 693 |
+
raise AnnotationException(r.text)
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
def scenegraph(self, text, properties=None):
|
| 697 |
+
"""
|
| 698 |
+
Send a request to the server which processes the text using SceneGraph
|
| 699 |
+
|
| 700 |
+
This will require a new CoreNLP release, 4.5.5 or later
|
| 701 |
+
"""
|
| 702 |
+
# since we're using requests ourself,
|
| 703 |
+
# check if the server has started or not
|
| 704 |
+
if self.start_server is not StartServer.DONT_START:
|
| 705 |
+
self.ensure_alive()
|
| 706 |
+
|
| 707 |
+
if properties is None:
|
| 708 |
+
properties = {}
|
| 709 |
+
# the only thing the scenegraph knows how to use is text
|
| 710 |
+
properties['inputFormat'] = 'text'
|
| 711 |
+
ctype = "text/plain; charset=utf-8"
|
| 712 |
+
# the json output format is much more useful
|
| 713 |
+
properties['outputFormat'] = 'json'
|
| 714 |
+
try:
|
| 715 |
+
r = requests.post(
|
| 716 |
+
self.endpoint + "/scenegraph",
|
| 717 |
+
params={
|
| 718 |
+
'properties': str(properties)
|
| 719 |
+
},
|
| 720 |
+
data=text.encode('utf-8') if isinstance(text, str) else text,
|
| 721 |
+
headers={'content-type': ctype},
|
| 722 |
+
timeout=(self.timeout*2)/1000,
|
| 723 |
+
)
|
| 724 |
+
r.raise_for_status()
|
| 725 |
+
if r.encoding is None:
|
| 726 |
+
r.encoding = "utf-8"
|
| 727 |
+
return json.loads(r.text)
|
| 728 |
+
except requests.HTTPError as e:
|
| 729 |
+
if r.text.startswith("Timeout"):
|
| 730 |
+
raise TimeoutException(r.text)
|
| 731 |
+
else:
|
| 732 |
+
raise AnnotationException(r.text)
|
| 733 |
+
except json.JSONDecodeError:
|
| 734 |
+
raise AnnotationException(r.text)
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
def read_corenlp_props(props_path):
|
| 738 |
+
""" Read a Stanford CoreNLP properties file into a dict """
|
| 739 |
+
props_dict = {}
|
| 740 |
+
with open(props_path) as props_file:
|
| 741 |
+
entry_lines = [entry_line for entry_line in props_file.read().split('\n')
|
| 742 |
+
if entry_line.strip() and not entry_line.startswith('#')]
|
| 743 |
+
for entry_line in entry_lines:
|
| 744 |
+
k = entry_line.split('=')[0]
|
| 745 |
+
k_len = len(k+"=")
|
| 746 |
+
v = entry_line[k_len:]
|
| 747 |
+
props_dict[k.strip()] = v
|
| 748 |
+
return props_dict
|
| 749 |
+
|
| 750 |
+
|
| 751 |
+
def write_corenlp_props(props_dict, file_path=None):
|
| 752 |
+
""" Write a Stanford CoreNLP properties dict to a file """
|
| 753 |
+
if file_path is None:
|
| 754 |
+
file_path = f"corenlp_server-{uuid.uuid4().hex[:16]}.props"
|
| 755 |
+
# confirm tmp file path matches pattern
|
| 756 |
+
assert SERVER_PROPS_TMP_FILE_PATTERN.match(file_path)
|
| 757 |
+
with open(file_path, 'w') as props_file:
|
| 758 |
+
for k, v in props_dict.items():
|
| 759 |
+
if isinstance(v, list):
|
| 760 |
+
writeable_v = ",".join(v)
|
| 761 |
+
else:
|
| 762 |
+
writeable_v = v
|
| 763 |
+
props_file.write(f'{k} = {writeable_v}\n\n')
|
| 764 |
+
return file_path
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
def regex_matches_to_indexed_words(matches):
|
| 768 |
+
"""
|
| 769 |
+
Transforms tokensregex and semgrex matches to indexed words.
|
| 770 |
+
:param matches: unprocessed regex matches
|
| 771 |
+
:return: flat array of indexed words
|
| 772 |
+
"""
|
| 773 |
+
words = [dict(v, **dict([('sentence', i)]))
|
| 774 |
+
for i, s in enumerate(matches['sentences'])
|
| 775 |
+
for k, v in s.items() if k != 'length']
|
| 776 |
+
return words
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
__all__ = ["CoreNLPClient", "AnnotationException", "TimeoutException", "to_text"]
|
stanza/stanza/server/semgrex.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Invokes the Java semgrex on a document
|
| 2 |
+
|
| 3 |
+
The server client has a method "semgrex" which sends text to Java
|
| 4 |
+
CoreNLP for processing with a semgrex (SEMantic GRaph regEX) query:
|
| 5 |
+
|
| 6 |
+
https://nlp.stanford.edu/nlp/javadoc/javanlp/edu/stanford/nlp/semgraph/semgrex/SemgrexPattern.html
|
| 7 |
+
|
| 8 |
+
However, this operates on text using the CoreNLP tools, which means
|
| 9 |
+
the dependency graphs may not align with stanza's depparse module, and
|
| 10 |
+
this also limits the languages for which it can be used. This module
|
| 11 |
+
allows for running semgrex commands on the graphs produced by
|
| 12 |
+
depparse.
|
| 13 |
+
|
| 14 |
+
To use, first process text into a doc using stanza.Pipeline
|
| 15 |
+
|
| 16 |
+
Next, pass the processed doc and a list of semgrex patterns to
|
| 17 |
+
process_doc in this module. It will run the java semgrex module as a
|
| 18 |
+
subprocess and return the result in the form of a SemgrexResponse,
|
| 19 |
+
whose description is in the proto file included with stanza.
|
| 20 |
+
|
| 21 |
+
A minimal example is the main method of this module.
|
| 22 |
+
|
| 23 |
+
Note that launching the subprocess is potentially quite expensive
|
| 24 |
+
relative to the search if used many times on small documents. Ideally
|
| 25 |
+
larger texts would be processed, and all of the desired semgrex
|
| 26 |
+
patterns would be run at once. The worst thing to do would be to call
|
| 27 |
+
this multiple times on a large document, one invocation per semgrex
|
| 28 |
+
pattern, as that would serialize the document each time.
|
| 29 |
+
Included here is a context manager which allows for keeping the same
|
| 30 |
+
java process open for multiple requests. This saves on the subprocess
|
| 31 |
+
launching time. It is still important not to wastefully serialize the
|
| 32 |
+
same document over and over, though.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
import argparse
|
| 36 |
+
import copy
|
| 37 |
+
|
| 38 |
+
import stanza
|
| 39 |
+
from stanza.protobuf import SemgrexRequest, SemgrexResponse
|
| 40 |
+
from stanza.server.java_protobuf_requests import send_request, add_token, add_word_to_graph, JavaProtobufContext, convert_networkx_graph
|
| 41 |
+
from stanza.utils.conll import CoNLL
|
| 42 |
+
|
| 43 |
+
SEMGREX_JAVA = "edu.stanford.nlp.semgraph.semgrex.ProcessSemgrexRequest"
|
| 44 |
+
|
| 45 |
+
def send_semgrex_request(request):
|
| 46 |
+
return send_request(request, SemgrexResponse, SEMGREX_JAVA)
|
| 47 |
+
|
| 48 |
+
def build_request(doc, semgrex_patterns, enhanced=False):
|
| 49 |
+
request = SemgrexRequest()
|
| 50 |
+
if isinstance(semgrex_patterns, str):
|
| 51 |
+
semgrex_patterns = [semgrex_patterns]
|
| 52 |
+
for semgrex in semgrex_patterns:
|
| 53 |
+
request.semgrex.append(semgrex)
|
| 54 |
+
|
| 55 |
+
for sent_idx, sentence in enumerate(doc.sentences):
|
| 56 |
+
query = request.query.add()
|
| 57 |
+
if enhanced:
|
| 58 |
+
# tokens will be added on to the graph object
|
| 59 |
+
convert_networkx_graph(query.graph, sentence, sent_idx)
|
| 60 |
+
else:
|
| 61 |
+
word_idx = 0
|
| 62 |
+
for token in sentence.tokens:
|
| 63 |
+
for word in token.words:
|
| 64 |
+
add_token(query.token, word, token)
|
| 65 |
+
add_word_to_graph(query.graph, word, sent_idx, word_idx)
|
| 66 |
+
|
| 67 |
+
word_idx = word_idx + 1
|
| 68 |
+
|
| 69 |
+
return request
|
| 70 |
+
|
| 71 |
+
def process_doc(doc, *semgrex_patterns, enhanced=False):
|
| 72 |
+
"""
|
| 73 |
+
Returns the result of processing the given semgrex expression on the stanza doc.
|
| 74 |
+
|
| 75 |
+
Currently the return is a SemgrexResponse from CoreNLP.proto
|
| 76 |
+
"""
|
| 77 |
+
request = build_request(doc, semgrex_patterns, enhanced=enhanced)
|
| 78 |
+
|
| 79 |
+
return send_semgrex_request(request)
|
| 80 |
+
|
| 81 |
+
class Semgrex(JavaProtobufContext):
|
| 82 |
+
"""
|
| 83 |
+
Semgrex context window
|
| 84 |
+
|
| 85 |
+
This is a context window which keeps a process open. Should allow
|
| 86 |
+
for multiple requests without launching new java processes each time.
|
| 87 |
+
"""
|
| 88 |
+
def __init__(self, classpath=None):
|
| 89 |
+
super(Semgrex, self).__init__(classpath, SemgrexResponse, SEMGREX_JAVA)
|
| 90 |
+
|
| 91 |
+
def process(self, doc, *semgrex_patterns):
|
| 92 |
+
"""
|
| 93 |
+
Apply each of the semgrex patterns to each of the dependency trees in doc
|
| 94 |
+
"""
|
| 95 |
+
request = build_request(doc, semgrex_patterns)
|
| 96 |
+
return self.process_request(request)
|
| 97 |
+
|
| 98 |
+
def annotate_doc(doc, semgrex_result, semgrex_patterns, matches_only):
|
| 99 |
+
"""
|
| 100 |
+
Put comments on the sentences which describe the matching semgrex patterns
|
| 101 |
+
"""
|
| 102 |
+
doc = copy.deepcopy(doc)
|
| 103 |
+
if isinstance(semgrex_patterns, str):
|
| 104 |
+
semgrex_patterns = [semgrex_patterns]
|
| 105 |
+
matching_sentences = []
|
| 106 |
+
for sentence, graph_result in zip(doc.sentences, semgrex_result.result):
|
| 107 |
+
sentence_matched = False
|
| 108 |
+
for semgrex_pattern, pattern_result in zip(semgrex_patterns, graph_result.result):
|
| 109 |
+
semgrex_pattern = semgrex_pattern.replace("\n", " ")
|
| 110 |
+
if len(pattern_result.match) == 0:
|
| 111 |
+
sentence.add_comment("# semgrex pattern |%s| did not match!" % semgrex_pattern)
|
| 112 |
+
else:
|
| 113 |
+
sentence_matched = True
|
| 114 |
+
for match in pattern_result.match:
|
| 115 |
+
match_word = "%d:%s" % (match.matchIndex, sentence.words[match.matchIndex-1].text)
|
| 116 |
+
if len(match.node) == 0:
|
| 117 |
+
node_matches = ""
|
| 118 |
+
else:
|
| 119 |
+
node_matches = ["%s=%d:%s" % (node.name, node.matchIndex, sentence.words[node.matchIndex-1].text)
|
| 120 |
+
for node in match.node]
|
| 121 |
+
node_matches = " " + " ".join(node_matches)
|
| 122 |
+
sentence.add_comment("# semgrex pattern |%s| matched at %s%s" % (semgrex_pattern, match_word, node_matches))
|
| 123 |
+
if sentence_matched:
|
| 124 |
+
matching_sentences.append(sentence)
|
| 125 |
+
if matches_only:
|
| 126 |
+
doc.sentences = matching_sentences
|
| 127 |
+
return doc
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def main():
|
| 131 |
+
"""
|
| 132 |
+
Runs a toy example, or can run a given semgrex expression on the given input file.
|
| 133 |
+
|
| 134 |
+
For example:
|
| 135 |
+
python3 -m stanza.server.semgrex --input_file demo/semgrex_sample.conllu
|
| 136 |
+
|
| 137 |
+
--matches_only to only print sentences that match the semgrex pattern
|
| 138 |
+
--no_print_input to not print the input
|
| 139 |
+
"""
|
| 140 |
+
parser = argparse.ArgumentParser()
|
| 141 |
+
parser.add_argument('--input_file', type=str, default=None, help="Input file to process (otherwise will process a sample text)")
|
| 142 |
+
parser.add_argument('semgrex', type=str, nargs="*", default=["{}=source >obj=zzz {}=target"], help="Semgrex to apply to the text. The default looks for sentences with objects")
|
| 143 |
+
parser.add_argument('--semgrex_file', type=str, default=None, help="File to read semgrex patterns from - relevant in case the pattern you want to use doesn't work well on the command line, for example")
|
| 144 |
+
parser.add_argument('--print_input', dest='print_input', action='store_true', default=False, help="Print the input alongside the output - gets kind of noisy")
|
| 145 |
+
parser.add_argument('--no_print_input', dest='print_input', action='store_false', help="Don't print the input alongside the output - gets kind of noisy")
|
| 146 |
+
parser.add_argument('--matches_only', action='store_true', default=False, help="Only print the matching sentences")
|
| 147 |
+
parser.add_argument('--enhanced', action='store_true', default=False, help='Use the enhanced dependencies instead of the basic')
|
| 148 |
+
args = parser.parse_args()
|
| 149 |
+
|
| 150 |
+
if args.semgrex_file:
|
| 151 |
+
with open(args.semgrex_file) as fin:
|
| 152 |
+
args.semgrex = [x.strip() for x in fin.readlines() if x.strip()]
|
| 153 |
+
|
| 154 |
+
if args.input_file:
|
| 155 |
+
doc = CoNLL.conll2doc(input_file=args.input_file, ignore_gapping=False)
|
| 156 |
+
else:
|
| 157 |
+
nlp = stanza.Pipeline('en', processors='tokenize,pos,lemma,depparse')
|
| 158 |
+
doc = nlp('Uro ruined modern. Fortunately, Wotc banned him.')
|
| 159 |
+
|
| 160 |
+
if args.print_input:
|
| 161 |
+
print("{:C}".format(doc))
|
| 162 |
+
print()
|
| 163 |
+
print("-" * 75)
|
| 164 |
+
print()
|
| 165 |
+
semgrex_result = process_doc(doc, *args.semgrex, enhanced=args.enhanced)
|
| 166 |
+
doc = annotate_doc(doc, semgrex_result, args.semgrex, args.matches_only)
|
| 167 |
+
print("{:C}".format(doc))
|
| 168 |
+
|
| 169 |
+
if __name__ == '__main__':
|
| 170 |
+
main()
|
stanza/stanza/server/ssurgeon.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Invokes the Java ssurgeon on a document
|
| 2 |
+
|
| 3 |
+
"ssurgeon" sends text to Java CoreNLP for processing with a ssurgeon
|
| 4 |
+
(Semantic graph SURGEON) query
|
| 5 |
+
|
| 6 |
+
The main program in this file gives a very short intro to how to use it.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
from collections import namedtuple
|
| 12 |
+
import copy
|
| 13 |
+
import os
|
| 14 |
+
import re
|
| 15 |
+
import sys
|
| 16 |
+
|
| 17 |
+
from stanza.models.common.utils import misc_to_space_after, space_after_to_misc
|
| 18 |
+
from stanza.protobuf import SsurgeonRequest, SsurgeonResponse
|
| 19 |
+
from stanza.server import java_protobuf_requests
|
| 20 |
+
from stanza.utils.conll import CoNLL
|
| 21 |
+
|
| 22 |
+
from stanza.models.common.doc import ID, TEXT, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC, START_CHAR, END_CHAR, NER, Word, Token, Sentence
|
| 23 |
+
|
| 24 |
+
SSURGEON_JAVA = "edu.stanford.nlp.semgraph.semgrex.ssurgeon.ProcessSsurgeonRequest"
|
| 25 |
+
|
| 26 |
+
SsurgeonEdit = namedtuple("SsurgeonEdit",
|
| 27 |
+
"semgrex_pattern ssurgeon_edits ssurgeon_id notes language",
|
| 28 |
+
defaults=[None, None, "UniversalEnglish"])
|
| 29 |
+
|
| 30 |
+
def parse_ssurgeon_edits(ssurgeon_text):
|
| 31 |
+
ssurgeon_text = ssurgeon_text.strip()
|
| 32 |
+
ssurgeon_blocks = re.split("\n\n+", ssurgeon_text)
|
| 33 |
+
ssurgeon_edits = []
|
| 34 |
+
for idx, block in enumerate(ssurgeon_blocks):
|
| 35 |
+
lines = block.split("\n")
|
| 36 |
+
comments = [line[1:].strip() for line in lines if line.startswith("#")]
|
| 37 |
+
notes = " ".join(comments)
|
| 38 |
+
lines = [x.strip() for x in lines if x.strip() and not x.startswith("#")]
|
| 39 |
+
if len(lines) == 0:
|
| 40 |
+
# was a block of entirely comments
|
| 41 |
+
continue
|
| 42 |
+
semgrex = lines[0]
|
| 43 |
+
ssurgeon = lines[1:]
|
| 44 |
+
ssurgeon_edits.append(SsurgeonEdit(semgrex, ssurgeon, "%d" % (idx + 1), notes))
|
| 45 |
+
return ssurgeon_edits
|
| 46 |
+
|
| 47 |
+
def read_ssurgeon_edits(edit_file):
|
| 48 |
+
with open(edit_file, encoding="utf-8") as fin:
|
| 49 |
+
return parse_ssurgeon_edits(fin.read())
|
| 50 |
+
|
| 51 |
+
def send_ssurgeon_request(request):
|
| 52 |
+
return java_protobuf_requests.send_request(request, SsurgeonResponse, SSURGEON_JAVA)
|
| 53 |
+
|
| 54 |
+
def build_request(doc, ssurgeon_edits):
|
| 55 |
+
request = SsurgeonRequest()
|
| 56 |
+
|
| 57 |
+
for ssurgeon in ssurgeon_edits:
|
| 58 |
+
ssurgeon_proto = request.ssurgeon.add()
|
| 59 |
+
ssurgeon_proto.semgrex = ssurgeon.semgrex_pattern
|
| 60 |
+
for operation in ssurgeon.ssurgeon_edits:
|
| 61 |
+
ssurgeon_proto.operation.append(operation)
|
| 62 |
+
if ssurgeon.ssurgeon_id is not None:
|
| 63 |
+
ssurgeon_proto.id = ssurgeon.ssurgeon_id
|
| 64 |
+
if ssurgeon.notes is not None:
|
| 65 |
+
ssurgeon_proto.notes = ssurgeon.notes
|
| 66 |
+
if ssurgeon.language is not None:
|
| 67 |
+
ssurgeon_proto.language = ssurgeon.language
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
for sent_idx, sentence in enumerate(doc.sentences):
|
| 71 |
+
graph = request.graph.add()
|
| 72 |
+
word_idx = 0
|
| 73 |
+
for token in sentence.tokens:
|
| 74 |
+
for word in token.words:
|
| 75 |
+
java_protobuf_requests.add_token(graph.token, word, token)
|
| 76 |
+
java_protobuf_requests.add_word_to_graph(graph, word, sent_idx, word_idx)
|
| 77 |
+
|
| 78 |
+
word_idx = word_idx + 1
|
| 79 |
+
except Exception as e:
|
| 80 |
+
raise RuntimeError("Failed to process sentence {}:\n{:C}".format(sent_idx, sentence)) from e
|
| 81 |
+
|
| 82 |
+
return request
|
| 83 |
+
|
| 84 |
+
def build_request_one_operation(doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id=None, notes=None):
|
| 85 |
+
ssurgeon_edit = SsurgeonEdit(semgrex_pattern, ssurgeon_edits, ssurgeon_id, notes)
|
| 86 |
+
return build_request(doc, [ssurgeon_edit])
|
| 87 |
+
|
| 88 |
+
def process_doc(doc, ssurgeon_edits):
|
| 89 |
+
"""
|
| 90 |
+
Returns the result of processing the given semgrex expression and ssurgeon edits on the stanza doc.
|
| 91 |
+
|
| 92 |
+
Currently the return is a SsurgeonResponse from CoreNLP.proto
|
| 93 |
+
"""
|
| 94 |
+
request = build_request(doc, ssurgeon_edits)
|
| 95 |
+
|
| 96 |
+
return send_ssurgeon_request(request)
|
| 97 |
+
|
| 98 |
+
def process_doc_one_operation(doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id=None, notes=None):
|
| 99 |
+
request = build_request_one_operation(doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id, notes)
|
| 100 |
+
|
| 101 |
+
return send_ssurgeon_request(request)
|
| 102 |
+
|
| 103 |
+
def build_word_entry(word_index, graph_word):
|
| 104 |
+
word_entry = {
|
| 105 |
+
ID: word_index,
|
| 106 |
+
TEXT: graph_word.word if graph_word.word else None,
|
| 107 |
+
LEMMA: graph_word.lemma if graph_word.lemma else None,
|
| 108 |
+
UPOS: graph_word.coarseTag if graph_word.coarseTag else None,
|
| 109 |
+
XPOS: graph_word.pos if graph_word.pos else None,
|
| 110 |
+
FEATS: java_protobuf_requests.features_to_string(graph_word.conllUFeatures),
|
| 111 |
+
DEPS: None,
|
| 112 |
+
NER: graph_word.ner if graph_word.ner else None,
|
| 113 |
+
MISC: None,
|
| 114 |
+
START_CHAR: None, # TODO: fix this? one problem is the text positions
|
| 115 |
+
END_CHAR: None, # might change across all of the sentences
|
| 116 |
+
# presumably python will complain if this conflicts
|
| 117 |
+
# with one of the constants above
|
| 118 |
+
"is_mwt": graph_word.isMWT,
|
| 119 |
+
"is_first_mwt": graph_word.isFirstMWT,
|
| 120 |
+
"mwt_text": graph_word.mwtText,
|
| 121 |
+
"mwt_misc": graph_word.mwtMisc,
|
| 122 |
+
}
|
| 123 |
+
# TODO: do "before" as well
|
| 124 |
+
word_entry[MISC] = space_after_to_misc(graph_word.after)
|
| 125 |
+
if graph_word.conllUMisc:
|
| 126 |
+
word_entry[MISC] = java_protobuf_requests.substitute_space_misc(graph_word.conllUMisc, word_entry[MISC])
|
| 127 |
+
return word_entry
|
| 128 |
+
|
| 129 |
+
def convert_response_to_doc(doc, semgrex_response):
|
| 130 |
+
doc = copy.deepcopy(doc)
|
| 131 |
+
try:
|
| 132 |
+
for sent_idx, (sentence, ssurgeon_result) in enumerate(zip(doc.sentences, semgrex_response.result)):
|
| 133 |
+
# EditNode is currently bugged... :/
|
| 134 |
+
# TODO: change this after next CoreNLP release (after 4.5.3)
|
| 135 |
+
#if not ssurgeon_result.changed:
|
| 136 |
+
# continue
|
| 137 |
+
|
| 138 |
+
ssurgeon_graph = ssurgeon_result.graph
|
| 139 |
+
tokens = []
|
| 140 |
+
for graph_node, graph_word in zip(ssurgeon_graph.node, ssurgeon_graph.token):
|
| 141 |
+
word_entry = build_word_entry(graph_node.index, graph_word)
|
| 142 |
+
tokens.append(word_entry)
|
| 143 |
+
tokens.sort(key=lambda x: x[ID])
|
| 144 |
+
for root in ssurgeon_graph.root:
|
| 145 |
+
tokens[root-1][HEAD] = 0
|
| 146 |
+
tokens[root-1][DEPREL] = "root"
|
| 147 |
+
for edge in ssurgeon_graph.edge:
|
| 148 |
+
# can't do anything about the extra dependencies for now
|
| 149 |
+
# TODO: put them all in .deps
|
| 150 |
+
if edge.isExtra:
|
| 151 |
+
continue
|
| 152 |
+
tokens[edge.target-1][HEAD] = edge.source
|
| 153 |
+
tokens[edge.target-1][DEPREL] = edge.dep
|
| 154 |
+
|
| 155 |
+
# for any MWT, produce a token_entry which represents the word range
|
| 156 |
+
mwt_tokens = []
|
| 157 |
+
for word_start_idx, word in enumerate(tokens):
|
| 158 |
+
if not word["is_first_mwt"]:
|
| 159 |
+
if word["is_mwt"]:
|
| 160 |
+
word[MISC] = java_protobuf_requests.remove_space_misc(word[MISC])
|
| 161 |
+
mwt_tokens.append(word)
|
| 162 |
+
continue
|
| 163 |
+
word_end_idx = word_start_idx + 1
|
| 164 |
+
while word_end_idx < len(tokens) and tokens[word_end_idx]["is_mwt"] and not tokens[word_end_idx]["is_first_mwt"]:
|
| 165 |
+
word_end_idx += 1
|
| 166 |
+
mwt_token_entry = {
|
| 167 |
+
# the tokens don't fencepost the way lists do
|
| 168 |
+
ID: (tokens[word_start_idx][ID], tokens[word_end_idx-1][ID]),
|
| 169 |
+
TEXT: word["mwt_text"],
|
| 170 |
+
NER: word[NER],
|
| 171 |
+
# use the SpaceAfter=No (or not) from the last word in the token
|
| 172 |
+
MISC: None,
|
| 173 |
+
}
|
| 174 |
+
mwt_token_entry[MISC] = java_protobuf_requests.misc_space_pieces(tokens[word_end_idx-1][MISC])
|
| 175 |
+
if tokens[word_end_idx-1]["mwt_misc"]:
|
| 176 |
+
mwt_token_entry[MISC] = java_protobuf_requests.substitute_space_misc(tokens[word_end_idx-1]["mwt_misc"], mwt_token_entry[MISC])
|
| 177 |
+
word[MISC] = java_protobuf_requests.remove_space_misc(word[MISC])
|
| 178 |
+
mwt_tokens.append(mwt_token_entry)
|
| 179 |
+
mwt_tokens.append(word)
|
| 180 |
+
|
| 181 |
+
old_comments = list(sentence.comments)
|
| 182 |
+
sentence = Sentence(mwt_tokens, doc)
|
| 183 |
+
|
| 184 |
+
token_text = []
|
| 185 |
+
for token_idx, token in enumerate(sentence.tokens):
|
| 186 |
+
token_text.append(token.text)
|
| 187 |
+
if token_idx == len(sentence.tokens) - 1:
|
| 188 |
+
break
|
| 189 |
+
token_text.append(token.spaces_after)
|
| 190 |
+
|
| 191 |
+
sentence_text = "".join(token_text)
|
| 192 |
+
|
| 193 |
+
for comment in old_comments:
|
| 194 |
+
if comment.startswith("# text ") or comment.startswith("#text ") or comment.startswith("# text=") or comment.startswith("#text="):
|
| 195 |
+
sentence.add_comment("# text = " + sentence_text)
|
| 196 |
+
else:
|
| 197 |
+
sentence.add_comment(comment)
|
| 198 |
+
|
| 199 |
+
doc.sentences[sent_idx] = sentence
|
| 200 |
+
|
| 201 |
+
sentence.rebuild_dependencies()
|
| 202 |
+
except Exception as e:
|
| 203 |
+
raise RuntimeError("Ssurgeon could not process sentence {}\nSsurgeon result:\n{}\nOriginal sentence:\n{:C}".format(sent_idx, ssurgeon_result, sentence)) from e
|
| 204 |
+
return doc
|
| 205 |
+
|
| 206 |
+
class Ssurgeon(java_protobuf_requests.JavaProtobufContext):
|
| 207 |
+
"""
|
| 208 |
+
Ssurgeon context window
|
| 209 |
+
|
| 210 |
+
This is a context window which keeps a process open. Should allow
|
| 211 |
+
for multiple requests without launching new java processes each time.
|
| 212 |
+
"""
|
| 213 |
+
def __init__(self, classpath=None):
|
| 214 |
+
super(Ssurgeon, self).__init__(classpath, SsurgeonResponse, SSURGEON_JAVA)
|
| 215 |
+
|
| 216 |
+
def process(self, doc, ssurgeon_edits):
|
| 217 |
+
"""
|
| 218 |
+
Apply each of the ssurgeon patterns to each of the dependency trees in doc
|
| 219 |
+
"""
|
| 220 |
+
request = build_request(doc, ssurgeon_edits)
|
| 221 |
+
return self.process_request(request)
|
| 222 |
+
|
| 223 |
+
def process_one_operation(self, doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id=None, notes=None):
|
| 224 |
+
"""
|
| 225 |
+
Convenience method - build one operation, then apply it
|
| 226 |
+
"""
|
| 227 |
+
request = build_request_one_operation(doc, semgrex_pattern, ssurgeon_edits, ssurgeon_id, notes)
|
| 228 |
+
return self.process_request(request)
|
| 229 |
+
|
| 230 |
+
SAMPLE_DOC = """
|
| 231 |
+
# sent_id = 271
|
| 232 |
+
# text = Hers is easy to clean.
|
| 233 |
+
# previous = What did the dealer like about Alex's car?
|
| 234 |
+
# comment = extraction/raising via "tough extraction" and clausal subject
|
| 235 |
+
1 Hers hers PRON PRP Gender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs 3 nsubj _ _
|
| 236 |
+
2 is be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 cop _ _
|
| 237 |
+
3 easy easy ADJ JJ Degree=Pos 0 root _ _
|
| 238 |
+
4 to to PART TO _ 5 mark _ _
|
| 239 |
+
5 clean clean VERB VB VerbForm=Inf 3 csubj _ SpaceAfter=No
|
| 240 |
+
6 . . PUNCT . _ 5 punct _ _
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
def main():
|
| 244 |
+
# for Windows, so that we aren't randomly printing garbage (or just failing to print)
|
| 245 |
+
try:
|
| 246 |
+
sys.stdout.reconfigure(encoding='utf-8')
|
| 247 |
+
except AttributeError:
|
| 248 |
+
# TODO: deprecate 3.6 support after the next release
|
| 249 |
+
pass
|
| 250 |
+
|
| 251 |
+
# The default semgrex detects sentences in the UD_English-Pronouns dataset which have both nsubj and csubj on the same word.
|
| 252 |
+
# The default ssurgeon transforms the unwanted csubj to advcl
|
| 253 |
+
# See https://github.com/UniversalDependencies/docs/issues/923
|
| 254 |
+
parser = argparse.ArgumentParser()
|
| 255 |
+
parser.add_argument('--input_file', type=str, default=None, help="Input file to process (otherwise will process a sample text)")
|
| 256 |
+
parser.add_argument('--output_file', type=str, default=None, help="Output file (otherwise will write to stdout)")
|
| 257 |
+
parser.add_argument('--input_dir', type=str, default=None, help="Input dir to process instead of a single file. Allows for reusing the Java program")
|
| 258 |
+
parser.add_argument('--input_filter', type=str, default=".*[.]conllu", help="Only process files from the input_dir that match this filter - regex, not shell filter. Default: %(default)s")
|
| 259 |
+
parser.add_argument('--no_input_filter', action='store_const', const=None, help="Remove the default input filename filter")
|
| 260 |
+
parser.add_argument('--output_dir', type=str, default=None, help="Output dir for writing files, necessary if using --input_dir")
|
| 261 |
+
parser.add_argument('--edit_file', type=str, default=None, help="File to get semgrex and ssurgeon rules from")
|
| 262 |
+
parser.add_argument('--semgrex', type=str, default="{}=source >nsubj {} >csubj=bad {}", help="Semgrex to apply to the text. A default detects words which have both an nsubj and a csubj. Default: %(default)s")
|
| 263 |
+
parser.add_argument('ssurgeon', type=str, default=["relabelNamedEdge -edge bad -reln advcl"], nargs="*", help="Ssurgeon edits to apply based on the Semgrex. Can have multiple edits in a row. A default exists to transform csubj into advcl. Default: %(default)s")
|
| 264 |
+
parser.add_argument('--print_input', dest='print_input', action='store_true', default=False, help="Print the input alongside the output - gets kind of noisy. Default: %(default)s")
|
| 265 |
+
parser.add_argument('--no_print_input', dest='print_input', action='store_false', help="Don't print the input alongside the output - gets kind of noisy")
|
| 266 |
+
args = parser.parse_args()
|
| 267 |
+
|
| 268 |
+
if args.edit_file:
|
| 269 |
+
ssurgeon_edits = read_ssurgeon_edits(args.edit_file)
|
| 270 |
+
else:
|
| 271 |
+
ssurgeon_edits = [SsurgeonEdit(args.semgrex, args.ssurgeon)]
|
| 272 |
+
|
| 273 |
+
if args.input_file:
|
| 274 |
+
docs = [CoNLL.conll2doc(input_file=args.input_file)]
|
| 275 |
+
outputs = [args.output_file]
|
| 276 |
+
input_output = zip(docs, outputs)
|
| 277 |
+
elif args.input_dir:
|
| 278 |
+
if not args.output_dir:
|
| 279 |
+
raise ValueError("Cannot process multiple files without knowing where to send them - please set --output_dir in order to use --input_dir")
|
| 280 |
+
if not os.path.exists(args.output_dir):
|
| 281 |
+
os.makedirs(args.output_dir)
|
| 282 |
+
def read_docs():
|
| 283 |
+
for doc_filename in os.listdir(args.input_dir):
|
| 284 |
+
if args.input_filter:
|
| 285 |
+
if not re.match(args.input_filter, doc_filename):
|
| 286 |
+
continue
|
| 287 |
+
doc_path = os.path.join(args.input_dir, doc_filename)
|
| 288 |
+
output_path = os.path.join(args.output_dir, doc_filename)
|
| 289 |
+
print("Processing %s to %s" % (doc_path, output_path))
|
| 290 |
+
yield CoNLL.conll2doc(input_file=doc_path), output_path
|
| 291 |
+
input_output = read_docs()
|
| 292 |
+
else:
|
| 293 |
+
docs = [CoNLL.conll2doc(input_str=SAMPLE_DOC)]
|
| 294 |
+
outputs = [None]
|
| 295 |
+
input_output = zip(docs, outputs)
|
| 296 |
+
|
| 297 |
+
for doc, output in input_output:
|
| 298 |
+
if args.print_input:
|
| 299 |
+
print("{:C}".format(doc))
|
| 300 |
+
ssurgeon_request = build_request(doc, ssurgeon_edits)
|
| 301 |
+
ssurgeon_response = send_ssurgeon_request(ssurgeon_request)
|
| 302 |
+
updated_doc = convert_response_to_doc(doc, ssurgeon_response)
|
| 303 |
+
if output is not None:
|
| 304 |
+
with open(output, "w", encoding="utf-8") as fout:
|
| 305 |
+
fout.write("{:C}\n\n".format(updated_doc))
|
| 306 |
+
else:
|
| 307 |
+
print("{:C}\n".format(updated_doc))
|
| 308 |
+
|
| 309 |
+
if __name__ == '__main__':
|
| 310 |
+
main()
|
stanza/stanza/tests/__init__.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for testing
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
# Environment Variables
|
| 9 |
+
# set this to specify working directory of tests
|
| 10 |
+
TEST_HOME_VAR = 'STANZA_TEST_HOME'
|
| 11 |
+
|
| 12 |
+
# Global Variables
|
| 13 |
+
TEST_DIR_BASE_NAME = 'stanza_test'
|
| 14 |
+
|
| 15 |
+
TEST_WORKING_DIR = os.getenv(TEST_HOME_VAR, None)
|
| 16 |
+
if not TEST_WORKING_DIR:
|
| 17 |
+
TEST_WORKING_DIR = os.path.join(os.getcwd(), TEST_DIR_BASE_NAME)
|
| 18 |
+
|
| 19 |
+
TEST_MODELS_DIR = f'{TEST_WORKING_DIR}/models'
|
| 20 |
+
TEST_CORENLP_DIR = f'{TEST_WORKING_DIR}/corenlp_dir'
|
| 21 |
+
|
| 22 |
+
# server resources
|
| 23 |
+
SERVER_TEST_PROPS = f'{TEST_WORKING_DIR}/scripts/external_server.properties'
|
| 24 |
+
|
| 25 |
+
# language resources
|
| 26 |
+
LANGUAGE_RESOURCES = {}
|
| 27 |
+
|
| 28 |
+
TOKENIZE_MODEL = 'tokenizer.pt'
|
| 29 |
+
MWT_MODEL = 'mwt_expander.pt'
|
| 30 |
+
POS_MODEL = 'tagger.pt'
|
| 31 |
+
POS_PRETRAIN = 'pretrain.pt'
|
| 32 |
+
LEMMA_MODEL = 'lemmatizer.pt'
|
| 33 |
+
DEPPARSE_MODEL = 'parser.pt'
|
| 34 |
+
DEPPARSE_PRETRAIN = 'pretrain.pt'
|
| 35 |
+
|
| 36 |
+
MODEL_FILES = [TOKENIZE_MODEL, MWT_MODEL, POS_MODEL, POS_PRETRAIN, LEMMA_MODEL, DEPPARSE_MODEL, DEPPARSE_PRETRAIN]
|
| 37 |
+
|
| 38 |
+
# English resources
|
| 39 |
+
EN_KEY = 'en'
|
| 40 |
+
EN_SHORTHAND = 'en_ewt'
|
| 41 |
+
# models
|
| 42 |
+
EN_MODELS_DIR = f'{TEST_WORKING_DIR}/models/{EN_SHORTHAND}_models'
|
| 43 |
+
EN_MODEL_FILES = [f'{EN_MODELS_DIR}/{EN_SHORTHAND}_{model_fname}' for model_fname in MODEL_FILES]
|
| 44 |
+
|
| 45 |
+
# French resources
|
| 46 |
+
FR_KEY = 'fr'
|
| 47 |
+
FR_SHORTHAND = 'fr_gsd'
|
| 48 |
+
# regression file paths
|
| 49 |
+
FR_TEST_IN = f'{TEST_WORKING_DIR}/in/fr_gsd.test.txt'
|
| 50 |
+
FR_TEST_OUT = f'{TEST_WORKING_DIR}/out/fr_gsd.test.txt.out'
|
| 51 |
+
FR_TEST_GOLD_OUT = f'{TEST_WORKING_DIR}/out/fr_gsd.test.txt.out.gold'
|
| 52 |
+
# models
|
| 53 |
+
FR_MODELS_DIR = f'{TEST_WORKING_DIR}/models/{FR_SHORTHAND}_models'
|
| 54 |
+
FR_MODEL_FILES = [f'{FR_MODELS_DIR}/{FR_SHORTHAND}_{model_fname}' for model_fname in MODEL_FILES]
|
| 55 |
+
|
| 56 |
+
# Other language resources
|
| 57 |
+
AR_SHORTHAND = 'ar_padt'
|
| 58 |
+
DE_SHORTHAND = 'de_gsd'
|
| 59 |
+
KK_SHORTHAND = 'kk_ktb'
|
| 60 |
+
KO_SHORTHAND = 'ko_gsd'
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# utils for clean up
|
| 64 |
+
# only allow removal of dirs/files in this approved list
|
| 65 |
+
REMOVABLE_PATHS = ['en_ewt_models', 'en_ewt_tokenizer.pt', 'en_ewt_mwt_expander.pt', 'en_ewt_tagger.pt',
|
| 66 |
+
'en_ewt.pretrain.pt', 'en_ewt_lemmatizer.pt', 'en_ewt_parser.pt', 'fr_gsd_models',
|
| 67 |
+
'fr_gsd_tokenizer.pt', 'fr_gsd_mwt_expander.pt', 'fr_gsd_tagger.pt', 'fr_gsd.pretrain.pt',
|
| 68 |
+
'fr_gsd_lemmatizer.pt', 'fr_gsd_parser.pt', 'ar_padt_models', 'ar_padt_tokenizer.pt',
|
| 69 |
+
'ar_padt_mwt_expander.pt', 'ar_padt_tagger.pt', 'ar_padt.pretrain.pt', 'ar_padt_lemmatizer.pt',
|
| 70 |
+
'ar_padt_parser.pt', 'de_gsd_models', 'de_gsd_tokenizer.pt', 'de_gsd_mwt_expander.pt',
|
| 71 |
+
'de_gsd_tagger.pt', 'de_gsd.pretrain.pt', 'de_gsd_lemmatizer.pt', 'de_gsd_parser.pt',
|
| 72 |
+
'kk_ktb_models', 'kk_ktb_tokenizer.pt', 'kk_ktb_mwt_expander.pt', 'kk_ktb_tagger.pt',
|
| 73 |
+
'kk_ktb.pretrain.pt', 'kk_ktb_lemmatizer.pt', 'kk_ktb_parser.pt', 'ko_gsd_models',
|
| 74 |
+
'ko_gsd_tokenizer.pt', 'ko_gsd_mwt_expander.pt', 'ko_gsd_tagger.pt', 'ko_gsd.pretrain.pt',
|
| 75 |
+
'ko_gsd_lemmatizer.pt', 'ko_gsd_parser.pt']
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def safe_rm(path_to_rm):
|
| 79 |
+
"""
|
| 80 |
+
Safely remove a directory of files or a file
|
| 81 |
+
1.) check path exists, files are files, dirs are dirs
|
| 82 |
+
2.) only remove things on approved list REMOVABLE_PATHS
|
| 83 |
+
3.) assert no longer exists
|
| 84 |
+
"""
|
| 85 |
+
# just return if path doesn't exist
|
| 86 |
+
if not os.path.exists(path_to_rm):
|
| 87 |
+
return
|
| 88 |
+
# handle directory
|
| 89 |
+
if os.path.isdir(path_to_rm):
|
| 90 |
+
files_to_rm = [f'{path_to_rm}/{fname}' for fname in os.listdir(path_to_rm)]
|
| 91 |
+
dir_to_rm = path_to_rm
|
| 92 |
+
else:
|
| 93 |
+
files_to_rm = [path_to_rm]
|
| 94 |
+
dir_to_rm = None
|
| 95 |
+
# clear out files
|
| 96 |
+
for file_to_rm in files_to_rm:
|
| 97 |
+
if os.path.isfile(file_to_rm) and os.path.basename(file_to_rm) in REMOVABLE_PATHS:
|
| 98 |
+
os.remove(file_to_rm)
|
| 99 |
+
assert not os.path.exists(file_to_rm), f'Error removing: {file_to_rm}'
|
| 100 |
+
# clear out directory
|
| 101 |
+
if dir_to_rm is not None and os.path.isdir(dir_to_rm):
|
| 102 |
+
os.rmdir(dir_to_rm)
|
| 103 |
+
assert not os.path.exists(dir_to_rm), f'Error removing: {dir_to_rm}'
|
| 104 |
+
|
| 105 |
+
def compare_ignoring_whitespace(predicted, expected):
|
| 106 |
+
predicted = re.sub('[ \t]+', ' ', predicted.strip())
|
| 107 |
+
predicted = re.sub('\r\n', '\n', predicted)
|
| 108 |
+
expected = re.sub('[ \t]+', ' ', expected.strip())
|
| 109 |
+
expected = re.sub('\r\n', '\n', expected)
|
| 110 |
+
assert predicted == expected
|
| 111 |
+
|
stanza/stanza/tests/data/tiny_emb.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
3 4
|
| 2 |
+
unban 1 2 3 4
|
| 3 |
+
mox 5 6 7 8
|
| 4 |
+
opal 9 10 11 12
|
stanza/stanza/tests/datasets/test_common.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test conllu manipulating routines in stanza/utils/dataset/common.py
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from stanza.utils.datasets.common import maybe_add_fake_dependencies
|
| 9 |
+
# from stanza.tests import *
|
| 10 |
+
|
| 11 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 12 |
+
|
| 13 |
+
DEPS_EXAMPLE="""
|
| 14 |
+
# text = Sh'reyan's antennae are hella thicc
|
| 15 |
+
1 Sh'reyan Sh'reyan PROPN NNP Number=Sing 3 nmod:poss 3:nmod:poss SpaceAfter=No
|
| 16 |
+
2 's 's PART POS _ 1 case 1:case _
|
| 17 |
+
3 antennae antenna NOUN NNS Number=Plur 6 nsubj 6:nsubj _
|
| 18 |
+
4 are be VERB VBP Mood=Ind|Tense=Pres|VerbForm=Fin 6 cop 6:cop _
|
| 19 |
+
5 hella hella ADV RB _ 6 advmod 6:advmod _
|
| 20 |
+
6 thicc thicc ADJ JJ Degree=Pos 0 root 0:root _
|
| 21 |
+
""".strip().split("\n")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
ONLY_ROOT_EXAMPLE="""
|
| 25 |
+
# text = Sh'reyan's antennae are hella thicc
|
| 26 |
+
1 Sh'reyan Sh'reyan PROPN NNP Number=Sing _ _ _ SpaceAfter=No
|
| 27 |
+
2 's 's PART POS _ _ _ _ _
|
| 28 |
+
3 antennae antenna NOUN NNS Number=Plur _ _ _ _
|
| 29 |
+
4 are be VERB VBP Mood=Ind|Tense=Pres|VerbForm=Fin _ _ _ _
|
| 30 |
+
5 hella hella ADV RB _ _ _ _ _
|
| 31 |
+
6 thicc thicc ADJ JJ Degree=Pos 0 root 0:root _
|
| 32 |
+
""".strip().split("\n")
|
| 33 |
+
|
| 34 |
+
ONLY_ROOT_EXPECTED="""
|
| 35 |
+
# text = Sh'reyan's antennae are hella thicc
|
| 36 |
+
1 Sh'reyan Sh'reyan PROPN NNP Number=Sing 6 dep _ SpaceAfter=No
|
| 37 |
+
2 's 's PART POS _ 1 dep _ _
|
| 38 |
+
3 antennae antenna NOUN NNS Number=Plur 1 dep _ _
|
| 39 |
+
4 are be VERB VBP Mood=Ind|Tense=Pres|VerbForm=Fin 1 dep _ _
|
| 40 |
+
5 hella hella ADV RB _ 1 dep _ _
|
| 41 |
+
6 thicc thicc ADJ JJ Degree=Pos 0 root 0:root _
|
| 42 |
+
""".strip().split("\n")
|
| 43 |
+
|
| 44 |
+
NO_DEPS_EXAMPLE="""
|
| 45 |
+
# text = Sh'reyan's antennae are hella thicc
|
| 46 |
+
1 Sh'reyan Sh'reyan PROPN NNP Number=Sing _ _ _ SpaceAfter=No
|
| 47 |
+
2 's 's PART POS _ _ _ _ _
|
| 48 |
+
3 antennae antenna NOUN NNS Number=Plur _ _ _ _
|
| 49 |
+
4 are be VERB VBP Mood=Ind|Tense=Pres|VerbForm=Fin _ _ _ _
|
| 50 |
+
5 hella hella ADV RB _ _ _ _ _
|
| 51 |
+
6 thicc thicc ADJ JJ Degree=Pos _ _ _ _
|
| 52 |
+
""".strip().split("\n")
|
| 53 |
+
|
| 54 |
+
NO_DEPS_EXPECTED="""
|
| 55 |
+
# text = Sh'reyan's antennae are hella thicc
|
| 56 |
+
1 Sh'reyan Sh'reyan PROPN NNP Number=Sing 0 root _ SpaceAfter=No
|
| 57 |
+
2 's 's PART POS _ 1 dep _ _
|
| 58 |
+
3 antennae antenna NOUN NNS Number=Plur 1 dep _ _
|
| 59 |
+
4 are be VERB VBP Mood=Ind|Tense=Pres|VerbForm=Fin 1 dep _ _
|
| 60 |
+
5 hella hella ADV RB _ 1 dep _ _
|
| 61 |
+
6 thicc thicc ADJ JJ Degree=Pos 1 dep _ _
|
| 62 |
+
""".strip().split("\n")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def test_fake_deps_no_change():
|
| 66 |
+
result = maybe_add_fake_dependencies(DEPS_EXAMPLE)
|
| 67 |
+
assert result == DEPS_EXAMPLE
|
| 68 |
+
|
| 69 |
+
def test_fake_deps_all_tokens():
|
| 70 |
+
result = maybe_add_fake_dependencies(NO_DEPS_EXAMPLE)
|
| 71 |
+
assert result == NO_DEPS_EXPECTED
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def test_fake_deps_only_root():
|
| 75 |
+
result = maybe_add_fake_dependencies(ONLY_ROOT_EXAMPLE)
|
| 76 |
+
assert result == ONLY_ROOT_EXPECTED
|
stanza/stanza/tests/datasets/test_vietnamese_renormalization.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from stanza.utils.datasets.vietnamese import renormalize
|
| 5 |
+
|
| 6 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 7 |
+
|
| 8 |
+
def test_replace_all():
|
| 9 |
+
text = "SỌAmple tụy test file"
|
| 10 |
+
expected = "SOẠmple tuỵ test file"
|
| 11 |
+
|
| 12 |
+
assert renormalize.replace_all(text) == expected
|
| 13 |
+
|
| 14 |
+
def test_replace_file(tmp_path):
|
| 15 |
+
text = "SỌAmple tụy test file"
|
| 16 |
+
expected = "SOẠmple tuỵ test file"
|
| 17 |
+
|
| 18 |
+
orig = tmp_path / "orig.txt"
|
| 19 |
+
converted = tmp_path / "converted.txt"
|
| 20 |
+
|
| 21 |
+
with open(orig, "w", encoding="utf-8") as fout:
|
| 22 |
+
for i in range(10):
|
| 23 |
+
fout.write(text)
|
| 24 |
+
fout.write("\n")
|
| 25 |
+
|
| 26 |
+
renormalize.convert_file(orig, converted)
|
| 27 |
+
|
| 28 |
+
assert os.path.exists(converted)
|
| 29 |
+
with open(converted, encoding="utf-8") as fin:
|
| 30 |
+
lines = fin.readlines()
|
| 31 |
+
|
| 32 |
+
assert len(lines) == 10
|
| 33 |
+
for i in lines:
|
| 34 |
+
assert i.strip() == expected
|
| 35 |
+
|
stanza/stanza/tests/depparse/__init__.py
ADDED
|
File without changes
|
stanza/stanza/tests/depparse/test_depparse_data.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test some pieces of the depparse dataloader
|
| 3 |
+
"""
|
| 4 |
+
import pytest
|
| 5 |
+
from stanza.models.depparse.data import data_to_batches
|
| 6 |
+
|
| 7 |
+
pytestmark = [pytest.mark.travis, pytest.mark.pipeline]
|
| 8 |
+
|
| 9 |
+
def make_fake_data(*lengths):
|
| 10 |
+
data = []
|
| 11 |
+
for i, length in enumerate(lengths):
|
| 12 |
+
word = chr(ord('A') + i)
|
| 13 |
+
chunk = [[word] * length]
|
| 14 |
+
data.append(chunk)
|
| 15 |
+
return data
|
| 16 |
+
|
| 17 |
+
def check_batches(batched_data, expected_sizes, expected_order):
|
| 18 |
+
for chunk, size in zip(batched_data, expected_sizes):
|
| 19 |
+
assert sum(len(x[0]) for x in chunk) == size
|
| 20 |
+
word_order = []
|
| 21 |
+
for chunk in batched_data:
|
| 22 |
+
for sentence in chunk:
|
| 23 |
+
word_order.append(sentence[0][0])
|
| 24 |
+
assert word_order == expected_order
|
| 25 |
+
|
| 26 |
+
def test_data_to_batches_eval_mode():
|
| 27 |
+
"""
|
| 28 |
+
Tests the chunking of batches in eval_mode
|
| 29 |
+
|
| 30 |
+
A few options are tested, such as whether or not to sort and the maximum sentence size
|
| 31 |
+
"""
|
| 32 |
+
data = make_fake_data(1, 2, 3)
|
| 33 |
+
batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None)
|
| 34 |
+
check_batches(batched_data[0], [5, 1], ['C', 'B', 'A'])
|
| 35 |
+
|
| 36 |
+
data = make_fake_data(1, 2, 6)
|
| 37 |
+
batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None)
|
| 38 |
+
check_batches(batched_data[0], [6, 3], ['C', 'B', 'A'])
|
| 39 |
+
|
| 40 |
+
data = make_fake_data(3, 2, 1)
|
| 41 |
+
batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None)
|
| 42 |
+
check_batches(batched_data[0], [5, 1], ['A', 'B', 'C'])
|
| 43 |
+
|
| 44 |
+
data = make_fake_data(3, 5, 2)
|
| 45 |
+
batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=True, min_length_to_batch_separately=None)
|
| 46 |
+
check_batches(batched_data[0], [5, 5], ['B', 'A', 'C'])
|
| 47 |
+
|
| 48 |
+
data = make_fake_data(3, 5, 2)
|
| 49 |
+
batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=False, min_length_to_batch_separately=3)
|
| 50 |
+
check_batches(batched_data[0], [3, 5, 2], ['A', 'B', 'C'])
|
| 51 |
+
|
| 52 |
+
data = make_fake_data(4, 1, 1)
|
| 53 |
+
batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=False, min_length_to_batch_separately=3)
|
| 54 |
+
check_batches(batched_data[0], [4, 2], ['A', 'B', 'C'])
|
| 55 |
+
|
| 56 |
+
data = make_fake_data(1, 4, 1)
|
| 57 |
+
batched_data = data_to_batches(data, batch_size=5, eval_mode=True, sort_during_eval=False, min_length_to_batch_separately=3)
|
| 58 |
+
check_batches(batched_data[0], [1, 4, 1], ['A', 'B', 'C'])
|
| 59 |
+
|
| 60 |
+
if __name__ == '__main__':
|
| 61 |
+
test_data_to_batches()
|
| 62 |
+
|
stanza/stanza/tests/depparse/test_parser.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Run the tagger for a couple iterations on some fake data
|
| 3 |
+
|
| 4 |
+
Uses a couple sentences of UD_English-EWT as training/dev data
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import pytest
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from stanza.models import parser
|
| 13 |
+
from stanza.models.common import pretrain
|
| 14 |
+
from stanza.models.depparse.trainer import Trainer
|
| 15 |
+
from stanza.tests import TEST_WORKING_DIR
|
| 16 |
+
|
| 17 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 18 |
+
|
| 19 |
+
TRAIN_DATA = """
|
| 20 |
+
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0003
|
| 21 |
+
# text = DPA: Iraqi authorities announced that they had busted up 3 terrorist cells operating in Baghdad.
|
| 22 |
+
1 DPA DPA PROPN NNP Number=Sing 0 root 0:root SpaceAfter=No
|
| 23 |
+
2 : : PUNCT : _ 1 punct 1:punct _
|
| 24 |
+
3 Iraqi Iraqi ADJ JJ Degree=Pos 4 amod 4:amod _
|
| 25 |
+
4 authorities authority NOUN NNS Number=Plur 5 nsubj 5:nsubj _
|
| 26 |
+
5 announced announce VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 1 parataxis 1:parataxis _
|
| 27 |
+
6 that that SCONJ IN _ 9 mark 9:mark _
|
| 28 |
+
7 they they PRON PRP Case=Nom|Number=Plur|Person=3|PronType=Prs 9 nsubj 9:nsubj _
|
| 29 |
+
8 had have AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 9 aux 9:aux _
|
| 30 |
+
9 busted bust VERB VBN Tense=Past|VerbForm=Part 5 ccomp 5:ccomp _
|
| 31 |
+
10 up up ADP RP _ 9 compound:prt 9:compound:prt _
|
| 32 |
+
11 3 3 NUM CD NumForm=Digit|NumType=Card 13 nummod 13:nummod _
|
| 33 |
+
12 terrorist terrorist ADJ JJ Degree=Pos 13 amod 13:amod _
|
| 34 |
+
13 cells cell NOUN NNS Number=Plur 9 obj 9:obj _
|
| 35 |
+
14 operating operate VERB VBG VerbForm=Ger 13 acl 13:acl _
|
| 36 |
+
15 in in ADP IN _ 16 case 16:case _
|
| 37 |
+
16 Baghdad Baghdad PROPN NNP Number=Sing 14 obl 14:obl:in SpaceAfter=No
|
| 38 |
+
17 . . PUNCT . _ 1 punct 1:punct _
|
| 39 |
+
|
| 40 |
+
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0004
|
| 41 |
+
# text = Two of them were being run by 2 officials of the Ministry of the Interior!
|
| 42 |
+
1 Two two NUM CD NumForm=Word|NumType=Card 6 nsubj:pass 6:nsubj:pass _
|
| 43 |
+
2 of of ADP IN _ 3 case 3:case _
|
| 44 |
+
3 them they PRON PRP Case=Acc|Number=Plur|Person=3|PronType=Prs 1 nmod 1:nmod:of _
|
| 45 |
+
4 were be AUX VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 6 aux 6:aux _
|
| 46 |
+
5 being be AUX VBG VerbForm=Ger 6 aux:pass 6:aux:pass _
|
| 47 |
+
6 run run VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _
|
| 48 |
+
7 by by ADP IN _ 9 case 9:case _
|
| 49 |
+
8 2 2 NUM CD NumForm=Digit|NumType=Card 9 nummod 9:nummod _
|
| 50 |
+
9 officials official NOUN NNS Number=Plur 6 obl 6:obl:by _
|
| 51 |
+
10 of of ADP IN _ 12 case 12:case _
|
| 52 |
+
11 the the DET DT Definite=Def|PronType=Art 12 det 12:det _
|
| 53 |
+
12 Ministry Ministry PROPN NNP Number=Sing 9 nmod 9:nmod:of _
|
| 54 |
+
13 of of ADP IN _ 15 case 15:case _
|
| 55 |
+
14 the the DET DT Definite=Def|PronType=Art 15 det 15:det _
|
| 56 |
+
15 Interior Interior PROPN NNP Number=Sing 12 nmod 12:nmod:of SpaceAfter=No
|
| 57 |
+
16 ! ! PUNCT . _ 6 punct 6:punct _
|
| 58 |
+
|
| 59 |
+
""".lstrip()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
DEV_DATA = """
|
| 63 |
+
1 From from ADP IN _ 3 case 3:case _
|
| 64 |
+
2 the the DET DT Definite=Def|PronType=Art 3 det 3:det _
|
| 65 |
+
3 AP AP PROPN NNP Number=Sing 4 obl 4:obl:from _
|
| 66 |
+
4 comes come VERB VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root _
|
| 67 |
+
5 this this DET DT Number=Sing|PronType=Dem 6 det 6:det _
|
| 68 |
+
6 story story NOUN NN Number=Sing 4 nsubj 4:nsubj _
|
| 69 |
+
7 : : PUNCT : _ 4 punct 4:punct _
|
| 70 |
+
|
| 71 |
+
""".lstrip()
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class TestParser:
|
| 76 |
+
@pytest.fixture(scope="class")
|
| 77 |
+
def wordvec_pretrain_file(self):
|
| 78 |
+
return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'
|
| 79 |
+
|
| 80 |
+
def run_training(self, tmp_path, wordvec_pretrain_file, train_text, dev_text, augment_nopunct=False, extra_args=None):
|
| 81 |
+
"""
|
| 82 |
+
Run the training for a few iterations, load & return the model
|
| 83 |
+
"""
|
| 84 |
+
train_file = str(tmp_path / "train.conllu")
|
| 85 |
+
dev_file = str(tmp_path / "dev.conllu")
|
| 86 |
+
pred_file = str(tmp_path / "pred.conllu")
|
| 87 |
+
|
| 88 |
+
save_name = "test_parser.pt"
|
| 89 |
+
save_file = str(tmp_path / save_name)
|
| 90 |
+
|
| 91 |
+
with open(train_file, "w", encoding="utf-8") as fout:
|
| 92 |
+
fout.write(train_text)
|
| 93 |
+
|
| 94 |
+
with open(dev_file, "w", encoding="utf-8") as fout:
|
| 95 |
+
fout.write(dev_text)
|
| 96 |
+
|
| 97 |
+
args = ["--wordvec_pretrain_file", wordvec_pretrain_file,
|
| 98 |
+
"--train_file", train_file,
|
| 99 |
+
"--eval_file", dev_file,
|
| 100 |
+
"--output_file", pred_file,
|
| 101 |
+
"--log_step", "10",
|
| 102 |
+
"--eval_interval", "20",
|
| 103 |
+
"--max_steps", "100",
|
| 104 |
+
"--shorthand", "en_test",
|
| 105 |
+
"--save_dir", str(tmp_path),
|
| 106 |
+
"--save_name", save_name,
|
| 107 |
+
# in case we are doing a bert test
|
| 108 |
+
"--bert_start_finetuning", "10",
|
| 109 |
+
"--bert_warmup_steps", "10",
|
| 110 |
+
"--lang", "en"]
|
| 111 |
+
if not augment_nopunct:
|
| 112 |
+
args.extend(["--augment_nopunct", "0.0"])
|
| 113 |
+
if extra_args is not None:
|
| 114 |
+
args = args + extra_args
|
| 115 |
+
trainer = parser.main(args)
|
| 116 |
+
|
| 117 |
+
assert os.path.exists(save_file)
|
| 118 |
+
pt = pretrain.Pretrain(wordvec_pretrain_file)
|
| 119 |
+
# test loading the saved model
|
| 120 |
+
saved_model = Trainer(pretrain=pt, model_file=save_file)
|
| 121 |
+
return trainer
|
| 122 |
+
|
| 123 |
+
def test_train(self, tmp_path, wordvec_pretrain_file):
|
| 124 |
+
"""
|
| 125 |
+
Simple test of a few 'epochs' of tagger training
|
| 126 |
+
"""
|
| 127 |
+
self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA)
|
| 128 |
+
|
| 129 |
+
def test_with_bert_nlayers(self, tmp_path, wordvec_pretrain_file):
|
| 130 |
+
self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_hidden_layers', '2'])
|
| 131 |
+
|
| 132 |
+
def test_with_bert_finetuning(self, tmp_path, wordvec_pretrain_file):
|
| 133 |
+
trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune', '--bert_hidden_layers', '2'])
|
| 134 |
+
assert 'bert_optimizer' in trainer.optimizer.keys()
|
| 135 |
+
assert 'bert_scheduler' in trainer.scheduler.keys()
|
| 136 |
+
|
| 137 |
+
def test_with_bert_finetuning_resaved(self, tmp_path, wordvec_pretrain_file):
|
| 138 |
+
"""
|
| 139 |
+
Check that if we save, then load, then save a model with a finetuned bert, that bert isn't lost
|
| 140 |
+
"""
|
| 141 |
+
trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune', '--bert_hidden_layers', '2'])
|
| 142 |
+
assert 'bert_optimizer' in trainer.optimizer.keys()
|
| 143 |
+
assert 'bert_scheduler' in trainer.scheduler.keys()
|
| 144 |
+
|
| 145 |
+
save_name = trainer.args['save_name']
|
| 146 |
+
filename = tmp_path / save_name
|
| 147 |
+
assert os.path.exists(filename)
|
| 148 |
+
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
|
| 149 |
+
assert any(x.startswith("bert_model") for x in checkpoint['model'].keys())
|
| 150 |
+
|
| 151 |
+
# Test loading the saved model, saving it, and still having bert in it
|
| 152 |
+
# even if we have set bert_finetune to False for this incarnation
|
| 153 |
+
pt = pretrain.Pretrain(wordvec_pretrain_file)
|
| 154 |
+
args = {"bert_finetune": False}
|
| 155 |
+
saved_model = Trainer(pretrain=pt, model_file=filename, args=args)
|
| 156 |
+
|
| 157 |
+
saved_model.save(filename)
|
| 158 |
+
|
| 159 |
+
# This is the part that would fail if the force_bert_saved option did not exist
|
| 160 |
+
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
|
| 161 |
+
assert any(x.startswith("bert_model") for x in checkpoint['model'].keys())
|
| 162 |
+
|
| 163 |
+
def test_with_peft(self, tmp_path, wordvec_pretrain_file):
|
| 164 |
+
trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_finetune', '--bert_hidden_layers', '2', '--use_peft'])
|
| 165 |
+
assert 'bert_optimizer' in trainer.optimizer.keys()
|
| 166 |
+
assert 'bert_scheduler' in trainer.scheduler.keys()
|
| 167 |
+
|
| 168 |
+
def test_single_optimizer_checkpoint(self, tmp_path, wordvec_pretrain_file):
|
| 169 |
+
trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--optim', 'adam'])
|
| 170 |
+
|
| 171 |
+
save_dir = trainer.args['save_dir']
|
| 172 |
+
save_name = trainer.args['save_name']
|
| 173 |
+
checkpoint_name = trainer.args["checkpoint_save_name"]
|
| 174 |
+
|
| 175 |
+
assert os.path.exists(os.path.join(save_dir, save_name))
|
| 176 |
+
assert checkpoint_name is not None
|
| 177 |
+
assert os.path.exists(checkpoint_name)
|
| 178 |
+
|
| 179 |
+
assert len(trainer.optimizer) == 1
|
| 180 |
+
for opt in trainer.optimizer.values():
|
| 181 |
+
assert isinstance(opt, torch.optim.Adam)
|
| 182 |
+
|
| 183 |
+
pt = pretrain.Pretrain(wordvec_pretrain_file)
|
| 184 |
+
checkpoint = Trainer(args=trainer.args, pretrain=pt, model_file=checkpoint_name)
|
| 185 |
+
assert checkpoint.optimizer is not None
|
| 186 |
+
assert len(checkpoint.optimizer) == 1
|
| 187 |
+
for opt in checkpoint.optimizer.values():
|
| 188 |
+
assert isinstance(opt, torch.optim.Adam)
|
| 189 |
+
|
| 190 |
+
def test_two_optimizers_checkpoint(self, tmp_path, wordvec_pretrain_file):
|
| 191 |
+
trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--optim', 'adam', '--second_optim', 'sgd', '--second_optim_start_step', '40'])
|
| 192 |
+
|
| 193 |
+
save_dir = trainer.args['save_dir']
|
| 194 |
+
save_name = trainer.args['save_name']
|
| 195 |
+
checkpoint_name = trainer.args["checkpoint_save_name"]
|
| 196 |
+
|
| 197 |
+
assert os.path.exists(os.path.join(save_dir, save_name))
|
| 198 |
+
assert checkpoint_name is not None
|
| 199 |
+
assert os.path.exists(checkpoint_name)
|
| 200 |
+
|
| 201 |
+
assert len(trainer.optimizer) == 1
|
| 202 |
+
for opt in trainer.optimizer.values():
|
| 203 |
+
assert isinstance(opt, torch.optim.SGD)
|
| 204 |
+
|
| 205 |
+
pt = pretrain.Pretrain(wordvec_pretrain_file)
|
| 206 |
+
checkpoint = Trainer(args=trainer.args, pretrain=pt, model_file=checkpoint_name)
|
| 207 |
+
assert checkpoint.optimizer is not None
|
| 208 |
+
assert len(checkpoint.optimizer) == 1
|
| 209 |
+
for opt in trainer.optimizer.values():
|
| 210 |
+
assert isinstance(opt, torch.optim.SGD)
|
| 211 |
+
|
stanza/stanza/tests/langid/__init__.py
ADDED
|
File without changes
|
stanza/stanza/tests/langid/test_multilingual.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests specifically for the MultilingualPipeline
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
from stanza.pipeline.multilingual import MultilingualPipeline
|
| 10 |
+
|
| 11 |
+
from stanza.tests import TEST_MODELS_DIR
|
| 12 |
+
|
| 13 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 14 |
+
|
| 15 |
+
def run_multilingual_pipeline(en_has_dependencies=True, fr_has_dependencies=True, **kwargs):
|
| 16 |
+
english_text = "This is an English sentence."
|
| 17 |
+
english_words = ["This", "is", "an", "English", "sentence", "."]
|
| 18 |
+
english_deps_gold = "\n".join((
|
| 19 |
+
"('This', 5, 'nsubj')",
|
| 20 |
+
"('is', 5, 'cop')",
|
| 21 |
+
"('an', 5, 'det')",
|
| 22 |
+
"('English', 5, 'amod')",
|
| 23 |
+
"('sentence', 0, 'root')",
|
| 24 |
+
"('.', 5, 'punct')"
|
| 25 |
+
))
|
| 26 |
+
if not en_has_dependencies:
|
| 27 |
+
english_deps_gold = ""
|
| 28 |
+
|
| 29 |
+
french_text = "C'est une phrase française."
|
| 30 |
+
french_words = ["C'", "est", "une", "phrase", "française", "."]
|
| 31 |
+
french_deps_gold = "\n".join((
|
| 32 |
+
"(\"C'\", 4, 'nsubj')",
|
| 33 |
+
"('est', 4, 'cop')",
|
| 34 |
+
"('une', 4, 'det')",
|
| 35 |
+
"('phrase', 0, 'root')",
|
| 36 |
+
"('française', 4, 'amod')",
|
| 37 |
+
"('.', 4, 'punct')"
|
| 38 |
+
))
|
| 39 |
+
if not fr_has_dependencies:
|
| 40 |
+
french_deps_gold = ""
|
| 41 |
+
|
| 42 |
+
if 'lang_configs' in kwargs:
|
| 43 |
+
nlp = MultilingualPipeline(model_dir=TEST_MODELS_DIR, download_method=None, **kwargs)
|
| 44 |
+
else:
|
| 45 |
+
lang_configs = {"en": {"processors": "tokenize,pos,lemma,depparse"},
|
| 46 |
+
"fr": {"processors": "tokenize,pos,lemma,depparse"}}
|
| 47 |
+
nlp = MultilingualPipeline(model_dir=TEST_MODELS_DIR, download_method=None, lang_configs=lang_configs, **kwargs)
|
| 48 |
+
docs = [english_text, french_text]
|
| 49 |
+
docs = nlp(docs)
|
| 50 |
+
|
| 51 |
+
assert docs[0].lang == "en"
|
| 52 |
+
assert len(docs[0].sentences) == 1
|
| 53 |
+
assert [x.text for x in docs[0].sentences[0].words] == english_words
|
| 54 |
+
assert docs[0].sentences[0].dependencies_string() == english_deps_gold
|
| 55 |
+
|
| 56 |
+
assert len(docs[1].sentences) == 1
|
| 57 |
+
assert docs[1].lang == "fr"
|
| 58 |
+
assert [x.text for x in docs[1].sentences[0].words] == french_words
|
| 59 |
+
assert docs[1].sentences[0].dependencies_string() == french_deps_gold
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def test_multilingual_pipeline():
|
| 63 |
+
"""
|
| 64 |
+
Basic test of multilingual pipeline
|
| 65 |
+
"""
|
| 66 |
+
run_multilingual_pipeline()
|
| 67 |
+
|
| 68 |
+
def test_multilingual_pipeline_small_cache():
|
| 69 |
+
"""
|
| 70 |
+
Test with the cache size 1
|
| 71 |
+
"""
|
| 72 |
+
run_multilingual_pipeline(max_cache_size=1)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def test_multilingual_config():
|
| 76 |
+
"""
|
| 77 |
+
Test with only tokenize for the EN pipeline
|
| 78 |
+
"""
|
| 79 |
+
lang_configs = {
|
| 80 |
+
"en": {"processors": "tokenize"}
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
run_multilingual_pipeline(en_has_dependencies=False, lang_configs=lang_configs)
|
| 84 |
+
|
| 85 |
+
def test_multilingual_processors_limited():
|
| 86 |
+
"""
|
| 87 |
+
Test loading an available subset of processors
|
| 88 |
+
"""
|
| 89 |
+
run_multilingual_pipeline(en_has_dependencies=False, fr_has_dependencies=False, lang_configs={}, processors="tokenize")
|
| 90 |
+
run_multilingual_pipeline(en_has_dependencies=True, fr_has_dependencies=False, lang_configs={"en": {"processors": "tokenize,pos,lemma,depparse"}}, processors="tokenize")
|
| 91 |
+
# this should not fail, as it will drop the zzzzzzzzzz processor for the languages which don't have it
|
| 92 |
+
run_multilingual_pipeline(en_has_dependencies=False, fr_has_dependencies=False, lang_configs={}, processors="tokenize,zzzzzzzzzz")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def test_defaultdict_config():
|
| 96 |
+
"""
|
| 97 |
+
Test that you can pass in a defaultdict for the lang_configs argument
|
| 98 |
+
"""
|
| 99 |
+
lang_configs = defaultdict(lambda: dict(processors="tokenize"))
|
| 100 |
+
run_multilingual_pipeline(en_has_dependencies=False, fr_has_dependencies=False, lang_configs=lang_configs)
|
| 101 |
+
|
| 102 |
+
lang_configs = defaultdict(lambda: dict(processors="tokenize"))
|
| 103 |
+
lang_configs["en"] = {"processors": "tokenize,pos,lemma,depparse"}
|
| 104 |
+
run_multilingual_pipeline(en_has_dependencies=True, fr_has_dependencies=False, lang_configs=lang_configs)
|
stanza/stanza/tests/lemma_classifier/__init__.py
ADDED
|
File without changes
|
stanza/stanza/tests/lemma_classifier/test_training.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import pytest
|
| 5 |
+
|
| 6 |
+
pytestmark = [pytest.mark.pipeline, pytest.mark.travis]
|
| 7 |
+
|
| 8 |
+
from stanza.models.lemma_classifier import train_lstm_model
|
| 9 |
+
from stanza.models.lemma_classifier import train_transformer_model
|
| 10 |
+
from stanza.models.lemma_classifier.base_model import LemmaClassifier
|
| 11 |
+
from stanza.models.lemma_classifier.evaluate_models import evaluate_model
|
| 12 |
+
|
| 13 |
+
from stanza.tests import TEST_WORKING_DIR
|
| 14 |
+
from stanza.tests.lemma_classifier.test_data_preparation import convert_english_dataset
|
| 15 |
+
|
| 16 |
+
@pytest.fixture(scope="module")
|
| 17 |
+
def pretrain_file():
|
| 18 |
+
return f'{TEST_WORKING_DIR}/in/tiny_emb.pt'
|
| 19 |
+
|
| 20 |
+
def test_train_lstm(tmp_path, pretrain_file):
|
| 21 |
+
converted_files = convert_english_dataset(tmp_path)
|
| 22 |
+
|
| 23 |
+
save_name = str(tmp_path / 'lemma.pt')
|
| 24 |
+
|
| 25 |
+
train_file = converted_files[0]
|
| 26 |
+
eval_file = converted_files[1]
|
| 27 |
+
train_args = ['--wordvec_pretrain_file', pretrain_file,
|
| 28 |
+
'--save_name', save_name,
|
| 29 |
+
'--train_file', train_file,
|
| 30 |
+
'--eval_file', eval_file]
|
| 31 |
+
trainer = train_lstm_model.main(train_args)
|
| 32 |
+
|
| 33 |
+
evaluate_model(trainer.model, eval_file)
|
| 34 |
+
# test that loading the model works
|
| 35 |
+
model = LemmaClassifier.load(save_name, None)
|
| 36 |
+
|
| 37 |
+
def test_train_transformer(tmp_path, pretrain_file):
|
| 38 |
+
converted_files = convert_english_dataset(tmp_path)
|
| 39 |
+
|
| 40 |
+
save_name = str(tmp_path / 'lemma.pt')
|
| 41 |
+
|
| 42 |
+
train_file = converted_files[0]
|
| 43 |
+
eval_file = converted_files[1]
|
| 44 |
+
train_args = ['--bert_model', 'hf-internal-testing/tiny-bert',
|
| 45 |
+
'--save_name', save_name,
|
| 46 |
+
'--train_file', train_file,
|
| 47 |
+
'--eval_file', eval_file]
|
| 48 |
+
trainer = train_transformer_model.main(train_args)
|
| 49 |
+
|
| 50 |
+
evaluate_model(trainer.model, eval_file)
|
| 51 |
+
|
| 52 |
+
# test that loading the model works
|
| 53 |
+
model = LemmaClassifier.load(save_name, None)
|
stanza/stanza/tests/mwt/__init__.py
ADDED
|
File without changes
|