Source code for tensortrade.env.utils.feed_controller

# Copyright 2024 The TensorTrade-NG Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from __future__ import annotations

import typing

import pandas as pd

from dataclasses import dataclass

from warnings import warn

from pandas.core.interchange.dataframe_protocol import DataFrame

from tensortrade.core import Observable, TimeIndexed
from tensortrade.feed import Stream, DataFeed
from tensortrade.feed.core.base import IterableStream, Group, NameSpace

if typing.TYPE_CHECKING:
    from pandas import DataFrame

    from typing import Any, Dict, List, Optional

    from tensortrade.oms.wallets import Portfolio, Wallet


[docs] @dataclass class State: """The actual state of the environment. :param features: The features at this point in time. :type features: Dict[str, Any] :param meta: The metadata at this point in time. Could be None. :type meta: Optional[Dict[str, Any]] :param meta: The portfolio data at this point in time. :type meta: Dict[str, Any] """ features: Dict[str, Any] meta: Optional[Dict[str, Any]] portfolio: Dict[str, Any] step: int
[docs] class FeedController(Observable, TimeIndexed): """This class is responsible for controlling the global feed in a :class:`TradingEnv`. The :class:`FeedController` prepares the feed on initialization and provides access to the features, metadata and portfolio data. It has to be reset when the environment is reset and for new data ``next()`` has to be executed. .. note:: The feed need to be a group of streams compiled together to a :class:`DataFeed`. It should consists of: * features (Group): The features shown to the environment as observation for learning. This data should be normalized and prepared for the environment. If it's missing a :class:`ValueError` is raised. * meta (Group): The metadata used by the components, like plotters or plotters. This contains data like raw ohlcv data. It can be omitted but this will display a warning. :param feed: The feed to use. :type feed: DataFeed :param portfolio: The portfolio to fetch data from. :type portfolio: Portfolio """ def __init__( self, feed: DataFeed, portfolio: Portfolio ): super().__init__() self._meta_history: List[Dict[str, Any]] = [] self._prepare_feed(feed, portfolio) self._update_data() @property def state(self) -> State: """Gets the state at this point in time. :return: The state. :rtype: State """ return self._state @property def meta_history(self) -> DataFrame: """Gets the metadata history of this episode. :return: The metadata history. :rtype: DataFrame """ return pd.DataFrame(self._meta_history) @property def features_len(self) -> int: """Gets the number of available features for training the model. :return: The number of features. :rtype: int """ return self._feature_len @property def features_size(self) -> int: """Gets the number of features per state. :return: The number of features per state. :rtype: int """ return self._feature_size
[docs] def has_next(self) -> bool: """Checks if there is more data available. :return: Whether there is new data available. :rtype: bool """ return self._feed.has_next()
[docs] def next(self) -> None: """Get next data.""" self._update_data()
[docs] def reset(self, random_start: int = 0) -> None: """Resets the feed and gets first data""" self._meta_history = [] self._feed.reset(random_start=random_start) self._update_data()
def _update_data(self) -> None: """Updates the data variables""" data = self._feed.next() meta = data.get('meta') self._state = State( features=data.get('features'), meta=data.get('meta'), portfolio=data.get('portfolio'), step=self.clock.step ) if meta is not None: self._meta_history.append(meta) for listener in self.listeners: listener.on_next(self._state) def _prepare_feed(self, input_feed: DataFeed, portfolio: Portfolio) -> None: """Prepares the feed to be used by the environment and all other components. .. note:: The DataFeed is a group of streams compiled together to be used by the environment. It consists of: * features (Group): The features shown to the environment as observation for learning. This data should be normalized and prepared for the environment. * meta (Group): The metadata used by the components, like plotters or plotters. This contains data like raw OHLCV data. * portfolio (Group): This contains data about the portfolio and will be created by the environment itself. :param input_feed: The input feed to be used by the environment. Must contain at least a features group of streams. :type input_feed: DataFeed """ feed = [] # select features feed and do some checks try: features_feed = Stream.select(input_feed.inputs, lambda s: s.name == 'features' and isinstance(s, Group)) for fs in features_feed.inputs: if not isinstance(fs, IterableStream): raise ValueError('Environment only supports IterableStreams.') self._feature_len = len(features_feed.inputs[0].iterable) self._feature_size = len(features_feed.inputs) # display a warning when the user selects to many features. if self._feature_size > 20: warn('Your feature set contains more than 20 features. This may introduce noise and lead to overfitting, ' 'potentially reducing your model\'s effectiveness. Consider reducing the number of features to improve ' 'performance and accuracy.', UserWarning) feed += [features_feed] except AttributeError: raise AttributeError('Feed has no features feed.') # select meta feed and do checks try: meta_feed = Stream.select(input_feed.inputs, lambda s: s.name == 'meta' and isinstance(s, Group)) for fs in features_feed.inputs: if not isinstance(fs, IterableStream): raise ValueError('Environment only supports IterableStreams.') feed += [meta_feed] except AttributeError: warn('Feed has no meta feed. Therefor some components may not work.', UserWarning) # add portfolio feed += [Stream.group(self.create_portfolio_streams(portfolio)).rename('portfolio')] self._feed = DataFeed(feed) self._feed.compile() self.attach(portfolio)
[docs] @staticmethod def create_wallet_source(wallet: Wallet, include_worth: bool = True) -> List[Stream[float]]: """Creates a list of streams to describe a :class:`Wallet`. :param wallet: The wallet to create the streams for. :type wallet: Wallet :param include_worth: Whether to include the worth of the wallet. :type include_worth: bool :return: A list of streams to describe a :class:`Wallet`. :rtype: List[Stream[float]] """ exchange_name = wallet.exchange.name symbol = wallet.instrument.symbol streams = [] with NameSpace(exchange_name + ":/" + symbol): free_balance = Stream.sensor(wallet, lambda w: w.balance.as_float(), dtype="float").rename("free") locked_balance = Stream.sensor(wallet, lambda w: w.locked_balance.as_float(), dtype="float").rename( "locked") total_balance = Stream.sensor(wallet, lambda w: w.total_balance.as_float(), dtype="float").rename("total") streams += [free_balance, locked_balance, total_balance] if include_worth: price = Stream.select(wallet.exchange.streams(), lambda node: node.name.endswith(symbol)) worth = price.mul(total_balance).rename('worth') streams += [worth] return streams
[docs] @staticmethod def create_portfolio_streams(portfolio: Portfolio) -> List[Stream[float]]: """Creates a list of streams to describe a :class:`Portfolio`. :param portfolio: The portfolio to create the streams for. :type portfolio: Portfolio :return: A list of streams to describe a :class:`Portfolio`. :rtype: List[Stream[float]] """ base_symbol = portfolio.base_instrument.symbol sources = [] for wallet in portfolio.wallets: symbol = wallet.instrument.symbol sources += wallet.exchange.streams() sources += FeedController.create_wallet_source(wallet, include_worth=(symbol != base_symbol)) worth_streams = [] for s in sources: if s.name.endswith(base_symbol + ":/total") or s.name.endswith("worth"): worth_streams += [s] net_worth = Stream.reduce(worth_streams).sum().rename("net_worth") sources += [net_worth] return sources