Skip to content
Snippets Groups Projects
Commit 85f45f25 authored by Devon Harstrom's avatar Devon Harstrom
Browse files

Part 3 done

made two csv files training and testing

binary encoded famhist
used z score to normalize numeric features

then for the model added a bunch of stuff to stop overfitting
leaky relu, L2, dropout, and early stopping

+ more epochs

NOTE: Saved model achieved 76%+ on testing for the bonus mark
parent 02191f07
No related branches found
No related tags found
1 merge request!4Part 3 done
File added
from threading import activeCount
import numpy as np
import pandas as pd
import tensorflow as tf
df = pd.read_csv("heart.csv")
df = df.drop('row.names', axis=1) # drop unnecessary column
#convert categorical feature to binary [Absent: 0, Present: 1]
df['famhist'] = df['famhist'].map({'Present': 1, 'Absent': 0})
numeric_feature_names = ['sbp','tobacco','ldl','adiposity','typea','obesity','alcohol','age']
# use z-score = x - mean /standard deviation
for col in numeric_feature_names:
mean = df[col].mean()
std = df[col].std()
df[col] = (df[col] - mean) / std
X = df.drop('chd', axis=1).to_numpy()
y = df['chd'].to_numpy()
chd_indices = np.where(y == 1)[0] # firstly find all the indices of all the chd positive
non_chd_indices = np.where(y == 0)[0] # same with non
np.random.shuffle(chd_indices) # randomize it
np.random.shuffle(non_chd_indices)
train_size_chd = int(0.8 * len(chd_indices))
train_size_non_chd = int(0.8 * len(non_chd_indices))
# then for training add non and chd training same idea for testing
train_indices = np.concatenate([chd_indices[:train_size_chd], non_chd_indices[:train_size_non_chd]])
test_indices = np.concatenate([chd_indices[train_size_chd:], non_chd_indices[train_size_non_chd:]])
np.random.shuffle(train_indices) # randomize again
np.random.shuffle(test_indices)
X_train, y_train = X[train_indices], y[train_indices] # now we have training and testing that should 80% of the non and chd
X_test, y_test = X[test_indices], y[test_indices] # 20% for non and chd
train_data = pd.DataFrame(X_train, columns=df.drop('chd', axis=1).columns)
train_data['chd'] = y_train # add back chd
test_data = pd.DataFrame(X_test, columns=df.drop('chd', axis=1).columns)
test_data['chd'] = y_test # same
train_data.to_csv('heart_train.csv', index=False) # no need for index
test_data.to_csv('heart_test.csv', index=False)
###################################################################################
model = tf.keras.Sequential([
tf.keras.Input(shape=(X_train.shape[1],)),
tf.keras.layers.Dense(64, kernel_regularizer=tf.keras.regularizers.l2(0.001)),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.LeakyReLU(),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(32, kernel_regularizer=tf.keras.regularizers.l2(0.001)),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.LeakyReLU(),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(32, kernel_regularizer=tf.keras.regularizers.l2(0.001)),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.LeakyReLU(),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(1, activation='sigmoid')
])
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
num_runs = 1 # setting it to 1 for right now, Was used to debug changes/improvements in the model
train_accuracies = []
test_accuracies = []
# I made this to make it easier to test different adjustments and seeing the average accuracy
# simply loops over testing the model and averaging the accuracy results
for i in range(num_runs):
print(f"Run {i+1}/{num_runs}")
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.fit(X_train, y_train, epochs=2000, validation_data=(X_test, y_test), callbacks=[early_stopping], verbose=0)
# model.fit(X_train, y_train, epochs=500, verbose=0) # Set verbose=0 for cleaner output
model_loss1, model_acc1 = model.evaluate(X_train, y_train, verbose=2)
model_loss2, model_acc2 = model.evaluate(X_test, y_test, verbose=2)
train_accuracies.append(model_acc1)
test_accuracies.append(model_acc2)
print(f"Train / Test Accuracy: {model_acc1*100:.1f}% / {model_acc2*100:.1f}%\n")
avg_train_acc = np.mean(train_accuracies) * 100
avg_test_acc = np.mean(test_accuracies) * 100
print(f"Average Train Accuracy: {avg_train_acc:.1f}%")
print(f"Average Test Accuracy: {avg_test_acc:.1f}%")
#save - complete
model.save("CHDModel.keras")
This diff is collapsed.
sbp,tobacco,ldl,adiposity,famhist,typea,obesity,alcohol,age,chd
-1.1868883379672235,-0.7915589573639271,-1.3522199466693332,-1.8490282619897394,0.0,0.09127586271194567,-1.394057507405407,0.8960235885582912,-1.8355874642926113,0
-0.6014173044508709,1.1135039925867627,0.8641978886992121,1.1070045050805342,0.0,-0.4180169944489111,0.9981505998104523,-0.6962278124516513,-0.12430848999447097,1
-0.11352477652057717,-0.4214324413735074,-0.5844412193848046,-0.67818223990234,0.0,0.2949930055762884,-1.5673027373922699,-0.10801795676814246,0.8340077356124876,0
-0.5038387988648122,-0.7915589573639271,0.7724507451872242,-1.7397835727719249,0.0,0.39685157700845974,-0.6892104758150199,-0.6962278124516513,-1.7671363053206859,0
-0.40626029327875346,0.18818770261071338,0.5406684878937819,1.5452684935896481,0.0,0.8042858627371452,1.2259799433548202,0.6231595721717746,1.0393612125282645,0
-0.6014173044508709,1.4945165825769007,-0.12087670479791911,-1.0380470985021992,0.0,1.4154372913301734,-1.5839152936923802,-0.6962278124516513,0.423300781780934,1
-0.21110328210663593,0.5365420706016966,0.014329611956588788,0.086530584622361,0.0,-0.6217341373132538,-0.527831357471093,-0.6541544686076225,0.423300781780934,0
1.1549957960981867,0.4276813306045143,-0.24159663047158716,-0.36972664681674594,0.0,-2.455188423092338,-0.7437945893725249,-0.4650286608427166,1.1762635304721158,0
-0.015946270934518416,-0.7915589573639271,-0.7244763331662594,0.29216764667942363,0.0,0.7024272913049738,-0.08166555969670672,2.925347590666397,-0.9457223976575784,0
1.5453098184424217,-0.7044703653661812,-0.3043709918218945,2.140330741917273,1.0,0.2949930055762884,1.674518963457794,-0.6120811247635938,0.970910053556339,0
-0.40626029327875346,-0.7915589573639271,0.0722751762799495,0.07367826824379488,0.0,1.92473014849103,1.0384953793964344,-0.09535510570828913,-1.3564293514891321,0
-1.2844668435532822,1.3094533245816908,-1.183212050726198,-1.0547551097943355,1.0,0.09127586271194567,-0.5966548050001205,-0.6684512359332634,0.6971054176686364,0
0.9598387849260692,-0.22548310937857924,1.3132760122052571,1.113430663269817,1.0,0.8042858627371452,0.7726944785946721,1.4801486535773314,1.313165848415967,1
0.47194625699577536,0.9937571785898621,1.4533111259867122,1.163554697146226,1.0,-0.7235927087454251,0.1153118935760282,-0.4495064563177351,1.4500681663598183,1
-0.991731326795106,-0.7915589573639271,-0.8451962588399274,-1.1871339684935691,0.0,-0.4180169944489111,-0.5254581351425063,-0.5646975530557555,-1.014173556629504,0
0.08163223465154035,0.9806938897902006,0.0915903643877364,2.1955957023451083,0.0,-0.01058270872022569,4.669525542134792,-0.43357577272630676,0.6971054176686364,1
-0.8941528212090473,0.014010518615221776,-0.34783016506441533,1.8318751488316785,0.0,0.8042858627371452,1.0740937143252416,-0.6962278124516513,1.4500681663598183,1
1.3501528072703042,-0.6609260693673084,-1.1204376893758907,1.1082897367183904,1.0,-0.01058270872022569,0.21736045370527612,1.5095591463615068,1.1762635304721158,0
-0.7965743156229885,-0.09485022138196049,3.177191664606692,1.2792255445533234,1.0,0.19313443414411702,0.24346589931973475,-0.6962278124516513,0.5602030997247852,1
0.5695247625818342,2.2129974665583036,0.17367991384583065,0.519653646580049,1.0,-0.11244128015239704,0.4451897972496439,2.4801054082392966,0.8340077356124876,1
-0.991731326795106,-0.7915589573639271,-1.1349240804567307,-1.706367550187652,0.0,-0.4180169944489111,-1.799878525593812,-0.6856073567240324,-1.7671363053206859,1
-0.6989958100369298,-0.3996602933740709,-0.4830364818189232,-1.1267280815143073,1.0,-1.1310269944741107,-0.8980540407306901,-0.267733271748873,-1.561782828404909,0
-1.1868883379672235,0.09674468101308029,-0.07258873452845206,-1.388915335637062,1.0,0.9061444341693166,-0.6963301428007818,-0.42172987979934723,1.0393612125282645,1
2.7162518854751267,-0.23636918337829751,1.038034581669294,1.089011262150541,1.0,1.619154434194516,0.7798141455804329,-0.6962278124516513,-0.12430848999447097,0
1.838045335200598,-0.5302931813706897,1.709237368414888,1.3820440755818544,1.0,-1.1310269944741107,0.4190843516351844,-0.2223920953732692,1.0393612125282645,1
1.2525743016842454,-0.6826982173667449,1.0670073638309745,1.8293046855559656,1.0,-0.6217341373132538,3.7297295000142756,-0.5405972881353895,0.21794730486515712,1
-0.21110328210663593,1.9299595425656297,-0.9707449815405422,1.7920329680581233,0.0,-0.5198755658810824,2.2630781009474097,-0.6962278124516513,0.35484962280900834,0
0.47194625699577536,1.8646430985673204,-0.45889249668418974,1.1237125163726698,0.0,0.39685157700845974,0.07971355864722025,-0.10801795676814246,0.970910053556339,1
-0.015946270934518416,0.5147699226022602,1.2070424776124293,1.4964296913510953,0.0,-1.538461280202796,0.6279279165508553,-0.6962278124516513,1.1078123715001902,0
-0.11352477652057717,1.6469216185729558,0.5165245027590478,0.8281092396656429,1.0,2.230305862787544,0.38823246136355155,0.24082316597749406,1.0393612125282645,1
-1.0893098323811647,-0.5673058329697316,-0.922457011271075,-1.870877199833302,0.0,-0.8254512801775965,-1.0570599367460303,-0.6247439758234471,-1.4933316694329832,0
-0.0647355237275478,-0.5302931813706897,-0.7727643034357266,-0.19750560734395617,0.0,1.313578719898002,-0.45426146528489125,1.1419279865593137,-0.39811312588217346,0
0.6671032681678929,-0.13839451738083344,-0.048444749393718546,0.7561362679456708,0.0,-1.2328855659062818,0.7821873679090205,-0.5111867953512141,-0.19275964896639658,0
1.3501528072703042,-0.13839451738083344,-0.44440610560334964,0.17264110435875613,0.0,-0.8254512801775965,-1.2303051667328933,-0.6962278124516513,1.3816170073878926,1
-0.6989958100369298,-0.7915589573639271,0.023987206010482456,1.195685488092642,0.0,-0.4180169944489111,0.010890111118192748,-0.31797619692183937,0.28639846383708273,0
1.057417290512128,0.12287125861240404,0.9752602203189865,1.61724146530962,1.0,0.8042858627371452,1.6294277392146368,-0.570007780919565,0.7655565766405621,1
-0.21110328210663593,1.7753772917696309,-0.352658962091362,0.14950693487733663,1.0,-1.538461280202796,-1.0095954901742872,0.8919387978938227,1.2447146894440413,1
-1.5772023603114587,0.4276813306045143,-0.7437915212740462,-1.6845186123440892,0.0,-0.4180169944489111,-1.3655788394623616,-0.6962278124516513,-0.2612108079383222,0
-0.11352477652057717,0.47122562660338724,0.5599836760015687,0.27545963538728746,0.0,1.2117201484658306,-0.07929233736811919,-0.10801795676814246,1.1078123715001902,0
-0.11352477652057717,-0.5302931813706897,-0.9466009964058087,-2.3502686007538287,0.0,-0.11244128015239704,-0.838723482516011,-0.5569364507932648,-1.0826247156014297,0
-0.5038387988648122,0.5147699226022602,-0.17882226912127983,-0.3118912231131972,1.0,-0.31615842301673974,-0.008095667510504147,-0.6962278124516513,0.28639846383708273,0
3.008987402233303,3.388693458527872,-0.14984948695959968,1.9526869227902033,1.0,0.19313443414411702,1.4229573966275542,0.7742968267571207,1.1762635304721158,1
-0.991731326795106,-0.7654323797646033,-0.28022700668716094,-2.061091482236085,0.0,0.39685157700845974,-1.5269579578062886,-0.6962278124516513,-1.7671363053206859,0
0.3743677514097166,-0.0077616293842146964,-0.5940988134386981,-0.35173340388675295,0.0,-0.2142998515845684,-0.8909343737449291,1.0892341869876663,-0.05585733102254535,0
-0.6014173044508709,-0.7915589573639271,0.5986140522171426,0.4695296127036399,1.0,0.2949930055762884,-0.15523545188290844,-0.22565992790484427,1.4500681663598183,1
3.3993014245775384,0.30575730180767025,0.2171390870883515,-0.6036388049066546,1.0,-0.11244128015239704,-0.21931245475476174,0.2951508818149292,1.0393612125282645,1
-0.8941528212090473,-0.7915589573639271,-0.5120092639806036,-1.692230002171229,0.0,-0.2142998515845684,-1.3109947259048567,-0.6753953800628603,-1.561782828404909,0
0.2767892458236579,-0.7044703653661812,-0.048444749393718546,0.601908471402874,0.0,-2.353329851660167,0.319409013834524,-0.6660003615345822,0.8340077356124876,0
-0.40626029327875346,-0.7806728833642089,-1.1107800953219973,0.3654258500372522,1.0,1.4154372913301734,1.142917161854269,0.9507597834621735,-0.6034666027979503,0
-0.21110328210663593,1.8210988025684476,0.10607675546857669,0.5633515222671747,0.0,-0.01058270872022569,-0.2810162352980283,-0.36004954076586815,0.970910053556339,0
0.3743677514097166,-0.5390020405704643,-1.1880408477531448,1.1725513186112224,0.0,-0.31615842301673974,0.6326743612080296,1.1419279865593137,0.423300781780934,0
-0.3086817876926947,-0.7915589573639271,-1.3860215258579602,-1.0534698781564786,0.0,-0.4180169944489111,-0.5729225817142494,-0.6566053430063039,-1.904038623264537,0
-0.991731326795106,-0.6282678473681537,-1.043176936944743,-0.6627594602480601,0.0,0.6005687198728025,-0.3759451284415145,-0.6962278124516513,-0.7403689207418015,0
0.3743677514097166,1.4945165825769007,1.7140661654418348,1.2792255445533234,1.0,2.535881577084058,1.5867097373000678,-0.1288503891569334,0.6971054176686364,1
2.5210948743030093,0.3297066646070504,0.6227580373518761,1.4373090360096907,0.0,-1.1310269944741107,1.012389933781975,2.244821465965893,0.4917519407528596,0
-0.3086817876926947,-0.356115997375198,-0.801737085597407,1.2830812394668936,0.0,-0.8254512801775965,1.2805640569123249,2.564252095927354,1.0393612125282645,1
-0.6014173044508709,0.20995985061014974,1.284303230043577,0.8461024825956355,1.0,0.39685157700845974,0.6231814718936811,-0.6811140869931167,1.1762635304721158,1
1.447731312856363,1.6904659145718288,0.1640223197919374,0.16107401961804638,1.0,0.2949930055762884,0.236346232333973,-0.5896147761090154,1.1078123715001902,1
-0.11352477652057717,-0.7915589573639271,-1.1880408477531448,-0.9339433358358111,0.0,0.19313443414411702,-0.819737703887314,-0.6962278124516513,-1.7671363053206859,0
1.2525743016842454,2.043174712162699,0.7338203689716505,0.5106570251150524,1.0,-0.6217341373132538,-0.7817661466299195,0.8163701706011496,0.7655565766405621,1
-1.0893098323811647,-0.20371096137914277,-0.507180466953657,-1.5277203525255791,0.0,0.19313443414411702,-1.16622816386104,0.059866939541525735,-0.7403689207418015,0
-0.45504954607178283,-0.3234577753760433,0.20748149303445784,0.27803009866300066,0.0,-0.11244128015239704,-0.1481157848971467,-0.6120811247635938,-0.2612108079383222,0
1.1062065433051573,1.167934362585354,-0.043615952366771495,-1.31694236391709,1.0,0.4987101484406311,-0.5420706914426157,1.0684017545988753,0.21794730486515712,0
-0.11352477652057717,0.07714974781358751,-1.043176936944743,-1.16014410409858,1.0,-0.01058270872022569,-0.8624557058018829,0.43403376440686886,-0.46656428485409907,0
-0.11352477652057717,-0.18193881337970638,-1.0673209220794768,-2.0726585669767945,1.0,0.8042858627371452,-1.268276723990288,-0.5103698372183204,-1.2195270335452808,0
0.6671032681678929,-0.595609625368999,2.1148563186784126,0.619901714332867,0.0,0.2949930055762884,0.6160618049079193,-0.6811140869931167,-0.05585733102254535,1
-0.21110328210663593,-0.13839451738083344,-0.17882226912127983,-0.3003241383724875,0.0,0.2949930055762884,-1.3062482812476826,-0.302045513330411,1.313165848415967,0
1.1549957960981867,-0.46497673737238027,-1.1011225012681038,-0.7732893811037311,1.0,-0.4180169944489111,-0.4091702410417349,-0.6962278124516513,1.1078123715001902,1
-0.3086817876926947,0.5583142186011332,0.8352251065375316,1.388470233771138,1.0,0.9061444341693166,1.1215581608969847,-0.11863841249576136,0.1494961458932315,0
0.9598387849260692,0.5517825742013022,1.6319766159837403,0.6867337595014122,0.0,-0.7235927087454251,0.42620401862094615,3.087105300979362,0.35484962280900834,0
-0.21110328210663593,-0.7806728833642089,1.58851744274122,0.3268689009015529,0.0,-0.5198755658810824,0.19837467507657838,-0.6962278124516513,1.1762635304721158,0
0.17921074023759911,-0.7915589573639271,-0.20296625425601336,-0.023999336233309996,0.0,-0.6217341373132538,0.6825120301083601,-0.4294909820618379,-0.6034666027979503,1
0.7646817737539516,-0.4867488853718167,-1.4584534812621612,-0.8414066579101332,0.0,0.4987101484406311,-0.8007519252586164,1.0684017545988753,1.1078123715001902,0
-0.6014173044508709,-0.7915589573639271,-0.09190392263623898,0.4849523923579197,0.0,-0.5198755658810824,-0.2620304566693306,0.7742968267571207,-0.12430848999447097,0
-0.6014173044508709,-0.3996602933740709,0.7145051808638635,-0.7321619686923186,0.0,1.2117201484658306,-0.2928823469409643,-0.6680427568668165,-0.8088200797137272,0
-0.11352477652057717,-0.03824263658342576,0.7917659332950111,0.8795185051799084,1.0,-1.0291684230419393,0.6374208058652038,-0.5683738646537776,0.012593827949380266,1
0.17921074023759911,0.0902130366132493,-0.6568731747890054,-1.1832782735799996,0.0,0.6005687198728025,-1.2421712783758292,-0.5892062970425684,-0.32966196691024785,0
-1.2844668435532822,-0.7915589573639271,-1.463282278289108,-1.2141238328885586,0.0,-1.1310269944741107,-0.9526381542881949,-0.5532601391952429,-1.8355874642926113,0
0.47194625699577536,-0.6826982173667449,0.5937852551901955,0.960488098364877,0.0,0.09127586271194567,0.7655748116089103,-0.43357577272630676,-0.05585733102254535,0
-1.0893098323811647,0.9502128825909893,0.9607738292381467,0.4373988217572239,1.0,-1.2328855659062818,0.16514956247635795,0.97608548558188,0.35484962280900834,1
1.3501528072703042,0.5147699226022602,1.9603348138161183,1.6043891489310536,0.0,-1.4366027087706246,0.630301138879442,1.0684017545988753,0.6286542586967109,0
-0.8941528212090473,-0.7915589573639271,-0.5651260312770177,-0.2810456638046381,0.0,0.4987101484406311,0.27431778959136766,-0.6962278124516513,-0.7403689207418015,0
-1.8211486242766055,-0.687052646966632,1.2167000716663225,-1.5945523976941245,0.0,-0.31615842301673974,-1.477120288905958,-0.48422717696571993,-1.8355874642926113,0
0.2767892458236579,-0.7828500981641525,-0.6568731747890054,-0.230921629928229,0.0,-2.353329851660167,-0.5444439137712033,-0.5058765674874047,-0.8772712386856528,0
-0.3086817876926947,1.886415246566757,0.5889564581632488,0.9489210136241668,1.0,0.39685157700845974,0.9672987095388195,0.18200218040914312,1.313165848415967,1
0.7646817737539516,-0.2690274053774522,0.429606156274007,2.1544682899336958,1.0,0.6005687198728025,2.1420437621894646,-0.1709237330009622,0.4917519407528596,1
-0.8941528212090473,-0.7915589573639271,-0.7920794915435135,0.20091620039160207,0.0,-1.2328855659062818,-0.295255569269551,-0.6962278124516513,-1.8355874642926113,0
0.17921074023759911,-0.5041666037713659,1.3953655616633514,0.587770923386451,1.0,0.39685157700845974,1.2141138317118843,2.2828100191454532,-0.6719177617698759,0
-0.6989958100369298,-0.6587488545673648,-0.9900601696483292,-1.0611812679836188,1.0,0.8042858627371452,-0.7793929243013319,-0.22443449070550356,-1.561782828404909,0
0.47194625699577536,-0.7915589573639271,-0.4347485115494562,-1.0470437199671954,0.0,1.7210130056266872,-1.434402286991389,-0.6962278124516513,-1.4933316694329832,0
0.17921074023759911,-0.2690274053774522,-1.0576633280255834,-0.19493514406824294,0.0,0.09127586271194567,0.010890111118192748,1.7195173865152038,-0.39811312588217346,0
-0.5038387988648122,0.331883879406994,0.07710397330689656,0.7638476577728109,1.0,0.39685157700845974,0.08920644796156954,-0.6962278124516513,1.4500681663598183,0
-0.11352477652057717,-0.4954577445715913,-0.7631067093818331,-1.341361765036366,1.0,0.2949930055762884,-0.2525375673549821,-0.39803809394542805,-1.2879781925172065,0
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment