假設我在 Pytorch 中有以下用法,我想遷移到 Tensorflow
done_mask = torch.BoolTensor(dones.values).to(device)
next_state_values[done_mask] = 0.0
uj5u.com熱心網友回復:
是什么dones?假設它是 0/1 張量,您可以將其轉換為 Bool 張量,如下所示:
tf.cast(dones,tf.bool)
但是,如果您想為張量分配值,則不能那樣做。
我推薦的一種方法是乘以 1/0 的矩陣:
next_state_values *= tf.cast(dones!=1,next_state_values.dtype)
我不推薦的另一種方法是使用 tf.tensor_scatter_nd_update,因為它在使用漸變時會出現一些問題。對于您的情況,這將是:
indices = tf.where(dones==1)
next_state_values = tf.tensor_scatter_nd_update(next_state_values ,indices,2*tf.zeros(len(indices)))
轉載請註明出處,本文鏈接:https://www.uj5u.com/houduan/411603.html
標籤:
上一篇:在TPU上訓練RNN時獲得
