模型鏈接:https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
Head
分類頭
build:
#激活函式
self._relu_fn = (self._block_args.activation_fn
or global_params.relu_fn or tf.nn.swish)
# Head part.
#卷積層定義
self._conv_head = utils.Conv2D(
filters=round_filters(1280, self._global_params, self._fix_head_stem),
kernel_size=[1, 1],#1x1
strides=[1, 1],
kernel_initializer=conv_kernel_initializer,
padding='same',
data_format=self._global_params.data_format,
use_bias=False)
#bn定義
self._bn1 = self._batch_norm(
axis=channel_axis,
momentum=batch_norm_momentum,
epsilon=batch_norm_epsilon)
#平均池化定義
self._avg_pooling = tf.keras.layers.GlobalAveragePooling2D(
data_format=self._global_params.data_format)
#全連接層定義
if self._global_params.num_classes:
self._fc = tf.layers.Dense(
self._global_params.num_classes,
kernel_initializer=dense_kernel_initializer)
else:
self._fc = None
#dropout定義
if self._global_params.dropout_rate > 0:
self._dropout = tf.keras.layers.Dropout(self._global_params.dropout_rate)
else:
self._dropout = None
前向傳播:
if not features_only:#不只是特征(后續分類)
# Calls final layers and returns logits.
with tf.variable_scope('head'):
#conv-bn1-激活(relu/swish)
outputs = self._relu_fn(
self._bn1(self._conv_head(outputs), training=training))
self.endpoints['head_1x1'] = outputs
#全域平均池化-dropout-fc
#tf的池化
if self._global_params.local_pooling:
shape = outputs.get_shape().as_list()
kernel_size = [
1, shape[self._spatial_dims[0]], shape[self._spatial_dims[1]], 1]
outputs = tf.nn.avg_pool(
outputs, ksize=kernel_size, strides=[1, 1, 1, 1], padding='VALID')
self.endpoints['pooled_features'] = outputs
#dropout
if not pooled_features_only:
if self._dropout:
outputs = self._dropout(outputs, training=training)
self.endpoints['global_pool'] = outputs
#全連接 輸出各類
if self._fc:
outputs = tf.squeeze(outputs, self._spatial_dims)
outputs = self._fc(outputs)
self.endpoints['head'] = outputs
#自己的池化
else:
outputs = self._avg_pooling(outputs)
self.endpoints['pooled_features'] = outputs
if not pooled_features_only:
#dropout
if self._dropout:
outputs = self._dropout(outputs, training=training)
self.endpoints['global_pool'] = outputs
#fc
if self._fc:
outputs = self._fc(outputs)
self.endpoints['head'] = outputs
return outputs
轉載請註明出處,本文鏈接:https://www.uj5u.com/ruanti/237214.html
標籤:其他
