tf2中metrics的配置及报错categorical_accuracy() missing 2 required positional arguments: ‘y_true‘ and ‘y_pr
问题
其中,categorical_accuracy是function,Precision为类。程序可以正常运行:
当categorical_accuracy修改为categorical_accuracy()时,报错:
TypeError: categorical_accuracy() missing 2 required positional arguments: ‘y_true’ and ‘y_pred’
理解
Tensorflow 2 keras.Model.compile()中metrics参数的说明:
可以看到是传入内建函数的名称,自定义函数和Metric类的实例。若传入func,tensorflow在模型训练时会自动调用func(y_true, y_pred)。
以上错误的原因就是传入了func(),使得程序在modelcompile()步骤时直接调用了categorical_accuracy(),而非在训练时由tf调用并传入y_true及y_pred。
另外,tensorflow2推荐通过子类化tf.keras.metrics.Metric的方法实现自定义metric。具体方法参考:
Metric类
custom metrics
推荐一个教程:
Keras Metrics教程
欢迎讨论,另外,想知道model.compile的具体实现(教程、源码等都可)。