我想對輸入引數數量不同的不同問題使用相同的神經網路演算法。現在我使用這個功能:
# main function to be called
def call(self,x0,x1=None,x2=None,x3=None,x4=None,x5=None,x6=None,x7=None,x8=None,x9=None):
# define input vector as time-space pairs
if x1 == None:
X = x0
elif x2 == None:
X = tf.concat([x0,x1],1)
elif x3 == None:
X = tf.concat([x0,x1,x2],1)
elif x4 == None:
X = tf.concat([x0,x1,x2,x3],1)
elif x5 == None:
X = tf.concat([x0,x1,x2,x3,x4],1)
elif x6 == None:
X = tf.concat([x0,x1,x2,x3,x4,x5],1)
elif x7 == None:
X = tf.concat([x0,x1,x2,x3,x4,x5,x6],1)
elif x8 == None:
X = tf.concat([x0,x1,x2,x3,x4,x5,x6,x7],1)
elif x9 == None:
X = tf.concat([x0,x1,x2,x3,x4,x5,x6,x7,x8],1)
else:
X = tf.concat([x0,x1,x2,x3,x4,x5,x6,x7,x8,x9],1)
它有效,但有沒有更好(更短/更快)的方法來做到這一點?
uj5u.com熱心網友回復:
您可以使用以下*args語法:
def f(*args):
print([*args])
>>> f("test")
['test']
>>> f("foo", "bar", 42)
['foo', 'bar', 42]
>>> f()
[]
對你來說看起來像這樣
def call(self, *args):
X = tf.concat([*args], 1)
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/511227.html
標籤:Python张量流级联
上一篇:ValueError:層conv2d的輸入0與層不兼容:預期ndim=4,發現ndim=3。收到的完整形狀:[None,30,30]
