🌐 AI搜索 & 代理 主页
Skip to content

Commit f31b8bb

Browse files
pytorchbotmalfet
andauthored
[MPS] Fix sliced cast (#138535)
[MPS] Fix sliced cast (#138314) This fixes internal crash due to the invalid bufer size computation if sliced API is used Not sure what was the purpose of ```c++ IntArrayRef baseShape; if (src.is_view()) { baseShape = src._base().sizes(); } else { baseShape = getIMPSAllocator()->getBufferShape(src.storage().data()); } int flattenedShaped = 1; for (const auto i : c10::irange(baseShape.size())) { flattenedShaped *= baseShape[i]; } ``` As flattenShaped could be much easier computed as `[srcBuf lengh]/src.element_size()`, and even if `srcBuf` is padded it's a safe thing to do. When someone allocated buffer to hold say uint8 and that view-casted it to float16, attempt to compute `baseShape` returned sizes of original tensor in its data type, rather than size in new dtypes Fixes #137800 Pull Request resolved: #138314 Approved by: https://github.com/albanD, https://github.com/DenisVieriu97 (cherry picked from commit de16159) Co-authored-by: Nikita Shulga <nikita.shulga@gmail.com>
1 parent 848e7ac commit f31b8bb

File tree

2 files changed

+9
-12
lines changed

2 files changed

+9
-12
lines changed

aten/src/ATen/native/mps/OperationUtils.mm

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -542,18 +542,9 @@ void printTensorNDArray(const Tensor& t) {
542542
MPSShape* mpsShape = getMPSShape(_tensor);
543543
MPSShape* mpsStrides = getMPSShape(_tensor.strides());
544544

545-
IntArrayRef baseShape;
546-
if (src.is_view()) {
547-
baseShape = src._base().sizes();
548-
} else {
549-
baseShape = getIMPSAllocator()->getBufferShape(src.storage().data());
550-
}
551-
int flattenedShaped = 1;
552-
for (const auto i : c10::irange(baseShape.size())) {
553-
flattenedShaped *= baseShape[i];
554-
}
555-
MPSShape* mpsBaseShape = @[ @(flattenedShaped) ];
556-
MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:dataType shape:mpsBaseShape];
545+
auto storage_numel = src.storage().nbytes() / src.element_size();
546+
MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:dataType
547+
shape:@[ @(storage_numel) ]];
557548
srcTensorDesc.preferPackedRows = YES;
558549
MPSNDArray* srcNDArray = [[[MPSNDArray alloc] initWithBuffer:srcBuf
559550
offset:src.storage_offset() * src.element_size()

test/test_mps.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10964,6 +10964,12 @@ def test_nonzero_multi_threading(self):
1096410964
t1.start()
1096510965
t2.start()
1096610966

10967+
def test_sliced_view_cast(self):
10968+
# This used to crash on MacOS Sequoia
10969+
# See https://github.com/pytorch/pytorch/issues/137800
10970+
x = torch.rand(16, 16, device='mps', dtype=torch.float16)
10971+
y = x[:, 0:2].view(torch.float32) + 1
10972+
1096710973
def test_masked_select(self):
1096810974
x = torch.randn(3, 4)
1096910975
x_mps = x.to("mps")

0 commit comments

Comments
 (0)