File size: 2,083 Bytes
66f2fb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import numpy as np
from scipy.sparse import csr_matrix
from typing import Dict, List


def convert_sparse_to_csr(sparse_dict: Dict[str, List]) -> csr_matrix:
    """
    Convert sparse embedding to scipy CSR matrix
    
    API format: {"indices": [10, 25, 42], "values": [0.85, 0.62, 0.91]}
    Milvus format: scipy.sparse.csr_matrix with shape (1, max_dimension)
    
    Args:
        sparse_dict: Dictionary with 'indices' and 'values'
    
    Returns:
        scipy CSR matrix
    """
    indices = sparse_dict["indices"]
    values = sparse_dict["values"]
    
    max_dim = max(indices) + 1 if indices else 1
    
    # Create CSR matrix
    # Shape: (1, max_dim) karena ini single vector
    row_indices = [0] * len(indices)  # Semua di row 0
    col_indices = indices
    
    sparse_matrix = csr_matrix(
        (values, (row_indices, col_indices)),
        shape=(1, max_dim)
    )
    
    return sparse_matrix


def batch_convert_sparse_to_csr(sparse_list: List[Dict[str, List]]) -> csr_matrix:
    """
    Convert batch of sparse embeddings to single CSR matrix
    
    Args:
        sparse_list: List of sparse dicts
    
    Returns:
        scipy CSR matrix with shape (batch_size, max_dim)
    """
    if not sparse_list:
        return csr_matrix((0, 0))
    
    max_dim = 0
    for sparse_dict in sparse_list:
        if sparse_dict["indices"]:
            max_dim = max(max_dim, max(sparse_dict["indices"]) + 1)
    
    if max_dim == 0:
        max_dim = 30000  # Default vocab size for SPLADE
    
    # Build row indices, column indices, and values
    row_indices = []
    col_indices = []
    values = []
    
    for row_idx, sparse_dict in enumerate(sparse_list):
        indices = sparse_dict["indices"]
        vals = sparse_dict["values"]
        
        row_indices.extend([row_idx] * len(indices))
        col_indices.extend(indices)
        values.extend(vals)
    
    # Create CSR matrix
    sparse_matrix = csr_matrix(
        (values, (row_indices, col_indices)),
        shape=(len(sparse_list), max_dim)
    )
    
    return sparse_matrix