2
\$\begingroup\$

Problem Statement

You are given a weighted undirected graph G with N vertices, numbered 1 to N. Initially, G has no edges.

You will perform M operations to add edges to G. The i-th operation (1≤i≤M) is as follows:

You are given a subset of vertices Si={Ai,1, Ai,2, ,…,Ai,Ki} consisting of Ki vertices. For every pair u,v such that u,v ∈ Si and u<v, add an edge between vertices u and v with weight Ci.

​ After performing all M operations, determine whether G is connected. If it is, find the total weight of the edges in a minimum spanning tree of G.

Code:

The code runs okay. Ideone

from collections import defaultdict
from heapq import heappush, heappop


def solution(A):
    def prim(G):
        vis = set()
        start = next(iter(G))
        vis.add(start)
        Q, mst = [], []
        for w, nei in G[start]:
            heappush(Q, (w, start, nei))
        while len(vis) < len(G):
            w, src, dest = heappop(Q)
            if dest in vis:
                continue
            vis.add(dest)
            mst.append((src, dest, w))
            for w, nei in G[dest]:
                heappush(Q, (w, dest, nei))
        return mst

    N, M = A[0]
    graph = defaultdict(list)
    for i in range(1, len(A)):
        if i % 2 == 1:
            k, c = A[i]
        else:
            edges = A[i]
            for ii in range(len(edges)):
                for jj in range(ii + 1, len(edges)):
                    if edges[ii] < edges[jj]:
                        graph[edges[jj]].append((c, edges[ii]))
                        graph[edges[ii]].append((c, edges[jj]))

    mst = prim(graph)
    res = 0
    s = set()
    for x, y, w in mst:
        res += w
        s.update({x, y})

    if sorted(s) != list(range(1, N + 1)):
        print(-1)
    else:
        print(res)


A = [[10, 5], [6, 158260522], [1, 3, 6, 8, 9, 10], [10, 877914575], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
     [4, 602436426], [2, 6, 7, 9], [6, 24979445], [2, 3, 4, 5, 8, 10], [4, 861648772], [2, 4, 8, 9]]
solution(A)

Question

  • How do we optimize it so that we won't get TLE? The running time must be below 2 seconds.

enter image description here

\$\endgroup\$
2
  • 3
    \$\begingroup\$ Welcome to Code Review! Can you confirm that the code is complete and that it it produces the correct results? If so, I recommend that you edit to add a summary of the testing (ideally as reproducible unit-test code). If it's not working, it isn't ready for review (see help center) and the question may be deleted. \$\endgroup\$ Commented May 4 at 15:54
  • 1
    \$\begingroup\$ @TobySpeight Yes, it does. I added a link to Ideone. Thanks! \$\endgroup\$ Commented May 4 at 16:01

1 Answer 1

2
\$\begingroup\$

Your code with the example dataset has only 645 calls (that's very fast). Anyway, I improved it and added some comment. I don't think any particular gimmick are needed, such as using JIT.

There are some negative aspects like the names of variables, the definition of a function inside another and unused variables. I suggest you read PEP 8 — the Style Guide for Python Code.

Here the new code:

from collections import defaultdict
from heapq import heappush, heappop
from typing import List, Tuple


def prim_minimum_spanning_tree(graph: dict) -> List[Tuple[int, int, int]]:
    """
    Computes the Minimum Spanning Tree (MST) using Prim's algorithm.
    :param graph: A dictionary representing the undirected graph with edge weights.
    :return: A list of tuples (src, dest, weight) representing the MST.
    """
    visited = set()
    start_node = next(iter(graph))
    visited.add(start_node)
    min_heap, mst = [], []

    for weight, neighbor in graph[start_node]:
        heappush(min_heap, (weight, start_node, neighbor))

    while len(visited) < len(graph):
        weight, src, dest = heappop(min_heap)
        if dest in visited:
            continue
        visited.add(dest)
        mst.append((src, dest, weight))
        for weight, neighbor in graph[dest]:
            heappush(min_heap, (weight, dest, neighbor))

    return mst


def solution(A: List[List[int]]) -> int:
    """
    Computes the sum of weights of the MST for the given graph.
    :param A: A list of lists representing the graph with edge weights.
    :return: The integer solutions
    """
    graph = defaultdict(list)
    num_nodes, _ = A[0]

    for i in range(1, len(A)):
        if i % 2 == 1:
            c = A[i][1]
        else:
            edges = A[i]
            for ii in range(len(edges)):
                for jj in range(ii + 1, len(edges)):
                    if edges[ii] < edges[jj]:
                        graph[edges[jj]].append((c, edges[ii]))
                        graph[edges[ii]].append((c, edges[jj]))

    mst = prim_minimum_spanning_tree(graph)
    total_weight = sum(weight for _, _, weight in mst)

    if sorted(set(node for edge in mst for node in edge[:2])) != list(
        range(1, num_nodes + 1)
    ):
        return -1
    else:
        return total_weight


def main():
    A = [
        [10, 5],
        [6, 158260522],
        [1, 3, 6, 8, 9, 10],
        [10, 877914575],
        [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
        [4, 602436426],
        [2, 6, 7, 9],
        [6, 24979445],
        [2, 3, 4, 5, 8, 10],
        [4, 861648772],
        [2, 4, 8, 9],
    ]
    print(solution(A))


if __name__ == "__main__":
    main()
\$\endgroup\$

Not the answer you're looking for? Browse other questions tagged or ask your own question.