dataset-builder / data1 /analysis.py
SunDou's picture
Upload data1/analysis.py with huggingface_hub
880e02b verified
import csv
import re
import tokenize
from io import StringIO
import os
from tqdm import tqdm
import json
import sys
from functools import lru_cache
csv.field_size_limit(sys.maxsize)
# ============== 预编译正则表达式以提高性能 ==============
# 行注释规则(预编译)
_LINE_COMMENT_PATTERNS = {
"python": re.compile(r"#(.*)$"),
"shell": re.compile(r"#(.*)$"),
"r": re.compile(r"#(.*)$"),
"matlab": re.compile(r"%(.*)$"),
"fortran": re.compile(r"!(.*)$"),
"c/c++": re.compile(r"//(.*)$"),
"java": re.compile(r"//(.*)$"),
"go": re.compile(r"//(.*)$"),
"rust": re.compile(r"//(.*)$"),
}
# 块注释规则(预编译)
_BLOCK_COMMENT_PATTERNS = {
"python": re.compile(r'("""[\s\S]*?"""|\'\'\'[\s\S]*?\'\'\')'),
"c/c++": re.compile(r"/\*([\s\S]*?)\*/"),
"java": re.compile(r"/\*([\s\S]*?)\*/"),
"rust": re.compile(r"/\*([\s\S]*?)\*/"),
"go": re.compile(r"/\*([\s\S]*?)\*/"),
"matlab": re.compile(r"%\{([\s\S]*?)%\}"),
}
# 函数匹配规则(预编译)
_FUNCTION_PATTERNS = {
"python": re.compile(r"^[ \t]*def\s+(\w+)\s*\(([^)]*)\)", re.MULTILINE),
"java": re.compile(r"""
(?:public|protected|private|static|final|native|synchronized|abstract|\s)*
\s*
(?:[\w\<\>\[\],\s]+)
\s+
(\w+)
\s*\(([^)]*)\)
(?:\s*throws\s+[\w,\s]+)?
\s*\{
""", re.MULTILINE | re.VERBOSE),
"c/c++": re.compile(r"""
^[ \t]*
(?!.*typedef)
(?!.*\#)
(?:[\w\*\s&]+)
\b(\w+)\s*
\(([^)]*)\)
\s*(?:const)?
\s*(?:override)?
\s*(?:noexcept)?
\s*\{
""", re.MULTILINE | re.VERBOSE),
"go": re.compile(r"\bfunc\s+(?:\([^)]+\)\s*)?(\w+)\s*\(([^)]*)\)", re.MULTILINE),
"rust": re.compile(r"\b(?:pub\s+)?(?:async\s+)?fn\s+(\w+)\s*(?:<[^>]*>)?\s*\(([^)]*)\)", re.MULTILINE),
"r": re.compile(r"(\w+)\s*(?:<-|=)\s*function\s*\(([^)]*)\)", re.MULTILINE),
"matlab": re.compile(r"^[ \t]*function\s+(?:(?:\[?[\w,\s]*\]?\s*=\s*)?(\w+)|(\w+))\s*\(([^)]*)\)", re.MULTILINE),
"shell": re.compile(r"^[ \t]*(?:function\s+)?(\w+)\s*\(\)\s*\{", re.MULTILINE),
"fortran": re.compile(r"""
(?i)
^[ \t]*
(?:recursive\s+)?
(?:pure\s+)?
(?:elemental\s+)?
(?:[\w\*]+(?:\s*\([^)]*\))?\s+)?
(function|subroutine)\s+
(\w+)\s*
\(([^)]*)\)
""", re.MULTILINE | re.VERBOSE),
}
# 移除注释的正则(预编译)
_REMOVE_COMMENT_PATTERNS = {
"python_line": re.compile(r'#.*$', re.MULTILINE),
"python_triple_dq": re.compile(r'"""[\s\S]*?"""'),
"python_triple_sq": re.compile(r"'''[\s\S]*?'''"),
"c_line": re.compile(r'//.*$', re.MULTILINE),
"c_block": re.compile(r'/\*[\s\S]*?\*/'),
"shell_line": re.compile(r'#.*$', re.MULTILINE),
"matlab_line": re.compile(r'%.*$', re.MULTILINE),
"matlab_block": re.compile(r'%\{[\s\S]*?%\}'),
"fortran_line": re.compile(r'!.*$', re.MULTILINE),
}
def detect_language(file_path: str):
"""仅根据文件后缀判断语言"""
ext_map = {
".py": "python",
".java": "java",
".c": "c/c++",
".h": "c/c++",
".hh": "c/c++",
".hpp": "c/c++",
".cpp": "c/c++",
".cc": "c/c++",
".cxx": "c/c++",
".c++": "c/c++",
".F": "fortran",
".f90": "fortran",
".f": "fortran",
".f95": "fortran",
".r": "r",
".m": "matlab", # MATLAB / Octave
".sh": "shell",
".bash": "shell",
".rs": "rust",
".go": "go",
}
ext = os.path.splitext(file_path)[1].lower()
ext = ext.strip()
# if ext not in ext_map.keys():
# print("unknown language:", ext)
return ext_map.get(ext, ext)
def count_comments(code: str, lang: str):
"""统计注释行数与注释 token(支持 Python/Java/C++/Fortran/Matlab/R/Shell/Rust/Go/Jupyter)
使用预编译的正则表达式以提高性能。
"""
# jupyter 使用 python 的规则
if lang == "jupyter":
lang = "python"
comment_lines = 0
comment_tokens = []
lines = code.splitlines()
# 记录已经被块注释覆盖的行号,避免重复计数
block_comment_line_indices = set()
# ---------- B. 先处理块注释(记录行号) ----------
if lang in _BLOCK_COMMENT_PATTERNS:
patt = _BLOCK_COMMENT_PATTERNS[lang]
if lang == "python":
# Python 的 triple-quote 需要特殊处理
for match in patt.finditer(code):
start_pos = match.start()
end_pos = match.end()
# 计算起始和结束行号
start_line = code[:start_pos].count('\n')
end_line = code[:end_pos].count('\n')
# 检查这个 triple-quote 是否是 docstring(不是赋值语句)
prefix = code[max(0, start_pos-20):start_pos].strip()
if not prefix.endswith('='):
for line_idx in range(start_line, end_line + 1):
block_comment_line_indices.add(line_idx)
block_content = match.group(1)
if block_content.startswith('"""'):
block_content = block_content[3:-3]
else:
block_content = block_content[3:-3]
for b in block_content.splitlines():
comment_lines += 1
if b.strip():
comment_tokens.extend(b.strip().split())
else:
for match in patt.finditer(code):
start_pos = match.start()
end_pos = match.end()
start_line = code[:start_pos].count('\n')
end_line = code[:end_pos].count('\n')
for line_idx in range(start_line, end_line + 1):
block_comment_line_indices.add(line_idx)
block_content = match.group(1) if match.lastindex else match.group(0)
for b in block_content.splitlines():
comment_lines += 1
if b.strip():
comment_tokens.extend(b.strip().split())
# ---------- A. 行注释(排除已被块注释覆盖的行) ----------
if lang in _LINE_COMMENT_PATTERNS:
patt = _LINE_COMMENT_PATTERNS[lang]
for line_idx, line in enumerate(lines):
if line_idx in block_comment_line_indices:
continue
m = patt.search(line)
if m:
prefix = line[:m.start()]
single_quotes = prefix.count("'") - prefix.count("\\'")
double_quotes = prefix.count('"') - prefix.count('\\"')
if single_quotes % 2 == 0 and double_quotes % 2 == 0:
comment_lines += 1
text = m.group(1)
if text:
comment_tokens.extend(text.strip().split())
return comment_lines, len(comment_tokens)
def count_functions_and_parameters(code: str, lang: str):
"""统计函数数量与参数数量,支持多语言(含 Fortran subroutine/function)。
使用预编译的正则表达式以提高性能。
"""
# jupyter 使用 python 的规则
if lang == "jupyter":
lang = "python"
patt = _FUNCTION_PATTERNS.get(lang)
if not patt:
return 0, 0
# 先移除注释,避免匹配注释中的函数定义
code_no_comments = _remove_comments(code, lang)
# 使用预编译的模式匹配
matches = patt.findall(code_no_comments)
function_count = len(matches)
parameter_count = 0
for m in matches:
if lang == "fortran":
params = m[2] # (keyword, name, params)
elif lang == "matlab":
params = m[2] if len(m) > 2 else ""
else:
params = m[1] if isinstance(m, tuple) and len(m) > 1 else ""
params = params.strip() if params else ""
if params:
items = [p.strip() for p in params.split(",") if p.strip()]
parameter_count += len(items)
return function_count, parameter_count
def _remove_comments(code: str, lang: str) -> str:
"""移除代码中的注释,用于更准确地匹配函数定义(使用预编译正则)"""
if lang in ("python", "jupyter"):
code = _REMOVE_COMMENT_PATTERNS["python_line"].sub('', code)
code = _REMOVE_COMMENT_PATTERNS["python_triple_dq"].sub(lambda m: '\n' * m.group(0).count('\n'), code)
code = _REMOVE_COMMENT_PATTERNS["python_triple_sq"].sub(lambda m: '\n' * m.group(0).count('\n'), code)
elif lang in ("c/c++", "java", "rust", "go"):
code = _REMOVE_COMMENT_PATTERNS["c_line"].sub('', code)
code = _REMOVE_COMMENT_PATTERNS["c_block"].sub(lambda m: '\n' * m.group(0).count('\n'), code)
elif lang == "shell":
code = _REMOVE_COMMENT_PATTERNS["shell_line"].sub('', code)
elif lang == "r":
code = _REMOVE_COMMENT_PATTERNS["shell_line"].sub('', code) # R 也用 #
elif lang == "matlab":
code = _REMOVE_COMMENT_PATTERNS["matlab_line"].sub('', code)
code = _REMOVE_COMMENT_PATTERNS["matlab_block"].sub(lambda m: '\n' * m.group(0).count('\n'), code)
elif lang == "fortran":
code = _REMOVE_COMMENT_PATTERNS["fortran_line"].sub('', code)
return code
def count_tokens(code: str):
"""统计 Python token;非 Python 用简单 split"""
try:
return len(list(tokenize.generate_tokens(StringIO(code).readline)))
except:
return len(code.split())
def analyze_code(code_str, code_path):
lang = detect_language(code_path)
# if lang == "unknown":
# print("==========unknown language==========")
# print(code_str)
# sys.exit(0)
lines = code_str.count("\n") + 1
empty_lines = sum(1 for line in code_str.splitlines() if not line.strip())
comment_lines, comment_token_count = count_comments(code_str, lang)
functions, parameters = count_functions_and_parameters(code_str, lang)
tokens = count_tokens(code_str)
return {
"idx": None,
"language": lang,
"total_lines": lines,
"comment_lines": comment_lines,
"comment_tokenst": comment_token_count,
"empty_lines": empty_lines,
"code_lines": lines - empty_lines - comment_lines,
"tokens": tokens,
"functions": functions,
"parameters": parameters,
}
if __name__ == "__main__":
input_dir = "/home/weifengsun/tangou1/domain_code/src/datasets/data_merged"
output_dir = "/home/weifengsun/tangou1/domain_code/src/datasets/analysis2"
for i in range(110, 120):
input_filename = f"{i:03}.csv"
output_file_name = f"{i:03}.jsonl"
input_path = os.path.join(input_dir, input_filename)
output_path = os.path.join(output_dir, output_file_name)
results = []
with open(input_path, "r", encoding="utf-8", errors="replace") as f:
filtered = (line.replace('\0', '') for line in f) # 删除 NUL
reader = csv.DictReader(filtered) # ✅ 使用 DictReader
for idx, row in tqdm(enumerate(reader)):
code_str = row.get("text") # 用 header 名字
code_path = row.get("repo_path")
if not code_path: # None / "" 都会进来
code_path = row.get("path")
result = analyze_code(code_str, code_path)
result["idx"] = f"{i:03}-{idx}"
results.append(result)
with open(output_path, "w", encoding="utf-8") as f:
for r in tqdm(results):
f.write(json.dumps(r) + "\n")