From 491800e30994f5e9a876e1fa630850c6b3dd5836 Mon Sep 17 00:00:00 2001 From: Maharshi Basu Date: Sat, 5 Apr 2025 23:24:03 +0530 Subject: [PATCH 01/16] feat: introduce scope analysis for python This commit introduces scope analysis for Python language for better data flow analysis. Support for: - variable declarations - simple declarations - tuple unpack, list unpack, pattern list - import - simple import - import with alias - import with `from` --- analysis/py_scope.go | 221 ++++++++++++++++++++++++++++++++++++++ analysis/py_scope_test.go | 97 +++++++++++++++++ analysis/scope.go | 6 +- 3 files changed, 323 insertions(+), 1 deletion(-) create mode 100644 analysis/py_scope.go create mode 100644 analysis/py_scope_test.go diff --git a/analysis/py_scope.go b/analysis/py_scope.go new file mode 100644 index 00000000..6d33c33a --- /dev/null +++ b/analysis/py_scope.go @@ -0,0 +1,221 @@ +package analysis + +import ( + "slices" + + sitter "github.com/smacker/go-tree-sitter" +) + +// NOTE: should this struct type be moved to another file? +/* +type UnresolvedRef struct { + id *sitter.Node + surroundingScope *Scope +} +*/ + +type PyScopeBuilder struct { + ast *sitter.Node + source []byte + // list of references that could not be resolved thus far + unresolvedRefs []UnresolvedRef +} + +func (py *PyScopeBuilder) GetLanguage() Language { + return LangPy +} + +var PyScopeNodes = []string{ + "module", + "function_definition", + "class_definition", + "for_statement", + "while_statement", + "if_statement", + "elif_clause", + "else_clause", + "with_statement", + "try_statement", + "except_clause", + "list_comprehension", + "dictionary_comprehension", + "lambda", +} + +func (py *PyScopeBuilder) NodeCreatesScope(node *sitter.Node) bool { + return slices.Contains(PyScopeNodes, node.Type()) +} + +func (py *PyScopeBuilder) DeclaresVariable(node *sitter.Node) bool { + typ := node.Type() + return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_item" || typ == "parameters" +} + +func (py *PyScopeBuilder) scanDecl(idOrPattern, declarator *sitter.Node, decls []*Variable) []*Variable { + switch idOrPattern.Type() { + case "identifier": + // TODO: implement for = = ... + // = ... + nameStr := idOrPattern.Content(py.source) + decls = append(decls, &Variable{ + Kind: VarKindVariable, + Name: nameStr, + DeclNode: declarator, + }) + + case "pattern_list", "tuple_pattern", "list_pattern": + // , = ..., ... + // (, ) = ..., ... + // [, ] = ..., ... + ids := ChildrenOfType(idOrPattern, "identifier") + for _, id := range ids { + decls = append(decls, &Variable{ + Kind: VarKindVariable, + Name: id.Content(py.source), + DeclNode: declarator, + }) + } + + // , * = ..., ..., ... + // also applicable to tuple_pattern & list_pattern + splats := ChildrenOfType(idOrPattern, "list_splat_pattern") + for _, splat := range splats { + splatIdNode := splat.Child(0) + if splatIdNode.Type() == "identifier" { + decls = append(decls, &Variable{ + Kind: VarKindVariable, + Name: splatIdNode.Content(py.source), + DeclNode: declarator, + }) + } + } + } + + return decls +} + +func (py *PyScopeBuilder) CollectVariables(node *sitter.Node) []*Variable { + var declaredVars []*Variable + switch node.Type() { + case "assignment": + lhs := node.ChildByFieldName("left") + return py.scanDecl(lhs, node, declaredVars) + + case "aliased_import": + // import as + aliasName := node.ChildByFieldName("name") + if aliasName != nil { + declaredVars = append(declaredVars, &Variable{ + Kind: VarKindImport, + Name: aliasName.Content(py.source), + DeclNode: aliasName, + }) + } + + case "dotted_name": + // import + defaultImport := FirstChildOfType(node, "identifier") + if defaultImport != nil { + declaredVars = append(declaredVars, &Variable{ + Kind: VarKindImport, + Name: defaultImport.Content(py.source), + DeclNode: defaultImport, + }) + } + + } + + return declaredVars +} + +func (py *PyScopeBuilder) OnNodeEnter(node *sitter.Node, scope *Scope) { + // collected identifier references if found + if node.Type() == "identifier" || node.Type() == "list_splat_pattern" { + parent := node.Parent() + if parent == nil { + return + } + + parentType := parent.Type() + + if parentType == "assignment" && parent.ChildByFieldName("left") == node { + return + } + + if parentType == "parameters" { + return + } + + if parentType == "default_parameter" && parent.ChildByFieldName("name") == node { + return + } + + if parentType == "pattern_list" || parentType == "tuple_pattern" || parentType == "list_pattern" { + return + } + + // module names in from import ... are not references + // names in import as are not references + if parentType == "dotted_name" && !isModuleName(parent) && parent.Parent().Type() != "aliased_import" { + return + } + + if parentType == "aliased_import" { + return + } + + // resolve this reference + variable := scope.Lookup(node.Content(py.source)) + if variable == nil { + unresolved := UnresolvedRef{ + id: node, + surroundingScope: scope, + } + + py.unresolvedRefs = append(py.unresolvedRefs, unresolved) + return + } + + ref := &Reference{ + Variable: variable, + Node: node, + } + + variable.Refs = append(variable.Refs, ref) + + } +} + +func (py *PyScopeBuilder) OnNodeExit(node *sitter.Node, scope *Scope) { + if node.Type() == "module" { + for _, unresolved := range py.unresolvedRefs { + variable := unresolved.surroundingScope.Lookup(unresolved.id.Content(py.source)) + + if variable == nil { + continue + } + + ref := &Reference{ + Variable: variable, + Node: unresolved.id, + } + + variable.Refs = append(variable.Refs, ref) + } + } +} + +func isModuleName(dottedNameNode *sitter.Node) bool { + if dottedNameNode.Type() != "dotted_name" { + return false + } + + importNode := dottedNameNode.Parent() + if importNode.Type() != "import_from_statement" || importNode == nil { + return false + } + + moduleNameChildren := ChildrenWithFieldName(importNode, "module_name") + + return slices.Contains(moduleNameChildren, dottedNameNode) +} diff --git a/analysis/py_scope_test.go b/analysis/py_scope_test.go new file mode 100644 index 00000000..5efc3422 --- /dev/null +++ b/analysis/py_scope_test.go @@ -0,0 +1,97 @@ +package analysis + +import ( + // "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func parsePyFile(t *testing.T, source string) *ParseResult { + parsed, err := Parse("file.py", []byte(source), LangPy, LangPy.Grammar()) + require.NoError(t, err) + require.NotNil(t, parsed) + return parsed +} + +func Test_PyBuildScopeTree(t *testing.T) { + t.Run("is able to resolve references", func(t *testing.T) { + source := ` + x = 1 + if True: + y = x + z = x` + parsed := parsePyFile(t, source) + + scopeTree := MakeScopeTree(parsed.Language, parsed.Ast, parsed.Source) + require.NotNil(t, scopeTree) + + globalScope := scopeTree.Root.Children[0] + require.NotNil(t, globalScope) + + varX, exists := globalScope.Variables["x"] + require.True(t, exists) + require.NotNil(t, varX) + + varY, exists := globalScope.Children[0].Variables["y"] + require.True(t, exists) + require.NotNil(t, varY) + require.Equal(t, VarKindVariable, varY.Kind) + + assert.Equal(t, 2, len(varX.Refs)) + xRef := varX.Refs[0] + assert.Equal(t, "x", xRef.Variable.Name) + require.Equal(t, VarKindVariable, varY.Kind) + + }) + + t.Run("supports import statements", func(t *testing.T) { + source := ` + import os + + os.system("cat file.txt") + + from csv import read + + if True: + f = read(file.csv) + ` + parsed := parsePyFile(t, source) + + scopeTree := MakeScopeTree(parsed.Language, parsed.Ast, parsed.Source) + require.NotNil(t, scopeTree) + + globalScope := scopeTree.Root.Children[0] + require.NotNil(t, globalScope) + + { + varOs, exists := globalScope.Variables["os"] + require.NotNil(t, varOs) + require.True(t, exists) + + assert.Equal(t, VarKindImport, varOs.Kind) + + osRefs := varOs.Refs + require.Equal(t, 1, len(osRefs)) + assert.Equal(t, "attribute", osRefs[0].Node.Parent().Type()) + } + + { + varRead, exists := globalScope.Variables["read"] + require.True(t, exists) + require.NotNil(t, varRead) + assert.Equal(t, VarKindImport, varRead.Kind) + + varF, exists := globalScope.Children[0].Variables["f"] + require.True(t, exists) + require.NotNil(t, varF) + assert.Equal(t, VarKindVariable, varF.Kind) + + readRefs := varRead.Refs + require.Equal(t, 1, len(readRefs)) + assert.Equal(t, "call", readRefs[0].Node.Parent().Type()) + } + + }) +} diff --git a/analysis/scope.go b/analysis/scope.go index 91c971a4..02b6b2b1 100644 --- a/analysis/scope.go +++ b/analysis/scope.go @@ -179,7 +179,11 @@ func (st *ScopeTree) GetScope(node *sitter.Node) *Scope { func MakeScopeTree(lang Language, ast *sitter.Node, source []byte) *ScopeTree { switch lang { case LangPy: - return nil + builder := &PyScopeBuilder{ + ast: ast, + source: source, + } + return BuildScopeTree(builder, ast, source) case LangTs, LangJs, LangTsx: builder := &TsScopeBuilder{ ast: ast, From f56300a3f1533bba27f17c8f7aaa2f88b22b0547 Mon Sep 17 00:00:00 2001 From: Maharshi Basu Date: Sun, 6 Apr 2025 17:05:58 +0530 Subject: [PATCH 02/16] enh: add support for function name and function params Signed-off-by: Maharshi Basu --- analysis/py_scope.go | 115 +++++++++++++++++++++++++++++++++++++- analysis/py_scope_test.go | 49 ++++++++++++++++ 2 files changed, 162 insertions(+), 2 deletions(-) diff --git a/analysis/py_scope.go b/analysis/py_scope.go index 6d33c33a..3b814be8 100644 --- a/analysis/py_scope.go +++ b/analysis/py_scope.go @@ -48,7 +48,7 @@ func (py *PyScopeBuilder) NodeCreatesScope(node *sitter.Node) bool { func (py *PyScopeBuilder) DeclaresVariable(node *sitter.Node) bool { typ := node.Type() - return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_item" || typ == "parameters" + return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_item" || typ == "parameters" || typ == "function_definition" } func (py *PyScopeBuilder) scanDecl(idOrPattern, declarator *sitter.Node, decls []*Variable) []*Variable { @@ -101,6 +101,22 @@ func (py *PyScopeBuilder) CollectVariables(node *sitter.Node) []*Variable { lhs := node.ChildByFieldName("left") return py.scanDecl(lhs, node, declaredVars) + case "function_definition": + name := node.ChildByFieldName("name") + // skipcq: TCV-001 + if name == nil { + break + } + + declaredVars = append(declaredVars, &Variable{ + Kind: VarKindFunction, + Name: name.Content(py.source), + DeclNode: node, + }) + + case "parameters": + declaredVars = py.variableFromFunctionParams(node, declaredVars) + case "aliased_import": // import as aliasName := node.ChildByFieldName("name") @@ -138,7 +154,7 @@ func (py *PyScopeBuilder) OnNodeEnter(node *sitter.Node, scope *Scope) { parentType := parent.Type() - if parentType == "assignment" && parent.ChildByFieldName("left") == node { + if parentType == "assignment" && parent.ChildByFieldName("left") == node { return } @@ -164,6 +180,14 @@ func (py *PyScopeBuilder) OnNodeEnter(node *sitter.Node, scope *Scope) { return } + if parentType == "function_definition" { + return + } + + if parentType == "paramaters" || parentType == "default_parameter" || parentType == "typed_default_parameter" { + return + } + // resolve this reference variable := scope.Lookup(node.Content(py.source)) if variable == nil { @@ -219,3 +243,90 @@ func isModuleName(dottedNameNode *sitter.Node) bool { return slices.Contains(moduleNameChildren, dottedNameNode) } + +func (py *PyScopeBuilder) variableFromFunctionParams(node *sitter.Node, decls []*Variable) []*Variable { + childrenCount := node.NamedChildCount() + for i := 0; i < int(childrenCount); i++ { + param := node.NamedChild(i) + + if param == nil { + continue + } + + // handle the parameter types: + // identifier, typed_parameter, default_parameter, typed_default_parameter + if param.Type() == "identifier" { + decls = append(decls, &Variable{ + Kind: VarKindParameter, + Name: param.Content(py.source), + DeclNode: param, + }) + } else if param.Type() == "typed_parameter" || param.Type() == "list_splat_pattern" || param.Type() == "dictionary_splat_pattern" { + idNode := FirstChildOfType(param, "identifier") + if idNode != nil { + decls = append(decls, &Variable{ + Kind: VarKindParameter, + Name: idNode.Content(py.source), + DeclNode: param, + }) + } + } else if param.Type() == "default_parameter" || param.Type() == "typed_default_parameter" { + name := ChildWithFieldName(param, "name") + if name != nil { + if name.Type() == "identifier" { + decls = append(decls, &Variable{ + Kind: VarKindParameter, + Name: name.Content(py.source), + DeclNode: param, + }) + } else if name.Type() == "tuple_pattern" { + childrenIds := ChildrenOfType(name, "identifier") + childrenListSplat := ChildrenOfType(name, "list_splat_pattern") + + for _, id := range childrenIds { + decls = append(decls, &Variable{ + Kind: VarKindParameter, + Name: id.Content(py.source), + DeclNode: param, + }) + } + + for _, listSplatPat := range childrenListSplat { + splatId := FirstChildOfType(listSplatPat, "identifier") + if splatId != nil { + decls = append(decls, &Variable{ + Kind: VarKindParameter, + Name: listSplatPat.Content(py.source), + DeclNode: param, + }) + } + } + } + } + } else if param.Type() == "tuple_pattern" { + childrenIds := ChildrenOfType(param, "identifier") + childrenListSplat := ChildrenOfType(param, "list_splat_pattern") + + for _, id := range childrenIds { + decls = append(decls, &Variable{ + Kind: VarKindParameter, + Name: id.Content(py.source), + DeclNode: param, + }) + } + + for _, listSplatPat := range childrenListSplat { + splatId := FirstChildOfType(listSplatPat, "identifier") + if splatId != nil { + decls = append(decls, &Variable{ + Kind: VarKindParameter, + Name: listSplatPat.Content(py.source), + DeclNode: param, + }) + } + } + } + } + + return decls +} diff --git a/analysis/py_scope_test.go b/analysis/py_scope_test.go index 5efc3422..17c5b2b2 100644 --- a/analysis/py_scope_test.go +++ b/analysis/py_scope_test.go @@ -94,4 +94,53 @@ func Test_PyBuildScopeTree(t *testing.T) { } }) + + t.Run("supports function parameters", func(t *testing.T) { + source := ` + def myFunc(a, b=2, c:int, d:str="Hello"): + A = otherFunc(a) + C = b + c + print(d) + ` + parsed := parsePyFile(t, source) + + scopeTree := MakeScopeTree(parsed.Language, parsed.Ast, parsed.Source) + require.NotNil(t, scopeTree) + + globalScope := scopeTree.Root.Children[0] + require.NotNil(t, globalScope) + + { + varMyFunc, exists := globalScope.Variables["myFunc"] + require.NotNil(t, varMyFunc) + require.True(t, exists) + + assert.Equal(t, VarKindFunction, varMyFunc.Kind) + myFuncRefs := varMyFunc.Refs + require.Equal(t, 0, len(myFuncRefs)) + } + + { + varA, exists := globalScope.Children[0].Variables["a"] + require.NotNil(t, varA) + require.True(t, exists) + assert.Equal(t, VarKindParameter, varA.Kind) + + aRefs := varA.Refs + require.Equal(t, 1, len(aRefs)) + assert.Equal(t, "argument_list", aRefs[0].Node.Parent().Type()) + } + + { + varB, exists := globalScope.Children[0].Variables["b"] + require.NotNil(t, varB) + require.True(t, exists) + assert.Equal(t, VarKindParameter, varB.Kind) + + bRefs := varB.Refs + require.Equal(t, 1, len(bRefs)) + assert.Equal(t, "binary_operator", bRefs[0].Node.Parent().Type()) + } + }) + } From d57081241b2d2f8a9bb3f43e7eae8d9d8af943a9 Mon Sep 17 00:00:00 2001 From: Maharshi Basu Date: Sun, 6 Apr 2025 19:06:35 +0530 Subject: [PATCH 03/16] feat: scope analysis for python in yaml checkers --- pkg/analysis/scope.go | 6 +- pkg/analysis/scope_py.go | 332 ++++++++++++++++++++++++++++++++++ pkg/analysis/scope_py_test.go | 146 +++++++++++++++ 3 files changed, 483 insertions(+), 1 deletion(-) create mode 100644 pkg/analysis/scope_py.go create mode 100644 pkg/analysis/scope_py_test.go diff --git a/pkg/analysis/scope.go b/pkg/analysis/scope.go index 2aefe4ad..1346bb0b 100644 --- a/pkg/analysis/scope.go +++ b/pkg/analysis/scope.go @@ -177,7 +177,11 @@ func (st *ScopeTree) GetScope(node *sitter.Node) *Scope { func MakeScopeTree(lang Language, ast *sitter.Node, source []byte) *ScopeTree { switch lang { case LangPy: - return nil + builder := &PyScopeBuilder{ + ast: ast, + source: source, + } + return BuildScopeTree(builder, ast, source) case LangTs, LangJs, LangTsx: builder := &TsScopeBuilder{ ast: ast, diff --git a/pkg/analysis/scope_py.go b/pkg/analysis/scope_py.go new file mode 100644 index 00000000..3b814be8 --- /dev/null +++ b/pkg/analysis/scope_py.go @@ -0,0 +1,332 @@ +package analysis + +import ( + "slices" + + sitter "github.com/smacker/go-tree-sitter" +) + +// NOTE: should this struct type be moved to another file? +/* +type UnresolvedRef struct { + id *sitter.Node + surroundingScope *Scope +} +*/ + +type PyScopeBuilder struct { + ast *sitter.Node + source []byte + // list of references that could not be resolved thus far + unresolvedRefs []UnresolvedRef +} + +func (py *PyScopeBuilder) GetLanguage() Language { + return LangPy +} + +var PyScopeNodes = []string{ + "module", + "function_definition", + "class_definition", + "for_statement", + "while_statement", + "if_statement", + "elif_clause", + "else_clause", + "with_statement", + "try_statement", + "except_clause", + "list_comprehension", + "dictionary_comprehension", + "lambda", +} + +func (py *PyScopeBuilder) NodeCreatesScope(node *sitter.Node) bool { + return slices.Contains(PyScopeNodes, node.Type()) +} + +func (py *PyScopeBuilder) DeclaresVariable(node *sitter.Node) bool { + typ := node.Type() + return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_item" || typ == "parameters" || typ == "function_definition" +} + +func (py *PyScopeBuilder) scanDecl(idOrPattern, declarator *sitter.Node, decls []*Variable) []*Variable { + switch idOrPattern.Type() { + case "identifier": + // TODO: implement for = = ... + // = ... + nameStr := idOrPattern.Content(py.source) + decls = append(decls, &Variable{ + Kind: VarKindVariable, + Name: nameStr, + DeclNode: declarator, + }) + + case "pattern_list", "tuple_pattern", "list_pattern": + // , = ..., ... + // (, ) = ..., ... + // [, ] = ..., ... + ids := ChildrenOfType(idOrPattern, "identifier") + for _, id := range ids { + decls = append(decls, &Variable{ + Kind: VarKindVariable, + Name: id.Content(py.source), + DeclNode: declarator, + }) + } + + // , * = ..., ..., ... + // also applicable to tuple_pattern & list_pattern + splats := ChildrenOfType(idOrPattern, "list_splat_pattern") + for _, splat := range splats { + splatIdNode := splat.Child(0) + if splatIdNode.Type() == "identifier" { + decls = append(decls, &Variable{ + Kind: VarKindVariable, + Name: splatIdNode.Content(py.source), + DeclNode: declarator, + }) + } + } + } + + return decls +} + +func (py *PyScopeBuilder) CollectVariables(node *sitter.Node) []*Variable { + var declaredVars []*Variable + switch node.Type() { + case "assignment": + lhs := node.ChildByFieldName("left") + return py.scanDecl(lhs, node, declaredVars) + + case "function_definition": + name := node.ChildByFieldName("name") + // skipcq: TCV-001 + if name == nil { + break + } + + declaredVars = append(declaredVars, &Variable{ + Kind: VarKindFunction, + Name: name.Content(py.source), + DeclNode: node, + }) + + case "parameters": + declaredVars = py.variableFromFunctionParams(node, declaredVars) + + case "aliased_import": + // import as + aliasName := node.ChildByFieldName("name") + if aliasName != nil { + declaredVars = append(declaredVars, &Variable{ + Kind: VarKindImport, + Name: aliasName.Content(py.source), + DeclNode: aliasName, + }) + } + + case "dotted_name": + // import + defaultImport := FirstChildOfType(node, "identifier") + if defaultImport != nil { + declaredVars = append(declaredVars, &Variable{ + Kind: VarKindImport, + Name: defaultImport.Content(py.source), + DeclNode: defaultImport, + }) + } + + } + + return declaredVars +} + +func (py *PyScopeBuilder) OnNodeEnter(node *sitter.Node, scope *Scope) { + // collected identifier references if found + if node.Type() == "identifier" || node.Type() == "list_splat_pattern" { + parent := node.Parent() + if parent == nil { + return + } + + parentType := parent.Type() + + if parentType == "assignment" && parent.ChildByFieldName("left") == node { + return + } + + if parentType == "parameters" { + return + } + + if parentType == "default_parameter" && parent.ChildByFieldName("name") == node { + return + } + + if parentType == "pattern_list" || parentType == "tuple_pattern" || parentType == "list_pattern" { + return + } + + // module names in from import ... are not references + // names in import as are not references + if parentType == "dotted_name" && !isModuleName(parent) && parent.Parent().Type() != "aliased_import" { + return + } + + if parentType == "aliased_import" { + return + } + + if parentType == "function_definition" { + return + } + + if parentType == "paramaters" || parentType == "default_parameter" || parentType == "typed_default_parameter" { + return + } + + // resolve this reference + variable := scope.Lookup(node.Content(py.source)) + if variable == nil { + unresolved := UnresolvedRef{ + id: node, + surroundingScope: scope, + } + + py.unresolvedRefs = append(py.unresolvedRefs, unresolved) + return + } + + ref := &Reference{ + Variable: variable, + Node: node, + } + + variable.Refs = append(variable.Refs, ref) + + } +} + +func (py *PyScopeBuilder) OnNodeExit(node *sitter.Node, scope *Scope) { + if node.Type() == "module" { + for _, unresolved := range py.unresolvedRefs { + variable := unresolved.surroundingScope.Lookup(unresolved.id.Content(py.source)) + + if variable == nil { + continue + } + + ref := &Reference{ + Variable: variable, + Node: unresolved.id, + } + + variable.Refs = append(variable.Refs, ref) + } + } +} + +func isModuleName(dottedNameNode *sitter.Node) bool { + if dottedNameNode.Type() != "dotted_name" { + return false + } + + importNode := dottedNameNode.Parent() + if importNode.Type() != "import_from_statement" || importNode == nil { + return false + } + + moduleNameChildren := ChildrenWithFieldName(importNode, "module_name") + + return slices.Contains(moduleNameChildren, dottedNameNode) +} + +func (py *PyScopeBuilder) variableFromFunctionParams(node *sitter.Node, decls []*Variable) []*Variable { + childrenCount := node.NamedChildCount() + for i := 0; i < int(childrenCount); i++ { + param := node.NamedChild(i) + + if param == nil { + continue + } + + // handle the parameter types: + // identifier, typed_parameter, default_parameter, typed_default_parameter + if param.Type() == "identifier" { + decls = append(decls, &Variable{ + Kind: VarKindParameter, + Name: param.Content(py.source), + DeclNode: param, + }) + } else if param.Type() == "typed_parameter" || param.Type() == "list_splat_pattern" || param.Type() == "dictionary_splat_pattern" { + idNode := FirstChildOfType(param, "identifier") + if idNode != nil { + decls = append(decls, &Variable{ + Kind: VarKindParameter, + Name: idNode.Content(py.source), + DeclNode: param, + }) + } + } else if param.Type() == "default_parameter" || param.Type() == "typed_default_parameter" { + name := ChildWithFieldName(param, "name") + if name != nil { + if name.Type() == "identifier" { + decls = append(decls, &Variable{ + Kind: VarKindParameter, + Name: name.Content(py.source), + DeclNode: param, + }) + } else if name.Type() == "tuple_pattern" { + childrenIds := ChildrenOfType(name, "identifier") + childrenListSplat := ChildrenOfType(name, "list_splat_pattern") + + for _, id := range childrenIds { + decls = append(decls, &Variable{ + Kind: VarKindParameter, + Name: id.Content(py.source), + DeclNode: param, + }) + } + + for _, listSplatPat := range childrenListSplat { + splatId := FirstChildOfType(listSplatPat, "identifier") + if splatId != nil { + decls = append(decls, &Variable{ + Kind: VarKindParameter, + Name: listSplatPat.Content(py.source), + DeclNode: param, + }) + } + } + } + } + } else if param.Type() == "tuple_pattern" { + childrenIds := ChildrenOfType(param, "identifier") + childrenListSplat := ChildrenOfType(param, "list_splat_pattern") + + for _, id := range childrenIds { + decls = append(decls, &Variable{ + Kind: VarKindParameter, + Name: id.Content(py.source), + DeclNode: param, + }) + } + + for _, listSplatPat := range childrenListSplat { + splatId := FirstChildOfType(listSplatPat, "identifier") + if splatId != nil { + decls = append(decls, &Variable{ + Kind: VarKindParameter, + Name: listSplatPat.Content(py.source), + DeclNode: param, + }) + } + } + } + } + + return decls +} diff --git a/pkg/analysis/scope_py_test.go b/pkg/analysis/scope_py_test.go new file mode 100644 index 00000000..17c5b2b2 --- /dev/null +++ b/pkg/analysis/scope_py_test.go @@ -0,0 +1,146 @@ +package analysis + +import ( + // "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func parsePyFile(t *testing.T, source string) *ParseResult { + parsed, err := Parse("file.py", []byte(source), LangPy, LangPy.Grammar()) + require.NoError(t, err) + require.NotNil(t, parsed) + return parsed +} + +func Test_PyBuildScopeTree(t *testing.T) { + t.Run("is able to resolve references", func(t *testing.T) { + source := ` + x = 1 + if True: + y = x + z = x` + parsed := parsePyFile(t, source) + + scopeTree := MakeScopeTree(parsed.Language, parsed.Ast, parsed.Source) + require.NotNil(t, scopeTree) + + globalScope := scopeTree.Root.Children[0] + require.NotNil(t, globalScope) + + varX, exists := globalScope.Variables["x"] + require.True(t, exists) + require.NotNil(t, varX) + + varY, exists := globalScope.Children[0].Variables["y"] + require.True(t, exists) + require.NotNil(t, varY) + require.Equal(t, VarKindVariable, varY.Kind) + + assert.Equal(t, 2, len(varX.Refs)) + xRef := varX.Refs[0] + assert.Equal(t, "x", xRef.Variable.Name) + require.Equal(t, VarKindVariable, varY.Kind) + + }) + + t.Run("supports import statements", func(t *testing.T) { + source := ` + import os + + os.system("cat file.txt") + + from csv import read + + if True: + f = read(file.csv) + ` + parsed := parsePyFile(t, source) + + scopeTree := MakeScopeTree(parsed.Language, parsed.Ast, parsed.Source) + require.NotNil(t, scopeTree) + + globalScope := scopeTree.Root.Children[0] + require.NotNil(t, globalScope) + + { + varOs, exists := globalScope.Variables["os"] + require.NotNil(t, varOs) + require.True(t, exists) + + assert.Equal(t, VarKindImport, varOs.Kind) + + osRefs := varOs.Refs + require.Equal(t, 1, len(osRefs)) + assert.Equal(t, "attribute", osRefs[0].Node.Parent().Type()) + } + + { + varRead, exists := globalScope.Variables["read"] + require.True(t, exists) + require.NotNil(t, varRead) + assert.Equal(t, VarKindImport, varRead.Kind) + + varF, exists := globalScope.Children[0].Variables["f"] + require.True(t, exists) + require.NotNil(t, varF) + assert.Equal(t, VarKindVariable, varF.Kind) + + readRefs := varRead.Refs + require.Equal(t, 1, len(readRefs)) + assert.Equal(t, "call", readRefs[0].Node.Parent().Type()) + } + + }) + + t.Run("supports function parameters", func(t *testing.T) { + source := ` + def myFunc(a, b=2, c:int, d:str="Hello"): + A = otherFunc(a) + C = b + c + print(d) + ` + parsed := parsePyFile(t, source) + + scopeTree := MakeScopeTree(parsed.Language, parsed.Ast, parsed.Source) + require.NotNil(t, scopeTree) + + globalScope := scopeTree.Root.Children[0] + require.NotNil(t, globalScope) + + { + varMyFunc, exists := globalScope.Variables["myFunc"] + require.NotNil(t, varMyFunc) + require.True(t, exists) + + assert.Equal(t, VarKindFunction, varMyFunc.Kind) + myFuncRefs := varMyFunc.Refs + require.Equal(t, 0, len(myFuncRefs)) + } + + { + varA, exists := globalScope.Children[0].Variables["a"] + require.NotNil(t, varA) + require.True(t, exists) + assert.Equal(t, VarKindParameter, varA.Kind) + + aRefs := varA.Refs + require.Equal(t, 1, len(aRefs)) + assert.Equal(t, "argument_list", aRefs[0].Node.Parent().Type()) + } + + { + varB, exists := globalScope.Children[0].Variables["b"] + require.NotNil(t, varB) + require.True(t, exists) + assert.Equal(t, VarKindParameter, varB.Kind) + + bRefs := varB.Refs + require.Equal(t, 1, len(bRefs)) + assert.Equal(t, "binary_operator", bRefs[0].Node.Parent().Type()) + } + }) + +} From 7a3a9608a05f46bad1f20bc2a99ab58e561d2c35 Mon Sep 17 00:00:00 2001 From: Maharshi Basu Date: Mon, 7 Apr 2025 12:10:22 +0530 Subject: [PATCH 04/16] enh: add support for with statements Signed-off-by: Maharshi Basu --- analysis/py_scope.go | 24 ++++++++++++++++++++++-- analysis/py_scope_test.go | 26 +++++++++++++++++++++++++- pkg/analysis/scope_py.go | 24 ++++++++++++++++++++++-- pkg/analysis/scope_py_test.go | 26 +++++++++++++++++++++++++- 4 files changed, 94 insertions(+), 6 deletions(-) diff --git a/analysis/py_scope.go b/analysis/py_scope.go index 3b814be8..d172e7e7 100644 --- a/analysis/py_scope.go +++ b/analysis/py_scope.go @@ -48,7 +48,7 @@ func (py *PyScopeBuilder) NodeCreatesScope(node *sitter.Node) bool { func (py *PyScopeBuilder) DeclaresVariable(node *sitter.Node) bool { typ := node.Type() - return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_item" || typ == "parameters" || typ == "function_definition" + return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" } func (py *PyScopeBuilder) scanDecl(idOrPattern, declarator *sitter.Node, decls []*Variable) []*Variable { @@ -139,6 +139,22 @@ func (py *PyScopeBuilder) CollectVariables(node *sitter.Node) []*Variable { }) } + case "with_statement": + clause := FirstChildOfType(node, "with_clause") + item := FirstChildOfType(clause, "with_item") + + value := item.ChildByFieldName("value") + alias := value.ChildByFieldName("alias") + if alias != nil { + id := FirstChildOfType(alias, "identifier") + if id != nil { + declaredVars = append(declaredVars, &Variable{ + Kind: VarKindVariable, + Name: id.Content(py.source), + DeclNode: node, + }) + } + } } return declaredVars @@ -184,7 +200,11 @@ func (py *PyScopeBuilder) OnNodeEnter(node *sitter.Node, scope *Scope) { return } - if parentType == "paramaters" || parentType == "default_parameter" || parentType == "typed_default_parameter" { + if parentType == "parameters" || parentType == "default_parameter" || parentType == "typed_default_parameter" { + return + } + + if parentType == "as_pattern_target" { return } diff --git a/analysis/py_scope_test.go b/analysis/py_scope_test.go index 17c5b2b2..b94c24c8 100644 --- a/analysis/py_scope_test.go +++ b/analysis/py_scope_test.go @@ -1,7 +1,6 @@ package analysis import ( - // "fmt" "testing" "github.com/stretchr/testify/assert" @@ -143,4 +142,29 @@ func Test_PyBuildScopeTree(t *testing.T) { } }) + t.Run("supports with statements", func(t *testing.T) { + source := ` + with open("file.txt", 'r') as f: + print(f.read(5)) + ` + parsed := parsePyFile(t, source) + + scopeTree := MakeScopeTree(parsed.Language, parsed.Ast, parsed.Source) + require.NotNil(t, scopeTree) + + globalScope := scopeTree.Root.Children[0] + require.NotNil(t, globalScope) + + { + varF, exists := globalScope.Variables["f"] + require.NotNil(t, varF) + require.True(t, exists) + + assert.Equal(t, VarKindVariable, varF.Kind) + fRefs := varF.Refs + require.Equal(t, 1, len(fRefs)) + assert.Equal(t, "call", fRefs[0].Node.Parent().Parent().Type()) + } + }) + } diff --git a/pkg/analysis/scope_py.go b/pkg/analysis/scope_py.go index 3b814be8..d172e7e7 100644 --- a/pkg/analysis/scope_py.go +++ b/pkg/analysis/scope_py.go @@ -48,7 +48,7 @@ func (py *PyScopeBuilder) NodeCreatesScope(node *sitter.Node) bool { func (py *PyScopeBuilder) DeclaresVariable(node *sitter.Node) bool { typ := node.Type() - return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_item" || typ == "parameters" || typ == "function_definition" + return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" } func (py *PyScopeBuilder) scanDecl(idOrPattern, declarator *sitter.Node, decls []*Variable) []*Variable { @@ -139,6 +139,22 @@ func (py *PyScopeBuilder) CollectVariables(node *sitter.Node) []*Variable { }) } + case "with_statement": + clause := FirstChildOfType(node, "with_clause") + item := FirstChildOfType(clause, "with_item") + + value := item.ChildByFieldName("value") + alias := value.ChildByFieldName("alias") + if alias != nil { + id := FirstChildOfType(alias, "identifier") + if id != nil { + declaredVars = append(declaredVars, &Variable{ + Kind: VarKindVariable, + Name: id.Content(py.source), + DeclNode: node, + }) + } + } } return declaredVars @@ -184,7 +200,11 @@ func (py *PyScopeBuilder) OnNodeEnter(node *sitter.Node, scope *Scope) { return } - if parentType == "paramaters" || parentType == "default_parameter" || parentType == "typed_default_parameter" { + if parentType == "parameters" || parentType == "default_parameter" || parentType == "typed_default_parameter" { + return + } + + if parentType == "as_pattern_target" { return } diff --git a/pkg/analysis/scope_py_test.go b/pkg/analysis/scope_py_test.go index 17c5b2b2..b94c24c8 100644 --- a/pkg/analysis/scope_py_test.go +++ b/pkg/analysis/scope_py_test.go @@ -1,7 +1,6 @@ package analysis import ( - // "fmt" "testing" "github.com/stretchr/testify/assert" @@ -143,4 +142,29 @@ func Test_PyBuildScopeTree(t *testing.T) { } }) + t.Run("supports with statements", func(t *testing.T) { + source := ` + with open("file.txt", 'r') as f: + print(f.read(5)) + ` + parsed := parsePyFile(t, source) + + scopeTree := MakeScopeTree(parsed.Language, parsed.Ast, parsed.Source) + require.NotNil(t, scopeTree) + + globalScope := scopeTree.Root.Children[0] + require.NotNil(t, globalScope) + + { + varF, exists := globalScope.Variables["f"] + require.NotNil(t, varF) + require.True(t, exists) + + assert.Equal(t, VarKindVariable, varF.Kind) + fRefs := varF.Refs + require.Equal(t, 1, len(fRefs)) + assert.Equal(t, "call", fRefs[0].Node.Parent().Parent().Type()) + } + }) + } From 0de1d53b5c52ba5b9dadfd7c3910d56dff8863fa Mon Sep 17 00:00:00 2001 From: Maharshi Basu Date: Mon, 7 Apr 2025 19:22:56 +0530 Subject: [PATCH 05/16] enh: add support for exception clauses --- analysis/py_scope.go | 21 ++++++++++++++++++++- analysis/py_scope_test.go | 25 +++++++++++++++++++++++++ pkg/analysis/scope_py.go | 21 ++++++++++++++++++++- pkg/analysis/scope_py_test.go | 25 +++++++++++++++++++++++++ 4 files changed, 90 insertions(+), 2 deletions(-) diff --git a/analysis/py_scope.go b/analysis/py_scope.go index d172e7e7..65b8c863 100644 --- a/analysis/py_scope.go +++ b/analysis/py_scope.go @@ -48,7 +48,7 @@ func (py *PyScopeBuilder) NodeCreatesScope(node *sitter.Node) bool { func (py *PyScopeBuilder) DeclaresVariable(node *sitter.Node) bool { typ := node.Type() - return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" + return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" || typ == "try_statement" } func (py *PyScopeBuilder) scanDecl(idOrPattern, declarator *sitter.Node, decls []*Variable) []*Variable { @@ -155,6 +155,23 @@ func (py *PyScopeBuilder) CollectVariables(node *sitter.Node) []*Variable { }) } } + + case "try_statement": + exceptClause := FirstChildOfType(node, "except_clause") + asPattern := FirstChildOfType(exceptClause, "as_pattern") + + if asPattern != nil { + alias := asPattern.ChildByFieldName("alias") + id := FirstChildOfType(alias, "identifier") + + if id != nil { + declaredVars = append(declaredVars, &Variable{ + Kind: VarKindError, + Name: id.Content(py.source), + DeclNode: node, + }) + } + } } return declaredVars @@ -204,6 +221,8 @@ func (py *PyScopeBuilder) OnNodeEnter(node *sitter.Node, scope *Scope) { return } + // works for patterns with an `as` clause + // eg. with as , except as if parentType == "as_pattern_target" { return } diff --git a/analysis/py_scope_test.go b/analysis/py_scope_test.go index b94c24c8..50f9d9c3 100644 --- a/analysis/py_scope_test.go +++ b/analysis/py_scope_test.go @@ -167,4 +167,29 @@ func Test_PyBuildScopeTree(t *testing.T) { } }) + t.Run("supports exception statements", func(t *testing.T) { + source := ` + try: + result = 10 / 2 + except ZeroDivisionError as e: + print(e) + ` + parsed := parsePyFile(t, source) + + scopeTree := MakeScopeTree(parsed.Language, parsed.Ast, parsed.Source) + require.NotNil(t, scopeTree) + + globalScope := scopeTree.Root.Children[0] + require.NotNil(t, globalScope) + + varE, exists := globalScope.Variables["e"] + require.NotNil(t, varE) + require.True(t, exists) + + assert.Equal(t, VarKindError, varE.Kind) + eRefs := varE.Refs + require.Equal(t, 1, len(eRefs)) + assert.Equal(t, "call", eRefs[0].Node.Parent().Parent().Type()) + }) + } diff --git a/pkg/analysis/scope_py.go b/pkg/analysis/scope_py.go index d172e7e7..65b8c863 100644 --- a/pkg/analysis/scope_py.go +++ b/pkg/analysis/scope_py.go @@ -48,7 +48,7 @@ func (py *PyScopeBuilder) NodeCreatesScope(node *sitter.Node) bool { func (py *PyScopeBuilder) DeclaresVariable(node *sitter.Node) bool { typ := node.Type() - return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" + return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" || typ == "try_statement" } func (py *PyScopeBuilder) scanDecl(idOrPattern, declarator *sitter.Node, decls []*Variable) []*Variable { @@ -155,6 +155,23 @@ func (py *PyScopeBuilder) CollectVariables(node *sitter.Node) []*Variable { }) } } + + case "try_statement": + exceptClause := FirstChildOfType(node, "except_clause") + asPattern := FirstChildOfType(exceptClause, "as_pattern") + + if asPattern != nil { + alias := asPattern.ChildByFieldName("alias") + id := FirstChildOfType(alias, "identifier") + + if id != nil { + declaredVars = append(declaredVars, &Variable{ + Kind: VarKindError, + Name: id.Content(py.source), + DeclNode: node, + }) + } + } } return declaredVars @@ -204,6 +221,8 @@ func (py *PyScopeBuilder) OnNodeEnter(node *sitter.Node, scope *Scope) { return } + // works for patterns with an `as` clause + // eg. with as , except as if parentType == "as_pattern_target" { return } diff --git a/pkg/analysis/scope_py_test.go b/pkg/analysis/scope_py_test.go index b94c24c8..50f9d9c3 100644 --- a/pkg/analysis/scope_py_test.go +++ b/pkg/analysis/scope_py_test.go @@ -167,4 +167,29 @@ func Test_PyBuildScopeTree(t *testing.T) { } }) + t.Run("supports exception statements", func(t *testing.T) { + source := ` + try: + result = 10 / 2 + except ZeroDivisionError as e: + print(e) + ` + parsed := parsePyFile(t, source) + + scopeTree := MakeScopeTree(parsed.Language, parsed.Ast, parsed.Source) + require.NotNil(t, scopeTree) + + globalScope := scopeTree.Root.Children[0] + require.NotNil(t, globalScope) + + varE, exists := globalScope.Variables["e"] + require.NotNil(t, varE) + require.True(t, exists) + + assert.Equal(t, VarKindError, varE.Kind) + eRefs := varE.Refs + require.Equal(t, 1, len(eRefs)) + assert.Equal(t, "call", eRefs[0].Node.Parent().Parent().Type()) + }) + } From 499e1fe4e7691b3320921552b320ab0f4b820450 Mon Sep 17 00:00:00 2001 From: Maharshi Basu Date: Mon, 7 Apr 2025 22:51:50 +0530 Subject: [PATCH 06/16] enh: add support for classes --- analysis/py_scope.go | 16 +++++++++++++++- analysis/py_scope_test.go | 27 +++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/analysis/py_scope.go b/analysis/py_scope.go index 65b8c863..87cb06e7 100644 --- a/analysis/py_scope.go +++ b/analysis/py_scope.go @@ -48,7 +48,7 @@ func (py *PyScopeBuilder) NodeCreatesScope(node *sitter.Node) bool { func (py *PyScopeBuilder) DeclaresVariable(node *sitter.Node) bool { typ := node.Type() - return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" || typ == "try_statement" + return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" || typ == "try_statement" || typ == "class_definition" } func (py *PyScopeBuilder) scanDecl(idOrPattern, declarator *sitter.Node, decls []*Variable) []*Variable { @@ -172,6 +172,16 @@ func (py *PyScopeBuilder) CollectVariables(node *sitter.Node) []*Variable { }) } } + + case "class_definition": + name := node.ChildByFieldName("name") + if name != nil && name.Type() == "identifier" { + declaredVars = append(declaredVars, &Variable{ + Kind: VarKindClass, + Name: name.Content(py.source), + DeclNode: node, + }) + } } return declaredVars @@ -227,6 +237,10 @@ func (py *PyScopeBuilder) OnNodeEnter(node *sitter.Node, scope *Scope) { return } + if parentType == "class_definition" { + return + } + // resolve this reference variable := scope.Lookup(node.Content(py.source)) if variable == nil { diff --git a/analysis/py_scope_test.go b/analysis/py_scope_test.go index 50f9d9c3..43948a69 100644 --- a/analysis/py_scope_test.go +++ b/analysis/py_scope_test.go @@ -1,6 +1,7 @@ package analysis import ( + // "fmt" "testing" "github.com/stretchr/testify/assert" @@ -192,4 +193,30 @@ func Test_PyBuildScopeTree(t *testing.T) { assert.Equal(t, "call", eRefs[0].Node.Parent().Parent().Type()) }) + t.Run("supports classes", func(t *testing.T) { + source := ` + class MyClass: + def __init__(self, name): + self.name = name + + def print_name(self): + print(self.name) + ` + parsed := parsePyFile(t, source) + + scopeTree := MakeScopeTree(parsed.Language, parsed.Ast, parsed.Source) + require.NotNil(t, scopeTree) + + globalScope := scopeTree.Root.Children[0] + require.NotNil(t, globalScope) + + { + varClass, exists := globalScope.Variables["MyClass"] + require.NotNil(t, varClass) + require.True(t, exists) + assert.Equal(t, VarKindClass, varClass.Kind) + } + + }) + } From 0aab1667aecfc26a413d3dc77900abf5ca76445c Mon Sep 17 00:00:00 2001 From: Maharshi Basu Date: Mon, 7 Apr 2025 22:55:10 +0530 Subject: [PATCH 07/16] enh: add support for classes for yaml --- pkg/analysis/scope.go | 1 + pkg/analysis/scope_py.go | 16 +++++++++++++++- pkg/analysis/scope_py_test.go | 27 +++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/pkg/analysis/scope.go b/pkg/analysis/scope.go index 1346bb0b..93dacc55 100644 --- a/pkg/analysis/scope.go +++ b/pkg/analysis/scope.go @@ -29,6 +29,7 @@ const ( VarKindFunction VarKindVariable VarKindParameter + VarKindClass ) type Variable struct { diff --git a/pkg/analysis/scope_py.go b/pkg/analysis/scope_py.go index 65b8c863..87cb06e7 100644 --- a/pkg/analysis/scope_py.go +++ b/pkg/analysis/scope_py.go @@ -48,7 +48,7 @@ func (py *PyScopeBuilder) NodeCreatesScope(node *sitter.Node) bool { func (py *PyScopeBuilder) DeclaresVariable(node *sitter.Node) bool { typ := node.Type() - return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" || typ == "try_statement" + return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" || typ == "try_statement" || typ == "class_definition" } func (py *PyScopeBuilder) scanDecl(idOrPattern, declarator *sitter.Node, decls []*Variable) []*Variable { @@ -172,6 +172,16 @@ func (py *PyScopeBuilder) CollectVariables(node *sitter.Node) []*Variable { }) } } + + case "class_definition": + name := node.ChildByFieldName("name") + if name != nil && name.Type() == "identifier" { + declaredVars = append(declaredVars, &Variable{ + Kind: VarKindClass, + Name: name.Content(py.source), + DeclNode: node, + }) + } } return declaredVars @@ -227,6 +237,10 @@ func (py *PyScopeBuilder) OnNodeEnter(node *sitter.Node, scope *Scope) { return } + if parentType == "class_definition" { + return + } + // resolve this reference variable := scope.Lookup(node.Content(py.source)) if variable == nil { diff --git a/pkg/analysis/scope_py_test.go b/pkg/analysis/scope_py_test.go index 50f9d9c3..43948a69 100644 --- a/pkg/analysis/scope_py_test.go +++ b/pkg/analysis/scope_py_test.go @@ -1,6 +1,7 @@ package analysis import ( + // "fmt" "testing" "github.com/stretchr/testify/assert" @@ -192,4 +193,30 @@ func Test_PyBuildScopeTree(t *testing.T) { assert.Equal(t, "call", eRefs[0].Node.Parent().Parent().Type()) }) + t.Run("supports classes", func(t *testing.T) { + source := ` + class MyClass: + def __init__(self, name): + self.name = name + + def print_name(self): + print(self.name) + ` + parsed := parsePyFile(t, source) + + scopeTree := MakeScopeTree(parsed.Language, parsed.Ast, parsed.Source) + require.NotNil(t, scopeTree) + + globalScope := scopeTree.Root.Children[0] + require.NotNil(t, globalScope) + + { + varClass, exists := globalScope.Variables["MyClass"] + require.NotNil(t, varClass) + require.True(t, exists) + assert.Equal(t, VarKindClass, varClass.Kind) + } + + }) + } From 95d51466380aeaedc9233fb0735881d356fa4f5d Mon Sep 17 00:00:00 2001 From: Maharshi Basu Date: Tue, 15 Apr 2025 10:36:32 +0530 Subject: [PATCH 08/16] enh: support walrus operator --- analysis/py_scope.go | 16 +++++++++++++++- analysis/py_scope_test.go | 19 ++++++++++++++++++- analysis/scope.go | 1 + 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/analysis/py_scope.go b/analysis/py_scope.go index 87cb06e7..edd23746 100644 --- a/analysis/py_scope.go +++ b/analysis/py_scope.go @@ -48,7 +48,7 @@ func (py *PyScopeBuilder) NodeCreatesScope(node *sitter.Node) bool { func (py *PyScopeBuilder) DeclaresVariable(node *sitter.Node) bool { typ := node.Type() - return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" || typ == "try_statement" || typ == "class_definition" + return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" || typ == "try_statement" || typ == "class_definition" || typ == "named_expression" } func (py *PyScopeBuilder) scanDecl(idOrPattern, declarator *sitter.Node, decls []*Variable) []*Variable { @@ -173,6 +173,16 @@ func (py *PyScopeBuilder) CollectVariables(node *sitter.Node) []*Variable { } } + case "named_expression": + name := node.ChildByFieldName("name") + if name != nil && name.Type() == "identifier" { + declaredVars = append(declaredVars, &Variable{ + Kind: VarKindVariable, + Name: name.Content(py.source), + DeclNode: node, + }) + } + case "class_definition": name := node.ChildByFieldName("name") if name != nil && name.Type() == "identifier" { @@ -241,6 +251,10 @@ func (py *PyScopeBuilder) OnNodeEnter(node *sitter.Node, scope *Scope) { return } + if parentType == "named_expression" && parent.ChildByFieldName("name") == node { + return + } + // resolve this reference variable := scope.Lookup(node.Content(py.source)) if variable == nil { diff --git a/analysis/py_scope_test.go b/analysis/py_scope_test.go index 43948a69..328e57ef 100644 --- a/analysis/py_scope_test.go +++ b/analysis/py_scope_test.go @@ -1,7 +1,6 @@ package analysis import ( - // "fmt" "testing" "github.com/stretchr/testify/assert" @@ -168,6 +167,24 @@ func Test_PyBuildScopeTree(t *testing.T) { } }) + t.Run("supports walrus operator", func(t *testing.T) { + source := ` +if (n := random.randint(1, 100)) > 50: + print("Greater than 50") + ` + parsed := parsePyFile(t, source) + + scopeTree := MakeScopeTree(parsed.Language, parsed.Ast, parsed.Source) + require.NotNil(t, scopeTree) + + globalScope := scopeTree.Root.Children[0] + require.NotNil(t, globalScope) + + varN, exists := globalScope.Children[0].Variables["n"] + require.NotNil(t, varN) + require.True(t, exists) + }) + t.Run("supports exception statements", func(t *testing.T) { source := ` try: diff --git a/analysis/scope.go b/analysis/scope.go index 02b6b2b1..b67893f5 100644 --- a/analysis/scope.go +++ b/analysis/scope.go @@ -146,6 +146,7 @@ func buildScopeTree( nextScope := scope if builder.NodeCreatesScope(node) { nextScope = NewScope(scope) + nextScope.AstNode = node scopeOfNode[node] = nextScope scope.AstNode = node if scope != nil { From 0ada7e2472bf0dfbd669cfac0238e307c2c0255a Mon Sep 17 00:00:00 2001 From: Maharshi Basu Date: Tue, 15 Apr 2025 17:47:30 +0530 Subject: [PATCH 09/16] add support for comprehension pattern --- analysis/py_scope.go | 14 +++++++-- analysis/py_scope_test.go | 38 +++++++++++++++++++++++ pkg/analysis/scope_py.go | 28 +++++++++++++++-- pkg/analysis/scope_py_test.go | 57 ++++++++++++++++++++++++++++++++++- 4 files changed, 132 insertions(+), 5 deletions(-) diff --git a/analysis/py_scope.go b/analysis/py_scope.go index edd23746..48a9a306 100644 --- a/analysis/py_scope.go +++ b/analysis/py_scope.go @@ -48,7 +48,7 @@ func (py *PyScopeBuilder) NodeCreatesScope(node *sitter.Node) bool { func (py *PyScopeBuilder) DeclaresVariable(node *sitter.Node) bool { typ := node.Type() - return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" || typ == "try_statement" || typ == "class_definition" || typ == "named_expression" + return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" || typ == "try_statement" || typ == "class_definition" || typ == "named_expression" || typ == "for_in_clause" } func (py *PyScopeBuilder) scanDecl(idOrPattern, declarator *sitter.Node, decls []*Variable) []*Variable { @@ -192,7 +192,13 @@ func (py *PyScopeBuilder) CollectVariables(node *sitter.Node) []*Variable { DeclNode: node, }) } - } + + // used in `list_comprehension`, `dictionary_comprehension`, `generator_comprehension` + // `set_comprehension` + case "for_in_clause": + left := node.ChildByFieldName("left") + declaredVars = py.scanDecl(left, node, declaredVars) + } return declaredVars } @@ -255,6 +261,10 @@ func (py *PyScopeBuilder) OnNodeEnter(node *sitter.Node, scope *Scope) { return } + if parentType == "for_in_clause" && parent.ChildByFieldName("left") == node { + return + } + // resolve this reference variable := scope.Lookup(node.Content(py.source)) if variable == nil { diff --git a/analysis/py_scope_test.go b/analysis/py_scope_test.go index 328e57ef..4a25cd8c 100644 --- a/analysis/py_scope_test.go +++ b/analysis/py_scope_test.go @@ -185,6 +185,44 @@ if (n := random.randint(1, 100)) > 50: require.True(t, exists) }) + // for `list_comprehension`, `dictionary_comprehension`, `generator_comprehension`, `set_comprehension` + t.Run("supports comprehension statements", func(t *testing.T) { + source := ` +a = [x for x in range(10) if x % 2 == 0] + +b = {x: x**2 for x in myList if x == 10} + ` + parsed := parsePyFile(t, source) + + scopeTree := MakeScopeTree(parsed.Language, parsed.Ast, parsed.Source) + require.NotNil(t, scopeTree) + + globalScope := scopeTree.Root.Children[0] + require.NotNil(t, globalScope) + + { + lcScope := globalScope.Children[0] + require.NotNil(t, lcScope) + varX, exists := lcScope.Variables["x"] + require.NotNil(t, varX) + require.True(t, exists) + + xRefs := varX.Refs + assert.Equal(t, 2, len(xRefs)) // first in the expression, second in the if-else statement + } + + { + dcScope := globalScope.Children[1] + require.NotNil(t, dcScope) + varX, exists := dcScope.Variables["x"] + require.NotNil(t, varX) + require.True(t, exists) + + xRefs := varX.Refs + assert.Equal(t, 3, len(xRefs)) + } + }) + t.Run("supports exception statements", func(t *testing.T) { source := ` try: diff --git a/pkg/analysis/scope_py.go b/pkg/analysis/scope_py.go index 87cb06e7..48a9a306 100644 --- a/pkg/analysis/scope_py.go +++ b/pkg/analysis/scope_py.go @@ -48,7 +48,7 @@ func (py *PyScopeBuilder) NodeCreatesScope(node *sitter.Node) bool { func (py *PyScopeBuilder) DeclaresVariable(node *sitter.Node) bool { typ := node.Type() - return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" || typ == "try_statement" || typ == "class_definition" + return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" || typ == "try_statement" || typ == "class_definition" || typ == "named_expression" || typ == "for_in_clause" } func (py *PyScopeBuilder) scanDecl(idOrPattern, declarator *sitter.Node, decls []*Variable) []*Variable { @@ -173,6 +173,16 @@ func (py *PyScopeBuilder) CollectVariables(node *sitter.Node) []*Variable { } } + case "named_expression": + name := node.ChildByFieldName("name") + if name != nil && name.Type() == "identifier" { + declaredVars = append(declaredVars, &Variable{ + Kind: VarKindVariable, + Name: name.Content(py.source), + DeclNode: node, + }) + } + case "class_definition": name := node.ChildByFieldName("name") if name != nil && name.Type() == "identifier" { @@ -182,7 +192,13 @@ func (py *PyScopeBuilder) CollectVariables(node *sitter.Node) []*Variable { DeclNode: node, }) } - } + + // used in `list_comprehension`, `dictionary_comprehension`, `generator_comprehension` + // `set_comprehension` + case "for_in_clause": + left := node.ChildByFieldName("left") + declaredVars = py.scanDecl(left, node, declaredVars) + } return declaredVars } @@ -241,6 +257,14 @@ func (py *PyScopeBuilder) OnNodeEnter(node *sitter.Node, scope *Scope) { return } + if parentType == "named_expression" && parent.ChildByFieldName("name") == node { + return + } + + if parentType == "for_in_clause" && parent.ChildByFieldName("left") == node { + return + } + // resolve this reference variable := scope.Lookup(node.Content(py.source)) if variable == nil { diff --git a/pkg/analysis/scope_py_test.go b/pkg/analysis/scope_py_test.go index 43948a69..4a25cd8c 100644 --- a/pkg/analysis/scope_py_test.go +++ b/pkg/analysis/scope_py_test.go @@ -1,7 +1,6 @@ package analysis import ( - // "fmt" "testing" "github.com/stretchr/testify/assert" @@ -168,6 +167,62 @@ func Test_PyBuildScopeTree(t *testing.T) { } }) + t.Run("supports walrus operator", func(t *testing.T) { + source := ` +if (n := random.randint(1, 100)) > 50: + print("Greater than 50") + ` + parsed := parsePyFile(t, source) + + scopeTree := MakeScopeTree(parsed.Language, parsed.Ast, parsed.Source) + require.NotNil(t, scopeTree) + + globalScope := scopeTree.Root.Children[0] + require.NotNil(t, globalScope) + + varN, exists := globalScope.Children[0].Variables["n"] + require.NotNil(t, varN) + require.True(t, exists) + }) + + // for `list_comprehension`, `dictionary_comprehension`, `generator_comprehension`, `set_comprehension` + t.Run("supports comprehension statements", func(t *testing.T) { + source := ` +a = [x for x in range(10) if x % 2 == 0] + +b = {x: x**2 for x in myList if x == 10} + ` + parsed := parsePyFile(t, source) + + scopeTree := MakeScopeTree(parsed.Language, parsed.Ast, parsed.Source) + require.NotNil(t, scopeTree) + + globalScope := scopeTree.Root.Children[0] + require.NotNil(t, globalScope) + + { + lcScope := globalScope.Children[0] + require.NotNil(t, lcScope) + varX, exists := lcScope.Variables["x"] + require.NotNil(t, varX) + require.True(t, exists) + + xRefs := varX.Refs + assert.Equal(t, 2, len(xRefs)) // first in the expression, second in the if-else statement + } + + { + dcScope := globalScope.Children[1] + require.NotNil(t, dcScope) + varX, exists := dcScope.Variables["x"] + require.NotNil(t, varX) + require.True(t, exists) + + xRefs := varX.Refs + assert.Equal(t, 3, len(xRefs)) + } + }) + t.Run("supports exception statements", func(t *testing.T) { source := ` try: From 19f72785286f0996226deadd1f615584123e879f Mon Sep 17 00:00:00 2001 From: Maharshi Basu Date: Tue, 15 Apr 2025 19:15:47 +0530 Subject: [PATCH 10/16] add support for loops --- analysis/py_scope.go | 14 +++++++++++--- analysis/py_scope_test.go | 35 +++++++++++++++++++++++++++++++++++ pkg/analysis/scope_py.go | 14 +++++++++++--- pkg/analysis/scope_py_test.go | 35 +++++++++++++++++++++++++++++++++++ 4 files changed, 92 insertions(+), 6 deletions(-) diff --git a/analysis/py_scope.go b/analysis/py_scope.go index 48a9a306..bf7c4dde 100644 --- a/analysis/py_scope.go +++ b/analysis/py_scope.go @@ -48,7 +48,7 @@ func (py *PyScopeBuilder) NodeCreatesScope(node *sitter.Node) bool { func (py *PyScopeBuilder) DeclaresVariable(node *sitter.Node) bool { typ := node.Type() - return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" || typ == "try_statement" || typ == "class_definition" || typ == "named_expression" || typ == "for_in_clause" + return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" || typ == "try_statement" || typ == "class_definition" || typ == "named_expression" || typ == "for_in_clause" || typ == "for_statement" } func (py *PyScopeBuilder) scanDecl(idOrPattern, declarator *sitter.Node, decls []*Variable) []*Variable { @@ -173,6 +173,10 @@ func (py *PyScopeBuilder) CollectVariables(node *sitter.Node) []*Variable { } } + case "for_statement": + left := node.ChildByFieldName("left") + declaredVars = py.scanDecl(left, node, declaredVars) + case "named_expression": name := node.ChildByFieldName("name") if name != nil && name.Type() == "identifier" { @@ -192,13 +196,13 @@ func (py *PyScopeBuilder) CollectVariables(node *sitter.Node) []*Variable { DeclNode: node, }) } - + // used in `list_comprehension`, `dictionary_comprehension`, `generator_comprehension` // `set_comprehension` case "for_in_clause": left := node.ChildByFieldName("left") declaredVars = py.scanDecl(left, node, declaredVars) - } + } return declaredVars } @@ -265,6 +269,10 @@ func (py *PyScopeBuilder) OnNodeEnter(node *sitter.Node, scope *Scope) { return } + if parentType == "for_statement" { + return + } + // resolve this reference variable := scope.Lookup(node.Content(py.source)) if variable == nil { diff --git a/analysis/py_scope_test.go b/analysis/py_scope_test.go index 4a25cd8c..80034efa 100644 --- a/analysis/py_scope_test.go +++ b/analysis/py_scope_test.go @@ -223,6 +223,41 @@ b = {x: x**2 for x in myList if x == 10} } }) + t.Run("supports loop statements", func(t *testing.T) { + source := ` +for id, value in enumerate(someList): + print(id, value) + ` + parsed := parsePyFile(t, source) + + scopeTree := MakeScopeTree(parsed.Language, parsed.Ast, parsed.Source) + require.NotNil(t, scopeTree) + + globalScope := scopeTree.Root.Children[0] + require.NotNil(t, globalScope) + + forLoopScope := globalScope.Children[0] + require.NotNil(t, forLoopScope) + { + varId, exists := globalScope.Variables["id"] + require.NotNil(t, varId) + require.True(t, exists) + + idRefs := varId.Refs + assert.Equal(t, 1, len(idRefs)) + } + + { + varValue, exists := globalScope.Variables["value"] + require.NotNil(t, varValue) + require.True(t, exists) + + valueRefs := varValue.Refs + assert.Equal(t, 1, len(valueRefs)) + } + + }) + t.Run("supports exception statements", func(t *testing.T) { source := ` try: diff --git a/pkg/analysis/scope_py.go b/pkg/analysis/scope_py.go index 48a9a306..bf7c4dde 100644 --- a/pkg/analysis/scope_py.go +++ b/pkg/analysis/scope_py.go @@ -48,7 +48,7 @@ func (py *PyScopeBuilder) NodeCreatesScope(node *sitter.Node) bool { func (py *PyScopeBuilder) DeclaresVariable(node *sitter.Node) bool { typ := node.Type() - return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" || typ == "try_statement" || typ == "class_definition" || typ == "named_expression" || typ == "for_in_clause" + return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" || typ == "try_statement" || typ == "class_definition" || typ == "named_expression" || typ == "for_in_clause" || typ == "for_statement" } func (py *PyScopeBuilder) scanDecl(idOrPattern, declarator *sitter.Node, decls []*Variable) []*Variable { @@ -173,6 +173,10 @@ func (py *PyScopeBuilder) CollectVariables(node *sitter.Node) []*Variable { } } + case "for_statement": + left := node.ChildByFieldName("left") + declaredVars = py.scanDecl(left, node, declaredVars) + case "named_expression": name := node.ChildByFieldName("name") if name != nil && name.Type() == "identifier" { @@ -192,13 +196,13 @@ func (py *PyScopeBuilder) CollectVariables(node *sitter.Node) []*Variable { DeclNode: node, }) } - + // used in `list_comprehension`, `dictionary_comprehension`, `generator_comprehension` // `set_comprehension` case "for_in_clause": left := node.ChildByFieldName("left") declaredVars = py.scanDecl(left, node, declaredVars) - } + } return declaredVars } @@ -265,6 +269,10 @@ func (py *PyScopeBuilder) OnNodeEnter(node *sitter.Node, scope *Scope) { return } + if parentType == "for_statement" { + return + } + // resolve this reference variable := scope.Lookup(node.Content(py.source)) if variable == nil { diff --git a/pkg/analysis/scope_py_test.go b/pkg/analysis/scope_py_test.go index 4a25cd8c..80034efa 100644 --- a/pkg/analysis/scope_py_test.go +++ b/pkg/analysis/scope_py_test.go @@ -223,6 +223,41 @@ b = {x: x**2 for x in myList if x == 10} } }) + t.Run("supports loop statements", func(t *testing.T) { + source := ` +for id, value in enumerate(someList): + print(id, value) + ` + parsed := parsePyFile(t, source) + + scopeTree := MakeScopeTree(parsed.Language, parsed.Ast, parsed.Source) + require.NotNil(t, scopeTree) + + globalScope := scopeTree.Root.Children[0] + require.NotNil(t, globalScope) + + forLoopScope := globalScope.Children[0] + require.NotNil(t, forLoopScope) + { + varId, exists := globalScope.Variables["id"] + require.NotNil(t, varId) + require.True(t, exists) + + idRefs := varId.Refs + assert.Equal(t, 1, len(idRefs)) + } + + { + varValue, exists := globalScope.Variables["value"] + require.NotNil(t, varValue) + require.True(t, exists) + + valueRefs := varValue.Refs + assert.Equal(t, 1, len(valueRefs)) + } + + }) + t.Run("supports exception statements", func(t *testing.T) { source := ` try: From e77b1e16e836b17f6d88cc033be4781389ec87eb Mon Sep 17 00:00:00 2001 From: Maharshi Basu Date: Tue, 15 Apr 2025 21:08:57 +0530 Subject: [PATCH 11/16] add support for lambda expressions --- analysis/py_scope.go | 9 ++++++++- analysis/py_scope_test.go | 20 ++++++++++++++++++++ pkg/analysis/scope_py.go | 9 ++++++++- pkg/analysis/scope_py_test.go | 20 ++++++++++++++++++++ 4 files changed, 56 insertions(+), 2 deletions(-) diff --git a/analysis/py_scope.go b/analysis/py_scope.go index bf7c4dde..17475e4d 100644 --- a/analysis/py_scope.go +++ b/analysis/py_scope.go @@ -48,7 +48,7 @@ func (py *PyScopeBuilder) NodeCreatesScope(node *sitter.Node) bool { func (py *PyScopeBuilder) DeclaresVariable(node *sitter.Node) bool { typ := node.Type() - return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" || typ == "try_statement" || typ == "class_definition" || typ == "named_expression" || typ == "for_in_clause" || typ == "for_statement" + return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" || typ == "try_statement" || typ == "class_definition" || typ == "named_expression" || typ == "for_in_clause" || typ == "for_statement" || typ == "lambda_parameters" } func (py *PyScopeBuilder) scanDecl(idOrPattern, declarator *sitter.Node, decls []*Variable) []*Variable { @@ -117,6 +117,9 @@ func (py *PyScopeBuilder) CollectVariables(node *sitter.Node) []*Variable { case "parameters": declaredVars = py.variableFromFunctionParams(node, declaredVars) + case "lambda_parameters": + declaredVars = py.variableFromFunctionParams(node, declaredVars) + case "aliased_import": // import as aliasName := node.ChildByFieldName("name") @@ -273,6 +276,10 @@ func (py *PyScopeBuilder) OnNodeEnter(node *sitter.Node, scope *Scope) { return } + if parentType == "lambda_parameters" { + return + } + // resolve this reference variable := scope.Lookup(node.Content(py.source)) if variable == nil { diff --git a/analysis/py_scope_test.go b/analysis/py_scope_test.go index 80034efa..eae8d8f8 100644 --- a/analysis/py_scope_test.go +++ b/analysis/py_scope_test.go @@ -258,6 +258,26 @@ for id, value in enumerate(someList): }) + t.Run("supports lambda expressions", func(t *testing.T) { + source := ` +a = lambda x: x**2 + ` + parsed := parsePyFile(t, source) + + scopeTree := MakeScopeTree(parsed.Language, parsed.Ast, parsed.Source) + require.NotNil(t, scopeTree) + + globalScope := scopeTree.Root.Children[0] + require.NotNil(t, globalScope) + + lambdaScope := globalScope.Children[0] + varX, exists := lambdaScope.Variables["x"] + require.NotNil(t, varX) + require.True(t, exists) + + assert.Equal(t, 1, len(varX.Refs)) + }) + t.Run("supports exception statements", func(t *testing.T) { source := ` try: diff --git a/pkg/analysis/scope_py.go b/pkg/analysis/scope_py.go index bf7c4dde..17475e4d 100644 --- a/pkg/analysis/scope_py.go +++ b/pkg/analysis/scope_py.go @@ -48,7 +48,7 @@ func (py *PyScopeBuilder) NodeCreatesScope(node *sitter.Node) bool { func (py *PyScopeBuilder) DeclaresVariable(node *sitter.Node) bool { typ := node.Type() - return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" || typ == "try_statement" || typ == "class_definition" || typ == "named_expression" || typ == "for_in_clause" || typ == "for_statement" + return typ == "assignment" || typ == "dotted_name" || typ == "aliased_import" || typ == "with_statement" || typ == "parameters" || typ == "function_definition" || typ == "try_statement" || typ == "class_definition" || typ == "named_expression" || typ == "for_in_clause" || typ == "for_statement" || typ == "lambda_parameters" } func (py *PyScopeBuilder) scanDecl(idOrPattern, declarator *sitter.Node, decls []*Variable) []*Variable { @@ -117,6 +117,9 @@ func (py *PyScopeBuilder) CollectVariables(node *sitter.Node) []*Variable { case "parameters": declaredVars = py.variableFromFunctionParams(node, declaredVars) + case "lambda_parameters": + declaredVars = py.variableFromFunctionParams(node, declaredVars) + case "aliased_import": // import as aliasName := node.ChildByFieldName("name") @@ -273,6 +276,10 @@ func (py *PyScopeBuilder) OnNodeEnter(node *sitter.Node, scope *Scope) { return } + if parentType == "lambda_parameters" { + return + } + // resolve this reference variable := scope.Lookup(node.Content(py.source)) if variable == nil { diff --git a/pkg/analysis/scope_py_test.go b/pkg/analysis/scope_py_test.go index 80034efa..eae8d8f8 100644 --- a/pkg/analysis/scope_py_test.go +++ b/pkg/analysis/scope_py_test.go @@ -258,6 +258,26 @@ for id, value in enumerate(someList): }) + t.Run("supports lambda expressions", func(t *testing.T) { + source := ` +a = lambda x: x**2 + ` + parsed := parsePyFile(t, source) + + scopeTree := MakeScopeTree(parsed.Language, parsed.Ast, parsed.Source) + require.NotNil(t, scopeTree) + + globalScope := scopeTree.Root.Children[0] + require.NotNil(t, globalScope) + + lambdaScope := globalScope.Children[0] + varX, exists := lambdaScope.Variables["x"] + require.NotNil(t, varX) + require.True(t, exists) + + assert.Equal(t, 1, len(varX.Refs)) + }) + t.Run("supports exception statements", func(t *testing.T) { source := ` try: From e26b8514c2b934fa542c43975216de1fbf31f3cf Mon Sep 17 00:00:00 2001 From: Maharshi Basu Date: Thu, 17 Apr 2025 10:24:26 +0530 Subject: [PATCH 12/16] chore: indent test examples + add more nodes for scope creation --- analysis/py_scope.go | 13 +++++++++++++ analysis/py_scope_test.go | 38 +++++++++++++++++++------------------- 2 files changed, 32 insertions(+), 19 deletions(-) diff --git a/analysis/py_scope.go b/analysis/py_scope.go index 17475e4d..98e02c2b 100644 --- a/analysis/py_scope.go +++ b/analysis/py_scope.go @@ -39,6 +39,8 @@ var PyScopeNodes = []string{ "except_clause", "list_comprehension", "dictionary_comprehension", + "set_comprehension", + "generator_expression", "lambda", } @@ -89,6 +91,17 @@ func (py *PyScopeBuilder) scanDecl(idOrPattern, declarator *sitter.Node, decls [ }) } } + + case "list_splat_pattern": + splatNode := ChildrenOfType(idOrPattern, "identifier") + splatId := splatNode[0].Child(0) + if splatId.Type() == "identifier" { + decls = append(decls, &Variable{ + Kind: VarKindVariable, + Name: splatId.Content(py.source), + DeclNode: declarator, + }) + } } return decls diff --git a/analysis/py_scope_test.go b/analysis/py_scope_test.go index eae8d8f8..f6e8826b 100644 --- a/analysis/py_scope_test.go +++ b/analysis/py_scope_test.go @@ -17,10 +17,10 @@ func parsePyFile(t *testing.T, source string) *ParseResult { func Test_PyBuildScopeTree(t *testing.T) { t.Run("is able to resolve references", func(t *testing.T) { source := ` - x = 1 - if True: - y = x - z = x` +x = 1 +if True: + y = x +z = x` parsed := parsePyFile(t, source) scopeTree := MakeScopeTree(parsed.Language, parsed.Ast, parsed.Source) @@ -47,14 +47,14 @@ func Test_PyBuildScopeTree(t *testing.T) { t.Run("supports import statements", func(t *testing.T) { source := ` - import os +import os - os.system("cat file.txt") +os.system("cat file.txt") - from csv import read +from csv import read - if True: - f = read(file.csv) +if True: + f = read(file.csv) ` parsed := parsePyFile(t, source) @@ -96,10 +96,10 @@ func Test_PyBuildScopeTree(t *testing.T) { t.Run("supports function parameters", func(t *testing.T) { source := ` - def myFunc(a, b=2, c:int, d:str="Hello"): - A = otherFunc(a) - C = b + c - print(d) +def myFunc(a, b=2, c:int, d:str="Hello"): + A = otherFunc(a) + C = b + c + print(d) ` parsed := parsePyFile(t, source) @@ -144,8 +144,8 @@ func Test_PyBuildScopeTree(t *testing.T) { t.Run("supports with statements", func(t *testing.T) { source := ` - with open("file.txt", 'r') as f: - print(f.read(5)) +with open("file.txt", 'r') as f: + print(f.read(5)) ` parsed := parsePyFile(t, source) @@ -280,10 +280,10 @@ a = lambda x: x**2 t.Run("supports exception statements", func(t *testing.T) { source := ` - try: - result = 10 / 2 - except ZeroDivisionError as e: - print(e) +try: + result = 10 / 2 +except ZeroDivisionError as e: + print(e) ` parsed := parsePyFile(t, source) From d71667ae4b7fa4546d21b8334b0a44b73c589568 Mon Sep 17 00:00:00 2001 From: Maharshi Basu Date: Mon, 21 Apr 2025 17:14:49 +0530 Subject: [PATCH 13/16] data flow analysis for python --- checkers/python/py_dataflow.go | 350 +++++++++++++++++++++++++++++++++ checkers/python/scope.go | 22 +++ 2 files changed, 372 insertions(+) create mode 100644 checkers/python/py_dataflow.go create mode 100644 checkers/python/scope.go diff --git a/checkers/python/py_dataflow.go b/checkers/python/py_dataflow.go new file mode 100644 index 00000000..cf2b2505 --- /dev/null +++ b/checkers/python/py_dataflow.go @@ -0,0 +1,350 @@ +//global:registry-exclude + +package python + +import ( + "fmt" + "reflect" + + sitter "github.com/smacker/go-tree-sitter" + "globstar.dev/analysis" +) + +var DataFlowAnalyzer = &analysis.Analyzer{ + Name: "py-dataflow-analyzer", + Language: analysis.LangPy, + Description: "Create a data flow graph for Python", + Category: analysis.CategorySecurity, + Severity: analysis.SeverityWarning, + Run: createPyDFG, + ResultType: reflect.TypeOf(&DataFlowGraph{}), + Requires: []*analysis.Analyzer{ScopeAnalyzer}, +} + +type DataFlowNode struct { + Node *sitter.Node + Sources []*DataFlowNode + Scope *analysis.Scope + Variable *analysis.Variable + FuncDef *FunctionDefinition +} + +type FunctionDefinition struct { + Node *sitter.Node + Parameters []*analysis.Variable + Body *sitter.Node + Scope *analysis.Scope +} + +type ClassDefinition struct { + Node *sitter.Node + Properties []*analysis.Variable + Methods []*FunctionDefinition + Scope *analysis.Scope +} + +type DataFlowGraph struct { + Graph map[*analysis.Variable]*DataFlowNode + ScopeTree *analysis.ScopeTree + FunDefs map[string]*FunctionDefinition + ClassDefs map[*analysis.Variable]*ClassDefinition +} + +var functionDefinitions = make(map[string]*FunctionDefinition) +var classDefinitions = make(map[*analysis.Variable]*ClassDefinition) + +func createPyDFG(pass *analysis.Pass) (interface{}, error) { + scopeResult, err := buildScopeTree(pass) + if err != nil { + return nil, fmt.Errorf("failed to build the source tree") + } + + scopeTree := scopeResult.(*analysis.ScopeTree) + + dfg := &DataFlowGraph{ + Graph: make(map[*analysis.Variable]*DataFlowNode), + ScopeTree: scopeTree, + FunDefs: make(map[string]*FunctionDefinition), + } + + analysis.Preorder(pass, func(node *sitter.Node) { + if node == nil { + return + } + + currentScope := scopeTree.GetScope(node) + if currentScope == nil { + return + } + + // track variable declarations and assignments + if node.Type() == "assignment" { + var nameNode, valueNode *sitter.Node + + nameNode = node.ChildByFieldName("left") + valueNode = node.ChildByFieldName("right") + + if nameNode != nil && nameNode.Type() == "identifier" && valueNode != nil { + var dfNode *DataFlowNode + varName := nameNode.Content(pass.FileContext.Source) + variable := currentScope.Lookup(varName) + + if variable == nil { + dfNode = &DataFlowNode{ + Node: nameNode, + Sources: []*DataFlowNode{}, + Scope: currentScope, + Variable: variable, + } + } + + switch valueNode.Type() { + case "identifier": + // if value is another variable, link to its data flow node + sourceVarName := valueNode.Content(pass.FileContext.Source) + currVar := currentScope.Lookup(sourceVarName) + if sourceNode, exists := dfg.Graph[currVar]; exists { + dfNode.Sources = append(dfNode.Sources, sourceNode) + } + + case "call": + handleFunctionCallDataFlow(valueNode, dfNode, dfg.Graph, pass.FileContext.Source, currentScope) + + case "binary_operator": + handleBinaryExprDataFlow(valueNode, dfNode, dfg.Graph, pass.FileContext.Source, currentScope) + + // analyze the variables in an f-string + case "string": + if valueNode.Content(pass.FileContext.Source)[0] == 'f' { + handleFStringDataFlow(valueNode, dfNode, dfg.Graph, pass.FileContext.Source, currentScope) + } + + // lambda expressions are also functions + case "lambda": + lambdaScope := scopeTree.GetScope(valueNode) + lambdaBody := valueNode.ChildByFieldName("body") + if lambdaBody == nil { + return + } + + funcDef := &FunctionDefinition{ + Node: valueNode, + Body: lambdaBody, + Scope: lambdaScope, + } + + for _, param := range lambdaScope.Variables { + funcDef.Parameters = append(funcDef.Parameters, param) + } + + functionDefinitions[varName] = funcDef + dfNode.FuncDef = funcDef + } + dfg.Graph[variable] = dfNode + + } + } + + if node.Type() == "function_definition" { + funcNameNode := node.ChildByFieldName("name") + if funcNameNode == nil { + return + } + + funcName := funcNameNode.Content(pass.FileContext.Source) + funcDef := &FunctionDefinition{ + Node: node, + Body: node.ChildByFieldName("body"), + Scope: currentScope, + } + + funcVar := currentScope.Lookup(funcName) + if funcVar == nil { + return + } + + for _, param := range currentScope.Variables { + funcDef.Parameters = append(funcDef.Parameters, param) + } + + functionDefinitions[funcName] = funcDef + dfg.Graph[funcVar] = &DataFlowNode{ + Node: funcNameNode, + Sources: []*DataFlowNode{}, + Scope: currentScope, + Variable: funcVar, + FuncDef: funcDef, + } + } + + if node.Type() == "class_definition" { + var dfNode *DataFlowNode + className := node.ChildByFieldName("name") + if className == nil { + return + } + + varClassName := className.Content(pass.FileContext.Source) + classNameVar := currentScope.Lookup(varClassName) + classScope := scopeTree.GetScope(classNameVar.DeclNode) + if classScope == nil { + return + } + + classBody := node.ChildByFieldName("body") + if classBody == nil { + return + } + + var classMethods []*FunctionDefinition + var classProperties []*analysis.Variable + + dfNode = &DataFlowNode{ + Node: classNameVar.DeclNode, + Scope: classScope, + Variable: classNameVar, + } + + dfg.Graph[dfNode.Variable] = dfNode + + for i := range int(classBody.NamedChildCount()) { + classChild := classBody.NamedChild(i) + if classChild == nil { + return + } + + if classChild.Type() == "function_definition" { + classMethodNameNode := classChild.ChildByFieldName("name") + if classMethodNameNode != nil && classMethodNameNode.Type() == "identifier" { + methodDef := &FunctionDefinition{ + Node: classChild, + Body: classChild.ChildByFieldName("body"), + Parameters: []*analysis.Variable{}, + Scope: classScope, + } + + params := node.ChildByFieldName("parameters") + if params != nil { + for i := range int(params.NamedChildCount()) { + param := params.NamedChild(i) + if param.Type() == "identifier" { + paramName := param.Content(pass.FileContext.Source) + paramVar := currentScope.Lookup(paramName) + if paramVar != nil { + methodDef.Parameters = append(methodDef.Parameters, paramVar) + } + } + } + } + classMethods = append(classMethods, methodDef) + } + } else if classChild.Type() == "assignment" { + classVarNameNode := classChild.ChildByFieldName("left") + if classVarNameNode != nil && classVarNameNode.Type() == "identifier" { + classVarName := classVarNameNode.Content(pass.FileContext.Source) + classVar := classScope.Children[0].Lookup(classVarName) + if classVar != nil { + classProperties = append(classProperties, classVar) + } + } + } + } + + classDef := &ClassDefinition{ + Node: node, + Properties: classProperties, + Methods: classMethods, + Scope: classScope, + } + + classDefinitions[classNameVar] = classDef + } + }) + + dfg.FunDefs = functionDefinitions + dfg.ClassDefs = classDefinitions + + return dfg, nil +} + +func handleFStringDataFlow(node *sitter.Node, dfNode *DataFlowNode, dfg map[*analysis.Variable]*DataFlowNode, source []byte, scope *analysis.Scope) { + if node == nil || node.Type() != "string" { + return + } + + interpolations := analysis.ChildrenWithFieldName(node, "interpolation") + for _, interpNode := range interpolations { + exprNode := interpNode.ChildByFieldName("expression") + if exprNode != nil && exprNode.Type() == "identifier" { + varName := exprNode.Content(source) + if variable := scope.Lookup(varName); variable != nil { + if sourceNode, exists := dfg[variable]; exists { + dfNode.Sources = append(dfNode.Sources, sourceNode) + } + } + } + } +} + +func handleBinaryExprDataFlow(node *sitter.Node, dfNode *DataFlowNode, dfg map[*analysis.Variable]*DataFlowNode, source []byte, scope *analysis.Scope) { + if node == nil || node.Type() != "binary_operator" { + return + } + + left := node.ChildByFieldName("left") + right := node.ChildByFieldName("right") + + if left != nil && left.Type() == "identifier" { + leftVar := left.Content(source) + if variable := scope.Lookup(leftVar); variable != nil { + if sourceNode, exists := dfg[variable]; exists { + dfNode.Sources = append(dfNode.Sources, sourceNode) + } + } + } + + if right != nil && right.Type() == "identifier" { + rightVar := right.Content(source) + if variable := scope.Lookup(rightVar); variable != nil { + if sourceNode, exists := dfg[variable]; exists { + dfNode.Sources = append(dfNode.Sources, sourceNode) + } + } + } + + // process nested binary expression + if left != nil && left.Type() == "binary_operator" { + handleBinaryExprDataFlow(left, dfNode, dfg, source, scope) + } + + if right != nil && right.Type() == "binary_operator" { + handleBinaryExprDataFlow(right, dfNode, dfg, source, scope) + } +} + +func handleFunctionCallDataFlow(node *sitter.Node, dfNode *DataFlowNode, dfg map[*analysis.Variable]*DataFlowNode, source []byte, scope *analysis.Scope) { + if node == nil || node.Type() != "call" { + return + } + + args := node.ChildByFieldName("arguments") + if args == nil || args.Type() != "argument_list" { + return + } + + for i := range int(args.NamedChildCount()) { + arg := args.NamedChild(i) + if arg == nil { + continue + } + + if arg.Type() == "identifier" { + argName := arg.Content(source) + if variable := scope.Lookup(argName); variable != nil { + if sourceNode, exists := dfg[variable]; exists { + dfNode.Sources = append(dfNode.Sources, sourceNode) + } + } + } + } +} diff --git a/checkers/python/scope.go b/checkers/python/scope.go new file mode 100644 index 00000000..968bf926 --- /dev/null +++ b/checkers/python/scope.go @@ -0,0 +1,22 @@ +//globstar:registry-exclude +// scope resolution for Python files + +package python + +import ( + "globstar.dev/analysis" + "reflect" +) + +var ScopeAnalyzer = &analysis.Analyzer{ + Name: "py-scope", + ResultType: reflect.TypeOf(&analysis.ScopeTree{}), + Run: buildScopeTree, + Language: analysis.LangPy, +} + +func buildScopeTree(pass *analysis.Pass) (any, error) { + // creates scope builder for python + scope := analysis.MakeScopeTree(pass.Analyzer.Language, pass.FileContext.Ast, pass.FileContext.Source) + return scope, nil +} From 17c5f0abd0026651197825bd23f75e1fb727a49d Mon Sep 17 00:00:00 2001 From: Maharshi Basu Date: Mon, 21 Apr 2025 17:22:50 +0530 Subject: [PATCH 14/16] fix registry exclude directive --- checkers/python/py_dataflow.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/checkers/python/py_dataflow.go b/checkers/python/py_dataflow.go index cf2b2505..f971f6de 100644 --- a/checkers/python/py_dataflow.go +++ b/checkers/python/py_dataflow.go @@ -1,4 +1,4 @@ -//global:registry-exclude +//globstar:registry-exclude package python From de70b83a443cdc17fa9f63e28d301ec0af033264 Mon Sep 17 00:00:00 2001 From: Maharshi Basu Date: Mon, 21 Apr 2025 17:49:46 +0530 Subject: [PATCH 15/16] add unit tests for python data flow analysis --- checkers/python/py_dataflow.go | 9 ++-- checkers/python/py_dataflow_test.go | 73 +++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 4 deletions(-) create mode 100644 checkers/python/py_dataflow_test.go diff --git a/checkers/python/py_dataflow.go b/checkers/python/py_dataflow.go index f971f6de..27412026 100644 --- a/checkers/python/py_dataflow.go +++ b/checkers/python/py_dataflow.go @@ -16,7 +16,7 @@ var DataFlowAnalyzer = &analysis.Analyzer{ Description: "Create a data flow graph for Python", Category: analysis.CategorySecurity, Severity: analysis.SeverityWarning, - Run: createPyDFG, + Run: createDataFlowGraph, ResultType: reflect.TypeOf(&DataFlowGraph{}), Requires: []*analysis.Analyzer{ScopeAnalyzer}, } @@ -53,7 +53,7 @@ type DataFlowGraph struct { var functionDefinitions = make(map[string]*FunctionDefinition) var classDefinitions = make(map[*analysis.Variable]*ClassDefinition) -func createPyDFG(pass *analysis.Pass) (interface{}, error) { +func createDataFlowGraph(pass *analysis.Pass) (interface{}, error) { scopeResult, err := buildScopeTree(pass) if err != nil { return nil, fmt.Errorf("failed to build the source tree") @@ -238,8 +238,9 @@ func createPyDFG(pass *analysis.Pass) (interface{}, error) { } classMethods = append(classMethods, methodDef) } - } else if classChild.Type() == "assignment" { - classVarNameNode := classChild.ChildByFieldName("left") + } else if classChild.Type() == "expression_statement" { + assignNode := analysis.FirstChildOfType(classChild, "assignment") + classVarNameNode := assignNode.ChildByFieldName("left") if classVarNameNode != nil && classVarNameNode.Type() == "identifier" { classVarName := classVarNameNode.Content(pass.FileContext.Source) classVar := classScope.Children[0].Lookup(classVarName) diff --git a/checkers/python/py_dataflow_test.go b/checkers/python/py_dataflow_test.go new file mode 100644 index 00000000..bf679a25 --- /dev/null +++ b/checkers/python/py_dataflow_test.go @@ -0,0 +1,73 @@ +package python + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "globstar.dev/analysis" +) + +func parsePyCode(t *testing.T, source []byte) *analysis.ParseResult { + pass, err := analysis.Parse("file.py", source, analysis.LangPy, analysis.LangPy.Grammar()) + require.NoError(t, err) + + return pass +} + +func TestDataFlowAnalysis(t *testing.T) { + t.Run("variable assignnment data flow", func(t *testing.T) { + source := ` +x = 10 +def f(x): + pass +f(x) + ` + parseResult := parsePyCode(t, []byte(source)) + pass := &analysis.Pass{ + Analyzer: + DataFlowAnalyzer, + FileContext: parseResult, + } + + dfgStruct, err := createDataFlowGraph(pass) + assert.NoError(t, err) + dfg := dfgStruct.(*DataFlowGraph) + flowGraph := dfg.Graph + assert.NotNil(t, flowGraph) + scopeTree := dfg.ScopeTree + assert.NotNil(t, scopeTree) + }) +} + +func TestClassDataFlow(t *testing.T) { + source := ` +class A: + a = 10 + def __init__(self, a): + self.a = a + def f(self): + return self.a + ` + parseResult := parsePyCode(t, []byte(source)) + pass := &analysis.Pass{ + Analyzer: DataFlowAnalyzer, + FileContext: parseResult, + } + dfgStruct, err := createDataFlowGraph(pass) + assert.NoError(t, err) + dfg := dfgStruct.(*DataFlowGraph) + scopeTree := dfg.ScopeTree + graph := dfg.Graph + assert.NotNil(t, scopeTree) + classVar := scopeTree.Root.Children[0].Lookup("A") + assert.NotNil(t, classVar) + dfgClassNode := graph[classVar] + assert.NotNil(t, dfgClassNode) + + classDef := dfg.ClassDefs + assert.NotNil(t, classDef) + assert.NotNil(t, classDef[classVar]) + assert.Greater(t, len(classDef[classVar].Methods), 0) + assert.Greater(t, len(classDef[classVar].Properties), 0) +} \ No newline at end of file From 9a1791cdf685a6057dcec559f78c1d51df22e279 Mon Sep 17 00:00:00 2001 From: Maharshi Basu Date: Tue, 22 Apr 2025 15:12:01 +0530 Subject: [PATCH 16/16] fix runtime bugs --- checkers/python/py_dataflow.go | 11 ++++++----- checkers/python/py_dataflow_test.go | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/checkers/python/py_dataflow.go b/checkers/python/py_dataflow.go index 27412026..6fe0658e 100644 --- a/checkers/python/py_dataflow.go +++ b/checkers/python/py_dataflow.go @@ -46,7 +46,7 @@ type ClassDefinition struct { type DataFlowGraph struct { Graph map[*analysis.Variable]*DataFlowNode ScopeTree *analysis.ScopeTree - FunDefs map[string]*FunctionDefinition + FuncDefs map[string]*FunctionDefinition ClassDefs map[*analysis.Variable]*ClassDefinition } @@ -64,7 +64,7 @@ func createDataFlowGraph(pass *analysis.Pass) (interface{}, error) { dfg := &DataFlowGraph{ Graph: make(map[*analysis.Variable]*DataFlowNode), ScopeTree: scopeTree, - FunDefs: make(map[string]*FunctionDefinition), + FuncDefs: make(map[string]*FunctionDefinition), } analysis.Preorder(pass, func(node *sitter.Node) { @@ -89,7 +89,7 @@ func createDataFlowGraph(pass *analysis.Pass) (interface{}, error) { varName := nameNode.Content(pass.FileContext.Source) variable := currentScope.Lookup(varName) - if variable == nil { + if variable != nil { dfNode = &DataFlowNode{ Node: nameNode, Sources: []*DataFlowNode{}, @@ -243,7 +243,8 @@ func createDataFlowGraph(pass *analysis.Pass) (interface{}, error) { classVarNameNode := assignNode.ChildByFieldName("left") if classVarNameNode != nil && classVarNameNode.Type() == "identifier" { classVarName := classVarNameNode.Content(pass.FileContext.Source) - classVar := classScope.Children[0].Lookup(classVarName) + fmt.Println(assignNode.Content(pass.FileContext.Source)) + classVar := classScope.Lookup(classVarName) if classVar != nil { classProperties = append(classProperties, classVar) } @@ -262,7 +263,7 @@ func createDataFlowGraph(pass *analysis.Pass) (interface{}, error) { } }) - dfg.FunDefs = functionDefinitions + dfg.FuncDefs = functionDefinitions dfg.ClassDefs = classDefinitions return dfg, nil diff --git a/checkers/python/py_dataflow_test.go b/checkers/python/py_dataflow_test.go index bf679a25..e6852f54 100644 --- a/checkers/python/py_dataflow_test.go +++ b/checkers/python/py_dataflow_test.go @@ -1,6 +1,7 @@ package python import ( + // "fmt" "testing" "github.com/stretchr/testify/assert" @@ -44,6 +45,7 @@ func TestClassDataFlow(t *testing.T) { source := ` class A: a = 10 + b = 10 def __init__(self, a): self.a = a def f(self): @@ -70,4 +72,21 @@ class A: assert.NotNil(t, classDef[classVar]) assert.Greater(t, len(classDef[classVar].Methods), 0) assert.Greater(t, len(classDef[classVar].Properties), 0) +} + +func TestBinaryExpressionDataFlow(t *testing.T) { + source := ` +class A: + a: int + c: int + ` + parseResult := parsePyCode(t, []byte(source)) + pass := &analysis.Pass{ + Analyzer: DataFlowAnalyzer, + FileContext: parseResult, + } + dfgStruct, err := createDataFlowGraph(pass) + // fmt.Println(dfgStruct.sco) + assert.NoError(t, err) + assert.NotNil(t, dfgStruct) } \ No newline at end of file