Convert Figma logo to code with AI

qubvel-org logosegmentation_models.pytorch

Semantic segmentation models with 500+ pretrained convolutional and transformer-based backbones.

10,528
1,754
10,528
63

Top Related Projects

🤗 Transformers: the model-definition framework for state-of-the-art machine learning models in text, vision, audio, and multimodal models, for both inference and training.

Detectron2 is a platform for object detection, segmentation and other visual recognition tasks.

25,251

Mask R-CNN for object detection and instance segmentation on Keras and TensorFlow

OpenMMLab Detection Toolbox and Benchmark

77,618

Models and examples built with TensorFlow

54,362

YOLOv5 🚀 in PyTorch > ONNX > CoreML > TFLite

Quick Overview

Segmentation Models PyTorch is a Python library that provides a collection of deep learning models for image segmentation tasks. It offers pre-trained models, flexible encoder-decoder architectures, and various loss functions, making it easier for researchers and practitioners to implement and experiment with state-of-the-art segmentation algorithms.

Pros

  • Wide range of pre-trained models and architectures
  • Easy-to-use API for quick implementation and experimentation
  • Supports various backbones and loss functions
  • Regularly updated with new models and features

Cons

  • Limited to PyTorch framework
  • May require significant computational resources for training large models
  • Documentation could be more comprehensive for advanced usage
  • Some reported issues with compatibility across different PyTorch versions

Code Examples

  1. Importing and initializing a model:
import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    classes=1,
    activation="sigmoid",
)
  1. Preparing data and training the model:
import torch
from torch.utils.data import DataLoader

# Assume 'train_dataset' is your custom dataset
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

loss_fn = smp.utils.losses.DiceLoss()
metrics = [smp.utils.metrics.IoU(threshold=0.5)]

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss_fn, 
    metrics=metrics, 
    optimizer=optimizer,
    device='cuda',
    verbose=True,
)

for i in range(0, 40):
    train_logs = train_epoch.run(train_loader)
  1. Making predictions:
import numpy as np

image = np.random.randint(0, 255, (256, 256, 3)).astype('uint8')
mask = model.predict(image)

Getting Started

To get started with Segmentation Models PyTorch:

  1. Install the library:
pip install segmentation-models-pytorch
  1. Import the library and create a model:
import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    classes=1,
    activation="sigmoid",
)
  1. Prepare your data, define loss function and optimizer, then train the model as shown in the code examples above.

Competitor Comparisons

🤗 Transformers: the model-definition framework for state-of-the-art machine learning models in text, vision, audio, and multimodal models, for both inference and training.

Pros of transformers

  • Broader scope, covering various NLP tasks beyond segmentation
  • Larger community and more frequent updates
  • Extensive documentation and examples

Cons of transformers

  • Steeper learning curve due to its comprehensive nature
  • Potentially heavier and more resource-intensive for simple tasks

Code Comparison

transformers:

from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")

segmentation_models.pytorch:

import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    classes=1,
    activation="sigmoid",
)

The transformers library offers a more generalized approach for various NLP tasks, while segmentation_models.pytorch focuses specifically on image segmentation models. transformers provides pre-trained models and easy-to-use interfaces for a wide range of NLP tasks, making it versatile but potentially more complex for beginners. segmentation_models.pytorch, on the other hand, offers a streamlined experience for image segmentation tasks, with a simpler API and focused functionality.

Detectron2 is a platform for object detection, segmentation and other visual recognition tasks.

Pros of Detectron2

  • More comprehensive, supporting object detection, instance segmentation, and other tasks beyond semantic segmentation
  • Highly modular and extensible architecture, allowing for easy customization and experimentation
  • Backed by Facebook AI Research, with frequent updates and a large community

Cons of Detectron2

  • Steeper learning curve due to its complexity and extensive feature set
  • Heavier resource requirements, potentially overkill for simpler segmentation tasks
  • Less focused on semantic segmentation specifically compared to segmentation_models.pytorch

Code Comparison

segmentation_models.pytorch:

import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    classes=1,
    activation="sigmoid",
)

Detectron2:

