TF: TensorFlow untuk dataset CIFAR-10
Jump to navigation
Jump to search
Berikut adalah contoh kode Python yang membangun model Convolutional Neural Network (CNN) menggunakan Keras untuk mengklasifikasikan dataset CIFAR-10, melakukan prediksi, dan memvisualisasikan hasilnya:
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt
import numpy as np
# Memuat dataset CIFAR-10
(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()
# Normalisasi nilai piksel ke rentang 0 hingga 1
x_train, x_test = x_train / 255.0, x_test / 255.0
# Mendefinisikan nama kelas
class_names = ['Pesawat', 'Mobil', 'Burung', 'Kucing', 'Rusa', 'Anjing', 'Katak', 'Kuda', 'Kapal', 'Truk']
# Membangun model CNN
model = models.Sequential([
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.Flatten(),
layers.Dense(64, activation='relu'),
layers.Dense(10)
])
# Kompilasi model
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# Melatih model
history = model.fit(x_train, y_train, epochs=10,
validation_data=(x_test, y_test))
# Mengevaluasi model
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f'\nAkurasi pada data uji: {test_acc:.2f}')
# Membuat prediksi
probability_model = models.Sequential([model,
layers.Softmax()])
predictions = probability_model.predict(x_test)
# Fungsi untuk memplot gambar dengan prediksi
def plot_image(i, predictions_array, true_label, img):
true_label, img = true_label[i][0], img[i]
plt.grid(False)
plt.xticks([])
plt.yticks([])
plt.imshow(img, cmap=plt.cm.binary)
predicted_label = np.argmax(predictions_array)
if predicted_label == true_label:
color = 'blue'
else:
color = 'red'
plt.xlabel(f"{class_names[predicted_label]} ({class_names[true_label]})", color=color)
# Fungsi untuk memplot nilai prediksi
def plot_value_array(i, predictions_array, true_label):
true_label = true_label[i][0]
plt.grid(False)
plt.xticks(range(10))
plt.yticks([])
thisplot = plt.bar(range(10), predictions_array, color="#777777")
plt.ylim([0, 1])
predicted_label = np.argmax(predictions_array)
thisplot[predicted_label].set_color('red')
thisplot[true_label].set_color('blue')
# Memvisualisasikan beberapa gambar beserta prediksinya
num_rows = 5
num_cols = 3
num_images = num_rows * num_cols
plt.figure(figsize=(2*2*num_cols, 2*num_rows))
for i in range(num_images):
plt.subplot(num_rows, 2*num_cols, 2*i+1)
plot_image(i, predictions[i], y_test, x_test)
plt.subplot(num_rows, 2*num_cols, 2*i+2)
plot_value_array(i, predictions[i], y_test)
plt.tight_layout()
plt.show()
Penjelasan Kode:
1. Memuat dan Mempersiapkan Data:
- Dataset CIFAR-10 dimuat menggunakan `tensorflow.keras.datasets.cifar10.load_data()`.
- Data gambar dinormalisasi ke rentang 0 hingga 1 untuk mempercepat konvergensi model.
2. Membangun Model CNN:
- Model Sequential Keras digunakan untuk menumpuk lapisan-lapisan jaringan.
- Tiga lapisan konvolusi (`Conv2D`) dengan fungsi aktivasi ReLU dan lapisan pooling (`MaxPooling2D`) ditambahkan untuk ekstraksi fitur.
- Lapisan flatten digunakan untuk meratakan output sebelum memasuki lapisan dense.
- Dua lapisan dense digunakan, yang terakhir memiliki 10 neuron sesuai dengan jumlah kelas.
3. Kompilasi dan Pelatihan Model:
- Model dikompilasi dengan optimizer Adam dan fungsi loss `SparseCategoricalCrossentropy`.
- Model dilatih selama 10 epoch dengan data pelatihan dan divalidasi menggunakan data uji.
4. Evaluasi Model:
- Model dievaluasi pada data uji untuk mendapatkan akurasi.
5. Prediksi dan Visualisasi:
- Prediksi dibuat untuk data uji.
- Fungsi `plot_image` dan `plot_value_array` digunakan untuk menampilkan gambar beserta prediksi dan probabilitasnya.
- Beberapa gambar beserta prediksinya ditampilkan dalam grid.
Catatan:
- Untuk mencapai akurasi yang lebih tinggi, Anda dapat menambahkan teknik augmentasi data atau menggunakan arsitektur model yang lebih kompleks.
- Pastikan Anda memiliki semua pustaka yang diperlukan terinstal, seperti TensorFlow dan Matplotlib.