IJCAI Track 2 Forecasting Future Turn-Based Strokes in Badminton Rallies - baseline学习

关注我的公众号YueTan进行交流探讨
欢迎关注数据比赛方案仓库 https://github.com/hongyingyue/Competition-solutions

第四名代码:https://github.com/LongxingTan/Data-competitions/tree/master/ijcai-badminton

在这里插入图片描述

背景

The forecasting of future turn-based strokes in badminton rallies (Track 2) is to design predictive models that are able to forecast future strokes including shot types and locations based on past strokes. For more details, please feel free to visit our repo to check out our previous work.

Input: landing_x, landing_y, shot type and metadata of past 4 strokes

Output: landing_x, landing_y, shot type of future strokes

For each singles rally, given the observed 4 strokes with type-area pairs and two players, the goal is to predict the future strokes including shot types and area coordinates for the next n steps. n is various based on the length of the rally.

规则

Testing Data Release: June 13, 2023

Testing Submission Deadline: June 20, 2023

Winner Announcement: June 27, 2023

Paper Submission Deadline: July 11, 2023

需要注意的是6月13号发布的测试数据集,所以之前的成绩似乎是A榜,并不完全重要

Baseline

baseline链接

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

if __name__ == "__main__":
    config = get_argument()
    config['data_folder'] = '../data/'
    config['model_folder'] = './model/'
    model_type = config['model_type']
    set_seed(config['seed_value'])

    # Clean data and Prepare dataset
    config, train_dataloader, val_dataloader, test_dataloader, train_matches, val_matches, test_matches = prepare_dataset(config)

    device = torch.device(f"cuda:{config['gpu_num']}" if torch.cuda.is_available() else "cpu")
    print("Model path: {}".format(config['output_folder_name']))
    if not os.path.exists(config['output_folder_name']):
        os.makedirs(config['output_folder_name'])

    # read model
    from ShuttleNet.ShuttleNet import ShotGenEncoder, ShotGenPredictor
    from ShuttleNet.ShuttleNet_runner import shotGen_trainer
    encoder = ShotGenEncoder(config)
    decoder = ShotGenPredictor(config)
    encoder.area_embedding.weight = decoder.shotgen_decoder.area_embedding.weight
    encoder.shot_embedding.weight = decoder.shotgen_decoder.shot_embedding.weight
    encoder.player_embedding.weight = decoder.shotgen_decoder.player_embedding.weight
    decoder.player_embedding.weight = decoder.shotgen_decoder.player_embedding.weight

    encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=config['lr'])
    decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=config['lr'])
    encoder.to(device), decoder.to(device)

    criterion = {
        'entropy': nn.CrossEntropyLoss(ignore_index=0, reduction='sum'),
        'mae': nn.L1Loss(reduction='sum')
    }
    for key, value in criterion.items():
        criterion[key].to(device)

    record_train_loss = shotGen_trainer(data_loader=train_dataloader, encoder=encoder, decoder=decoder, criterion=criterion, encoder_optimizer=encoder_optimizer, decoder_optimizer=decoder_optimizer, config=config, device=device)

    draw_loss(record_train_loss, config)

数据

  • 注意输入和输出都是一定的序列,每个序列带有一些特征
  • 输入的设置最大ball round为70,因此每一次得分都被pad到70轮
    • 对于大于70的部分,baseline只截取了前70
    • 对于小于70的,后面pad为0
  • 输入和输出的部分还需要确认
    • 输入输出的维度都是:shot_type, x, y, player
    • 输入为 序列 0 ~ (n-1)
    • 输出为序列 1~ n
