"class myMSE(nn.Module): # just normal MSE, but pytorch implementation somehow did not work properly\n",
" def forward(self, input, target):\n",
" return ((input-target)**2).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Crear red neuronal"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"net = NeuralNet( #skorch wrapper facility\n",
" model,\n",
" criterion=myMSE,\n",
" batch_size=batch_size,\n",
" lr=learning_rate,\n",
" max_epochs=num_epochs,\n",
" optimizer=optim.Adam, \n",
" iterator_train__shuffle=True,\n",
" iterator_train__num_workers=num_workers,\n",
" iterator_valid__shuffle=False,\n",
" iterator_valid__num_workers=num_workers,\n",
" train_split=predefined_split(validset), #strange naming, but validset will be used for validation not training, see skorch.helper.predefined_split documentation\n",
" callbacks=[Checkpoint(dirname=f'training',f_params='best_params.pt'), #Saves the best parameters to best_params.pt.\n",
" EarlyStopping(patience=patience, threshold=1e-3, threshold_mode='abs')], #stops training if valid loss did not improve for patience epochs \n",
" device=device\n",
")\n",
"\n",
"tstart = time.time()\n",
"net.fit(trainset)\n",
"print('Time for training', (time.time()-tstart)/60, 'min')"
classmyMSE(nn.Module):# just normal MSE, but pytorch implementation somehow did not work properly
defforward(self,input,target):
return ((input-target)**2).mean()
```
%% Cell type:markdown id: tags:
# Crear red neuronal
%% Cell type:code id: tags:
``` python
net=NeuralNet(#skorch wrapper facility
model,
criterion=myMSE,
batch_size=batch_size,
lr=learning_rate,
max_epochs=num_epochs,
optimizer=optim.Adam,
iterator_train__shuffle=True,
iterator_train__num_workers=num_workers,
iterator_valid__shuffle=False,
iterator_valid__num_workers=num_workers,
train_split=predefined_split(validset),#strange naming, but validset will be used for validation not training, see skorch.helper.predefined_split documentation
callbacks=[Checkpoint(dirname=f'training',f_params='best_params.pt'),#Saves the best parameters to best_params.pt.
EarlyStopping(patience=patience,threshold=1e-3,threshold_mode='abs')],#stops training if valid loss did not improve for patience epochs
device=device
)
tstart=time.time()
net.fit(trainset)
print('Time for training',(time.time()-tstart)/60,'min')