Adding a backend¶
A backend wraps a model runtime and exposes a uniform interface to the pipeline. Add a backend when you need a new runtime (e.g. TensorFlow, TFLite, scikit-learn).
Steps¶
A backend is a one-file plugin. We will use a fictional backend called MyBackend. It depends on Torch, but your real backend might not; the following is just an example.
Subclass
ModelBackendand decorate with@backends.register(provides=..., extensions=...).Declare
providesandextensions:providesis thefrozenset[Capability]your backend offers;extensionsis the set of file suffixes it loads. The decorator type-checks both, sets them as class variables, and indexes the backend by extension so model loading resolves the right backend for a given file.Declare
extra(andsupported_hardwareif hardware-split):extrais the uv extra that installs your runtime library (e.g."torch","xgboost").raitap-depsreads it — import-free, via an AST scan of the decorator — to tell users which extra to install for your file format. Addsupported_hardware={ResolvedHardware.cpu, ...}only if your library ships a distinct wheel per accelerator; the installable extra is thenf"{extra}-{hw.pyproject_extra_suffix}"(e.g.torch-cpu). Omit it for single-wheel runtimes (the extra is the bareextraon all hardware). A file-backed backend withoutextrais invisible to deps inference and falls back to the torch default.Implement the abstract methods:
from_path(construct from a model file),__call__(run inference), and thehardware_labelproperty.predict_callableis inherited: it returnsself.__call__, the universal forward-only shape that model-agnostic explainers consume. You do not implement it.autograd_moduleis opt-in: implement it (return the live torchnn.Module) and declareCapability.AUTOGRADONLY if your backend exposes a differentiable torch module. Gradient explainers and attacks get this shape; model-agnostic ones get the predict callable.
from pathlib import Path
from typing import Any
import torch
from torch import nn
from raitap import backends
from raitap.models.backend import ModelBackend
from raitap.types import Capability, ResolvedHardware
@backends.register(
provides={Capability.AUTOGRAD},
extensions={".pth", ".pt"},
extra="mybackend",
supported_hardware={ResolvedHardware.cpu, ResolvedHardware.cuda}, # ships cpu + cuda wheels
)
class MyBackend(ModelBackend):
def __init__(self, model: nn.Module) -> None:
self.model = model
@classmethod
def from_path(
cls, path: Path, *, model_cfg: Any, hardware: str, allow_unsafe_pickle: bool = False
) -> ModelBackend: # load a model file into this backend
...
def __call__(self, inputs: torch.Tensor) -> Any: # run inference, return raw output
return self.model(inputs)
@property
def hardware_label(self) -> str: # free-form display label for the run summary
return get_hardware_label_for_mybackend(self.device)
def autograd_module(self) -> nn.Module: # only if AUTOGRAD-capable
return self.model
A non-torch or forward-only backend (e.g. ONNX) declares provides=FORWARD_ONLY (the empty capability set, imported from raitap.types), skips autograd_module, and runs model-agnostic explainers only.
Tree / tabular backend¶
Tree-ensemble runtimes (XGBoost, LightGBM, scikit-learn) follow a separate base class: TabularTreeBackend. It owns the torch-to-numpy bridge, the fitted_estimator() accessor, and the (N, C) probability output shape. Subclass it instead of raw ModelBackend.
The concrete subclass implements two methods:
from_path: defer the library import inside the method so the import error surfaces only when the backend is actually used, not at module load. RaiseImportErrorwith a pip install hint if the library is absent._predict_proba: call the fitted estimator and return an(N, C)numpy array of class probabilities.
Register with provides={Capability.TREE_MODEL, Capability.PREDICT_PROBA}, a file extension, and extra="xgboost". No supported_hardware: XGBoost ships a single wheel (the bare xgboost extra on all hardware). The fitted_estimator() accessor satisfies the EstimatorProvider protocol, which shap.TreeExplainer consumes directly. The predict_callable method (inherited) returns a callable over the numpy-bridge probabilities, which enables model-agnostic SHAP explainers (e.g. KernelExplainer) on tree backends for free.
from pathlib import Path
from typing import Any
import numpy as np
from raitap import backends
from raitap.models.tree_backend import TabularTreeBackend
from raitap.types import Capability
@backends.register(
provides={Capability.TREE_MODEL, Capability.PREDICT_PROBA},
extensions={".ubj"},
extra="xgboost", # single wheel -> bare extra, no supported_hardware
)
class XGBoostBackend(TabularTreeBackend):
# __init__(estimator) is inherited from TabularTreeBackend.
@classmethod
def from_path(
cls, path: Path, *, model_cfg: Any, hardware: str, allow_unsafe_pickle: bool = False
) -> "XGBoostBackend":
try:
import xgboost # deferred: only required with --extra xgboost
except ImportError as exc:
raise ImportError(
"XGBoost is not installed. Run: uv sync --extra xgboost"
) from exc
estimator = xgboost.XGBClassifier()
estimator.load_model(str(path))
return cls(estimator)
def _predict_proba(self, x: np.ndarray) -> np.ndarray: # (N, C)
return self._estimator.predict_proba(x)
TabularTreeBackend inherits hardware_label and the CPU/classification defaults, so you do not need to override them unless your runtime supports GPU placement.
Which capabilities to declare¶
Most backends provide {Capability.AUTOGRAD} (torch) or FORWARD_ONLY (forward-only, e.g. ONNX). See Backend capabilities for the full list, what each means, and which algorithms require it.
Optional attributes¶
Override these class or instance attributes as needed:
Attribute |
Type |
Default |
Purpose |
|---|---|---|---|
|
|
|
Per-sample input shape. |
|
|
|
Class id-to-name table (e.g. from model weights metadata). |
|
|
|
Task family this backend serves. Override for detection etc. |
No gate code needed¶
Do not write compatibility checks in your backend. The shared gate (AdapterMixin.check_backend_compat) is inherited by every adapter and raises BackendIncompatibilityError automatically when algorithm.requires - backend.provides is non-empty. Rule: an algorithm runs on a backend iff algorithm.requires <= backend.provides.