from detectron2 import model_zoo
from detectron2.config import get_cfg

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
model = model_zoo.get_model(cfg, weights=model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
25,251

Mask R-CNN for object detection and instance segmentation on Keras and TensorFlow

Pros of Mask_RCNN

  • Specialized for instance segmentation tasks
  • Includes pre-trained models on COCO dataset
  • Provides visualization tools for results

Cons of Mask_RCNN

  • Less flexible for other segmentation tasks
  • Requires more computational resources
  • Limited to Keras/TensorFlow backend

Code Comparison

Mask_RCNN:

import mrcnn.model as modellib

model = modellib.MaskRCNN(mode="inference", config=config, model_dir=MODEL_DIR)
model.load_weights(COCO_MODEL_PATH, by_name=True)
results = model.detect([image], verbose=1)

segmentation_models.pytorch:

import segmentation_models_pytorch as smp

model = smp.Unet('resnet34', encoder_weights='imagenet', classes=1, activation='sigmoid')
pred_mask = model(image)

The Mask_RCNN code focuses on instance segmentation with pre-trained weights, while segmentation_models.pytorch offers a more flexible approach for various segmentation tasks using PyTorch.

OpenMMLab Detection Toolbox and Benchmark

Pros of mmdetection

  • Broader scope: Supports object detection, instance segmentation, and panoptic segmentation
  • Extensive model zoo with pre-trained weights for various architectures
  • Highly modular design allowing easy customization and extension

Cons of mmdetection

  • Steeper learning curve due to its comprehensive nature
  • Potentially more complex setup and configuration process
  • May be overkill for simple segmentation tasks

Code Comparison

mmdetection:

from mmdet.apis import init_detector, inference_detector

config_file = 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
checkpoint_file = 'checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
model = init_detector(config_file, checkpoint_file, device='cuda:0')
result = inference_detector(model, 'test.jpg')

segmentation_models.pytorch:

import segmentation_models_pytorch as smp

model = smp.Unet('resnet34', encoder_weights='imagenet', classes=1, activation='sigmoid')
pred_mask = model(image)

The code comparison shows that mmdetection requires more setup but offers greater flexibility, while segmentation_models.pytorch provides a simpler interface for segmentation tasks.

77,618

Models and examples built with TensorFlow

Pros of models

  • Broader scope: Covers various ML tasks beyond just segmentation
  • Official TensorFlow repository: Well-maintained with regular updates
  • Extensive documentation and examples for multiple model architectures

Cons of models

  • Steeper learning curve due to its comprehensive nature
  • May be overwhelming for users focused solely on segmentation tasks
  • Requires more setup and configuration for specific use cases

Code Comparison

models (TensorFlow):

import tensorflow as tf
from official.vision.image_classification import resnet_model

model = resnet_model.resnet50(num_classes=1000)

segmentation_models.pytorch:

import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    classes=1,
)

Summary

models offers a comprehensive suite of TensorFlow implementations for various ML tasks, while segmentation_models.pytorch focuses specifically on image segmentation using PyTorch. The former provides broader functionality but may be more complex for beginners, while the latter offers a streamlined experience for segmentation tasks. Code-wise, models requires more setup, whereas segmentation_models.pytorch provides a more concise API for creating segmentation models.

54,362

YOLOv5 🚀 in PyTorch > ONNX > CoreML > TFLite

Pros of YOLOv5

  • Specialized for object detection tasks with real-time performance
  • Extensive documentation and community support
  • Includes pre-trained models and easy-to-use inference scripts

Cons of YOLOv5

  • Limited to object detection, not suitable for general segmentation tasks
  • May require more computational resources for training and inference
  • Less flexibility in terms of backbone architectures compared to segmentation_models.pytorch

Code Comparison

YOLOv5:

from yolov5 import YOLOv5

model = YOLOv5('yolov5s.pt')
results = model('image.jpg')
results.show()

segmentation_models.pytorch:

import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    classes=1,
    activation="sigmoid",
)

YOLOv5 focuses on simplicity and ease of use for object detection, while segmentation_models.pytorch offers more flexibility for various segmentation tasks with different architectures and encoders. YOLOv5 is better suited for real-time object detection, whereas segmentation_models.pytorch provides a broader range of segmentation models for diverse applications.

Convert Figma logo designs to code with AI

Visual Copilot

Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.

Try Visual Copilot

README

logo
Python library with Neural Networks for Image Semantic
Segmentation based on PyTorch.

GitHub Workflow Status (branch) Codecov Read the Docs
PyPI PyTorch - Version Python - Version
Generic badge PyPI - Downloads

The main features of the library are:

  • Super simple high-level API (just two lines to create a neural network)
  • 12 encoder-decoder model architectures (Unet, Unet++, Segformer, DPT, ...)
  • 800+ pretrained convolution- and transform-based encoders, including timm support
  • Popular metrics and losses for training routines (Dice, Jaccard, Tversky, ...)
  • ONNX export and torch script/trace/compile friendly

Community-Driven Project, Supported By

withoutBG API Logo withoutBG API
https://withoutbg.com

High-quality background removal API

📚 Project Documentation 📚

Visit Read The Docs Project Page or read the following README to know more about Segmentation Models Pytorch (SMP for short) library

📋 Table of content

  1. Quick start
  2. Examples
  3. Models and encoders
  4. Models API
    1. Input channels
    2. Auxiliary classification output
    3. Depth
  5. Installation
  6. Competitions won with the library
  7. Contributing
  8. Citing
  9. License

⏳ Quick start

1. Create your first Segmentation model with SMP

The segmentation model is just a PyTorch torch.nn.Module, which can be created as easy as:

