amaye15 commited on
Commit
3ab8cc6
·
verified ·
1 Parent(s): 2d0edb4

Update modeling_aimv2.py

Browse files
Files changed (1) hide show
  1. modeling_aimv2.py +12 -18
modeling_aimv2.py CHANGED
@@ -309,13 +309,6 @@ class AIMv2ForImageClassification(AIMv2PretrainedModel):
309
  '''
310
 
311
 
312
- import logging
313
-
314
-
315
- # Setup logging
316
- logging.basicConfig(level=logging.INFO)
317
- logger = logging.getLogger(__name__)
318
-
319
  class AIMv2ForImageClassification(AIMv2PretrainedModel):
320
  def __init__(self, config: AIMv2Config):
321
  super().__init__(config)
@@ -341,15 +334,15 @@ class AIMv2ForImageClassification(AIMv2PretrainedModel):
341
  output_hidden_states: Optional[bool] = None,
342
  return_dict: Optional[bool] = None,
343
  ) -> Union[tuple, ImageClassifierOutput]:
344
- logger.info("Forward pass initiated")
345
- logger.info(f"Input pixel_values shape: {pixel_values.shape if pixel_values is not None else 'None'}")
346
- logger.info(f"Head mask provided: {head_mask is not None}")
347
- logger.info(f"Labels provided: {labels is not None}")
348
 
349
  return_dict = (
350
  return_dict if return_dict is not None else self.config.use_return_dict
351
  )
352
- logger.info(f"Using return_dict: {return_dict}")
353
 
354
  # Call base model
355
  outputs = self.aimv2(
@@ -359,30 +352,31 @@ class AIMv2ForImageClassification(AIMv2PretrainedModel):
359
  return_dict=return_dict,
360
  )
361
  sequence_output = outputs[0]
362
- logger.info(f"Sequence output shape: {sequence_output.shape}")
363
 
364
  # Classifier head
365
  logits = self.classifier(sequence_output[:, 0, :])
366
- logger.info(f"Logits shape: {logits.shape}")
367
 
368
  loss = None
369
  if labels is not None:
370
  labels = labels.to(logits.device)
371
- logger.info(f"Labels shape: {labels.shape}")
372
 
373
  # Always use cross-entropy loss
374
  loss_fct = CrossEntropyLoss()
375
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
376
- logger.info(f"Loss computed: {loss.item()}")
377
 
378
  if not return_dict:
379
  output = (logits,) + outputs[1:]
380
- logger.info("Returning as tuple")
381
  return ((loss,) + output) if loss is not None else output
382
 
383
- logger.info("Returning as ImageClassifierOutput")
384
  return ImageClassifierOutput(
385
  loss=loss,
386
  logits=logits,
387
  hidden_states=outputs.hidden_states,
388
  )
 
 
309
  '''
310
 
311
 
 
 
 
 
 
 
 
312
  class AIMv2ForImageClassification(AIMv2PretrainedModel):
313
  def __init__(self, config: AIMv2Config):
314
  super().__init__(config)
 
334
  output_hidden_states: Optional[bool] = None,
335
  return_dict: Optional[bool] = None,
336
  ) -> Union[tuple, ImageClassifierOutput]:
337
+ print("Forward pass initiated")
338
+ print(f"Input pixel_values shape: {pixel_values.shape if pixel_values is not None else 'None'}")
339
+ print(f"Head mask provided: {head_mask is not None}")
340
+ print(f"Labels provided: {labels is not None}")
341
 
342
  return_dict = (
343
  return_dict if return_dict is not None else self.config.use_return_dict
344
  )
345
+ print(f"Using return_dict: {return_dict}")
346
 
347
  # Call base model
348
  outputs = self.aimv2(
 
352
  return_dict=return_dict,
353
  )
354
  sequence_output = outputs[0]
355
+ print(f"Sequence output shape: {sequence_output.shape}")
356
 
357
  # Classifier head
358
  logits = self.classifier(sequence_output[:, 0, :])
359
+ print(f"Logits shape: {logits.shape}")
360
 
361
  loss = None
362
  if labels is not None:
363
  labels = labels.to(logits.device)
364
+ print(f"Labels shape: {labels.shape}")
365
 
366
  # Always use cross-entropy loss
367
  loss_fct = CrossEntropyLoss()
368
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
369
+ print(f"Loss computed: {loss.item()}")
370
 
371
  if not return_dict:
372
  output = (logits,) + outputs[1:]
373
+ print("Returning as tuple")
374
  return ((loss,) + output) if loss is not None else output
375
 
376
+ print("Returning as ImageClassifierOutput")
377
  return ImageClassifierOutput(
378
  loss=loss,
379
  logits=logits,
380
  hidden_states=outputs.hidden_states,
381
  )
382
+