Image Classification

Develope Classifier to classify 102 types of flowers

Project Title and Overview:

Title:

Flower Classification using Deep Learning and PyTorch

Overview:

This project involves building an image classifier to recognize different species of flowers. Using a dataset of 102 flower categories, the goal is to train a model capable of identifying flower species based on an image input. The classifier can be deployed in various applications, such as a mobile app for flower recognition.

Flowers Sample from Oxford Datasets

This project covers:

  • Loading and preprocessing the dataset.
  • Training a deep learning model.
  • Testing and evaluating the model's performance.
  • Using the model to predict flower species in new images.

Project Objective

Objective:

The primary goal is to develop an image classification model that can accurately recognize 102 categories of flowers from the dataset. The trained model will be used to predict flower species and could be extended for real-world applications such as smartphone apps.

Key Use Cases:

  • Flower Species Recognition: Automatically identify flower species based on an image input.
  • Educational Tool: Help students and botanists learn about various flower species.
  • Mobile App Integration: A potential extension where users can photograph a flower, and the app identifies it.

Technology Stack

Libraries:

  • PyTorch: For building and training the deep learning model.
  • Torchvision: For dataset transformations and loading.
  • NumPy: For numerical operations.
  • Matplotlib: For visualizing model performance and images.

Dataset Preparation

The dataset used for this project contains 102 categories of flowers, provided by the Oxford Flowers 102 dataset.

Data Augmentation and Preprocessing:

  • Transforms for Training Data:
    • Random Horizontal Flip
    • Random Vertical Flip
    • Random Rotation up to 180 degrees
    • Random Resized Crop to 224x224 pixels
    • Normalization (mean and standard deviation)
  • Transforms for Validation Data:
    • Resize to 225 pixels
    • Center Crop to 224x224 pixels
    • Normalization (mean and standard deviation
Here’s the code snippet for preprocessing:
train_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(180),
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

valid_transforms = transforms.Compose([
    transforms.Resize(225),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
Data Augmentation Result

Model Architecture:

Model Architecture

The image classification model is built using a pre-trained deep learning model (most likely a CNN like resnet50. The classifier is fine-tuned on the flower dataset to make it capable of recognizing 102 different flower species.

Key Layers:

  • Pre-trained Convolutional Base: Used for feature extraction.
  • Fully Connected Layers: Added to adapt the model to the specific flower dataset.
model = models.resnet50(pretrained=True)

# Define a new classifier for the model
model.classifier = nn.Sequential(OrderedDict([
    ('fc1', nn.Linear(25088, 4096)),
    ('relu', nn.ReLU()),
    ('dropout', nn.Dropout(0.5)),
    ('fc2', nn.Linear(4096, 102)),
    ('output', nn.LogSoftmax(dim=1))
]))

Optimizer and Loss Function:

  • Optimizer: Adam optimizer
  • Loss Function: Cross-Entropy Loss for multi-class classification.

Training and Evaluation

Training Process:

  • Epochs: The model is trained over multiple epochs (typically 70-80).
  • Learning Rate: Adjusted using learning rate schedulers for better convergence.
  • Validation: Model performance is evaluated on a validation set after each epoch and save the best model during training.
def train(model, epochs = 1, train_loader = train_loader, valid_loader = valid_loader, optimizer = optimizer, graph = False):

    import time

    epochs_end = 80

    valid_loss_min = np.Inf  # Track the changes in the validation losses

    elapsed_time = 0

    for e in range(epochs, epochs_end + epochs):

        train_loss = 0

        ###################
        # train the model #
        ###################
        start = time.time()

        # Tranning the model
        model.train()
        for images, labels in train_loader:

            # Shifting the computation on GPU if available otherwise on CPU
            if train_on_gpu:
                images, labels = images.cuda(), labels.cuda()

            # Set to the zero grad
            optimizer.zero_grad()

            # Pass data to the network
            output = model(images)
            loss = criterion(output, labels)

            # Use backprop
            loss.backward()

            # Set the weight by calling the optimizer
            optimizer.step()

            train_loss += loss.item() * len(images)  ## 32 in batches

        end = time.time()

        elapsed_time += end - start

        ###################
        # Validation #
        ###################
        valid_loss = 0
        valid_accuracy = 0

        # Shut down dropout layer
        model.eval()

        # No gradient calculation
        with torch.no_grad():

            for images, labels in valid_loader:

                if train_on_gpu:   
                    images, labels = images.cuda(), labels.cuda()

                output = model(images)

                # Calculate loss for this validation batch
                loss = criterion(output, labels)
                # Track validation loss
                valid_loss += loss.item()*len(images)  # 20 len of img in first batch

                # Calculate accuracy
                ps = torch.exp(output)
                top_ps, top_class = ps.topk(1, dim=1)
                equals = top_class == labels.view(*top_class.shape)
                valid_accuracy += torch.mean(equals.type(torch.FloatTensor)).item()*len(images)


        # Measures the loss of each epoch
        train_loss = train_loss/len(image_dataset_train_data)
        valid_loss = valid_loss/len(valid_loader.sampler)
        valid_accuracy = valid_accuracy/len(valid_loader.sampler)
        
        # Update learning rate
        scheduler.step(valid_loss)
        #scheduler.step()
        
       #################
        # Save Model #
       ################
        
        '''Save Model'''
        # Create the checkpoint with relevent information  
        if valid_loss <= valid_loss_min:
            
            print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
                 valid_loss_min, valid_loss))
            
            checkpoint_info = {
                                'model_state_dict': model.state_dict(),
                                'optimizer_state_dict': optimizer.state_dict(),
                                'scheduler_state_dict': scheduler.state_dict()
                            }
            torch.save(checkpoint_info, 'classifier.pt')
            valid_loss_min = valid_loss
                
            
        
        # Appending losses into the list for analysis
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)
        accuracy.append(valid_accuracy)

Result:

Link of the project code on GitHub

Figure 1: Accuracy on traning and validation set
Figure 2: Flower classification result (green color: show correctly classify)
Figure 3 (a)
Figure 3(a) and (b): Indivisual Classificaiton

Achievements and Future Work

Achievements:

  • Successfully trained a model to classify 102 flower species.
  • Achieved high accuracy on the test set (95%).
  • Implemented a data pipeline with transformations to improve model generalization.

Future Work:

  • Mobile Integration: Convert the trained model to TorchScript or ONNX format for deployment in a mobile app.
  • Model Improvement: Experiment with other architectures (e.g., ResNet, DenseNet) to improve accuracy.
  • Transfer Learning: Further fine-tuning of pre-trained models for other related classification tasks.

Conclusion

The flower classification project demonstrates the effectiveness of deep learning models in image classification tasks. Using a pre-trained model and applying transfer learning, I developed a model that can identify 102 flower species. This project highlights the potential of integrating AI into everyday applications like mobile apps, educational tools, and more.

References

  • Oxford Flowers 102 Dataset: Link
  • Udacity AI Programming with PyTorch course materials

Thank you for you interest!