diff --git a/lab5/code/main.py b/lab5/code/main.py index 6efa8fab..c3c14aa3 100644 --- a/lab5/code/main.py +++ b/lab5/code/main.py @@ -69,12 +69,11 @@ def get_optimizer(model): """ Return optimizer function """ if OPTIMIZER_TYPE == 'SGD': return optim.SGD(model.parameters(), lr=LEARNING_RATE) - elif OPTIMIZER_TYPE == 'SGD_Momentum': + if OPTIMIZER_TYPE == 'SGD_Momentum': return optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9) - elif OPTIMIZER_TYPE == 'Adam': + if OPTIMIZER_TYPE == 'Adam': return optim.Adam(model.parameters(), lr=LEARNING_RATE) - else: - raise ValueError("Unsupported optimizer type!") + raise ValueError("Unsupported optimizer type!") def initial_configuration():