File size: 5,483 Bytes
360d784 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
from __future__ import annotations
from typing import Union
import libcst as cst
from libcst._nodes.module import Module
DocstringNode = Union[cst.Module, cst.ClassDef, cst.FunctionDef]
def get_docstring_statement(body: DocstringNode) -> cst.SimpleStatementLine:
"""Extracts the docstring from the body of a node.
Args:
body: The body of a node.
Returns:
The docstring statement if it exists, None otherwise.
"""
if isinstance(body, cst.Module):
body = body.body
else:
body = body.body.body
if not body:
return
statement = body[0]
if not isinstance(statement, cst.SimpleStatementLine):
return
expr = statement
while isinstance(expr, (cst.BaseSuite, cst.SimpleStatementLine)):
if len(expr.body) == 0:
return None
expr = expr.body[0]
if not isinstance(expr, cst.Expr):
return None
val = expr.value
if not isinstance(val, (cst.SimpleString, cst.ConcatenatedString)):
return None
evaluated_value = val.evaluated_value
if isinstance(evaluated_value, bytes):
return None
return statement
class DocstringCollector(cst.CSTVisitor):
"""A visitor class for collecting docstrings from a CST.
Attributes:
stack: A list to keep track of the current path in the CST.
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings.
"""
def __init__(self):
self.stack: list[str] = []
self.docstrings: dict[tuple[str, ...], cst.SimpleStatementLine] = {}
def visit_Module(self, node: cst.Module) -> bool | None:
self.stack.append("")
def leave_Module(self, node: cst.Module) -> None:
return self._leave(node)
def visit_ClassDef(self, node: cst.ClassDef) -> bool | None:
self.stack.append(node.name.value)
def leave_ClassDef(self, node: cst.ClassDef) -> None:
return self._leave(node)
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None:
self.stack.append(node.name.value)
def leave_FunctionDef(self, node: cst.FunctionDef) -> None:
return self._leave(node)
def _leave(self, node: DocstringNode) -> None:
key = tuple(self.stack)
self.stack.pop()
if hasattr(node, "decorators") and any(i.decorator.value == "overload" for i in node.decorators):
return
statement = get_docstring_statement(node)
if statement:
self.docstrings[key] = statement
class DocstringTransformer(cst.CSTTransformer):
"""A transformer class for replacing docstrings in a CST.
Attributes:
stack: A list to keep track of the current path in the CST.
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings.
"""
def __init__(
self,
docstrings: dict[tuple[str, ...], cst.SimpleStatementLine],
):
self.stack: list[str] = []
self.docstrings = docstrings
def visit_Module(self, node: cst.Module) -> bool | None:
self.stack.append("")
def leave_Module(self, original_node: Module, updated_node: Module) -> Module:
return self._leave(original_node, updated_node)
def visit_ClassDef(self, node: cst.ClassDef) -> bool | None:
self.stack.append(node.name.value)
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.CSTNode:
return self._leave(original_node, updated_node)
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None:
self.stack.append(node.name.value)
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.CSTNode:
return self._leave(original_node, updated_node)
def _leave(self, original_node: DocstringNode, updated_node: DocstringNode) -> DocstringNode:
key = tuple(self.stack)
self.stack.pop()
if hasattr(updated_node, "decorators") and any((i.decorator.value == "overload") for i in updated_node.decorators):
return updated_node
statement = self.docstrings.get(key)
if not statement:
return updated_node
original_statement = get_docstring_statement(original_node)
if isinstance(updated_node, cst.Module):
body = updated_node.body
if original_statement:
return updated_node.with_changes(body=(statement, *body[1:]))
else:
updated_node = updated_node.with_changes(body=(statement, cst.EmptyLine(), *body))
return updated_node
body = updated_node.body.body[1:] if original_statement else updated_node.body.body
return updated_node.with_changes(body=updated_node.body.with_changes(body=(statement, *body)))
def merge_docstring(code: str, documented_code: str) -> str:
"""Merges the docstrings from the documented code into the original code.
Args:
code: The original code.
documented_code: The documented code.
Returns:
The original code with the docstrings from the documented code.
"""
code_tree = cst.parse_module(code)
documented_code_tree = cst.parse_module(documented_code)
visitor = DocstringCollector()
documented_code_tree.visit(visitor)
transformer = DocstringTransformer(visitor.docstrings)
modified_tree = code_tree.visit(transformer)
return modified_tree.code
|