TF: TensorFlow untuk dataset Fashion-MNIST

From OnnoWiki
Jump to navigation Jump to search

Berikut adalah contoh kode Python yang membangun model Convolutional Neural Network (CNN) menggunakan Keras untuk dataset Fashion-MNIST, melakukan prediksi, dan menampilkan visualisasi hasilnya:


import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.utils import to_categorical

# Memuat dataset Fashion-MNIST
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Mengubah dimensi data agar sesuai dengan input CNN
train_images = train_images.reshape((train_images.shape[0], 28, 28, 1))
test_images = test_images.reshape((test_images.shape[0], 28, 28, 1))

# Normalisasi nilai piksel ke rentang 0-1
train_images, test_images = train_images / 255.0, test_images / 255.0

# One-hot encoding untuk label
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)

# Membuat model CNN
model = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    MaxPooling2D((2, 2)),
    Flatten(),
    Dense(128, activation='relu'),
    Dense(10, activation='softmax')
])

# Kompilasi model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Melatih model
model.fit(train_images, train_labels, epochs=5, batch_size=64, validation_split=0.2)

# Evaluasi model
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f'Akurasi pada data uji: {test_acc:.2f}')

# Melakukan prediksi pada data uji
predictions = model.predict(test_images) 

# Fungsi untuk menampilkan gambar dengan prediksi dan label sebenarnya
def plot_image(i, predictions_array, true_label, img):
    predictions_array, true_label, img = predictions_array[i], np.argmax(true_label[i]), img[i].reshape(28, 28)
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(img, cmap=plt.cm.binary)
    predicted_label = np.argmax(predictions_array)
    if predicted_label == true_label:
        color = 'blue'
    else:
        color = 'red'
    plt.xlabel(f"{class_names[predicted_label]} ({class_names[true_label]})", color=color)

# Fungsi untuk menampilkan grafik bar dari prediksi
def plot_value_array(i, predictions_array, true_label):
    predictions_array, true_label = predictions_array[i], np.argmax(true_label[i])
    plt.grid(False)
    plt.xticks(range(10))
    plt.yticks([])
    thisplot = plt.bar(range(10), predictions_array, color="#777777")
    plt.ylim([0, 1])
    predicted_label = np.argmax(predictions_array)
    thisplot[predicted_label].set_color('red')
    thisplot[true_label].set_color('blue')

# Daftar nama kelas
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# Menampilkan beberapa gambar dengan prediksi dan label sebenarnya
num_rows, num_cols = 5, 3
num_images = num_rows * num_cols
plt.figure(figsize=(2 * 2 * num_cols, 2 * num_rows))
for i in range(num_images):
    plt.subplot(num_rows, 2 * num_cols, 2 * i + 1)
    plot_image(i, predictions, test_labels, test_images)
    plt.subplot(num_rows, 2 * num_cols, 2 * i + 2)
    plot_value_array(i, predictions, test_labels)
plt.tight_layout()
plt.show()


Kode di atas melakukan langkah-langkah berikut:

  1. Memuat dataset Fashion-MNIST dan membagi menjadi data latih dan uji.
  2. Mengubah dimensi data gambar agar sesuai dengan input CNN dan melakukan normalisasi nilai piksel ke rentang 0-1.
  3. Melakukan one-hot encoding pada label.
  4. Membangun model CNN sederhana dengan satu lapisan konvolusi, satu lapisan pooling, dan dua lapisan dense.
  5. Mengompilasi dan melatih model pada data latih.
  6. Mengevaluasi model pada data uji dan mencetak akurasi.
  7. Melakukan prediksi pada data uji.
  8. Menampilkan beberapa gambar dari data uji dengan prediksi model dan label sebenarnya, serta grafik bar yang menunjukkan distribusi probabilitas prediksi untuk setiap kelas.

Pastikan Anda telah menginstal pustaka yang diperlukan seperti TensorFlow dan Matplotlib sebelum menjalankan kode ini.


Pranala Menarik