Adding a task family¶
A task family is the shape of the model's forward output: classification returns a (batch, num_classes) tensor, detection returns list[dict[str, Tensor]], segmentation a per-pixel mask, seq2seq variable-length token sequences. Each family is one strategy object: a TaskFamily subclass in src/raitap/task_families/ that owns every task-specific behaviour the pipeline phases used to branch on. ClassificationFamily and DetectionFamily are the two registered families; copy whichever is closer to your case.
This page covers adding a new family. Adding a new adapter to an existing family is much smaller: see Adding an adapter.
When to add a task family¶
Trigger: the forward output is not a scalar-per-sample tensor and the user-facing semantics of explanation / metrics / robustness differ from classification.
If the output reduces to a scalar via a thin wrapper without changing what the user sees (e.g. "explain the score of box #0" instead of "explain each box"), you do not need a new family. Write a wrapper and reuse TaskKind.classification (pattern: ScalarDetectionWrapper in src/raitap/models/task_wrappers.py). Detection earned a new family because per-box semantics differ from per-class: K boxes per sample, each with its own attribution map and ground-truth match.
1. Add a TaskKind member¶
TaskKind in src/raitap/types.py is a StrEnum. Add one member; name and value must be identical and lowercase so OmegaConf round-trips it through YAML. This is the only edit to a shared file.
class TaskKind(StrEnum):
classification = "classification"
detection = "detection"
segmentation = "segmentation"
2. Implement the TaskFamily subclass¶
Add src/raitap/task_families/<family>.py. The @task_family decorator instantiates the class once and registers that singleton under its kind. Two class attributes plus the TaskFamily Protocol members (src/raitap/task_families/base.py):
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from raitap.task_families.registry import task_family
from raitap.transparency.contracts import ExplanationOutputSpace
from raitap.types import TaskKind
from raitap.utils.lazy import lazy_import
if TYPE_CHECKING:
import torch
from raitap.task_families.base import ExplainContext, ForwardContext
else:
torch = lazy_import("torch") # this module is imported at CLI startup, before torch exists
@task_family
class SegmentationFamily:
kind: TaskKind = TaskKind.segmentation
fixed_output_space: ExplanationOutputSpace | None = ExplanationOutputSpace.IMAGE_SPATIAL_MAP
def validate_payload(self, payload: object) -> None:
if not isinstance(payload, torch.Tensor):
raise ValueError("segmentation payload must be a (N, C, H, W) mask tensor.")
def explain(self, ctx: ExplainContext) -> list:
... # the rest of the Protocol members below
Use lazy_import("torch"), never a top-level import torch: this file is imported eagerly for @task_family registration, and the bare CLI bootstrap imports it before installing torch. A top-level import crashes raitap --demo -y (see Adding an adapter).
Class attributes:
kind: TaskKind: the member from step 1.fixed_output_space: ExplanationOutputSpace | None:Nonemeans "infer the output space dynamically" (classification: CAM vs input-features). A fixed value pins it (detection:DETECTION_BOXES).
Methods (signatures take ctx / payload / tensor / cfg as the Protocol defines):
validate_payload(payload): raiseValueErrorif the forward payload is the wrong type for this family.adapt_loaded_inputs(tensor): shape the freshly-loaded dense tensor. Classification returns it as-is; detection unbinds it into a raggedlist[(C, H, W)].validate_inputs(tensor): raise if the post-adapt inputs break this family's contract.load_labels(cfg, *, tensor, sample_ids): load labels in this family's on-disk shape, or returnNonewhen no label source is set.validate_labels(labels): raise when loaded labels are the wrong shape (e.g. a tensor where this family wantslist[dict]); a mismatch means model and data declare different families.extract_forward(ctx, *, batch_size): run the backend forward in chunks and return this family's payload.explain(ctx): returnlist[ExplanationResult]. Classification runs the explainer once over the batch; detection runs a per-box K-loop.metrics_inputs(config, forward_output, labels): adapt payload + labels into(preds, targets)for the metric adapters, orNoneto skip metrics.supports_robustness(): return whether the robustness phase runs for this family.prediction_summaries(payload, *, sample_ids, targets): per-sample summary rows, orNonewhen the concept does not apply (detection returnsNone).allows_preprocessing(property): whether data/model preprocessing transforms are allowed.matches_model(model): optional. ReturnTrueif this family recognisesmodelby architecture, for backend auto-inference (step 4). Implement it only for auto-detectable families. Classification is the fallback and returnsFalse.
ClassificationFamily and DetectionFamily are the two worked examples; read both before writing yours.
3. Register the metric adapter and visualiser¶
Same decorators as any adapter (Adding an adapter). The metric computer ties to your family by registry_name (selected via metrics=<name>); the visualiser ties to it by supported_tasks, which makes the transparency pre-flight reject an incompatible config up front instead of failing deep in the explain loop.
from raitap.metrics.registration import metrics_adapter
from raitap.transparency.visualisers.registration import transparency_visualiser
@metrics_adapter(registry_name="segmentation", extra="metrics", schema=SegmentationMetricsConfig)
class SegmentationMetrics(BaseMetricComputer): ...
@transparency_visualiser(
registry_name="segmentation_mask",
supported_tasks={TaskKind.segmentation}, # the pre-flight gate
supported_output_spaces={ExplanationOutputSpace.IMAGE_SPATIAL_MAP},
supported_payload_kinds={ExplanationPayloadKind.ATTRIBUTIONS},
)
class SegmentationMaskVisualiser(BaseVisualiser): ...
4. Declare the model's task¶
The pipeline reads backend.task_kind. Pass it explicitly, or rely on auto-inference: TorchBackend._infer_task_kind calls each registered family's matches_model and raises if more than one matches.
backend = TorchBackend(model, task_kind=TaskKind.segmentation) # explicit
backend = TorchBackend(model) # auto: scans matches_model
Implement matches_model only for an architecture-detectable family (torchvision detectors are recognised this way); otherwise pass task_kind= explicitly. Backends have their own page: Adding a backend.
5. Data-layer boundary¶
The Data layer is still image/file-centric: it loads images into dense tensors and routes by a task_kind file-loading check. Non-image families (text, seq2seq) also need the Data layer generalised before they work end to end. That generalisation is tracked in GH issue #253. Image families (classification, detection, segmentation) work within the current loader.
Worked example: detection¶
DetectionFamily (src/raitap/task_families/detection.py) is the complete reference for a non-classification family: fixed_output_space = DETECTION_BOXES, adapt_loaded_inputs unbinds to a ragged list, load_labels parses a JSON records file aligned by sample_id, explain drives a per-box K-loop via explain_detection, supports_robustness returns False (Phase 4 follow-up), and matches_model recognises torchvision detectors. Its runnable end-to-end config lives at contributor-configs/fasterrcnn-udacity/.