ML Beginner's Guide to Build an Agricultural Crop Classifier Using CNNs

CNN Crop Classifier
CNN Crop Classifier


Advances in agriculture technology are changing traditional farming practises; one important aspect of this shift is machine learning with Convolutional Neural Networks (CNNs). In this comprehensive guide, we'll show you how to create an automated agricultural crop classifier using CNNs. By breaking down the code step-by-step, this lecture aims to provide helpful insights on the application of CNNs in crop classification, illustrating how technology may transform the agricultural industry.

Understanding Dataset

1. Data Loading and Preprocessing

Import PyTorch and TensorFlow, the necessary libraries. Using the ImageFolder class, we load and preprocess datasets by utilising PyTorch's robust data handling features. Our CNN model is built on top of the agricultural-crops dataset, which is partitioned into classes (several crop types).

import os
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from import random_split
from import DataLoader, Dataset, Subset
from import random_split, SubsetRandomSampler
from torchvision import datasets, transforms, models 
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
import pytorch_lightning as pl
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from PIL import Image

Some set of adjustments is intended for use in an image preparation pipeline for machine learning. It incorporates normalisation to enhance convergence during training, scaling and cropping to standardise the input size, and random rotations and flips for data augmentation. Convolutional neural networks (CNNs) are frequently trained using this collection of transformations in the context of image datasets.‌

        transforms.RandomRotation(10),      # rotate +/- 10 degrees
        transforms.RandomHorizontalFlip(),  # reverse 50% of images
        transforms.Resize(224),             # resize shortest side to 224 pixels
        transforms.CenterCrop(224),         # crop longest side to 224 pixels at center
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])

2.Exploring the Dataset

A PyTorch ImageFolder dataset representing several classes of agricultural crops is set up by code from a provided directory structure. The class names and the total number of classes in the dataset are then printed out.



Provides the PyTorch Lightning preprocessing and data loading logic for an image classification problem involving agricultural crops. It is simple to integrate with a Lightning training loop because it offers dataloaders for testing, validation, and training.

class DataModule(pl.LightningDataModule):
    def __init__(self, transform=transform, batch_size=32):
        self.root_dir = "/kaggle/input/agricultural-crops-image-classification/Agricultural-crops"
        self.transform = transform
        self.batch_size = batch_size

    def setup(self, stage=None):
        dataset = datasets.ImageFolder(root=self.root_dir, transform=self.transform)
        n_data = len(dataset)
        n_train = int(0.5 * n_data)
        n_valid = int(0.2 * n_data)
        n_test = n_data - n_train - n_valid

        train_dataset, valid_dataset, test_dataset =, [n_train, n_valid, n_test])

        self.trainset = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        self.validset = DataLoader(valid_dataset, batch_size=self.batch_size)
        self.testset = DataLoader(test_dataset, batch_size=self.batch_size)

    def train_dataloader(self):
        return self.trainset
    def valid_dataloader(self):
        return self.validset
    def test_dataloader(self):
        return self.testset

3. Constructing the CNN Model

Convolutional Neural Network for an image classification task on agricultural crops using PyTorch Lightning. The model features fully linked and convolutional layers in a standard architecture. The PyTorch Lightning framework's training loop, validation, and testing are managed via the Lightning-specific methods configure_optimizers, training_step, validation_step, and test_step.

class ConvolutionalNetwork(LightningModule):
    def __init__(self):
        super(ConvolutionalNetwork, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 3, 1)
        self.conv2 = nn.Conv2d(6, 16, 3, 1)
        self.fc1 = nn.Linear(16 * 54 * 54, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 20)
        self.fc4 = nn.Linear(20, len(class_names))

    def forward(self, X):
        X = F.relu(self.conv1(X))
        X = F.max_pool2d(X, 2, 2)
        X = F.relu(self.conv2(X))
        X = F.max_pool2d(X, 2, 2)
        X = X.view(-1, 16 * 54 * 54)
        X = F.relu(self.fc1(X))
        X = F.relu(self.fc2(X))
        X = F.relu(self.fc3(X))
        X = self.fc4(X)
        return F.log_softmax(X, dim=1)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        X, y = train_batch
        y_hat = self(X)
        loss = F.cross_entropy(y_hat, y)
        pred = y_hat.argmax(dim=1, keepdim=True)
        acc = pred.eq(y.view_as(pred)).sum().item() / y.shape[0]
        self.log("train_loss", loss)
        self.log("train_acc", acc)
        return loss

    def validation_step(self, val_batch, batch_idx):
        X, y = val_batch
        y_hat = self(X)
        loss = F.cross_entropy(y_hat, y)
        pred = y_hat.argmax(dim=1, keepdim=True)
        acc = pred.eq(y.view_as(pred)).sum().item() / y.shape[0]
        self.log("val_loss", loss)
        self.log("val_acc", acc)

    def test_step(self, test_batch, batch_idx):
        X, y = test_batch
        y_hat = self(X)
        loss = F.cross_entropy(y_hat, y)
        pred = y_hat.argmax(dim=1, keepdim=True)
        acc = pred.eq(y.view_as(pred)).sum().item() / y.shape[0]
        self.log("test_loss", loss)
        self.log("test_acc", acc)

4. Training the Model

PyTorch Lightning is used to train the model on the training and validation datasets and assess its performance on the test dataset. To ensure that the code executes when the script is executed directly, the complete process is included behind a conditional block.

if __name__ == '__main__':
    datamodule = DataModule()
    model = ConvolutionalNetwork()
    trainer = pl.Trainer(max_epochs=50), datamodule)
    test_loader = datamodule.test_dataloader()

Training The Model

5.Data Visualization

Loads a set of photos from the training dataloader, arranges them into a grid, shows the original photos, tries to reverse the normalisation (albeit it might not work perfectly), and then shows the purportedly reversed photos. The matplotlib programme is used for the visualisation.

for images, labels in datamodule.train_dataloader():




Data Visualization

6. Evaluating Model Performance

Producing a thorough classification report with pertinent metrics and assessing the model's performance on the test dataset.

Model Perfromance


This interactive guide has provided a comprehensive explanation of how to use convolutional neural networks to construct a crop classifier for agricultural use. We have shown how CNNs can transform crop classification by dissecting the algorithm and illuminating each stage of the procedure one by one. This investigation into agricultural automation shows the transformative power of technology and offers a glimpse into a future where innovative solutions will streamline processes, increase productivity, and establish a sustainable farming environment.

Frequently Asked Questions

1.What is the CNN model for classification?

A CNN-based technique called the CNN classifier for image classification is intended to categorise images into various predetermined classifications. It gains the ability to accurately classify images by learning to extract pertinent features from input photos and map them to the appropriate classes.

2.How CNN works in plant disease detection?

This study uses a deep convolutional neural network to distinguish between healthy and diseased leaves and to diagnose disease in plants that are affected. The CNN model is made to work with both healthy and sick leaves; it is trained using photographs, and the input leaf determines the final result.

3.How AI is revolutionizing agriculture?

Artificial intelligence (AI) facilitates more accurate and efficient data-driven decision-making by processing massive datasets, automating laborious operations, and identifying patterns. These days, farmers and researchers can profit greatly from AI-driven technology in a variety of ways, from crop monitoring and harvesting to pest control, irrigation, and seeding.

4.How CNN works step by step?

The convolutional layer, which carries out the convolution operation, receives the image's pixels. The outcome is a jumbled map. The rectified feature map is produced by applying the convolved map to a ReLU function. ReLU layers and several convolutions are applied to the image in order to locate the features.

Train Your Vision/NLP/LLM Models 10X Faster

Book our demo with one of our product specialist

Book a Demo