class BadmintonDataset(Dataset):
    def __init__(self, matches, config):
        super().__init__()
        self.max_ball_round = config['max_ball_round']   # max_ball_round=70
        group = matches[['rally_id', 'ball_round', 'type', 'landing_x', 'landing_y', 'player', 'set']].groupby('rally_id').apply(lambda r: (r['ball_round'].values, r['type'].values, r['landing_x'].values, r['landing_y'].values, r['player'].values, r['set'].values))

        self.sequences, self.rally_ids = {}, []
        for i, rally_id in enumerate(group.index):
            ball_round, shot_type, landing_x, landing_y, player, sets = group[rally_id]
            self.sequences[rally_id] = (ball_round, shot_type, landing_x, landing_y, player, sets)
            self.rally_ids.append(rally_id)

    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, index):
        rally_id = self.rally_ids[index]
        ball_round, shot_type, landing_x, landing_y, player, sets = self.sequences[rally_id]

        pad_input_shot = np.full(self.max_ball_round, fill_value=PAD, dtype=int)
        pad_input_x = np.full(self.max_ball_round, fill_value=PAD, dtype=float)
        pad_input_y = np.full(self.max_ball_round, fill_value=PAD, dtype=float)
        pad_input_player = np.full(self.max_ball_round, fill_value=PAD, dtype=int)
        pad_output_shot = np.full(self.max_ball_round, fill_value=PAD, dtype=int)
        pad_output_x = np.full(self.max_ball_round, fill_value=PAD, dtype=float)
        pad_output_y = np.full(self.max_ball_round, fill_value=PAD, dtype=float)
        pad_output_player = np.full(self.max_ball_round, fill_value=PAD, dtype=int)

        # pad or trim based on the max ball round
        if len(ball_round) > self.max_ball_round:
            rally_len = self.max_ball_round

            pad_input_shot[:] = shot_type[0:-1:1][:rally_len]                                   # 0, 1, ..., max_ball_round-1
            pad_input_x[:] = landing_x[0:-1:1][:rally_len]
            pad_input_y[:] = landing_y[0:-1:1][:rally_len]
            pad_input_player[:] = player[0:-1:1][:rally_len]
            pad_output_shot[:] = shot_type[1::1][:rally_len]                                    # 1, 2, ..., max_ball_round
            pad_output_x[:] = landing_x[1::1][:rally_len]
            pad_output_y[:] = landing_y[1::1][:rally_len]
            pad_output_player[:] = player[1::1][:rally_len]
        else:
            rally_len = len(ball_round) - 1                                                     # 0 ~ (n-2)
            
            pad_input_shot[:rally_len] = shot_type[0:-1:1]                                      # 0, 1, ..., n-1
            pad_input_x[:rally_len] = landing_x[0:-1:1]
            pad_input_y[:rally_len] = landing_y[0:-1:1]
            pad_input_player[:rally_len] = player[0:-1:1]
            pad_output_shot[:rally_len] = shot_type[1::1]                                       # 1, 2, ..., n
            pad_output_x[:rally_len] = landing_x[1::1]
            pad_output_y[:rally_len] = landing_y[1::1]
            pad_output_player[:rally_len] = player[1::1]

        return (pad_input_shot, pad_input_x, pad_input_y, pad_input_player,
                pad_output_shot, pad_output_x, pad_output_y, pad_output_player,
                rally_len, sets[0])


def prepare_dataset(config):
    train_matches = pd.read_csv(f"{config['data_folder']}train.csv")
    val_matches = pd.read_csv(f"{config['data_folder']}val_given.csv")
    test_matches = pd.read_csv(f"{config['data_folder']}test_given.csv")

    # encode shot type
    codes_type, uniques_type = pd.factorize(train_matches['type'])  # 
    train_matches['type'] = codes_type + 1                                # Reserve code 0 for paddings
    val_matches['type'] = val_matches['type'].apply(lambda x: list(uniques_type).index(x)+1)
    test_matches['type'] = test_matches['type'].apply(lambda x: list(uniques_type).index(x)+1)
    config['uniques_type'] = uniques_type.to_list()
    config['shot_num'] = len(uniques_type) + 1                            # Add padding

    # encode player
    train_matches['player'] = train_matches['player'].apply(lambda x: x+1)
    val_matches['player'] = val_matches['player'].apply(lambda x: x+1)
    test_matches['player'] = test_matches['player'].apply(lambda x: x+1)
    config['player_num'] = 35 + 1                                         # Add padding

    train_dataset = BadmintonDataset(train_matches, config)
    train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)

    val_dataset = BadmintonDataset(val_matches, config)
    val_dataloader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)

    test_dataset = BadmintonDataset(test_matches, config)
    test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)

    return config, train_dataloader, val_dataloader, test_dataloader, train_matches, val_matches, test_matches

·

模型

