liushaojie
Add application file
360d784
raw
history blame
5.48 kB
from __future__ import annotations
from typing import Union
import libcst as cst
from libcst._nodes.module import Module
DocstringNode = Union[cst.Module, cst.ClassDef, cst.FunctionDef]
def get_docstring_statement(body: DocstringNode) -> cst.SimpleStatementLine:
"""Extracts the docstring from the body of a node.
Args:
body: The body of a node.
Returns:
The docstring statement if it exists, None otherwise.
"""
if isinstance(body, cst.Module):
body = body.body
else:
body = body.body.body
if not body:
return
statement = body[0]
if not isinstance(statement, cst.SimpleStatementLine):
return
expr = statement
while isinstance(expr, (cst.BaseSuite, cst.SimpleStatementLine)):
if len(expr.body) == 0:
return None
expr = expr.body[0]
if not isinstance(expr, cst.Expr):
return None
val = expr.value
if not isinstance(val, (cst.SimpleString, cst.ConcatenatedString)):
return None
evaluated_value = val.evaluated_value
if isinstance(evaluated_value, bytes):
return None
return statement
class DocstringCollector(cst.CSTVisitor):
"""A visitor class for collecting docstrings from a CST.
Attributes:
stack: A list to keep track of the current path in the CST.
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings.
"""
def __init__(self):
self.stack: list[str] = []
self.docstrings: dict[tuple[str, ...], cst.SimpleStatementLine] = {}
def visit_Module(self, node: cst.Module) -> bool | None:
self.stack.append("")
def leave_Module(self, node: cst.Module) -> None:
return self._leave(node)
def visit_ClassDef(self, node: cst.ClassDef) -> bool | None:
self.stack.append(node.name.value)
def leave_ClassDef(self, node: cst.ClassDef) -> None:
return self._leave(node)
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None:
self.stack.append(node.name.value)
def leave_FunctionDef(self, node: cst.FunctionDef) -> None:
return self._leave(node)
def _leave(self, node: DocstringNode) -> None:
key = tuple(self.stack)
self.stack.pop()
if hasattr(node, "decorators") and any(i.decorator.value == "overload" for i in node.decorators):
return
statement = get_docstring_statement(node)
if statement:
self.docstrings[key] = statement
class DocstringTransformer(cst.CSTTransformer):
"""A transformer class for replacing docstrings in a CST.
Attributes:
stack: A list to keep track of the current path in the CST.
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings.
"""
def __init__(
self,
docstrings: dict[tuple[str, ...], cst.SimpleStatementLine],
):
self.stack: list[str] = []
self.docstrings = docstrings
def visit_Module(self, node: cst.Module) -> bool | None:
self.stack.append("")
def leave_Module(self, original_node: Module, updated_node: Module) -> Module:
return self._leave(original_node, updated_node)
def visit_ClassDef(self, node: cst.ClassDef) -> bool | None:
self.stack.append(node.name.value)
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.CSTNode:
return self._leave(original_node, updated_node)
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None:
self.stack.append(node.name.value)
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.CSTNode:
return self._leave(original_node, updated_node)
def _leave(self, original_node: DocstringNode, updated_node: DocstringNode) -> DocstringNode:
key = tuple(self.stack)
self.stack.pop()
if hasattr(updated_node, "decorators") and any((i.decorator.value == "overload") for i in updated_node.decorators):
return updated_node
statement = self.docstrings.get(key)
if not statement:
return updated_node
original_statement = get_docstring_statement(original_node)
if isinstance(updated_node, cst.Module):
body = updated_node.body
if original_statement:
return updated_node.with_changes(body=(statement, *body[1:]))
else:
updated_node = updated_node.with_changes(body=(statement, cst.EmptyLine(), *body))
return updated_node
body = updated_node.body.body[1:] if original_statement else updated_node.body.body
return updated_node.with_changes(body=updated_node.body.with_changes(body=(statement, *body)))
def merge_docstring(code: str, documented_code: str) -> str:
"""Merges the docstrings from the documented code into the original code.
Args:
code: The original code.
documented_code: The documented code.
Returns:
The original code with the docstrings from the documented code.
"""
code_tree = cst.parse_module(code)
documented_code_tree = cst.parse_module(documented_code)
visitor = DocstringCollector()
documented_code_tree.visit(visitor)
transformer = DocstringTransformer(visitor.docstrings)
modified_tree = code_tree.visit(transformer)
return modified_tree.code