我正在嘗試為 Keras 定義自定義 rmse 損失函式。我撰寫了下面的函式來懲罰資料值小于 0.15 時的損失,否則。
import keras.backend as K
def custom_rmse(y_true, y_pred):
loss = K.square(y_pred - y_true)
for i in range(len(y_true)):
for j in range(y_true.shape[1]):
tmp = float(y_true[i][j])
if (tmp < 0.15):
loss[i][j] *= 0.2
else:
loss[i][j] *=0.8
loss = K.sqrt(K.sum(loss, axis=1))
return loss
但是當我運行模型并嘗試修復它時,我不斷收到此錯誤
/usr/local/lib/python3.7/dist-packages/keras/engine/training.py:853 train_function *
return step_function(self, iterator)
<ipython-input-95-efab27dd2563>:8 custom_rmse *
if (tmp < 0.15):
/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/operators/control_flow.py:1172 if_stmt
_tf_if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/operators/control_flow.py:1219 _tf_if_stmt
cond, aug_body, aug_orelse, strict=True)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py:206 wrapper
return target(*args, **kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/deprecation.py:549 new_func
return func(*args, **kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/control_flow_ops.py:1254 cond
return cond_v2.cond_v2(pred, true_fn, false_fn, name)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/cond_v2.py:88 cond_v2
op_return_value=pred)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py:1007 func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/operators/control_flow.py:1197 aug_body
set_state(init_vars)
/tmp/tmp_3e6lmrw.py:35 set_state
(loss[i][j],) = vars_
TypeError: 'Tensor' object does not support item assignment
我將感謝有關如何解決此問題的建議。謝謝。
uj5u.com熱心網友回復:
If-Else 陳述句通常不是損失函式的方法。大多數情況下,最好采用“軟”方式來實作您的目標。這可以通過(例如)通過以下方式對您的損失值使用陡峭的邏輯函式來完成:
def custom_rmse(y_true, y_pred):
loss = K.square(y_pred - y_true)
logistic_values = tf.sigmoid(1000 * (y_true - 0.15))
loss = logistic_values * loss * 0.8 (1-logistic_values * loss * 0.2)
loss = K.sqrt(K.sum(loss, axis=1))
return loss
此代碼將執行以下操作:
- 我們從您的 y_true 中減去 0.15(您的閾值),以便新值的閾值現在為 0。
- 我們將結果乘以一個很大的數字(我在這里選擇了 1000,數字越大,“軟閾值”就越陡峭。這意味著,所有高于閾值的值現在都是非常高的正值,所有低于您的閾值的值閾值現在將是高負值。
- 我們將 sigmoid 函式應用于結果值。對于所有高正值,此函式將為 1,對于所有高負值(中間有軟過渡),此函式將為 -0。
- 現在,我們可以將我們的損失乘以logistic_values或1-logistic_values,它基本上充當掩碼,分別屏蔽所有 0 或 1 的值。所有未被屏蔽的值現在都可以乘以它們各自的因子 0.8 或 0.2。
轉載請註明出處,本文鏈接:https://www.uj5u.com/qianduan/318295.html
上一篇:numpy陣列的反向整形
