Source code for tensortrade.env.plotters.plotly_trading_chart

# Copyright 2024 The TensorTrade and TensorTrade-NG Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from __future__ import annotations

import os
import typing

import pandas as pd

import plotly.graph_objects as go
from IPython.display import display, clear_output
from plotly.subplots import make_subplots

from tensortrade.env.plotters.abstract import AbstractPlotter
from tensortrade.env.plotters.utils import check_valid_format, check_path, create_auto_file_name

if typing.TYPE_CHECKING:
    from typing import Tuple, Union

    from collections import OrderedDict

[docs] class PlotlyTradingChart(AbstractPlotter): """Trading visualization for TensorTrade using Plotly. Parameters ---------- display : bool True to display the chart on the screen, False for not. height : int Chart height in pixels. Affects both display and saved file charts. Set to None for 100% height. Default is None. save_format : str A format to save the chart to. Acceptable formats are html, png, jpeg, webp, svg, pdf, eps. All the formats except for 'html' require Orca. Default is None for no saving. path : str The path to save the char to if save_format is not None. The folder will be created if not found. filename_prefix : str A string that precedes automatically-created file name when charts are saved. Default 'chart_'. timestamp_format : str The format of the date shown in the chart title. auto_open_html : bool Works for save_format='html' only. True to automatically open the saved chart HTML file in the default browser, False otherwise. include_plotlyjs : Union[bool, str] Whether to include/load the plotly.js library in the saved file. 'cdn' results in a smaller file by loading the library online but requires an Internet connect while True includes the library resulting in much larger file sizes. False to not include the library. For more details, refer to https://plot.ly/python-api-reference/generated/plotly.graph_objects.Figure.html Notes ----- Possible Future Enhancements: - Saving images without using Orca. - Limit displayed step range for the case of a large number of steps and let the shown part of the chart slide after filling that range to keep showing recent data as it's being added. References ---------- .. [1] https://plot.ly/python-api-reference/generated/plotly.graph_objects.Figure.html .. [2] https://plot.ly/python/figurewidget/ .. [3] https://plot.ly/python/subplots/ .. [4] https://plot.ly/python/reference/#candlestick .. [5] https://plot.ly/python/#chart-events """ registered_name = "plotly_trading_chart" def __init__(self, display: bool = True, height: int = None, timestamp_format: str = '%Y-%m-%d %H:%M:%S', save_format: str = None, path: str = 'charts', filename_prefix: str = 'chart_', auto_open_html: bool = False, include_plotlyjs: Union[bool, str] = 'cdn') -> None: super().__init__() self._height = height self._timestamp_format = timestamp_format self._save_format = save_format self._path = path self._filename_prefix = filename_prefix self._include_plotlyjs = include_plotlyjs self._auto_open_html = auto_open_html if self._save_format and self._path and not os.path.exists(path): os.mkdir(path) self.fig = None self._price_chart = None self._volume_chart = None self._performance_chart = None self._net_worth_chart = None self._base_annotations = None self._last_trade_step = 0 self._show_chart = display def _create_figure(self, performance_keys: dict) -> None: fig = make_subplots( rows=4, cols=1, shared_xaxes=True, vertical_spacing=0.03, row_heights=[0.55, 0.15, 0.15, 0.15], ) fig.add_trace(go.Candlestick(name='Price', xaxis='x1', yaxis='y1', showlegend=False), row=1, col=1) fig.update_layout(xaxis_rangeslider_visible=False) fig.add_trace(go.Bar(name='Volume', showlegend=False, marker={'color': 'DodgerBlue'}), row=2, col=1) for k in performance_keys: fig.add_trace(go.Scatter(mode='lines', name=k), row=3, col=1) fig.add_trace(go.Scatter(mode='lines', name='Net Worth', marker={'color': 'DarkGreen'}), row=4, col=1) fig.update_xaxes(linecolor='Grey', gridcolor='Gainsboro') fig.update_yaxes(linecolor='Grey', gridcolor='Gainsboro') fig.update_xaxes(title_text='Price', row=1) fig.update_xaxes(title_text='Volume', row=2) fig.update_xaxes(title_text='Performance', row=3) fig.update_xaxes(title_text='Net Worth', row=4) fig.update_xaxes(title_standoff=7, title_font=dict(size=12)) self.fig = go.FigureWidget(fig) self._price_chart = self.fig.data[0] self._volume_chart = self.fig.data[1] self._performance_chart = self.fig.data[2] self._net_worth_chart = self.fig.data[-1] self.fig.update_annotations({'font': {'size': 12}}) self.fig.update_layout(template='plotly_white', height=self._height, margin=dict(t=50)) self._base_annotations = self.fig.layout.annotations def _create_trade_annotations(self, trades: OrderedDict, price_history: pd.DataFrame) -> Tuple[go.layout.Annotation]: """Creates annotations of the new trades after the last one in the chart. Parameters ---------- trades : `OrderedDict` The history of trades for the current episode. price_history : `pd.DataFrame` The price history of the current episode. Returns ------- `Tuple[go.layout.Annotation]` A tuple of annotations used in the renderering process. """ annotations = [] for trade in reversed(trades.values()): trade = trade[0] tp = float(trade.price) ts = float(trade.size) if trade.step <= self._last_trade_step: break if trade.side.value == 'buy': color = 'DarkGreen' ay = 15 qty = round(ts / tp, trade.quote_instrument.precision) text_info = dict( step=trade.step, datetime=price_history.iloc[trade.step - 1]['date'], side=trade.side.value.upper(), qty=qty, size=ts, quote_instrument=trade.quote_instrument, price=tp, base_instrument=trade.base_instrument, type=trade.type.value.upper(), commission=trade.commission ) elif trade.side.value == 'sell': color = 'FireBrick' ay = -15 # qty = round(ts * tp, trade.quote_instrument.precision) text_info = dict( step=trade.step, datetime=price_history.iloc[trade.step - 1]['date'], side=trade.side.value.upper(), qty=ts, size=round(ts * tp, trade.base_instrument.precision), quote_instrument=trade.quote_instrument, price=tp, base_instrument=trade.base_instrument, type=trade.type.value.upper(), commission=trade.commission ) else: raise ValueError(f"Valid trade side values are 'buy' and 'sell'. Found '{trade.side.value}'.") hovertext = 'Step {step} [{datetime}]<br>' \ '{side} {qty} {quote_instrument} @ {price} {base_instrument} {type}<br>' \ 'Total: {size} {base_instrument} - Comm.: {commission}'.format(**text_info) annotations += [go.layout.Annotation( x=trade.step - 1, y=tp, ax=0, ay=ay, xref='x1', yref='y1', showarrow=True, arrowhead=2, arrowcolor=color, arrowwidth=4, arrowsize=0.8, hovertext=hovertext, opacity=0.6, hoverlabel=dict(bgcolor=color) )] if trades: self._last_trade_step = trades[list(trades)[-1]][0].step return tuple(annotations)
[docs] def render(self) -> None: # get price history price_history = self.trading_env.feed.meta_history # get performance data performance_df = pd.DataFrame.from_dict(self.trading_env.portfolio.performance, orient='index') net_worth = performance_df['net_worth'] performance = performance_df.drop(columns=['base_symbol']) trades = self.trading_env.broker.trades if not self.fig: self._create_figure(performance.keys()) if self._show_chart: # ensure chart visibility through notebook cell reruns display(self.fig) self.fig.layout.title = f'Profit chart' self._price_chart.update(dict( open=price_history['open'], high=price_history['high'], low=price_history['low'], close=price_history['close'] )) self.fig.layout.annotations += self._create_trade_annotations(trades, price_history) self._volume_chart.update({'y': price_history['volume']}) for trace in self.fig.select_traces(row=3): trace.update({'y': performance[trace.name]}) self._net_worth_chart.update({'y': net_worth}) if self._show_chart: self.fig.show() if self._save_format: self.save()
[docs] def save(self) -> None: """Saves the current chart to a file. Notes ----- All formats other than HTML require Orca installed and server running. """ if not self._save_format: return else: valid_formats = ['html', 'png', 'jpeg', 'webp', 'svg', 'pdf', 'eps'] check_valid_format(valid_formats, self._save_format) check_path(self._path) filename = create_auto_file_name(self._filename_prefix, self._save_format) filename = os.path.join(self._path, filename) if self._save_format == 'html': self.fig.write_html(file=filename, include_plotlyjs='cdn', auto_open=self._auto_open_html) else: self.fig.write_image(filename)
[docs] def reset(self) -> None: self._last_trade_step = 0 if self.fig is None: return self.fig.layout.annotations = self._base_annotations clear_output(wait=True)