from ShuttleNet.ShuttleNet import ShotGenEncoder, ShotGenPredictor
from ShuttleNet.ShuttleNet_runner import shotGen_trainer
encoder = ShotGenEncoder(config)
decoder = ShotGenPredictor(config)
encoder.area_embedding.weight = decoder.shotgen_decoder.area_embedding.weight
encoder.shot_embedding.weight = decoder.shotgen_decoder.shot_embedding.weight
encoder.player_embedding.weight = decoder.shotgen_decoder.player_embedding.weight
decoder.player_embedding.weight = decoder.shotgen_decoder.player_embedding.weight

encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=config['lr'])
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=config['lr'])
encoder.to(device), decoder.to(device)

其中encoder

  • 输入: encoder(input_shot, input_x, input_y, input_player)
  • input_shot_type embedding: nn.Embedding
  • player embedding: nn.Embedding
  • 两个player分布embedding
  • input x, y : 2个concat到一起,然后 nn.Linear(2, config[‘area_dim’])
  • 位置编码
  • attention
class ShotGenEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.area_embedding = nn.Linear(2, config['area_dim'])
        self.shot_embedding = ShotEmbedding(config['shot_num'], config['shot_dim'])
        self.player_embedding = PlayerEmbedding(config['player_num'], config['player_dim'])

        n_heads = 2
        d_k = config['encode_dim']
        d_v = config['encode_dim']
        d_model = config['encode_dim']
        d_inner = config['encode_dim'] * 2
        dropout = 0.1
        self.d_model = d_model

        self.position_embedding = PositionalEncoding(config['shot_dim'], config['encode_length'], n_position=config['max_ball_round'])
        self.dropout = nn.Dropout(p=dropout)

        self.global_layer = EncoderLayer(d_model, d_inner, n_heads, d_k, d_v, dropout=dropout)
        self.local_layer = EncoderLayer(d_model, d_inner, n_heads, d_k, d_v, dropout=dropout)

    def forward(self, input_shot, input_x, input_y, input_player, src_mask=None, return_attns=False):
        enc_slf_attn_list = []

        area = torch.cat((input_x.unsqueeze(-1), input_y.unsqueeze(-1)), dim=-1).float()
        
        embedded_area = F.relu(self.area_embedding(area))  # batch, seq, embed
        embedded_shot = self.shot_embedding(input_shot)  # batch, seq, embed
        embedded_player = self.player_embedding(input_player)  # batch, seq, embed

        h_a = embedded_area + embedded_player
        h_s = embedded_shot + embedded_player

        # split player
        h_a_A = h_a[:, ::2]
        h_a_B = h_a[:, 1::2]
        h_s_A = h_s[:, ::2]
        h_s_B = h_s[:, 1::2]

        # local
        encode_output_area = self.dropout(self.position_embedding(h_a, mode='encode'))
        encode_output_shot = self.dropout(self.position_embedding(h_s, mode='encode'))
        # global
        encode_output_area_A = self.dropout(self.position_embedding(h_a_A, mode='encode'))
        encode_output_area_B = self.dropout(self.position_embedding(h_a_B, mode='encode'))
        encode_output_shot_A = self.dropout(self.position_embedding(h_s_A, mode='encode'))
        encode_output_shot_B = self.dropout(self.position_embedding(h_s_B, mode='encode'))

        encode_global_A, enc_slf_attn_A = self.global_layer(encode_output_area_A, encode_output_shot_A, slf_attn_mask=src_mask)
        encode_global_B, enc_slf_attn_B = self.global_layer(encode_output_area_B, encode_output_shot_B, slf_attn_mask=src_mask)
        
        encode_local_output, enc_slf_attn = self.local_layer(encode_output_area, encode_output_shot, slf_attn_mask=src_mask)

        if return_attns:
            return encode_local_output, encode_global_A, encode_global_B, enc_slf_attn_list
        return encode_local_output, encode_global_A, encode_global_B

其中的predictor

  • 输入: decoder(input_shot, input_x, input_y, input_player, encode_local_output, encode_global_A, encode_global_B, target_player)
class ShotGenDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.area_embedding = nn.Linear(2, config['area_dim'])
        self.shot_embedding = ShotEmbedding(config['shot_num'], config['shot_dim'])
        self.player_embedding = PlayerEmbedding(config['player_num'], config['player_dim'])

        n_heads = 2
        d_k = config['encode_dim']
        d_v = config['encode_dim']
        d_model = config['encode_dim']
        d_inner = config['encode_dim'] * 2
        dropout = 0.1
        self.d_model = d_model

        self.position_embedding = PositionalEncoding(config['shot_dim'], config['encode_length'], n_position=config['max_ball_round']+1)
        self.dropout = nn.Dropout(p=dropout)

        self.global_layer = DecoderLayer(d_model, d_inner, n_heads, d_k, d_v, dropout=dropout)
        self.local_layer = DecoderLayer(d_model, d_inner, n_heads, d_k, d_v, dropout=dropout)

        self.gated_fusion = GatedFusionLayer(d_model, d_model, config['encode_length'], config['max_ball_round']+1)

    def forward(self, input_shot, input_x, input_y, input_player, encode_local_output, encode_global_A, encode_global_B, trg_mask=None, return_attns=False):
        decoder_self_attention_list, decoder_encoder_self_attention_list = [], []

        area = torch.cat((input_x.unsqueeze(-1), input_y.unsqueeze(-1)), dim=-1).float()

        # split player only for masking
        mask_A = input_shot[:, ::2]
        mask_B = input_shot[:, 1::2]

        # triangular mask
        trg_local_mask = get_pad_mask(input_shot) & get_subsequent_mask(input_shot)
        trg_global_A_mask = get_pad_mask(mask_A) & get_subsequent_mask(mask_A)
        trg_global_B_mask = get_pad_mask(mask_B) & get_subsequent_mask(mask_B)
        
        embedded_area = F.relu(self.area_embedding(area))
        embedded_shot = self.shot_embedding(input_shot)
        embedded_player = self.player_embedding(input_player)

        h_a = embedded_area + embedded_player
        h_s = embedded_shot + embedded_player

        # split player
        h_a_A = h_a[:, ::2]
        h_a_B = h_a[:, 1::2]
        h_s_A = h_s[:, ::2]
        h_s_B = h_s[:, 1::2]

        # local
        decode_output_area = self.dropout(self.position_embedding(h_a, mode='decode'))
        decode_output_shot = self.dropout(self.position_embedding(h_s, mode='decode'))
        # global
        decode_output_area_A = self.dropout(self.position_embedding(h_a_A, mode='decode'))
        decode_output_area_B = self.dropout(self.position_embedding(h_a_B, mode='decode'))
        decode_output_shot_A = self.dropout(self.position_embedding(h_s_A, mode='decode'))
        decode_output_shot_B = self.dropout(self.position_embedding(h_s_B, mode='decode'))

        decode_global_A, dec_slf_attn_A, dec_enc_attn_A, disentangled_weight_A = self.global_layer(decode_output_area_A, decode_output_shot_A, encode_global_A, slf_attn_mask=trg_global_A_mask, return_attns=return_attns)
        if decode_output_area_B.shape[1] != 0:
            decode_global_B, dec_slf_attn_B, dec_enc_attn_B, disentangled_weight_B = self.global_layer(decode_output_area_B, decode_output_shot_B, encode_global_B, slf_attn_mask=trg_global_B_mask, return_attns=return_attns)

        decode_local_output, dec_slf_attn, dec_enc_attn, disentangled_weight_local = self.local_layer(decode_output_area, decode_output_shot, encode_local_output, slf_attn_mask=trg_local_mask, return_attns=return_attns)
        decoder_self_attention_list = dec_slf_attn if return_attns else []
        decoder_encoder_self_attention_list = dec_enc_attn if return_attns else []

        if decode_output_area_B.shape[1] != 0:
            decode_output_A = alternatemerge(decode_global_A, decode_global_A, decode_local_output.shape[1], 'A')
            decode_output_B = alternatemerge(decode_global_B, decode_global_B, decode_local_output.shape[1], 'B')
        else:
            decode_output_A = decode_global_A.clone()
            decode_output_B = torch.zeros(decode_local_output.shape, device=decode_local_output.device)
        decode_output = self.gated_fusion(decode_output_A, decode_output_B, decode_local_output)

        # (batch, seq_len, encode_dim)
        if return_attns:
            return decode_output, decoder_self_attention_list, decoder_encoder_self_attention_list, disentangled_weight_local
        return decode_output


