我想使用梯度做一些線性代數(例如 tf.matmul)。默認情況下,梯度以張量串列的形式回傳,其中張量可能具有不同的形狀。我的解決方案是將漸變重塑為單個向量。這在渴望模式下作業,但現在我想使用 tf.function 編譯我的代碼。似乎沒有辦法撰寫一個可以在圖形模式(tf.function)中“拉平”梯度的函式。
grad = [tf.ones((2,10)), tf.ones((3,))] # an example of what a gradient from tape.gradient can look like
# this works for flattening the gradient in eager mode only
def flatten_grad(grad):
return tf.concat([tf.reshape(grad[i], tf.math.reduce_prod(tf.shape(grad[i]))) for i in range(len(grad))], 0)
我嘗試像這樣轉換它,但它也不適用于 tf.function 。
@tf.function
def flatten_grad1(grad):
temp = [None]*len(grad)
for i in tf.range(len(grad)):
i = tf.cast(i, tf.int32)
temp[i] = tf.reshape(grad[i], tf.math.reduce_prod(tf.shape(grad[i])))
return tf.concat(temp, 0)
我嘗試了 TensorArrays,但它也不起作用。
@tf.function
def flatten_grad2(grad):
temp = tf.TensorArray(tf.float32, size=len(grad), infer_shape=False)
for i in tf.range(len(grad)):
i = tf.cast(i, tf.int32)
temp = temp.write(i, tf.reshape(grad[i], tf.math.reduce_prod(tf.shape(grad[i]))))
return temp.concat()
uj5u.com熱心網友回復:
也許您可以嘗試直接迭代您list的張量,而不是通過索引獲取單個張量:
import tensorflow as tf
grad = [tf.ones((2,10)), tf.ones((3,))] # an example of what a gradient from tape.gradient can look like
@tf.function
def flatten_grad1(grad):
temp = [None]*len(grad)
for i, g in enumerate(grad):
temp[i] = tf.reshape(g, (tf.math.reduce_prod(tf.shape(g)), ))
return tf.concat(temp, axis=0)
print(flatten_grad1(grad))
tf.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(23,), dtype=float32)
與tf.TensorArray:
@tf.function
def flatten_grad2(grad):
temp = tf.TensorArray(tf.float32, size=0, dynamic_size=True, infer_shape=False)
for g in grad:
temp = temp.write(temp.size(), tf.reshape(g, (tf.math.reduce_prod(tf.shape(g)), )))
return temp.concat()
print(flatten_grad2(grad))
uj5u.com熱心網友回復:
嗨,我認為最大的問題是不鼓勵在 python 計算回圈中的回圈。
這是一個如何使用 tf 函式對梯度變數進行展平的示例,看起來有點奇怪,通常應該是與批次一致的形狀
import tensorflow as tf
import numpy as np
@tf.function
def flatten(arr):
dim = tf.math.reduce_prod(tf.shape(arr)[1:])
return tf.reshape(arr, [-1, dim])
grad = tf.Variable(np.random.randn(100, 10, 10, 3))
flatten_grad = flatten(grad)
轉載請註明出處,本文鏈接:https://www.uj5u.com/net/415558.html
標籤:
