Jatin-tec commited on
Commit
65d7391
·
1 Parent(s): 140553b

Add application file

Browse files
.gitignore ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python artifacts
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.so
5
+
6
+ # Virtual environments
7
+ .venv/
8
+ venv/
9
+ ENV/
10
+ env/
11
+ .env
12
+ .env.*
13
+
14
+ # Packaging / build outputs
15
+ build/
16
+ dist/
17
+ *.egg-info/
18
+
19
+ # Testing and type checking caches
20
+ .pytest_cache/
21
+ .mypy_cache/
22
+ .pytype/
23
+ .ruff_cache/
24
+ .coverage
25
+ coverage.xml
26
+
27
+ # Jupyter
28
+ .ipynb_checkpoints/
29
+
30
+ # Editor and OS cruft
31
+ .vscode/
32
+ .idea/
33
+ *.swp
34
+ .DS_Store
35
+
36
+ # Logs
37
+ *.log
38
+
39
+ # Docker scratch space
40
+ test_docker/data/
41
+ test_docker/data_out/
42
+
README.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Interface Demo
2
+
3
+ This Gradio app compares two detectors for image provenance:
4
+ - Hugging Face `Ateeqq/ai-vs-human-image-detector` estimates whether an image is AI-generated or human-made.
5
+ - A bundled TruFor backend estimates tampering and renders heatmaps when the required weights are present.
6
+
7
+ ## Requirements
8
+ - Python 3.9 or newer
9
+ - `pip install -r requirements.txt`
10
+
11
+ ## Getting Started
12
+ 1. Create or activate a virtual environment that uses Python 3.9+.
13
+ 2. Install dependencies:
14
+ ```bash
15
+ pip install -r requirements.txt
16
+ ```
17
+ 3. Launch the interface:
18
+ ```bash
19
+ python app.py
20
+ ```
21
+ Gradio prints a local URL in the terminal; open it in a browser and upload an image to view the AI/Human probabilities alongside TruFor diagnostics.
22
+
23
+ ## TruFor Weights
24
+ TruFor is released for non-commercial research use. Obtain the official `trufor.pth.tar` weight file from the upstream project and place it at `weights/trufor.pth.tar` (or set the environment variable `TRUFOR_WEIGHTS` to point to the file). When the weights are available, the app switches to the native TruFor backend and overlays tamper and confidence heatmaps next to the classifier output.
25
+
26
+ Optional environment variables:
27
+ - `TRUFOR_BACKEND`: force a backend (`native`, `docker`, or `auto`). The default is `auto`, which prefers the bundled native implementation.
28
+ - `TRUFOR_WEIGHTS`: absolute or relative path to `trufor.pth.tar` if you keep the file outside `weights/`.
29
+
30
+ ## Notes
31
+ - The TruFor assets are redistributed here as Python modules for convenience, but you must still respect the upstream license for any research or redistribution.
32
+ - Docker support remains available for legacy setups, but no container build steps are required when using the bundled backend.
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from typing import Dict, Optional, Tuple
5
+ from transformers import AutoImageProcessor, SiglipForImageClassification
6
+
7
+ from trufor_runner import TruForEngine, TruForResult, TruForUnavailableError
8
+
9
+ MODEL_ID = "Ateeqq/ai-vs-human-image-detector"
10
+
11
+ # Use GPU when available so large batches stay responsive.
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ try:
15
+ processor = AutoImageProcessor.from_pretrained(MODEL_ID)
16
+ model = SiglipForImageClassification.from_pretrained(MODEL_ID)
17
+ model.to(device)
18
+ model.eval()
19
+ except Exception as exc: # pragma: no cover - surface loading issues early.
20
+ raise RuntimeError(f"Failed to load model from {MODEL_ID}") from exc
21
+
22
+ try:
23
+ TRUFOR_ENGINE: Optional[TruForEngine] = TruForEngine(device="cpu")
24
+ TRUFOR_STATUS = TRUFOR_ENGINE.status_message
25
+ except TruForUnavailableError as exc:
26
+ TRUFOR_ENGINE = None
27
+ TRUFOR_STATUS = str(exc)
28
+
29
+
30
+ def analyze_ai_vs_human(image: Image.Image) -> Tuple[Dict[str, float], str]:
31
+ """Run the Hugging Face detector and return confidences with a readable summary."""
32
+ if image is None:
33
+ empty_scores = {label: 0.0 for label in model.config.id2label.values()}
34
+ return empty_scores, "No image provided."
35
+
36
+ image = image.convert("RGB")
37
+ inputs = processor(images=image, return_tensors="pt").to(device)
38
+
39
+ with torch.no_grad():
40
+ logits = model(**inputs).logits
41
+
42
+ probabilities = torch.softmax(logits, dim=-1)[0]
43
+ scores = {
44
+ model.config.id2label[idx]: float(probabilities[idx])
45
+ for idx in range(probabilities.size(0))
46
+ }
47
+
48
+ top_idx = int(probabilities.argmax().item())
49
+ top_label = model.config.id2label[top_idx]
50
+ top_score = scores[top_label]
51
+ summary = f"**Predicted Label:** {top_label} \
52
+ **Confidence:** {top_score:.4f}"
53
+
54
+ return scores, summary
55
+
56
+
57
+ def analyze_trufor(image: Image.Image) -> Tuple[str, Optional[Image.Image], Optional[Image.Image]]:
58
+ """Run TruFor inference when available, otherwise return diagnostics."""
59
+ if TRUFOR_ENGINE is None:
60
+ return TRUFOR_STATUS, None, None
61
+
62
+ if image is None:
63
+ return "Upload an image to run TruFor.", None, None
64
+
65
+ try:
66
+ result: TruForResult = TRUFOR_ENGINE.infer(image)
67
+ except TruForUnavailableError as exc:
68
+ return str(exc), None, None
69
+
70
+ summary_lines = []
71
+ if result.score is not None:
72
+ summary_lines.append(f"**Tamper Score:** {result.score:.4f}")
73
+ extras_dict = result.raw_scores.copy()
74
+ if result.score is not None:
75
+ extras_dict.pop("tamper_score", None)
76
+ if extras_dict:
77
+ extras = " ".join(f"{key}: {value:.4f}" for key, value in extras_dict.items())
78
+ summary_lines.append(f"`{extras}`")
79
+ if not summary_lines:
80
+ summary_lines.append("TruFor returned no scores for this image.")
81
+
82
+ return "\n".join(summary_lines), result.map_overlay, result.confidence_overlay
83
+
84
+
85
+ def analyze_image(image: Image.Image) -> Tuple[Dict[str, float], str, str, Optional[Image.Image], Optional[Image.Image]]:
86
+ ai_scores, ai_summary = analyze_ai_vs_human(image)
87
+ trufor_summary, tamper_overlay, conf_overlay = analyze_trufor(image)
88
+ return ai_scores, ai_summary, trufor_summary, tamper_overlay, conf_overlay
89
+
90
+
91
+ with gr.Blocks() as demo:
92
+ gr.Markdown(
93
+ """# Image Authenticity Workbench\nUpload an image to compare the AI-vs-human classifier with the TruFor forgery detector."""
94
+ )
95
+
96
+ status_box = gr.Markdown(f"`{TRUFOR_STATUS}`")
97
+
98
+ image_input = gr.Image(label="Input Image", type="pil")
99
+ analyze_button = gr.Button("Analyze", variant="primary", size="sm")
100
+
101
+ with gr.Tabs():
102
+ with gr.TabItem("AI vs Human"):
103
+ ai_label_output = gr.Label(label="Prediction", num_top_classes=2)
104
+ ai_summary_output = gr.Markdown("Upload an image to view the prediction.")
105
+ with gr.TabItem("TruFor Forgery Detection"):
106
+ trufor_summary_output = gr.Markdown("Configure TruFor assets to enable tamper analysis.")
107
+ tamper_overlay_output = gr.Image(label="Tamper Heatmap", type="pil", interactive=False)
108
+ conf_overlay_output = gr.Image(label="Confidence Heatmap", type="pil", interactive=False)
109
+
110
+ output_components = [
111
+ ai_label_output,
112
+ ai_summary_output,
113
+ trufor_summary_output,
114
+ tamper_overlay_output,
115
+ conf_overlay_output,
116
+ ]
117
+
118
+ analyze_button.click(
119
+ fn=analyze_image,
120
+ inputs=image_input,
121
+ outputs=output_components,
122
+ )
123
+
124
+ image_input.change(
125
+ fn=analyze_image,
126
+ inputs=image_input,
127
+ outputs=output_components,
128
+ )
129
+
130
+
131
+ if __name__ == "__main__":
132
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio==4.44.1
2
+ pydantic==2.8.2
3
+ transformers==4.44.2
4
+ torch>=2.1,<3
5
+ Pillow>=10.0
6
+ numpy>=1.23
7
+ timm>=0.5.4
trufor_native/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Bundled TruFor model for native inference."""
2
+
3
+ from .inference import TruForBundledModel
4
+
5
+ __all__ = ["TruForBundledModel"]
trufor_native/inference.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image
11
+
12
+ from .models.cmx.builder_np_conf import myEncoderDecoder as TruForNetwork
13
+
14
+ LOGGER = logging.getLogger(__name__)
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class TruForOutputs:
19
+ """Lightweight container for TruFor inference outputs."""
20
+
21
+ tamper_map: np.ndarray
22
+ confidence_map: Optional[np.ndarray]
23
+ detection_score: Optional[float]
24
+
25
+
26
+ class TruForBundledModel:
27
+ """Loads the TruFor network from the vendored sources and runs inference."""
28
+
29
+ def __init__(self, weights_path: Path | str, device: str = "cpu") -> None:
30
+ self.weights_path = Path(weights_path)
31
+ if not self.weights_path.exists():
32
+ raise FileNotFoundError(f"TruFor weights missing at {self.weights_path}")
33
+
34
+ try:
35
+ self.device = torch.device(device)
36
+ except RuntimeError as exc: # pragma: no cover - defensive path for invalid strings
37
+ raise ValueError(f"Unsupported torch device '{device}'") from exc
38
+
39
+ self.model = self._build_model().to(self.device)
40
+ self.model.eval()
41
+
42
+ # ------------------------------------------------------------------
43
+ # Public API
44
+ # ------------------------------------------------------------------
45
+ def predict(self, image: Image.Image) -> TruForOutputs:
46
+ if image is None:
47
+ raise ValueError("An input image is required for TruFor inference.")
48
+
49
+ tensor = self._prepare_tensor(image).to(self.device)
50
+
51
+ with torch.inference_mode():
52
+ pred, conf, det, _ = self.model(tensor)
53
+
54
+ tamper_map = torch.softmax(pred[0], dim=0)[1].cpu().numpy()
55
+
56
+ confidence_map: Optional[np.ndarray] = None
57
+ if conf is not None:
58
+ confidence_map = torch.sigmoid(conf[0][0]).cpu().numpy()
59
+
60
+ detection_score: Optional[float] = None
61
+ if det is not None:
62
+ detection_score = torch.sigmoid(det).item()
63
+
64
+ return TruForOutputs(
65
+ tamper_map=tamper_map,
66
+ confidence_map=confidence_map,
67
+ detection_score=detection_score,
68
+ )
69
+
70
+ # ------------------------------------------------------------------
71
+ # Internal helpers
72
+ # ------------------------------------------------------------------
73
+ def _build_model(self) -> torch.nn.Module:
74
+ cfg = self._default_config()
75
+ model = TruForNetwork(cfg=cfg)
76
+ checkpoint = torch.load(self.weights_path, map_location="cpu", weights_only=False)
77
+ state_dict = checkpoint.get("state_dict", checkpoint)
78
+
79
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
80
+ if missing:
81
+ LOGGER.warning("TruFor missing keys: %s", sorted(missing))
82
+ if unexpected:
83
+ LOGGER.warning("TruFor unexpected keys: %s", sorted(unexpected))
84
+
85
+ return model
86
+
87
+ @staticmethod
88
+ def _prepare_tensor(image: Image.Image) -> torch.Tensor:
89
+ rgb = np.asarray(image.convert("RGB"), dtype=np.float32)
90
+ tensor = torch.from_numpy(rgb.transpose(2, 0, 1)).unsqueeze(0)
91
+ tensor = tensor / 256.0 # matches the reference implementation
92
+ return tensor
93
+
94
+ class AttrNamespace(dict):
95
+ def __getattr__(self, item):
96
+ try:
97
+ return self[item]
98
+ except KeyError as exc:
99
+ raise AttributeError(item) from exc
100
+
101
+ def __setattr__(self, key, value):
102
+ self[key] = value
103
+
104
+ def __contains__(self, item):
105
+ return item in self.keys()
106
+
107
+ @classmethod
108
+ def _default_config(cls) -> AttrNamespace:
109
+ extra = cls.AttrNamespace(
110
+ BACKBONE="mit_b2",
111
+ DECODER="MLPDecoder",
112
+ DECODER_EMBED_DIM=512,
113
+ PREPRC="imagenet",
114
+ BN_EPS=0.001,
115
+ BN_MOMENTUM=0.1,
116
+ DETECTION="confpool",
117
+ CONF=True,
118
+ NP_WEIGHTS="",
119
+ )
120
+
121
+ model = cls.AttrNamespace(
122
+ NAME="detconfcmx",
123
+ MODS=("RGB", "NP++"),
124
+ PRETRAINED="",
125
+ EXTRA=extra,
126
+ )
127
+
128
+ dataset = cls.AttrNamespace(NUM_CLASSES=2)
129
+
130
+ return cls.AttrNamespace(MODEL=model, DATASET=dataset)
trufor_native/models/DnCNN.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
2
+ # Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA').
3
+ #
4
+ # All rights reserved.
5
+ # This work should only be used for nonprofit purposes.
6
+ #
7
+ # By downloading and/or using any of these files, you implicitly agree to all the
8
+ # terms of the license, as specified in the document LICENSE.txt
9
+ # (included in this package) and online at
10
+ # http://www.grip.unina.it/download/LICENSE_OPEN.txt
11
+
12
+ """
13
+ Created in September 2020
14
+ @author: davide.cozzolino
15
+ """
16
+
17
+ import math
18
+ import torch.nn as nn
19
+
20
+ def conv_with_padding(in_planes, out_planes, kernelsize, stride=1, dilation=1, bias=False, padding = None):
21
+ if padding is None:
22
+ padding = kernelsize//2
23
+ return nn.Conv2d(in_planes, out_planes, kernel_size=kernelsize, stride=stride, dilation=dilation, padding=padding, bias=bias)
24
+
25
+ def conv_init(conv, act='linear'):
26
+ r"""
27
+ Reproduces conv initialization from DnCNN
28
+ """
29
+ n = conv.kernel_size[0] * conv.kernel_size[1] * conv.out_channels
30
+ conv.weight.data.normal_(0, math.sqrt(2. / n))
31
+
32
+ def batchnorm_init(m, kernelsize=3):
33
+ r"""
34
+ Reproduces batchnorm initialization from DnCNN
35
+ """
36
+ n = kernelsize**2 * m.num_features
37
+ m.weight.data.normal_(0, math.sqrt(2. / (n)))
38
+ m.bias.data.zero_()
39
+
40
+ def make_activation(act):
41
+ if act is None:
42
+ return None
43
+ elif act == 'relu':
44
+ return nn.ReLU(inplace=True)
45
+ elif act == 'tanh':
46
+ return nn.Tanh()
47
+ elif act == 'leaky_relu':
48
+ return nn.LeakyReLU(inplace=True)
49
+ elif act == 'softmax':
50
+ return nn.Softmax()
51
+ elif act == 'linear':
52
+ return None
53
+ else:
54
+ assert(False)
55
+
56
+ def make_net(nplanes_in, kernels, features, bns, acts, dilats, bn_momentum = 0.1, padding=None):
57
+ r"""
58
+ :param nplanes_in: number of of input feature channels
59
+ :param kernels: list of kernel size for convolution layers
60
+ :param features: list of hidden layer feature channels
61
+ :param bns: list of whether to add batchnorm layers
62
+ :param acts: list of activations
63
+ :param dilats: list of dilation factors
64
+ :param bn_momentum: momentum of batchnorm
65
+ :param padding: integer for padding (None for same padding)
66
+ """
67
+
68
+ depth = len(features)
69
+ assert(len(features)==len(kernels))
70
+
71
+ layers = list()
72
+ for i in range(0,depth):
73
+ if i==0:
74
+ in_feats = nplanes_in
75
+ else:
76
+ in_feats = features[i-1]
77
+
78
+ elem = conv_with_padding(in_feats, features[i], kernelsize=kernels[i], dilation=dilats[i], padding=padding, bias=not(bns[i]))
79
+ conv_init(elem, act=acts[i])
80
+ layers.append(elem)
81
+
82
+ if bns[i]:
83
+ elem = nn.BatchNorm2d(features[i], momentum = bn_momentum)
84
+ batchnorm_init(elem, kernelsize=kernels[i])
85
+ layers.append(elem)
86
+
87
+ elem = make_activation(acts[i])
88
+ if elem is not None:
89
+ layers.append(elem)
90
+
91
+ return nn.Sequential(*layers)
92
+
93
+ class DnCNN(nn.Module):
94
+ r"""
95
+ Implements a DnCNN network
96
+ """
97
+ def __init__(self, nplanes_in, nplanes_out, features, kernel, depth, activation, residual, bn, lastact=None, bn_momentum = 0.10, padding=None):
98
+ r"""
99
+ :param nplanes_in: number of of input feature channels
100
+ :param nplanes_out: number of of output feature channels
101
+ :param features: number of of hidden layer feature channels
102
+ :param kernel: kernel size of convolution layers
103
+ :param depth: number of convolution layers (minimum 2)
104
+ :param bn: whether to add batchnorm layers
105
+ :param residual: whether to add a residual connection from input to output
106
+ :param bn_momentum: momentum of batchnorm
107
+ :param padding: inteteger for padding
108
+ """
109
+ super(DnCNN, self).__init__()
110
+
111
+ self.residual = residual
112
+ self.nplanes_out = nplanes_out
113
+ self.nplanes_in = nplanes_in
114
+
115
+ kernels = [kernel, ] * depth
116
+ features = [features, ] * (depth-1) + [nplanes_out, ]
117
+ bns = [False, ] + [bn,] * (depth - 2) + [False, ]
118
+ dilats = [1, ] * depth
119
+ acts = [activation, ] * (depth - 1) + [lastact, ]
120
+ self.layers = make_net(nplanes_in, kernels, features, bns, acts, dilats=dilats, bn_momentum = bn_momentum, padding=padding)
121
+
122
+
123
+ def forward(self, x):
124
+ shortcut = x
125
+
126
+ x = self.layers(x)
127
+
128
+ if self.residual:
129
+ nshortcut = min(self.nplanes_in, self.nplanes_out)
130
+ x[:, :nshortcut, :, :] = x[:, :nshortcut, :, :] + shortcut[:, :nshortcut, :, :]
131
+
132
+ return x
133
+
134
+
135
+ def add_commandline_networkparams(parser, name, features, depth, kernel, activation, bn):
136
+ parser.add_argument("--{}.{}".format(name, "features" ), type=int, default=features )
137
+ parser.add_argument("--{}.{}".format(name, "depth" ), type=int, default=depth )
138
+ parser.add_argument("--{}.{}".format(name, "kernel" ), type=int, default=kernel )
139
+ parser.add_argument("--{}.{}".format(name, "activation"), type=str, default=activation)
140
+
141
+ bnarg = "{}.{}".format(name, "bn")
142
+ parser.add_argument("--"+bnarg , action="store_true", dest=bnarg)
143
+ parser.add_argument("--{}.{}".format(name, "no-bn"), action="store_false", dest=bnarg)
144
+ parser.set_defaults(**{bnarg: bn})
145
+
trufor_native/models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
2
+ # Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA').
3
+ #
4
+ # All rights reserved.
5
+ # This work should only be used for nonprofit purposes.
6
+ #
7
+ # By downloading and/or using any of these files, you implicitly agree to all the
8
+ # terms of the license, as specified in the document LICENSE.txt
9
+ # (included in this package) and online at
10
+ # http://www.grip.unina.it/download/LICENSE_OPEN.txt
trufor_native/models/cmx/LICENSE_CMX.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Huayao Liu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
trufor_native/models/cmx/__init__.py ADDED
File without changes
trufor_native/models/cmx/builder_np_conf.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Edited in September 2022
3
+ @author: fabrizio.guillaro, davide.cozzolino
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import os
10
+
11
+ from .utils.init_func import init_weight
12
+
13
+ import logging
14
+
15
+
16
+ def preprc_imagenet_torch(x):
17
+ mean = torch.Tensor([0.485, 0.456, 0.406]).to(x.device)
18
+ std = torch.Tensor([0.229, 0.224, 0.225]).to(x.device)
19
+ x = (x-mean[None, :, None, None]) / std[None, :, None, None]
20
+ return x
21
+
22
+
23
+ def create_backbone(typ, norm_layer):
24
+ channels = [64, 128, 320, 512]
25
+ if typ == 'mit_b2':
26
+ logging.info('Using backbone: Segformer-B2')
27
+ from .encoders.dual_segformer import mit_b2 as backbone_
28
+ backbone = backbone_(norm_fuse=norm_layer)
29
+ else:
30
+ raise NotImplementedError('backbone not implemented')
31
+ return backbone, channels
32
+
33
+
34
+ class myEncoderDecoder(nn.Module):
35
+ def __init__(self, cfg=None, norm_layer=nn.BatchNorm2d):
36
+ super(myEncoderDecoder, self).__init__()
37
+
38
+ self.norm_layer = norm_layer
39
+ self.cfg = cfg.MODEL.EXTRA
40
+ self.mods = cfg.MODEL.MODS
41
+
42
+ # import backbone and decoder
43
+ self.backbone, self.channels = create_backbone(self.cfg.BACKBONE, norm_layer)
44
+
45
+ if 'CONF_BACKBONE' in self.cfg:
46
+ self.backbone_conf, self.channels_conf = create_backbone(self.cfg.CONF_BACKBONE, norm_layer)
47
+ else:
48
+ self.backbone_conf = None
49
+
50
+ if self.cfg.DECODER == 'MLPDecoder':
51
+ logging.info('Using MLP Decoder')
52
+ from .decoders.MLPDecoder import DecoderHead
53
+ self.decode_head = DecoderHead(in_channels=self.channels, num_classes=cfg.DATASET.NUM_CLASSES, norm_layer=norm_layer, embed_dim=self.cfg.DECODER_EMBED_DIM)
54
+
55
+ if self.cfg.CONF:
56
+ self.decode_head_conf = DecoderHead(in_channels=self.channels, num_classes=1, norm_layer=norm_layer, embed_dim=self.cfg.DECODER_EMBED_DIM)
57
+ else:
58
+ self.decode_head_conf = None
59
+
60
+ self.conf_detection = None
61
+ if self.cfg.DETECTION is not None:
62
+ if self.cfg.DETECTION == 'none':
63
+ pass
64
+ elif self.cfg.DETECTION == 'confpool':
65
+ self.conf_detection = 'confpool'
66
+ assert self.cfg.CONF
67
+ self.detection = nn.Sequential(
68
+ nn.Linear(in_features=8, out_features=128),
69
+ nn.ReLU(),
70
+ nn.Dropout(p=0.5),
71
+ nn.Linear(in_features=128, out_features=1),
72
+ )
73
+ else:
74
+ raise NotImplementedError('Detection mechanism not implemented')
75
+
76
+ else:
77
+ raise NotImplementedError('decoder not implemented')
78
+
79
+ from ..DnCNN import make_net
80
+ num_levels = 17
81
+ out_channel = 1
82
+ self.dncnn = make_net(3, kernels=[3, ] * num_levels,
83
+ features=[64, ] * (num_levels - 1) + [out_channel],
84
+ bns=[False, ] + [True, ] * (num_levels - 2) + [False, ],
85
+ acts=['relu', ] * (num_levels - 1) + ['linear', ],
86
+ dilats=[1, ] * num_levels,
87
+ bn_momentum=0.1, padding=1)
88
+
89
+ if self.cfg.PREPRC == 'imagenet': #RGB (mean and variance)
90
+ self.prepro = preprc_imagenet_torch
91
+ else:
92
+ assert False
93
+
94
+ self.init_weights(pretrained=cfg.MODEL.PRETRAINED)
95
+
96
+
97
+
98
+ def init_weights(self, pretrained=None):
99
+ if pretrained:
100
+ logging.info('Loading pretrained model: {}'.format(pretrained))
101
+ self.backbone.init_weights(pretrained=pretrained)
102
+ if self.backbone_conf is not None:
103
+ self.backbone_conf.init_weights(pretrained=pretrained)
104
+
105
+ np_weights = self.cfg.NP_WEIGHTS
106
+ assert os.path.isfile(np_weights)
107
+ dat = torch.load(np_weights, map_location=torch.device('cpu'))
108
+ logging.info(f'Noiseprint++ weights: {np_weights}')
109
+ if 'network' in dat:
110
+ dat = dat['network']
111
+ self.dncnn.load_state_dict(dat)
112
+
113
+ logging.info('Initing weights ...')
114
+ init_weight(self.decode_head, nn.init.kaiming_normal_,
115
+ self.norm_layer, self.cfg.BN_EPS, self.cfg.BN_MOMENTUM,
116
+ mode='fan_in', nonlinearity='relu')
117
+
118
+
119
+
120
+
121
+ def encode_decode(self, rgb, modal_x):
122
+
123
+ if rgb is not None:
124
+ orisize = rgb.shape
125
+ else:
126
+ orisize = modal_x.shape
127
+
128
+ # cmx
129
+ x = self.backbone(rgb, modal_x)
130
+ out, feats = self.decode_head(x, return_feats=True)
131
+ out = F.interpolate(out, size=orisize[2:], mode='bilinear', align_corners=False)
132
+
133
+ # confidence
134
+ if self.decode_head_conf is not None:
135
+ if self.backbone_conf is not None:
136
+ x_conf = self.backbone_conf(rgb, modal_x)
137
+ else:
138
+ x_conf = x # same encoder of Localization Network
139
+
140
+ conf = self.decode_head_conf(x_conf)
141
+ conf = F.interpolate(conf, size=orisize[2:], mode='bilinear', align_corners=False)
142
+ else:
143
+ conf = None
144
+
145
+
146
+ # detection
147
+ if self.conf_detection is not None:
148
+ if self.conf_detection == 'confpool':
149
+ from .layer_utils import weighted_statistics_pooling
150
+ f1 = weighted_statistics_pooling(conf).view(out.shape[0],-1)
151
+ f2 = weighted_statistics_pooling(out[:,1:2,:,:]-out[:,0:1,:,:], F.logsigmoid(conf)).view(out.shape[0],-1)
152
+ det = self.detection(torch.cat((f1,f2),-1))
153
+ else:
154
+ assert False
155
+ else:
156
+ det = None
157
+
158
+ return out, conf, det
159
+
160
+
161
+ def forward(self, rgb):
162
+
163
+ # Noiseprint++ extraction
164
+ if 'NP++' in self.mods:
165
+ modal_x = self.dncnn(rgb)
166
+ modal_x = torch.tile(modal_x, (3, 1, 1))
167
+ else:
168
+ modal_x = None
169
+
170
+ if self.prepro is not None:
171
+ rgb = self.prepro(rgb)
172
+
173
+ out, conf, det = self.encode_decode(rgb, modal_x)
174
+ return out, conf, det, modal_x
175
+
trufor_native/models/cmx/decoders/MLPDecoder.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch.nn as nn
3
+ import torch
4
+
5
+ import torch.nn.functional as F
6
+
7
+ class MLP(nn.Module):
8
+ """
9
+ Linear Embedding:
10
+ """
11
+ def __init__(self, input_dim=2048, embed_dim=768):
12
+ super().__init__()
13
+ self.proj = nn.Linear(input_dim, embed_dim)
14
+
15
+ def forward(self, x):
16
+ x = x.flatten(2).transpose(1, 2)
17
+ x = self.proj(x)
18
+ return x
19
+
20
+
21
+ class DecoderHead(nn.Module):
22
+ def __init__(self,
23
+ in_channels=[64, 128, 320, 512],
24
+ num_classes=40,
25
+ dropout_ratio=0.1,
26
+ norm_layer=nn.BatchNorm2d,
27
+ embed_dim=768,
28
+ align_corners=False):
29
+
30
+ super(DecoderHead, self).__init__()
31
+ self.num_classes = num_classes
32
+ self.dropout_ratio = dropout_ratio
33
+ self.align_corners = align_corners
34
+
35
+ self.in_channels = in_channels
36
+
37
+ if dropout_ratio > 0:
38
+ self.dropout = nn.Dropout2d(dropout_ratio)
39
+ else:
40
+ self.dropout = None
41
+
42
+ c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
43
+
44
+ embedding_dim = embed_dim
45
+ self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim)
46
+ self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim)
47
+ self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim)
48
+ self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim)
49
+
50
+ self.linear_fuse = nn.Sequential(
51
+ nn.Conv2d(in_channels=embedding_dim*4, out_channels=embedding_dim, kernel_size=1),
52
+ norm_layer(embedding_dim),
53
+ nn.ReLU(inplace=True)
54
+ )
55
+
56
+ self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1)
57
+
58
+ def forward(self, inputs, return_feats=False):
59
+ # len=4, 1/4,1/8,1/16,1/32
60
+ c1, c2, c3, c4 = inputs
61
+
62
+ ############## MLP decoder on C1-C4 ###########
63
+ n, _, h, w = c4.shape
64
+
65
+ _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])
66
+ _c4 = F.interpolate(_c4, size=c1.size()[2:],mode='bilinear',align_corners=self.align_corners)
67
+
68
+ _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])
69
+ _c3 = F.interpolate(_c3, size=c1.size()[2:],mode='bilinear',align_corners=self.align_corners)
70
+
71
+ _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])
72
+ _c2 = F.interpolate(_c2, size=c1.size()[2:],mode='bilinear',align_corners=self.align_corners)
73
+
74
+ _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])
75
+
76
+ _c = torch.cat([_c4, _c3, _c2, _c1], dim=1)
77
+ x = self.linear_fuse(_c)
78
+ x = self.dropout(x)
79
+ x = self.linear_pred(x)
80
+
81
+ if return_feats:
82
+ return x, _c
83
+ else:
84
+ return x
85
+
86
+
trufor_native/models/cmx/decoders/__init__.py ADDED
File without changes
trufor_native/models/cmx/encoders/__init__.py ADDED
File without changes
trufor_native/models/cmx/encoders/dual_segformer.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from functools import partial
5
+
6
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7
+ from ..net_utils import FeatureFusionModule as FFM
8
+ from ..net_utils import FeatureRectifyModule as FRM
9
+ import math
10
+ import time
11
+ #from engine.logger import get_logger
12
+ import logging as logger
13
+
14
+ #logger = get_logger()
15
+
16
+
17
+ class DWConv(nn.Module):
18
+ """
19
+ Depthwise convolution bloc: input: x with size(B N C); output size (B N C)
20
+ """
21
+ def __init__(self, dim=768):
22
+ super(DWConv, self).__init__()
23
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=True, groups=dim)
24
+
25
+ def forward(self, x, H, W):
26
+ B, N, C = x.shape
27
+ x = x.permute(0, 2, 1).reshape(B, C, H, W).contiguous() # B N C -> B C N -> B C H W
28
+ x = self.dwconv(x)
29
+ x = x.flatten(2).transpose(1, 2) # B C H W -> B N C
30
+
31
+ return x
32
+
33
+
34
+ class Mlp(nn.Module):
35
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
36
+ super().__init__()
37
+ """
38
+ MLP Block:
39
+ """
40
+ out_features = out_features or in_features
41
+ hidden_features = hidden_features or in_features
42
+ self.fc1 = nn.Linear(in_features, hidden_features)
43
+ self.dwconv = DWConv(hidden_features)
44
+ self.act = act_layer()
45
+ self.fc2 = nn.Linear(hidden_features, out_features)
46
+ self.drop = nn.Dropout(drop)
47
+
48
+ self.apply(self._init_weights)
49
+
50
+ def _init_weights(self, m):
51
+ if isinstance(m, nn.Linear):
52
+ trunc_normal_(m.weight, std=.02)
53
+ if isinstance(m, nn.Linear) and m.bias is not None:
54
+ nn.init.constant_(m.bias, 0)
55
+ elif isinstance(m, nn.LayerNorm):
56
+ nn.init.constant_(m.bias, 0)
57
+ nn.init.constant_(m.weight, 1.0)
58
+ elif isinstance(m, nn.Conv2d):
59
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
60
+ fan_out //= m.groups
61
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
62
+ if m.bias is not None:
63
+ m.bias.data.zero_()
64
+
65
+ def forward(self, x, H, W):
66
+ x = self.fc1(x)
67
+ x = self.dwconv(x, H, W)
68
+ x = self.act(x)
69
+ x = self.drop(x)
70
+ x = self.fc2(x)
71
+ x = self.drop(x)
72
+ return x
73
+
74
+
75
+ class Attention(nn.Module):
76
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
77
+ super().__init__()
78
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
79
+
80
+ self.dim = dim
81
+ self.num_heads = num_heads
82
+ head_dim = dim // num_heads
83
+ self.scale = qk_scale or head_dim ** -0.5
84
+
85
+ # Linear embedding
86
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
87
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
88
+ self.attn_drop = nn.Dropout(attn_drop)
89
+ self.proj = nn.Linear(dim, dim)
90
+ self.proj_drop = nn.Dropout(proj_drop)
91
+
92
+ self.sr_ratio = sr_ratio
93
+ if sr_ratio > 1:
94
+ self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
95
+ self.norm = nn.LayerNorm(dim)
96
+
97
+ self.apply(self._init_weights)
98
+
99
+ def _init_weights(self, m):
100
+ if isinstance(m, nn.Linear):
101
+ trunc_normal_(m.weight, std=.02)
102
+ if isinstance(m, nn.Linear) and m.bias is not None:
103
+ nn.init.constant_(m.bias, 0)
104
+ elif isinstance(m, nn.LayerNorm):
105
+ nn.init.constant_(m.bias, 0)
106
+ nn.init.constant_(m.weight, 1.0)
107
+ elif isinstance(m, nn.Conv2d):
108
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
109
+ fan_out //= m.groups
110
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
111
+ if m.bias is not None:
112
+ m.bias.data.zero_()
113
+
114
+ def forward(self, x, H, W):
115
+ B, N, C = x.shape
116
+ # B N C -> B N num_head C//num_head -> B C//num_head N num_heads
117
+ q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
118
+
119
+ if self.sr_ratio > 1:
120
+ x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
121
+ x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
122
+ x_ = self.norm(x_)
123
+ kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
124
+ else:
125
+ kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
126
+ k, v = kv[0], kv[1]
127
+
128
+ attn = (q @ k.transpose(-2, -1)) * self.scale
129
+ attn = attn.softmax(dim=-1)
130
+ attn = self.attn_drop(attn)
131
+
132
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
133
+ x = self.proj(x)
134
+ x = self.proj_drop(x)
135
+
136
+ return x
137
+
138
+
139
+ class Block(nn.Module):
140
+ """
141
+ Transformer Block: Self-Attention -> Mix FFN -> OverLap Patch Merging
142
+ """
143
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
144
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
145
+ super().__init__()
146
+ self.norm1 = norm_layer(dim)
147
+ self.attn = Attention(
148
+ dim,
149
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
150
+ attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
151
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
152
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
153
+ self.norm2 = norm_layer(dim)
154
+ mlp_hidden_dim = int(dim * mlp_ratio)
155
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
156
+
157
+ self.apply(self._init_weights)
158
+
159
+ def _init_weights(self, m):
160
+ if isinstance(m, nn.Linear):
161
+ trunc_normal_(m.weight, std=.02)
162
+ if isinstance(m, nn.Linear) and m.bias is not None:
163
+ nn.init.constant_(m.bias, 0)
164
+ elif isinstance(m, nn.LayerNorm):
165
+ nn.init.constant_(m.bias, 0)
166
+ nn.init.constant_(m.weight, 1.0)
167
+ elif isinstance(m, nn.Conv2d):
168
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
169
+ fan_out //= m.groups
170
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
171
+ if m.bias is not None:
172
+ m.bias.data.zero_()
173
+
174
+ def forward(self, x, H, W):
175
+ x = x + self.drop_path(self.attn(self.norm1(x), H, W))
176
+ x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
177
+
178
+ return x
179
+
180
+
181
+ class OverlapPatchEmbed(nn.Module):
182
+ """ Image to Patch Embedding
183
+ """
184
+
185
+ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
186
+ super().__init__()
187
+ img_size = to_2tuple(img_size)
188
+ patch_size = to_2tuple(patch_size)
189
+
190
+ self.img_size = img_size
191
+ self.patch_size = patch_size
192
+ self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
193
+ self.num_patches = self.H * self.W
194
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
195
+ padding=(patch_size[0] // 2, patch_size[1] // 2))
196
+ self.norm = nn.LayerNorm(embed_dim)
197
+
198
+ self.apply(self._init_weights)
199
+
200
+ def _init_weights(self, m):
201
+ if isinstance(m, nn.Linear):
202
+ trunc_normal_(m.weight, std=.02)
203
+ if isinstance(m, nn.Linear) and m.bias is not None:
204
+ nn.init.constant_(m.bias, 0)
205
+ elif isinstance(m, nn.LayerNorm):
206
+ nn.init.constant_(m.bias, 0)
207
+ nn.init.constant_(m.weight, 1.0)
208
+ elif isinstance(m, nn.Conv2d):
209
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
210
+ fan_out //= m.groups
211
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
212
+ if m.bias is not None:
213
+ m.bias.data.zero_()
214
+
215
+ def forward(self, x):
216
+ # B C H W
217
+ x = self.proj(x)
218
+ _, _, H, W = x.shape
219
+ x = x.flatten(2).transpose(1, 2)
220
+ # B H*W/16 C
221
+ x = self.norm(x)
222
+
223
+ return x, H, W
224
+
225
+
226
+ class RGBXTransformer(nn.Module):
227
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
228
+ num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
229
+ attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, norm_fuse=nn.BatchNorm2d,
230
+ depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], stride0=4):
231
+ super().__init__()
232
+ self.num_classes = num_classes
233
+ self.depths = depths
234
+
235
+ # patch_embed
236
+ self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=stride0, in_chans=in_chans,
237
+ embed_dim=embed_dims[0])
238
+ self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
239
+ embed_dim=embed_dims[1])
240
+ self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
241
+ embed_dim=embed_dims[2])
242
+ self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
243
+ embed_dim=embed_dims[3])
244
+
245
+ self.extra_patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=stride0, in_chans=in_chans,
246
+ embed_dim=embed_dims[0])
247
+ self.extra_patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
248
+ embed_dim=embed_dims[1])
249
+ self.extra_patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
250
+ embed_dim=embed_dims[2])
251
+ self.extra_patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
252
+ embed_dim=embed_dims[3])
253
+
254
+ # transformer encoder
255
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
256
+ cur = 0
257
+
258
+ self.block1 = nn.ModuleList([Block(
259
+ dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
260
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
261
+ sr_ratio=sr_ratios[0])
262
+ for i in range(depths[0])])
263
+ self.norm1 = norm_layer(embed_dims[0])
264
+
265
+ self.extra_block1 = nn.ModuleList([Block(
266
+ dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
267
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
268
+ sr_ratio=sr_ratios[0])
269
+ for i in range(depths[0])])
270
+ self.extra_norm1 = norm_layer(embed_dims[0])
271
+ cur += depths[0]
272
+
273
+ self.block2 = nn.ModuleList([Block(
274
+ dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
275
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur], norm_layer=norm_layer,
276
+ sr_ratio=sr_ratios[1])
277
+ for i in range(depths[1])])
278
+ self.norm2 = norm_layer(embed_dims[1])
279
+
280
+ self.extra_block2 = nn.ModuleList([Block(
281
+ dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
282
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur+1], norm_layer=norm_layer,
283
+ sr_ratio=sr_ratios[1])
284
+ for i in range(depths[1])])
285
+ self.extra_norm2 = norm_layer(embed_dims[1])
286
+
287
+ cur += depths[1]
288
+
289
+ self.block3 = nn.ModuleList([Block(
290
+ dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
291
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
292
+ sr_ratio=sr_ratios[2])
293
+ for i in range(depths[2])])
294
+ self.norm3 = norm_layer(embed_dims[2])
295
+
296
+ self.extra_block3 = nn.ModuleList([Block(
297
+ dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
298
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
299
+ sr_ratio=sr_ratios[2])
300
+ for i in range(depths[2])])
301
+ self.extra_norm3 = norm_layer(embed_dims[2])
302
+
303
+ cur += depths[2]
304
+
305
+ self.block4 = nn.ModuleList([Block(
306
+ dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
307
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
308
+ sr_ratio=sr_ratios[3])
309
+ for i in range(depths[3])])
310
+ self.norm4 = norm_layer(embed_dims[3])
311
+
312
+ self.extra_block4 = nn.ModuleList([Block(
313
+ dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
314
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
315
+ sr_ratio=sr_ratios[3])
316
+ for i in range(depths[3])])
317
+ self.extra_norm4 = norm_layer(embed_dims[3])
318
+
319
+ cur += depths[3]
320
+
321
+ self.FRMs = nn.ModuleList([
322
+ FRM(dim=embed_dims[0], reduction=1),
323
+ FRM(dim=embed_dims[1], reduction=1),
324
+ FRM(dim=embed_dims[2], reduction=1),
325
+ FRM(dim=embed_dims[3], reduction=1)])
326
+
327
+ self.FFMs = nn.ModuleList([
328
+ FFM(dim=embed_dims[0], reduction=1, num_heads=num_heads[0], norm_layer=norm_fuse),
329
+ FFM(dim=embed_dims[1], reduction=1, num_heads=num_heads[1], norm_layer=norm_fuse),
330
+ FFM(dim=embed_dims[2], reduction=1, num_heads=num_heads[2], norm_layer=norm_fuse),
331
+ FFM(dim=embed_dims[3], reduction=1, num_heads=num_heads[3], norm_layer=norm_fuse)])
332
+
333
+ self.apply(self._init_weights)
334
+
335
+ def _init_weights(self, m):
336
+ if isinstance(m, nn.Linear):
337
+ trunc_normal_(m.weight, std=.02)
338
+ if isinstance(m, nn.Linear) and m.bias is not None:
339
+ nn.init.constant_(m.bias, 0)
340
+ elif isinstance(m, nn.LayerNorm):
341
+ nn.init.constant_(m.bias, 0)
342
+ nn.init.constant_(m.weight, 1.0)
343
+ elif isinstance(m, nn.Conv2d):
344
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
345
+ fan_out //= m.groups
346
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
347
+ if m.bias is not None:
348
+ m.bias.data.zero_()
349
+
350
+ def init_weights(self, pretrained=None):
351
+ if isinstance(pretrained, str):
352
+ load_dualpath_model(self, pretrained)
353
+ else:
354
+ raise TypeError('pretrained must be a str or None')
355
+
356
+ def forward_features(self, x_rgb, x_e):
357
+ """
358
+ x_rgb: B x N x H x W
359
+ """
360
+ B = x_rgb.shape[0]
361
+ outs = []
362
+ outs_fused = []
363
+
364
+ # stage 1
365
+ x_rgb, H, W = self.patch_embed1(x_rgb)
366
+ # B H*W/16 C
367
+ x_e, _, _ = self.extra_patch_embed1(x_e)
368
+ for i, blk in enumerate(self.block1):
369
+ x_rgb = blk(x_rgb, H, W)
370
+ for i, blk in enumerate(self.extra_block1):
371
+ x_e = blk(x_e, H, W)
372
+ x_rgb = self.norm1(x_rgb)
373
+ x_e = self.extra_norm1(x_e)
374
+
375
+ x_rgb = x_rgb.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
376
+ x_e = x_e.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
377
+ x_rgb, x_e = self.FRMs[0](x_rgb, x_e)
378
+ x_fused = self.FFMs[0](x_rgb, x_e)
379
+ outs.append(x_fused)
380
+
381
+
382
+ # stage 2
383
+ x_rgb, H, W = self.patch_embed2(x_rgb)
384
+ x_e, _, _ = self.extra_patch_embed2(x_e)
385
+ for i, blk in enumerate(self.block2):
386
+ x_rgb = blk(x_rgb, H, W)
387
+ for i, blk in enumerate(self.extra_block2):
388
+ x_e = blk(x_e, H, W)
389
+ x_rgb = self.norm2(x_rgb)
390
+ x_e = self.extra_norm2(x_e)
391
+
392
+ x_rgb = x_rgb.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
393
+ x_e = x_e.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
394
+ x_rgb, x_e = self.FRMs[1](x_rgb, x_e)
395
+ x_fused = self.FFMs[1](x_rgb, x_e)
396
+ outs.append(x_fused)
397
+
398
+
399
+ # stage 3
400
+ x_rgb, H, W = self.patch_embed3(x_rgb)
401
+ x_e, _, _ = self.extra_patch_embed3(x_e)
402
+ for i, blk in enumerate(self.block3):
403
+ x_rgb = blk(x_rgb, H, W)
404
+ for i, blk in enumerate(self.extra_block3):
405
+ x_e = blk(x_e, H, W)
406
+ x_rgb = self.norm3(x_rgb)
407
+ x_e = self.extra_norm3(x_e)
408
+
409
+ x_rgb = x_rgb.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
410
+ x_e = x_e.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
411
+ x_rgb, x_e = self.FRMs[2](x_rgb, x_e)
412
+ x_fused = self.FFMs[2](x_rgb, x_e)
413
+ outs.append(x_fused)
414
+
415
+
416
+ # stage 4
417
+ x_rgb, H, W = self.patch_embed4(x_rgb)
418
+ x_e, _, _ = self.extra_patch_embed4(x_e)
419
+ for i, blk in enumerate(self.block4):
420
+ x_rgb = blk(x_rgb, H, W)
421
+ for i, blk in enumerate(self.extra_block4):
422
+ x_e = blk(x_e, H, W)
423
+ x_rgb = self.norm4(x_rgb)
424
+ x_e = self.extra_norm4(x_e)
425
+
426
+ x_rgb = x_rgb.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
427
+ x_e = x_e.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
428
+ x_rgb, x_e = self.FRMs[3](x_rgb, x_e)
429
+ x_fused = self.FFMs[3](x_rgb, x_e)
430
+ outs.append(x_fused)
431
+
432
+ return outs
433
+
434
+ def forward(self, x_rgb, x_e):
435
+ out = self.forward_features(x_rgb, x_e)
436
+ return out
437
+
438
+
439
+ def load_dualpath_model(model, model_file):
440
+ # load raw state_dict
441
+ t_start = time.time()
442
+ if isinstance(model_file, str):
443
+ raw_state_dict = torch.load(model_file, map_location=torch.device('cpu'))
444
+ #raw_state_dict = torch.load(model_file)
445
+ if 'model' in raw_state_dict.keys():
446
+ raw_state_dict = raw_state_dict['model']
447
+ else:
448
+ raw_state_dict = model_file
449
+
450
+ state_dict = {}
451
+ for k, v in raw_state_dict.items():
452
+ if k.find('patch_embed') >= 0:
453
+ state_dict[k] = v
454
+ state_dict[k.replace('patch_embed', 'extra_patch_embed')] = v
455
+ elif k.find('block') >= 0:
456
+ state_dict[k] = v
457
+ state_dict[k.replace('block', 'extra_block')] = v
458
+ elif k.find('norm') >= 0:
459
+ state_dict[k] = v
460
+ state_dict[k.replace('norm', 'extra_norm')] = v
461
+
462
+ t_ioend = time.time()
463
+
464
+ model.load_state_dict(state_dict, strict=False)
465
+ del state_dict
466
+
467
+ t_end = time.time()
468
+ logger.info(
469
+ "Load model, Time usage:\n\tIO: {}, initialize parameters: {}".format(
470
+ t_ioend - t_start, t_end - t_ioend))
471
+
472
+
473
+ class mit_b0(RGBXTransformer):
474
+ def __init__(self, fuse_cfg=None, stride0=4, **kwargs):
475
+ super(mit_b0, self).__init__(
476
+ patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
477
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
478
+ drop_rate=0.0, drop_path_rate=0.1, stride0=stride0)
479
+
480
+
481
+ class mit_b1(RGBXTransformer):
482
+ def __init__(self, fuse_cfg=None, stride0=4, **kwargs):
483
+ super(mit_b1, self).__init__(
484
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
485
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
486
+ drop_rate=0.0, drop_path_rate=0.1, stride0=stride0)
487
+
488
+
489
+ class mit_b2(RGBXTransformer):
490
+ def __init__(self, fuse_cfg=None, stride0=4, **kwargs):
491
+ super(mit_b2, self).__init__(
492
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
493
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
494
+ drop_rate=0.0, drop_path_rate=0.1, stride0=stride0)
495
+
496
+
497
+ class mit_b3(RGBXTransformer):
498
+ def __init__(self, fuse_cfg=None, stride0=4, **kwargs):
499
+ super(mit_b3, self).__init__(
500
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
501
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
502
+ drop_rate=0.0, drop_path_rate=0.1, stride0=stride0)
503
+
504
+
505
+ class mit_b4(RGBXTransformer):
506
+ def __init__(self, fuse_cfg=None, stride0=4, **kwargs):
507
+ super(mit_b4, self).__init__(
508
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
509
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
510
+ drop_rate=0.0, drop_path_rate=0.1, stride0=stride0)
511
+
512
+
513
+ class mit_b5(RGBXTransformer):
514
+ def __init__(self, fuse_cfg=None, stride0=4, **kwargs):
515
+ super(mit_b5, self).__init__(
516
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
517
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
518
+ drop_rate=0.0, drop_path_rate=0.1, stride0=stride0)
trufor_native/models/cmx/layer_utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
2
+ # Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA').
3
+ #
4
+ # All rights reserved.
5
+ # This work should only be used for nonprofit purposes.
6
+ #
7
+ # By downloading and/or using any of these files, you implicitly agree to all the
8
+ # terms of the license, as specified in the document LICENSE.txt
9
+ # (included in this package) and online at
10
+ # http://www.grip.unina.it/download/LICENSE_OPEN.txt
11
+
12
+ """
13
+ Created in September 2022
14
+ @author: davide.cozzolino
15
+ """
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+
20
+
21
+ def weighted_statistics_pooling(x, log_w=None):
22
+ b = x.shape[0]
23
+ c = x.shape[1]
24
+ x = x.view(b,c,-1)
25
+
26
+ if log_w is None:
27
+ log_w = torch.zeros((b,1,x.shape[-1]), device=x.device)
28
+ else:
29
+ assert log_w.shape[0]==b
30
+ assert log_w.shape[1]==1
31
+ log_w = log_w.view(b,1,-1)
32
+
33
+ assert log_w.shape[-1]==x.shape[-1]
34
+
35
+ log_w = F.log_softmax(log_w, dim=-1)
36
+ x_min = -torch.logsumexp(log_w-x, dim=-1)
37
+ x_max = torch.logsumexp(log_w+x, dim=-1)
38
+
39
+ w = torch.exp(log_w)
40
+ x_avg = torch.sum(w*x , dim=-1)
41
+ x_msq = torch.sum(w*x*x, dim=-1)
42
+
43
+ x = torch.cat((x_min, x_max, x_avg, x_msq), dim=1)
44
+
45
+ return x
trufor_native/models/cmx/net_utils.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from timm.models.layers import trunc_normal_
5
+ import math
6
+
7
+
8
+ # Feature Rectify Module
9
+ class ChannelWeights(nn.Module):
10
+ def __init__(self, dim, reduction=1):
11
+ super(ChannelWeights, self).__init__()
12
+ self.dim = dim
13
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
14
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
15
+ self.mlp = nn.Sequential(
16
+ nn.Linear(self.dim * 4, self.dim * 4 // reduction),
17
+ nn.ReLU(inplace=True),
18
+ nn.Linear(self.dim * 4 // reduction, self.dim * 2),
19
+ nn.Sigmoid())
20
+
21
+ def forward(self, x1, x2):
22
+ B, _, H, W = x1.shape
23
+ x = torch.cat((x1, x2), dim=1)
24
+ avg = self.avg_pool(x).view(B, self.dim * 2)
25
+ max = self.max_pool(x).view(B, self.dim * 2)
26
+ y = torch.cat((avg, max), dim=1) # B 4C
27
+ y = self.mlp(y).view(B, self.dim * 2, 1)
28
+ channel_weights = y.reshape(B, 2, self.dim, 1, 1).permute(1, 0, 2, 3, 4) # 2 B C 1 1
29
+ return channel_weights
30
+
31
+
32
+ class SpatialWeights(nn.Module):
33
+ def __init__(self, dim, reduction=1):
34
+ super(SpatialWeights, self).__init__()
35
+ self.dim = dim
36
+ self.mlp = nn.Sequential(
37
+ nn.Conv2d(self.dim * 2, self.dim // reduction, kernel_size=1),
38
+ nn.ReLU(inplace=True),
39
+ nn.Conv2d(self.dim // reduction, 2, kernel_size=1),
40
+ nn.Sigmoid())
41
+
42
+ def forward(self, x1, x2):
43
+ B, _, H, W = x1.shape
44
+ x = torch.cat((x1, x2), dim=1) # B 2C H W
45
+ spatial_weights = self.mlp(x).reshape(B, 2, 1, H, W).permute(1, 0, 2, 3, 4) # 2 B 1 H W
46
+ return spatial_weights
47
+
48
+
49
+ class FeatureRectifyModule(nn.Module):
50
+ def __init__(self, dim, reduction=1, lambda_c=.5, lambda_s=.5):
51
+ super(FeatureRectifyModule, self).__init__()
52
+ self.lambda_c = lambda_c
53
+ self.lambda_s = lambda_s
54
+ self.channel_weights = ChannelWeights(dim=dim, reduction=reduction)
55
+ self.spatial_weights = SpatialWeights(dim=dim, reduction=reduction)
56
+
57
+ def _init_weights(self, m):
58
+ if isinstance(m, nn.Linear):
59
+ trunc_normal_(m.weight, std=.02)
60
+ if isinstance(m, nn.Linear) and m.bias is not None:
61
+ nn.init.constant_(m.bias, 0)
62
+ elif isinstance(m, nn.LayerNorm):
63
+ nn.init.constant_(m.bias, 0)
64
+ nn.init.constant_(m.weight, 1.0)
65
+ elif isinstance(m, nn.Conv2d):
66
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
67
+ fan_out //= m.groups
68
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
69
+ if m.bias is not None:
70
+ m.bias.data.zero_()
71
+
72
+ def forward(self, x1, x2):
73
+ channel_weights = self.channel_weights(x1, x2)
74
+ spatial_weights = self.spatial_weights(x1, x2)
75
+ out_x1 = x1 + self.lambda_c * channel_weights[1] * x2 + self.lambda_s * spatial_weights[1] * x2
76
+ out_x2 = x2 + self.lambda_c * channel_weights[0] * x1 + self.lambda_s * spatial_weights[0] * x1
77
+ return out_x1, out_x2
78
+
79
+
80
+ # Stage 1
81
+ class CrossAttention(nn.Module):
82
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None):
83
+ super(CrossAttention, self).__init__()
84
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
85
+
86
+ self.dim = dim
87
+ self.num_heads = num_heads
88
+ head_dim = dim // num_heads
89
+ self.scale = qk_scale or head_dim ** -0.5
90
+ self.kv1 = nn.Linear(dim, dim * 2, bias=qkv_bias)
91
+ self.kv2 = nn.Linear(dim, dim * 2, bias=qkv_bias)
92
+
93
+ def forward(self, x1, x2):
94
+ B, N, C = x1.shape
95
+ q1 = x1.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous()
96
+ q2 = x2.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous()
97
+
98
+ k1, v1 = self.kv1(x1).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
99
+ k2, v2 = self.kv2(x2).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
100
+
101
+ # q,k,v B H N C
102
+
103
+ ctx1 = (k1.transpose(-2, -1) @ v1) * self.scale # B H C C
104
+ ctx1 = ctx1.softmax(dim=-2)
105
+ ctx2 = (k2.transpose(-2, -1) @ v2) * self.scale # B H C C
106
+ ctx2 = ctx2.softmax(dim=-2)
107
+
108
+ x1 = (q1 @ ctx2).permute(0, 2, 1, 3).reshape(B, N, C).contiguous()
109
+ x2 = (q2 @ ctx1).permute(0, 2, 1, 3).reshape(B, N, C).contiguous()
110
+
111
+ return x1, x2
112
+
113
+
114
+ class CrossPath(nn.Module):
115
+ def __init__(self, dim, reduction=1, num_heads=None, norm_layer=nn.LayerNorm):
116
+ super().__init__()
117
+ self.channel_proj1 = nn.Linear(dim, dim // reduction * 2)
118
+ self.channel_proj2 = nn.Linear(dim, dim // reduction * 2)
119
+ self.act1 = nn.ReLU(inplace=True)
120
+ self.act2 = nn.ReLU(inplace=True)
121
+ self.cross_attn = CrossAttention(dim // reduction, num_heads=num_heads)
122
+ self.end_proj1 = nn.Linear(dim // reduction * 2, dim)
123
+ self.end_proj2 = nn.Linear(dim // reduction * 2, dim)
124
+ self.norm1 = norm_layer(dim)
125
+ self.norm2 = norm_layer(dim)
126
+
127
+ def forward(self, x1, x2):
128
+ y1, u1 = self.act1(self.channel_proj1(x1)).chunk(2, dim=-1)
129
+ y2, u2 = self.act2(self.channel_proj2(x2)).chunk(2, dim=-1)
130
+ v1, v2 = self.cross_attn(u1, u2)
131
+ y1 = torch.cat((y1, v1), dim=-1)
132
+ y2 = torch.cat((y2, v2), dim=-1)
133
+ out_x1 = self.norm1(x1 + self.end_proj1(y1))
134
+ out_x2 = self.norm2(x2 + self.end_proj2(y2))
135
+ return out_x1, out_x2
136
+
137
+
138
+ # Stage 2
139
+ class ChannelEmbed(nn.Module):
140
+ def __init__(self, in_channels, out_channels, reduction=1, norm_layer=nn.BatchNorm2d):
141
+ super(ChannelEmbed, self).__init__()
142
+ self.out_channels = out_channels
143
+ self.residual = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
144
+ self.channel_embed = nn.Sequential(
145
+ nn.Conv2d(in_channels, out_channels//reduction, kernel_size=1, bias=True),
146
+ nn.Conv2d(out_channels//reduction, out_channels//reduction, kernel_size=3, stride=1, padding=1, bias=True, groups=out_channels//reduction),
147
+ nn.ReLU(inplace=True),
148
+ nn.Conv2d(out_channels//reduction, out_channels, kernel_size=1, bias=True),
149
+ norm_layer(out_channels)
150
+ )
151
+ self.norm = norm_layer(out_channels)
152
+
153
+ def forward(self, x, H, W):
154
+ B, N, _C = x.shape
155
+ x = x.permute(0, 2, 1).reshape(B, _C, H, W).contiguous()
156
+ residual = self.residual(x)
157
+ x = self.channel_embed(x)
158
+ out = self.norm(residual + x)
159
+ return out
160
+
161
+
162
+ class FeatureFusionModule(nn.Module):
163
+ def __init__(self, dim, reduction=1, num_heads=None, norm_layer=nn.BatchNorm2d):
164
+ super().__init__()
165
+ self.cross = CrossPath(dim=dim, reduction=reduction, num_heads=num_heads)
166
+ self.channel_emb = ChannelEmbed(in_channels=dim*2, out_channels=dim, reduction=reduction, norm_layer=norm_layer)
167
+ self.apply(self._init_weights)
168
+
169
+ def _init_weights(self, m):
170
+ if isinstance(m, nn.Linear):
171
+ trunc_normal_(m.weight, std=.02)
172
+ if isinstance(m, nn.Linear) and m.bias is not None:
173
+ nn.init.constant_(m.bias, 0)
174
+ elif isinstance(m, nn.LayerNorm):
175
+ nn.init.constant_(m.bias, 0)
176
+ nn.init.constant_(m.weight, 1.0)
177
+ elif isinstance(m, nn.Conv2d):
178
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
179
+ fan_out //= m.groups
180
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
181
+ if m.bias is not None:
182
+ m.bias.data.zero_()
183
+
184
+ def forward(self, x1, x2):
185
+ B, C, H, W = x1.shape
186
+ x1 = x1.flatten(2).transpose(1, 2)
187
+ x2 = x2.flatten(2).transpose(1, 2)
188
+ x1, x2 = self.cross(x1, x2)
189
+ merge = torch.cat((x1, x2), dim=-1)
190
+ merge = self.channel_emb(merge, H, W)
191
+
192
+ return merge
193
+
trufor_native/models/cmx/utils/__init__.py ADDED
File without changes
trufor_native/models/cmx/utils/init_func.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # encoding: utf-8
3
+ # @Time : 2018/9/28 下午12:13
4
+ # @Author : yuchangqian
5
+ # @Contact : changqian_yu@163.com
6
+ # @File : init_func.py.py
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ def __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum,
11
+ **kwargs):
12
+ for name, m in feature.named_modules():
13
+ if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
14
+ conv_init(m.weight, **kwargs)
15
+ elif isinstance(m, norm_layer):
16
+ m.eps = bn_eps
17
+ m.momentum = bn_momentum
18
+ nn.init.constant_(m.weight, 1)
19
+ nn.init.constant_(m.bias, 0)
20
+
21
+
22
+ def init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum,
23
+ **kwargs):
24
+ if isinstance(module_list, list):
25
+ for feature in module_list:
26
+ __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum,
27
+ **kwargs)
28
+ else:
29
+ __init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum,
30
+ **kwargs)
31
+
32
+
33
+ def group_weight(weight_group, module, norm_layer, lr):
34
+ group_decay = []
35
+ group_no_decay = []
36
+ count = 0
37
+ for m in module.modules():
38
+ if isinstance(m, nn.Linear):
39
+ group_decay.append(m.weight)
40
+ if m.bias is not None:
41
+ group_no_decay.append(m.bias)
42
+ elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
43
+ group_decay.append(m.weight)
44
+ if m.bias is not None:
45
+ group_no_decay.append(m.bias)
46
+ elif isinstance(m, norm_layer) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d) \
47
+ or isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.GroupNorm) or isinstance(m, nn.LayerNorm):
48
+ if m.weight is not None:
49
+ group_no_decay.append(m.weight)
50
+ if m.bias is not None:
51
+ group_no_decay.append(m.bias)
52
+ elif isinstance(m, nn.Parameter):
53
+ group_decay.append(m)
54
+
55
+ assert len(list(module.parameters())) >= len(group_decay) + len(group_no_decay)
56
+ weight_group.append(dict(params=group_decay, lr=lr))
57
+ weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr))
58
+ return weight_group
trufor_runner.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ import shutil
6
+ import subprocess
7
+ import tempfile
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Dict, Optional
11
+
12
+ import numpy as np
13
+ from PIL import Image
14
+
15
+ LOGGER = logging.getLogger(__name__)
16
+
17
+
18
+ class TruForUnavailableError(RuntimeError):
19
+ """Raised when the TruFor assets are missing or inference fails."""
20
+
21
+
22
+ @dataclass
23
+ class TruForResult:
24
+ score: Optional[float]
25
+ map_overlay: Optional[Image.Image]
26
+ confidence_overlay: Optional[Image.Image]
27
+ raw_scores: Dict[str, float]
28
+
29
+
30
+ class TruForEngine:
31
+ """Wrapper that executes TruFor inference through docker or python backends."""
32
+
33
+ def __init__(
34
+ self,
35
+ repo_root: Optional[Path] = None,
36
+ weights_path: Optional[Path] = None,
37
+ device: str = "cpu",
38
+ ) -> None:
39
+ self.base_dir = Path(__file__).resolve().parent
40
+ self.device = device
41
+ self.backend: Optional[str] = None
42
+ self.status_message = "TruFor backend not initialized."
43
+
44
+ backend_pref = os.environ.get("TRUFOR_BACKEND", "auto").lower()
45
+ if backend_pref not in {"auto", "native", "docker"}:
46
+ backend_pref = "auto"
47
+
48
+ errors: list[str] = []
49
+
50
+ if backend_pref in {"auto", "native"}:
51
+ try:
52
+ self._configure_native_backend(repo_root, weights_path)
53
+ self.backend = "native"
54
+ self.status_message = "TruFor ready (bundled python backend)."
55
+ except TruForUnavailableError as exc:
56
+ errors.append(f"Native backend unavailable: {exc}")
57
+ if backend_pref == "native":
58
+ raise
59
+
60
+ if self.backend is None and backend_pref in {"auto", "docker"}:
61
+ try:
62
+ self._configure_docker_backend()
63
+ self.backend = "docker"
64
+ self.status_message = f'TruFor ready (docker image "{self.docker_image}").'
65
+ except TruForUnavailableError as exc:
66
+ errors.append(f"Docker backend unavailable: {exc}")
67
+ if backend_pref == "docker":
68
+ raise
69
+
70
+ if self.backend is None:
71
+ raise TruForUnavailableError(" | ".join(errors) if errors else "TruFor backend unavailable.")
72
+
73
+ # ------------------------------------------------------------------
74
+ # Backend configuration helpers
75
+ # ------------------------------------------------------------------
76
+ def _configure_docker_backend(self) -> None:
77
+ if shutil.which("docker") is None:
78
+ raise TruForUnavailableError("docker CLI not found on PATH.")
79
+
80
+ test_docker_dir = self.base_dir / "test_docker"
81
+ if not test_docker_dir.exists():
82
+ raise TruForUnavailableError("test_docker directory not found in workspace.")
83
+
84
+ image_name = os.environ.get("TRUFOR_DOCKER_IMAGE", "trufor")
85
+ inspect = subprocess.run(
86
+ ["docker", "image", "inspect", image_name],
87
+ stdout=subprocess.PIPE,
88
+ stderr=subprocess.PIPE,
89
+ text=True,
90
+ check=False,
91
+ )
92
+ if inspect.returncode != 0:
93
+ raise TruForUnavailableError(
94
+ f'Docker image "{image_name}" not found. Build it with "bash test_docker/docker_build.sh".'
95
+ )
96
+
97
+ weights_candidate = Path(os.environ.get("TRUFOR_DOCKER_WEIGHTS", self.base_dir / "weights")).expanduser()
98
+ weight_file = weights_candidate / "trufor.pth.tar"
99
+ self.docker_weights_dir: Optional[Path]
100
+ self.docker_weights_dir = weight_file.parent if weight_file.exists() else None
101
+
102
+ self.docker_runtime = os.environ.get("TRUFOR_DOCKER_RUNTIME")
103
+ gpu_pref = os.environ.get("TRUFOR_DOCKER_GPU")
104
+ if gpu_pref is None:
105
+ gpu_pref = "-1" if self.device == "cpu" else "0"
106
+ self.docker_gpu = gpu_pref
107
+
108
+ gpus_arg = os.environ.get("TRUFOR_DOCKER_GPUS_ARG")
109
+ if not gpus_arg and gpu_pref not in {"-1", "cpu", "none"}:
110
+ gpus_arg = "all"
111
+ self.docker_gpus_arg = gpus_arg
112
+
113
+ self.docker_image = image_name
114
+
115
+ def _configure_native_backend(self, _repo_root: Optional[Path], weights_path: Optional[Path]) -> None:
116
+ try:
117
+ from trufor_native import TruForBundledModel
118
+ except ImportError as exc: # pragma: no cover - packaging guard
119
+ raise TruForUnavailableError("Bundled TruFor modules are not available.") from exc
120
+
121
+ default_weights = self.base_dir / "weights" / "trufor.pth.tar"
122
+ weight_candidate = weights_path or os.environ.get("TRUFOR_WEIGHTS") or default_weights
123
+ weight_path = Path(weight_candidate).expanduser()
124
+ if not weight_path.exists():
125
+ raise TruForUnavailableError(
126
+ f"TruFor weights missing at {weight_path}. Place trufor.pth.tar under weights/ or set TRUFOR_WEIGHTS."
127
+ )
128
+
129
+ try:
130
+ self.native_model = TruForBundledModel(weight_path, device=self.device)
131
+ except Exception as exc: # pragma: no cover - propagate detailed failure
132
+ raise TruForUnavailableError(f"Failed to initialise bundled TruFor model: {exc}") from exc
133
+
134
+ # ------------------------------------------------------------------
135
+ # Public API
136
+ # ------------------------------------------------------------------
137
+ def infer(self, image: Image.Image) -> TruForResult:
138
+ if image is None:
139
+ raise TruForUnavailableError("No image supplied to TruFor inference.")
140
+
141
+ if self.backend == "docker":
142
+ return self._infer_docker(image)
143
+ if self.backend == "native":
144
+ return self._infer_native(image)
145
+
146
+ raise TruForUnavailableError("TruFor backend not configured.")
147
+
148
+ # ------------------------------------------------------------------
149
+ # Inference helpers
150
+ # ------------------------------------------------------------------
151
+ def _infer_native(self, image: Image.Image) -> TruForResult:
152
+ outputs = self.native_model.predict(image)
153
+
154
+ overlays: Dict[str, Optional[Image.Image]] = {"map": None, "conf": None}
155
+ try:
156
+ overlays["map"] = self._apply_heatmap(image, outputs.tamper_map)
157
+ except Exception as exc: # pragma: no cover - visualisation fallback
158
+ LOGGER.debug("Failed to build tamper heatmap: %s", exc)
159
+
160
+ if outputs.confidence_map is not None:
161
+ try:
162
+ overlays["conf"] = self._apply_heatmap(image, outputs.confidence_map)
163
+ except Exception as exc: # pragma: no cover
164
+ LOGGER.debug("Failed to build confidence heatmap: %s", exc)
165
+
166
+ raw_scores: Dict[str, float] = {
167
+ "tamper_mean": float(np.mean(outputs.tamper_map)),
168
+ "tamper_max": float(np.max(outputs.tamper_map)),
169
+ }
170
+
171
+ if outputs.confidence_map is not None:
172
+ raw_scores["confidence_mean"] = float(np.mean(outputs.confidence_map))
173
+ raw_scores["confidence_max"] = float(np.max(outputs.confidence_map))
174
+
175
+ if outputs.detection_score is not None:
176
+ raw_scores["tamper_score"] = float(outputs.detection_score)
177
+
178
+ return TruForResult(
179
+ score=outputs.detection_score,
180
+ map_overlay=overlays["map"],
181
+ confidence_overlay=overlays["conf"],
182
+ raw_scores=raw_scores,
183
+ )
184
+
185
+ def _infer_docker(self, image: Image.Image) -> TruForResult:
186
+ with tempfile.TemporaryDirectory(prefix="trufor_docker_") as workdir:
187
+ workdir_path = Path(workdir)
188
+ input_dir = workdir_path / "data"
189
+ output_dir = workdir_path / "data_out"
190
+ input_dir.mkdir(parents=True, exist_ok=True)
191
+ output_dir.mkdir(parents=True, exist_ok=True)
192
+ input_path = input_dir / "input.png"
193
+ image.convert("RGB").save(input_path)
194
+
195
+ cmd = ["docker", "run", "--rm"]
196
+ if self.docker_runtime:
197
+ cmd.extend(["--runtime", self.docker_runtime])
198
+
199
+ gpu_flag = str(self.docker_gpu)
200
+ if gpu_flag.lower() in {"cpu", "none"}:
201
+ gpu_flag = "-1"
202
+ if gpu_flag != "-1" and self.docker_gpus_arg:
203
+ cmd.extend(["--gpus", self.docker_gpus_arg])
204
+
205
+ cmd.extend([
206
+ "-v",
207
+ f"{input_dir.resolve()}:/data:ro",
208
+ "-v",
209
+ f"{output_dir.resolve()}:/data_out:rw",
210
+ ])
211
+
212
+ if self.docker_weights_dir is not None:
213
+ cmd.extend([
214
+ "-v",
215
+ f"{self.docker_weights_dir.resolve()}:/weights:ro",
216
+ ])
217
+
218
+ cmd.append(self.docker_image)
219
+ cmd.extend(
220
+ [
221
+ "-gpu",
222
+ gpu_flag,
223
+ "-in",
224
+ "data/input.png",
225
+ "-out",
226
+ "data_out",
227
+ ]
228
+ )
229
+
230
+ LOGGER.debug("Running TruFor docker command: %s", " ".join(cmd))
231
+ result = subprocess.run(
232
+ cmd,
233
+ text=True,
234
+ capture_output=True,
235
+ check=False,
236
+ )
237
+
238
+ return self._process_results(result, output_dir, image)
239
+
240
+ # ------------------------------------------------------------------
241
+ # Result parsing
242
+ # ------------------------------------------------------------------
243
+ def _process_results(self, run_result: subprocess.CompletedProcess[str], output_dir: Path, image: Image.Image) -> TruForResult:
244
+ if run_result.returncode != 0:
245
+ stderr_tail = "\n".join(run_result.stderr.strip().splitlines()[-8:]) if run_result.stderr else ""
246
+ LOGGER.error("TruFor stderr: %s", stderr_tail)
247
+ raise TruForUnavailableError(
248
+ "TruFor inference failed. Inspect dependencies and stderr:\n" + stderr_tail
249
+ )
250
+
251
+ npz_files = list(output_dir.rglob("*.npz"))
252
+ if not npz_files:
253
+ stdout_tail = "\n".join(run_result.stdout.strip().splitlines()[-8:]) if run_result.stdout else ""
254
+ raise TruForUnavailableError(
255
+ "TruFor inference produced no output files. Stdout tail:\n" + stdout_tail
256
+ )
257
+
258
+ data = np.load(npz_files[0], allow_pickle=False)
259
+ tamper_map = data.get("map")
260
+ conf_map = data.get("conf")
261
+ score = float(data["score"]) if "score" in data.files else None
262
+
263
+ overlays: Dict[str, Optional[Image.Image]] = {"map": None, "conf": None}
264
+ try:
265
+ overlays["map"] = self._apply_heatmap(image, tamper_map) if tamper_map is not None else None
266
+ except Exception as exc: # pragma: no cover
267
+ LOGGER.debug("Failed to build tamper heatmap: %s", exc)
268
+
269
+ try:
270
+ overlays["conf"] = self._apply_heatmap(image, conf_map) if conf_map is not None else None
271
+ except Exception as exc: # pragma: no cover
272
+ LOGGER.debug("Failed to build confidence heatmap: %s", exc)
273
+
274
+ raw_scores: Dict[str, float] = {}
275
+ if score is not None:
276
+ raw_scores["tamper_score"] = score
277
+ if tamper_map is not None:
278
+ raw_scores["tamper_mean"] = float(np.mean(tamper_map))
279
+ raw_scores["tamper_max"] = float(np.max(tamper_map))
280
+ if conf_map is not None:
281
+ raw_scores["confidence_mean"] = float(np.mean(conf_map))
282
+ raw_scores["confidence_max"] = float(np.max(conf_map))
283
+
284
+ return TruForResult(
285
+ score=score,
286
+ map_overlay=overlays["map"],
287
+ confidence_overlay=overlays["conf"],
288
+ raw_scores=raw_scores,
289
+ )
290
+
291
+ @staticmethod
292
+ def _apply_heatmap(base: Image.Image, data: np.ndarray, alpha: float = 0.55) -> Image.Image:
293
+ base_rgb = base.convert("RGB")
294
+ if data is None or data.ndim != 2:
295
+ raise ValueError("Expected a 2D map from TruFor")
296
+
297
+ data = np.asarray(data, dtype=np.float32)
298
+ if np.allclose(data.max(), data.min()):
299
+ norm = np.zeros_like(data, dtype=np.float32)
300
+ else:
301
+ norm = (data - data.min()) / (data.max() - data.min())
302
+
303
+ heat = np.zeros((*norm.shape, 3), dtype=np.uint8)
304
+ heat[..., 0] = np.clip(norm * 255, 0, 255).astype(np.uint8)
305
+ heat[..., 1] = np.clip(np.sqrt(norm) * 255, 0, 255).astype(np.uint8)
306
+ heat[..., 2] = np.clip((1.0 - norm) * 255, 0, 255).astype(np.uint8)
307
+
308
+ heat_img = Image.fromarray(heat, mode="RGB").resize(base_rgb.size, Image.BILINEAR)
309
+ return Image.blend(base_rgb, heat_img, alpha)