class ShotGenPredictor(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.shotgen_decoder = ShotGenDecoder(config)
        self.area_decoder = nn.Sequential(
            nn.Linear(config['encode_dim'], config['area_num'], bias=False)
        )
        self.shot_decoder = nn.Sequential(
            nn.Linear(config['encode_dim'], config['shot_num'], bias=False)
        )
        self.player_embedding = PlayerEmbedding(config['player_num'], config['player_dim'])

    def forward(self, input_shot, input_x, input_y, input_player, encode_local_output, encode_global_A, encode_global_B, target_player, return_attns=False):
        embedded_target_player = self.player_embedding(target_player)
        if return_attns:
            decode_output, decoder_self_attention_list, decoder_encoder_self_attention_list, disentangled_weight_local = self.shotgen_decoder(input_shot, input_x, input_y, input_player, encode_local_output, encode_global_A, encode_global_B, return_attns=return_attns)
        else:
            decode_output = self.shotgen_decoder(input_shot, input_x, input_y, input_player, encode_local_output, encode_global_A, encode_global_B, return_attns)
        
        decode_output = (decode_output + embedded_target_player)

        area_logits = self.area_decoder(decode_output)
        shot_logits = self.shot_decoder(decode_output)

        if return_attns:
            return area_logits, shot_logits, decoder_self_attention_list, decoder_encoder_self_attention_list, disentangled_weight_local
        else:
            return area_logits, shot_logits

训练

  • 注意的是encoder和decoder都各自有一个optimizer。之前在NLP中可以为不同层设置不同的学习率,但仍然是一个优化器。像GAN一样的两个优化器好处在哪?
def shotGen_trainer(data_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, config, device="cpu"):
    encode_length = config['encode_length'] - 1         # use the first 3 strokes to the encoder
    record_loss = {
        'total': [],
        'shot': [],
        'area': []
    }

    for epoch in tqdm(range(config['epochs']), desc='Epoch: '):
        encoder.train(), decoder.train()
        total_loss, total_shot_loss, total_area_loss = 0, 0, 0
        total_instance = 0

        for loader_idx, item in enumerate(data_loader):
            batch_input_shot, batch_input_x, batch_input_y, batch_input_player = item[0].to(device), item[1].to(device), item[2].to(device), item[3].to(device)
            batch_target_shot, batch_target_x, batch_target_y, batch_target_player = item[4].to(device), item[5].to(device), item[6].to(device), item[7].to(device)
            seq_len, seq_sets = item[8].to(device), item[9].to(device)

            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()

            input_shot = batch_input_shot[:, :encode_length]
            input_x = batch_input_x[:, :encode_length]
            input_y = batch_input_y[:, :encode_length]
            input_player = batch_input_player[:, :encode_length]
            encode_local_output, encode_global_A, encode_global_B = encoder(input_shot, input_x, input_y, input_player)

            input_shot = batch_input_shot[:, encode_length:]
            input_x = batch_input_x[:, encode_length:]
            input_y = batch_input_y[:, encode_length:]
            input_player = batch_input_player[:, encode_length:]
            target_shot = batch_target_shot[:, encode_length:]
            target_x = batch_target_x[:, encode_length:]
            target_y = batch_target_y[:, encode_length:]
            target_player = batch_target_player[:, encode_length:]
            output_xy, output_shot_logits = decoder(input_shot, input_x, input_y, input_player, encode_local_output, encode_global_A, encode_global_B, target_player)
            
            pad_mask = (input_shot!=PAD)
            output_shot_logits = output_shot_logits[pad_mask]
            target_shot = target_shot[pad_mask]
            output_xy = output_xy[pad_mask]
            target_x = target_x[pad_mask]
            target_y = target_y[pad_mask]

            _, output_shot = torch.topk(output_shot_logits, 1)
            gold_xy = torch.cat((target_x.unsqueeze(-1), target_y.unsqueeze(-1)), dim=-1).to(device, dtype=torch.float)

            total_instance += len(target_shot)

            loss_shot = criterion['entropy'](output_shot_logits, target_shot)
            loss_area = Gaussian2D_loss(output_xy, gold_xy)

            loss = loss_shot + loss_area
            loss.backward()

            encoder_optimizer.step()
            decoder_optimizer.step()

            total_loss += loss.item()
            total_shot_loss += loss_shot.item()
            total_area_loss += loss_area.item()

        total_loss = round(total_loss / total_instance, 4)
        total_shot_loss = round(total_shot_loss / total_instance, 4)
        total_area_loss = round(total_area_loss / total_instance, 4)

        record_loss['total'].append(total_loss)
        record_loss['shot'].append(total_shot_loss)
        record_loss['area'].append(total_area_loss)

    config['total_loss'] = total_loss
    config['total_shot_loss'] = total_shot_loss
    config['total_area_loss'] = total_area_loss
    save(encoder, decoder, config)

    return record_loss

推理

  • 文本生成的方式,逐个产生预测值
def shotgen_generator(given_seq, encoder, decoder, config, samples, device):
    encode_length = config['encode_length'] - 1
    encoder.eval(), decoder.eval()
    generated_shot_logits, generated_area_coordinates = [], []

    with torch.no_grad():
        # encoding stage
        input_shot = given_seq['given_shot'][:encode_length].unsqueeze(0)
        input_x = given_seq['given_x'][:encode_length].unsqueeze(0)
        input_y = given_seq['given_y'][:encode_length].unsqueeze(0)
        input_player = given_seq['given_player'][:encode_length].unsqueeze(0)

        encode_local_output, encode_global_A, encode_global_B = encoder(input_shot, input_x, input_y, input_player)

        for sample_id in range(samples):
            current_generated_shot, current_generated_area = [], []
            total_instance = len(given_seq['given_shot']) - len(given_seq['given_shot'][:encode_length])
            for seq_idx in range(encode_length, given_seq['rally_length']-1):
                if seq_idx == encode_length:
                    input_shot = given_seq['given_shot'][seq_idx].unsqueeze(0).unsqueeze(0)
                    input_x = given_seq['given_x'][seq_idx].unsqueeze(0).unsqueeze(0)
                    input_y = given_seq['given_y'][seq_idx].unsqueeze(0).unsqueeze(0)
                    input_player = given_seq['given_player'][seq_idx].unsqueeze(0).unsqueeze(0)
                else:
                    # use its own predictions as the next input
                    input_shot = torch.cat((input_shot, prev_shot), dim=-1)
                    input_x = torch.cat((input_x, prev_x), dim=-1)
                    input_y = torch.cat((input_y, prev_y), dim=-1)
                    input_player = torch.cat((input_player, prev_player), dim=-1)
                target_player = given_seq['target_player'][seq_idx-encode_length].unsqueeze(0).unsqueeze(0)

                output_xy, output_shot_logits = decoder(input_shot, input_x, input_y, input_player, encode_local_output, encode_global_A, encode_global_B, target_player)

                # sample area coordinates
                sx = torch.exp(output_xy[:, -1, 2]) #sx
                sy = torch.exp(output_xy[:, -1, 3]) #sy
                corr = torch.tanh(output_xy[:, -1, 4]) #corr
                
                cov = torch.zeros(2, 2).cuda(output_xy.device)
                cov[0, 0]= sx * sx
                cov[0, 1]= corr * sx * sy
                cov[1, 0]= corr * sx * sy
                cov[1, 1]= sy * sy
                mean = output_xy[:, -1, 0:2]
                
                mvnormal = torchdist.MultivariateNormal(mean, cov)
                output_xy = mvnormal.sample().unsqueeze(0)

                # sampling
                shot_prob = F.softmax(output_shot_logits, dim=-1)
                output_shot = shot_prob[0].multinomial(num_samples=1).unsqueeze(0)

                while output_shot[0, -1, 0] == 0:
                    output_shot = shot_prob[0].multinomial(num_samples=1).unsqueeze(0)

                prev_shot = output_shot[:, -1, :]
                prev_x = output_xy[:, -1, 0].unsqueeze(1)
                prev_y = output_xy[:, -1, 1].unsqueeze(1)
                prev_player = target_player.clone()

                # transform to original format
                ori_shot = config['uniques_type'][prev_shot.item()-1]
                ori_x = prev_x.item()
                ori_y = prev_y.item()

                current_generated_shot.append(shot_prob[0][-1][1:].cpu().tolist())      # 0 is pad
                current_generated_area.append((ori_x, ori_y))

            generated_shot_logits.append(current_generated_shot), generated_area_coordinates.append(current_generated_area)

    return generated_shot_logits, generated_area_coordinates

my EDA

在这里插入图片描述
在这里插入图片描述

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

YueTann

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值