Deep Learning/Coding

nn.ModuleList vs nn.Sequential

반응형

파이토치 코드를 보다보면 자주 등장하는 두 가지 클래스다.

비슷하게 쓰이는것 같으면서도 그 차이점을 구별해라 하면 말하기 어려운데,

구글링을 해 보니 친절한 답변이 있어서 가져왔다. (링크)

더보기

In nn.Sequential, the nn.Module's stored inside are connected in a cascaded way. For instance, in the example that I gave, I define a neural network that receives as input an image with 3 channels and outputs 10 neurons. That network is composed by the following blocks, in the following order: Conv2D -> ReLU -> Linear layer. Moreover, an object of type nn.Sequential has a forward() method, so if I have an input image x I can directly call y = simple_cnn(x) to obtain the scores for x. When you define an nn.Sequential you must be careful to make sure that the output size of a block matches the input size of the following block. Basically, it behaves just like a nn.Module

On the other hand, nn.ModuleList does not have a forward() method, because it does not define any neural network, that is, there is no connection between each of the nn.Module's that it stores. You may use it to store nn.Module's, just like you use Python lists to store other types of objects (integers, strings, etc). The advantage of using nn.ModuleList's instead of using conventional Python lists to store nn.Module's is that Pytorch is “aware” of the existence of the nn.Module's inside an nn.ModuleList, which is not the case for Python lists. If you want to understand exactly what I mean, just try to redefine my class LinearNet using a Python list instead of a nn.ModuleList and train it. When defining the optimizer() for that net, you’ll get an error saying that your model has no parameters, because PyTorch does not see the parameters of the layers stored in a Python list. If you use a nn.ModuleList instead, you’ll get no error.

요약하자면,

nn.Sequential은 안에 들어가는 모듈들을 연결해주고, 하나의 뉴럴넷을 정의한다.

모듈들이 연결된다는 것은, 나열된 모듈들의 output shape와 input shape가 일치해야 한다는 것이다.

자동적으로 연결되므로 그냥 nn.Seqential로 정의된 모듈에 input x를 넣어주면 주루룩 통과돼서 output이 나온다.

반면에, nn.ModuleList는 말 그대로 개별적으로 모듈들이 담겨있는 리스트이다.

이들 간의 연결관계들은 정의되지 않았으며,

따라서 forward 함수에서 ModuleList 내의 모듈들을 이용하여 적절한 연결관계를 정의하는 과정이 필수적이다.

ModuleList가 단순히 모듈 클래스를 python list에 담는 것과 다른 점은,

파이토치 프레임워크가 리스트 안에 담긴것들이 모듈들이라는 것을 인식하고

optimizer를 정의할 때 파라미터들을 인식한다고 한다. (답변 참조...사실 정확히 무슨 뜻인지는 모르겠다)

다른 참고자료 : ttps://michigusa-nlp.tistory.com/26

반응형