train_size=min(8,len(train))# Check if model overfits on small data, to ensure DNN actually is effectivedev_size=min(8,len(dev))min_training_batches=4train_batch_size=min(32,max(1,train_size//min_training_batches))evaluation_batch_size=min(1_024,dev_size)
defget_max_len(arrays):returnmax([len(array)forarrayinarrays])defpad(array,max_len):returnlist(np.pad(array,pad_width=(0,max_len-len(array)),constant_values=np.nan))# @torch.compile(mode="reduce-overhead")deftrain_batch(model,optimizer,loss,x,y,train_dl_len,batch_idx,accum_iter=1,k_frac=None):# x = x.half()# y = y.half()model.train()# with torch.set_grad_enabled(True): # turn on history tracking# forward passproba=model(x)loss_array=loss(proba,y)loss_scalar=loss_array.mean()# backward passoptimizer.zero_grad(set_to_none=True)loss_scalar.backward()# weights update# if accum_iter != 1 -> gradient accumulationbatch_num=batch_idx+1if((batch_num%accum_iter==0)or(batch_num==len(train_dl_len))):optimizer.step()# @torch.compile(mode="reduce-overhead")deftrain_epoch(dl,model,optimizer,loss,train_dl_len,k_frac=None):epoch_accuracies=[]forbatch_idx,(x,y)inenumerate(dl):train_batch(model,optimizer,loss,x,y,train_dl_len,batch_idx,accum_iter=1,k_frac=k_frac)epoch_accuracies+=eval_batch(model,x,y)returnepoch_accuracies# @torch.compile(mode="reduce-overhead")defeval_batch(model,x,y):# x = x.half()# y = y.half()model.eval()withtorch.inference_mode():# turn off history tracking# forward passproba=model(x)true=y.argmax(axis=1)pred=proba.argmax(axis=1)epoch_accuracy_array=(pred==true)# torch.sum()# epoch_loss_array = loss_value.detach() # loss_value.item() # batch lossreturnepoch_accuracy_array# @torch.compile(mode="reduce-overhead")defeval_epoch(dl,model):epoch_accuracies=[]forbatch_idx,(x,y)inenumerate(dl):epoch_accuracies+=eval_batch(model,x,y)returnepoch_accuraciesdeftrain_model(train_dl,dev_dl,model,loss,optimizer,n_epochs,eval_every=5,k_frac=None,agg=["mean"],log=False):model.train()summary_list=[]train_dl_len=len(train_dl)forepochinrange(1,n_epochs+1):epoch_train_accuracies=train_epoch(train_dl,model,optimizer,loss,train_dl_len,k_frac)ifepoch%eval_every==0orepoch==1:epoch_dev_accuracies=eval_epoch(dev_dl,model)else:epoch_dev_accuracies=[]foreinepoch_train_accuracies:summary_list.append([epoch,"Train",float(e)])foreinepoch_dev_accuracies:summary_list.append([epoch,"Dev",float(e)])iflog:print(f"Epoch {epoch}/{n_epochs} Completed")model.eval()summary=(pd.DataFrame(columns=["Epoch","Subset","Accuracy"],data=summary_list))ifagg:summary=(summary.groupby(["Epoch","Subset"]).agg(["mean"]))summary.columns=list(map('_'.join,summary.columns.values))summary=(summary.reset_index().pivot(index="Epoch",columns="Subset",# values = "Accuracy"))summary.columns=list(map('_'.join,summary.columns.values))summary=summary.reset_index()returnsummary
model=NeuralNet(train_dl,hidden_layers=[nn.Flatten(),nn.LazyLinear(100),nn.ReLU(),nn.LazyLinear(10),nn.ReLU()# nn.Sigmoid() not required])# model = model.half()# model = torch.compile(model, mode="reduce-overhead")
defplot_examples(data,plot_count=4,fig_size=(10,5)):x,y=data[:plot_count]# x = x.half()# y = y.half()pred=model(x).argmax(axis=1)cols=4rows=np.ceil(plot_count/cols).astype(int)fig,ax=plt.subplots(rows,cols,figsize=fig_size)foriinrange(plot_count):plt.subplot(rows,cols,i+1)plt.imshow(x[i])plt.title(f"True: {y[i].argmax()}; Pred: {pred[i]}")fig.tight_layout()plt.show()plot_examples(dev,4)
Last Updated: 2024-05-12 ; Contributors: AhmedThahir