How to resolve types in Python code with Tree-sitter?

283 Views Asked by At

I'm using Tree-sitter to parse Python code and extract ASTs, and trying to manually traverse ASTs to infer types based on assignments and function definitions.

But I'm struggling with accurately resolving types (variables, functions, classes) due to Python's dynamic typing. Specifically, challenges arise with:

  • Inferring types in dynamically typed contexts.
  • Handling types from external modules/packages.
  • Leveraging Python's type annotations for improved type resolution.

I just need to resolve type that I have defined in my repository.

For example:

car.py 
class Car
# ...

Now in a different class:

from car import Car

car = Car()
# ...
bmw = car
# ...

I need to know BMW is a car.

How can I successfully navigate type resolution in Python with Tree-sitter?

What approaches or algorithms can I use to accurately resolve types defined in external modules without executing the Python code?

I also need to tackle def-use (definition-use) chains to track variable assignments and their types across the codebase.

  • Resolve identifiers to their definitions.
  • Determine the type of each identifier (class, module, etc.).
  • Extract the name of the class or module the identifier refers to.

Also I think it needs to handle control flow as well.

Do we need to implement some form of scope-graph to achieve this?

Also it needs to implement python LEGB logic somehow.

1

There are 1 best solutions below

6
On

Given the dynamic nature of Python, direct type inference can be complex. You can, however, leverage Python's type annotations and incorporate a mix of static analysis techniques: that would enhance the accuracy of type resolution.

[ Python Code ] --> [ Tree-sitter Parser ] --> [ AST ]
                         |
                Manual Traversal & Analysis
                         |
               Enhanced Type Resolution
                         |
           [ Leverage Type Annotations ]
           /              |              \
 Use Heuristics    Handle External    Infer from Function/Method
for Dynamic Types     Modules/Types           Calls

Python's type annotations provide valuable information for type inference. When traversing the AST, specifically look for nodes related to function definitions (FunctionDef) and variable assignments (Assign) that contain annotations.
For the setup from tree-sitter/py-tree-sitter, cloning https://github.com/tree-sitter/tree-sitter-python in a vendor subfolder, and try (as in this jdoodle):

import tree_sitter
import tree_sitter_python as tspython
from tree_sitter import Language, Parser

Language.build_library(
  # Store the library in the `build` directory
  'build/my-languages.so',
  # Include one or more languages
  [
    'vendor/tree-sitter-python'
  ]
)

PY_LANGUAGE = Language(tspython.language(), 'python')
parser = Parser()
parser.set_language(PY_LANGUAGE)

code = """
def add(a: int, b: int) -> int:
    return a + b

x: int = 10
"""

tree = parser.parse(bytes(code, "utf8"))
root_node = tree.root_node

def extract_type_annotations(node):
    if node.type in ["function_definition", "assignment"]:
        for child in node.children:
            if child.type == "type":
                print(f"Found type annotation: {child.text.decode('utf8')}")
            extract_type_annotations(child)

extract_type_annotations(root_node)

For dynamically typed variables or those imported from external modules, consider implementing a heuristic-based approach. You could maintain a map of known types for standard library modules and popular third-party libraries.
When encountering an import statement, check if the module and its types are in your map and apply these types when these modules' members are used. That means, for any variable assignment, function call, or other expressions involving these imported members, you infer their types based on the information previously mapped.
That approach not only aids in resolving types for standard Python types but also for types that come from external libraries, assuming you have them included in your 'known_types' map."

known_types = {
    'numpy': {
        'array': 'numpy.ndarray'
    },
    'pandas': {
        'DataFrame': 'pandas.core.frame.DataFrame'
    }
    # Add more known types as needed
}

def infer_type_from_imports(node, known_types):
    if node.type == "import_statement" or node.type == "import_from_statement":
        module_name = ""
        imported_items = []
        for child in node.children:
            if child.type == "dotted_name":
                module_name = child.text.decode('utf8')
            elif child.type == "alias":
                imported_items.append(child.text.decode('utf8'))
        # Infer types based on known imports
        for item in imported_items:
            if module_name in known_types and item in known_types[module_name]:
                print(f"Inferred type for {item}: {known_types[module_name][item]}")

def traverse_and_infer(node):
    infer_type_from_imports(node, known_types)
    for child in node.children:
        traverse_and_infer(child)

traverse_and_infer(root_node)

For function and method calls, you could analyze the return types based on the annotations in their definitions. That would require building a context or a scope map as you traverse the AST, linking functions and methods to their return types when defined. Then, when these functions or methods are called, you can infer the type based on this context:

function_return_types = {}

def extract_function_return_types(node):
    if node.type == "function_definition":
        function_name = ""
        return_type = None
        for child in node.children:
            if child.type == "identifier":
                function_name = child.text.decode('utf8')
            elif child.type == "type":
                return_type = child.text.decode('utf8')
        if function_name and return_type:
            function_return_types[function_name] = return_type

    for child in node.children:
        extract_function_return_types(child)

extract_function_return_types(root_node)

def infer_type_from_function_call(node):
    if node.type == "call":
        function_name = ""
        for child in node.children:
            if child.type == "identifier":
                function_name = child.text.decode('utf8')
        if function_name in function_return_types:
            print(f"Inferred return type for {function_name}: {function_return_types[function_name]}")

def traverse_and_infer_calls(node):
    infer_type_from_function_call(node)
    for child in node.children:
        traverse_and_infer_calls(child)

traverse_and_infer_calls(root_node)

We also need to do some form of def-use analysis. Isnt it?

Def-use (definition-use) analysis is indeed an important aspect of type resolution, especially in dynamically typed languages like Python. It involves tracking where variables (and more generally, symbols) are defined and where they are used throughout the code. That process is essential for understanding variable scopes, lifetimes, and types, which, in turn, helps in resolving types more accurately and detecting potential errors.

For incorporating def-use analysis into your approach with Tree-sitter for Python code, you would enhance your existing strategy by adding

  • tracking definitions: Every time a variable, function, or class is declared (i.e., "defined"), you record its location (node in the AST), scope, and any type information available through annotations or inferred types.
  • tracking uses: for every use of a variable, function, or class, you identify the corresponding definition(s) that could be associated with this use. That step might involve analyzing the scope to resolve which definition a particular use refers to, especially in languages with shadowing or nested scopes.

To implement def-use analysis for type resolution with Tree-sitter, you can extend your existing traversal functions to maintain a context or environment that maps variable names to their definitions and types. The general idea would be:

class Scope:
    def __init__(self, parent=None):
        self.parent = parent
        self.definitions = {}

    def define(self, name, type_annotation, node):
        self.definitions[name] = (type_annotation, node)

    def find(self, name):
        if name in self.definitions:
            return self.definitions[name]
        if self.parent:
            return self.parent.find(name)
        return None, None

def traverse_with_def_use_analysis(node, scope=Scope()):
    # Handle definition
    if node.type == "function_definition" or node.type == "assignment":
        # Extract name and type if available, then define in current scope
        name = extract_name(node)
        type_annotation = extract_type_annotation(node)
        scope.define(name, type_annotation, node)

    # Handle use
    elif node.type == "identifier":
        name = node.text.decode('utf8')
        definition, _ = scope.find(name)
        if definition:
            print(f"Use of '{name}' with inferred type {definition}")

    # Adjust scope for nested structures (e.g., functions)
    new_scope = scope
    if node.type in ["function_definition", "class_definition"]:
        new_scope = Scope(parent=scope)

    # Recursively traverse children
    for child in node.children:
        traverse_with_def_use_analysis(child, new_scope)

def extract_name(node):
    # Implement based on AST node structure
    pass

def extract_type_annotation(node):
    # Implement based on AST node structure, including handling Python's type annotations
    pass

Without using control flow, how can you resolve the name?
I agree we need scope and control flow both.

To resolve names in Python, considering both scope and control flow, you would need more than Tree-Sitter.You would need a tree climber: tree-climber. That would help tracking where variables are declared and used within different scopes but also understanding the flow of the program to predict how variables might change over time.

Scope tracking involves keeping a record of variable definitions within their respective scopes. Each function, class, and block (e.g., loops, conditionals) introduces a new scope.

A control flow analysis would require analyzing the program's paths (e.g., loops, conditionals) to understand how variables might be re-assigned or modified. Control flow graphs (CFGs) are a common tool used for this purpose, where nodes represent statements and edges represent the flow of execution.

You would need to extend the Tree-sitter AST traversal to build a CFG by marking nodes that introduce control flow changes (e.g., if statements, loops) and tracking the flow between them.
bstee615/tree-climber is an example of such an extension, but only supports C.

The general idea would be:

class CFGNode:
    def __init__(self, ast_node):
        self.ast_node = ast_node
        self.next = []  # Next nodes in the flow

def build_cfg(node, cfg_node=None, parent_cfg_node=None):
    # Create a CFGNode if not passed one
    if cfg_node is None:
        cfg_node = CFGNode(node)
    
    # Logic to link CFGNodes based on control flow constructs (simplified)
    
    for child in node.children:
        child_cfg_node = build_cfg(child, parent_cfg_node=cfg_node)
        if child_cfg_node:
            cfg_node.next.append(child_cfg_node)
    
    return cfg_node

def analyze_cfg(cfg_node, scope):
    # Perform analysis on each CFGNode, considering scope and control flow
    pass

# Simplified example of how you might start this process
root_cfg_node = build_cfg(root_node)
analyze_cfg(root_cfg_node, global_scope)

That is just an outline for a conceptual overview. A full implementation would involve detailed handling of Python's syntax and semantics, potentially including type inference algorithms and more comprehensive static analysis tools.

Once you have constructed the CFG using CFGNode objects, traversing and resolving names means navigating this graph and performing analysis based on the control flow it represents: you start from entry points (like the start of a function), and follow the edges that represent control flow between nodes. During this traversal, you maintain and update a scope context to resolve names based on where you are in the flow of the program.
As you traverse the CFG, you manage scope similarly to how you would in an AST traversal, but with additional considerations for the control flow. For example, entering a new block (e.g., a loop or conditional block) might create a new scope, and returning from a function call would revert to the previous scope.

def analyze_cfg(cfg_node, scope, visited=set()):
    # Avoid re-visiting nodes to prevent infinite loops in cyclic graphs
    if cfg_node in visited:
        return
    visited.add(cfg_node)

    # Scope management based on the CFG node type
    # E.g., entering a function or a new block might create a new scope
    new_scope = scope
    if cfg_node.ast_node.type in ["function_definition", "block"]:
        new_scope = Scope(parent=scope)
    
    # Example of handling a definition within a CFG node
    if is_definition(cfg_node.ast_node):
        name, type_annotation = extract_definition_details(cfg_node.ast_node)
        new_scope.define(name, type_annotation, cfg_node)
    
    # Example of resolving a name use within a CFG node
    if is_use(cfg_node.ast_node):
        name = extract_name(cfg_node.ast_node)
        definition, _ = new_scope.find(name)
        if definition:
            print(f"Resolved name '{name}' with type {definition} at CFG node {cfg_node}")

    # Recursively analyze connected nodes
    for next_node in cfg_node.next:
        analyze_cfg(next_node, new_scope, visited)

You might have to do path-sensitive analysis by analyzing different paths through the CFG separately, for accurate name and type resolution in the presence of conditional logic.