File size: 3,278 Bytes
29547e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#include "../../include/timer.h"

// Define HIPRT_CB if not already defined
#ifndef HIPRT_CB
#define HIPRT_CB
#endif

// Forward declaration of KernelTimer
class KernelTimer;

// Static callback function for hipStreamAddCallback
static void HIPRT_CB eventCallback(hipStream_t stream, hipError_t status, void* userData) {
  if (status != hipSuccess) return;
  
  KernelTimer* timer = static_cast<KernelTimer*>(userData);
  float elapsed_time;
  
  // Use the getter methods to access the private members
  HOST_TYPE(Event_t) start_event = timer->get_start_event();
  HOST_TYPE(Event_t) stop_event = timer->get_stop_event();
  
  LIB_CALL(HOST_TYPE(EventElapsedTime)(&elapsed_time, start_event, stop_event));
  
  size_t calc_ops = timer->get_calc_ops();
  double flops = static_cast<double>(calc_ops);
  double gflops_val = (flops / (elapsed_time * 1e-3)) / 1e9;

  // Store results in the provided pointers
  float* time_ptr = timer->get_time_ptr();
  float* gflops_ptr = timer->get_gflops_ptr();
  
  if (time_ptr != nullptr) {
    *time_ptr = elapsed_time;
  }
  if (gflops_ptr != nullptr) {
    *gflops_ptr = static_cast<float>(gflops_val);
  }
  
  // Call user callback if provided
  timer->execute_callback(elapsed_time);
  timer->set_callback_executed(true);
}

KernelTimer::KernelTimer(size_t calc_ops, float *time, float *gflops)
    : calc_ops(calc_ops), time_ptr(time), gflops_ptr(gflops), user_data(nullptr), 
      callback(nullptr), callback_executed(false) {
  LIB_CALL(HOST_TYPE(EventCreate)(&start));
  LIB_CALL(HOST_TYPE(EventCreate)(&stop));
}

void KernelTimer::start_timer(hipStream_t stream) { 
  LIB_CALL(HOST_TYPE(EventRecord)(start, stream));
  callback_executed = false;
}

void KernelTimer::stop_timer(hipStream_t stream) {
  LIB_CALL(HOST_TYPE(EventRecord)(stop, stream));
  // Instead of synchronizing, add a callback to the stream that will be called when the event completes
  LIB_CALL(hipStreamAddCallback(stream, eventCallback, this, 0));
}

void KernelTimer::set_callback(TimerCompletionCallback cb, void* data) {
  callback = cb;
  user_data = data;
}

void KernelTimer::execute_callback(float elapsed_time) {
  if (callback && !callback_executed) {
    callback(elapsed_time, calc_ops, time_ptr, gflops_ptr, user_data);
  }
}

void KernelTimer::synchronize() {
  // If callback hasn't been executed yet, synchronize and wait for event completion, then manually execute callback
  if (!callback_executed) {
    LIB_CALL(HOST_TYPE(EventSynchronize)(stop));
    float elapsed_time;
    LIB_CALL(HOST_TYPE(EventElapsedTime)(&elapsed_time, start, stop));
    
    double flops = static_cast<double>(calc_ops);
    double gflops_val = (flops / (elapsed_time * 1e-3)) / 1e9;

    // Store results in the provided pointers
    if (time_ptr != nullptr) {
      *time_ptr = elapsed_time;
    }
    if (gflops_ptr != nullptr) {
      *gflops_ptr = static_cast<float>(gflops_val);
    }
    
    // Execute callback
    if (callback) {
      callback(elapsed_time, calc_ops, time_ptr, gflops_ptr, user_data);
    }
    callback_executed = true;
  }
}

KernelTimer::~KernelTimer() {
  // Synchronize during destruction to ensure callback is executed
  synchronize();
  LIB_CALL(HOST_TYPE(EventDestroy)(start));
  LIB_CALL(HOST_TYPE(EventDestroy)(stop));
}