🌐 AI搜索 & 代理 主页
Skip to content

Commit c63d653

Browse files
committed
(fix class inheritance tests)
1 parent 3460d8f commit c63d653

File tree

2 files changed

+38
-5
lines changed

2 files changed

+38
-5
lines changed
Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,44 @@
11
from py_codegen.test_fixtures.classes_with_inheritance import ChildClass
22

3-
from py_codegen.type_extractor.nodes.BaseNodeType import BaseOption
43
from py_codegen.type_extractor.nodes.ClassFound import ClassFound
5-
from py_codegen.type_extractor.nodes.TypeOR import TypeOR
64
from py_codegen.type_extractor.__tests__.utils import traverse, cleanup
75
from py_codegen.type_extractor.type_extractor import TypeExtractor
86

97

108
def test_classes_with_inheritance():
11-
type_collector = TypeExtractor()
9+
type_extractor = TypeExtractor()
1210

13-
type_collector.add(None)(ChildClass)
11+
type_extractor.add(None)(ChildClass)
1412

15-
print(type_collector)
13+
classes = {
14+
key: traverse(value, cleanup)
15+
for (key, value) in type_extractor.collected_types.items()
16+
if isinstance(value, ClassFound)
17+
}
18+
parent_class_a = ClassFound(
19+
name='ParentClassA',
20+
fields={
21+
'from_parent_a': int,
22+
},
23+
)
24+
parent_class_b = ClassFound(
25+
name='ParentClassB',
26+
fields={
27+
'from_parent_b': str,
28+
},
29+
)
30+
assert classes == {
31+
'ParentClassA': parent_class_a,
32+
'ParentClassB': parent_class_b,
33+
'ChildClass': ClassFound(
34+
name='ChildClass',
35+
fields={
36+
'b': str,
37+
},
38+
base_classes=[
39+
parent_class_a,
40+
parent_class_b,
41+
],
42+
)
43+
}
44+
print(type_extractor)

py_codegen/type_extractor/__tests__/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ def traverse(node: NodeType, func: traverse_func_type):
2222
key: traverse(value, func)
2323
for (key, value) in node.fields.items()
2424
}
25+
class_found_node.base_classes = [
26+
traverse(base_class, func)
27+
for base_class in node.base_classes
28+
]
2529
return func(class_found_node)
2630
if isinstance(node, FunctionFound):
2731
function_found_node = copy(node)

0 commit comments

Comments
 (0)