Python API¶
This page explains how to use RAITAP from Python directly, without YAML files or the CLI.
When to use which¶
Use YAML + CLI when your job is reproducible from disk, when you sweep parameters with Hydra multirun, or when you launch on Slurm. The CLI gives you --multirun, --help, dotted overrides, and persistent output directories. Configs in version control are the source of truth.
Use Python when you work in a notebook, embed RAITAP into another tool, want type-checking against the dataclass schema, or build configs dynamically (e.g. a sweep generated in a for loop where each iteration depends on the previous result). The Python path skips Hydra's chdir and its logging hijack, so it composes cleanly with your own logging and working-directory conventions.
API surface¶
The general use objects are exported by the raitap package:
from raitap import AppConfig, Hardware, run
The module-specific objects are exported by the respective modules:
from raitap.models import ModelConfig
from raitap.data import DataConfig, LabelsConfig
from raitap.metrics import multiclass_classification
from raitap.robustness import image_pair, torchattacks
from raitap.transparency import captum, captum_image
Install + quickstart¶
The Python equivalent of raitap --demo is roughly twenty lines. Build an AppConfig, pass it to run, read the structured RunOutputs (raitap.pipeline.outputs.RunOutputs) back:
from raitap import AppConfig, Hardware, run
from raitap.data import DataConfig, LabelsConfig
from raitap.metrics import multiclass_classification
from raitap.models import ModelConfig
from raitap.robustness import image_pair, torchattacks
from raitap.transparency import captum, captum_image
cfg = # config code, omitted
outputs = run(cfg, verbose=False)
verbose=False suppresses the rich console summary panel but does not silence Python logging; configure the root logger yourself if you want quiet output.
Conversely, with verbose=True (the default) only the summary panel renders — per-step progress messages flow through Python logging and stay silent until you attach a handler. The CLI configures one via Hydra; the Python path leaves logging to you. A one-liner near the top of the script is enough:
import logging
logging.basicConfig(level=logging.INFO, format="%(message)s")
Auto-installing extras from Python¶
Pass the corresponding flag to the run function:
# imports omitted
cfg = # config code, omitted
run(cfg, auto_install_deps=True)
auto_install_deps is opt-in. Without it run(cfg) assumes the extras the config references are already installed — the typical case after a CLI bootstrap or a manual uv sync. A missing adapter library surfaces as the usual ModuleNotFoundError from the adapter import chain.
Translation rules¶
YAML pattern |
Python builder |
|---|---|
|
|
|
Not needed — |
Group/name selection ( |
Use the dict key on the Python side too: |
List of visualisers |
One builder per visualiser ( |
|
Builder kwargs are required-or-optional based on the wrapped constructor signature; your editor surfaces the missing ones. |
CLI overrides ( |
Mutate the dataclass: |
See Examples for translation examples.
Type safety¶
Due to how Hydra works, only some fields are typed.
Fully typed¶
hardware: Hardware,data.labels.encoding: LabelEncoding,data.labels.id_strategy: IdStrategy,model.task_kind: TaskKind— all four areenum.StrEnumsubclasses.The nested dataclass dicts on
AppConfig.transparencyandAppConfig.robustness— keys are arbitrary user-chosen strings, values must beTransparencyConfig/RobustnessConfiginstances (or dicts with the right keys).All scalar fields on
ModelConfig,DataConfig,LabelsConfig,MetricsConfig,TrackingConfig,ReportingConfigare checked by OmegaConf's structured-config validation when the orchestrator boots.
Library-forwarded kwargs (unchecked at schema time):
TransparencyConfig.constructor/.call/.raitap— dicts forwarded verbatim to the underlying explainer library (Captum'sIntegratedGradients(...), SHAP'sGradientExplainer(...), the corresponding.attribute()/.shap_values()call).RobustnessConfig.constructor/.call/.raitap— same story for torchattacks / Foolbox / Marabou.
RunOutputs shape¶
raitap.run returns a frozen RunOutputs dataclass:
Field |
Type |
Meaning |
|---|---|---|
|
|
Typed model forward output (predictions tensor or detection predictions) + batch size. |
|
|
Each configured assessment phase's result, keyed by phase name (see below). Only configured phases appear. |
|
|
Stable ids aligned with |
|
|
Ground-truth labels when configured — a tensor (classification) or |
|
|
Per-sample |
Typed accessors (the common path) return each phase's result data:
Accessor |
Type |
Notes |
|---|---|---|
|
|
|
|
|
one per explainer; |
|
|
one per assessor; |
Every per-adapter result (ExplanationResult, RobustnessResult) shares one envelope — the AdapterResult contract: .name (the config key, e.g. "ig"), .adapter_target (the _target_ class), .algorithm, .semantics, .run_dir, and .visualisations (the figures that result owns) — plus its own domain payload (.attributions / .verdicts / …).
Mapping access reaches the underlying PhaseResult wrapper for any (incl. future) phase: outputs.get(name), outputs[name], name in outputs.
result = run(cfg)
if result.metrics is not None:
print(result.metrics.scalars) # dict[str, float]
for explanation in result.transparency: # list[ExplanationResult]
print(explanation.name, explanation.algorithm, len(explanation.visualisations))
for assessment in result.robustness: # list[RobustnessResult]
print(assessment.name, assessment.metrics.attack_success_rate)
fairness = result.get("fairness") # any future phase → PhaseResult | None
Multiruns in Python¶
Only the CLI can benefit from Hydra's multirun feature (see Batch runs), but you can recreate the loop in Python directly:
from copy import deepcopy
from raitap import run
cfg = # omitted, see example above
results = []
for eps in (0.01, 0.03, 0.06, 0.1):
copied_cfg = deepcopy(cfg)
copied_cfg.robustness["pgd"].constructor["eps"] = eps
copied_cfg.experiment_name = f"pgd-eps={eps}"
results.append((eps, run(copied_cfg, verbose=False)))
for eps, outputs in results:
metrics = outputs.get("metrics")
print(eps, metrics.result.scalars if metrics else None)