8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130 | class BaseEditor(cst.CSTTransformer):
'''
A tool for code text Transformer new code text
'''
def __init__(self, gen_docstring: Callable, pattern: str, module):
self.gen_docstring = gen_docstring
self.pattern = pattern
self.module = module
self.module_dict = self.create_module_dict(module)
self.current_class = None
def create_module_dict(self, module):
module_dict = {}
seen_objects = set()
for name, obj in inspect.getmembers(module):
module_dict[name] = obj
if inspect.isclass(obj):
self.add_class_members_to_dict(
module_dict, obj, name, seen_objects)
return module_dict
def add_class_members_to_dict(self, module_dict, cls, parent_name, seen_objects):
if cls in seen_objects:
return
seen_objects.add(cls)
for name, obj in inspect.getmembers(cls):
full_name = f"{parent_name}.{name}"
module_dict[full_name] = obj
if inspect.isclass(obj):
self.add_class_members_to_dict(
module_dict, obj, full_name, seen_objects)
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
full_name = f"{self.current_class}.{original_node.name.value}" if self.current_class else original_node.name.value
obj = self._get_obj_by_name(full_name)
docstring = obj.__doc__ if obj else None
return self._update_node_with_new_docstring(original_node, updated_node, docstring)
def visit_ClassDef(self, node: cst.ClassDef) -> None:
self.current_class = node.name.value
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
self.current_class = None
obj = self._get_obj_by_name(original_node.name.value)
docstring = obj.__doc__ if obj else None
return self._update_node_with_new_docstring(original_node, updated_node, docstring)
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
return self._update_node_with_new_docstring(original_node, updated_node, self.module.__doc__)
def _get_obj_by_name(self, name: str):
return self.module_dict.get(name, None)
def _update_node_with_new_docstring(self, original_node: cst.CSTNode, updated_node: cst.CSTNode, docstring: Optional[str]) -> cst.CSTNode:
node_code = cst.Module([]).code_for_node(original_node)
old_docstring = docstring
new_body = []
if isinstance(updated_node.body, tuple):
body = updated_node.body
else:
body = getattr(updated_node.body, 'body', [])
# Extract existing docstring if present and build new body without it
for stmt in body:
if m.matches(stmt, m.SimpleStatementLine(body=[m.Expr(m.SimpleString())])):
old_docstring = cst.ensure_type(
stmt.body[0].value, cst.SimpleString).value.strip('\"\'')
else:
new_body.append(stmt)
new_docstring = self.gen_docstring(
old_docstring, node_code)
# Create a new docstring node if new_docstring is provided
new_docstring_node = (
cst.SimpleStatementLine([cst.Expr(cst.SimpleString(
f'"""{new_docstring}"""'))]) if new_docstring else None
)
if new_docstring_node:
if isinstance(updated_node.body, cst.SimpleStatementSuite): # 检查函数体是否为 SimpleStatementSuite 如果为单行函数
# 创建一个新的 IndentedBlock,包含原函数体的语句
new_body = cst.IndentedBlock(
body=[
new_docstring_node,
cst.SimpleStatementLine(
body=[
cst.Expr(
value=updated_node.body.body[0]
)
]
)
]
)
# 用新的 IndentedBlock 替换原函数体
return updated_node.with_changes(body=new_body)
else:
new_body.insert(0, new_docstring_node)
# Update the body with the new list of statements
try:
if isinstance(updated_node.body, tuple):
updated_body = tuple(new_body)
else:
updated_body = updated_node.body.with_changes(body=new_body)
except Exception as e:
print(f"Error updating body with new statements: {new_body}")
print(f"Error message: {e}")
raise
return updated_node.with_changes(body=updated_body)
|