import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=1,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=3,                      # model output channels (number of classes in your dataset)
)
  • see table with available model architectures
  • see table with available encoders and their corresponding weights

2. Configure data preprocessing

All encoders have pretrained weights. Preparing your data the same way as during weights pre-training may give you better results (higher metric score and faster convergence). It is not necessary in case you train the whole model, not only the decoder.

from segmentation_models_pytorch.encoders import get_preprocessing_fn

preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')

Congratulations! You are done! Now you can train your model with your favorite framework!

💡 Examples

NameLinkColab
Train pets binary segmentation on OxfordPetsNotebookOpen In Colab
Train cars binary segmentation on CamVidNotebookOpen In Colab
Train multiclass segmentation on CamVidNotebookOpen In Colab
Train clothes binary segmentation by @ternausRepo
Load and inference pretrained SegformerNotebookOpen In Colab
Load and inference pretrained DPTNotebookOpen In Colab
Load and inference pretrained UPerNetNotebookOpen In Colab
Save and load models locally / to HuggingFace HubNotebookOpen In Colab
Export trained model to ONNXNotebookOpen In Colab

📦 Models and encoders

Architectures

ArchitecturePaperDocumentationCheckpoints
Unetpaperdocs
Unet++paperdocs
MAnetpaperdocs
Linknetpaperdocs
FPNpaperdocs
PSPNetpaperdocs
PANpaperdocs
DeepLabV3paperdocs
DeepLabV3+paperdocs
UPerNetpaperdocscheckpoints
Segformerpaperdocscheckpoints
DPTpaperdocscheckpoints

Encoders

The library provides a wide range of pretrained encoders (also known as backbones) for segmentation models. Instead of using features from the final layer of a classification model, we extract intermediate features and feed them into the decoder for segmentation tasks.

All encoders come with pretrained weights, which help achieve faster and more stable convergence when training segmentation models.

Given the extensive selection of supported encoders, you can choose the best one for your specific use case, for example:

  • Lightweight encoders for low-latency applications or real-time inference on edge devices (mobilenet/mobileone).
  • High-capacity architectures for complex tasks involving a large number of segmented classes, providing superior accuracy (convnext/swin/mit).

By selecting the right encoder, you can balance efficiency, performance, and model complexity to suit your project needs.

All encoders and corresponding pretrained weight are listed in the documentation:

  • table with natively ported encoders
  • table with timm encoders supported

🔁 Models API

Input channels

The input channels parameter allows you to create a model that can process a tensor with an arbitrary number of channels. If you use pretrained weights from ImageNet, the weights of the first convolution will be reused:

  • For the 1-channel case, it would be a sum of the weights of the first convolution layer.
  • Otherwise, channels would be populated with weights like new_weight[:, i] = pretrained_weight[:, i % 3], and then scaled with new_weight * 3 / new_in_channels.
model = smp.FPN('resnet34', in_channels=1)
mask = model(torch.ones([1, 1, 64, 64]))

Auxiliary classification output

All models support aux_params parameters, which is default set to None. If aux_params = None then classification auxiliary output is not created, else model produce not only mask, but also label output with shape NC. Classification head consists of GlobalPooling->Dropout(optional)->Linear->Activation(optional) layers, which can be configured by aux_params as follows:

aux_params=dict(
    pooling='avg',             # one of 'avg', 'max'
    dropout=0.5,               # dropout ratio, default is None
    activation='sigmoid',      # activation function, default is None
    classes=4,                 # define number of output labels
)
model = smp.Unet('resnet34', classes=4, aux_params=aux_params)
mask, label = model(x)

Depth

Depth parameter specify a number of downsampling operations in encoder, so you can make your model lighter if specify smaller depth.

model = smp.Unet('resnet34', encoder_depth=4)

🛠 Installation

PyPI version:

$ pip install segmentation-models-pytorch

The latest version from GitHub:

$ pip install git+https://github.com/qubvel/segmentation_models.pytorch

🏆 Competitions won with the library

Segmentation Models package is widely used in image segmentation competitions. Here you can find competitions, names of the winners and links to their solutions.

🤝 Contributing

  1. Install SMP in dev mode
make install_dev  # Create .venv, install SMP in dev mode
  1. Run tests and code checks
make test          # Run tests suite with pytest
make fixup         # Ruff for formatting and lint checks
  1. Update a table (in case you added an encoder)
make table        # Generates a table with encoders and print to stdout

📝 Citing

@misc{Iakubovskii:2019,
  Author = {Pavel Iakubovskii},
  Title = {Segmentation Models Pytorch},
  Year = {2019},
  Publisher = {GitHub},
  Journal = {GitHub repository},
  Howpublished = {\url{https://github.com/qubvel/segmentation_models.pytorch}}
}

🛡️ License

The project is primarily distributed under MIT License, while some files are subject to other licenses. Please refer to LICENSES and license statements in each file for careful check, especially for commercial use.