Skip to content

Add support for code stemming with tree-sitter #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ dev:

isort:
@echo "-> Apply isort changes to ensure proper imports ordering"
${VENV}/bin/isort --sl -l 100 src tests setup.py
${VENV}/bin/isort --sl -l 100 src tests setup.py --skip="tests/testfiles/"

black:
@echo "-> Apply black code formatter"
${VENV}/bin/black -l 100 src tests setup.py
${VENV}/bin/black -l 100 src tests setup.py --exclude="tests/testfiles/"

doc8:
@echo "-> Run doc8 validation"
Expand All @@ -33,11 +33,11 @@ valid: isort black

check:
@echo "-> Run pycodestyle (PEP8) validation"
@${ACTIVATE} pycodestyle --max-line-length=100 --exclude=.eggs,venv,lib,thirdparty,docs,migrations,settings.py,.cache .
@${ACTIVATE} pycodestyle --max-line-length=100 --exclude=.eggs,venv,lib,thirdparty,docs,migrations,settings.py,.cache,tests/testfiles/stemming/ .
@echo "-> Run isort imports ordering validation"
@${ACTIVATE} isort --sl --check-only -l 100 setup.py src tests .
@${ACTIVATE} isort --sl --check-only -l 100 setup.py src tests . --skip="tests/testfiles/"
@echo "-> Run black validation"
@${ACTIVATE} black --check --check -l 100 src tests setup.py
@${ACTIVATE} black --check --check -l 100 src tests setup.py --exclude="tests/testfiles/"

clean:
@echo "-> Clean the Python env"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ norecursedirs = [
"tests/data",
".eggs",
"src/*/data",
"tests/*/data"
"tests/testfiles/*"
]

python_files = "*.py"
Expand Down
8 changes: 8 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,11 @@ soupsieve==2.6
text-unidecode==1.3
urllib3==2.2.3
wheel==0.45.1
tree-sitter==0.23.0
tree-sitter-c==0.21.1
tree-sitter-cpp==0.22.0
tree-sitter-go==0.21.0
tree-sitter-java==0.21.0
tree-sitter-javascript==0.21.2
tree-sitter-python==0.21.0
tree-sitter-rust==0.21.2
9 changes: 9 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ install_requires =
commoncode
plugincode
samecode
typecode
tree-sitter
tree-sitter-c
tree-sitter-cpp
tree-sitter-go
tree-sitter-java
tree-sitter-javascript
tree-sitter-python
tree-sitter-rust


[options.packages.find]
Expand Down
4 changes: 3 additions & 1 deletion src/matchcode_toolkit/fingerprinting.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@ def tokenizer(text):
return _tokenizer(text.lower())


def get_file_fingerprint_hashes(location, ngram_length=5, window_length=16, include_ngrams=False, **kwargs):
def get_file_fingerprint_hashes(
location, ngram_length=5, window_length=16, include_ngrams=False, **kwargs
):
"""
Return a mapping of fingerprint hashes for the file at `location`

Expand Down
1 change: 0 additions & 1 deletion src/matchcode_toolkit/plugin_fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#

import attr

from commoncode.cliutils import SCAN_GROUP
from commoncode.cliutils import PluggableCommandLineOption
from plugincode.scan import ScanPlugin
Expand Down
166 changes: 166 additions & 0 deletions src/matchcode_toolkit/stemming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
#
# Copyright (c) nexB Inc. and others. All rights reserved.
# ScanCode is a trademark of nexB Inc.
# SPDX-License-Identifier: Apache-2.0
# See http://www.apache.org/licenses/LICENSE-2.0 for the license text.
# See https://github.com/aboutcode-org/matchcode-toolkit for support or download.
# See https://aboutcode.org for more information about nexB OSS projects.
#

import importlib

from tree_sitter import Language
from tree_sitter import Parser
from typecode.contenttype import Type


class TreeSitterWheelNotInstalled(Exception):
pass


TS_LANGUAGE_CONF = {
"C": {
"wheel": "tree_sitter_c",
"identifiers": ["identifier"],
"comments": ["comment"],
},
"C++": {
"wheel": "tree_sitter_cpp",
"identifiers": ["identifier"],
"comments": ["comment"],
},
"Go": {
"wheel": "tree_sitter_go",
"identifiers": ["identifier"],
"comments": ["comment"],
},
"Java": {
"wheel": "tree_sitter_java",
"identifiers": ["identifier"],
"comments": ["comment", "block_comment", "line_comment"],
},
"JavaScript": {
"wheel": "tree_sitter_javascript",
"identifiers": ["identifier"],
"comments": ["comment"],
},
"Python": {
"wheel": "tree_sitter_python",
"identifiers": ["identifier"],
"comments": ["comment"],
},
"Rust": {
"wheel": "tree_sitter_rust",
"identifiers": ["identifier"],
"comments": ["comment", "block_comment", "line_comment"],
},
}


def get_parser(location):
"""
Get the appropriate tree-sitter parser and grammar config for
file at location.
"""
file_type = Type(location)
language = file_type.programming_language

if not language or language not in TS_LANGUAGE_CONF:
return

language_info = TS_LANGUAGE_CONF[language]
wheel = language_info["wheel"]

try:
grammar = importlib.import_module(wheel)
except ModuleNotFoundError:
raise TreeSitterWheelNotInstalled(f"{wheel} package is not installed")

parser = Parser(language=Language(grammar.language()))

return parser, language_info


def add_to_mutation_index(node, mutation_index):
if content := node.text.decode():
end_point = node.end_point
start_point = node.start_point
mutation_index[(end_point.row, end_point.column)] = {
"type": node.type,
"content": content,
"start_point": (start_point.row, start_point.column),
"end_point": (end_point.row, end_point.column),
}


def traverse(node, language_info, mutation_index):
"""
Recursively traverse the parse tree node and create mutation index.

