Spaces:
Running
Running
File size: 3,958 Bytes
9ce984a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
"""
Title: Imbalanced classification: credit card fraud detection
Author: [fchollet](https://twitter.com/fchollet)
Date created: 2019/05/28
Last modified: 2020/04/17
Description: Demonstration of how to handle highly imbalanced classification problems.
Accelerator: GPU
"""
"""
## Introduction
This example looks at the
[Kaggle Credit Card Fraud Detection](https://www.kaggle.com/mlg-ulb/creditcardfraud/)
dataset to demonstrate how
to train a classification model on data with highly imbalanced classes.
"""
"""
## First, vectorize the CSV data
"""
import csv
import numpy as np
# Get the real data from https://www.kaggle.com/mlg-ulb/creditcardfraud/
fname = "/Users/fchollet/Downloads/creditcard.csv"
all_features = []
all_targets = []
with open(fname) as f:
for i, line in enumerate(f):
if i == 0:
print("HEADER:", line.strip())
continue # Skip header
fields = line.strip().split(",")
all_features.append([float(v.replace('"', "")) for v in fields[:-1]])
all_targets.append([int(fields[-1].replace('"', ""))])
if i == 1:
print("EXAMPLE FEATURES:", all_features[-1])
features = np.array(all_features, dtype="float32")
targets = np.array(all_targets, dtype="uint8")
print("features.shape:", features.shape)
print("targets.shape:", targets.shape)
"""
## Prepare a validation set
"""
num_val_samples = int(len(features) * 0.2)
train_features = features[:-num_val_samples]
train_targets = targets[:-num_val_samples]
val_features = features[-num_val_samples:]
val_targets = targets[-num_val_samples:]
print("Number of training samples:", len(train_features))
print("Number of validation samples:", len(val_features))
"""
## Analyze class imbalance in the targets
"""
counts = np.bincount(train_targets[:, 0])
print(
"Number of positive samples in training data: {} ({:.2f}% of total)".format(
counts[1], 100 * float(counts[1]) / len(train_targets)
)
)
weight_for_0 = 1.0 / counts[0]
weight_for_1 = 1.0 / counts[1]
"""
## Normalize the data using training set statistics
"""
mean = np.mean(train_features, axis=0)
train_features -= mean
val_features -= mean
std = np.std(train_features, axis=0)
train_features /= std
val_features /= std
"""
## Build a binary classification model
"""
import keras
model = keras.Sequential(
[
keras.Input(shape=train_features.shape[1:]),
keras.layers.Dense(256, activation="relu"),
keras.layers.Dense(256, activation="relu"),
keras.layers.Dropout(0.3),
keras.layers.Dense(256, activation="relu"),
keras.layers.Dropout(0.3),
keras.layers.Dense(1, activation="sigmoid"),
]
)
model.summary()
"""
## Train the model with `class_weight` argument
"""
metrics = [
keras.metrics.FalseNegatives(name="fn"),
keras.metrics.FalsePositives(name="fp"),
keras.metrics.TrueNegatives(name="tn"),
keras.metrics.TruePositives(name="tp"),
keras.metrics.Precision(name="precision"),
keras.metrics.Recall(name="recall"),
]
model.compile(
optimizer=keras.optimizers.Adam(1e-2), loss="binary_crossentropy", metrics=metrics
)
callbacks = [keras.callbacks.ModelCheckpoint("fraud_model_at_epoch_{epoch}.keras")]
class_weight = {0: weight_for_0, 1: weight_for_1}
model.fit(
train_features,
train_targets,
batch_size=2048,
epochs=30,
verbose=2,
callbacks=callbacks,
validation_data=(val_features, val_targets),
class_weight=class_weight,
)
"""
## Conclusions
At the end of training, out of 56,961 validation transactions, we are:
- Correctly identifying 66 of them as fraudulent
- Missing 9 fraudulent transactions
- At the cost of incorrectly flagging 441 legitimate transactions
In the real world, one would put an even higher weight on class 1,
so as to reflect that False Negatives are more costly than False Positives.
Next time your credit card gets declined in an online purchase -- this is why.
"""
|