我正在嘗試制作一個預測模型,用于根據 word2vec 向量識別治療性肽。該資料集有 100 個正例和 100 個負例。我已經用 Word2Vec 嵌入了肽序列,并且正在嘗試訓練我的神經網路。但是,準確度保持不變,為 51.88%。
我嘗試過的:更改損失函式(到二元交叉熵),每層中的節點數
這是我的代碼:
import sklearn
a = sklearn.utils.shuffle(arrayvectors, random_state=1)
b = sklearn.utils.shuffle(labels, random_state=1)
dfa = pd.DataFrame(a, columns=None)
dfb = pd.DataFrame(b, columns=None)
X = dfa.iloc[:]
y = dfb.iloc[:]
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.2, random_state=300)
X_train = np.asarray(X_train)
X_test = np.asarray(X_test)
y_train = np.asarray(y_train)
y_test = np.asarray(y_test)
y_train = y_train.astype(np.float32)
y_test = y_test.astype(np.float32)
## train data
class trainData(Dataset):
def __init__(self, X_data, y_data):
self.X_data = X_data
self.y_data = y_data
def __getitem__(self, index):
return self.X_data[index], self.y_data[index]
def __len__ (self):
return len(self.X_data)
train_data = trainData(torch.FloatTensor(X_train),
torch.FloatTensor(y_train))
## test data
class testData(Dataset):
def __init__(self, X_data):
self.X_data = X_data
def __getitem__(self, index):
return self.X_data[index]
def __len__ (self):
return len(self.X_data)
test_data = testData(torch.FloatTensor(X_test))
EPOCHS = 100
BATCH_SIZE = 2
LEARNING_RATE = 0.0001
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1)
# make mode
model = Sequential()
model.add(Dense(64, activation='relu', input_shape=(4,)))
model.add(Dropout(0.5))
model.add(Dense(16, input_dim=1, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(12,activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(1,activation='sigmoid'))
model.summary()
model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])
history = model.fit(X_train, y_train, epochs=1000, batch_size=64)
如果您有任何想法,請告訴我!
uj5u.com熱心網友回復:
嘗試將批量大小從 2 增加到 16,并將 dropout 減少到 0.2 或更小。輟學太多了
轉載請註明出處,本文鏈接:https://www.uj5u.com/yidong/350057.html
上一篇:如何將從image_dataset_from_directory獲得的資料集拆分為資料和標簽?
下一篇:根據另一個和列的索引生成新的張量