Mutation index contains the start, end coordinates and where mutations
is to be applied, along with the type of mutation. Each mutation entry
is keyed by a tuple containing the end coordinates.
"""
if node.type in language_info.get("identifiers") or node.type in language_info.get("comments"):
add_to_mutation_index(node=node, mutation_index=mutation_index)

for child in node.children:
traverse(child, language_info, mutation_index)


def apply_mutation(text, start_point, end_point, replacement, successive_line_count):
"""Mutate tokens between start and end points with replacement string."""

start_row, start_col = start_point
end_row, end_col = end_point

# Compute 1D mutation position from 2D coordinates
start_index = successive_line_count[start_row] + start_col
end_index = successive_line_count[end_row] + end_col

modified_text = text[:start_index] + replacement + text[end_index:]
modified_lines = modified_text.splitlines(keepends=True)

# Remove empty comment lines.
if not replacement and modified_lines[start_row].strip() == "":
del modified_lines[start_row]

return "".join(modified_lines)


def get_stem_code(location):
"""
Return the stemmed code for the code file at the specified `location`.

Parse the code using tree-sitter, create a mutation index for tokens that
need to be replaced or removed, and apply these mutations bottom-up to
generate the stemmed code.
"""
parser_result = get_parser(location)
if not parser_result:
return

with open(location, "rb") as f:
source = f.read()
mutations = {}
parser, language_info = parser_result
tree = parser.parse(source)
traverse(tree.root_node, language_info, mutations)

# Apply mutations bottom-up
mutations = dict(sorted(mutations.items(), reverse=True))
text = source.decode()
cur_count = 0
lines = text.splitlines(keepends=True)
successive_line_count = [cur_count := cur_count + len(line) for line in lines]
successive_line_count.insert(0, 0)

for value in mutations.values():
text = apply_mutation(
text=text,
end_point=value["end_point"],
start_point=value["start_point"],
replacement=("idf" if value["type"] == "identifier" else ""),
successive_line_count=successive_line_count,
)
return text
16 changes: 11 additions & 5 deletions tests/test_fingerprinting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from commoncode.resource import VirtualCodebase
from commoncode.testcase import FileBasedTesting
from commoncode.testcase import check_against_expected_json_file
from samecode.halohash import byte_hamming_distance

from matchcode_toolkit.fingerprinting import _create_directory_fingerprint
from matchcode_toolkit.fingerprinting import _get_resource_subpath
Expand All @@ -22,7 +23,6 @@
from matchcode_toolkit.fingerprinting import create_structure_fingerprint
from matchcode_toolkit.fingerprinting import get_file_fingerprint_hashes
from matchcode_toolkit.fingerprinting import split_fingerprint
from samecode.halohash import byte_hamming_distance


class Resource:
Expand Down Expand Up @@ -193,10 +193,13 @@ def test_snippets_similarity(self, regen=False):
results1_snippet_mappings_by_snippets = self._create_snippet_mappings_by_snippets(
results1_snippets
)
results2_snippet_mappings_by_snippets = self._create_snippet_mappings_by_snippets(results2_snippets)
results2_snippet_mappings_by_snippets = self._create_snippet_mappings_by_snippets(
results2_snippets
)

matching_snippets = (
results1_snippet_mappings_by_snippets.keys() & results2_snippet_mappings_by_snippets.keys()
results1_snippet_mappings_by_snippets.keys()
& results2_snippet_mappings_by_snippets.keys()
)
expected_matching_snippets = {
"33b1d50de7e1701bd4beb706bf25970e",
Expand Down Expand Up @@ -247,10 +250,13 @@ def test_snippets_similarity_2(self, regen=False):
results1_snippet_mappings_by_snippets = self._create_snippet_mappings_by_snippets(
results1_snippets
)
results2_snippet_mappings_by_snippets = self._create_snippet_mappings_by_snippets(results2_snippets)
results2_snippet_mappings_by_snippets = self._create_snippet_mappings_by_snippets(
results2_snippets
)

matching_snippets = (
results1_snippet_mappings_by_snippets.keys() & results2_snippet_mappings_by_snippets.keys()
results1_snippet_mappings_by_snippets.keys()
& results2_snippet_mappings_by_snippets.keys()
)

# jaccard coefficient
Expand Down
80 changes: 80 additions & 0 deletions tests/test_stemming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#
# Copyright (c) nexB Inc. and others. All rights reserved.
# ScanCode is a trademark of nexB Inc.
# SPDX-License-Identifier: Apache-2.0
# See http://www.apache.org/licenses/LICENSE-2.0 for the license text.
# See https://github.com/aboutcode-org/matchcode-toolkit for support or download.
# See https://aboutcode.org for more information about nexB OSS projects.
#


from pathlib import Path

from commoncode.testcase import FileBasedTesting

from matchcode_toolkit import stemming


def check_against_expected_code_file(results, expected_file, regen=False):
"""
Check that the ``results`` data are the same as the data in the
``expected_file``.

