Skip to content

chuchichaestli.metrics.fid

FID evaluation metric, including InceptionV3 for feature mapping.

Classes:

Name Description
FID

Frechet inception distance.

FIDInceptionV3

InceptionV3 model for calculating FIDs.

FID

FID(
    model: Module | None = None,
    feature_dim: int | None = None,
    device: torch.device | None = None,
    n_images: int = 0,
    **kwargs,
)

Bases: EvalMetric

Frechet inception distance.

Constructor.

Parameters:

Name Type Description Default
model Module | None

Model from which to extract features.

None
feature_dim int | None

Feature dimension of the model output; if None, the feature dimension is determined automatically if possible).

None
device torch.device | None

Tensor allocation/computation device.

None
n_images int

Number of images seen by the internal state.

0
kwargs

Additional keyword arguments (passed to parent class).

{}

Methods:

Name Description
compute

Return current metric state total.

reset

Reset the current metrics state.

to

Perform tensor device conversion for all internal tensors.

update

Compute metric on new input and update current state.

Attributes:

Name Type Description
feature_dim

Feature dimension of the output.

Source code in src/chuchichaestli/metrics/fid.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def __init__(
    self,
    model: Module | None = None,
    feature_dim: int | None = None,
    device: torch.device | None = None,
    n_images: int = 0,
    **kwargs,
):
    """Constructor.

    Args:
        model: Model from which to extract features.
        feature_dim: Feature dimension of the model output; if `None`, the
          feature dimension is determined automatically if possible).
        device: Tensor allocation/computation device.
        n_images: Number of images seen by the internal state.
        kwargs: Additional keyword arguments (passed to parent class).
    """
    super().__init__(device=device, n_images=n_images, **kwargs)
    if model is None:
        model = FIDInceptionV3()
    self.model = model.to(self.device)
    self.model.eval()
    self.model.requires_grad_(False)
    self._feature_dim = feature_dim
    self.n_images_fake = torch.tensor(n_images // 2, device=self.device)
    self.n_images_real = torch.tensor(n_images - n_images // 2, device=self.device)
    self.aggregate_fake = torch.zeros(self.feature_dim, device=self.device)
    self.aggregate_real = torch.zeros(self.feature_dim, device=self.device)
    self.aggregate_cov_fake = torch.zeros(
        (self.feature_dim, self.feature_dim), device=self.device
    )
    self.aggregate_cov_real = torch.zeros(
        (self.feature_dim, self.feature_dim), device=self.device
    )

feature_dim property

feature_dim

Feature dimension of the output.

compute

compute() -> float

Return current metric state total.

Source code in src/chuchichaestli/metrics/fid.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
@torch.inference_mode()
def compute(self) -> float:
    """Return current metric state total."""
    if (self.n_images_fake < 1) or (self.n_images_real < 1):
        warnings.warn(
            "Computing FID requires at least 1 real images and 1 fake images."
        )
        return torch.tensor(0.0)
    mean_fake = (self.aggregate_fake / self.n_images_fake).unsqueeze(0)
    mean_real = (self.aggregate_real / self.n_images_real).unsqueeze(0)
    n_cov_fake = self.aggregate_cov_fake - self.n_images_fake * torch.matmul(
        mean_fake.T, mean_fake
    )
    cov_fake = n_cov_fake / max(self.n_images_fake - 1, 1)
    n_cov_real = self.aggregate_cov_real - self.n_images_real * torch.matmul(
        mean_real.T, mean_real
    )
    cov_real = n_cov_real / max(self.n_images_real - 1, 1)
    return self._calculate_frechet_distance(
        mean_real.squeeze(), cov_real, mean_fake.squeeze(), cov_fake
    )

reset

reset(**kwargs)

Reset the current metrics state.

Source code in src/chuchichaestli/metrics/fid.py
198
199
200
201
202
203
204
205
206
def reset(self, **kwargs):
    """Reset the current metrics state."""
    self.__init__(
        min_value=self.min_value.item(),
        max_value=self.max_value.item(),
        device=self.device,
        model=self.model,
        feature_dim=self._feature_dim,
    )

to

to(device: torch.device = None)

Perform tensor device conversion for all internal tensors.

Parameters:

Name Type Description Default
device torch.device

Tensor allocation/computation device.

None
Source code in src/chuchichaestli/metrics/fid.py
208
209
210
211
212
213
214
215
216
217
218
219
def to(self, device: torch.device = None):
    """Perform tensor device conversion for all internal tensors.

    Args:
        device: Tensor allocation/computation device.
    """
    super().to(device)
    self.model = self.model.to(device)
    self.n_images_fake = self.n_images_fake.to(device=device)
    self.n_images_real = self.n_images_real.to(device=device)
    self.aggregate_fake = self.aggregate_fake.to(device=device)
    self.aggregate_real = self.aggregate_real.to(device=device)

update

update(
    data: torch.Tensor | None = None,
    prediction: torch.Tensor | None = None,
    **kwargs,
)

Compute metric on new input and update current state.

Parameters:

Name Type Description Default
data torch.Tensor | None

Observed (real) data.

None
prediction torch.Tensor | None

Predicted (fake) data.

None
kwargs

Additional keyword arguments for parent class.

{}
Source code in src/chuchichaestli/metrics/fid.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
@torch.inference_mode()
def update(
    self,
    data: torch.Tensor | None = None,
    prediction: torch.Tensor | None = None,
    **kwargs,
):
    """Compute metric on new input and update current state.

    Args:
        data: Observed (real) data.
        prediction: Predicted (fake) data.
        kwargs: Additional keyword arguments for parent class.
    """
    sample = kwargs.pop("sample", 0)
    # fake images
    if prediction is not None:
        if prediction.ndim == 5:
            prediction = as_batched_slices(prediction, sample=sample)
        super().update(prediction, prediction, **kwargs)
        features_fake = self.model(prediction)
        self.n_images_fake += prediction.shape[0]
        self.aggregate_fake += torch.sum(features_fake, dim=0)
        self.aggregate_cov_fake += torch.matmul(features_fake.T, features_fake)
    # real images
    if data is not None:
        if data.ndim == 5:
            data = as_batched_slices(data, sample=sample)
        super().update(data, data, **kwargs)
        features_real = self.model(data)
        self.n_images_real += data.shape[0]
        self.aggregate_real += torch.sum(features_real, dim=0)
        self.aggregate_cov_real += torch.matmul(features_real.T, features_real)
    return self

FIDInceptionV3

FIDInceptionV3(
    weights: tv.Inception_V3_Weights | None = None,
    use_default_transforms: bool = True,
    mode: str = "bilinear",
    antialias: bool = False,
)

Bases: Module

InceptionV3 model for calculating FIDs.

Constructor.

Parameters:

Name Type Description Default
weights tv.Inception_V3_Weights | None

The pretrained weights for the model; for details and possible values see https://docs.pytorch.org/vision/stable/models/generated/torchvision.models.inception_v3.html#torchvision.models.Inception_V3_Weights.

None
use_default_transforms bool

If True, uses standard transforms for preprocessing (Inception_V3_Weights.IMAGENET1K_V1.transforms).

True
mode str

If use_default_transforms=False, a simple interpolation in given mode is performed.

'bilinear'
antialias bool

If use_default_transforms=False and True, antialiasing is used during interpolation.

False

Methods:

Name Description
forward

Forward method for the FIDInceptionV3 model.

Source code in src/chuchichaestli/metrics/fid.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def __init__(
    self,
    weights: tv.Inception_V3_Weights | None = None,
    use_default_transforms: bool = True,
    mode: str = "bilinear",
    antialias: bool = False,
):
    """Constructor.

    Args:
        weights: The pretrained weights for the model; for details and possible values
          see https://docs.pytorch.org/vision/stable/models/generated/torchvision.models.inception_v3.html#torchvision.models.Inception_V3_Weights.
        use_default_transforms: If `True`, uses standard transforms for
          preprocessing (Inception_V3_Weights.IMAGENET1K_V1.transforms).
        mode: If `use_default_transforms=False`, a simple interpolation in
          given mode is performed.
        antialias: If `use_default_transforms=False` and `True`,
          antialiasing is used during interpolation.
    """
    super().__init__()
    self.weights = (
        tv.Inception_V3_Weights.IMAGENET1K_V1 if weights is None else weights
    )
    self.model = tv.inception_v3(weights=self.weights)
    self.model.fc = torch.nn.Identity()
    if use_default_transforms:
        if isinstance(self.weights, str):
            weights = tv.get_model_weights("inception_v3")[self.weights]
            self.transforms = weights.transforms()
        else:
            self.transforms = self.weights.transforms()
    else:
        self.transforms = partial(
            interpolate,
            size=(299, 299),
            mode=mode,
            align_corners=False,
            antialias=antialias,
        )

forward

forward(x: torch.Tensor) -> torch.Tensor

Forward method for the FIDInceptionV3 model.

Source code in src/chuchichaestli/metrics/fid.py
66
67
68
69
70
71
72
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward method for the FIDInceptionV3 model."""
    x = sanitize_ndim(x)
    x = as_tri_channel(x)
    x = self.transforms(x)
    x = self.model(x)
    return x