keep-loving-pythonのブログ

Pythonを愛し続けたいです(Pythonが流行っている限りですが。。。)

解決策。explainer.explain_graphでエラー。

エラーの内容

Traceback (most recent call last):
  File "chapter14.py", line 146, in <module>
    feature_mask, edge_mask = explainer.explain_graph(data.x, data.edge_index)
  File "C:\Users\XYZZZ\AppData\Local\Programs\Python\Python37\lib\site-packages\torch_geometric\explain\algorithm\gnn_explainer.py", line 315, in explain_graph
    target=self.get_initial_prediction(x, edge_index, **kwargs),
  File "C:\Users\XYZZZ\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "C:\Users\XYZZZ\AppData\Local\Programs\Python\Python37\lib\site-packages\torch_geometric\explain\algorithm\gnn_explainer.py", line 295, in get_initial_prediction
    out = self.model(*args, **kwargs)
  File "C:\Users\XYZZZ\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
TypeError: forward() missing 1 required positional argument: 'batch'

環境

torch-geometric                   2.2.0

windows10 python3.7 CPU

このエラーに出会ったのは。。。。

以下のを上記の環境(CPU)で動かそうとしました。 https://github.com/PacktPublishing/Hands-On-Graph-Neural-Networks-Using-Python/tree/main/Chapter14

↑これをみて2.2.0にしましたが。。。。

解決策

python -m pip install "torch-geometric<2.2"

にて、バージョンを直前に下げた

torch-geometric                   2.1.0.post1