Source code for tensortrade.env.environment

# Copyright 2024 The TensorTrade and TensorTrade-NG Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from __future__ import annotations

import random
import typing
import uuid
from warnings import warn

import gymnasium

from typing import List

from gymnasium.core import ActType, ObsType, RenderFrame

from tensortrade.core import TimeIndexed, Clock, Component
from tensortrade.env.plotters.utils import AggregatePlotter
from tensortrade.env.utils import FeedController, ObsState
from tensortrade.feed import DataFeed
from tensortrade.oms.orders import Broker

if typing.TYPE_CHECKING:
    from typing import Dict, Tuple, Any, SupportsFloat, Optional, Union

    from tensortrade.env.actions.abstract import AbstractActionScheme
    from tensortrade.env.observers.abstract import AbstractObserver
    from tensortrade.env.rewards.abstract import AbstractRewardScheme
    from tensortrade.env.renderers.abstract import AbstractRenderer
    from tensortrade.env.plotters.abstract import AbstractPlotter
    from tensortrade.env.stoppers.abstract import AbstractStopper
    from tensortrade.env.informers.abstract import AbstractInformer
    from tensortrade.oms.wallets import Portfolio


[docs] class TradingEnv(gymnasium.Env, TimeIndexed): """A trading environment made for use with Gym-compatible reinforcement learning algorithms. Parameters ---------- action_scheme : `AbstractActionScheme` A component for generating an action to perform at each step of the environment. reward_scheme : `RewardScheme` A component for computing reward after each step of the environment. observer : `AbstractObserver` A component for generating observations after each step of the environment. informer : `AbstractInformer` A component for providing information after each step of the environment. renderer : `AbstractRenderer` A component for rendering the environment. render_mode : str The chosen render mode. As example 'human'. plotter : `AbstractPlotter` A component for rendering the environment. """ def __init__(self, portfolio: Portfolio, feed: DataFeed, action_scheme: AbstractActionScheme, reward_scheme: AbstractRewardScheme, observer: AbstractObserver, *, stopper: Optional[AbstractStopper] = None, informer: Optional[AbstractInformer] = None, renderer: Optional[AbstractRenderer] = None, render_mode: Optional[str] = None, plotter: Union[Optional[AbstractPlotter], List[AbstractPlotter]] = None, random_start_pct: float = 0.00 ) -> None: super().__init__() self.random_start_pct = random_start_pct # public variables self.agent_id: Optional[str] = None self.episode_id: Optional[str] = None self.n_episode: int = -1 self.metadata: Dict[str, Any] = {} # private variables self._action_scheme = action_scheme self._reward_scheme = reward_scheme self._observer = observer self._stopper = stopper self._informer = informer self._portfolio = portfolio self._renderer = renderer # renderer can be a list of multiple plotters if plotter is not None and isinstance(plotter, List): self._plotter = AggregatePlotter(renderers=plotter) else: self._plotter = plotter # init portfolio self._portfolio.clock = self._clock # internal attributes self._broker = Broker() self._feed = FeedController(feed, self._portfolio) self._last_state: Optional[ObsState] = None # init components self._action_scheme.trading_env = self self._reward_scheme.trading_env = self self._observer.trading_env = self if self._renderer is not None: self._renderer.trading_env = self self.metadata['render_modes'] = self._renderer.render_modes if self._plotter is not None: self._plotter.trading_env = self if self._stopper is not None: self._stopper.trading_env = self if self._informer is not None: self._informer.trading_env = self # set action and observation space self.action_space = self._action_scheme.action_space self.observation_space = self._observer.observation_space # configure renderer if render_mode is not None: if render_mode in self.metadata['render_modes']: self.render_mode = render_mode else: warn(f'Render mode "{render_mode}" is not supported.', UserWarning) @property def clock(self) -> Clock: return self._clock @property def portfolio(self) -> Portfolio: return self._portfolio @property def broker(self) -> Broker: return self._broker @property def feed(self) -> FeedController: return self._feed @property def last_state(self) -> ObsState: return self._last_state @property def components(self) -> Dict[str, Component]: """The components of the environment. (`Dict[str,Component]`, read-only)""" return { "action_scheme": self._action_scheme, "reward_scheme": self._reward_scheme, "observer": self._observer, "stopper": self._stopper, "informer": self._informer, "renderer": self._renderer }
[docs] def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """Run one timestep of the environment's dynamics using the agent actions. When the end of an episode is reached (``terminated or truncated``), it is necessary to call :meth:`reset` to reset this environment's state for the next episode. .. note:: The tuple returned contains the data according to :class:`gymnasium.Env` specifications: * observation (ObsType): An element of the environment's :attr:`observation_space` as the next observation due to the agent actions. This could be a numpy array with the observed features at that time. * reward (SupportsFloat): The reward as a result of taking the action. * terminated (bool): Whether the agent reaches the terminal state which can be positive or negative. This happens when there is no training data anymore or by the metric defined by :class:`AbstractStopper`. * truncated (bool): Whether the truncation condition outside the scope is satisfied. This is not used by TensorTrade-NG. * info (Dict[str, Any]): Contains auxiliary diagnostic information (helpful for debugging, learning, and logging). It's controlled by the :class:`AbstractInformer`. .. note:: Because the internals of this method may look a bit special, hereby a little explanation: #. The first step is to execute the action defined by the :class:`AbstractActionScheme`. After executing, we will have orders executed and therefor need to get new data. #. We will now use ``self.feed.next()`` to fetch the newest data with the changes (like Orders) that the action has done to the environment. We begin a new state. #. Now we are ready to reward this new state by using the :class:`AbstractRewardScheme`. #. After rewarding the agent we can get a new observation and info from this new state. #. Last but not least we need to check if it's time to terminate this episode. This can either happen because :class:`AbstractStopper` decides it, or we don't have any more data to begin a new state. :param action: An action provided by the agent to update the environment state. :type action: ActType :return: A :class:`gymnasium.Env` observation of the environment to learn the agent. :rtype: Tuple[ObsType, float, bool, bool, Dict[str, Any]] """ # Execute the action decided by the agent self._action_scheme.perform_action(action) # Get new data and begin a new state self.clock.increment() self.feed.next() # Reward the agent, get a new observation for the next decision and add the info reward = self._reward_scheme.reward() obs, info = self._get_obs() # Now we decide if we need to end this episode if self._stopper is not None: terminated = self._stopper.stop() else: terminated = False # If we are not terminated right now, check if there is still data available if not terminated: terminated = not self.feed.has_next() # Save last state self._last_state = ObsState( observation=obs, info=info, reward=reward, terminated=terminated ) if self.render_mode == 'human': self._renderer.render() return obs, reward, terminated, False, info
[docs] def reset( self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None ) -> Tuple[ObsType, Dict[str, Any]]: """Resets the environment to an initial internal state, returning an initial observation and info. This method resets all components of the environment to it's initial state and begins with a new episode. The seed parameter is used to reset the PRNG of the environment. It should always be used to initialize the environment and after the environment is terminated or truncated. Then it returns the first observation and info according to the used components. .. note:: The tuple returned contains the data according to :class:`gymnasium.Env` specifications: * observation (ObsType): The first observation of the environment. Like in ``step()``. * info (Dict[str, Any]): The info-dict like in ``step()`` :return: A :class:`gymnasium.Env` initial observation. :rtype: Tuple[ObsType, Dict[str, Any]] """ super().reset(seed=seed) # reset all components self._reset_env(seed=seed) # return new observation obs, info = self._get_obs() # Save last state self._last_state = ObsState( observation=obs, info=info ) if self.render_mode == 'human': self._renderer.render() return obs, info
[docs] def render(self) -> Optional[Union[RenderFrame, List[RenderFrame]]]: """Renders the environment according to :class:`gymnasium.Env` specifications. :returns: A `RenderFrame` or a list of `RenderFrame` instances. :rtype: Optional[Union[RenderFrame, List[RenderFrame]]] """ if self._renderer is not None: return self._renderer.render()
[docs] def plot(self, **kwargs) -> None: """Renders the environment.""" if self._plotter is not None: self._plotter.render()
[docs] def close(self) -> None: """Closes the environment.""" self._plotter.close()
def _get_obs(self) -> Tuple[ObsType, Dict[str, Any]]: obs = self._observer.observe() if self._informer is not None: info = self._informer.info() else: info = {} return obs, info def _reset_env(self, seed: Optional[int] = None) -> None: if seed is not None: random.seed(seed) if self.random_start_pct > 0: random_start = random.randint(0, self.feed.features_len) else: random_start = 0 # reset env state self.episode_id = str(uuid.uuid4()) self.n_episode += 1 self._clock.reset() self._portfolio.reset() self._broker.reset() self._feed.reset(random_start=random_start) # reset component state self._action_scheme.reset() self._observer.reset() self._reward_scheme.reset() if self._stopper is not None: self._stopper.reset() if self._informer is not None: self._informer.reset() if self._renderer is not None: self._renderer.reset() if self._plotter is not None: self._plotter.reset()