mirror of
https://github.com/kuhyx/WUT_Computer_Science.git
synced 2026-07-04 18:03:14 +02:00
feat: initial commit,
This commit is contained in:
parent
78ea388285
commit
dd2601fc39
162
.gitignore
vendored
Normal file
162
.gitignore
vendored
Normal file
@ -0,0 +1,162 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
vid
|
||||
3
.pylintrc
Normal file
3
.pylintrc
Normal file
@ -0,0 +1,3 @@
|
||||
[DESIGN]
|
||||
# Maximum number of statements in function / method body
|
||||
max-statements=16
|
||||
161
main.py
161
main.py
@ -1,28 +1,145 @@
|
||||
"""
|
||||
Code used to solve MountainCar-v0 gymnasium problem using Q-Learning algorithm
|
||||
"""
|
||||
from datetime import datetime
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
||||
if __name__ == "__main__":
|
||||
# init env
|
||||
env = gym.make("MountainCar-v0", render_mode="rgb_array")
|
||||
|
||||
# wrapper to record the video at 3rd episode and saves it to the folder
|
||||
# 'vid'
|
||||
def initialize_environment():
|
||||
"""
|
||||
Initialize environment and video recording
|
||||
"""
|
||||
# Initialize environment
|
||||
env = gym.make('MountainCar-v0', render_mode='rgb_array')
|
||||
# Save video
|
||||
now = datetime.now()
|
||||
time_string = now.strftime("%H:%M:%S")
|
||||
env = gym.wrappers.RecordVideo(
|
||||
env, video_folder="vid", episode_trigger=lambda x: x == 3
|
||||
)
|
||||
env,
|
||||
video_folder='vid',
|
||||
episode_trigger=lambda x: x == 1,
|
||||
disable_logger=False,
|
||||
name_prefix=time_string)
|
||||
return env
|
||||
|
||||
# an episode ends if goal is reached or other game ending factors (e.g.
|
||||
# reached max steps)
|
||||
n_episodes = 4
|
||||
for episode in range(n_episodes): # iterate episodes
|
||||
state, info = env.reset() # reset the env to an initial state
|
||||
done = False # boolean to stop an episode
|
||||
|
||||
while not done: # iterate steps
|
||||
# randomly choose a sample
|
||||
action = env.action_space.sample()
|
||||
# take the action (step) and observe the state and reward
|
||||
next_state, reward, terminated, truncated, info = env.step(action)
|
||||
# condition to stop an episode
|
||||
done = terminated or truncated
|
||||
|
||||
env.close()
|
||||
def initialize_q_table(env):
|
||||
"""
|
||||
Initialize "empty" Q-table
|
||||
"""
|
||||
# Initialize Q-table
|
||||
n_actions = env.action_space.n # Number of possible actions, should be 3
|
||||
# 0 accelerate left
|
||||
# 1 dont accelerate
|
||||
# 2 accelerate to the right
|
||||
q_table = np.zeros((n_actions,))
|
||||
return q_table
|
||||
|
||||
|
||||
def initialize_hyperparameters():
|
||||
"""
|
||||
Initialize hyperparameters used by algorithm
|
||||
"""
|
||||
hyperparameters = {
|
||||
"learning_rate": 0.1,
|
||||
"discount_factor": 0.99,
|
||||
"epsilon": 0.2,
|
||||
"max_episodes": 1
|
||||
}
|
||||
return hyperparameters
|
||||
|
||||
|
||||
def choose_action(hyperparameters, env, q_table):
|
||||
"""
|
||||
Choose one of 3 actions possible for the algorithm
|
||||
"""
|
||||
# hyperparameters["epsilon"]-greedy exploration-exploitation tradeoff
|
||||
if np.random.uniform(0, 1) < hyperparameters["epsilon"]:
|
||||
action = env.action_space.sample() # Choose a random action
|
||||
else:
|
||||
# Choose the action with the highest Q-value
|
||||
action = np.argmax(q_table)
|
||||
return action
|
||||
|
||||
|
||||
def update_q_table(q_table, action, hyperparameters, reward):
|
||||
"""
|
||||
Update q_table with newest reward
|
||||
"""
|
||||
# Q-table update
|
||||
q_value = q_table[action]
|
||||
max_q_value = np.max(q_table)
|
||||
new_q_value = (1 - hyperparameters["learning_rate"]) * q_value + \
|
||||
hyperparameters["learning_rate"] * \
|
||||
(reward + hyperparameters["discount_factor"] * max_q_value)
|
||||
q_table[action] = new_q_value
|
||||
return q_table
|
||||
|
||||
|
||||
def movement(hyperparameters, env, q_table, total_reward=0):
|
||||
"""
|
||||
Choose action and observe consequences
|
||||
"""
|
||||
action = choose_action(hyperparameters, env, q_table)
|
||||
# Take the action and observe the next state
|
||||
next_state, reward, terminated, truncated, info = env.step(action)
|
||||
done = terminated or truncated
|
||||
q_table = update_q_table(q_table, action, hyperparameters, reward)
|
||||
|
||||
total_reward += reward
|
||||
return hyperparameters, env, q_table, done, total_reward
|
||||
|
||||
|
||||
def episode_step(env, hyperparameters, q_table, episode_rewards):
|
||||
"""
|
||||
Actions done with every episode
|
||||
"""
|
||||
state, _ = env.reset() # Reset the environment to an initial state
|
||||
done = False # Boolean to indicate episode completion
|
||||
total_reward = 0 # Accumulate rewards for the episode
|
||||
|
||||
while not done:
|
||||
hyperparameters, env, q_table, done, total_reward = movement(
|
||||
hyperparameters, env, q_table, total_reward)
|
||||
|
||||
episode_rewards.append(total_reward)
|
||||
return env, hyperparameters, q_table, episode_rewards
|
||||
|
||||
|
||||
def training_loop(hyperparameters, env, q_table):
|
||||
"""
|
||||
Actual training for MountainCar
|
||||
"""
|
||||
episode_rewards = [] # List to store episode rewards
|
||||
|
||||
for episode in range(hyperparameters["max_episodes"]):
|
||||
env, hyperparameters, q_table, episode_rewards = episode_step(
|
||||
env, hyperparameters, q_table, episode_rewards)
|
||||
|
||||
return env, q_table
|
||||
|
||||
|
||||
def inference(env, q_table):
|
||||
"""
|
||||
Inference using the updated Q-table
|
||||
"""
|
||||
state, _ = env.reset()
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
# Choose the action with the highest Q-value
|
||||
action = np.argmax(q_table)
|
||||
# Take the action and observe the next state
|
||||
next_state, reward, terminated, truncated, info = env.step(action)
|
||||
done = terminated or truncated
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ENV = initialize_environment()
|
||||
Q_TABLE = initialize_q_table(ENV)
|
||||
HYPERPARAMETERS = initialize_hyperparameters()
|
||||
ENV, Q_TABLE = training_loop(HYPERPARAMETERS, ENV, Q_TABLE)
|
||||
inference(ENV, Q_TABLE)
|
||||
|
||||
ENV.close()
|
||||
|
||||
@ -9,3 +9,4 @@ dependencies:
|
||||
- numpy
|
||||
- python=3.9
|
||||
- pygame
|
||||
- opencv-python
|
||||
|
||||
Loading…
Reference in New Issue
Block a user