import tensorflow as tf
# 載入資料集
mnist = tf.keras.datasets.mnist
# 載入資料,資料載入的時候就已經劃分好訓練集和測驗集
# 訓練集資料 x_train 的資料形狀為(60000,28,28)
# 訓練集標簽 y_train 的資料形狀為(60000)
# 測驗集資料 x_test 的資料形狀為(10000,28,28)
# 測驗集標簽 y_test 的資料形狀為(10000)
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 對訓練集和測驗集的資料進行歸一化處理,有助于提升模型訓練速度
x_train, x_test = x_train / 255.0, x_test / 255.0
# 把訓練集和測驗集的標簽轉為獨熱編碼
y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
# 創建 dataset 物件,使用 dataset 物件來管理資料
mnist_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
# 訓練周期設定為 1(把所有訓練集資料訓練一次稱為訓練一個周期)
mnist_train = mnist_train.repeat(1)
# 批次大小設定為 32(每次訓練模型傳入 32 個資料進行訓練)
mnist_train = mnist_train.batch(32)
# 創建 dataset 物件,使用 dataset 物件來管理資料
mnist_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
# 訓練周期設定為 1(把所有訓練集資料訓練一次稱為訓練一個周期)
mnist_test = mnist_test.repeat(1)
# 批次大小設定為 32(每次訓練模型傳入 32 個資料進行訓練)
mnist_test = mnist_test.batch(32)
# 模型定義
# 先用 Flatten 把資料從 3 維變成 2 維,(60000,28,28)->(60000,784)
# 設定輸入資料形狀 input_shape 不需要包含資料的數量,(28,28)即可
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(10, activation='softmax')
])
# 優化器定義
optimizer = tf.keras.optimizers.SGD(0.1)
# 計算平均值
train_loss = tf.keras.metrics.Mean(name='train_loss')
# 訓練準確率計算
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')
# 計算平均值
test_loss = tf.keras.metrics.Mean(name='test_loss')
# 測驗準確率計算
test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')
# 我們可以用@tf.function 裝飾器來將 python 代碼轉成 tensorflow 的圖表示代碼,用于加速代碼運行速度
# 定義一個訓練模型的函式
@tf.function
def train_step(data, label):
# 固定寫法,使用 tf.GradientTape()來計算梯度
with tf.GradientTape() as tape:
# 傳入資料獲得模型預測結果
predictions = model(data)
# 對比 label 和 predictions 計算 loss
loss = tf.keras.losses.MSE(label, predictions)
# 傳入 loss 和模型引數,計算權值調整
gradients = tape.gradient(loss, model.trainable_variables)
# 進行權值調整
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
# 計算平均 loss
train_loss(loss)
# 計算平均準確率
train_accuracy(label, predictions)
# 我們可以用@tf.function 裝飾器來將 python 代碼轉成 tensorflow 的圖表示代碼,用于加速代碼運行速度
# 定義一個模型測驗的函式
@tf.function
def test_step(data, label):
# 傳入資料獲得模型預測結果
predictions = model(data)
# 對比 label 和 predictions 計算 loss
t_loss = tf.keras.losses.MSE(label, predictions)
# 計算平均 loss
test_loss(t_loss)
# 計算平均準確率
test_accuracy(label, predictions)
# 訓練 10 個周期(把所有訓練集資料訓練一次稱為訓練一個周期)
EPOCHS = 10
for epoch in range(EPOCHS):
# 訓練集回圈 60000/32=1875 次
for image, label in mnist_train:
# 每次回圈傳入一個批次的資料和標簽訓練模型
train_step(image, label)
# 測驗集回圈 10000/32=312.5->313 次
for test_image, test_label in mnist_test:
# 每次回圈傳入一個批次的資料和標簽進行測驗
test_step(test_image, test_label)
# 列印結果
template = 'Epoch {}, Loss: {:.3}, Accuracy: {:.3}, Test Loss: {:.3}, Test Accuracy: {:.3}'
print(template.format(epoch + 1,
train_loss.result(),
train_accuracy.result(),
test_loss.result(),
test_accuracy.result()))
請問大佬們這段代碼出現這種問題該怎么解決?
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
Traceback (most recent call last):
File "C:\Users\lenovo\Anaconda3\lib\urllib\request.py", line 1254, in do_open
h.request(req.get_method(), req.selector, req.data, headers)
File "C:\Users\lenovo\Anaconda3\lib\http\client.py", line 1107, in request
self._send_request(method, url, body, headers)
File "C:\Users\lenovo\Anaconda3\lib\http\client.py", line 1152, in _send_request
self.endheaders(body)
File "C:\Users\lenovo\Anaconda3\lib\http\client.py", line 1103, in endheaders
self._send_output(message_body)
File "C:\Users\lenovo\Anaconda3\lib\http\client.py", line 934, in _send_output
self.send(msg)
File "C:\Users\lenovo\Anaconda3\lib\http\client.py", line 877, in send
self.connect()
File "C:\Users\lenovo\Anaconda3\lib\http\client.py", line 1253, in connect
super().connect()
File "C:\Users\lenovo\Anaconda3\lib\http\client.py", line 849, in connect
(self.host,self.port), self.timeout, self.source_address)
File "C:\Users\lenovo\Anaconda3\lib\socket.py", line 694, in create_connection
for res in getaddrinfo(host, port, 0, SOCK_STREAM):
File "C:\Users\lenovo\Anaconda3\lib\socket.py", line 733, in getaddrinfo
for res in _socket.getaddrinfo(host, port, family, type, proto, flags):
socket.gaierror: [Errno 11002] getaddrinfo failed
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "C:\Users\lenovo\Anaconda3\lib\site-packages\tensorflow\python\keras\utils\data_utils.py", line 278, in get_file
urlretrieve(origin, fpath, dl_progress)
File "C:\Users\lenovo\Anaconda3\lib\urllib\request.py", line 188, in urlretrieve
with contextlib.closing(urlopen(url, data)) as fp:
File "C:\Users\lenovo\Anaconda3\lib\urllib\request.py", line 163, in urlopen
return opener.open(url, data, timeout)
File "C:\Users\lenovo\Anaconda3\lib\urllib\request.py", line 466, in open
response = self._open(req, data)
File "C:\Users\lenovo\Anaconda3\lib\urllib\request.py", line 484, in _open
'_open', req)
File "C:\Users\lenovo\Anaconda3\lib\urllib\request.py", line 444, in _call_chain
result = func(*args)
File "C:\Users\lenovo\Anaconda3\lib\urllib\request.py", line 1297, in https_open
context=self._context, check_hostname=self._check_hostname)
File "C:\Users\lenovo\Anaconda3\lib\urllib\request.py", line 1256, in do_open
raise URLError(err)
urllib.error.URLError: <urlopen error [Errno 11002] getaddrinfo failed>
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "C:/Users/lenovo/Documents/Tencent Files/2863543455/FileRecv/main.py", line 10, in <module>
(x_train, y_train), (x_test, y_test) = mnist.load_data()
File "C:\Users\lenovo\Anaconda3\lib\site-packages\tensorflow\python\keras\datasets\mnist.py", line 62, in load_data
'731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1')
File "C:\Users\lenovo\Anaconda3\lib\site-packages\tensorflow\python\keras\utils\data_utils.py", line 282, in get_file
raise Exception(error_msg.format(origin, e.errno, e.reason))
Exception: URL fetch failure on https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz: None -- [Errno 11002] getaddrinfo failed
Process finished with exit code 1
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/230023.html
上一篇:初步接觸Axure
下一篇:怎樣才能兩列同時關聯一張表?
