Update modeling_aimv2.py
Browse files- modeling_aimv2.py +2 -0
modeling_aimv2.py
CHANGED
|
@@ -357,11 +357,13 @@ class AIMv2ForImageClassification(AIMv2PretrainedModel):
|
|
| 357 |
# Classifier head
|
| 358 |
logits = self.classifier(sequence_output[:, 0, :])
|
| 359 |
print(f"Logits shape: {logits.shape}")
|
|
|
|
| 360 |
|
| 361 |
loss = None
|
| 362 |
if labels is not None:
|
| 363 |
labels = labels.to(logits.device)
|
| 364 |
print(f"Labels shape: {labels.shape}")
|
|
|
|
| 365 |
|
| 366 |
# Always use cross-entropy loss
|
| 367 |
loss_fct = CrossEntropyLoss()
|
|
|
|
| 357 |
# Classifier head
|
| 358 |
logits = self.classifier(sequence_output[:, 0, :])
|
| 359 |
print(f"Logits shape: {logits.shape}")
|
| 360 |
+
print(f"Logits shape: {logits}")
|
| 361 |
|
| 362 |
loss = None
|
| 363 |
if labels is not None:
|
| 364 |
labels = labels.to(logits.device)
|
| 365 |
print(f"Labels shape: {labels.shape}")
|
| 366 |
+
print(f"Labels shape: {labels}")
|
| 367 |
|
| 368 |
# Always use cross-entropy loss
|
| 369 |
loss_fct = CrossEntropyLoss()
|