我想使用一系列檔案來預測目標標簽:
['some text here', 'some more text here'] --> label
最初我的文本序列是固定長度的,我在嘗試使用填充長度之前開始作業。架構如下:
Input -> HubLayer -> LSTM -> Dense
以下代碼開始運行,然后失敗:
hub_model = 'https://tfhub.dev/google/nnlm-en-dim50/2'
hub_layer = hub.KerasLayer(hub_model, input_shape=(), dtype='string', trainable=False)
def build_model():
inputs = tf.keras.Input(shape=(), dtype='string')
inputs_1d = tf.reshape(inputs, [-1])
x = hub_layer(inputs_1d)
x = tf.reshape(x, [BATCH_SIZE,2, 50])
x = tf.keras.layers.LSTM(32, activation='relu')(x)
outputs = tf.keras.layers.Dense(y.shape[1], activation='sigmoid')(x)
return tf.keras.Model(inputs, outputs)
問題在于如何將序列傳遞給 keras 集線器層(我相信)。
錯誤:
2021-11-02 19:34:34.360697: W tensorflow/core/framework/op_kernel.cc:1680] Invalid argument: required broadcastable shapes
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
/tmp/ipykernel_9371/20784351.py in <module>
----> 1 history = model.fit(train, epochs=2, validation_data=test)
/opt/conda/lib/python3.7/site-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
1182 _r=1):
1183 callbacks.on_train_batch_begin(step)
-> 1184 tmp_logs = self.train_function(iterator)
1185 if data_handler.should_sync:
1186 context.async_wait()
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
883
884 with OptionalXlaContext(self._jit_compile):
--> 885 result = self._call(*args, **kwds)
886
887 new_tracing_count = self.experimental_get_tracing_count()
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
915 # In this case we have created variables on the first call, so we run the
916 # defunned version which is guaranteed to never create variables.
--> 917 return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable
918 elif self._stateful_fn is not None:
919 # Release the lock early so that multiple threads can perform the call
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
3038 filtered_flat_args) = self._maybe_define_function(args, kwargs)
3039 return graph_function._call_flat(
-> 3040 filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access
3041
3042 @property
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1962 # No tape is watching; skip to running the function.
1963 return self._build_call_outputs(self._inference_function.call(
-> 1964 ctx, args, cancellation_manager=cancellation_manager))
1965 forward_backward = self._select_forward_and_backward_functions(
1966 args,
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)
594 inputs=args,
595 attrs=attrs,
--> 596 ctx=ctx)
597 else:
598 outputs = execute.execute_with_cancellation(
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
58 ctx.ensure_initialized()
59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
62 if name is not None:
InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: required broadcastable shapes
[[node gradient_tape/binary_crossentropy/logistic_loss/mul/Mul (defined at tmp/ipykernel_9371/484917154.py:1) ]]
(1) Invalid argument: required broadcastable shapes
[[node gradient_tape/binary_crossentropy/logistic_loss/mul/Mul (defined at tmp/ipykernel_9371/484917154.py:1) ]]
[[model_1/keras_layer_1/StatefulPartitionedCall/StatefulPartitionedCall/StatefulPartitionedCall/tokenize/StringSplit/StringSplit/_23]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_4634]
Function call stack:
train_function -> train_function
uj5u.com熱心網友回復:
您只需要確保在訓練期間提供句子和標簽,并且輸入和輸出形狀都是正確的。這是一個簡單的作業示例,其中輸入包含兩個句子和一個相應的標簽:
import tensorflow_hub as hub
import tensorflow as tf
hub_model = 'https://tfhub.dev/google/nnlm-en-dim50/2'
hub_layer = hub.KerasLayer(hub_model, input_shape=(), dtype='string', trainable=False)
def build_model():
inputs = tf.keras.Input(shape=(2,), dtype='string')
inputs_1d = tf.reshape(inputs, [-1])
x = hub_layer(inputs_1d)
x = tf.reshape(x, [BATCH_SIZE, 2, 50])
x = tf.keras.layers.LSTM(32, activation='relu')(x)
outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)
return tf.keras.Model(inputs, outputs)
BATCH_SIZE = 2
sentences = tf.constant([
[
"Improve the physical fitness of your goldfish by getting him a bicycle",
"You are unsure whether or not to trust him but very thankful that you wore a turtle neck"],
["Not all people who wander are lost",
"There is a reason that roses have thorns"],
["Charles ate the french fries knowing they would be his last meal",
"He hated that he loved what she hated about hate"],
["Charles ate the french fries knowing they would be his last meal",
"He hated that he loved what she hated about hate"],
["Charles ate the french fries knowing they would be his last meal",
"He hated that he loved what she hated about hate"],
["Charles ate the french fries knowing they would be his last meal",
"He hated that he loved what she hated about hate"]
])
labels = tf.random.uniform((6, ), minval=0, maxval=2, dtype=tf.dtypes.int32)
model = build_model()
model.compile(optimizer='adam',
loss=tf.losses.BinaryCrossentropy())
train_dataset = tf.data.Dataset.from_tensor_slices(
(sentences, labels)).shuffle(
sentences.shape[0]).batch(
BATCH_SIZE)
model.fit(x=train_dataset, epochs=2)
Epoch 1/2
3/3 [==============================] - 1s 8ms/step - loss: 0.6965
Epoch 2/2
3/3 [==============================] - 0s 6ms/step - loss: 0.6916
<keras.callbacks.History at 0x7fe851c4a090>
您LSTM需要輸入 shape (timesteps, features),因此您需要制作x該形狀的張量。并且您的第一次重塑是必要的,因為它hub_layer以一維字串張量中的一批句子作為輸入。
轉載請註明出處,本文鏈接:https://www.uj5u.com/qiye/347407.html
上一篇:使用ANN進行多任務學習?
