解決策。explainer.explain_graphでエラー。
エラーの内容
Traceback (most recent call last): File "chapter14.py", line 146, in <module> feature_mask, edge_mask = explainer.explain_graph(data.x, data.edge_index) File "C:\Users\XYZZZ\AppData\Local\Programs\Python\Python37\lib\site-packages\torch_geometric\explain\algorithm\gnn_explainer.py", line 315, in explain_graph target=self.get_initial_prediction(x, edge_index, **kwargs), File "C:\Users\XYZZZ\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context return func(*args, **kwargs) File "C:\Users\XYZZZ\AppData\Local\Programs\Python\Python37\lib\site-packages\torch_geometric\explain\algorithm\gnn_explainer.py", line 295, in get_initial_prediction out = self.model(*args, **kwargs) File "C:\Users\XYZZZ\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl return forward_call(*input, **kwargs) TypeError: forward() missing 1 required positional argument: 'batch'
環境
torch-geometric 2.2.0
windows10 python3.7 CPU
このエラーに出会ったのは。。。。
以下のを上記の環境(CPU)で動かそうとしました。 https://github.com/PacktPublishing/Hands-On-Graph-Neural-Networks-Using-Python/tree/main/Chapter14
↑これをみて2.2.0にしましたが。。。。
解決策
python -m pip install "torch-geometric<2.2"
にて、バージョンを直前に下げた
torch-geometric 2.1.0.post1