|
|
#include "checker.h" |
|
|
#include <dlfcn.h> |
|
|
#include <sstream> |
|
|
#include <fstream> |
|
|
#include <iomanip> |
|
|
#include <limits> |
|
|
#include <getopt.h> |
|
|
#include <unistd.h> |
|
|
|
|
|
std::pair<bool, std::string> verbose_allclose(const torch::Tensor &received, const torch::Tensor &expected, |
|
|
float rtol = 1e-05, float atol = 1e-08, int max_print = 5) { |
|
|
|
|
|
if (received.sizes() != expected.sizes()) { |
|
|
std::string expected_shape_str = "["; |
|
|
std::string received_shape_str = "["; |
|
|
auto expected_sizes = expected.sizes(); |
|
|
auto received_sizes = received.sizes(); |
|
|
|
|
|
for (int i = 0; i < expected_sizes.size(); i++) { |
|
|
expected_shape_str += std::to_string(expected_sizes[i]); |
|
|
if (i < expected_sizes.size() - 1) |
|
|
expected_shape_str += ", "; |
|
|
} |
|
|
expected_shape_str += "]"; |
|
|
|
|
|
for (int i = 0; i < received_sizes.size(); i++) { |
|
|
received_shape_str += std::to_string(received_sizes[i]); |
|
|
if (i < received_sizes.size() - 1) |
|
|
received_shape_str += ", "; |
|
|
} |
|
|
received_shape_str += "]"; |
|
|
|
|
|
return {false, "SIZE MISMATCH: expected " + expected_shape_str + " but got " + received_shape_str}; |
|
|
} |
|
|
|
|
|
auto diff = torch::abs(received.to(torch::kFloat32) - expected.to(torch::kFloat32)); |
|
|
|
|
|
auto tolerance = atol + rtol * torch::abs(expected); |
|
|
|
|
|
auto tol_mismatched = diff > tolerance; |
|
|
auto nan_mismatched = torch::logical_xor(torch::isnan(received), torch::isnan(expected)); |
|
|
auto posinf_mismatched = torch::logical_xor(torch::isposinf(received), torch::isposinf(expected)); |
|
|
auto neginf_mismatched = torch::logical_xor(torch::isneginf(received), torch::isneginf(expected)); |
|
|
|
|
|
auto mismatched = torch::logical_or(torch::logical_or(tol_mismatched, nan_mismatched), |
|
|
torch::logical_or(posinf_mismatched, neginf_mismatched)); |
|
|
|
|
|
auto mismatched_indices = torch::nonzero(mismatched); |
|
|
|
|
|
|
|
|
int64_t num_mismatched = mismatched.sum().item<int64_t>(); |
|
|
|
|
|
|
|
|
if (num_mismatched >= 1) { |
|
|
std::stringstream mismatch_details; |
|
|
auto sizes = received.sizes(); |
|
|
mismatch_details << "Mismatch found in tensors with shape ["; |
|
|
for (int i = 0; i < sizes.size(); i++) { |
|
|
mismatch_details << sizes[i]; |
|
|
if (i < sizes.size() - 1) |
|
|
mismatch_details << ", "; |
|
|
} |
|
|
mismatch_details << "]:\n"; |
|
|
mismatch_details << "Number of mismatched elements: " << num_mismatched << "\n"; |
|
|
|
|
|
for (int i = 0; i < std::min(max_print, (int)mismatched_indices.size(0)); i++) { |
|
|
auto index = mismatched_indices[i]; |
|
|
std::vector<int64_t> idx_vec; |
|
|
for (int j = 0; j < index.size(0); j++) { |
|
|
idx_vec.push_back(index[j].item<int64_t>()); |
|
|
} |
|
|
|
|
|
|
|
|
std::string idx_str = "("; |
|
|
for (size_t j = 0; j < idx_vec.size(); j++) { |
|
|
idx_str += std::to_string(idx_vec[j]); |
|
|
if (j < idx_vec.size() - 1) |
|
|
idx_str += ", "; |
|
|
} |
|
|
idx_str += ")"; |
|
|
|
|
|
float received_val, expected_val; |
|
|
torch::Tensor received_elem = received; |
|
|
torch::Tensor expected_elem = expected; |
|
|
|
|
|
for (size_t j = 0; j < idx_vec.size(); j++) { |
|
|
received_elem = received_elem[idx_vec[j]]; |
|
|
expected_elem = expected_elem[idx_vec[j]]; |
|
|
} |
|
|
|
|
|
received_val = received_elem.item<float>(); |
|
|
expected_val = expected_elem.item<float>(); |
|
|
|
|
|
mismatch_details << "ERROR at " << idx_str << ": " << received_val << " " << expected_val << "\n"; |
|
|
} |
|
|
|
|
|
if (num_mismatched > max_print) { |
|
|
mismatch_details << "... and " << (num_mismatched - max_print) << " more mismatched elements."; |
|
|
} |
|
|
|
|
|
return {false, mismatch_details.str()}; |
|
|
} |
|
|
|
|
|
return {true, "Maximum error: " + std::to_string(diff.max().item<float>())}; |
|
|
} |
|
|
|
|
|
|
|
|
std::pair<bool, std::string> check_implementation(std::ofstream &fout, const torch::Tensor &output, |
|
|
const torch::Tensor &expected, float rtol = 2e-02, float atol = 1e-03, |
|
|
CheckerMode mode = CheckerMode::kElementWise) { |
|
|
if (mode == CheckerMode::kRowIndex) { |
|
|
|
|
|
|
|
|
auto sorted_output = output.clone(); |
|
|
auto sorted_expected = expected.clone(); |
|
|
|
|
|
sorted_output = std::get<0>(torch::sort(output, 1)); |
|
|
sorted_expected = std::get<0>(torch::sort(expected, 1)); |
|
|
|
|
|
return verbose_allclose(sorted_output, sorted_expected, rtol, atol); |
|
|
} else if (mode == CheckerMode::kJustDump) { |
|
|
|
|
|
{ |
|
|
fout << "=====OUTPUT=====" << std::endl; |
|
|
fout << output.sizes() << std::endl; |
|
|
|
|
|
|
|
|
auto sizes = output.sizes(); |
|
|
if (sizes.size() == 2) { |
|
|
|
|
|
for (int64_t i = 0; i < sizes[0]; i++) { |
|
|
for (int64_t j = 0; j < sizes[1]; j++) { |
|
|
fout << std::setw(12) << std::setprecision(6) << output[i][j].item<float>() << " "; |
|
|
} |
|
|
fout << std::endl; |
|
|
} |
|
|
} else { |
|
|
|
|
|
fout << output << std::endl; |
|
|
} |
|
|
} |
|
|
|
|
|
{ |
|
|
fout << "=====EXPECTED=====" << std::endl; |
|
|
fout << expected.sizes() << std::endl; |
|
|
|
|
|
|
|
|
auto sizes = output.sizes(); |
|
|
if (sizes.size() == 2) { |
|
|
|
|
|
for (int64_t i = 0; i < sizes[0]; i++) { |
|
|
for (int64_t j = 0; j < sizes[1]; j++) { |
|
|
fout << std::setw(12) << std::setprecision(6) << expected[i][j].item<float>() << " "; |
|
|
} |
|
|
fout << std::endl; |
|
|
} |
|
|
} else { |
|
|
|
|
|
fout << output << std::endl; |
|
|
} |
|
|
} |
|
|
|
|
|
return {true, ""}; |
|
|
} |
|
|
return verbose_allclose(output, expected, rtol, atol); |
|
|
} |
|
|
|
|
|
constexpr int BENCHMARK_ITERS = 5; |
|
|
|
|
|
void preload() { |
|
|
void *handle_rocblas = dlopen("/usr/local/lib/python3.10/dist-packages/torch/lib/librocblas.so", RTLD_NOW | RTLD_GLOBAL); |
|
|
void *handle_hipblas = dlopen("/usr/local/lib/python3.10/dist-packages/torch/lib/libhipblas.so", RTLD_NOW | RTLD_GLOBAL); |
|
|
void *handle_hipblaslt = dlopen("/usr/local/lib/python3.10/dist-packages/torch/lib/libhipblaslt.so", RTLD_NOW | RTLD_GLOBAL); |
|
|
|
|
|
if (!handle_rocblas || !handle_hipblas || !handle_hipblaslt) { |
|
|
fprintf(stderr, "Failed to load required libraries: %s\n", dlerror()); |
|
|
exit(1); |
|
|
} |
|
|
} |
|
|
|
|
|
int main(int argc, char **argv) { |
|
|
|
|
|
|
|
|
bool benchmark = true; |
|
|
bool profile_mode = false; |
|
|
int target_test_case = -1; |
|
|
int target_sub_case = -1; |
|
|
int opt; |
|
|
|
|
|
while ((opt = getopt(argc, argv, "bpt:c:")) != -1) { |
|
|
switch (opt) { |
|
|
case 'b': |
|
|
benchmark = false; |
|
|
break; |
|
|
case 'p': |
|
|
profile_mode = true; |
|
|
break; |
|
|
case 't': |
|
|
target_sub_case = std::stoi(optarg); |
|
|
break; |
|
|
case 'c': |
|
|
target_test_case = std::stoi(optarg); |
|
|
break; |
|
|
default: |
|
|
fprintf(stderr, "Usage: %s [-b] [-p] [-t subcase_index] [-c test_case_index]\n", argv[0]); |
|
|
fprintf(stderr, " -b: Disable benchmark mode\n"); |
|
|
fprintf(stderr, " -p: Enable profile mode (skips reference kernel and comparison)\n"); |
|
|
fprintf(stderr, " -t: Run only the specified subcase index\n"); |
|
|
fprintf(stderr, " -c: Run only the specified test case index\n"); |
|
|
exit(EXIT_FAILURE); |
|
|
} |
|
|
} |
|
|
|
|
|
case_initialize(); |
|
|
int num_params, passed_cases = 0; |
|
|
num_params = get_params_count(); |
|
|
|
|
|
|
|
|
if (target_test_case >= 0) { |
|
|
if (target_test_case >= num_params) { |
|
|
std::cerr << "Error: Test case index " << target_test_case << " is out of range (0-" << (num_params - 1) |
|
|
<< ")" << std::endl; |
|
|
exit(EXIT_FAILURE); |
|
|
} |
|
|
} |
|
|
|
|
|
std::vector<std::vector<PerfMetrics>> run_times(num_params); |
|
|
std::vector<std::tuple<bool, std::string, std::vector<std::pair<float, float>>>> results; |
|
|
|
|
|
|
|
|
if (target_test_case >= 0 && target_sub_case >= 0) { |
|
|
void *input = case_get_input(target_test_case); |
|
|
std::vector<Checkee> output; |
|
|
float best_time = std::numeric_limits<float>::max(); |
|
|
|
|
|
for (int j = 0; j < BENCHMARK_ITERS; j++) { |
|
|
PerfMetrics metrics; |
|
|
output = case_run_kernel(input, &metrics); |
|
|
|
|
|
if (metrics.count <= target_sub_case) { |
|
|
std::cerr << "Error: Subcase index " << target_sub_case << " is out of range (0-" << (metrics.count - 1) |
|
|
<< ")" << std::endl; |
|
|
exit(EXIT_FAILURE); |
|
|
} |
|
|
|
|
|
best_time = std::min(best_time, metrics.entries[target_sub_case].time); |
|
|
} |
|
|
|
|
|
std::cout << std::fixed << std::setprecision(6) << best_time * 1e3 << std::endl; |
|
|
case_destroy(input); |
|
|
return 0; |
|
|
} |
|
|
|
|
|
|
|
|
if (!profile_mode && target_test_case < 0) { |
|
|
std::cout << "Found " << num_params << " test cases for " << case_get_name() << '\n'; |
|
|
} |
|
|
if (benchmark) { |
|
|
std::cout << "Benchmark mode enabled\n"; |
|
|
} |
|
|
if (profile_mode) { |
|
|
std::cout << "Profile mode enabled (skipping reference kernels and comparison)\n"; |
|
|
} |
|
|
|
|
|
|
|
|
std::vector<int> test_cases_to_run; |
|
|
if (target_test_case >= 0) { |
|
|
test_cases_to_run.push_back(target_test_case); |
|
|
} else { |
|
|
for (int i = 0; i < num_params; i++) { |
|
|
test_cases_to_run.push_back(i); |
|
|
} |
|
|
} |
|
|
|
|
|
for (int i : test_cases_to_run) { |
|
|
std::ofstream *fout = nullptr; |
|
|
void *input = case_get_input(i); |
|
|
if (!profile_mode && target_test_case < 0) { |
|
|
std::cerr << "Running test case " << i << std::flush; |
|
|
} |
|
|
std::vector<Checkee> reference; |
|
|
if (!profile_mode) { |
|
|
reference = case_run_ref_kernel(input); |
|
|
} |
|
|
std::vector<Checkee> output; |
|
|
for (int j = 0; j < (benchmark ? BENCHMARK_ITERS : 1); j++) { |
|
|
PerfMetrics metrics; |
|
|
output = case_run_kernel(input, &metrics); |
|
|
run_times[i].push_back(metrics); |
|
|
} |
|
|
|
|
|
bool match = true; |
|
|
std::string case_message; |
|
|
|
|
|
if (!profile_mode) { |
|
|
if (reference.size() != output.size()) { |
|
|
std::cerr << "Wrong test definition: reference and output have different sizes" << '\n'; |
|
|
abort(); |
|
|
} |
|
|
|
|
|
for (int j = 0; j < reference.size(); j++) { |
|
|
float rtol, atol; |
|
|
get_error_tolerance(&rtol, &atol); |
|
|
if (output[j].mode == CheckerMode::kJustDump) { |
|
|
if (!fout) { |
|
|
fout = new std::ofstream(std::string("case_") + std::to_string(i) + ".txt"); |
|
|
} |
|
|
*fout << "===== SUBCASE " << output[j].name << "=====" << std::endl; |
|
|
} |
|
|
auto [match_sub, message_sub] = |
|
|
check_implementation(*fout, *output[j].tensor, *reference[j].tensor, rtol, atol, output[j].mode); |
|
|
if (!match_sub) { |
|
|
case_message += "Err on sub case " + std::to_string(j) + ": " + message_sub + "\n"; |
|
|
match = false; |
|
|
} |
|
|
} |
|
|
if (match) { |
|
|
passed_cases++; |
|
|
} |
|
|
} else { |
|
|
match = true; |
|
|
passed_cases++; |
|
|
} |
|
|
|
|
|
std::vector<std::pair<float, float>> case_metrics; |
|
|
|
|
|
|
|
|
for (const auto &run : run_times[i]) { |
|
|
if (run.count == 1) { |
|
|
|
|
|
case_metrics.push_back({run.entries[0].time, run.entries[0].gflops}); |
|
|
} else { |
|
|
|
|
|
case_metrics.push_back({run.entries[0].time, run.entries[0].gflops}); |
|
|
} |
|
|
} |
|
|
|
|
|
results.push_back(std::make_tuple(match, case_message, case_metrics)); |
|
|
case_destroy(input); |
|
|
if (!profile_mode && target_test_case < 0) { |
|
|
std::cout << "\033[2K\r" << std::flush; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (target_test_case < 0) { |
|
|
std::cout << "=======================" << '\n'; |
|
|
if (!profile_mode) { |
|
|
if (passed_cases == num_params) { |
|
|
std::cout << "✅ All " << num_params << " test cases passed!" << '\n'; |
|
|
} else { |
|
|
std::cout << "❌ [" << num_params - passed_cases << "/" << num_params << "] test cases failed!" << '\n'; |
|
|
} |
|
|
} else { |
|
|
std::cout << "Profile mode: results comparison skipped" << '\n'; |
|
|
} |
|
|
std::cout << "-----------------------" << '\n'; |
|
|
|
|
|
for (int i = 0; i < num_params; i++) { |
|
|
auto [match, message, metrics] = results[i]; |
|
|
|
|
|
|
|
|
float best_time = std::numeric_limits<float>::max(); |
|
|
float best_gflops = 0.0f; |
|
|
float worst_time = 0.0f; |
|
|
float worst_gflops = std::numeric_limits<float>::max(); |
|
|
|
|
|
for (const auto &[time, gflops] : metrics) { |
|
|
best_time = std::min(best_time, time); |
|
|
best_gflops = std::max(best_gflops, gflops); |
|
|
worst_time = std::max(worst_time, time); |
|
|
worst_gflops = std::min(worst_gflops, gflops); |
|
|
} |
|
|
|
|
|
std::string timing_info; |
|
|
if (benchmark) { |
|
|
std::stringstream ss; |
|
|
ss << std::fixed << std::setprecision(2); |
|
|
ss << "Best: [\033[1m" << best_time * 1e3 << "\033[0m us, \033[1m" << best_gflops / 1e3 |
|
|
<< "\033[0m TFLOPS], " |
|
|
<< "\033[2mSlowest: [" << worst_time * 1e3 << " us, " << worst_gflops / 1e3 << " TFLOPS]\033[0m"; |
|
|
timing_info = ss.str(); |
|
|
} else { |
|
|
std::stringstream ss; |
|
|
ss << std::fixed << std::setprecision(2); |
|
|
ss << "Time: " << best_time * 1e3 << " us, TFLOPS: " << best_gflops / 1e3; |
|
|
timing_info = ss.str(); |
|
|
} |
|
|
|
|
|
if (!profile_mode && !match) { |
|
|
std::cout << "❌ Test case " << i << ": " << timing_info << "\n" << message << '\n'; |
|
|
} else { |
|
|
std::cout << "✅ Test case " << i << ": " << timing_info << "\n"; |
|
|
} |
|
|
|
|
|
|
|
|
if (run_times[i][0].count > 1) { |
|
|
for (int j = 1; j < run_times[i][0].count; j++) { |
|
|
std::stringstream ss; |
|
|
ss << std::fixed << std::setprecision(2); |
|
|
ss << " - Sub-case " << run_times[i][0].entries[j].name << ": "; |
|
|
|
|
|
if (benchmark) { |
|
|
float sub_best_time = std::numeric_limits<float>::max(); |
|
|
float sub_best_gflops = 0.0f; |
|
|
float sub_worst_time = 0.0f; |
|
|
float sub_worst_gflops = std::numeric_limits<float>::max(); |
|
|
|
|
|
for (const auto &run : run_times[i]) { |
|
|
sub_best_time = std::min(sub_best_time, run.entries[j].time); |
|
|
sub_best_gflops = std::max(sub_best_gflops, run.entries[j].gflops); |
|
|
sub_worst_time = std::max(sub_worst_time, run.entries[j].time); |
|
|
sub_worst_gflops = std::min(sub_worst_gflops, run.entries[j].gflops); |
|
|
} |
|
|
|
|
|
ss << "Best: [\033[1m" << sub_best_time * 1e3 << "\033[0m us, \033[1m" << sub_best_gflops / 1e3 |
|
|
<< "\033[0m TFLOPS], " |
|
|
<< "\033[2mSlowest: [" << sub_worst_time * 1e3 << " us, " << sub_worst_gflops / 1e3 |
|
|
<< " TFLOPS]\033[0m"; |
|
|
} else { |
|
|
ss << "Time: " << run_times[i][0].entries[j].time * 1e3 |
|
|
<< " us, TFLOPS: " << run_times[i][0].entries[j].gflops / 1e3; |
|
|
} |
|
|
|
|
|
std::cout << ss.str() << std::endl; |
|
|
} |
|
|
} |
|
|
} |
|
|
std::cout << "-----------------------" << '\n'; |
|
|
|
|
|
|
|
|
double geo_mean_time = 1.0; |
|
|
double geo_mean_gflops = 1.0; |
|
|
|
|
|
for (int i = 0; i < num_params; i++) { |
|
|
auto [match, message, metrics] = results[i]; |
|
|
|
|
|
float best_time = std::numeric_limits<float>::max(); |
|
|
float best_gflops = 0.0f; |
|
|
|
|
|
for (const auto &[time, gflops] : metrics) { |
|
|
best_time = std::min(best_time, time); |
|
|
best_gflops = std::max(best_gflops, gflops); |
|
|
} |
|
|
|
|
|
geo_mean_time *= best_time; |
|
|
geo_mean_gflops *= best_gflops; |
|
|
} |
|
|
|
|
|
geo_mean_time = std::pow(geo_mean_time, 1.0 / num_params); |
|
|
geo_mean_gflops = std::pow(geo_mean_gflops, 1.0 / num_params); |
|
|
|
|
|
if (benchmark) { |
|
|
std::stringstream ss; |
|
|
ss << std::fixed << std::setprecision(2); |
|
|
ss << "GeoMean - Best Time: \033[1m" << geo_mean_time * 1e3 << "\033[0m us, Best TFLOPS: \033[1m" |
|
|
<< geo_mean_gflops / 1e3 << "\033[0m"; |
|
|
std::cout << ss.str() << std::endl; |
|
|
} else { |
|
|
std::stringstream ss; |
|
|
ss << std::fixed << std::setprecision(2); |
|
|
ss << "GeoMean - Time: " << geo_mean_time * 1e3 << " us, TFLOPS: " << geo_mean_gflops / 1e3; |
|
|
std::cout << ss.str() << std::endl; |
|
|
} |
|
|
std::cout << "=======================" << '\n'; |
|
|
} |
|
|
|
|
|
return 0; |
|
|
} |
|
|
|