File size: 4,222 Bytes
ac305e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from typing_extensions import override
from typing import Callable

import torch

import comfy.model_management
from comfy_api.latest import ComfyExtension, io

import nodes

class EmptyChromaRadianceLatentImage(io.ComfyNode):
    @classmethod
    def define_schema(cls) -> io.Schema:
        return io.Schema(
            node_id="EmptyChromaRadianceLatentImage",
            category="latent/chroma_radiance",
            inputs=[
                io.Int.Input(id="width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
                io.Int.Input(id="height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
                io.Int.Input(id="batch_size", default=1, min=1, max=4096),
            ],
            outputs=[io.Latent().Output()],
        )

    @classmethod
    def execute(cls, *, width: int, height: int, batch_size: int=1) -> io.NodeOutput:
        latent = torch.zeros((batch_size, 3, height, width), device=comfy.model_management.intermediate_device())
        return io.NodeOutput({"samples":latent})


class ChromaRadianceOptions(io.ComfyNode):
    @classmethod
    def define_schema(cls) -> io.Schema:
        return io.Schema(
            node_id="ChromaRadianceOptions",
            category="model_patches/chroma_radiance",
            description="Allows setting advanced options for the Chroma Radiance model.",
            inputs=[
                io.Model.Input(id="model"),
                io.Boolean.Input(
                    id="preserve_wrapper",
                    default=True,
                    tooltip="When enabled, will delegate to an existing model function wrapper if it exists. Generally should be left enabled.",
                ),
                io.Float.Input(
                    id="start_sigma",
                    default=1.0,
                    min=0.0,
                    max=1.0,
                    tooltip="First sigma that these options will be in effect.",
                ),
                io.Float.Input(
                    id="end_sigma",
                    default=0.0,
                    min=0.0,
                    max=1.0,
                    tooltip="Last sigma that these options will be in effect.",
                ),
                io.Int.Input(
                    id="nerf_tile_size",
                    default=-1,
                    min=-1,
                    tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
                ),
            ],
            outputs=[io.Model.Output()],
        )

    @classmethod
    def execute(
        cls,
        *,
        model: io.Model.Type,
        preserve_wrapper: bool,
        start_sigma: float,
        end_sigma: float,
        nerf_tile_size: int,
    ) -> io.NodeOutput:
        radiance_options = {}
        if nerf_tile_size >= 0:
            radiance_options["nerf_tile_size"] = nerf_tile_size

        if not radiance_options:
            return io.NodeOutput(model)

        old_wrapper = model.model_options.get("model_function_wrapper")

        def model_function_wrapper(apply_model: Callable, args: dict) -> torch.Tensor:
            c = args["c"].copy()
            sigma = args["timestep"].max().detach().cpu().item()
            if end_sigma <= sigma <= start_sigma:
                transformer_options = c.get("transformer_options", {}).copy()
                transformer_options["chroma_radiance_options"] = radiance_options.copy()
                c["transformer_options"] = transformer_options
            if not (preserve_wrapper and old_wrapper):
                return apply_model(args["input"], args["timestep"], **c)
            return old_wrapper(apply_model, args | {"c": c})

        model = model.clone()
        model.set_model_unet_function_wrapper(model_function_wrapper)
        return io.NodeOutput(model)


class ChromaRadianceExtension(ComfyExtension):
    @override
    async def get_node_list(self) -> list[type[io.ComfyNode]]:
        return [
            EmptyChromaRadianceLatentImage,
            ChromaRadianceOptions,
        ]


async def comfy_entrypoint() -> ChromaRadianceExtension:
    return ChromaRadianceExtension()