同樣用pytorch 和 keras 寫的用LSTM計算文本相似度的模型。
keras在測驗集上準確率88%, pytorch做好的一次為83%。但keras特別慢,是因為兩者LSTM模塊有差別嗎?
keras模型
# 搭建模型
input_left = Input(shape=(Max_Seq_Num, ))
input_right = Input(shape=(Max_Seq_Num, ))
# 詞嵌入
seq_input = Input(shape=(Max_Seq_Num, ))
embedding_layer = Embedding(Words_vec.words_num, Embed_Size, input_length=Max_Seq_Num,
weights=[Words_vec.vectors],
trainable=False, name='embed_layer')(seq_input)
# LSTM
lstm_layer = LSTM(128, dropout=0.1, recurrent_dropout=0.1, name='lstm')(embedding_layer)
model_encode = Model(seq_input, lstm_layer, name='encode_model')
model_encode.summary()
left_encode = model_encode(input_left)
right_encode = model_encode(input_right)
merge_vec = concatenate([left_encode, right_encode])
drop1 = Dropout(0.1, name='drop1')(merge_vec)
BN = BatchNormalization(name='bn1')(drop1)
dence_1 = Dense(128, activation='relu', name='dence1')(BN)
drop2 = Dropout(0.1, name='drop2')(dence_1)
BN2 = BatchNormalization(name='bn2')(drop2)
predictions = Dense(1, activation='sigmoid', name='pre')(BN2)
model = Model(input=[input_left, input_right], outputs=predictions)
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['acc'])
pytorch模型
class bilstm_att(nn.Module):
def __init__(self, embed_shape, vectors):
super(bilstm_att, self).__init__()
self.word_embeding = nn.Embedding(embed_shape[0], embed_shape[1])
self.word_embeding.from_pretrained(torch.from_numpy(vectors), freeze=True)
self.bilstm_word = nn.Sequential(nn.Dropout(0.1), nn.LSTM(embed_shape[1], 128, num_layers=1, dropout=0.1,
bidirectional=False, batch_first=True))
self.pre = nn.Sequential(nn.Dropout(0.1), nn.BatchNorm1d(256), nn.Linear(256, 128), nn.ReLU(True),
nn.BatchNorm1d(128), nn.Dropout(p=0.1), nn.Linear(128, 1), nn.Sigmoid())
def encoder(self, x):
x = self.word_embeding(x)
x, _ = self.bilstm_word(x)
# x 32 * 40 * 256
x = x[:, -1, :]
return x
def forward(self, x1, x2):
X1 = self.encoder(x1)
X2 = self.encoder(x2)
x = torch.cat((X1, X2), dim=1)
y = self.pre(x).squeeze()
return y
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss = nn.BCELoss()
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/44917.html
標籤:人工智能技術
