Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: 'torch._C.Node' object has no attribute 'ival' #1237

Closed
MHGL opened this issue Jul 5, 2021 · 16 comments
Closed

AttributeError: 'torch._C.Node' object has no attribute 'ival' #1237

MHGL opened this issue Jul 5, 2021 · 16 comments
Labels
bug Unexpected behaviour that should be corrected (type) triaged Reviewed and examined, release as been assigned if applicable (status)

Comments

@MHGL
Copy link

MHGL commented Jul 5, 2021

🐞Describe the bug

I got this error when convert coremlmodel after torch.jit.freeze

Trace

Traceback (most recent call last):
  File "mini_code.py", line 15, in <module>
    model = ct.convert(
  File "/home/liyang/.local/lib/python3.8/site-packages/coremltools/converters/_converters_entry.py", line 175, in convert
    mlmodel = mil_convert(
  File "/home/liyang/.local/lib/python3.8/site-packages/coremltools/converters/mil/converter.py", line 128, in mil_convert
    proto = mil_convert_to_proto(model, convert_from, convert_to,
  File "/home/liyang/.local/lib/python3.8/site-packages/coremltools/converters/mil/converter.py", line 171, in mil_convert_to_proto
    prog = frontend_converter(model, **kwargs)
  File "/home/liyang/.local/lib/python3.8/site-packages/coremltools/converters/mil/converter.py", line 85, in __call__
    return load(*args, **kwargs)
  File "/home/liyang/.local/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 70, in load
    converter = TorchConverter(torchscript, inputs, outputs, cut_at_symbols)
  File "/home/liyang/.local/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 146, in __init__
    self.graph = InternalTorchIRGraph(
  File "/home/liyang/.local/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/internal_graph.py", line 241, in __init__
    new_node = InternalTorchIRNode(raw_node, parent=self)
  File "/home/liyang/.local/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/internal_graph.py", line 140, in __init__
    self.attr = {
  File "/home/liyang/.local/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/internal_graph.py", line 141, in <dictcomp>
    name: getattr(node, node.kindOf(name))(name)
AttributeError: 'torch._C.Node' object has no attribute 'ival'

To Reproduce

  • If a python script can reproduce the error, please paste the code snippet
import torch
import coremltools as ct

# init maxpool module
torch_model = torch.nn.Conv2d(3, 3, 1, 1)

# Trace with random data
example_input = torch.rand(1, 3, 224, 224) 
trace_model = torch.jit.trace(torch_model, example_input).eval()
freeze_model = torch.jit.freeze(trace_model)

# Convert to Core ML using the Unified Conversion API
model = ct.convert(
    freeze_model,
    inputs=[ct.ImageType(name="input", shape=example_input.shape)], 
)

System environment (please complete the following information):

  • coremltools version (e.g., 3.0b5): 4.1
  • OS (e.g., MacOS, Linux): Ubuntu20.04 LTS
  • How you install python (anaconda, virtualenv, system): miniconda
  • python version (e.g. 3.7): 3.8.5
  • any other relevant information:
    • pytorch version: 1.9.0
    • gpu: GeForce GTX 1650
    • driver: Driver Version: 460.80
    • CUDA: CUDA Version: 11.2
@MHGL MHGL added the bug Unexpected behaviour that should be corrected (type) label Jul 5, 2021
@TobyRoseman TobyRoseman added the triaged Reviewed and examined, release as been assigned if applicable (status) label Jul 7, 2021
@SaulAryehKohn
Copy link

I'm getting the same error on coremltools 5.1, python 3.8.5.

@SaulAryehKohn
Copy link

FWIW, this is the structure that is not replicated in the /mil/frontend/torch/internal_graph.py: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#ivalue

@ludwigfriborg
Copy link

Are there any updates to this? I'm getting the same error with torch 1.10, and coremltools 5.1.

@ludwigfriborg
Copy link

I'm able to replicate the same error using the mobile optimizer as well:

trace = torch.jit.trace(net, dummy_input).eval()
# trace = torch.jit.freeze(trace)
trace = torch.utils.mobile_optimizer.optimize_for_mobile(
    trace,
    set(
        [
            MobileOptimizerType.CONV_BN_FUSION,
            MobileOptimizerType.INSERT_FOLD_PREPACK_OPS,
            MobileOptimizerType.REMOVE_DROPOUT,
        ]
    ),
)

Running it all on python 3.8.10, torch 1.9 and coremltools 5.1.

@borijang
Copy link

I've got the same problem on python 3.8, torch 1.9 and coremltools 5.1, when trying to convert a quantized and optimized ResNet18 to CoreML. On the other hand, I am able to convert the non-quantized model without any issues.

@undefdev
Copy link

I'm having the same problem when trying to convert a model optimized for mobile. Unoptimized conversion works fine, but is not really an option, due to iOS memory limitations.

@jral3s
Copy link

jral3s commented Jun 8, 2022

Was this issue ever resolved? I am running into it while trying to convert a model that is optimized for mobile (which converts fine when it is not optimized).

@aseemw
Copy link
Collaborator

aseemw commented Jun 8, 2022

@jral3s , can you please expand by what you mean by "model that is optimized for mobile"? Which torch APIs are you using in particular to "optimize" the model and which kinds of optimizations it has?

@jral3s
Copy link

jral3s commented Jun 8, 2022

Sure. I am currently trying to convert a deeplabv3 model to coreml on google colab, and if I just trace it like in the code below then I have no issues:

import torch
import coremltools as ct
from torch.utils.mobile_optimizer import optimize_for_mobile

class ModelWrapper(torch.nn.Module):
  # The output of the model is a dictionary, so it must be wrapped to unpack it to a tensor
  
  def __init__(self, model):
    super(ModelWrapper, self).__init__()
    self.model = model
  
  def forward(self, x):
    return self.model(x)["out"]

model = torch.hub.load('pytorch/vision:v0.11.0', 'deeplabv3_resnet50', pretrained=True).eval()
wrapped_model = ModelWrapper(model)
traced_model = torch.jit.trace(wrapped_model, torch.rand((1, 3, 640, 640)))
mlmodel = ct.convert(traced_model,  inputs=[ct.ImageType(name="input", shape=(1, 3, 640, 640))])

But it is too large and slow for my application, so I am trying to use torch's mobile optimizer like in the code below, which returns the following error:

import torch
import coremltools as ct
from torch.utils.mobile_optimizer import optimize_for_mobile

class ModelWrapper(torch.nn.Module):
  # The output of a model is a dictionary, so it must be wrapped to unpack it
  # to a tensor
  def __init__(self, model):
    super(ModelWrapper, self).__init__()
    self.model = model
  
  def forward(self, x):
    return self.model(x)["out"]

model = torch.hub.load('pytorch/vision:v0.11.0', 'deeplabv3_resnet50', pretrained=True).eval()
wrapped_model = ModelWrapper(model)
scripted_model = torch.jit.script(wrapped_model)
optimized_model = optimize_for_mobile(scripted_model)
mlmodel = ct.convert(optimized_model,  inputs=[ct.ImageType(name="input", shape=(1, 3, 640, 640))])
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-22-e5182728cd74>](https://localhost:8080/#) in <module>()
     17 scripted_model = torch.jit.script(wrapped_model)
     18 optimized_model = optimize_for_mobile(scripted_model)
---> 19 mlmodel = ct.convert(optimized_model,  inputs=[ct.ImageType(name="input", shape=(1, 3, 640, 640))])

9 frames
[/usr/local/lib/python3.7/dist-packages/coremltools/converters/_converters_entry.py](https://localhost:8080/#) in convert(model, source, inputs, outputs, classifier_config, minimum_deployment_target, convert_to, compute_precision, skip_model_load, compute_units, useCPUOnly, package_dir, debug)
    361         compute_units=compute_units,
    362         package_dir=package_dir,
--> 363         debug=debug,
    364     )
    365 

[/usr/local/lib/python3.7/dist-packages/coremltools/converters/mil/converter.py](https://localhost:8080/#) in mil_convert(model, convert_from, convert_to, compute_units, **kwargs)
    181         See `coremltools.converters.convert`
    182     """
--> 183     return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
    184 
    185 

[/usr/local/lib/python3.7/dist-packages/coremltools/converters/mil/converter.py](https://localhost:8080/#) in _mil_convert(model, convert_from, convert_to, registry, modelClass, compute_units, **kwargs)
    213                             convert_to,
    214                             registry,
--> 215                             **kwargs
    216                          )
    217 

[/usr/local/lib/python3.7/dist-packages/coremltools/converters/mil/converter.py](https://localhost:8080/#) in mil_convert_to_proto(model, convert_from, convert_to, converter_registry, **kwargs)
    271     frontend_converter = frontend_converter_type()
    272 
--> 273     prog = frontend_converter(model, **kwargs)
    274 
    275     if convert_to.lower() != "neuralnetwork":

[/usr/local/lib/python3.7/dist-packages/coremltools/converters/mil/converter.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
    103         from .frontend.torch import load
    104 
--> 105         return load(*args, **kwargs)
    106 
    107 

[/usr/local/lib/python3.7/dist-packages/coremltools/converters/mil/frontend/torch/load.py](https://localhost:8080/#) in load(model_spec, debug, **kwargs)
     44     outputs = kwargs.get("outputs", None)
     45     cut_at_symbols = kwargs.get("cut_at_symbols", None)
---> 46     converter = TorchConverter(torchscript, inputs, outputs, cut_at_symbols)
     47     return _perform_torch_convert(converter, debug)
     48 

[/usr/local/lib/python3.7/dist-packages/coremltools/converters/mil/frontend/torch/converter.py](https://localhost:8080/#) in __init__(self, torchscript, inputs, outputs, cut_at_symbols)
    157         self.params_dict = params_dict
    158         self.graph = InternalTorchIRGraph(
--> 159             raw_graph, params_dict, self.inputs, cut_at_symbols
    160         )
    161         passes = [

[/usr/local/lib/python3.7/dist-packages/coremltools/converters/mil/frontend/torch/internal_graph.py](https://localhost:8080/#) in __init__(self, raw_graph, params_dict, input_values, cut_at_symbols, nodes, params, inputs, outputs)
    252             # Add nodes
    253             for raw_node in raw_graph.nodes():
--> 254                 new_node = InternalTorchIRNode(raw_node, parent=self)
    255                 if new_node.name == new_node.kind:
    256                     new_node.name = _find_new_name(new_node.name, node_names)

[/usr/local/lib/python3.7/dist-packages/coremltools/converters/mil/frontend/torch/internal_graph.py](https://localhost:8080/#) in __init__(self, node, parent, attr, inputs, outputs, kind, blocks)
    149             self.attr = {
    150                 name: getattr(node, node.kindOf(name))(name)
--> 151                 for name in node.attributeNames()
    152             }
    153             if "value" not in self.attr:

[/usr/local/lib/python3.7/dist-packages/coremltools/converters/mil/frontend/torch/internal_graph.py](https://localhost:8080/#) in <dictcomp>(.0)
    149             self.attr = {
    150                 name: getattr(node, node.kindOf(name))(name)
--> 151                 for name in node.attributeNames()
    152             }
    153             if "value" not in self.attr:

AttributeError: 'torch._C.Node' object has no attribute 'ival'

@aseemw
Copy link
Collaborator

aseemw commented Jun 8, 2022

The optimizations specified here are specific to pytorch mobile and not necessarily relevant or applicable or even supported for CoreML.
The relevant CoreML optimizations are performed during conversion, and later during model load in the CoreML framework API.
So you should not need to call the mobile optimizer for torch as they are meant for a different deployment environment, if you are deploying with CoreML.

But it is too large and slow for my application

Can you please expand on that? How do the optimizations help in reducing the size of the model? What are the memory and latency targets that you are planning to hit and getting with the current CoreML model deployment?
It might be useful to file a feedback request with the particular model you are using so that we can verify that it's being optimized correctly by CoreMLTools and CoreML.

@jral3s
Copy link

jral3s commented Jun 8, 2022

We are trying to run the torch mobile optimizer because when we export the model purely using the trace (exactly like in the first code block), we run into the following error when trying to have it run configured with the neural engine in XCode.

---------------------------------------------------------------------------------------------
Error: Convolution configuration cannot fit in KMEM (Given: 6881280b, Max: 65536b)
---------------------------------------------------------------------------------------------

The work around was to configure it to run with cpu and gpu only, but the resulting model takes 800 ms to run in XCode (whereas the Coreml model for DeepLabv3 available from Apple works just fine on the neural engine and runs with a 30 ms latency).

How can I file a feedback request for this model?

@aseemw
Copy link
Collaborator

aseemw commented Jun 8, 2022

How can I file a feedback request for this model?

https://developer.apple.com/bug-reporting/

@H4dr1en
Copy link

H4dr1en commented Jun 21, 2022

In my case, trace = mobile_optimizer.optimize_for_mobile(traced_model) works but coremltools.convert(traced_model) fails with the same error describe above, AttributeError: 'torch._C.Node' object has no attribute 'ival'

python 3.7, pytorch 1.11 and coremltools 6.0b1

@alealv
Copy link
Contributor

alealv commented Aug 5, 2022

Same error here with: python 3.9, pytorch 1.11 and coremltools master branch

script_mdl = torch.jit.script(generator, **kwargs)
script_mdl = torch.jit.freeze(script_mdl, preserved_attrs=["reset", "get_sample_rate"] if stream else [],)
coreml_mdl = ct.convert(script_mdl, inputs=[ct.TensorType(shape=(1,80,100)], debug=True)
  File "/home/aalvarez/.virtualenvs/tts-train-XZ1ykfT_-py3.9/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/internal_graph.py", line 150, in <dictcomp>
    name: getattr(node, node.kindOf(name))(name)
AttributeError: 'torch._C.Node' object has no attribute 'ival'

@alealv
Copy link
Contributor

alealv commented Aug 5, 2022

I think the problem is solved with Pytorch 1.12 because of this definition

Although, I'm getting other error conv1d not being define, so I cannot assure it.

@TobyRoseman
Copy link
Collaborator

@alealv - you are correct. This has been fixed in PyTorch. The original code now runs without error. Thanks for the information.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Unexpected behaviour that should be corrected (type) triaged Reviewed and examined, release as been assigned if applicable (status)
Projects
None yet
Development

No branches or pull requests

10 participants