import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
class KerasSupervisedModelWrapper(keras.Model):
def __init__(self, batch_size, **kwargs):
super().__init__()
self.batch_size = batch_size
def summary(self, input_shape): # temporary fix for a bug
x = layers.Input(shape=input_shape)
model = keras.Model(inputs=[x], outputs=self.call(x))
return model.summary()
class ExampleModel(KerasSupervisedModelWrapper):
def __init__(self, batch_size):
super().__init__(batch_size)
self.conv1 = layers.Conv2D(32, kernel_size=(3, 3), activation='relu')
def call(self, x):
x = self.conv1(x)
return x
model = MyModel(15)
model.summary([28, 28, 1])
輸出:
Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 28, 28, 1)] 0
conv2d_2 (Conv2D) (None, 26, 26, 32) 320
=================================================================
Total params: 320
Trainable params: 320
Non-trainable params: 0
_________________________________________________________________
我正在為 keras 模型撰寫一個包裝器,以預先定義一些有用的方法和變數,如上所述。
而且我想修改包裝器以獲取一些層來組成模型keras.Sequential。
因此,我添加Sequential了分配新call方法的方法,如下所示。
class KerasSupervisedModelWrapper(keras.Model):
...(continue)...
@staticmethod
def Sequential(layers, **kwargs):
model = KerasSupervisedModelWrapper(**kwargs)
pipe = keras.Sequential(layers)
def call(self, x):
return pipe(x)
model.call = call
return model
但是,它似乎沒有按我的預期作業。相反,它顯示下面的錯誤訊息。
model = KerasSupervisedModelWrapper.Sequential([
layers.Conv2D(32, kernel_size=(3, 3), activation="relu")
], batch_size=15)
model.summary((28, 28, 1))
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipykernel_91471/2826773946.py in <module>
1 # model.build((None, 28, 28, 1))
2 # model.compile('adam', loss=keras.losses.SparseCategoricalCrossentropy(), metrics=['accuracy'])
----> 3 model.summary((28, 28, 1))
/tmp/ipykernel_91471/3696340317.py in summary(self, input_shape)
10 def summary(self, input_shape): # temporary fix for a bug
11 x = layers.Input(shape=input_shape)
---> 12 model = keras.Model(inputs=[x], outputs=self.call(x))
13 return model.summary()
14
TypeError: call() missing 1 required positional argument: 'x'
keras.Sequential在使用其他屬性時,我可以為包裝器做些什么來獲取模型?
uj5u.com熱心網友回復:
你可以嘗試這樣的事情:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
class KerasSupervisedModelWrapper(keras.Model):
def __init__(self, batch_size, **kwargs):
super().__init__()
self.batch_size = batch_size
def summary(self, input_shape): # temporary fix for a bug
x = layers.Input(shape=input_shape)
model = keras.Model(inputs=[x], outputs=self.call(x))
return model.summary()
@staticmethod
def Sequential(layers, **kwargs):
model = KerasSupervisedModelWrapper(**kwargs)
pipe = keras.Sequential(layers)
model.call = pipe
return model
class ExampleModel(KerasSupervisedModelWrapper):
def __init__(self, batch_size):
super().__init__(batch_size)
self.conv1 = layers.Conv2D(32, kernel_size=(3, 3), activation='relu')
def call(self, x):
x = self.conv1(x)
return x
model = ExampleModel(15)
model.summary([28, 28, 1])
model = KerasSupervisedModelWrapper.Sequential([
layers.Conv2D(32, kernel_size=(3, 3), activation="relu")
], batch_size=15)
model.summary((28, 28, 1))
print(model(tf.random.normal((1, 28, 28, 1))).shape)
Model: "model_9"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_14 (InputLayer) [(None, 28, 28, 1)] 0
conv2d_17 (Conv2D) (None, 26, 26, 32) 320
=================================================================
Total params: 320
Trainable params: 320
Non-trainable params: 0
_________________________________________________________________
Model: "model_10"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_15 (InputLayer) [(None, 28, 28, 1)] 0
sequential_8 (Sequential) (None, 26, 26, 32) 320
=================================================================
Total params: 320
Trainable params: 320
Non-trainable params: 0
_________________________________________________________________
(1, 26, 26, 32)
轉載請註明出處,本文鏈接:https://www.uj5u.com/qiye/429745.html
上一篇:機器學習準確度太低
