Spaces:
Running
Running
| /** | |
| * @license | |
| * Copyright 2019 Google LLC. All Rights Reserved. | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| * ============================================================================= | |
| */ | |
| // https://github.com/tensorflow/tfjs-models/blob/master/universal-sentence-encoder/src/tokenizer/trie.ts | |
| class TrieNode { | |
| constructor(key) { | |
| this.key = key; | |
| this.parent = null; | |
| this.children = {}; | |
| this.end = false; | |
| } | |
| getWord() { | |
| const output = []; | |
| let node = this; | |
| while (node !== null) { | |
| if (node.key !== null) { | |
| output.unshift(node.key); | |
| } | |
| node = node.parent; | |
| } | |
| return [output, this.score, this.index]; | |
| } | |
| } | |
| class Trie { | |
| constructor() { | |
| this.root = new TrieNode(null); | |
| } | |
| insert(word, score, index) { | |
| let node = this.root; | |
| const symbols = []; | |
| for (const symbol of word) { | |
| symbols.push(symbol); | |
| } | |
| for (let i = 0; i < symbols.length; i++) { | |
| if (!node.children[symbols[i]]) { | |
| node.children[symbols[i]] = new TrieNode(symbols[i]); | |
| node.children[symbols[i]].parent = node; | |
| } | |
| node = node.children[symbols[i]]; | |
| if (i === symbols.length - 1) { | |
| node.end = true; | |
| node.score = score; | |
| node.index = index; | |
| } | |
| } | |
| } | |
| find(ss) { | |
| let node = this.root; | |
| let iter = 0; | |
| while (iter < ss.length && node != null) { | |
| node = node.children[ss[iter]]; | |
| iter++; | |
| } | |
| return node; | |
| } | |
| } | |
| const bert = { | |
| loadTokenizer: async () => { | |
| const tokenizer = new BertTokenizer(); | |
| await tokenizer.load(); | |
| return tokenizer; | |
| } | |
| }; | |
| class BertTokenizer { | |
| constructor() { | |
| this.separator = '\u2581'; | |
| this.UNK_INDEX = 100; | |
| } | |
| async load() { | |
| this.vocab = await this.loadVocab(); | |
| this.trie = new Trie(); | |
| // Actual tokens start at 999. | |
| for (let i = 999; i < this.vocab.length; i++) { | |
| const word = this.vocab[i]; | |
| this.trie.insert(word, 1, i); | |
| } | |
| this.token2Id = {} | |
| this.vocab.forEach((d, i) => { | |
| this.token2Id[d] = i | |
| }) | |
| this.decode = a => a.map(d => this.vocab[d].replace('▁', ' ')).join('') | |
| // Adds [CLS] and [SEP] | |
| this.tokenizeCLS = str => [101, ...this.tokenize(str), 102] | |
| } | |
| async loadVocab() { | |
| if (!window.bertProcessedVocab){ | |
| window.bertProcessedVocab = await (await fetch('data/processed_vocab.json')).json() | |
| } | |
| return window.bertProcessedVocab | |
| } | |
| processInput(text) { | |
| const words = text.split(' '); | |
| return words.map(word => { | |
| if (word !== '[CLS]' && word !== '[SEP]') { | |
| return this.separator + word.toLowerCase().normalize('NFKC'); | |
| } | |
| return word; | |
| }); | |
| } | |
| tokenize(text) { | |
| // Source: | |
| // https://github.com/google-research/bert/blob/88a817c37f788702a363ff935fd173b6dc6ac0d6/tokenization.py#L311 | |
| let outputTokens = []; | |
| const words = this.processInput(text); | |
| for (let i = 0; i < words.length; i++) { | |
| const chars = []; | |
| for (const symbol of words[i]) { | |
| chars.push(symbol); | |
| } | |
| let isUnknown = false; | |
| let start = 0; | |
| const subTokens = []; | |
| const charsLength = chars.length; | |
| while (start < charsLength) { | |
| let end = charsLength; | |
| let currIndex; | |
| while (start < end) { | |
| let substr = chars.slice(start, end).join(''); | |
| const match = this.trie.find(substr); | |
| if (match != null && match.end) { | |
| currIndex = match.getWord()[2]; | |
| break; | |
| } | |
| end = end - 1; | |
| } | |
| if (currIndex == null) { | |
| isUnknown = true; | |
| break; | |
| } | |
| subTokens.push(currIndex); | |
| start = end; | |
| } | |
| if (isUnknown) { | |
| outputTokens.push(this.UNK_INDEX); | |
| } else { | |
| outputTokens = outputTokens.concat(subTokens); | |
| } | |
| } | |
| return outputTokens; | |
| } | |
| } |