Skip to content
Snippets Groups Projects
Commit dae47727 authored by David Akim's avatar David Akim
Browse files

Update file 3_train_model.ipynb

parent 3202807a
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
# Configurar directorios
%% Cell type:code id: tags:
``` python
DATADIR = 'processed_data/'
```
%% Cell type:markdown id: tags:
# Cargar bibliotecas
%% Cell type:code id: tags:
``` python
import numpy as np #DL specific imports below
import sys, time, copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
# from torchvision import models, transforms
from skorch.net import NeuralNet #pytorch wrapper skorch
from skorch.helper import predefined_split
from skorch.callbacks import Checkpoint, EarlyStopping
```
%% Cell type:markdown id: tags:
# Variables de configuración
%% Cell type:code id: tags:
``` python
months = np.array([[(1,31),(2,28),(12,31)]])
valid, test = [2015,2018], [2019,2022] #years to use for validation and testing, do not use these years to compute normalization constants
train = [x for x in np.arange(2002,2015) if x not in valid+test]
targetVar = 'fdimrk'
var = ['u10','v10','t2m','lai_hv','lai_lv','tp']
means = np.array([np.load(f'norm_consts/input_{v}.npy')[0] for v in var])
stds = np.array([np.load(f'norm_consts/input_{v}.npy')[1] for v in var])
targetMean, targetStd = np.load(f'norm_consts/target_{targetVar}.npy')
```
%% Cell type:markdown id: tags:
# Parámetros de aprendizaje profundo
%% Cell type:code id: tags:
``` python
##### DL parameters
batch_size = 64
learning_rate = 1e-3
num_epochs = 200
num_workers = 8
weight_decay=0.
patience=30 # early stopping if valid loss did not improve for 30 epochs
```
%% Cell type:markdown id: tags:
# Parámetros de Antorcha
%% Cell type:code id: tags:
``` python
member = 0 #ensemble member = seed for weight initialization
torch.manual_seed(member) #for reproducibility and creation of a seed ensemble
np.random.seed(member)
torch.backends.cudnn.benchmark = False
# torch.set_deterministic(True)
torch.use_deterministic_algorithms(True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
def pytorch_count_params(model): #counts the number trainable parameters in a pytorch model
tot = 0
for x in model.parameters():
#print(x.size())
tot += np.prod(x.size())
return tot
```
%% Cell type:markdown id: tags:
# Cuenta el número de parámetros entrenables en un modelo pytorch
%% Cell type:code id: tags:
``` python
def pytorch_count_params(model): #counts the number trainable parameters in a pytorch model
tot = 0
for x in model.parameters():
#print(x.size())
tot += np.prod(x.size())
return tot
```
%% Cell type:markdown id: tags:
# Conjunto de datos del índice de peligro de incendios y ERA5
%% Cell type:code id: tags:
``` python
class copernicusDataset(Dataset):
def __init__(self,years,aug=False): #aug = use rotation and flipping as data augmentation
self.years = years
self.length = len(self.years)*sum([m[1] for m in months]) #nb of years * nb of time steps each year
self.aug = aug
def idxToFile(self,idx): #conversion between time step index and (year, month, day, hour)-input and target file
year = self.years[idx//(sum([m[1] for m in months]))]
tInYear = idx%(sum([m[1] for m in months])) #time step within this year
monthIdx = np.argmax(tInYear < np.array([sum(months[:m,1]) for m in range(1,len(months)+1)]))
month = months[monthIdx,0]
tInMonth = tInYear - sum(months[:monthIdx,1]) #time step within this month
day = tInMonth + 1 #day numbering starts with 1
return f"/{year:d}_{month:02d}_{day:02d}.npy", f"/{year:d}_{month:02d}_{day:02d}.npy"
def normalize(self, x): #normalize the input fields
return ((x.transpose()-means)/stds).transpose()
def __len__(self):
return self.length
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
inpFile, targetFile = self.idxToFile(idx)
inp = []
for v in var:
inp += [np.load(DATADIR+v+inpFile)]
inp = self.normalize(np.stack(inp))
target = ((np.load(DATADIR+targetVar+targetFile)-targetMean)/targetStd).reshape((1,100,100))
if self.aug: #50 % probability to rotate by 180 deg, 50 % probability to flip left and right
rot = np.random.randint(2) #0 -> no rotate, 1 -> rotate
inp = np.rot90(inp,k=2*rot,axes=(1,2))
target = np.rot90(target,k=2*rot,axes=(1,2))
if np.random.randint(2): #0 -> no flip, 1 -> flip
inp = np.flip(inp,axis=2)
target = np.flip(target,axis=2)
return torch.tensor(inp.astype(np.float32)), torch.tensor(target.astype(np.float32))
trainset = copernicusDataset(train, aug=True)
validset = copernicusDataset(valid)
```
%% Cell type:markdown id: tags:
# Modelo Unet
%% Cell type:code id: tags:
``` python
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__() # in: len(var) x 120 x 180 140x140
self.conv1 = nn.Conv2d(in_channels=len(var),out_channels=64,kernel_size=3) # out: 64 x 118 x 178
self.bn1 = nn.BatchNorm2d(64)
self.conv2 = nn.Conv2d(64,64,3) # out: 64 x 116 x 176
self.bn2 = nn.BatchNorm2d(64)
self.pool1 = nn.MaxPool2d(2) # out: 64 x 58 x 88
self.conv3 = nn.Conv2d(64,128,3) # out: 128 x 56 x 86
self.bn3 = nn.BatchNorm2d(128)
self.conv4 = nn.Conv2d(128,128,3) # out: 128 x 54 x 84
self.bn4 = nn.BatchNorm2d(128)
self.pool2 = nn.MaxPool2d(2) # out: 128 x 27 x 42
self.conv5 = nn.Conv2d(128,256,3) # out: 256 x 25 x 40
self.bn5 = nn.BatchNorm2d(256)
self.conv6 = nn.Conv2d(256,256,3) # out: 256 x 23 x 38
self.bn6 = nn.BatchNorm2d(256)
self.upconv1 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1) #out: 128 x 46 x 76
### concat with crop(conv 4) -> out: 256 x 46 x 76
self.conv7 = nn.Conv2d(256,128,3) # out: 128 x 44 x 74
self.bn7 = nn.BatchNorm2d(128)
self.conv8 = nn.Conv2d(128,128,3) # out: 128 x 42 x 72
self.bn8 = nn.BatchNorm2d(128)
self.upconv2 = nn.ConvTranspose2d(128,64,4,2,1) # out 64 x 84 x 144
### concat with crop(conv4) -> out: 128 x 84 x 144
self.conv9 = nn.Conv2d(128,64,3) # out: 64 x 82 x 142
self.bn9 = nn.BatchNorm2d(64)
self.conv10 = nn.Conv2d(64,64,3) # out: 64 x 80 x 140
self.bn10 = nn.BatchNorm2d(64)
self.conv11 = nn.Conv2d(64,1,1) # out: 1 x 80 x 140
def forward(self, x):
level1 = F.relu(self.bn2(self.conv2(F.relu(self.bn1(self.conv1(x))))))
level2 = F.relu(self.bn4(self.conv4(F.relu(self.bn3(self.conv3(self.pool1(level1)))))))
level3 = F.relu(self.bn6(self.conv6(F.relu(self.bn5(self.conv5(self.pool2(level2)))))))
### going up again - to center crop the concatenated array, use the pad function with negative padding
level2 = F.relu(self.bn8(self.conv8(F.relu(self.bn7(self.conv7(torch.cat((F.pad(level2,[-4,-4,-4,-4]), self.upconv1(level3)), dim=1)))))))
level1 = F.relu(self.bn10(self.conv10(F.relu(self.bn9(self.conv9(torch.cat((F.pad(level1,[-16,-16,-16,-16]), self.upconv2(level2)), dim=1)))))))
return self.conv11(level1)
model = UNet()
print('Number of parameters in the model', pytorch_count_params(model))
```
%% Cell type:code id: tags:
``` python
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model)
```
%% Cell type:markdown id: tags:
# Error medio cuadrado
%% Cell type:code id: tags:
``` python
class myMSE(nn.Module): # just normal MSE, but pytorch implementation somehow did not work properly
def forward(self, input, target):
return ((input-target)**2).mean()
```
%% Cell type:markdown id: tags:
# Crear red neuronal
%% Cell type:code id: tags:
``` python
net = NeuralNet( #skorch wrapper facility
model,
criterion=myMSE,
batch_size=batch_size,
lr=learning_rate,
max_epochs=num_epochs,
optimizer=optim.Adam,
iterator_train__shuffle=True,
iterator_train__num_workers=num_workers,
iterator_valid__shuffle=False,
iterator_valid__num_workers=num_workers,
train_split=predefined_split(validset), #strange naming, but validset will be used for validation not training, see skorch.helper.predefined_split documentation
callbacks=[Checkpoint(dirname=f'training',f_params='best_params.pt'), #Saves the best parameters to best_params.pt.
EarlyStopping(patience=patience, threshold=1e-3, threshold_mode='abs')], #stops training if valid loss did not improve for patience epochs
device=device
)
tstart = time.time()
net.fit(trainset)
print('Time for training', (time.time()-tstart)/60, 'min')
```
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment