def get_shape(torch_node):
"""Return the output shape of the given Pytorch node."""
# Extract node output shape from the node string representation
# This is a hack because there doesn't seem to be an official way to do it.
# See my quesiton in the PyTorch forum:
# https://discuss.pytorch.org/t/node-output-shape-from-trace-graph/24351/2
# TODO: find a better way to extract output shape
# TODO: Assuming the node has one output. Update if we encounter a multi-output node.
shape = torch_node.output().type().sizes()
return shape