If `regen` is True the expected_file will overwritten with the ``results``.
This is convenient for updating tests expectations. But use with caution.
"""
if regen:
with open(expected_file, "w") as reg:
reg.write(results)
expected = results
else:
with open(expected_file) as exp:
expected = exp.read()

assert results == expected


class TestFingerprintingFunctions(FileBasedTesting):
test_data_dir = Path(__file__).parent / "testfiles/stemming"

def test_java_code_stemming(self):
file_location = self.test_data_dir / "java/contenttype.java"
expected_file_location = self.test_data_dir / "java/contenttype-stemmed.java"
results = stemming.get_stem_code(location=str(file_location))
check_against_expected_code_file(results, expected_file_location)

def test_cpp_code_stemming(self):
file_location = self.test_data_dir / "cpp/string.cpp"
expected_file_location = self.test_data_dir / "cpp/string-stemmed.cpp"
results = stemming.get_stem_code(location=str(file_location))
check_against_expected_code_file(results, expected_file_location)

def test_c_code_stemming(self):
file_location = self.test_data_dir / "c/main.c"
expected_file_location = self.test_data_dir / "c/main-stemmed.c"
results = stemming.get_stem_code(location=str(file_location))
check_against_expected_code_file(results, expected_file_location)

def test_golang_code_stemming(self):
file_location = self.test_data_dir / "golang/utils.go"
expected_file_location = self.test_data_dir / "golang/utils-stemmed.go"
results = stemming.get_stem_code(location=str(file_location))
check_against_expected_code_file(results, expected_file_location)

def test_python_code_stemming(self):
file_location = self.test_data_dir / "python/sync_scancode_scans.py"
expected_file_location = self.test_data_dir / "python/sync_scancode_scans-stemmed.py"
results = stemming.get_stem_code(location=str(file_location))
check_against_expected_code_file(results, expected_file_location)

def test_javascript_code_stemming(self):
file_location = self.test_data_dir / "javascript/utils.js"
expected_file_location = self.test_data_dir / "javascript/utils-stemmed.js"
results = stemming.get_stem_code(location=str(file_location))
check_against_expected_code_file(results, expected_file_location)

def test_rust_code_stemming(self):
file_location = self.test_data_dir / "rust/metrics.rs"
expected_file_location = self.test_data_dir / "rust/metrics-stemmeds.rs"
results = stemming.get_stem_code(location=str(file_location))
check_against_expected_code_file(results, expected_file_location)
Loading