15.5 手写数字
辨认图片里的数字,大部分是手写体
训练资源下载
源代码:
import structimport tensorflow as tf#from matplotlib import pyplot as pltimport numpy as npdef get_img_datas(file_path): image_file = None with open(file_path, 'rb') as file: image_file = file.read() image_datas = [] file_index = 0 magic_num, data_num, rows, cols = struct.unpack_from('>IIII', image_file, file_index) file_index += struct.calcsize('>IIII') read_formate = ">{}B".format(rows * cols) for image_index in range(0, data_num): data = struct.unpack_from(read_formate, image_file, file_index) data = np.reshape(data, (28, 28)) # print(data) image_datas.append(data) file_index += struct.calcsize(read_formate) print('read image datas {}/{}'.format(len(image_datas), data_num)) return image_datasdef get_label_datas(file_path): with open(file_path, 'rb') as file: label_file = file.read() label_datas = [] file_index = 0 magic_num, data_num = struct.unpack_from('>II', label_file, file_index) file_index += struct.calcsize(">II") for label_index in range(0, data_num): data = struct.unpack_from(">B", label_file, file_index) label_datas.append(data) file_index += struct.calcsize(">B") print('read label datas {}/{}'.format(len(label_datas), data_num)) return label_datasdef pre_process(image_data, label_data): image_data = np.array(image_data) image_data = image_data / 255. #image_data = np.ceil(image_data) # print(image_data[0]) label_data = np.reshape(label_data, len(label_data)) # (60000,1)===>(60000,) # print(np.shape(image_data)) # image_data = np.reshape(image_data, (6000, (28, 28))) # print(np.shape(image_data)) return image_data, label_datamodels = tf.keras.modelslayers = tf.keras.layersmodel = models.Sequential()model.add(layers.Flatten(input_shape=(28, 28)))model.add(layers.Dense(128, activation='relu'))model.add(layers.Dense(128, activation='relu'))model.add(layers.Dense(10, activation='softmax'))model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])model.summary()train_image = get_img_datas("D:/images-idx3-ubyte/train-images-idx3-ubyte")train_label = get_label_datas("D:/images-idx3-ubyte/train-labels-idx1-ubyte")train_image, train_label = pre_process(train_image, train_label)# print(train_image[0])# train_image = train_image.reshape(0, 28, 28)history = model.fit(x=train_image, y=train_label, epochs=15)loss = history.history['loss']acc = history.history['accuracy']# val_loss = history.history['val_loss']'''
plt.figure(figsize=(15, 5), frameon=True)
plt.subplot(1, 2, 1)
# plt.plot(loss,label="Loss")
plt.plot(acc, label='Training Accuracy')
# plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.grid()
plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
# plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
'''test_image = get_img_datas("D:/images-idx3-ubyte/t10k-images-idx3-ubyte")test_label = get_label_datas("D:/images-idx3-ubyte/t10k-labels-idx1-ubyte")test_image, test_label = pre_process(train_image, train_label)print("===========================")print(model.evaluate(test_image, test_label, verbose="2"))print("===========================")