from datetime import time as dtime
from pythongo.base import BaseParams, BaseState, Field
from pythongo.classdef import KLineData, OrderData, TickData, TradeData
from pythongo.core import KLineStyleType
from pythongo.ui import BaseStrategy
from pythongo.utils import KLineGeneratorArb


class Params(BaseParams):
    """参数映射模型"""
    exchange: str = Field(default="", title="交易所代码")
    instrument_id: str = Field(default="", title="合约代码")
    initial_price: float = Field(default=0.0, title="初始价格")
    buy_grid_size: float = Field(default=5.0, title="买入网格大小")
    sell_grid_size: float = Field(default=5.0, title="卖出网格大小")
    kline_style: KLineStyleType = Field(default="M1", title="K 线周期")
    pay_up: int | float = Field(default=0, title="超价")
    max_buy_position: int = Field(default=10, title="最大买入持仓数量")
    max_sell_position: int = Field(default=10, title="最大卖出持仓数量")
    price_unit: float = Field(default=0.01, title="价格单位")


class State(BaseState):
    """状态映射模型"""
    long_position: int = Field(default=0, title="多单持仓数量")
    short_position: int = Field(default=0, title="空单持仓数量")


class GridTradingStrategy(BaseStrategy):
    """网格交易策略"""

    def __init__(self):
        super().__init__()
        self.params_map = Params()
        """参数表"""

        self.state_map = State()
        """状态表"""

        self.order_id = None
        """报单 ID"""

        # 初始化时获取持仓
        self.get_initial_positions()

    def get_initial_positions(self):
        """获取初始持仓"""
        # position = self.get_position(self.params_map.instrument_id)  # 使用 get_position 函数获取持仓
        # self.state_map.long_position = position['long_position']  # 假设 get_position 返回的字典包含 'long_position' 键
        # self.state_map.short_position = position['short_position']  # 假设 get_position 返回的字典包含'short_position' 键
        position = self.get_position(self.params_map.instrument_id)
        # 初始化品种持仓
        # position.net_position
        self.output(f'{position}')

    @property
    def main_indicator_data(self) -> dict[str, float]:
        """主图指标"""
        return {}

    def on_tick(self, tick: TickData) -> None:
        """收到行情 tick 推送"""
        super().on_tick(tick)

        current_price = tick.close  # 获取当前价格

        if current_price >= self.params_map.initial_price + self.params_map.sell_grid_size * self.params_map.price_unit:
            if self.state_map.short_position > 0 and self.state_map.short_position < self.params_map.max_sell_position:  # 检查卖出限制
                self.order_id = self.send_order(
                    exchange=self.params_map.exchange,
                    instrument_id=self.params_map.instrument_id,
                    volume=1,
                    price=current_price - self.params_map.pay_up,
                    order_direction="sell"
                )
                self.state_map.short_position -= 1
                self.output(f"卖出一手空单,当前价格: {current_price}")
            elif self.state_map.short_position >= self.params_map.max_sell_position:  # 达到最大卖出持仓,不再卖出
                self.output("已达到最大空单卖出持仓,无法继续卖出")

        elif current_price <= self.params_map.initial_price - self.params_map.buy_grid_size * self.params_map.price_unit:
            if self.state_map.long_position < 0 and self.state_map.long_position > -self.params_map.max_buy_position:  # 检查买入限制
                self.order_id = self.send_order(
                    exchange=self.params_map.exchange,
                    instrument_id=self.params_map.instrument_id,
                    volume=1,
                    price=current_price + self.params_map.pay_up,
                    order_direction="buy"
                )
                self.state_map.long_position += 1
                self.output(f"买入一手多单,当前价格: {current_price}")
            elif self.state_map.long_position <= -self.params_map.max_buy_position:  # 达到最大买入持仓,不再买入
                self.output("已达到最大多单买入持仓,无法继续买入")

    def on_order_cancel(self, order: OrderData) -> None:
        """撤单推送回调"""
        super().on_order_cancel(order)
        self.order_id = None

    def on_trade(self, trade: TradeData, log: bool = False) -> None:
        """成交回调"""
        super().on_trade(trade, log)
        self.order_id = None

    def on_start(self):
        super().on_start()

        self.sub_market_data(
            exchange=self.params_map.exchange,
            instrument_id=self.params_map.instrument_id
        )

    def on_stop(self):
        super().on_stop()

        self.unsub_market_data(
            exchange=self.params_map.exchange,
            instrument_id=self.params_map.instrument_id
        )

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