Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| # Copyright (c) XiMing Xing. All rights reserved. | |
| # Author: XiMing Xing | |
| # Description: | |
| import torch | |
| def channel_saturation_penalty_loss(x: torch.Tensor): | |
| assert x.shape[1] == 3 | |
| r_channel = x[:, 0, :, :] | |
| g_channel = x[:, 1, :, :] | |
| b_channel = x[:, 2, :, :] | |
| channel_accumulate = torch.pow(r_channel, 2) + torch.pow(g_channel, 2) + torch.pow(b_channel, 2) | |
| return channel_accumulate.mean() / 3 | |