Using the Segment Anything Model and GEOBIA to classify vineyard area¶
Segment Anything is a new project by Meta to build two important components:
- A large dataset for image segmentation
- The Segment Anything Model (SAM) as a basic model for image segmentation
It was introduced in the Segment Anything paper by Alexander Kirillov et al.
This takes inspiration from the field of NLP, where basic models and large data sets (worth billions of tokens) have become commonplace.
The project leads to the creation of a large dataset, a segmentation model, and a data engine, all in a loop.
Since image segmentation is one of the main tasks in computer vision, the authors chose it as a starting point for such large models and datasets. In both science and artificial intelligence, image segmentation has several potential uses. This includes biomedical image analysis, photo editing, and autonomous driving, among others.
To solve any of these problems, you must train specialized models that can only perform one task. This requires extensive domain knowledge and time required to collect specific data, not to mention the hours of training required for deep learning models.
However, the Segment Anything project aims to democratize the world of image segmentation. By open-sourcing both the dataset and the model (SAM for short), it opens up a huge range of possibilities.
Segment Anything Model¶
The Segment Anything model is an approach to building a fully automatic image segmentation model with minimal human intervention.
Previous deep learning approaches required specialized training data collection, manual annotation, and hours of training. These approaches work well, but they also require a substantial amount of model retraining when changing the dataset.
With SAM, we now have a generalizable and actionable image segmentation model that can crop almost anything from an image.
SAM is a deep learning model (based on transformers). And as with any deep learning, it was trained on a huge number of images and masks – more than a billion masks across 11 million images, to be precise. That's a considerable number. The dataset is called the Segment Anything dataset, which we will delve into later in this article
Still, how does SAM know which objects to target in an image? The fact is that you don't always know. That's where SAM's callable design comes in.
SAM as a prompt image segmentation model¶
SAM can receive instructions from users on which area to target precisely. As of the current version, we can provide three different prompts to SAM:
- When clicking on a point
- Drawing a bounding box
- Drawing a rough mask on an object
The authors are also working on a version of SAM that accepts text inputs as prompts, just like in the case of language models.
Because SAM was trained on 1 billion masks and allows users to provide warnings, it can always generate a valid mask. What exactly does this mean?
In the case of a click like prompt, sometimes users may click on a person. If for some reason the SAM does not feel confident removing the person's mask, they will attempt to mask the dress. In a similar way, it tries to target other “parts of an object” whenever a prompt is given to SAM.
How does the Segment Anything model work?¶
The creators of the Segment Anything Model take inspiration from chat-based Large Language Models, where the request is an integral part of the pipeline.
Also in SAM there are three important mode components:
- An image encoder.
- A prompt encoder.
- A mask decoder.
When we provide an image as input to the Segment Anything model, it first passes through an image encoder and produces a single embedding for the entire image.
There is also a prompt encoder for dots, boxes, or text as prompts. For points, the x and y coordinates, along with foreground and background information, become input to the encoder. For boxes, the bounding box coordinates become the input to the encoder, and for text (not released at the time of writing), tokens become the input.
If we provide a mask as input, it directly goes through a downsampling stage. Downsampling happens using 2D convolutional layers. Then the model concatenates it with the image embedding to obtain the final vector.
Any vector that the model obtains from the vector + prompt image embedding passes through a lightweight decoder that creates the final segmentation mask.
We obtain possible valid masks along with a confidence score as a result.
The image encoder is one of the most powerful and essential components of SAM. It is built on a pre-trained Vision Transformer model from MAE. But to maintain real-time performance in browsers, compromises were made.
For the prompt encoder, dots, boxes, and text act as sparse inputs, and masks act as dense inputs. The creators of SAM represent points and bounding boxes using positional encodings and sum them with learned embeddings. For text prompts, SAM uses the CLIP text encoder. For masks like prompts, after downsampling occurs through convolutional layers, the embedding is added element by element with the input image embedding.
The Segment Anything dataset¶
The foundation of any innovative deep learning model is the dataset on which it is trained. And it’s no different for the Anything Segmentation Model either.
The Segment Anything dataset contains over 11 million images and 1.1 billion masks. The final dataset is called the SA-1B dataset.
This dataset is certainly needed to train a Segment Anything model. But we also know that such datasets do not exist and that it is impossible to manually annotate so many images.
SAM pre-trained weights are available from the segment-anything open source project.
As of now, there are model weights available for three different scales of Vision Transformer models.
ViT-B SAM
ViT-L SAM
ViT-H SAM
The project also has an installable pip command
Let's then use SAM to segment a high-resolution satellite image downloaded from Google Earth Engine and classify each generated segment using GEOBIA:
First let's download the NAIP image of the area of interest:
import geopandas as gpd
import json
import folium
from folium import plugins
from IPython.display import Image
import numpy as np
from matplotlib import pyplot as plt
import cv2
import ee
ee.Authenticate()
ee.Initialize(project='talvaradol')
We select the coordinates of the area of interest:
AOI = ee.Geometry.Polygon([[[-101.903489, 20.736243],
[-101.803922, 20.737896],
[-101.802756, 20.674544],
[-101.902282, 20.672896]]])
We filter by date, apply the median to obtain a unique mosaic:
startDateviz = ee.Date.fromYMD(2017,1,1);
endDateviz = ee.Date.fromYMD(2018,12,29);
collectionviz2 = ee.ImageCollection("USDA/NAIP/DOQQ").filterDate(startDateviz,endDateviz).filterBounds(AOI)
Naip = collectionviz2.median().clip(AOI).uint8()
So we can export the Image:
image_naip = Naip.select(['R','G','B']).reproject('EPSG:4326', None, 1)
task = ee.batch.Export.image.toDrive(image=image_naip,
crs='EPSG:4326',
scale=1,
fileFormat='GeoTIFF',
description='Gto' ,
maxPixels=1e13,
folder='Eagave',
region= AOI)
task.start()
To assist in collecting samples, we will download a preliminary mapping of viticulture areas:
collectionUsda = ee.ImageCollection('USDA/NASS/CDL').filter(ee.Filter.date('2020-01-01', '2021-12-31')).first().clip(AOI);
cropWinesUsda = collectionUsda.select('cropland').eq(69).selfMask();
cropWinesUsda = cropWinesUsda.select(['cropland']).reproject('EPSG:4326', None, 1)
task = ee.batch.Export.image.toDrive(image=cropWinesUsda,
crs='EPSG:4326',
scale=1,
fileFormat='GeoTIFF',
description='Gto' ,
maxPixels=1e13,
folder='Eagave',
region= AOI)
task.start()
With the images downloaded to Google Drive, let's perform the segmentation with SAM:
!pip install segment-geospatial leafmap localtileserver
!pip install rasterio
Requirement already satisfied: segment-geospatial in /usr/local/lib/python3.11/dist-packages (0.12.3) Requirement already satisfied: leafmap in /usr/local/lib/python3.11/dist-packages (0.42.9) Requirement already satisfied: localtileserver in /usr/local/lib/python3.11/dist-packages (0.10.6) Requirement already satisfied: fiona in /usr/local/lib/python3.11/dist-packages (from segment-geospatial) (1.10.1) Requirement already satisfied: gdown in /usr/local/lib/python3.11/dist-packages (from segment-geospatial) (5.2.0) Requirement already satisfied: geopandas in /usr/local/lib/python3.11/dist-packages (from segment-geospatial) (1.0.1) Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.11/dist-packages (from segment-geospatial) (0.27.1) Requirement already satisfied: ipympl in /usr/local/lib/python3.11/dist-packages (from segment-geospatial) (0.9.6) Requirement already satisfied: matplotlib in /usr/local/lib/python3.11/dist-packages (from segment-geospatial) (3.10.0) Requirement already satisfied: opencv-python-headless in /usr/local/lib/python3.11/dist-packages (from segment-geospatial) (4.11.0.86) Requirement already satisfied: patool in /usr/local/lib/python3.11/dist-packages (from segment-geospatial) (3.1.0) Requirement already satisfied: pycocotools in /usr/local/lib/python3.11/dist-packages (from segment-geospatial) (2.0.8) Requirement already satisfied: pyproj in /usr/local/lib/python3.11/dist-packages (from segment-geospatial) (3.7.0) Requirement already satisfied: rasterio in /usr/local/lib/python3.11/dist-packages (from segment-geospatial) (1.4.3) Requirement already satisfied: rioxarray in /usr/local/lib/python3.11/dist-packages (from segment-geospatial) (0.18.2) Requirement already satisfied: sam2 in /usr/local/lib/python3.11/dist-packages (from segment-geospatial) (1.1.0) Requirement already satisfied: scikit-image in /usr/local/lib/python3.11/dist-packages (from segment-geospatial) (0.25.0) Requirement already satisfied: scikit-learn in /usr/local/lib/python3.11/dist-packages (from segment-geospatial) (1.6.1) Requirement already satisfied: segment-anything-hq in /usr/local/lib/python3.11/dist-packages (from segment-geospatial) (0.3) Requirement already satisfied: segment-anything-py in /usr/local/lib/python3.11/dist-packages (from segment-geospatial) (1.0.1) Requirement already satisfied: timm in /usr/local/lib/python3.11/dist-packages (from segment-geospatial) (1.0.14) Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from segment-geospatial) (4.67.1) Requirement already satisfied: xarray in /usr/local/lib/python3.11/dist-packages (from segment-geospatial) (2025.1.1) Requirement already satisfied: xyzservices in /usr/local/lib/python3.11/dist-packages (from segment-geospatial) (2025.1.0) Requirement already satisfied: anywidget in /usr/local/lib/python3.11/dist-packages (from leafmap) (0.9.13) Requirement already satisfied: bqplot in /usr/local/lib/python3.11/dist-packages (from leafmap) (0.12.44) Requirement already satisfied: colour in /usr/local/lib/python3.11/dist-packages (from leafmap) (0.1.5) Requirement already satisfied: duckdb in /usr/local/lib/python3.11/dist-packages (from leafmap) (1.1.3) Requirement already satisfied: folium in /usr/local/lib/python3.11/dist-packages (from leafmap) (0.19.4) Requirement already satisfied: geojson in /usr/local/lib/python3.11/dist-packages (from leafmap) (3.2.0) Requirement already satisfied: ipyevents in /usr/local/lib/python3.11/dist-packages (from leafmap) (2.0.2) Requirement already satisfied: ipyfilechooser in /usr/local/lib/python3.11/dist-packages (from leafmap) (0.6.0) Requirement already satisfied: ipyleaflet in /usr/local/lib/python3.11/dist-packages (from leafmap) (0.19.2) Requirement already satisfied: ipyvuetify in /usr/local/lib/python3.11/dist-packages (from leafmap) (1.10.0) Requirement already satisfied: ipywidgets in /usr/local/lib/python3.11/dist-packages (from leafmap) (7.7.1) Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from leafmap) (1.26.4) Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (from leafmap) (2.2.2) Requirement already satisfied: plotly in /usr/local/lib/python3.11/dist-packages (from leafmap) (5.24.1) Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from leafmap) (5.9.5) Requirement already satisfied: pyshp in /usr/local/lib/python3.11/dist-packages (from leafmap) (2.3.1) Requirement already satisfied: pystac-client in /usr/local/lib/python3.11/dist-packages (from leafmap) (0.8.5) Requirement already satisfied: python-box in /usr/local/lib/python3.11/dist-packages (from leafmap) (7.3.2) Requirement already satisfied: scooby in /usr/local/lib/python3.11/dist-packages (from leafmap) (0.10.0) Requirement already satisfied: whiteboxgui in /usr/local/lib/python3.11/dist-packages (from leafmap) (2.3.0) Requirement already satisfied: click in /usr/local/lib/python3.11/dist-packages (from localtileserver) (8.1.8) Requirement already satisfied: flask<4,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from localtileserver) (3.1.0) Requirement already satisfied: Flask-Caching in /usr/local/lib/python3.11/dist-packages (from localtileserver) (2.3.0) Requirement already satisfied: flask-cors in /usr/local/lib/python3.11/dist-packages (from localtileserver) (5.0.0) Requirement already satisfied: flask-restx>=1.3.0 in /usr/local/lib/python3.11/dist-packages (from localtileserver) (1.3.0) Requirement already satisfied: rio-tiler in /usr/local/lib/python3.11/dist-packages (from localtileserver) (7.4.0) Requirement already satisfied: rio-cogeo in /usr/local/lib/python3.11/dist-packages (from localtileserver) (5.4.1) Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from localtileserver) (2.32.3) Requirement already satisfied: server-thread in /usr/local/lib/python3.11/dist-packages (from localtileserver) (0.3.0) Requirement already satisfied: werkzeug in /usr/local/lib/python3.11/dist-packages (from localtileserver) (3.1.3) Requirement already satisfied: Jinja2>=3.1.2 in /usr/local/lib/python3.11/dist-packages (from flask<4,>=2.0.0->localtileserver) (3.1.5) Requirement already satisfied: itsdangerous>=2.2 in /usr/local/lib/python3.11/dist-packages (from flask<4,>=2.0.0->localtileserver) (2.2.0) Requirement already satisfied: blinker>=1.9 in /usr/local/lib/python3.11/dist-packages (from flask<4,>=2.0.0->localtileserver) (1.9.0) Requirement already satisfied: aniso8601>=0.82 in /usr/local/lib/python3.11/dist-packages (from flask-restx>=1.3.0->localtileserver) (10.0.0) Requirement already satisfied: jsonschema in /usr/local/lib/python3.11/dist-packages (from flask-restx>=1.3.0->localtileserver) (4.23.0) Requirement already satisfied: pytz in /usr/local/lib/python3.11/dist-packages (from flask-restx>=1.3.0->localtileserver) (2024.2) Requirement already satisfied: importlib-resources in /usr/local/lib/python3.11/dist-packages (from flask-restx>=1.3.0->localtileserver) (6.5.2) Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.11/dist-packages (from werkzeug->localtileserver) (3.0.2) Requirement already satisfied: psygnal>=0.8.1 in /usr/local/lib/python3.11/dist-packages (from anywidget->leafmap) (0.11.1) Requirement already satisfied: typing-extensions>=4.2.0 in /usr/local/lib/python3.11/dist-packages (from anywidget->leafmap) (4.12.2) Requirement already satisfied: ipykernel>=4.5.1 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->leafmap) (5.5.6) Requirement already satisfied: ipython-genutils~=0.2.0 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->leafmap) (0.2.0) Requirement already satisfied: traitlets>=4.3.1 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->leafmap) (5.7.1) Requirement already satisfied: widgetsnbextension~=3.6.0 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->leafmap) (3.6.10) Requirement already satisfied: ipython>=4.0.0 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->leafmap) (7.34.0) Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->leafmap) (3.0.13) Requirement already satisfied: traittypes>=0.0.6 in /usr/local/lib/python3.11/dist-packages (from bqplot->leafmap) (0.2.1) Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas->leafmap) (2.8.2) Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas->leafmap) (2025.1) Requirement already satisfied: attrs>=19.2.0 in /usr/local/lib/python3.11/dist-packages (from fiona->segment-geospatial) (24.3.0) Requirement already satisfied: certifi in /usr/local/lib/python3.11/dist-packages (from fiona->segment-geospatial) (2024.12.14) Requirement already satisfied: click-plugins>=1.0 in /usr/local/lib/python3.11/dist-packages (from fiona->segment-geospatial) (1.1.1) Requirement already satisfied: cligj>=0.5 in /usr/local/lib/python3.11/dist-packages (from fiona->segment-geospatial) (0.7.2) Requirement already satisfied: cachelib<0.10.0,>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from Flask-Caching->localtileserver) (0.9.0) Requirement already satisfied: branca>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from folium->leafmap) (0.8.1) Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.11/dist-packages (from gdown->segment-geospatial) (4.12.3) Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from gdown->segment-geospatial) (3.17.0) Requirement already satisfied: pyogrio>=0.7.2 in /usr/local/lib/python3.11/dist-packages (from geopandas->segment-geospatial) (0.10.0) Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from geopandas->segment-geospatial) (24.2) Requirement already satisfied: shapely>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from geopandas->segment-geospatial) (2.0.6) Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub->segment-geospatial) (2024.10.0) Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub->segment-geospatial) (6.0.2) Requirement already satisfied: jupyter-leaflet<0.20,>=0.19 in /usr/local/lib/python3.11/dist-packages (from ipyleaflet->leafmap) (0.19.2) Requirement already satisfied: pillow in /usr/local/lib/python3.11/dist-packages (from ipympl->segment-geospatial) (11.1.0) Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib->segment-geospatial) (1.3.1) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib->segment-geospatial) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib->segment-geospatial) (4.55.5) Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib->segment-geospatial) (1.4.8) Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib->segment-geospatial) (3.2.1) Requirement already satisfied: ipyvue<2,>=1.7 in /usr/local/lib/python3.11/dist-packages (from ipyvuetify->leafmap) (1.11.2) Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.11/dist-packages (from plotly->leafmap) (9.0.0) Requirement already satisfied: pystac>=1.10.0 in /usr/local/lib/python3.11/dist-packages (from pystac[validation]>=1.10.0->pystac-client->leafmap) (1.12.1) Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->localtileserver) (3.4.1) Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->localtileserver) (3.10) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->localtileserver) (2.3.0) Requirement already satisfied: affine in /usr/local/lib/python3.11/dist-packages (from rasterio->segment-geospatial) (2.4.0) Requirement already satisfied: morecantile<7.0,>=5.0 in /usr/local/lib/python3.11/dist-packages (from rio-cogeo->localtileserver) (6.2.0) Requirement already satisfied: pydantic~=2.0 in /usr/local/lib/python3.11/dist-packages (from rio-cogeo->localtileserver) (2.10.5) Requirement already satisfied: cachetools in /usr/local/lib/python3.11/dist-packages (from rio-tiler->localtileserver) (5.5.1) Requirement already satisfied: color-operations in /usr/local/lib/python3.11/dist-packages (from rio-tiler->localtileserver) (0.1.6) Requirement already satisfied: httpx in /usr/local/lib/python3.11/dist-packages (from rio-tiler->localtileserver) (0.28.1) Requirement already satisfied: numexpr in /usr/local/lib/python3.11/dist-packages (from rio-tiler->localtileserver) (2.10.2) Requirement already satisfied: torch>=2.5.1 in /usr/local/lib/python3.11/dist-packages (from sam2->segment-geospatial) (2.5.1+cu121) Requirement already satisfied: torchvision>=0.20.1 in /usr/local/lib/python3.11/dist-packages (from sam2->segment-geospatial) (0.20.1+cu121) Requirement already satisfied: hydra-core>=1.3.2 in /usr/local/lib/python3.11/dist-packages (from sam2->segment-geospatial) (1.3.2) Requirement already satisfied: iopath>=0.1.10 in /usr/local/lib/python3.11/dist-packages (from sam2->segment-geospatial) (0.1.10) Requirement already satisfied: scipy>=1.11.2 in /usr/local/lib/python3.11/dist-packages (from scikit-image->segment-geospatial) (1.13.1) Requirement already satisfied: networkx>=3.0 in /usr/local/lib/python3.11/dist-packages (from scikit-image->segment-geospatial) (3.4.2) Requirement already satisfied: imageio!=2.35.0,>=2.33 in /usr/local/lib/python3.11/dist-packages (from scikit-image->segment-geospatial) (2.36.1) Requirement already satisfied: tifffile>=2022.8.12 in /usr/local/lib/python3.11/dist-packages (from scikit-image->segment-geospatial) (2025.1.10) Requirement already satisfied: lazy-loader>=0.4 in /usr/local/lib/python3.11/dist-packages (from scikit-image->segment-geospatial) (0.4) Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn->segment-geospatial) (1.4.2) Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn->segment-geospatial) (3.5.0) Requirement already satisfied: uvicorn in /usr/local/lib/python3.11/dist-packages (from server-thread->localtileserver) (0.34.0) Requirement already satisfied: safetensors in /usr/local/lib/python3.11/dist-packages (from timm->segment-geospatial) (0.5.2) Requirement already satisfied: ipytree in /usr/local/lib/python3.11/dist-packages (from whiteboxgui->leafmap) (0.2.2) Requirement already satisfied: whitebox in /usr/local/lib/python3.11/dist-packages (from whiteboxgui->leafmap) (2.3.5) Requirement already satisfied: omegaconf<2.4,>=2.2 in /usr/local/lib/python3.11/dist-packages (from hydra-core>=1.3.2->sam2->segment-geospatial) (2.3.0) Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.11/dist-packages (from hydra-core>=1.3.2->sam2->segment-geospatial) (4.9.3) Requirement already satisfied: portalocker in /usr/local/lib/python3.11/dist-packages (from iopath>=0.1.10->sam2->segment-geospatial) (3.1.1) Requirement already satisfied: jupyter-client in /usr/local/lib/python3.11/dist-packages (from ipykernel>=4.5.1->ipywidgets->leafmap) (6.1.12) Requirement already satisfied: tornado>=4.2 in /usr/local/lib/python3.11/dist-packages (from ipykernel>=4.5.1->ipywidgets->leafmap) (6.3.3) Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.11/dist-packages (from ipython>=4.0.0->ipywidgets->leafmap) (75.1.0) Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.11/dist-packages (from ipython>=4.0.0->ipywidgets->leafmap) (0.19.2) Requirement already satisfied: decorator in /usr/local/lib/python3.11/dist-packages (from ipython>=4.0.0->ipywidgets->leafmap) (4.4.2) Requirement already satisfied: pickleshare in /usr/local/lib/python3.11/dist-packages (from ipython>=4.0.0->ipywidgets->leafmap) (0.7.5) Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from ipython>=4.0.0->ipywidgets->leafmap) (3.0.50) Requirement already satisfied: pygments in /usr/local/lib/python3.11/dist-packages (from ipython>=4.0.0->ipywidgets->leafmap) (2.18.0) Requirement already satisfied: backcall in /usr/local/lib/python3.11/dist-packages (from ipython>=4.0.0->ipywidgets->leafmap) (0.2.0) Requirement already satisfied: matplotlib-inline in /usr/local/lib/python3.11/dist-packages (from ipython>=4.0.0->ipywidgets->leafmap) (0.1.7) Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.11/dist-packages (from ipython>=4.0.0->ipywidgets->leafmap) (4.9.0) Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from pydantic~=2.0->rio-cogeo->localtileserver) (0.7.0) Requirement already satisfied: pydantic-core==2.27.2 in /usr/local/lib/python3.11/dist-packages (from pydantic~=2.0->rio-cogeo->localtileserver) (2.27.2) Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.11/dist-packages (from jsonschema->flask-restx>=1.3.0->localtileserver) (2024.10.1) Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.11/dist-packages (from jsonschema->flask-restx>=1.3.0->localtileserver) (0.36.1) Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from jsonschema->flask-restx>=1.3.0->localtileserver) (0.22.3) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas->leafmap) (1.17.0) Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch>=2.5.1->sam2->segment-geospatial) (12.1.105) Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch>=2.5.1->sam2->segment-geospatial) (12.1.105) Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch>=2.5.1->sam2->segment-geospatial) (12.1.105) Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch>=2.5.1->sam2->segment-geospatial) (9.1.0.70) Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.5.1->sam2->segment-geospatial) (12.1.3.1) Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.11/dist-packages (from torch>=2.5.1->sam2->segment-geospatial) (11.0.2.54) Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.11/dist-packages (from torch>=2.5.1->sam2->segment-geospatial) (10.3.2.106) Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.11/dist-packages (from torch>=2.5.1->sam2->segment-geospatial) (11.4.5.107) Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.11/dist-packages (from torch>=2.5.1->sam2->segment-geospatial) (12.1.0.106) Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=2.5.1->sam2->segment-geospatial) (2.21.5) Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch>=2.5.1->sam2->segment-geospatial) (12.1.105) Requirement already satisfied: triton==3.1.0 in /usr/local/lib/python3.11/dist-packages (from torch>=2.5.1->sam2->segment-geospatial) (3.1.0) Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.5.1->sam2->segment-geospatial) (1.13.1) Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.11/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=2.5.1->sam2->segment-geospatial) (12.6.85) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=2.5.1->sam2->segment-geospatial) (1.3.0) Requirement already satisfied: notebook>=4.4.1 in /usr/local/lib/python3.11/dist-packages (from widgetsnbextension~=3.6.0->ipywidgets->leafmap) (6.5.5) Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.11/dist-packages (from beautifulsoup4->gdown->segment-geospatial) (2.6) Requirement already satisfied: anyio in /usr/local/lib/python3.11/dist-packages (from httpx->rio-tiler->localtileserver) (3.7.1) Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.11/dist-packages (from httpx->rio-tiler->localtileserver) (1.0.7) Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.11/dist-packages (from httpcore==1.*->httpx->rio-tiler->localtileserver) (0.14.0) Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /usr/local/lib/python3.11/dist-packages (from requests[socks]->gdown->segment-geospatial) (1.7.1) Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.11/dist-packages (from jedi>=0.16->ipython>=4.0.0->ipywidgets->leafmap) (0.8.4) Requirement already satisfied: pyzmq<25,>=17 in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (24.0.1) Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (23.1.0) Requirement already satisfied: jupyter-core>=4.6.1 in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (5.7.2) Requirement already satisfied: nbformat in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (5.10.4) Requirement already satisfied: nbconvert>=5 in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (7.16.5) Requirement already satisfied: nest-asyncio>=1.5 in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (1.6.0) Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (1.8.3) Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (0.18.1) Requirement already satisfied: prometheus-client in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (0.21.1) Requirement already satisfied: nbclassic>=0.4.7 in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (1.2.0) Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.11/dist-packages (from pexpect>4.3->ipython>=4.0.0->ipywidgets->leafmap) (0.7.0) Requirement already satisfied: wcwidth in /usr/local/lib/python3.11/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=4.0.0->ipywidgets->leafmap) (0.2.13) Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.11/dist-packages (from anyio->httpx->rio-tiler->localtileserver) (1.3.1) Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.11/dist-packages (from jupyter-core>=4.6.1->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (4.3.6) Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.11/dist-packages (from nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (0.2.4) Requirement already satisfied: bleach!=5.0.0 in /usr/local/lib/python3.11/dist-packages (from bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (6.2.0) Requirement already satisfied: defusedxml in /usr/local/lib/python3.11/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (0.7.1) Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.11/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (0.3.0) Requirement already satisfied: mistune<4,>=2.0.3 in /usr/local/lib/python3.11/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (3.1.0) Requirement already satisfied: nbclient>=0.5.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (0.10.2) Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (1.5.1) Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.11/dist-packages (from nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (2.21.1) Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.11/dist-packages (from argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (21.2.0) Requirement already satisfied: webencodings in /usr/local/lib/python3.11/dist-packages (from bleach!=5.0.0->bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (0.5.1) Requirement already satisfied: tinycss2<1.5,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (1.4.0) Requirement already satisfied: jupyter-server<3,>=1.8 in /usr/local/lib/python3.11/dist-packages (from notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (1.24.0) Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (1.17.1) Requirement already satisfied: pycparser in /usr/local/lib/python3.11/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (2.22) Requirement already satisfied: websocket-client in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->leafmap) (1.8.0) Requirement already satisfied: rasterio in /usr/local/lib/python3.11/dist-packages (1.4.3) Requirement already satisfied: affine in /usr/local/lib/python3.11/dist-packages (from rasterio) (2.4.0) Requirement already satisfied: attrs in /usr/local/lib/python3.11/dist-packages (from rasterio) (24.3.0) Requirement already satisfied: certifi in /usr/local/lib/python3.11/dist-packages (from rasterio) (2024.12.14) Requirement already satisfied: click>=4.0 in /usr/local/lib/python3.11/dist-packages (from rasterio) (8.1.8) Requirement already satisfied: cligj>=0.5 in /usr/local/lib/python3.11/dist-packages (from rasterio) (0.7.2) Requirement already satisfied: numpy>=1.24 in /usr/local/lib/python3.11/dist-packages (from rasterio) (1.26.4) Requirement already satisfied: click-plugins in /usr/local/lib/python3.11/dist-packages (from rasterio) (1.1.1) Requirement already satisfied: pyparsing in /usr/local/lib/python3.11/dist-packages (from rasterio) (3.2.1)
We connect to Drive to access the image and plot it with Matplotlib:
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
path = '/content/drive/MyDrive/Eagave/azul.tif'
import os
import leafmap
import torch
from samgeo import SamGeo, tms_to_geotiff, get_basemaps
import rasterio
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
from rasterio.plot import show
src = rasterio.open(path)
img = src.read()
img.shape
(4, 14039, 20758)
img = img.transpose([1,2,0])
img.max()
255
plt.figure(figsize=[16,16])
plt.imshow(img)
plt.axis('off')
(-0.5, 20757.5, 14038.5, -0.5)
As the image resolution is very high, we can reduce the resolution to facilitate processing:
from rasterio.enums import Resampling
from rasterio.plot import show
donwscale_factor = 2
with rasterio.open(path) as dataset:
data = dataset.read(out_shape=(dataset.count,int(dataset.height / donwscale_factor), int(dataset.width / donwscale_factor)),resampling=Resampling.bilinear)
transform = dataset.transform * dataset.transform.scale((dataset.width / data.shape[-1]),(dataset.height / data.shape[-2]))
profile = dataset.profile
profile.update(transform=transform, driver='GTiff', height=data.shape[-2], width=data.shape[-1], crs=dataset.crs, compress='lzw', dtype='uint8')
with rasterio.open(os.path.join('/content/RGB_resampled.tif'),'w', **profile) as dst:
dst.write(data)
Then we apply SAM to the resulting image:
image = '/content/RGB_resampled.tif'
sam = SamGeo(
model_type="vit_h",
checkpoint="sam_vit_h_4b8939.pth",
sam_kwargs=None,
)
Model checkpoint for vit_h not found.
Downloading... From: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth To: /root/.cache/torch/hub/checkpoints/sam_vit_h_4b8939.pth 100%|██████████| 2.56G/2.56G [00:19<00:00, 135MB/s]
mask = "segment.tif"
sam.generate(
image, mask, batch=True, foreground=True, erosion_kernel=(3, 3), mask_multiplier=255
)
1%| | 2/294 [19:45<47:58:36, 591.50s/it]
We will convert the mask resulting from SAM into a shapefile:
if not os.path.isdir('/content/Result'):
os.mkdir('/content/Result')
shapefile = '/content/Result/segment.shp'
sam.tiff_to_vector(mask, shapefile)
To check the result, let's plot the image along with the segments:
src_test = rasterio.open(image)
mask_result = gpd.read_file(shapefile)
mask_result
| value | geometry | |
|---|---|---|
| 0 | 255.0 | POLYGON ((-119.96786 36.73463, -119.96783 36.7... |
| 1 | 255.0 | POLYGON ((-119.95462 36.73463, -119.95415 36.7... |
| 2 | 255.0 | POLYGON ((-119.95408 36.73463, -119.95394 36.7... |
| 3 | 255.0 | POLYGON ((-119.98776 36.73461, -119.98772 36.7... |
| 4 | 255.0 | POLYGON ((-119.98481 36.73461, -119.98470 36.7... |
| ... | ... | ... |
| 2360 | 255.0 | POLYGON ((-119.94186 36.69201, -119.94185 36.6... |
| 2361 | 255.0 | POLYGON ((-119.94057 36.69201, -119.94055 36.6... |
| 2362 | 255.0 | POLYGON ((-119.98925 36.69199, -119.98923 36.6... |
| 2363 | 255.0 | POLYGON ((-119.98910 36.69201, -119.98896 36.6... |
| 2364 | 255.0 | POLYGON ((-119.98847 36.69199, -119.98844 36.6... |
2365 rows × 2 columns
fig, ax = plt.subplots(figsize=(15, 15))
show(src_test, ax=ax)
mask_result.plot(ax=ax, facecolor='none', edgecolor='red')
<Axes: >
The next step consists of collecting the segments for each of the classes. There in QGIS, we created 2 vector files of points, one for the viticulture areas and the other for non-viticulture. To facilitate collection, we display the segments along with the auxiliary mapping we downloaded.
After collecting the points, we will open the two shapefiles and intersect the points with the segments:
no_wines = gpd.read_file('/content/drive/MyDrive/Datasets/California_Viticulture/No_Wines.shp')
wines = gpd.read_file('/content/drive/MyDrive/Datasets/California_Viticulture/Wines.shp')
no_wines['class'] = 0
wines['class'] = 1
samples = pd.concat([no_wines,wines])
segments = gpd.read_file('/content/drive/MyDrive/Datasets/California_Viticulture/segment.shp')
segments.reset_index(inplace=True)
seg_samples = segments.sjoin(samples, how='inner',predicate='intersects').copy()
Here we have the segments of the two classes that will be used in training
seg_samples.plot('class',cmap='viridis')
<Axes: >
seg_samples = seg_samples[~seg_samples['index'].duplicated()]
seg_samples
| index | value | geometry | index_right | id | class | |
|---|---|---|---|---|---|---|
| 160 | 160 | 255.0 | POLYGON ((-119.97884 36.73457, -119.97462 36.7... | 35 | NaN | 1 |
| 232 | 232 | 255.0 | POLYGON ((-119.96073 36.73139, -119.96060 36.7... | 40 | NaN | 1 |
| 248 | 248 | 255.0 | POLYGON ((-119.94747 36.73147, -119.94623 36.7... | 41 | NaN | 1 |
| 286 | 286 | 255.0 | POLYGON ((-119.95164 36.73457, -119.94821 36.7... | 34 | NaN | 1 |
| 290 | 290 | 255.0 | POLYGON ((-119.93617 36.73156, -119.93615 36.7... | 42 | NaN | 1 |
| ... | ... | ... | ... | ... | ... | ... |
| 2209 | 2209 | 255.0 | POLYGON ((-119.93477 36.69505, -119.93450 36.6... | 39 | NaN | 0 |
| 2315 | 2315 | 255.0 | POLYGON ((-119.93543 36.69503, -119.93538 36.6... | 29 | NaN | 0 |
| 2316 | 2316 | 255.0 | POLYGON ((-119.97191 36.69525, -119.97164 36.6... | 20 | NaN | 1 |
| 2334 | 2334 | 255.0 | POLYGON ((-119.97723 36.69533, -119.97717 36.6... | 18 | NaN | 1 |
| 2338 | 2338 | 255.0 | POLYGON ((-119.94308 36.73463, -119.94285 36.7... | 30 | NaN | 0 |
110 rows × 6 columns
With these segments selected, we will extract information from the image for each segment. We will use the average of each spectral band of the image and the values of the GLCM properties:
from skimage.feature import graycomatrix, graycoprops
from rasterio.mask import mask
src = rasterio.open(path)
json_dataset = json.loads(seg_samples.to_json())['features']
X = []
Y = []
for i in range(len(json_dataset)):
coords = json_dataset[i]['geometry']
out_image, out_transform = mask(src, [coords], crop=True)
out_image = out_image.transpose([1,2,0])
out_image = np.nan_to_num(out_image)
R_mean = out_image[:,:,0][np.nonzero(out_image[:,:,0])].mean()
G_mean = out_image[:,:,1][np.nonzero(out_image[:,:,1])].mean()
B_mean = out_image[:,:,2][np.nonzero(out_image[:,:,2])].mean()
norm_image = out_image[:,:,0].copy()
distances = [1]
angles = [ np.pi/2]
glcm = graycomatrix(norm_image,
distances=distances,
angles=angles,
symmetric=True,
normed=True,
levels=256)
contrast = np.array(graycoprops(glcm, 'contrast')).ravel()[0]
energy = np.array(graycoprops(glcm, 'energy')).ravel()[0]
homogeneity = np.array(graycoprops(glcm, 'homogeneity')).ravel()[0]
correlation = np.array(graycoprops(glcm, 'correlation')).ravel()[0]
dissimilarity = np.array(graycoprops(glcm, 'dissimilarity')).ravel()[0]
ASM = np.array(graycoprops(glcm,'ASM')).ravel()[0]
X.append([R_mean,G_mean,B_mean,contrast,energy,homogeneity,correlation,dissimilarity,ASM])
Y.append(json_dataset[i]['properties']['class'])
Thus, we generate our set of information that will be divided into training and testing:
X = np.array(X)
Y = np.array(Y)
X.shape
(110, 9)
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from sklearn.metrics import cohen_kappa_score
from sklearn.metrics import confusion_matrix
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = 0.3, random_state = 42)
We apply the random forest classifier to our data:
clf=RandomForestClassifier(n_estimators=50, n_jobs=8, verbose=3)
clf.fit(X_train,Y_train)
building tree 1 of 50 building tree 2 of 50 building tree 3 of 50 building tree 4 of 50 building tree 5 of 50 building tree 6 of 50building tree 7 of 50 building tree 8 of 50 building tree 9 of 50building tree 10 of 50 building tree 11 of 50 building tree 12 of 50building tree 13 of 50 building tree 14 of 50 building tree 15 of 50 building tree 16 of 50 building tree 17 of 50 building tree 18 of 50building tree 19 of 50 building tree 20 of 50 building tree 21 of 50 building tree 22 of 50building tree 23 of 50 building tree 24 of 50 building tree 25 of 50 building tree 26 of 50 building tree 27 of 50building tree 28 of 50 building tree 29 of 50 building tree 30 of 50 building tree 31 of 50 building tree 32 of 50 building tree 33 of 50 building tree 34 of 50building tree 35 of 50 building tree 36 of 50 building tree 37 of 50 building tree 38 of 50 building tree 39 of 50 building tree 40 of 50 building tree 41 of 50 building tree 42 of 50 building tree 43 of 50 building tree 44 of 50 building tree 45 of 50 building tree 46 of 50 building tree 47 of 50 building tree 48 of 50 building tree 49 of 50building tree 50 of 50
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers. [Parallel(n_jobs=8)]: Done 16 tasks | elapsed: 0.1s [Parallel(n_jobs=8)]: Done 50 out of 50 | elapsed: 0.1s finished
RandomForestClassifier(n_estimators=50, n_jobs=8, verbose=3)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RandomForestClassifier(n_estimators=50, n_jobs=8, verbose=3)
y_pred=clf.predict(X_test)
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers. [Parallel(n_jobs=8)]: Done 16 tasks | elapsed: 0.0s [Parallel(n_jobs=8)]: Done 50 out of 50 | elapsed: 0.0s finished
We observe the accuracy values:
print("Accuracy on training set: {:.2f}".format(clf.score(X_train, Y_train)))
print("Accuracy on test set: {:.2f}".format(clf.score(X_test, Y_test)))
Accuracy on training set: 1.00 Accuracy on test set: 0.82
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers. [Parallel(n_jobs=8)]: Done 16 tasks | elapsed: 0.0s [Parallel(n_jobs=8)]: Done 50 out of 50 | elapsed: 0.0s finished [Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers. [Parallel(n_jobs=8)]: Done 16 tasks | elapsed: 0.0s [Parallel(n_jobs=8)]: Done 50 out of 50 | elapsed: 0.0s finished
print(classification_report(Y_test, y_pred))
precision recall f1-score support
0 0.80 0.89 0.84 18
1 0.85 0.73 0.79 15
accuracy 0.82 33
macro avg 0.82 0.81 0.81 33
weighted avg 0.82 0.82 0.82 33
The last step is to perform inference on all segments of the image using the trained model. Let's extract the features, but now for all segments:
json_dataset_full = json.loads(segments.to_json())['features']
X_pred = []
id_pred = []
for i in range(len(json_dataset_full)):
coords = json_dataset_full[i]['geometry']
out_image, out_transform = mask(src, [coords], crop=True)
out_image = out_image.transpose([1,2,0])
out_image = np.nan_to_num(out_image)
R_mean = out_image[:,:,0][np.nonzero(out_image[:,:,0])].mean()
G_mean = out_image[:,:,1][np.nonzero(out_image[:,:,1])].mean()
B_mean = out_image[:,:,2][np.nonzero(out_image[:,:,2])].mean()
norm_image = out_image[:,:,0].copy()
distances = [1]
angles = [ np.pi/2]
glcm = graycomatrix(norm_image,
distances=distances,
angles=angles,
symmetric=True,
normed=True,
levels=256)
contrast = np.array(graycoprops(glcm, 'contrast')).ravel()[0]
energy = np.array(graycoprops(glcm, 'energy')).ravel()[0]
homogeneity = np.array(graycoprops(glcm, 'homogeneity')).ravel()[0]
correlation = np.array(graycoprops(glcm, 'correlation')).ravel()[0]
dissimilarity = np.array(graycoprops(glcm, 'dissimilarity')).ravel()[0]
ASM = np.array(graycoprops(glcm,'ASM')).ravel()[0]
X_pred.append([R_mean,G_mean,B_mean,contrast,energy,homogeneity,correlation,dissimilarity,ASM])
id_pred.append(json_dataset_full[i]['properties']['index'])
X_pred = np.array(X_pred)
ID = np.array(id_pred)
X_pred = np.nan_to_num(X_pred)
img_pred=clf.predict(X_pred)
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers. [Parallel(n_jobs=8)]: Done 16 tasks | elapsed: 0.0s [Parallel(n_jobs=8)]: Done 50 out of 50 | elapsed: 0.0s finished
ID = ID.astype('uint16')
segments['class'] = img_pred
segments
| index | value | geometry | class | |
|---|---|---|---|---|
| 0 | 0 | 255.0 | POLYGON ((-119.96786 36.73463, -119.96783 36.7... | 0 |
| 1 | 1 | 255.0 | POLYGON ((-119.95462 36.73463, -119.95415 36.7... | 0 |
| 2 | 2 | 255.0 | POLYGON ((-119.95408 36.73463, -119.95394 36.7... | 0 |
| 3 | 3 | 255.0 | POLYGON ((-119.98776 36.73461, -119.98772 36.7... | 1 |
| 4 | 4 | 255.0 | POLYGON ((-119.98481 36.73461, -119.98470 36.7... | 1 |
| ... | ... | ... | ... | ... |
| 2360 | 2360 | 255.0 | POLYGON ((-119.94186 36.69201, -119.94185 36.6... | 1 |
| 2361 | 2361 | 255.0 | POLYGON ((-119.94057 36.69201, -119.94055 36.6... | 1 |
| 2362 | 2362 | 255.0 | POLYGON ((-119.98925 36.69199, -119.98923 36.6... | 1 |
| 2363 | 2363 | 255.0 | POLYGON ((-119.98910 36.69201, -119.98896 36.6... | 1 |
| 2364 | 2364 | 255.0 | POLYGON ((-119.98847 36.69199, -119.98844 36.6... | 1 |
2365 rows × 4 columns
Let's plot the result:
names = { 0:'No_Wine',
1:'Wine'}
segments['class_name'] = segments['class'].replace(names)
fig, ax = plt.subplots(1, figsize=(14,8))
segments.plot(column='class_name', categorical=True, cmap='viridis', linewidth=.6, edgecolor='0.2',
legend=True, legend_kwds={'bbox_to_anchor':(1, 0.5),'loc':'upper left','fontsize':16,'frameon':False}, ax=ax)
ax.axis('off')
ax.set_title('California Wines',fontsize=20)
plt.tight_layout()