# Copyright 2024 The TensorTrade and TensorTrade-NG Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from __future__ import annotations
import os
import sys
import typing
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import style
from pandas.plotting import register_matplotlib_converters
from tensortrade.env.plotters.abstract import AbstractPlotter
from tensortrade.env.plotters.utils import create_auto_file_name, check_path, check_valid_format
from tensortrade.oms.orders import TradeSide
style.use("ggplot")
register_matplotlib_converters()
if typing.TYPE_CHECKING:
import pandas as pd
from collections import OrderedDict
[docs]
class MatplotlibTradingChart(AbstractPlotter):
""" Trading visualization for TensorTrade using Matplotlib
Parameters
---------
display : bool
True to display the chart on the screen, False for not.
save_format : str
A format to save the chart to. Acceptable formats are
png, jpg, svg, pdf.
path : str
The path to save the char to if save_format is not None. The folder
will be created if not found.
filename_prefix : str
A string that precedes automatically-created file name
when charts are saved. Default 'chart_'.
"""
registered_name = "matplotlib_trading_chart"
def __init__(self,
display: bool = True,
save_format: str = None,
path: str = 'charts',
filename_prefix: str = 'chart_') -> None:
super().__init__()
self._volume_chart_height = 0.33
self._df = None
self.fig = None
self._price_ax = None
self._volume_ax = None
self.net_worth_ax = None
self._show_chart = display
self._save_format = save_format
self._path = path
self._filename_prefix = filename_prefix
if self._save_format and self._path and not os.path.exists(path):
os.mkdir(path)
def _create_figure(self) -> None:
self.fig = plt.figure()
self.net_worth_ax = plt.subplot2grid((6, 1), (0, 0), rowspan=2, colspan=1)
self.price_ax = plt.subplot2grid((6, 1), (2, 0), rowspan=8,
colspan=1, sharex=self.net_worth_ax)
self.volume_ax = self.price_ax.twinx()
plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0)
def _render_trades(self, step_range, trades) -> None:
trades = [trade for sublist in trades.values() for trade in sublist]
for trade in trades:
if trade.step in range(sys.maxsize)[step_range]:
date = self._df.index.values[trade.step]
close = self._df['close'].values[trade.step]
color = 'green'
if trade.side is TradeSide.SELL:
color = 'red'
self.price_ax.annotate(' ', (date, close),
xytext=(date, close),
size="large",
arrowprops=dict(arrowstyle='simple', facecolor=color))
def _render_volume(self, step_range, times) -> None:
self.volume_ax.clear()
volume = np.array(self._df['volume'].values[step_range])
self.volume_ax.plot(times, volume, color='blue')
self.volume_ax.fill_between(times, volume, color='blue', alpha=0.5)
self.volume_ax.set_ylim(0, max(volume) / self._volume_chart_height)
self.volume_ax.yaxis.set_ticks([])
def _render_price(self, step_range, times, current_step) -> None:
self.price_ax.clear()
self.price_ax.plot(times, self._df['close'].values[step_range], color="black")
last_time = self._df.index.values[current_step]
last_close = self._df['close'].values[current_step]
last_high = self._df['high'].values[current_step]
self.price_ax.annotate('{0:.2f}'.format(last_close), (last_time, last_close),
xytext=(last_time, last_high),
bbox=dict(boxstyle='round',
fc='w', ec='k', lw=1),
color="black",
fontsize="small")
ylim = self.price_ax.get_ylim()
self.price_ax.set_ylim(ylim[0] - (ylim[1] - ylim[0]) * self._volume_chart_height, ylim[1])
# def _render_net_worth(self, step_range, times, current_step, net_worths, benchmarks):
def _render_net_worth(self, step_range, times, current_step, net_worths) -> None:
self.net_worth_ax.clear()
self.net_worth_ax.plot(times, net_worths[step_range], label='Net Worth', color="g")
self.net_worth_ax.legend()
legend = self.net_worth_ax.legend(loc=2, ncol=2, prop={'size': 8})
legend.get_frame().set_alpha(0.4)
last_time = times[-1]
last_net_worth = list(net_worths[step_range])[-1]
self.net_worth_ax.annotate('{0:.2f}'.format(last_net_worth), (last_time, last_net_worth),
xytext=(last_time, last_net_worth),
bbox=dict(boxstyle='round',
fc='w', ec='k', lw=1),
color="black",
fontsize="small")
self.net_worth_ax.set_ylim(min(net_worths) / 1.25, max(net_worths) * 1.25)
[docs]
def render(self) -> None:
# get price history
price_history = self.trading_env.feed.meta_history
# get performance data
performance_df = pd.DataFrame.from_dict(self.trading_env.portfolio.performance, orient='index')
net_worth = performance_df['net_worth']
trades = self.trading_env.broker.trades
if not self.fig:
self._create_figure()
if self._show_chart:
plt.show(block=False)
current_step = self.trading_env.clock.step - 1
self._df = price_history
#if max_steps:
# window_size = max_steps
#else:
# window_size = 20
window_size = 20
current_net_worth = round(net_worth[len(net_worth)-1], 1)
initial_net_worth = round(net_worth[0], 1)
profit_percent = round((current_net_worth - initial_net_worth) / initial_net_worth * 100, 2)
self.fig.suptitle('Net worth: $' + str(current_net_worth) +
' | Profit: ' + str(profit_percent) + '%')
window_start = max(current_step - window_size, 0)
step_range = slice(window_start, current_step)
times = self._df.index.values[step_range]
if len(times) > 0:
# self._render_net_worth(step_range, times, current_step, net_worths, benchmarks)
self._render_net_worth(step_range, times, current_step, net_worth)
self._render_price(step_range, times, current_step)
self._render_volume(step_range, times)
self._render_trades(step_range, trades)
self.price_ax.set_xticklabels(times, rotation=45, horizontalalignment='right')
plt.setp(self.net_worth_ax.get_xticklabels(), visible=False)
plt.pause(0.001)
[docs]
def save(self) -> None:
"""Saves the rendering of the `TradingEnv`.
"""
if not self._save_format:
return
else:
valid_formats = ['png', 'jpeg', 'svg', 'pdf']
check_valid_format(valid_formats, self._save_format)
check_path(self._path)
filename = create_auto_file_name(self._filename_prefix, self._save_format)
filename = os.path.join(self._path, filename)
self.fig.savefig(filename, format=self._save_format)
[docs]
def reset(self) -> None:
"""Resets the renderer.
"""
self.fig = None
self._price_ax = None
self._volume_ax = None
self.net_worth_ax = None
self._df = None