Source code for tensortrade.env.plotters.matplotlib_trading_chart

# Copyright 2024 The TensorTrade and TensorTrade-NG Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from __future__ import annotations

import os
import sys
import typing

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import style
from pandas.plotting import register_matplotlib_converters

from tensortrade.env.plotters.abstract import AbstractPlotter
from tensortrade.env.plotters.utils import create_auto_file_name, check_path, check_valid_format
from tensortrade.oms.orders import TradeSide

style.use("ggplot")
register_matplotlib_converters()

if typing.TYPE_CHECKING:
    import pandas as pd

    from collections import OrderedDict


[docs] class MatplotlibTradingChart(AbstractPlotter): """ Trading visualization for TensorTrade using Matplotlib Parameters --------- display : bool True to display the chart on the screen, False for not. save_format : str A format to save the chart to. Acceptable formats are png, jpg, svg, pdf. path : str The path to save the char to if save_format is not None. The folder will be created if not found. filename_prefix : str A string that precedes automatically-created file name when charts are saved. Default 'chart_'. """ registered_name = "matplotlib_trading_chart" def __init__(self, display: bool = True, save_format: str = None, path: str = 'charts', filename_prefix: str = 'chart_') -> None: super().__init__() self._volume_chart_height = 0.33 self._df = None self.fig = None self._price_ax = None self._volume_ax = None self.net_worth_ax = None self._show_chart = display self._save_format = save_format self._path = path self._filename_prefix = filename_prefix if self._save_format and self._path and not os.path.exists(path): os.mkdir(path) def _create_figure(self) -> None: self.fig = plt.figure() self.net_worth_ax = plt.subplot2grid((6, 1), (0, 0), rowspan=2, colspan=1) self.price_ax = plt.subplot2grid((6, 1), (2, 0), rowspan=8, colspan=1, sharex=self.net_worth_ax) self.volume_ax = self.price_ax.twinx() plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0) def _render_trades(self, step_range, trades) -> None: trades = [trade for sublist in trades.values() for trade in sublist] for trade in trades: if trade.step in range(sys.maxsize)[step_range]: date = self._df.index.values[trade.step] close = self._df['close'].values[trade.step] color = 'green' if trade.side is TradeSide.SELL: color = 'red' self.price_ax.annotate(' ', (date, close), xytext=(date, close), size="large", arrowprops=dict(arrowstyle='simple', facecolor=color)) def _render_volume(self, step_range, times) -> None: self.volume_ax.clear() volume = np.array(self._df['volume'].values[step_range]) self.volume_ax.plot(times, volume, color='blue') self.volume_ax.fill_between(times, volume, color='blue', alpha=0.5) self.volume_ax.set_ylim(0, max(volume) / self._volume_chart_height) self.volume_ax.yaxis.set_ticks([]) def _render_price(self, step_range, times, current_step) -> None: self.price_ax.clear() self.price_ax.plot(times, self._df['close'].values[step_range], color="black") last_time = self._df.index.values[current_step] last_close = self._df['close'].values[current_step] last_high = self._df['high'].values[current_step] self.price_ax.annotate('{0:.2f}'.format(last_close), (last_time, last_close), xytext=(last_time, last_high), bbox=dict(boxstyle='round', fc='w', ec='k', lw=1), color="black", fontsize="small") ylim = self.price_ax.get_ylim() self.price_ax.set_ylim(ylim[0] - (ylim[1] - ylim[0]) * self._volume_chart_height, ylim[1]) # def _render_net_worth(self, step_range, times, current_step, net_worths, benchmarks): def _render_net_worth(self, step_range, times, current_step, net_worths) -> None: self.net_worth_ax.clear() self.net_worth_ax.plot(times, net_worths[step_range], label='Net Worth', color="g") self.net_worth_ax.legend() legend = self.net_worth_ax.legend(loc=2, ncol=2, prop={'size': 8}) legend.get_frame().set_alpha(0.4) last_time = times[-1] last_net_worth = list(net_worths[step_range])[-1] self.net_worth_ax.annotate('{0:.2f}'.format(last_net_worth), (last_time, last_net_worth), xytext=(last_time, last_net_worth), bbox=dict(boxstyle='round', fc='w', ec='k', lw=1), color="black", fontsize="small") self.net_worth_ax.set_ylim(min(net_worths) / 1.25, max(net_worths) * 1.25)
[docs] def render(self) -> None: # get price history price_history = self.trading_env.feed.meta_history # get performance data performance_df = pd.DataFrame.from_dict(self.trading_env.portfolio.performance, orient='index') net_worth = performance_df['net_worth'] trades = self.trading_env.broker.trades if not self.fig: self._create_figure() if self._show_chart: plt.show(block=False) current_step = self.trading_env.clock.step - 1 self._df = price_history #if max_steps: # window_size = max_steps #else: # window_size = 20 window_size = 20 current_net_worth = round(net_worth[len(net_worth)-1], 1) initial_net_worth = round(net_worth[0], 1) profit_percent = round((current_net_worth - initial_net_worth) / initial_net_worth * 100, 2) self.fig.suptitle('Net worth: $' + str(current_net_worth) + ' | Profit: ' + str(profit_percent) + '%') window_start = max(current_step - window_size, 0) step_range = slice(window_start, current_step) times = self._df.index.values[step_range] if len(times) > 0: # self._render_net_worth(step_range, times, current_step, net_worths, benchmarks) self._render_net_worth(step_range, times, current_step, net_worth) self._render_price(step_range, times, current_step) self._render_volume(step_range, times) self._render_trades(step_range, trades) self.price_ax.set_xticklabels(times, rotation=45, horizontalalignment='right') plt.setp(self.net_worth_ax.get_xticklabels(), visible=False) plt.pause(0.001)
[docs] def save(self) -> None: """Saves the rendering of the `TradingEnv`. """ if not self._save_format: return else: valid_formats = ['png', 'jpeg', 'svg', 'pdf'] check_valid_format(valid_formats, self._save_format) check_path(self._path) filename = create_auto_file_name(self._filename_prefix, self._save_format) filename = os.path.join(self._path, filename) self.fig.savefig(filename, format=self._save_format)
[docs] def reset(self) -> None: """Resets the renderer. """ self.fig = None self._price_ax = None self._volume_ax = None self.net_worth_ax = None self._df = None