|
1 | 1 | using System; |
2 | 2 | using System.Collections; |
3 | 3 | using System.Collections.Generic; |
| 4 | +using System.Linq; |
| 5 | +using System.Reflection; |
4 | 6 | using System.Runtime.InteropServices; |
5 | 7 |
|
6 | 8 | namespace Python.Runtime |
@@ -557,5 +559,91 @@ public static int mp_ass_subscript(IntPtr ob, IntPtr idx, IntPtr v) |
557 | 559 |
|
558 | 560 | return 0; |
559 | 561 | } |
| 562 | + |
| 563 | + static IntPtr tp_call_impl(IntPtr ob, IntPtr args, IntPtr kw) |
| 564 | + { |
| 565 | + IntPtr tp = Runtime.PyObject_TYPE(ob); |
| 566 | + var self = (ClassBase)GetManagedObject(tp); |
| 567 | + |
| 568 | + if (!self.type.Valid) |
| 569 | + { |
| 570 | + return Exceptions.RaiseTypeError(self.type.DeletedMessage); |
| 571 | + } |
| 572 | + |
| 573 | + Type type = self.type.Value; |
| 574 | + |
| 575 | + var calls = GetCallImplementations(type).ToList(); |
| 576 | + if (calls.Count > 0) |
| 577 | + { |
| 578 | + var callBinder = new MethodBinder(); |
| 579 | + foreach (MethodInfo call in calls) |
| 580 | + { |
| 581 | + callBinder.AddMethod(call); |
| 582 | + } |
| 583 | + return callBinder.Invoke(ob, args, kw); |
| 584 | + } |
| 585 | + |
| 586 | + return InvokeCallInheritedFromPython(new BorrowedReference(ob), args, kw); |
| 587 | + } |
| 588 | + |
| 589 | + static IEnumerable<MethodInfo> GetCallImplementations(Type type) |
| 590 | + => type.GetMethods(BindingFlags.Public | BindingFlags.Instance) |
| 591 | + .Where(m => m.Name == "__call__"); |
| 592 | + |
| 593 | + /// <summary> |
| 594 | + /// Find bases defined in Python and use their __call__ if any |
| 595 | + /// </summary> |
| 596 | + static IntPtr InvokeCallInheritedFromPython(BorrowedReference ob, IntPtr args, IntPtr kw) |
| 597 | + { |
| 598 | + BorrowedReference tp = Runtime.PyObject_TYPE(ob); |
| 599 | + using var super = new PyObject(new BorrowedReference(Runtime.PySuper_Type)); |
| 600 | + using var pyInst = new PyObject(ob); |
| 601 | + |
| 602 | + BorrowedReference mro = PyType.GetMRO(tp); |
| 603 | + nint mroLen = Runtime.PyTuple_Size(mro); |
| 604 | + for (int baseIndex = 0; baseIndex < mroLen - 1; baseIndex++) |
| 605 | + { |
| 606 | + BorrowedReference @base = Runtime.PyTuple_GetItem(mro, baseIndex); |
| 607 | + if (!IsManagedType(@base)) continue; |
| 608 | + |
| 609 | + BorrowedReference nextBase = Runtime.PyTuple_GetItem(mro, baseIndex + 1); |
| 610 | + if (ManagedType.IsManagedType(nextBase)) continue; |
| 611 | + |
| 612 | + // call via super |
| 613 | + using var managedBase = new PyObject(@base); |
| 614 | + using var superInstance = super.Invoke(managedBase, pyInst); |
| 615 | + using var call = Runtime.PyObject_GetAttrString(superInstance.Reference, "__call__"); |
| 616 | + if (call.IsNull()) |
| 617 | + { |
| 618 | + if (Exceptions.ExceptionMatches(Exceptions.AttributeError)) |
| 619 | + { |
| 620 | + Runtime.PyErr_Clear(); |
| 621 | + continue; |
| 622 | + } |
| 623 | + else |
| 624 | + { |
| 625 | + return IntPtr.Zero; |
| 626 | + } |
| 627 | + } |
| 628 | + |
| 629 | + return Runtime.PyObject_Call(call.DangerousGetAddress(), args, kw); |
| 630 | + } |
| 631 | + |
| 632 | + Exceptions.SetError(Exceptions.TypeError, "object is not callable"); |
| 633 | + return IntPtr.Zero; |
| 634 | + } |
| 635 | + |
| 636 | + static readonly Interop.TernaryFunc tp_call_delegate = tp_call_impl; |
| 637 | + |
| 638 | + public virtual void InitializeSlots(SlotsHolder slotsHolder) |
| 639 | + { |
| 640 | + if (!this.type.Valid) return; |
| 641 | + |
| 642 | + if (GetCallImplementations(this.type.Value).Any() |
| 643 | + && !slotsHolder.IsHolding(TypeOffset.tp_call)) |
| 644 | + { |
| 645 | + TypeManager.InitializeSlot(ObjectReference, TypeOffset.tp_call, tp_call_delegate, slotsHolder); |
| 646 | + } |
| 647 | + } |
560 | 648 | } |
561 | 649 | } |
0 commit comments