|
from __future__ import annotations |
|
|
|
from typing import Union |
|
|
|
import libcst as cst |
|
from libcst._nodes.module import Module |
|
|
|
DocstringNode = Union[cst.Module, cst.ClassDef, cst.FunctionDef] |
|
|
|
|
|
def get_docstring_statement(body: DocstringNode) -> cst.SimpleStatementLine: |
|
"""Extracts the docstring from the body of a node. |
|
|
|
Args: |
|
body: The body of a node. |
|
|
|
Returns: |
|
The docstring statement if it exists, None otherwise. |
|
""" |
|
if isinstance(body, cst.Module): |
|
body = body.body |
|
else: |
|
body = body.body.body |
|
|
|
if not body: |
|
return |
|
|
|
statement = body[0] |
|
if not isinstance(statement, cst.SimpleStatementLine): |
|
return |
|
|
|
expr = statement |
|
while isinstance(expr, (cst.BaseSuite, cst.SimpleStatementLine)): |
|
if len(expr.body) == 0: |
|
return None |
|
expr = expr.body[0] |
|
|
|
if not isinstance(expr, cst.Expr): |
|
return None |
|
|
|
val = expr.value |
|
if not isinstance(val, (cst.SimpleString, cst.ConcatenatedString)): |
|
return None |
|
|
|
evaluated_value = val.evaluated_value |
|
if isinstance(evaluated_value, bytes): |
|
return None |
|
|
|
return statement |
|
|
|
|
|
class DocstringCollector(cst.CSTVisitor): |
|
"""A visitor class for collecting docstrings from a CST. |
|
|
|
Attributes: |
|
stack: A list to keep track of the current path in the CST. |
|
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings. |
|
""" |
|
def __init__(self): |
|
self.stack: list[str] = [] |
|
self.docstrings: dict[tuple[str, ...], cst.SimpleStatementLine] = {} |
|
|
|
def visit_Module(self, node: cst.Module) -> bool | None: |
|
self.stack.append("") |
|
|
|
def leave_Module(self, node: cst.Module) -> None: |
|
return self._leave(node) |
|
|
|
def visit_ClassDef(self, node: cst.ClassDef) -> bool | None: |
|
self.stack.append(node.name.value) |
|
|
|
def leave_ClassDef(self, node: cst.ClassDef) -> None: |
|
return self._leave(node) |
|
|
|
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: |
|
self.stack.append(node.name.value) |
|
|
|
def leave_FunctionDef(self, node: cst.FunctionDef) -> None: |
|
return self._leave(node) |
|
|
|
def _leave(self, node: DocstringNode) -> None: |
|
key = tuple(self.stack) |
|
self.stack.pop() |
|
if hasattr(node, "decorators") and any(i.decorator.value == "overload" for i in node.decorators): |
|
return |
|
|
|
statement = get_docstring_statement(node) |
|
if statement: |
|
self.docstrings[key] = statement |
|
|
|
|
|
class DocstringTransformer(cst.CSTTransformer): |
|
"""A transformer class for replacing docstrings in a CST. |
|
|
|
Attributes: |
|
stack: A list to keep track of the current path in the CST. |
|
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings. |
|
""" |
|
def __init__( |
|
self, |
|
docstrings: dict[tuple[str, ...], cst.SimpleStatementLine], |
|
): |
|
self.stack: list[str] = [] |
|
self.docstrings = docstrings |
|
|
|
def visit_Module(self, node: cst.Module) -> bool | None: |
|
self.stack.append("") |
|
|
|
def leave_Module(self, original_node: Module, updated_node: Module) -> Module: |
|
return self._leave(original_node, updated_node) |
|
|
|
def visit_ClassDef(self, node: cst.ClassDef) -> bool | None: |
|
self.stack.append(node.name.value) |
|
|
|
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.CSTNode: |
|
return self._leave(original_node, updated_node) |
|
|
|
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: |
|
self.stack.append(node.name.value) |
|
|
|
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.CSTNode: |
|
return self._leave(original_node, updated_node) |
|
|
|
def _leave(self, original_node: DocstringNode, updated_node: DocstringNode) -> DocstringNode: |
|
key = tuple(self.stack) |
|
self.stack.pop() |
|
|
|
if hasattr(updated_node, "decorators") and any((i.decorator.value == "overload") for i in updated_node.decorators): |
|
return updated_node |
|
|
|
statement = self.docstrings.get(key) |
|
if not statement: |
|
return updated_node |
|
|
|
original_statement = get_docstring_statement(original_node) |
|
|
|
if isinstance(updated_node, cst.Module): |
|
body = updated_node.body |
|
if original_statement: |
|
return updated_node.with_changes(body=(statement, *body[1:])) |
|
else: |
|
updated_node = updated_node.with_changes(body=(statement, cst.EmptyLine(), *body)) |
|
return updated_node |
|
|
|
body = updated_node.body.body[1:] if original_statement else updated_node.body.body |
|
return updated_node.with_changes(body=updated_node.body.with_changes(body=(statement, *body))) |
|
|
|
|
|
def merge_docstring(code: str, documented_code: str) -> str: |
|
"""Merges the docstrings from the documented code into the original code. |
|
|
|
Args: |
|
code: The original code. |
|
documented_code: The documented code. |
|
|
|
Returns: |
|
The original code with the docstrings from the documented code. |
|
""" |
|
code_tree = cst.parse_module(code) |
|
documented_code_tree = cst.parse_module(documented_code) |
|
|
|
visitor = DocstringCollector() |
|
documented_code_tree.visit(visitor) |
|
transformer = DocstringTransformer(visitor.docstrings) |
|
modified_tree = code_tree.visit(transformer) |
|
return modified_tree.code |
|
|