|
32 | 32 |
|
33 | 33 | import torch
|
34 | 34 | import torch.nn.functional as F
|
35 |
| -from examples.infinitebench import args |
36 | 35 | import triton
|
37 | 36 | import triton.language as tl
|
38 | 37 | from torch import nn
|
@@ -269,13 +268,9 @@ def __init__(
|
269 | 268 | topk_group: int,
|
270 | 269 | routed_scaling_factor: float,
|
271 | 270 | dtype: Optional[torch.dtype] = None,
|
272 |
| -<<<<<<< HEAD |
273 | 271 | fuse_routing_kernel: bool = True,
|
274 | 272 | apply_routing: bool = False,
|
275 |
| -======= |
276 |
| - is_thop: bool = True, |
277 | 273 | moe_backend: str = 'CUTLASS',
|
278 |
| ->>>>>>> 14626789cf (Add TRT-LLM Gen MOE to Deepseek) |
279 | 274 | ):
|
280 | 275 | super().__init__()
|
281 | 276 | self.weight = nn.Parameter(torch.empty((num_experts, hidden_size),
|
@@ -358,12 +353,9 @@ def __init__(self,
|
358 | 353 | topk_group=config.topk_group,
|
359 | 354 | routed_scaling_factor=config.routed_scaling_factor,
|
360 | 355 | dtype=dtype,
|
361 |
| -<<<<<<< HEAD |
362 | 356 | fuse_routing_kernel=True,
|
363 |
| - apply_routing=False) |
364 |
| -======= |
| 357 | + apply_routing=False, |
365 | 358 | moe_backend=model_config.moe_backend)
|
366 |
| ->>>>>>> 14626789cf (Add TRT-LLM Gen MOE to Deepseek) |
367 | 359 | self.experts = FusedMoE(
|
368 | 360 | num_experts=num_experts,
|
369 | 361 | routing_method=self.gate.routing_method,
|
@@ -602,7 +594,7 @@ def forward(
|
602 | 594 | attn_metadata=attn_metadata,
|
603 | 595 | all_reduce_params=AllReduceParams(
|
604 | 596 | enable_allreduce=not self.disable_attn_allreduce),
|
605 |
| - **args, |
| 597 | + **kwargs, |
606 | 598 | )
|
607 | 599 |
|
608 | 600 | # deepseek allreduce kernel is better when m < 512, two shot(128~512) has acc bug, waive
|
|
0 commit comments