| | --- |
| | language: |
| | - en |
| | - code |
| | tags: |
| | - code completion |
| | - code generation |
| | license: "apache-2.0" |
| | --- |
| | |
| | # NLGP docstring model |
| |
|
| | The NLGP docstring model was introduced in the paper [Natural Language-Guided Programming](https://arxiv.org/abs/2108.05198). The model was trained on a collection of Jupyter notebooks and can be used to synthesize Python code that addresses a natural language **intent** in a certain code **context** (see the example below). |
| | Also see the [NLGP natural](https://huggingface.co/Nokia/nlgp-natural) model. |
| |
|
| | This work was carried out by a research team in Nokia Bell Labs. |
| |
|
| | **Context** |
| | ```py |
| | import matplotlib.pyplot as plt |
| | |
| | values = [1, 2, 3, 4] |
| | labels = ["a", "b", "c", "d"] |
| | ``` |
| |
|
| | **Intent** |
| | ```py |
| | # plot a bart chart |
| | ``` |
| |
|
| | **Prediction** |
| | ```py |
| | plt.bar(labels, values) |
| | plt.show() |
| | ``` |
| |
|
| | ## Usage |
| |
|
| | ```py |
| | import re |
| | from transformers import GPT2LMHeadModel, GPT2TokenizerFast |
| | |
| | # load the model |
| | tok = GPT2TokenizerFast.from_pretrained("Nokia/nlgp-docstring") |
| | model = GPT2LMHeadModel.from_pretrained("Nokia/nlgp-docstring") |
| | |
| | # preprocessing functions |
| | num_spaces = [2, 4, 6, 8, 10, 12, 14, 16, 18] |
| | def preprocess(context, query): |
| | """ |
| | Encodes context + query as a single string and |
| | replaces whitespace with special tokens <|2space|>, <|4space|>, ... |
| | """ |
| | input_str = f"{context}\n{query} <|endofcomment|>\n" |
| | indentation_symbols = {n: f"<|{n}space|>" for n in num_spaces} |
| | m = re.match("^[ ]+", input_str) |
| | if not m: |
| | return input_str |
| | leading_whitespace = m.group(0) |
| | N = len(leading_whitespace) |
| | for n in self.num_spaces: |
| | leading_whitespace = leading_whitespace.replace(n * " ", self.indentation_symbols[n]) |
| | return leading_whitespace + input_str[N:] |
| | |
| | detokenize_pattern = re.compile(fr"<\|(\d+)space\|>") |
| | def postprocess(output): |
| | output = output.split("<|cell|>")[0] |
| | def insert_space(m): |
| | num_spaces = int(m.group(1)) |
| | return num_spaces * " " |
| | return detokenize_pattern.sub(insert_space, output) |
| | |
| | # inference |
| | code_context = """ |
| | import matplotlib.pyplot as plt |
| | |
| | values = [1, 2, 3, 4] |
| | labels = ["a", "b", "c", "d"] |
| | """ |
| | query = "# plot a bar chart" |
| | |
| | input_str = preprocess(code_context, query) |
| | input_ids = tok(input_str, return_tensors="pt").input_ids |
| | |
| | max_length = 150 # don't generate output longer than this length |
| | total_max_length = min(1024 - input_ids.shape[-1], input_ids.shape[-1] + 150) # total = input + output |
| | |
| | input_and_output = model.generate( |
| | input_ids=input_ids, |
| | max_length=total_max_length, |
| | min_length=10, |
| | do_sample=False, |
| | num_beams=4, |
| | early_stopping=True, |
| | eos_token_id=tok.encode("<|cell|>")[0] |
| | ) |
| | |
| | output = input_and_output[:, input_ids.shape[-1]:] # remove the tokens that correspond to the input_str |
| | output_str = tok.decode(output[0]) |
| | postprocess(output_str) |
| | ``` |
| |
|
| | ## License and copyright |
| |
|
| | Copyright 2021 Nokia |
| |
|
| | Licensed under the Apache License 2.0 |
| |
|
| | SPDX-License-Identifier: Apache-2.0 |